QUDA v0.4.0
A library for QCD on GPUs
quda/tests/llfat_reference.cpp
Go to the documentation of this file.
00001 #include <stdio.h>
00002 #include <stdlib.h>
00003 #include <math.h>
00004 
00005 #include <quda.h>
00006 #include <test_util.h>
00007 #include "llfat_reference.h"
00008 #include "misc.h"
00009 #include <string.h>
00010 
00011 #include <quda_internal.h>
00012 #include "face_quda.h"
00013 
00014 #define XUP 0
00015 #define YUP 1
00016 #define ZUP 2
00017 #define TUP 3
00018 
00019 typedef struct {   
00020   float real;      
00021   float imag; 
00022 } fcomplex;  
00023 
00024 /* specific for double complex */
00025 typedef struct {
00026   double real;
00027   double imag;
00028 } dcomplex;
00029 
00030 typedef struct { fcomplex e[3][3]; } fsu3_matrix;
00031 typedef struct { fcomplex c[3]; } fsu3_vector;
00032 typedef struct { dcomplex e[3][3]; } dsu3_matrix;
00033 typedef struct { dcomplex c[3]; } dsu3_vector;
00034 
00035 
00036 #define CADD(a,b,c) { (c).real = (a).real + (b).real;   \
00037     (c).imag = (a).imag + (b).imag; }
00038 #define CMUL(a,b,c) { (c).real = (a).real*(b).real - (a).imag*(b).imag; \
00039     (c).imag = (a).real*(b).imag + (a).imag*(b).real; }
00040 #define CSUM(a,b) { (a).real += (b).real; (a).imag += (b).imag; }
00041 
00042 /* c = a* * b */
00043 #define CMULJ_(a,b,c) { (c).real = (a).real*(b).real + (a).imag*(b).imag; \
00044     (c).imag = (a).real*(b).imag - (a).imag*(b).real; }
00045 
00046 /* c = a * b* */
00047 #define CMUL_J(a,b,c) { (c).real = (a).real*(b).real + (a).imag*(b).imag; \
00048     (c).imag = (a).imag*(b).real - (a).real*(b).imag; }
00049 
00050 extern int Z[4];
00051 extern int V;
00052 extern int Vh;
00053 extern int Vs[];
00054 extern int Vsh[];
00055 extern int Vs_x, Vs_y, Vs_z, Vs_t;
00056 extern int Vsh_x, Vsh_y, Vsh_z, Vsh_t;
00057 
00058 
00059 template<typename su3_matrix, typename Real>
00060 void 
00061 llfat_scalar_mult_su3_matrix( su3_matrix *a, Real s, su3_matrix *b )
00062 {
00063     
00064   int i,j;
00065   for(i=0;i<3;i++)for(j=0;j<3;j++){
00066       b->e[i][j].real = s*a->e[i][j].real;
00067       b->e[i][j].imag = s*a->e[i][j].imag;
00068     }
00069     
00070   return;
00071 }
00072 
00073 template<typename su3_matrix, typename Real>
00074 void
00075 llfat_scalar_mult_add_su3_matrix(su3_matrix *a,su3_matrix *b, Real s, su3_matrix *c)
00076 {    
00077   int i,j;
00078   for(i=0;i<3;i++)for(j=0;j<3;j++){
00079       c->e[i][j].real = a->e[i][j].real + s*b->e[i][j].real;
00080       c->e[i][j].imag = a->e[i][j].imag + s*b->e[i][j].imag;
00081     }
00082     
00083 }
00084 
00085 template <typename su3_matrix>
00086 void 
00087 llfat_mult_su3_na(  su3_matrix *a, su3_matrix *b, su3_matrix *c )
00088 {
00089   int i,j,k;
00090   typeof(a->e[0][0]) x,y;
00091   for(i=0;i<3;i++)for(j=0;j<3;j++){
00092       x.real=x.imag=0.0;
00093       for(k=0;k<3;k++){
00094         CMUL_J( a->e[i][k] , b->e[j][k] , y );
00095         CSUM( x , y );
00096       }
00097       c->e[i][j] = x;
00098     }
00099 }
00100 
00101 template <typename su3_matrix>
00102 void
00103 llfat_mult_su3_nn( su3_matrix *a, su3_matrix *b, su3_matrix *c )
00104 {
00105   int i,j,k;
00106   typeof(a->e[0][0]) x,y;
00107   for(i=0;i<3;i++)for(j=0;j<3;j++){
00108       x.real=x.imag=0.0;
00109       for(k=0;k<3;k++){
00110         CMUL( a->e[i][k] , b->e[k][j] , y );
00111         CSUM( x , y );
00112       }
00113       c->e[i][j] = x;
00114     }
00115 }
00116 
00117 template<typename su3_matrix>
00118 void
00119 llfat_mult_su3_an( su3_matrix *a, su3_matrix *b, su3_matrix *c )
00120 {
00121   int i,j,k;
00122   typeof(a->e[0][0]) x,y;
00123   for(i=0;i<3;i++)for(j=0;j<3;j++){
00124       x.real=x.imag=0.0;
00125       for(k=0;k<3;k++){
00126         CMULJ_( a->e[k][i] , b->e[k][j], y );
00127         CSUM( x , y );
00128       }
00129       c->e[i][j] = x;
00130     }
00131 }
00132 
00133 
00134 
00135 
00136 
00137 template<typename su3_matrix>
00138 void 
00139 llfat_add_su3_matrix( su3_matrix *a, su3_matrix *b, su3_matrix *c ) 
00140 {
00141   int i,j;
00142   for(i=0;i<3;i++)for(j=0;j<3;j++){
00143       CADD( a->e[i][j], b->e[i][j], c->e[i][j] );
00144     }
00145 }
00146 
00147 
00148 
00149 template<typename su3_matrix, typename Real>
00150 void 
00151 llfat_compute_gen_staple_field(su3_matrix *staple, int mu, int nu, 
00152                                su3_matrix* mulink, su3_matrix** sitelink, void** fatlink, Real coef,
00153                                int use_staple) 
00154 {
00155   su3_matrix tmat1,tmat2;
00156   int i ;
00157   su3_matrix *fat1;
00158     
00159   /* Upper staple */
00160   /* Computes the staple :
00161    *                mu (B)
00162    *               +-------+
00163    *       nu      |       | 
00164    *         (A)   |       |(C)
00165    *               X       X
00166    *
00167    * Where the mu link can be any su3_matrix. The result is saved in staple.
00168    * if staple==NULL then the result is not saved.
00169    * It also adds the computed staple to the fatlink[mu] with weight coef.
00170    */
00171     
00172   int dx[4];
00173 
00174   /* upper staple */
00175     
00176   for(i=0;i < V;i++){       
00177         
00178     fat1 = ((su3_matrix*)fatlink[mu]) + i;
00179     su3_matrix* A = sitelink[nu] + i;
00180         
00181     memset(dx, 0, sizeof(dx));
00182     dx[nu] =1;
00183     int nbr_idx = neighborIndexFullLattice(i, dx[3], dx[2], dx[1], dx[0]);
00184     su3_matrix* B;
00185     if (use_staple){
00186       B = mulink + nbr_idx;
00187     }else{
00188       B = mulink + nbr_idx;
00189     }
00190         
00191     memset(dx, 0, sizeof(dx));
00192     dx[mu] =1;
00193     nbr_idx = neighborIndexFullLattice(i, dx[3], dx[2],dx[1],dx[0]);
00194     su3_matrix* C = sitelink[nu] + nbr_idx;
00195         
00196     llfat_mult_su3_nn( A, B,&tmat1);
00197         
00198     if(staple!=NULL){/* Save the staple */
00199       llfat_mult_su3_na( &tmat1, C, &staple[i]);            
00200     } else{ /* No need to save the staple. Add it to the fatlinks */
00201       llfat_mult_su3_na( &tmat1, C, &tmat2);        
00202       llfat_scalar_mult_add_su3_matrix(fat1, &tmat2, coef, fat1);           
00203     }
00204   }    
00205   /***************lower staple****************
00206    *
00207    *               X       X
00208    *       nu      |       | 
00209    *         (A)   |       |(C)
00210    *               +-------+
00211    *                mu (B)
00212    *
00213    *********************************************/
00214 
00215   for(i=0;i < V;i++){       
00216         
00217     fat1 = ((su3_matrix*)fatlink[mu]) + i;
00218     memset(dx, 0, sizeof(dx));
00219     dx[nu] = -1;
00220     int nbr_idx = neighborIndexFullLattice(i, dx[3], dx[2], dx[1], dx[0]);      
00221     if (nbr_idx >= V || nbr_idx <0){
00222       fprintf(stderr, "ERROR: invliad nbr_idx(%d), line=%d\n", nbr_idx, __LINE__);
00223       exit(1);
00224     }
00225     su3_matrix* A = sitelink[nu] + nbr_idx;
00226         
00227     su3_matrix* B;
00228     if (use_staple){
00229       B = mulink + nbr_idx;
00230     }else{
00231       B = mulink + nbr_idx;
00232     }
00233         
00234     memset(dx, 0, sizeof(dx));
00235     dx[mu] = 1;
00236     nbr_idx = neighborIndexFullLattice(nbr_idx, dx[3], dx[2],dx[1],dx[0]);
00237     su3_matrix* C = sitelink[nu] + nbr_idx;
00238 
00239     llfat_mult_su3_an( A, B,&tmat1);    
00240     llfat_mult_su3_nn( &tmat1, C,&tmat2);
00241         
00242     if(staple!=NULL){/* Save the staple */
00243       llfat_add_su3_matrix(&staple[i], &tmat2, &staple[i]);
00244       llfat_scalar_mult_add_su3_matrix(fat1, &staple[i], coef, fat1);
00245             
00246     } else{ /* No need to save the staple. Add it to the fatlinks */
00247       llfat_scalar_mult_add_su3_matrix(fat1, &tmat2, coef, fat1);           
00248     }
00249   } 
00250     
00251 } /* compute_gen_staple_site */
00252 
00253 
00254 
00255 /*  Optimized fattening code for the Asq and Asqtad actions.           
00256  *  I assume that: 
00257  *  path 0 is the one link
00258  *  path 2 the 3-staple
00259  *  path 3 the 5-staple 
00260  *  path 4 the 7-staple
00261  *  path 5 the Lapage term.
00262  *  Path 1 is the Naik term
00263  *
00264  */
00265 template <typename su3_matrix, typename Float>
00266 void llfat_cpu(void** fatlink, su3_matrix** sitelink, Float* act_path_coeff)
00267 {
00268 
00269   su3_matrix* staple = (su3_matrix *)malloc(V*sizeof(su3_matrix));
00270   if(staple == NULL){
00271     fprintf(stderr, "Error: malloc failed for staple in function %s\n", __FUNCTION__);
00272     exit(1);
00273   }
00274     
00275   su3_matrix* tempmat1 = (su3_matrix *)malloc(V*sizeof(su3_matrix));
00276   if(tempmat1 == NULL){
00277     fprintf(stderr, "ERROR:  malloc failed for tempmat1 in function %s\n", __FUNCTION__);
00278     exit(1);
00279   }
00280     
00281   /* to fix up the Lepage term, included by a trick below */
00282   Float one_link = (act_path_coeff[0] - 6.0*act_path_coeff[5]);
00283     
00284 
00285   for (int dir=XUP; dir<=TUP; dir++){
00286 
00287     /* Intialize fat links with c_1*U_\mu(x) */
00288     for(int i=0;i < V;i ++){
00289       su3_matrix* fat1 = ((su3_matrix*)fatlink[dir]) +  i;
00290       llfat_scalar_mult_su3_matrix(sitelink[dir] + i, one_link, fat1 );
00291     }
00292   }
00293 
00294 
00295 
00296 
00297   for (int dir=XUP; dir<=TUP; dir++){
00298     for(int nu=XUP; nu<=TUP; nu++){
00299       if(nu!=dir){
00300         llfat_compute_gen_staple_field(staple,dir,nu,sitelink[dir], sitelink,fatlink, act_path_coeff[2], 0);
00301 
00302         /* The Lepage term */
00303         /* Note this also involves modifying c_1 (above) */
00304                 
00305         llfat_compute_gen_staple_field((su3_matrix*)NULL,dir,nu,staple,sitelink, fatlink, act_path_coeff[5],1);
00306                 
00307         for(int rho=XUP; rho<=TUP; rho++) {
00308           if((rho!=dir)&&(rho!=nu)){
00309             llfat_compute_gen_staple_field( tempmat1, dir, rho, staple,sitelink,fatlink, act_path_coeff[3], 1);
00310             
00311             for(int sig=XUP; sig<=TUP; sig++){
00312               if((sig!=dir)&&(sig!=nu)&&(sig!=rho)){
00313                 llfat_compute_gen_staple_field((su3_matrix*)NULL,dir,sig,tempmat1,sitelink,fatlink, act_path_coeff[4], 1);
00314               } 
00315             }/* sig */
00316 
00317           } 
00318 
00319         }/* rho */
00320       } 
00321 
00322     }/* nu */
00323         
00324   }/* dir */      
00325 
00326 
00327   free(staple);
00328   free(tempmat1);
00329 
00330 }
00331 
00332 
00333 
00334 void
00335 llfat_reference(void** fatlink, void** sitelink, QudaPrecision prec, void* act_path_coeff)
00336 {
00337   switch(prec){
00338   case QUDA_DOUBLE_PRECISION:{
00339     llfat_cpu((void**)fatlink, (dsu3_matrix**)sitelink, (double*) act_path_coeff);
00340     break;
00341   }
00342   case QUDA_SINGLE_PRECISION:{
00343     llfat_cpu((void**)fatlink, (fsu3_matrix**)sitelink, (float*) act_path_coeff);
00344     break;
00345   }
00346   default:
00347     fprintf(stderr, "ERROR: unsupported precision(%d)\n", prec);
00348     exit(1);
00349     break;
00350         
00351   }
00352 
00353   return;
00354 
00355 }
00356 
00357 #ifdef MULTI_GPU
00358 
00359 template<typename su3_matrix, typename Real>
00360 void 
00361 llfat_compute_gen_staple_field_mg(su3_matrix *staple, int mu, int nu, 
00362                                   su3_matrix* mulink, su3_matrix** ghost_mulink, 
00363                                   su3_matrix** sitelink, su3_matrix** ghost_sitelink, su3_matrix** ghost_sitelink_diag, 
00364                                   void** fatlink, Real coef,
00365                                   int use_staple) 
00366 {
00367   su3_matrix tmat1,tmat2;
00368   int i ;
00369   su3_matrix *fat1;
00370   
00371 
00372   int X1 = Z[0];  
00373   int X2 = Z[1];
00374   int X3 = Z[2];
00375   int X4 = Z[3];
00376   int X1h =X1/2;
00377   
00378   int X2X1 = X1*X2;
00379   int X3X2 = X3*X2;
00380   int X3X1 = X3*X1;  
00381 
00382   /* Upper staple */
00383   /* Computes the staple :
00384    *                mu (B)
00385    *               +-------+
00386    *       nu      |       | 
00387    *         (A)   |       |(C)
00388    *               X       X
00389    *
00390    * Where the mu link can be any su3_matrix. The result is saved in staple.
00391    * if staple==NULL then the result is not saved.
00392    * It also adds the computed staple to the fatlink[mu] with weight coef.
00393    */
00394     
00395   int dx[4];
00396 
00397   /* upper staple */
00398     
00399   for(i=0;i < V;i++){       
00400         
00401     int half_index = i;
00402     int oddBit =0;
00403     if (i >= Vh){
00404       oddBit = 1;
00405       half_index = i -Vh;
00406     }
00407     //int x4 = x4_from_full_index(i);
00408 
00409 
00410     
00411     int sid =half_index;
00412     int za = sid/X1h;
00413     int x1h = sid - za*X1h;
00414     int zb = za/X2;
00415     int x2 = za - zb*X2;
00416     int x4 = zb/X3;
00417     int x3 = zb - x4*X3;
00418     int x1odd = (x2 + x3 + x4 + oddBit) & 1;
00419     int x1 = 2*x1h + x1odd;
00420     int x[4] = {x1,x2,x3,x4};
00421     int space_con[4]={
00422       (x4*X3X2+x3*X2+x2)/2,
00423       (x4*X3X1+x3*X1+x1)/2,
00424       (x4*X2X1+x2*X1+x1)/2,
00425       (x3*X2X1+x2*X1+x1)/2
00426     };
00427 
00428     fat1 = ((su3_matrix*)fatlink[mu]) + i;
00429     su3_matrix* A = sitelink[nu] + i;
00430         
00431     memset(dx, 0, sizeof(dx));
00432     dx[nu] =1;
00433     int nbr_idx;
00434     
00435     su3_matrix* B;  
00436     if (use_staple){
00437       if (x[nu] + dx[nu]  >= Z[nu]){
00438         B =  ghost_mulink[nu] + Vs[nu] + (1-oddBit)*Vsh[nu] + space_con[nu];
00439       }else{
00440         nbr_idx = neighborIndexFullLattice_mg(i, dx[3], dx[2], dx[1], dx[0]);     
00441         B = mulink + nbr_idx;
00442       }
00443     }else{      
00444       if(x[nu]+dx[nu] >= Z[nu]){ //out of boundary, use ghost data
00445         B = ghost_sitelink[nu] + 4*Vs[nu] + mu*Vs[nu] + (1-oddBit)*Vsh[nu] + space_con[nu];
00446       }else{
00447         nbr_idx = neighborIndexFullLattice_mg(i, dx[3], dx[2], dx[1], dx[0]);   
00448         B = sitelink[mu] + nbr_idx;
00449       }
00450     }
00451         
00452 
00453     //we could be in the ghost link area if mu is T and we are at high T boundary
00454     su3_matrix* C;
00455     memset(dx, 0, sizeof(dx));
00456     dx[mu] =1;    
00457     if(x[mu] + dx[mu] >= Z[mu]){ //out of boundary, use ghost data
00458       C = ghost_sitelink[mu] + 4*Vs[mu] + nu*Vs[mu] + (1-oddBit)*Vsh[mu] + space_con[mu];
00459     }else{
00460       nbr_idx = neighborIndexFullLattice_mg(i, dx[3], dx[2],dx[1],dx[0]);    
00461       C = sitelink[nu] + nbr_idx;
00462     }
00463 
00464     llfat_mult_su3_nn( A, B,&tmat1);
00465         
00466     if(staple!=NULL){/* Save the staple */
00467       llfat_mult_su3_na( &tmat1, C, &staple[i]);            
00468     } else{ /* No need to save the staple. Add it to the fatlinks */
00469       llfat_mult_su3_na( &tmat1, C, &tmat2);        
00470       llfat_scalar_mult_add_su3_matrix(fat1, &tmat2, coef, fat1);           
00471     }
00472   }    
00473   /***************lower staple****************
00474    *
00475    *               X       X
00476    *       nu      |       | 
00477    *         (A)   |       |(C)
00478    *               +-------+
00479    *                mu (B)
00480    *
00481    *********************************************/
00482 
00483   for(i=0;i < V;i++){
00484             
00485     int half_index = i;
00486     int oddBit =0;
00487     if (i >= Vh){
00488       oddBit = 1;
00489       half_index = i -Vh;
00490     }
00491 
00492     int sid =half_index;
00493     int za = sid/X1h;
00494     int x1h = sid - za*X1h;
00495     int zb = za/X2;
00496     int x2 = za - zb*X2;
00497     int x4 = zb/X3;
00498     int x3 = zb - x4*X3;
00499     int x1odd = (x2 + x3 + x4 + oddBit) & 1;
00500     int x1 = 2*x1h + x1odd;
00501     int x[4] = {x1,x2,x3,x4};
00502     int space_con[4]={
00503       (x4*X3X2+x3*X2+x2)/2,
00504       (x4*X3X1+x3*X1+x1)/2,
00505       (x4*X2X1+x2*X1+x1)/2,
00506       (x3*X2X1+x2*X1+x1)/2
00507     };
00508 
00509     //int x4 = x4_from_full_index(i);
00510 
00511     fat1 = ((su3_matrix*)fatlink[mu]) + i;
00512 
00513     //we could be in the ghost link area if nu is T and we are at low T boundary    
00514     su3_matrix* A;
00515     memset(dx, 0, sizeof(dx));
00516     dx[nu] = -1;
00517 
00518     int nbr_idx;
00519     if(x[nu] + dx[nu] < 0){ //out of boundary, use ghost data
00520       A = ghost_sitelink[nu] + nu*Vs[nu] + (1-oddBit)*Vsh[nu] + space_con[nu];
00521     }else{
00522       nbr_idx = neighborIndexFullLattice_mg(i, dx[3], dx[2], dx[1], dx[0]);     
00523       A = sitelink[nu] + nbr_idx;
00524     }
00525     
00526     su3_matrix* B;
00527     if (use_staple){
00528       nbr_idx = neighborIndexFullLattice_mg(i, dx[3], dx[2], dx[1], dx[0]);     
00529       if (x[nu] + dx[nu]  < 0){
00530         B =  ghost_mulink[nu] + (1-oddBit)*Vsh[nu] + space_con[nu];
00531       }else{
00532         B = mulink + nbr_idx;
00533       }
00534     }else{      
00535       if(x[nu] + dx[nu] < 0){ //out of boundary, use ghost data
00536         B = ghost_sitelink[nu] + mu*Vs[nu] + (1-oddBit)*Vsh[nu] + space_con[nu];        
00537       }else{
00538         nbr_idx = neighborIndexFullLattice_mg(i, dx[3], dx[2], dx[1], dx[0]);
00539         B = sitelink[mu] + nbr_idx;
00540       }
00541     }
00542 
00543     //we could be in the ghost link area if nu is T and we are at low T boundary
00544     // or mu is T and we are on high T boundary
00545     su3_matrix* C;
00546     memset(dx, 0, sizeof(dx));
00547     dx[nu] = -1;
00548     dx[mu] = 1;
00549     nbr_idx = neighborIndexFullLattice_mg(i, dx[3], dx[2],dx[1],dx[0]);
00550     
00551     //space con must be recomputed because we have coodinates change in 2 directions
00552     int new_x1, new_x2, new_x3, new_x4;
00553     new_x1 = (x[0] + dx[0] + Z[0])%Z[0];
00554     new_x2 = (x[1] + dx[1] + Z[1])%Z[1];
00555     new_x3 = (x[2] + dx[2] + Z[2])%Z[2];
00556     new_x4 = (x[3] + dx[3] + Z[3])%Z[3];
00557     int new_x[4] = {new_x1, new_x2, new_x3, new_x4};
00558     space_con[0] = (new_x4*X3X2 + new_x3*X2 + new_x2)/2;
00559     space_con[1] = (new_x4*X3X1 + new_x3*X1 + new_x1)/2;
00560     space_con[2] = (new_x4*X2X1 + new_x2*X1 + new_x1)/2;
00561     space_con[3] = (new_x3*X2X1 + new_x2*X1 + new_x1)/2;
00562 
00563     if( (x[nu] + dx[nu]) < 0  && (x[mu] + dx[mu] >= Z[mu])){
00564       //find the other 2 directions, dir1, dir2
00565       //with dir2 the slowest changing direction
00566       int dir1, dir2; //other two dimensions
00567       for(dir1=0; dir1 < 4; dir1 ++){
00568         if(dir1 != nu && dir1 != mu){
00569           break;
00570         }
00571       }
00572       for(dir2=0; dir2 < 4; dir2 ++){
00573         if(dir2 != nu && dir2 != mu && dir2 != dir1){
00574           break;
00575         }
00576       }  
00577       C = ghost_sitelink_diag[nu*4+mu] +  oddBit*Z[dir1]*Z[dir2]/2 + (new_x[dir2]*Z[dir1]+new_x[dir1])/2;       
00578     }else if (x[nu] + dx[nu] < 0){
00579       C = ghost_sitelink[nu] + nu*Vs[nu] + oddBit*Vsh[nu]+ space_con[nu];
00580     }else if (x[mu] + dx[mu] >= Z[mu]){
00581       C = ghost_sitelink[mu] + 4*Vs[mu] + nu*Vs[mu] + oddBit*Vsh[mu]+space_con[mu];
00582     }else{
00583       C = sitelink[nu] + nbr_idx;
00584     }
00585     llfat_mult_su3_an( A, B,&tmat1);    
00586     llfat_mult_su3_nn( &tmat1, C,&tmat2);
00587         
00588     if(staple!=NULL){/* Save the staple */
00589       llfat_add_su3_matrix(&staple[i], &tmat2, &staple[i]);
00590       llfat_scalar_mult_add_su3_matrix(fat1, &staple[i], coef, fat1);
00591             
00592     } else{ /* No need to save the staple. Add it to the fatlinks */
00593       llfat_scalar_mult_add_su3_matrix(fat1, &tmat2, coef, fat1);           
00594     }
00595   } 
00596     
00597 } /* compute_gen_staple_site */
00598 
00599 
00600 template <typename su3_matrix, typename Float>
00601 void llfat_cpu_mg(void** fatlink, su3_matrix** sitelink, su3_matrix** ghost_sitelink,
00602                   su3_matrix** ghost_sitelink_diag, Float* act_path_coeff)
00603 {
00604   QudaPrecision prec;
00605   if (sizeof(Float) == 4){
00606     prec = QUDA_SINGLE_PRECISION;
00607   }else{
00608     prec = QUDA_DOUBLE_PRECISION;
00609   }
00610   
00611   su3_matrix* staple = (su3_matrix *)malloc(V*sizeof(su3_matrix));
00612   if(staple == NULL){
00613     fprintf(stderr, "Error: malloc failed for staple in function %s\n", __FUNCTION__);
00614     exit(1);
00615   }
00616   
00617 
00618   su3_matrix* ghost_staple[4];
00619   su3_matrix* ghost_staple1[4];
00620 
00621   for(int i=0;i < 4;i++){
00622     ghost_staple[i] = (su3_matrix*)malloc(2*Vs[i]*sizeof(su3_matrix));
00623     if (ghost_staple[i] == NULL){
00624       fprintf(stderr, "Error: malloc failed for ghost staple in function %s\n", __FUNCTION__);
00625       exit(1);
00626     }
00627     
00628     ghost_staple1[i] = (su3_matrix*)malloc(2*Vs[i]*sizeof(su3_matrix));
00629     if (ghost_staple1[i] == NULL){ 
00630       fprintf(stderr, "Error: malloc failed for ghost staple1 in function %s\n", __FUNCTION__);
00631       exit(1);
00632     }     
00633   }
00634 
00635   su3_matrix* tempmat1 = (su3_matrix *)malloc(V*sizeof(su3_matrix));
00636   if(tempmat1 == NULL){
00637     fprintf(stderr, "ERROR:  malloc failed for tempmat1 in function %s\n", __FUNCTION__);
00638     exit(1);
00639   }
00640     
00641   /* to fix up the Lepage term, included by a trick below */
00642   Float one_link = (act_path_coeff[0] - 6.0*act_path_coeff[5]);
00643     
00644 
00645   for (int dir=XUP; dir<=TUP; dir++){
00646 
00647     /* Intialize fat links with c_1*U_\mu(x) */
00648     for(int i=0;i < V;i ++){
00649       su3_matrix* fat1 = ((su3_matrix*)fatlink[dir]) +  i;
00650       llfat_scalar_mult_su3_matrix(sitelink[dir] + i, one_link, fat1 );
00651     }
00652   }
00653 
00654   for (int dir=XUP; dir<=TUP; dir++){
00655     for(int nu=XUP; nu<=TUP; nu++){
00656       if(nu!=dir){
00657         llfat_compute_gen_staple_field_mg(staple,dir,nu,
00658                                           sitelink[dir], (su3_matrix**)NULL, 
00659                                           sitelink, ghost_sitelink, ghost_sitelink_diag, 
00660                                           fatlink, act_path_coeff[2], 0);       
00661         /* The Lepage term */
00662         /* Note this also involves modifying c_1 (above) */
00663 
00664         exchange_cpu_staple(Z, staple, (void**)ghost_staple, prec);
00665         
00666         llfat_compute_gen_staple_field_mg((su3_matrix*)NULL,dir,nu,
00667                                           staple,ghost_staple, 
00668                                           sitelink, ghost_sitelink, ghost_sitelink_diag, 
00669                                           fatlink, act_path_coeff[5],1);
00670 
00671         for(int rho=XUP; rho<=TUP; rho++) {
00672           if((rho!=dir)&&(rho!=nu)){
00673             llfat_compute_gen_staple_field_mg( tempmat1, dir, rho, 
00674                                                staple,ghost_staple, 
00675                                                sitelink, ghost_sitelink, ghost_sitelink_diag, 
00676                                                fatlink, act_path_coeff[3], 1);
00677 
00678 
00679             exchange_cpu_staple(Z, tempmat1, (void**)ghost_staple1, prec);
00680             
00681             for(int sig=XUP; sig<=TUP; sig++){
00682               if((sig!=dir)&&(sig!=nu)&&(sig!=rho)){
00683 
00684                 llfat_compute_gen_staple_field_mg((su3_matrix*)NULL,dir,sig,
00685                                                   tempmat1, ghost_staple1,
00686                                                   sitelink, ghost_sitelink, ghost_sitelink_diag, 
00687                                                   fatlink, act_path_coeff[4], 1);
00688                 //FIXME
00689                 //return;
00690 
00691               } 
00692             }/* sig */          
00693           } 
00694         }/* rho */
00695       } 
00696 
00697     }/* nu */
00698         
00699   }/* dir */      
00700   
00701   free(staple);
00702   for(int i=0;i < 4;i++){
00703     free(ghost_staple[i]);
00704     free(ghost_staple1[i]);
00705   }
00706   free(tempmat1);
00707 
00708 }
00709 
00710 
00711 
00712 void
00713 llfat_reference_mg(void** fatlink, void** sitelink, void** ghost_sitelink,
00714                    void** ghost_sitelink_diag, QudaPrecision prec, void* act_path_coeff)
00715 {
00716 
00717   switch(prec){
00718   case QUDA_DOUBLE_PRECISION:{
00719     llfat_cpu_mg((void**)fatlink, (dsu3_matrix**)sitelink, (dsu3_matrix**)ghost_sitelink, 
00720                  (dsu3_matrix**)ghost_sitelink_diag, (double*) act_path_coeff);
00721     break;
00722   }
00723   case QUDA_SINGLE_PRECISION:{
00724     llfat_cpu_mg((void**)fatlink, (fsu3_matrix**)sitelink, (fsu3_matrix**)ghost_sitelink, 
00725                  (fsu3_matrix**)ghost_sitelink_diag, (float*) act_path_coeff);
00726     break;
00727   }
00728   default:
00729     fprintf(stderr, "ERROR: unsupported precision(%d)\n", prec);
00730     exit(1);
00731     break;
00732         
00733   }
00734 
00735   return;
00736 
00737 }
00738 
00739 
00740 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines