QUDA  v0.5.0
A library for QCD on GPUs
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
llfat_reference.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 
5 #include <quda.h>
6 #include <test_util.h>
7 #include "llfat_reference.h"
8 #include "misc.h"
9 #include <string.h>
10 
11 #include <quda_internal.h>
12 #include "face_quda.h"
13 
14 #define XUP 0
15 #define YUP 1
16 #define ZUP 2
17 #define TUP 3
18 
19 typedef struct {
20  float real;
21  float imag;
22 } fcomplex;
23 
24 /* specific for double complex */
25 typedef struct {
26  double real;
27  double imag;
28 } dcomplex;
29 
30 typedef struct { fcomplex e[3][3]; } fsu3_matrix;
31 typedef struct { fcomplex c[3]; } fsu3_vector;
32 typedef struct { dcomplex e[3][3]; } dsu3_matrix;
33 typedef struct { dcomplex c[3]; } dsu3_vector;
34 
35 
36 #define CADD(a,b,c) { (c).real = (a).real + (b).real; \
37  (c).imag = (a).imag + (b).imag; }
38 #define CMUL(a,b,c) { (c).real = (a).real*(b).real - (a).imag*(b).imag; \
39  (c).imag = (a).real*(b).imag + (a).imag*(b).real; }
40 #define CSUM(a,b) { (a).real += (b).real; (a).imag += (b).imag; }
41 
42 /* c = a* * b */
43 #define CMULJ_(a,b,c) { (c).real = (a).real*(b).real + (a).imag*(b).imag; \
44  (c).imag = (a).real*(b).imag - (a).imag*(b).real; }
45 
46 /* c = a * b* */
47 #define CMUL_J(a,b,c) { (c).real = (a).real*(b).real + (a).imag*(b).imag; \
48  (c).imag = (a).imag*(b).real - (a).real*(b).imag; }
49 
50 static int Vs[4];
51 static int Vsh[4];
52 
53 template<typename su3_matrix, typename Real>
54 void
55 llfat_scalar_mult_su3_matrix( su3_matrix *a, Real s, su3_matrix *b )
56 {
57 
58  int i,j;
59  for(i=0;i<3;i++)for(j=0;j<3;j++){
60  b->e[i][j].real = s*a->e[i][j].real;
61  b->e[i][j].imag = s*a->e[i][j].imag;
62  }
63 
64  return;
65 }
66 
67 template<typename su3_matrix, typename Real>
68 void
69 llfat_scalar_mult_add_su3_matrix(su3_matrix *a,su3_matrix *b, Real s, su3_matrix *c)
70 {
71  int i,j;
72  for(i=0;i<3;i++)for(j=0;j<3;j++){
73  c->e[i][j].real = a->e[i][j].real + s*b->e[i][j].real;
74  c->e[i][j].imag = a->e[i][j].imag + s*b->e[i][j].imag;
75  }
76 
77 }
78 
79 template <typename su3_matrix>
80 void
81 llfat_mult_su3_na( su3_matrix *a, su3_matrix *b, su3_matrix *c )
82 {
83  int i,j,k;
84  typeof(a->e[0][0]) x,y;
85  for(i=0;i<3;i++)for(j=0;j<3;j++){
86  x.real=x.imag=0.0;
87  for(k=0;k<3;k++){
88  CMUL_J( a->e[i][k] , b->e[j][k] , y );
89  CSUM( x , y );
90  }
91  c->e[i][j] = x;
92  }
93 }
94 
95 template <typename su3_matrix>
96 void
97 llfat_mult_su3_nn( su3_matrix *a, su3_matrix *b, su3_matrix *c )
98 {
99  int i,j,k;
100  typeof(a->e[0][0]) x,y;
101  for(i=0;i<3;i++)for(j=0;j<3;j++){
102  x.real=x.imag=0.0;
103  for(k=0;k<3;k++){
104  CMUL( a->e[i][k] , b->e[k][j] , y );
105  CSUM( x , y );
106  }
107  c->e[i][j] = x;
108  }
109 }
110 
111 template<typename su3_matrix>
112 void
113 llfat_mult_su3_an( su3_matrix *a, su3_matrix *b, su3_matrix *c )
114 {
115  int i,j,k;
116  typeof(a->e[0][0]) x,y;
117  for(i=0;i<3;i++)for(j=0;j<3;j++){
118  x.real=x.imag=0.0;
119  for(k=0;k<3;k++){
120  CMULJ_( a->e[k][i] , b->e[k][j], y );
121  CSUM( x , y );
122  }
123  c->e[i][j] = x;
124  }
125 }
126 
127 
128 
129 
130 
131 template<typename su3_matrix>
132 void
133 llfat_add_su3_matrix( su3_matrix *a, su3_matrix *b, su3_matrix *c )
134 {
135  int i,j;
136  for(i=0;i<3;i++)for(j=0;j<3;j++){
137  CADD( a->e[i][j], b->e[i][j], c->e[i][j] );
138  }
139 }
140 
141 
142 
143 template<typename su3_matrix, typename Real>
144 void
145 llfat_compute_gen_staple_field(su3_matrix *staple, int mu, int nu,
146  su3_matrix* mulink, su3_matrix** sitelink, void** fatlink, Real coef,
147  int use_staple)
148 {
149  su3_matrix tmat1,tmat2;
150  int i ;
151  su3_matrix *fat1;
152 
153  /* Upper staple */
154  /* Computes the staple :
155  * mu (B)
156  * +-------+
157  * nu | |
158  * (A) | |(C)
159  * X X
160  *
161  * Where the mu link can be any su3_matrix. The result is saved in staple.
162  * if staple==NULL then the result is not saved.
163  * It also adds the computed staple to the fatlink[mu] with weight coef.
164  */
165 
166  int dx[4];
167 
168  /* upper staple */
169 
170  for(i=0;i < V;i++){
171 
172  fat1 = ((su3_matrix*)fatlink[mu]) + i;
173  su3_matrix* A = sitelink[nu] + i;
174 
175  memset(dx, 0, sizeof(dx));
176  dx[nu] =1;
177  int nbr_idx = neighborIndexFullLattice(i, dx[3], dx[2], dx[1], dx[0]);
178  su3_matrix* B;
179  if (use_staple){
180  B = mulink + nbr_idx;
181  }else{
182  B = mulink + nbr_idx;
183  }
184 
185  memset(dx, 0, sizeof(dx));
186  dx[mu] =1;
187  nbr_idx = neighborIndexFullLattice(i, dx[3], dx[2],dx[1],dx[0]);
188  su3_matrix* C = sitelink[nu] + nbr_idx;
189 
190  llfat_mult_su3_nn( A, B,&tmat1);
191 
192  if(staple!=NULL){/* Save the staple */
193  llfat_mult_su3_na( &tmat1, C, &staple[i]);
194  } else{ /* No need to save the staple. Add it to the fatlinks */
195  llfat_mult_su3_na( &tmat1, C, &tmat2);
196  llfat_scalar_mult_add_su3_matrix(fat1, &tmat2, coef, fat1);
197  }
198  }
199  /***************lower staple****************
200  *
201  * X X
202  * nu | |
203  * (A) | |(C)
204  * +-------+
205  * mu (B)
206  *
207  *********************************************/
208 
209  for(i=0;i < V;i++){
210 
211  fat1 = ((su3_matrix*)fatlink[mu]) + i;
212  memset(dx, 0, sizeof(dx));
213  dx[nu] = -1;
214  int nbr_idx = neighborIndexFullLattice(i, dx[3], dx[2], dx[1], dx[0]);
215  if (nbr_idx >= V || nbr_idx <0){
216  fprintf(stderr, "ERROR: invliad nbr_idx(%d), line=%d\n", nbr_idx, __LINE__);
217  exit(1);
218  }
219  su3_matrix* A = sitelink[nu] + nbr_idx;
220 
221  su3_matrix* B;
222  if (use_staple){
223  B = mulink + nbr_idx;
224  }else{
225  B = mulink + nbr_idx;
226  }
227 
228  memset(dx, 0, sizeof(dx));
229  dx[mu] = 1;
230  nbr_idx = neighborIndexFullLattice(nbr_idx, dx[3], dx[2],dx[1],dx[0]);
231  su3_matrix* C = sitelink[nu] + nbr_idx;
232 
233  llfat_mult_su3_an( A, B,&tmat1);
234  llfat_mult_su3_nn( &tmat1, C,&tmat2);
235 
236  if(staple!=NULL){/* Save the staple */
237  llfat_add_su3_matrix(&staple[i], &tmat2, &staple[i]);
238  llfat_scalar_mult_add_su3_matrix(fat1, &staple[i], coef, fat1);
239 
240  } else{ /* No need to save the staple. Add it to the fatlinks */
241  llfat_scalar_mult_add_su3_matrix(fat1, &tmat2, coef, fat1);
242  }
243  }
244 
245 } /* compute_gen_staple_site */
246 
247 
248 
249 /* Optimized fattening code for the Asq and Asqtad actions.
250  * I assume that:
251  * path 0 is the one link
252  * path 2 the 3-staple
253  * path 3 the 5-staple
254  * path 4 the 7-staple
255  * path 5 the Lapage term.
256  * Path 1 is the Naik term
257  *
258  */
259 template <typename su3_matrix, typename Float>
260 void llfat_cpu(void** fatlink, su3_matrix** sitelink, Float* act_path_coeff)
261 {
262 
263  su3_matrix* staple = (su3_matrix *)malloc(V*sizeof(su3_matrix));
264  if(staple == NULL){
265  fprintf(stderr, "Error: malloc failed for staple in function %s\n", __FUNCTION__);
266  exit(1);
267  }
268 
269  su3_matrix* tempmat1 = (su3_matrix *)malloc(V*sizeof(su3_matrix));
270  if(tempmat1 == NULL){
271  fprintf(stderr, "ERROR: malloc failed for tempmat1 in function %s\n", __FUNCTION__);
272  exit(1);
273  }
274 
275  /* to fix up the Lepage term, included by a trick below */
276  Float one_link = (act_path_coeff[0] - 6.0*act_path_coeff[5]);
277 
278 
279  for (int dir=XUP; dir<=TUP; dir++){
280 
281  /* Intialize fat links with c_1*U_\mu(x) */
282  for(int i=0;i < V;i ++){
283  su3_matrix* fat1 = ((su3_matrix*)fatlink[dir]) + i;
284  llfat_scalar_mult_su3_matrix(sitelink[dir] + i, one_link, fat1 );
285  }
286  }
287 
288 
289 
290 
291  for (int dir=XUP; dir<=TUP; dir++){
292  for(int nu=XUP; nu<=TUP; nu++){
293  if(nu!=dir){
294  llfat_compute_gen_staple_field(staple,dir,nu,sitelink[dir], sitelink,fatlink, act_path_coeff[2], 0);
295 
296  /* The Lepage term */
297  /* Note this also involves modifying c_1 (above) */
298 
299  llfat_compute_gen_staple_field((su3_matrix*)NULL,dir,nu,staple,sitelink, fatlink, act_path_coeff[5],1);
300 
301  for(int rho=XUP; rho<=TUP; rho++) {
302  if((rho!=dir)&&(rho!=nu)){
303  llfat_compute_gen_staple_field( tempmat1, dir, rho, staple,sitelink,fatlink, act_path_coeff[3], 1);
304 
305  for(int sig=XUP; sig<=TUP; sig++){
306  if((sig!=dir)&&(sig!=nu)&&(sig!=rho)){
307  llfat_compute_gen_staple_field((su3_matrix*)NULL,dir,sig,tempmat1,sitelink,fatlink, act_path_coeff[4], 1);
308  }
309  }/* sig */
310 
311  }
312 
313  }/* rho */
314  }
315 
316  }/* nu */
317 
318  }/* dir */
319 
320 
321  free(staple);
322  free(tempmat1);
323 
324 }
325 
326 
327 
328 void
329 llfat_reference(void** fatlink, void** sitelink, QudaPrecision prec, void* act_path_coeff)
330 {
331  Vs[0] = Vs_x;
332  Vs[1] = Vs_y;
333  Vs[2] = Vs_z;
334  Vs[3] = Vs_t;
335 
336  Vsh[0] = Vsh_x;
337  Vsh[1] = Vsh_y;
338  Vsh[2] = Vsh_z;
339  Vsh[3] = Vsh_t;
340 
341 
342  switch(prec){
343  case QUDA_DOUBLE_PRECISION:{
344  llfat_cpu((void**)fatlink, (dsu3_matrix**)sitelink, (double*) act_path_coeff);
345  break;
346  }
347  case QUDA_SINGLE_PRECISION:{
348  llfat_cpu((void**)fatlink, (fsu3_matrix**)sitelink, (float*) act_path_coeff);
349  break;
350  }
351  default:
352  fprintf(stderr, "ERROR: unsupported precision(%d)\n", prec);
353  exit(1);
354  break;
355 
356  }
357 
358  return;
359 
360 }
361 
362 #ifdef MULTI_GPU
363 
364 template<typename su3_matrix, typename Real>
365 void
366 llfat_compute_gen_staple_field_mg(su3_matrix *staple, int mu, int nu,
367  su3_matrix* mulink, su3_matrix** ghost_mulink,
368  su3_matrix** sitelink, su3_matrix** ghost_sitelink, su3_matrix** ghost_sitelink_diag,
369  void** fatlink, Real coef,
370  int use_staple)
371 {
372  su3_matrix tmat1,tmat2;
373  int i ;
374  su3_matrix *fat1;
375 
376 
377  int X1 = Z[0];
378  int X2 = Z[1];
379  int X3 = Z[2];
380  //int X4 = Z[3];
381  int X1h =X1/2;
382 
383  int X2X1 = X1*X2;
384  int X3X2 = X3*X2;
385  int X3X1 = X3*X1;
386 
387  /* Upper staple */
388  /* Computes the staple :
389  * mu (B)
390  * +-------+
391  * nu | |
392  * (A) | |(C)
393  * X X
394  *
395  * Where the mu link can be any su3_matrix. The result is saved in staple.
396  * if staple==NULL then the result is not saved.
397  * It also adds the computed staple to the fatlink[mu] with weight coef.
398  */
399 
400  int dx[4];
401 
402  /* upper staple */
403 
404  for(i=0;i < V;i++){
405 
406  int half_index = i;
407  int oddBit =0;
408  if (i >= Vh){
409  oddBit = 1;
410  half_index = i -Vh;
411  }
412  //int x4 = x4_from_full_index(i);
413 
414 
415 
416  int sid =half_index;
417  int za = sid/X1h;
418  int x1h = sid - za*X1h;
419  int zb = za/X2;
420  int x2 = za - zb*X2;
421  int x4 = zb/X3;
422  int x3 = zb - x4*X3;
423  int x1odd = (x2 + x3 + x4 + oddBit) & 1;
424  int x1 = 2*x1h + x1odd;
425  int x[4] = {x1,x2,x3,x4};
426  int space_con[4]={
427  (x4*X3X2+x3*X2+x2)/2,
428  (x4*X3X1+x3*X1+x1)/2,
429  (x4*X2X1+x2*X1+x1)/2,
430  (x3*X2X1+x2*X1+x1)/2
431  };
432 
433  fat1 = ((su3_matrix*)fatlink[mu]) + i;
434  su3_matrix* A = sitelink[nu] + i;
435 
436  memset(dx, 0, sizeof(dx));
437  dx[nu] =1;
438  int nbr_idx;
439 
440  su3_matrix* B;
441  if (use_staple){
442  if (x[nu] + dx[nu] >= Z[nu]){
443  B = ghost_mulink[nu] + Vs[nu] + (1-oddBit)*Vsh[nu] + space_con[nu];
444  }else{
445  nbr_idx = neighborIndexFullLattice_mg(i, dx[3], dx[2], dx[1], dx[0]);
446  B = mulink + nbr_idx;
447  }
448  }else{
449  if(x[nu]+dx[nu] >= Z[nu]){ //out of boundary, use ghost data
450  B = ghost_sitelink[nu] + 4*Vs[nu] + mu*Vs[nu] + (1-oddBit)*Vsh[nu] + space_con[nu];
451  }else{
452  nbr_idx = neighborIndexFullLattice_mg(i, dx[3], dx[2], dx[1], dx[0]);
453  B = sitelink[mu] + nbr_idx;
454  }
455  }
456 
457 
458  //we could be in the ghost link area if mu is T and we are at high T boundary
459  su3_matrix* C;
460  memset(dx, 0, sizeof(dx));
461  dx[mu] =1;
462  if(x[mu] + dx[mu] >= Z[mu]){ //out of boundary, use ghost data
463  C = ghost_sitelink[mu] + 4*Vs[mu] + nu*Vs[mu] + (1-oddBit)*Vsh[mu] + space_con[mu];
464  }else{
465  nbr_idx = neighborIndexFullLattice_mg(i, dx[3], dx[2],dx[1],dx[0]);
466  C = sitelink[nu] + nbr_idx;
467  }
468 
469  llfat_mult_su3_nn( A, B,&tmat1);
470 
471  if(staple!=NULL){/* Save the staple */
472  llfat_mult_su3_na( &tmat1, C, &staple[i]);
473  } else{ /* No need to save the staple. Add it to the fatlinks */
474  llfat_mult_su3_na( &tmat1, C, &tmat2);
475  llfat_scalar_mult_add_su3_matrix(fat1, &tmat2, coef, fat1);
476  }
477  }
478  /***************lower staple****************
479  *
480  * X X
481  * nu | |
482  * (A) | |(C)
483  * +-------+
484  * mu (B)
485  *
486  *********************************************/
487 
488  for(i=0;i < V;i++){
489 
490  int half_index = i;
491  int oddBit =0;
492  if (i >= Vh){
493  oddBit = 1;
494  half_index = i -Vh;
495  }
496 
497  int sid =half_index;
498  int za = sid/X1h;
499  int x1h = sid - za*X1h;
500  int zb = za/X2;
501  int x2 = za - zb*X2;
502  int x4 = zb/X3;
503  int x3 = zb - x4*X3;
504  int x1odd = (x2 + x3 + x4 + oddBit) & 1;
505  int x1 = 2*x1h + x1odd;
506  int x[4] = {x1,x2,x3,x4};
507  int space_con[4]={
508  (x4*X3X2+x3*X2+x2)/2,
509  (x4*X3X1+x3*X1+x1)/2,
510  (x4*X2X1+x2*X1+x1)/2,
511  (x3*X2X1+x2*X1+x1)/2
512  };
513 
514  //int x4 = x4_from_full_index(i);
515 
516  fat1 = ((su3_matrix*)fatlink[mu]) + i;
517 
518  //we could be in the ghost link area if nu is T and we are at low T boundary
519  su3_matrix* A;
520  memset(dx, 0, sizeof(dx));
521  dx[nu] = -1;
522 
523  int nbr_idx;
524  if(x[nu] + dx[nu] < 0){ //out of boundary, use ghost data
525  A = ghost_sitelink[nu] + nu*Vs[nu] + (1-oddBit)*Vsh[nu] + space_con[nu];
526  }else{
527  nbr_idx = neighborIndexFullLattice_mg(i, dx[3], dx[2], dx[1], dx[0]);
528  A = sitelink[nu] + nbr_idx;
529  }
530 
531  su3_matrix* B;
532  if (use_staple){
533  nbr_idx = neighborIndexFullLattice_mg(i, dx[3], dx[2], dx[1], dx[0]);
534  if (x[nu] + dx[nu] < 0){
535  B = ghost_mulink[nu] + (1-oddBit)*Vsh[nu] + space_con[nu];
536  }else{
537  B = mulink + nbr_idx;
538  }
539  }else{
540  if(x[nu] + dx[nu] < 0){ //out of boundary, use ghost data
541  B = ghost_sitelink[nu] + mu*Vs[nu] + (1-oddBit)*Vsh[nu] + space_con[nu];
542  }else{
543  nbr_idx = neighborIndexFullLattice_mg(i, dx[3], dx[2], dx[1], dx[0]);
544  B = sitelink[mu] + nbr_idx;
545  }
546  }
547 
548  //we could be in the ghost link area if nu is T and we are at low T boundary
549  // or mu is T and we are on high T boundary
550  su3_matrix* C;
551  memset(dx, 0, sizeof(dx));
552  dx[nu] = -1;
553  dx[mu] = 1;
554  nbr_idx = neighborIndexFullLattice_mg(i, dx[3], dx[2],dx[1],dx[0]);
555 
556  //space con must be recomputed because we have coodinates change in 2 directions
557  int new_x1, new_x2, new_x3, new_x4;
558  new_x1 = (x[0] + dx[0] + Z[0])%Z[0];
559  new_x2 = (x[1] + dx[1] + Z[1])%Z[1];
560  new_x3 = (x[2] + dx[2] + Z[2])%Z[2];
561  new_x4 = (x[3] + dx[3] + Z[3])%Z[3];
562  int new_x[4] = {new_x1, new_x2, new_x3, new_x4};
563  space_con[0] = (new_x4*X3X2 + new_x3*X2 + new_x2)/2;
564  space_con[1] = (new_x4*X3X1 + new_x3*X1 + new_x1)/2;
565  space_con[2] = (new_x4*X2X1 + new_x2*X1 + new_x1)/2;
566  space_con[3] = (new_x3*X2X1 + new_x2*X1 + new_x1)/2;
567 
568  if( (x[nu] + dx[nu]) < 0 && (x[mu] + dx[mu] >= Z[mu])){
569  //find the other 2 directions, dir1, dir2
570  //with dir2 the slowest changing direction
571  int dir1, dir2; //other two dimensions
572  for(dir1=0; dir1 < 4; dir1 ++){
573  if(dir1 != nu && dir1 != mu){
574  break;
575  }
576  }
577  for(dir2=0; dir2 < 4; dir2 ++){
578  if(dir2 != nu && dir2 != mu && dir2 != dir1){
579  break;
580  }
581  }
582  C = ghost_sitelink_diag[nu*4+mu] + oddBit*Z[dir1]*Z[dir2]/2 + (new_x[dir2]*Z[dir1]+new_x[dir1])/2;
583  }else if (x[nu] + dx[nu] < 0){
584  C = ghost_sitelink[nu] + nu*Vs[nu] + oddBit*Vsh[nu]+ space_con[nu];
585  }else if (x[mu] + dx[mu] >= Z[mu]){
586  C = ghost_sitelink[mu] + 4*Vs[mu] + nu*Vs[mu] + oddBit*Vsh[mu]+space_con[mu];
587  }else{
588  C = sitelink[nu] + nbr_idx;
589  }
590  llfat_mult_su3_an( A, B,&tmat1);
591  llfat_mult_su3_nn( &tmat1, C,&tmat2);
592 
593  if(staple!=NULL){/* Save the staple */
594  llfat_add_su3_matrix(&staple[i], &tmat2, &staple[i]);
595  llfat_scalar_mult_add_su3_matrix(fat1, &staple[i], coef, fat1);
596 
597  } else{ /* No need to save the staple. Add it to the fatlinks */
598  llfat_scalar_mult_add_su3_matrix(fat1, &tmat2, coef, fat1);
599  }
600  }
601 
602 } /* compute_gen_staple_site */
603 
604 
605 template <typename su3_matrix, typename Float>
606 void llfat_cpu_mg(void** fatlink, su3_matrix** sitelink, su3_matrix** ghost_sitelink,
607  su3_matrix** ghost_sitelink_diag, Float* act_path_coeff)
608 {
610  if (sizeof(Float) == 4){
611  prec = QUDA_SINGLE_PRECISION;
612  }else{
613  prec = QUDA_DOUBLE_PRECISION;
614  }
615 
616  su3_matrix* staple = (su3_matrix *)malloc(V*sizeof(su3_matrix));
617  if(staple == NULL){
618  fprintf(stderr, "Error: malloc failed for staple in function %s\n", __FUNCTION__);
619  exit(1);
620  }
621 
622 
623  su3_matrix* ghost_staple[4];
624  su3_matrix* ghost_staple1[4];
625 
626  for(int i=0;i < 4;i++){
627  ghost_staple[i] = (su3_matrix*)malloc(2*Vs[i]*sizeof(su3_matrix));
628  if (ghost_staple[i] == NULL){
629  fprintf(stderr, "Error: malloc failed for ghost staple in function %s\n", __FUNCTION__);
630  exit(1);
631  }
632 
633  ghost_staple1[i] = (su3_matrix*)malloc(2*Vs[i]*sizeof(su3_matrix));
634  if (ghost_staple1[i] == NULL){
635  fprintf(stderr, "Error: malloc failed for ghost staple1 in function %s\n", __FUNCTION__);
636  exit(1);
637  }
638  }
639 
640  su3_matrix* tempmat1 = (su3_matrix *)malloc(V*sizeof(su3_matrix));
641  if(tempmat1 == NULL){
642  fprintf(stderr, "ERROR: malloc failed for tempmat1 in function %s\n", __FUNCTION__);
643  exit(1);
644  }
645 
646  /* to fix up the Lepage term, included by a trick below */
647  Float one_link = (act_path_coeff[0] - 6.0*act_path_coeff[5]);
648 
649 
650  for (int dir=XUP; dir<=TUP; dir++){
651 
652  /* Intialize fat links with c_1*U_\mu(x) */
653  for(int i=0;i < V;i ++){
654  su3_matrix* fat1 = ((su3_matrix*)fatlink[dir]) + i;
655  llfat_scalar_mult_su3_matrix(sitelink[dir] + i, one_link, fat1 );
656  }
657  }
658 
659  for (int dir=XUP; dir<=TUP; dir++){
660  for(int nu=XUP; nu<=TUP; nu++){
661  if(nu!=dir){
662  llfat_compute_gen_staple_field_mg(staple,dir,nu,
663  sitelink[dir], (su3_matrix**)NULL,
664  sitelink, ghost_sitelink, ghost_sitelink_diag,
665  fatlink, act_path_coeff[2], 0);
666  /* The Lepage term */
667  /* Note this also involves modifying c_1 (above) */
668 
669  exchange_cpu_staple(Z, staple, (void**)ghost_staple, prec);
670 
671  llfat_compute_gen_staple_field_mg((su3_matrix*)NULL,dir,nu,
672  staple,ghost_staple,
673  sitelink, ghost_sitelink, ghost_sitelink_diag,
674  fatlink, act_path_coeff[5],1);
675 
676  for(int rho=XUP; rho<=TUP; rho++) {
677  if((rho!=dir)&&(rho!=nu)){
678  llfat_compute_gen_staple_field_mg( tempmat1, dir, rho,
679  staple,ghost_staple,
680  sitelink, ghost_sitelink, ghost_sitelink_diag,
681  fatlink, act_path_coeff[3], 1);
682 
683 
684  exchange_cpu_staple(Z, tempmat1, (void**)ghost_staple1, prec);
685 
686  for(int sig=XUP; sig<=TUP; sig++){
687  if((sig!=dir)&&(sig!=nu)&&(sig!=rho)){
688 
689  llfat_compute_gen_staple_field_mg((su3_matrix*)NULL,dir,sig,
690  tempmat1, ghost_staple1,
691  sitelink, ghost_sitelink, ghost_sitelink_diag,
692  fatlink, act_path_coeff[4], 1);
693  //FIXME
694  //return;
695 
696  }
697  }/* sig */
698  }
699  }/* rho */
700  }
701 
702  }/* nu */
703 
704  }/* dir */
705 
706  free(staple);
707  for(int i=0;i < 4;i++){
708  free(ghost_staple[i]);
709  free(ghost_staple1[i]);
710  }
711  free(tempmat1);
712 
713 }
714 
715 
716 
717 void
718 llfat_reference_mg(void** fatlink, void** sitelink, void** ghost_sitelink,
719  void** ghost_sitelink_diag, QudaPrecision prec, void* act_path_coeff)
720 {
721 
722  Vs[0] = Vs_x;
723  Vs[1] = Vs_y;
724  Vs[2] = Vs_z;
725  Vs[3] = Vs_t;
726 
727  Vsh[0] = Vsh_x;
728  Vsh[1] = Vsh_y;
729  Vsh[2] = Vsh_z;
730  Vsh[3] = Vsh_t;
731 
732  switch(prec){
733  case QUDA_DOUBLE_PRECISION:{
734  llfat_cpu_mg((void**)fatlink, (dsu3_matrix**)sitelink, (dsu3_matrix**)ghost_sitelink,
735  (dsu3_matrix**)ghost_sitelink_diag, (double*) act_path_coeff);
736  break;
737  }
738  case QUDA_SINGLE_PRECISION:{
739  llfat_cpu_mg((void**)fatlink, (fsu3_matrix**)sitelink, (fsu3_matrix**)ghost_sitelink,
740  (fsu3_matrix**)ghost_sitelink_diag, (float*) act_path_coeff);
741  break;
742  }
743  default:
744  fprintf(stderr, "ERROR: unsupported precision(%d)\n", prec);
745  exit(1);
746  break;
747 
748  }
749 
750  return;
751 
752 }
753 
754 
755 #endif