25 template<
typename real>
struct su3_matrix { std::complex<real> e[3][3]; };
26 template<
typename real>
struct su3_vector { std::complex<real> e[3]; };
31 template<
typename su3_matrix,
typename Real>
37 for(i=0;i<3;i++)
for(j=0;j<3;j++){
38 b->
e[i][j] = s*a->
e[i][j];
44 template<
typename su3_matrix,
typename Real>
49 for(i=0;i<3;i++)
for(j=0;j<3;j++){
50 c->
e[i][j] = a->
e[i][j] + s*b->
e[i][j];
55 template <
typename su3_matrix>
60 typename std::remove_reference<decltype(a->
e[0][0])>::type x,y;
61 for(i=0;i<3;i++)
for(j=0;j<3;j++){
64 y = a->
e[i][k] *
conj(b->
e[j][k]);
71 template <
typename su3_matrix>
76 typename std::remove_reference<decltype(a->
e[0][0])>::type x,y;
77 for(i=0;i<3;i++)
for(j=0;j<3;j++){
80 y = a->
e[i][k] * b->
e[k][j];
87 template<
typename su3_matrix>
92 typename std::remove_reference<decltype(a->
e[0][0])>::type x,y;
93 for(i=0;i<3;i++)
for(j=0;j<3;j++){
96 y =
conj(a->
e[k][i]) * b->
e[k][j];
107 template<
typename su3_matrix>
112 for(i=0;i<3;i++)
for(j=0;j<3;j++){
113 c->
e[i][j] = a->
e[i][j] + b->
e[i][j];
119 template<
typename su3_matrix,
typename Real>
151 memset(dx, 0,
sizeof(dx));
156 B = mulink + nbr_idx;
158 B = mulink + nbr_idx;
161 memset(dx, 0,
sizeof(dx));
188 memset(dx, 0,
sizeof(dx));
191 if (nbr_idx >= V || nbr_idx <0){
192 fprintf(stderr,
"ERROR: invliad nbr_idx(%d), line=%d\n", nbr_idx, __LINE__);
199 B = mulink + nbr_idx;
201 B = mulink + nbr_idx;
204 memset(dx, 0,
sizeof(dx));
235 template <
typename su3_matrix,
typename Float>
241 fprintf(stderr,
"Error: malloc failed for staple in function %s\n", __FUNCTION__);
246 if(tempmat1 == NULL){
247 fprintf(stderr,
"ERROR: malloc failed for tempmat1 in function %s\n", __FUNCTION__);
252 Float one_link = (act_path_coeff[0] - 6.0*act_path_coeff[5]);
255 for (
int dir=
XUP; dir<=
TUP; dir++){
258 for(
int i=0;i <
V;i ++){
267 for (
int dir=
XUP; dir<=
TUP; dir++){
268 for(
int nu=
XUP; nu<=
TUP; nu++){
277 for(
int rho=
XUP; rho<=
TUP; rho++) {
278 if((rho!=dir)&&(rho!=nu)){
281 for(
int sig=
XUP; sig<=
TUP; sig++){
282 if((sig!=dir)&&(sig!=nu)&&(sig!=rho)){
303 template<
typename su3_matrix,
typename Float>
305 Float* act_path_coeff)
309 for(
int dir=
XUP; dir<=
TUP; ++dir){
310 int dx[4] = {0,0,0,0};
311 for(
int i=0; i<
V; ++i){
326 template<
typename su3_matrix,
typename Float>
330 for(
int dir=0; dir<4; ++dir) E[dir] =
Z[dir]+4;
333 const int extended_volume = E[3]*E[2]*E[1]*E[0];
336 for(
int t=0; t<
Z[3]; ++t){
337 for(
int z=0; z<Z[2]; ++z){
338 for(
int y=0; y<Z[1]; ++y){
339 for(
int x=0; x<Z[0]; ++x){
340 const int oddBit = (x+y+z+t)&1;
341 int little_index = ((((t*Z[2] + z)*Z[1] + y)*Z[0] + x)/2) + oddBit*
Vh;
342 int large_index = (((((t+2)*E[2] + (z+2))*E[1] + (y+2))*E[0] + x+2)/2) + oddBit*(extended_volume/2);
345 for(
int dir=
XUP; dir<=
TUP; ++dir){
346 int dx[4] = {0,0,0,0};
377 fprintf(stderr,
"ERROR: unsupported precision(%d)\n", prec);
410 fprintf(stderr,
"ERROR: unsupported precision(%d)\n", prec);
422 template<
typename su3_matrix,
typename Real>
424 llfat_compute_gen_staple_field_mg(
su3_matrix *staple,
int mu,
int nu,
476 int x1h = sid - za*X1h;
481 int x1odd = (x2 + x3 + x4 + oddBit) & 1;
482 int x1 = 2*x1h + x1odd;
483 int x[4] = {x1,x2,x3,x4};
485 (x4*X3X2+x3*X2+x2)/2,
486 (x4*X3X1+x3*X1+x1)/2,
487 (x4*X2X1+x2*X1+x1)/2,
494 memset(dx, 0,
sizeof(dx));
500 if (x[nu] + dx[nu] >=
Z[nu]){
501 B = ghost_mulink[nu] +
Vs[nu] + (1-oddBit)*
Vsh[nu] + space_con[nu];
504 B = mulink + nbr_idx;
507 if(x[nu]+dx[nu] >=
Z[nu]){
508 B = ghost_sitelink[nu] + 4*
Vs[nu] + mu*
Vs[nu] + (1-oddBit)*
Vsh[nu] + space_con[nu];
511 B = sitelink[
mu] + nbr_idx;
518 memset(dx, 0,
sizeof(dx));
520 if(x[mu] + dx[mu] >=
Z[mu]){
521 C = ghost_sitelink[
mu] + 4*
Vs[
mu] + nu*
Vs[
mu] + (1-oddBit)*
Vsh[mu] + space_con[mu];
524 C = sitelink[nu] + nbr_idx;
557 int x1h = sid - za*X1h;
562 int x1odd = (x2 + x3 + x4 + oddBit) & 1;
563 int x1 = 2*x1h + x1odd;
564 int x[4] = {x1,x2,x3,x4};
566 (x4*X3X2+x3*X2+x2)/2,
567 (x4*X3X1+x3*X1+x1)/2,
568 (x4*X2X1+x2*X1+x1)/2,
578 memset(dx, 0,
sizeof(dx));
582 if(x[nu] + dx[nu] < 0){
583 A = ghost_sitelink[nu] + nu*
Vs[nu] + (1-oddBit)*
Vsh[nu] + space_con[nu];
586 A = sitelink[nu] + nbr_idx;
592 if (x[nu] + dx[nu] < 0){
593 B = ghost_mulink[nu] + (1-oddBit)*
Vsh[nu] + space_con[nu];
595 B = mulink + nbr_idx;
598 if(x[nu] + dx[nu] < 0){
599 B = ghost_sitelink[nu] + mu*
Vs[nu] + (1-oddBit)*
Vsh[nu] + space_con[nu];
602 B = sitelink[
mu] + nbr_idx;
609 memset(dx, 0,
sizeof(dx));
615 int new_x1, new_x2, new_x3, new_x4;
616 new_x1 = (x[0] + dx[0] +
Z[0])%
Z[0];
617 new_x2 = (x[1] + dx[1] +
Z[1])%
Z[1];
618 new_x3 = (x[2] + dx[2] +
Z[2])%
Z[2];
619 new_x4 = (x[3] + dx[3] +
Z[3])%
Z[3];
620 int new_x[4] = {new_x1, new_x2, new_x3, new_x4};
621 space_con[0] = (new_x4*X3X2 + new_x3*X2 + new_x2)/2;
622 space_con[1] = (new_x4*X3X1 + new_x3*X1 + new_x1)/2;
623 space_con[2] = (new_x4*X2X1 + new_x2*X1 + new_x1)/2;
624 space_con[3] = (new_x3*X2X1 + new_x2*X1 + new_x1)/2;
626 if( (x[nu] + dx[nu]) < 0 && (x[
mu] + dx[
mu] >=
Z[
mu])){
630 for(dir1=0; dir1 < 4; dir1 ++){
631 if(dir1 != nu && dir1 != mu){
635 for(dir2=0; dir2 < 4; dir2 ++){
636 if(dir2 != nu && dir2 != mu && dir2 != dir1){
640 C = ghost_sitelink_diag[nu*4+
mu] + oddBit*
Z[dir1]*
Z[dir2]/2 + (new_x[dir2]*
Z[dir1]+new_x[dir1])/2;
641 }
else if (x[nu] + dx[nu] < 0){
642 C = ghost_sitelink[nu] + nu*
Vs[nu] + oddBit*
Vsh[nu]+ space_con[nu];
643 }
else if (x[mu] + dx[mu] >=
Z[mu]){
646 C = sitelink[nu] + nbr_idx;
663 template <
typename su3_matrix,
typename Float>
665 su3_matrix** ghost_sitelink_diag, Float* act_path_coeff)
668 if (
sizeof(Float) == 4){
676 fprintf(stderr,
"Error: malloc failed for staple in function %s\n", __FUNCTION__);
684 for(
int i=0;i < 4;i++){
686 if (ghost_staple[i] == NULL){
687 fprintf(stderr,
"Error: malloc failed for ghost staple in function %s\n", __FUNCTION__);
692 if (ghost_staple1[i] == NULL){
693 fprintf(stderr,
"Error: malloc failed for ghost staple1 in function %s\n", __FUNCTION__);
699 if(tempmat1 == NULL){
700 fprintf(stderr,
"ERROR: malloc failed for tempmat1 in function %s\n", __FUNCTION__);
705 Float one_link = (act_path_coeff[0] - 6.0*act_path_coeff[5]);
708 for (
int dir=
XUP; dir<=
TUP; dir++){
711 for(
int i=0;i <
V;i ++){
717 for (
int dir=
XUP; dir<=
TUP; dir++){
718 for(
int nu=
XUP; nu<=
TUP; nu++){
720 llfat_compute_gen_staple_field_mg(staple,dir,nu,
722 sitelink, ghost_sitelink, ghost_sitelink_diag,
723 fatlink, act_path_coeff[2], 0);
729 llfat_compute_gen_staple_field_mg((
su3_matrix*)NULL,dir,nu,
731 sitelink, ghost_sitelink, ghost_sitelink_diag,
734 for(
int rho=
XUP; rho<=
TUP; rho++) {
735 if((rho!=dir)&&(rho!=nu)){
736 llfat_compute_gen_staple_field_mg( tempmat1, dir, rho,
738 sitelink, ghost_sitelink, ghost_sitelink_diag,
739 fatlink, act_path_coeff[3], 1);
744 for(
int sig=
XUP; sig<=
TUP; sig++){
745 if((sig!=dir)&&(sig!=nu)&&(sig!=rho)){
747 llfat_compute_gen_staple_field_mg((
su3_matrix*)NULL,dir,sig,
748 tempmat1, ghost_staple1,
749 sitelink, ghost_sitelink, ghost_sitelink_diag,
750 fatlink, act_path_coeff[4], 1);
765 for(
int i=0;i < 4;i++){
766 free(ghost_staple[i]);
767 free(ghost_staple1[i]);
802 fprintf(stderr,
"ERROR: unsupported precision(%d)\n",
prec);
819 double* dst = (
double*)y;
820 double* src = (
double*)x;
821 for (
int i = 0; i <
size; i++)
826 float* dst = (
float*)y;
827 float* src = (
float*)x;
828 for (
int i = 0; i <
size; i++)
838 double* dst = (
double*)y;
839 double* src = (
double*)x;
840 for (
int i = 0; i <
size; i++)
845 float* dst = (
float*)y;
846 float* src = (
float*)x;
847 for (
int i = 0; i <
size; i++)
855 template <
typename Out,
typename In>
857 for (
int i = 0; i <
V; i++) {
858 for (
int dir = 0; dir < 4; dir++) {
859 for (
int j = 0; j < siteSize; j++) {
860 milc_out[(i*4+dir)*siteSize+j] = static_cast<Out>(qdp_in[dir][i*siteSize+j]);
869 reorderQDPtoMILC<float,float>((
float*)milc_out, (
float**)qdp_in,
V, siteSize);
871 reorderQDPtoMILC<float,double>((
float*)milc_out, (
double**)qdp_in,
V, siteSize);
875 reorderQDPtoMILC<double,float>((
double*)milc_out, (
float**)qdp_in,
V, siteSize);
877 reorderQDPtoMILC<double,double>((
double*)milc_out, (
double**)qdp_in,
V, siteSize);
882 template <
typename Out,
typename In>
884 for (
int i = 0; i <
V; i++) {
885 for (
int dir = 0; dir < 4; dir++) {
886 for (
int j = 0; j < siteSize; j++) {
887 qdp_out[dir][i*siteSize+j] =
static_cast<Out
>(milc_in[(i*4+dir)*siteSize+j]);
896 reorderMILCtoQDP<float,float>((
float**)qdp_out, (
float*)milc_in,
V, siteSize);
898 reorderMILCtoQDP<float,double>((
float**)qdp_out, (
double*)milc_in,
V, siteSize);
902 reorderMILCtoQDP<double,float>((
double**)qdp_out, (
float*)milc_in,
V, siteSize);
904 reorderMILCtoQDP<double,double>((
double**)qdp_out, (
double*)milc_in,
V, siteSize);
913 void** fatlink_eps,
void** longlink_eps,
914 void** sitelink,
void* qudaGaugeParamPtr,
915 double** act_path_coeffs,
double eps_naik)
930 const int n_naiks = (eps_naik == 0.0 ? 1 : 2);
936 void* sitelink_ex[4];
941 void* ghost_sitelink[4];
942 void* ghost_sitelink_diag[16];
950 for(
int i=0; i <
V_ex; i++){
959 int x1h = sid - za*
E1h;
964 int x1odd = (x2 + x3 + x4 + oddBit) & 1;
965 int x1 = 2*x1h + x1odd;
968 if( x1< 2 || x1 >= X1 +2
969 || x2< 2 || x2 >= X2 +2
970 || x3< 2 || x3 >= X3 +2
971 || x4< 2 || x4 >= X4 +2){
979 x1 = (x1 - 2 +
X1) % X1;
980 x2 = (x2 - 2 +
X2) % X2;
981 x3 = (x3 - 2 +
X3) % X3;
982 x4 = (x4 - 2 +
X4) % X4;
984 int idx = (x4*X3*X2*X1+x3*X2*X1+x2*X1+x1)>>1;
988 for(
int dir= 0; dir < 4; dir++){
989 char* src = (
char*)sitelink[dir];
990 char* dst = (
char*)sitelink_ex[dir];
1001 void* w_reflink_ex[4];
1002 for(
int i=0;i < 4;i++){
1009 void* ghost_wlink[4];
1010 void* ghost_wlink_diag[16];
1027 for (
int i=0; i < 6;i++) coeff_sp[i] = coeff_dp[i] = act_path_coeffs[0][i];
1044 for(
int nu=0;nu < 4;nu++){
1045 for(
int mu=0;
mu < 4;
mu++){
1047 ghost_sitelink_diag[nu*4+
mu] = NULL;
1051 for(dir1= 0; dir1 < 4; dir1++){
1052 if(dir1 !=nu && dir1 !=
mu){
1056 for(dir2=0; dir2 < 4; dir2++){
1057 if(dir2 != nu && dir2 !=
mu && dir2 != dir1){
1067 exchange_cpu_sitelink(gParam.
x, sitelink, ghost_sitelink, ghost_sitelink_diag, prec, &qudaGaugeParam, optflag);
1068 llfat_reference_mg(v_reflink, sitelink, ghost_sitelink, ghost_sitelink_diag, prec, coeff);
1105 gParam.
gauge = v_sitelink;
1148 for(
int i=0; i <
V_ex; i++) {
1157 int x1h = sid - za*
E1h;
1159 int x2 = za - zb*
E2;
1161 int x3 = zb - x4*
E3;
1162 int x1odd = (x2 + x3 + x4 + oddBit) & 1;
1163 int x1 = 2*x1h + x1odd;
1166 if( x1< 2 || x1 >= X1 +2
1167 || x2< 2 || x2 >= X2 +2
1168 || x3< 2 || x3 >= X3 +2
1169 || x4< 2 || x4 >= X4 +2){
1177 x1 = (x1 - 2 +
X1) % X1;
1178 x2 = (x2 - 2 +
X2) % X2;
1179 x3 = (x3 - 2 +
X3) % X3;
1180 x4 = (x4 - 2 +
X4) % X4;
1182 int idx = (x4*X3*X2*X1+x3*X2*X1+x2*X1+x1)>>1;
1186 for(
int dir= 0; dir < 4; dir++){
1187 char* src = (
char*)w_reflink[dir];
1188 char* dst = (
char*)w_reflink_ex[dir];
1189 memcpy(dst+i*gaugeSiteSize*gSize, src+idx*gaugeSiteSize*gSize, gaugeSiteSize*gSize);
1201 for (
int i=0; i < 4; i++) ghost_wlink[i] =
safe_malloc(8*Vs[i]*gaugeSiteSize*gSize);
1209 for(
int nu=0;nu < 4;nu++){
1210 for(
int mu=0;
mu < 4;
mu++){
1212 ghost_wlink_diag[nu*4+
mu] = NULL;
1216 for(dir1= 0; dir1 < 4; dir1++){
1217 if(dir1 !=nu && dir1 !=
mu){
1221 for(dir2=0; dir2 < 4; dir2++){
1222 if(dir2 != nu && dir2 !=
mu && dir2 != dir1){
1226 ghost_wlink_diag[nu*4+
mu] =
safe_malloc(
Z[dir1]*
Z[dir2]*gaugeSiteSize*gSize);
1227 memset(ghost_wlink_diag[nu*4+
mu], 0,
Z[dir1]*
Z[dir2]*gaugeSiteSize*gSize);
1240 for (
int i=0; i < 6;i++) coeff_sp[i] = coeff_dp[i] = act_path_coeffs[2][i];
1249 int R[4] = {2,2,2,2};
1259 for (
int i = 0; i < 4; i++) {
1260 cpu_axy(prec, eps_naik, fatlink[i], fatlink_eps[i],
V*gaugeSiteSize);
1261 cpu_axy(prec, eps_naik, longlink[i], longlink_eps[i],
V*gaugeSiteSize);
1269 for (
int i=0; i < 6;i++) coeff_sp[i] = coeff_dp[i] = act_path_coeffs[1][i];
1281 int R[4] = {2,2,2,2};
1292 for (
int i = 0; i < 4; i++) {
1293 cpu_xpy(prec, fatlink[i], fatlink_eps[i],
V*gaugeSiteSize);
1294 cpu_xpy(prec, longlink[i], longlink_eps[i],
V*gaugeSiteSize);
1302 for(
int i=0; i < 4; i++){
1311 for(
int i=0; i<4; i++){
1314 for(
int j=0;j <4; j++){
static QudaGaugeParam qudaGaugeParam
QudaGhostExchange ghostExchange
#define pinned_malloc(size)
enum QudaPrecision_s QudaPrecision
void llfat_mult_su3_nn(su3_matrix *a, su3_matrix *b, su3_matrix *c)
int neighborIndexFullLattice(int i, int dx4, int dx3, int dx2, int dx1)
void exchange_cpu_sitelink(int *X, void **sitelink, void **ghost_sitelink, void **ghost_sitelink_diag, QudaPrecision gPrecision, QudaGaugeParam *param, int optflag)
void llfat_add_su3_matrix(su3_matrix *a, su3_matrix *b, su3_matrix *c)
void llfat_compute_gen_staple_field(su3_matrix *staple, int mu, int nu, su3_matrix *mulink, su3_matrix **sitelink, void **fatlink, Real coef, int use_staple)
void llfat_scalar_mult_su3_matrix(su3_matrix *a, Real s, su3_matrix *b)
void exchange_cpu_staple(int *X, void *staple, void **ghost_staple, QudaPrecision gPrecision)
void reorderQDPtoMILC(Out *milc_out, In **qdp_in, int V, int siteSize)
void cpu_xpy(QudaPrecision prec, void *x, void *y, int size)
void llfat_scalar_mult_add_su3_matrix(su3_matrix *a, su3_matrix *b, Real s, su3_matrix *c)
QudaGaugeFieldOrder order
void llfat_mult_su3_na(su3_matrix *a, su3_matrix *b, su3_matrix *c)
void computeHISQLinksCPU(void **fatlink, void **longlink, void **fatlink_eps, void **longlink_eps, void **sitelink, void *qudaGaugeParamPtr, double **act_path_coeffs, double eps_naik)
#define safe_malloc(size)
void reorderMILCtoQDP(Out **qdp_out, In *milc_in, int V, int siteSize)
std::complex< real > e[3][3]
void * memset(void *s, int c, size_t n)
void computeLongLinkCPU(void **longlink, su3_matrix **sitelink, Float *act_path_coeff)
int neighborIndexFullLattice_mg(int i, int dx4, int dx3, int dx2, int dx1)
void llfat_cpu(void **fatlink, su3_matrix **sitelink, Float *act_path_coeff)
void unitarizeLinksCPU(cpuGaugeField &outfield, const cpuGaugeField &infield)
Main header file for the QUDA library.
void llfat_reference_mg(void **fatlink, void **sitelink, void **ghost_sitelink, void **ghost_sitelink_diag, QudaPrecision prec, void *act_path_coeff)
void cpu_axy(QudaPrecision prec, double a, void *x, void *y, int size)
void llfat_mult_su3_an(su3_matrix *a, su3_matrix *b, su3_matrix *c)
__host__ __device__ ValueType conj(ValueType x)
void exchange_cpu_sitelink_ex(int *X, int *R, void **sitelink, QudaGaugeFieldOrder cpu_order, QudaPrecision gPrecision, int optflag, int geometry)
void llfat_reference(void **fatlink, void **sitelink, QudaPrecision prec, void *act_path_coeff)