36 #define CADD(a,b,c) { (c).real = (a).real + (b).real; \
37 (c).imag = (a).imag + (b).imag; }
38 #define CMUL(a,b,c) { (c).real = (a).real*(b).real - (a).imag*(b).imag; \
39 (c).imag = (a).real*(b).imag + (a).imag*(b).real; }
40 #define CSUM(a,b) { (a).real += (b).real; (a).imag += (b).imag; }
43 #define CMULJ_(a,b,c) { (c).real = (a).real*(b).real + (a).imag*(b).imag; \
44 (c).imag = (a).real*(b).imag - (a).imag*(b).real; }
47 #define CMUL_J(a,b,c) { (c).real = (a).real*(b).real + (a).imag*(b).imag; \
48 (c).imag = (a).imag*(b).real - (a).real*(b).imag; }
53 template<
typename su3_matrix,
typename Real>
59 for(i=0;i<3;i++)
for(j=0;j<3;j++){
60 b->e[i][j].real = s*a->e[i][j].real;
61 b->e[i][j].imag = s*a->e[i][j].imag;
67 template<
typename su3_matrix,
typename Real>
72 for(i=0;i<3;i++)
for(j=0;j<3;j++){
73 c->e[i][j].real = a->e[i][j].real + s*b->e[i][j].real;
74 c->e[i][j].imag = a->e[i][j].imag + s*b->e[i][j].imag;
79 template <
typename su3_matrix>
84 typeof(a->e[0][0])
x,
y;
85 for(i=0;i<3;i++)
for(j=0;j<3;j++){
88 CMUL_J( a->e[i][k] , b->e[j][k] , y );
95 template <
typename su3_matrix>
100 typeof(a->e[0][0])
x,
y;
101 for(i=0;i<3;i++)
for(j=0;j<3;j++){
104 CMUL( a->e[i][k] , b->e[k][j] , y );
111 template<
typename su3_matrix>
116 typeof(a->e[0][0])
x,
y;
117 for(i=0;i<3;i++)
for(j=0;j<3;j++){
120 CMULJ_( a->e[k][i] , b->e[k][j], y );
131 template<
typename su3_matrix>
136 for(i=0;i<3;i++)
for(j=0;j<3;j++){
137 CADD( a->e[i][j], b->e[i][j], c->e[i][j] );
143 template<
typename su3_matrix,
typename Real>
146 su3_matrix* mulink, su3_matrix** sitelink,
void**
fatlink, Real coef,
149 su3_matrix tmat1,tmat2;
172 fat1 = ((su3_matrix*)fatlink[mu]) + i;
173 su3_matrix* A = sitelink[nu] + i;
175 memset(dx, 0,
sizeof(dx));
180 B = mulink + nbr_idx;
182 B = mulink + nbr_idx;
185 memset(dx, 0,
sizeof(dx));
188 su3_matrix* C = sitelink[nu] + nbr_idx;
211 fat1 = ((su3_matrix*)fatlink[mu]) + i;
212 memset(dx, 0,
sizeof(dx));
215 if (nbr_idx >= V || nbr_idx <0){
216 fprintf(stderr,
"ERROR: invliad nbr_idx(%d), line=%d\n", nbr_idx, __LINE__);
219 su3_matrix* A = sitelink[nu] + nbr_idx;
223 B = mulink + nbr_idx;
225 B = mulink + nbr_idx;
228 memset(dx, 0,
sizeof(dx));
231 su3_matrix* C = sitelink[nu] + nbr_idx;
259 template <
typename su3_matrix,
typename Float>
263 su3_matrix* staple = (su3_matrix *)malloc(
V*
sizeof(su3_matrix));
265 fprintf(stderr,
"Error: malloc failed for staple in function %s\n", __FUNCTION__);
269 su3_matrix* tempmat1 = (su3_matrix *)malloc(
V*
sizeof(su3_matrix));
270 if(tempmat1 == NULL){
271 fprintf(stderr,
"ERROR: malloc failed for tempmat1 in function %s\n", __FUNCTION__);
276 Float one_link = (act_path_coeff[0] - 6.0*act_path_coeff[5]);
279 for (
int dir=
XUP; dir<=
TUP; dir++){
282 for(
int i=0;i <
V;i ++){
283 su3_matrix* fat1 = ((su3_matrix*)fatlink[dir]) + i;
291 for (
int dir=
XUP; dir<=
TUP; dir++){
292 for(
int nu=
XUP; nu<=
TUP; nu++){
301 for(
int rho=
XUP; rho<=
TUP; rho++) {
302 if((rho!=dir)&&(rho!=nu)){
327 template<
typename su3_matrix,
typename Float>
329 Float* act_path_coeff)
333 for(
int dir=
XUP; dir<=
TUP; ++dir){
334 int dx[4] = {0,0,0,0};
335 for(
int i=0; i<
V; ++i){
337 su3_matrix* llink = ((su3_matrix*)longlink[dir]) + i;
350 template<
typename su3_matrix,
typename Float>
354 for(
int dir=0; dir<4; ++dir) E[dir] =
Z[dir]+4;
357 const int extended_volume = E[3]*E[2]*E[1]*E[0];
360 for(
int t=0; t<
Z[3]; ++t){
361 for(
int z=0; z<Z[2]; ++z){
362 for(
int y=0;
y<Z[1]; ++
y){
363 for(
int x=0;
x<Z[0]; ++
x){
365 int little_index = ((((t*Z[2] + z)*Z[1] +
y)*Z[0] +
x)/2) + oddBit*
Vh;
366 int large_index = (((((t+2)*E[2] + (z+2))*E[1] + (
y+2))*E[0] +
x+2)/2) + oddBit*(extended_volume/2);
369 for(
int dir=
XUP; dir<=
TUP; ++dir){
370 int dx[4] = {0,0,0,0};
371 su3_matrix* llink = ((su3_matrix*)longlink[dir]) + little_index;
401 fprintf(stderr,
"ERROR: unsupported precision(%d)\n", prec);
434 fprintf(stderr,
"ERROR: unsupported precision(%d)\n", prec);
446 template<
typename su3_matrix,
typename Real>
448 llfat_compute_gen_staple_field_mg(su3_matrix *staple,
int mu,
int nu,
449 su3_matrix* mulink, su3_matrix** ghost_mulink,
450 su3_matrix** sitelink, su3_matrix** ghost_sitelink, su3_matrix** ghost_sitelink_diag,
454 su3_matrix tmat1,tmat2;
509 (x4*X3X2+x3*X2+
x2)/2,
510 (x4*X3X1+x3*X1+x1)/2,
511 (x4*X2X1+x2*X1+
x1)/2,
515 fat1 = ((su3_matrix*)fatlink[mu]) + i;
516 su3_matrix* A = sitelink[nu] + i;
518 memset(dx, 0,
sizeof(dx));
524 if (x[nu] + dx[nu] >= Z[nu]){
525 B = ghost_mulink[nu] +
Vs[nu] + (1-
oddBit)*
Vsh[nu] + space_con[nu];
528 B = mulink + nbr_idx;
531 if(x[nu]+dx[nu] >= Z[nu]){
532 B = ghost_sitelink[nu] + 4*
Vs[nu] + mu*
Vs[nu] + (1-
oddBit)*
Vsh[nu] + space_con[nu];
535 B = sitelink[
mu] + nbr_idx;
542 memset(dx, 0,
sizeof(dx));
544 if(x[mu] + dx[mu] >= Z[mu]){
548 C = sitelink[nu] + nbr_idx;
581 int x1h = sid - za*
X1h;
586 int x1odd = (x2 + x3 + x4 +
oddBit) & 1;
587 int x1 = 2*x1h +
x1odd;
590 (x4*X3X2+x3*X2+
x2)/2,
591 (x4*X3X1+x3*X1+x1)/2,
592 (x4*X2X1+x2*X1+
x1)/2,
598 fat1 = ((su3_matrix*)fatlink[mu]) + i;
602 memset(dx, 0,
sizeof(dx));
606 if(x[nu] + dx[nu] < 0){
607 A = ghost_sitelink[nu] + nu*
Vs[nu] + (1-
oddBit)*
Vsh[nu] + space_con[nu];
610 A = sitelink[nu] + nbr_idx;
616 if (x[nu] + dx[nu] < 0){
617 B = ghost_mulink[nu] + (1-
oddBit)*
Vsh[nu] + space_con[nu];
619 B = mulink + nbr_idx;
622 if(x[nu] + dx[nu] < 0){
623 B = ghost_sitelink[nu] + mu*
Vs[nu] + (1-
oddBit)*
Vsh[nu] + space_con[nu];
626 B = sitelink[
mu] + nbr_idx;
633 memset(dx, 0,
sizeof(dx));
639 int new_x1, new_x2, new_x3, new_x4;
640 new_x1 = (x[0] + dx[0] + Z[0])%Z[0];
641 new_x2 = (x[1] + dx[1] + Z[1])%Z[1];
642 new_x3 = (x[2] + dx[2] + Z[2])%Z[2];
643 new_x4 = (x[3] + dx[3] + Z[3])%Z[3];
644 int new_x[4] = {new_x1, new_x2, new_x3, new_x4};
645 space_con[0] = (new_x4*X3X2 + new_x3*X2 + new_x2)/2;
646 space_con[1] = (new_x4*X3X1 + new_x3*X1 + new_x1)/2;
647 space_con[2] = (new_x4*X2X1 + new_x2*X1 + new_x1)/2;
648 space_con[3] = (new_x3*X2X1 + new_x2*X1 + new_x1)/2;
650 if( (x[nu] + dx[nu]) < 0 && (x[
mu] + dx[
mu] >= Z[
mu])){
654 for(dir1=0; dir1 < 4; dir1 ++){
655 if(dir1 != nu && dir1 != mu){
659 for(dir2=0; dir2 < 4; dir2 ++){
660 if(dir2 != nu && dir2 != mu && dir2 != dir1){
664 C = ghost_sitelink_diag[nu*4+
mu] + oddBit*Z[dir1]*Z[dir2]/2 + (new_x[dir2]*Z[dir1]+new_x[dir1])/2;
665 }
else if (x[nu] + dx[nu] < 0){
666 C = ghost_sitelink[nu] + nu*
Vs[nu] + oddBit*
Vsh[nu]+ space_con[nu];
667 }
else if (x[mu] + dx[mu] >= Z[mu]){
670 C = sitelink[nu] + nbr_idx;
687 template <
typename su3_matrix,
typename Float>
688 void llfat_cpu_mg(
void** fatlink, su3_matrix** sitelink, su3_matrix** ghost_sitelink,
689 su3_matrix** ghost_sitelink_diag,
Float* act_path_coeff)
692 if (
sizeof(
Float) == 4){
698 su3_matrix* staple = (su3_matrix *)malloc(V*
sizeof(su3_matrix));
700 fprintf(stderr,
"Error: malloc failed for staple in function %s\n", __FUNCTION__);
705 su3_matrix* ghost_staple[4];
706 su3_matrix* ghost_staple1[4];
708 for(
int i=0;i < 4;i++){
709 ghost_staple[i] = (su3_matrix*)malloc(2*
Vs[i]*
sizeof(su3_matrix));
710 if (ghost_staple[i] == NULL){
711 fprintf(stderr,
"Error: malloc failed for ghost staple in function %s\n", __FUNCTION__);
715 ghost_staple1[i] = (su3_matrix*)malloc(2*
Vs[i]*
sizeof(su3_matrix));
716 if (ghost_staple1[i] == NULL){
717 fprintf(stderr,
"Error: malloc failed for ghost staple1 in function %s\n", __FUNCTION__);
722 su3_matrix* tempmat1 = (su3_matrix *)malloc(V*
sizeof(su3_matrix));
723 if(tempmat1 == NULL){
724 fprintf(stderr,
"ERROR: malloc failed for tempmat1 in function %s\n", __FUNCTION__);
729 Float one_link = (act_path_coeff[0] - 6.0*act_path_coeff[5]);
732 for (
int dir=
XUP; dir<=
TUP; dir++){
735 for(
int i=0;i <
V;i ++){
736 su3_matrix* fat1 = ((su3_matrix*)fatlink[dir]) + i;
741 for (
int dir=
XUP; dir<=
TUP; dir++){
742 for(
int nu=
XUP; nu<=
TUP; nu++){
744 llfat_compute_gen_staple_field_mg(staple,dir,nu,
745 sitelink[dir], (su3_matrix**)NULL,
746 sitelink, ghost_sitelink, ghost_sitelink_diag,
747 fatlink, act_path_coeff[2], 0);
753 llfat_compute_gen_staple_field_mg((su3_matrix*)NULL,dir,nu,
755 sitelink, ghost_sitelink, ghost_sitelink_diag,
756 fatlink, act_path_coeff[5],1);
758 for(
int rho=
XUP; rho<=
TUP; rho++) {
759 if((rho!=dir)&&(rho!=nu)){
760 llfat_compute_gen_staple_field_mg( tempmat1, dir, rho,
762 sitelink, ghost_sitelink, ghost_sitelink_diag,
763 fatlink, act_path_coeff[3], 1);
771 llfat_compute_gen_staple_field_mg((su3_matrix*)NULL,dir,
sig,
772 tempmat1, ghost_staple1,
773 sitelink, ghost_sitelink, ghost_sitelink_diag,
774 fatlink, act_path_coeff[4], 1);
789 for(
int i=0;i < 4;i++){
790 free(ghost_staple[i]);
791 free(ghost_staple1[i]);
801 void** ghost_sitelink_diag,
QudaPrecision prec,
void* act_path_coeff)
817 (
dsu3_matrix**)ghost_sitelink_diag, (
double*) act_path_coeff);
822 (
fsu3_matrix**)ghost_sitelink_diag, (
float*) act_path_coeff);
826 fprintf(stderr,
"ERROR: unsupported precision(%d)\n", prec);
enum QudaPrecision_s QudaPrecision
void exchange_cpu_staple(int *X, void *staple, void **ghost_staple, QudaPrecision gPrecision)
void llfat_mult_su3_nn(su3_matrix *a, su3_matrix *b, su3_matrix *c)
__global__ void const RealA *const const RealA *const const RealA *const const RealB *const const RealB *const int int mu
int neighborIndexFullLattice(int i, int dx4, int dx3, int dx2, int dx1)
void llfat_add_su3_matrix(su3_matrix *a, su3_matrix *b, su3_matrix *c)
void llfat_compute_gen_staple_field(su3_matrix *staple, int mu, int nu, su3_matrix *mulink, su3_matrix **sitelink, void **fatlink, Real coef, int use_staple)
void llfat_scalar_mult_su3_matrix(su3_matrix *a, Real s, su3_matrix *b)
FloatingPoint< float > Float
void llfat_scalar_mult_add_su3_matrix(su3_matrix *a, su3_matrix *b, Real s, su3_matrix *c)
void llfat_mult_su3_na(su3_matrix *a, su3_matrix *b, su3_matrix *c)
void computeLongLinkCPU(void **longlink, su3_matrix **sitelink, Float *act_path_coeff)
int neighborIndexFullLattice_mg(int i, int dx4, int dx3, int dx2, int dx1)
void llfat_cpu(void **fatlink, su3_matrix **sitelink, Float *act_path_coeff)
void * memset(void *s, int c, size_t n)
Main header file for the QUDA library.
void llfat_reference_mg(void **fatlink, void **sitelink, void **ghost_sitelink, void **ghost_sitelink_diag, QudaPrecision prec, void *act_path_coeff)
void llfat_mult_su3_an(su3_matrix *a, su3_matrix *b, su3_matrix *c)
__global__ void const RealA *const const RealA *const const RealA *const const RealB *const const RealB *const int sig
void llfat_reference(void **fatlink, void **sitelink, QudaPrecision prec, void *act_path_coeff)