QUDA v0.4.0
A library for QCD on GPUs
quda/tests/test_util.cpp
Go to the documentation of this file.
00001 #include <complex>
00002 #include <stdlib.h>
00003 #include <stdio.h>
00004 #include <string.h>
00005 
00006 #include <short.h>
00007 
00008 #include <wilson_dslash_reference.h>
00009 #include <test_util.h>
00010 
00011 #include <face_quda.h>
00012 #include "misc.h"
00013 
00014 #define XUP 0
00015 #define YUP 1
00016 #define ZUP 2
00017 #define TUP 3
00018 
00019 extern float fat_link_max;
00020 using namespace std;
00021 
00022 template <typename Float>
00023 static void printVector(Float *v) {
00024   printfQuda("{(%f %f) (%f %f) (%f %f)}\n", v[0], v[1], v[2], v[3], v[4], v[5]);
00025 }
00026 
00027 // X indexes the lattice site
00028 void printSpinorElement(void *spinor, int X, QudaPrecision precision) {
00029   if (precision == QUDA_DOUBLE_PRECISION)
00030     for (int s=0; s<4; s++) printVector((double*)spinor+X*24+s*6);
00031   else
00032     for (int s=0; s<4; s++) printVector((float*)spinor+X*24+s*6);
00033 }
00034 
00035 // X indexes the full lattice
00036 void printGaugeElement(void *gauge, int X, QudaPrecision precision) {
00037   if (getOddBit(X) == 0) {
00038     if (precision == QUDA_DOUBLE_PRECISION)
00039       for (int m=0; m<3; m++) printVector((double*)gauge +(X/2)*gaugeSiteSize + m*3*2);
00040     else
00041       for (int m=0; m<3; m++) printVector((float*)gauge +(X/2)*gaugeSiteSize + m*3*2);
00042       
00043   } else {
00044     if (precision == QUDA_DOUBLE_PRECISION)
00045       for (int m = 0; m < 3; m++) printVector((double*)gauge + (X/2+Vh)*gaugeSiteSize + m*3*2);
00046     else
00047       for (int m = 0; m < 3; m++) printVector((float*)gauge + (X/2+Vh)*gaugeSiteSize + m*3*2);
00048   }
00049 }
00050 
00051 // returns 0 or 1 if the full lattice index X is even or odd
00052 int getOddBit(int Y) {
00053   int x4 = Y/(Z[2]*Z[1]*Z[0]);
00054   int x3 = (Y/(Z[1]*Z[0])) % Z[2];
00055   int x2 = (Y/Z[0]) % Z[1];
00056   int x1 = Y % Z[0];
00057   return (x4+x3+x2+x1) % 2;
00058 }
00059 
00060 // a+=b
00061 template <typename Float>
00062 inline void complexAddTo(Float *a, Float *b) {
00063   a[0] += b[0];
00064   a[1] += b[1];
00065 }
00066 
00067 // a = b*c
00068 template <typename Float>
00069 inline void complexProduct(Float *a, Float *b, Float *c) {
00070   a[0] = b[0]*c[0] - b[1]*c[1];
00071   a[1] = b[0]*c[1] + b[1]*c[0];
00072 }
00073 
00074 // a = conj(b)*conj(c)
00075 template <typename Float>
00076 inline void complexConjugateProduct(Float *a, Float *b, Float *c) {
00077   a[0] = b[0]*c[0] - b[1]*c[1];
00078   a[1] = -b[0]*c[1] - b[1]*c[0];
00079 }
00080 
00081 // a = conj(b)*c
00082 template <typename Float>
00083 inline void complexDotProduct(Float *a, Float *b, Float *c) {
00084   a[0] = b[0]*c[0] + b[1]*c[1];
00085   a[1] = b[0]*c[1] - b[1]*c[0];
00086 }
00087 
00088 // a += b*c
00089 template <typename Float>
00090 inline void accumulateComplexProduct(Float *a, Float *b, Float *c, Float sign) {
00091   a[0] += sign*(b[0]*c[0] - b[1]*c[1]);
00092   a[1] += sign*(b[0]*c[1] + b[1]*c[0]);
00093 }
00094 
00095 // a += conj(b)*c)
00096 template <typename Float>
00097 inline void accumulateComplexDotProduct(Float *a, Float *b, Float *c) {
00098   a[0] += b[0]*c[0] + b[1]*c[1];
00099   a[1] += b[0]*c[1] - b[1]*c[0];
00100 }
00101 
00102 template <typename Float>
00103 inline void accumulateConjugateProduct(Float *a, Float *b, Float *c, int sign) {
00104   a[0] += sign * (b[0]*c[0] - b[1]*c[1]);
00105   a[1] -= sign * (b[0]*c[1] + b[1]*c[0]);
00106 }
00107 
00108 template <typename Float>
00109 inline void su3Construct12(Float *mat) {
00110   Float *w = mat+12;
00111   w[0] = 0.0;
00112   w[1] = 0.0;
00113   w[2] = 0.0;
00114   w[3] = 0.0;
00115   w[4] = 0.0;
00116   w[5] = 0.0;
00117 }
00118 
00119 // Stabilized Bunk and Sommer
00120 template <typename Float>
00121 inline void su3Construct8(Float *mat) {
00122   mat[0] = atan2(mat[1], mat[0]);
00123   mat[1] = atan2(mat[13], mat[12]);
00124   for (int i=8; i<18; i++) mat[i] = 0.0;
00125 }
00126 
00127 void su3_construct(void *mat, QudaReconstructType reconstruct, QudaPrecision precision) {
00128   if (reconstruct == QUDA_RECONSTRUCT_12) {
00129     if (precision == QUDA_DOUBLE_PRECISION) su3Construct12((double*)mat);
00130     else su3Construct12((float*)mat);
00131   } else {
00132     if (precision == QUDA_DOUBLE_PRECISION) su3Construct8((double*)mat);
00133     else su3Construct8((float*)mat);
00134   }
00135 }
00136 
00137 // given first two rows (u,v) of SU(3) matrix mat, reconstruct the third row
00138 // as the cross product of the conjugate vectors: w = u* x v*
00139 // 
00140 // 48 flops
00141 template <typename Float>
00142 static void su3Reconstruct12(Float *mat, int dir, int ga_idx, QudaGaugeParam *param) {
00143   Float *u = &mat[0*(3*2)];
00144   Float *v = &mat[1*(3*2)];
00145   Float *w = &mat[2*(3*2)];
00146   w[0] = 0.0; w[1] = 0.0; w[2] = 0.0; w[3] = 0.0; w[4] = 0.0; w[5] = 0.0;
00147   accumulateConjugateProduct(w+0*(2), u+1*(2), v+2*(2), +1);
00148   accumulateConjugateProduct(w+0*(2), u+2*(2), v+1*(2), -1);
00149   accumulateConjugateProduct(w+1*(2), u+2*(2), v+0*(2), +1);
00150   accumulateConjugateProduct(w+1*(2), u+0*(2), v+2*(2), -1);
00151   accumulateConjugateProduct(w+2*(2), u+0*(2), v+1*(2), +1);
00152   accumulateConjugateProduct(w+2*(2), u+1*(2), v+0*(2), -1);
00153   Float u0 = (dir < 3 ? param->anisotropy :
00154               (ga_idx >= (Z[3]-1)*Z[0]*Z[1]*Z[2]/2 ? param->t_boundary : 1));
00155   w[0]*=u0; w[1]*=u0; w[2]*=u0; w[3]*=u0; w[4]*=u0; w[5]*=u0;
00156 }
00157 
00158 template <typename Float>
00159 static void su3Reconstruct8(Float *mat, int dir, int ga_idx, QudaGaugeParam *param) {
00160   // First reconstruct first row
00161   Float row_sum = 0.0;
00162   row_sum += mat[2]*mat[2];
00163   row_sum += mat[3]*mat[3];
00164   row_sum += mat[4]*mat[4];
00165   row_sum += mat[5]*mat[5];
00166   Float u0 = (dir < 3 ? param->anisotropy :
00167               (ga_idx >= (Z[3]-1)*Z[0]*Z[1]*Z[2]/2 ? param->t_boundary : 1));
00168   Float U00_mag = sqrt(1.f/(u0*u0) - row_sum);
00169 
00170   mat[14] = mat[0];
00171   mat[15] = mat[1];
00172 
00173   mat[0] = U00_mag * cos(mat[14]);
00174   mat[1] = U00_mag * sin(mat[14]);
00175 
00176   Float column_sum = 0.0;
00177   for (int i=0; i<2; i++) column_sum += mat[i]*mat[i];
00178   for (int i=6; i<8; i++) column_sum += mat[i]*mat[i];
00179   Float U20_mag = sqrt(1.f/(u0*u0) - column_sum);
00180 
00181   mat[12] = U20_mag * cos(mat[15]);
00182   mat[13] = U20_mag * sin(mat[15]);
00183 
00184   // First column now restored
00185 
00186   // finally reconstruct last elements from SU(2) rotation
00187   Float r_inv2 = 1.0/(u0*row_sum);
00188 
00189   // U11
00190   Float A[2];
00191   complexDotProduct(A, mat+0, mat+6);
00192   complexConjugateProduct(mat+8, mat+12, mat+4);
00193   accumulateComplexProduct(mat+8, A, mat+2, u0);
00194   mat[8] *= -r_inv2;
00195   mat[9] *= -r_inv2;
00196 
00197   // U12
00198   complexConjugateProduct(mat+10, mat+12, mat+2);
00199   accumulateComplexProduct(mat+10, A, mat+4, -u0);
00200   mat[10] *= r_inv2;
00201   mat[11] *= r_inv2;
00202 
00203   // U21
00204   complexDotProduct(A, mat+0, mat+12);
00205   complexConjugateProduct(mat+14, mat+6, mat+4);
00206   accumulateComplexProduct(mat+14, A, mat+2, -u0);
00207   mat[14] *= r_inv2;
00208   mat[15] *= r_inv2;
00209 
00210   // U12
00211   complexConjugateProduct(mat+16, mat+6, mat+2);
00212   accumulateComplexProduct(mat+16, A, mat+4, u0);
00213   mat[16] *= -r_inv2;
00214   mat[17] *= -r_inv2;
00215 }
00216 
00217 void su3_reconstruct(void *mat, int dir, int ga_idx, QudaReconstructType reconstruct, QudaPrecision precision, QudaGaugeParam *param) {
00218   if (reconstruct == QUDA_RECONSTRUCT_12) {
00219     if (precision == QUDA_DOUBLE_PRECISION) su3Reconstruct12((double*)mat, dir, ga_idx, param);
00220     else su3Reconstruct12((float*)mat, dir, ga_idx, param);
00221   } else {
00222     if (precision == QUDA_DOUBLE_PRECISION) su3Reconstruct8((double*)mat, dir, ga_idx, param);
00223     else su3Reconstruct8((float*)mat, dir, ga_idx, param);
00224   }
00225 }
00226 
00227 /*
00228   void su3_construct_8_half(float *mat, short *mat_half) {
00229   su3Construct8(mat);
00230 
00231   mat_half[0] = floatToShort(mat[0] / M_PI);
00232   mat_half[1] = floatToShort(mat[1] / M_PI);
00233   for (int i=2; i<18; i++) {
00234   mat_half[i] = floatToShort(mat[i]);
00235   }
00236   }
00237 
00238   void su3_reconstruct_8_half(float *mat, short *mat_half, int dir, int ga_idx, QudaGaugeParam *param) {
00239 
00240   for (int i=0; i<18; i++) {
00241   mat[i] = shortToFloat(mat_half[i]);
00242   }
00243   mat[0] *= M_PI;
00244   mat[1] *= M_PI;
00245 
00246   su3Reconstruct8(mat, dir, ga_idx, param);
00247   }*/
00248 
00249 template <typename Float>
00250 static int compareFloats(Float *a, Float *b, int len, double epsilon) {
00251   for (int i = 0; i < len; i++) {
00252     double diff = fabs(a[i] - b[i]);
00253     if (diff > epsilon) {
00254       printfQuda("error: i=%d, a[%d]=%f, b[%d]=%f\n", i, i, a[i], i, b[i]);
00255       return 0;
00256     }
00257   }
00258   return 1;
00259 }
00260 
00261 int compare_floats(void *a, void *b, int len, double epsilon, QudaPrecision precision) {
00262   if  (precision == QUDA_DOUBLE_PRECISION) return compareFloats((double*)a, (double*)b, len, epsilon);
00263   else return compareFloats((float*)a, (float*)b, len, epsilon);
00264 }
00265 
00266 
00267 
00268 // given a "half index" i into either an even or odd half lattice (corresponding
00269 // to oddBit = {0, 1}), returns the corresponding full lattice index.
00270 int fullLatticeIndex(int i, int oddBit) {
00271   /*
00272     int boundaryCrossings = i/(Z[0]/2) + i/(Z[1]*Z[0]/2) + i/(Z[2]*Z[1]*Z[0]/2);
00273     return 2*i + (boundaryCrossings + oddBit) % 2;
00274   */
00275 
00276   int X1 = Z[0];  
00277   int X2 = Z[1];
00278   int X3 = Z[2];
00279   //int X4 = Z[3];
00280   int X1h =X1/2;
00281 
00282   int sid =i;
00283   int za = sid/X1h;
00284   //int x1h = sid - za*X1h;
00285   int zb = za/X2;
00286   int x2 = za - zb*X2;
00287   int x4 = zb/X3;
00288   int x3 = zb - x4*X3;
00289   int x1odd = (x2 + x3 + x4 + oddBit) & 1;
00290   //int x1 = 2*x1h + x1odd;
00291   int X = 2*sid + x1odd; 
00292 
00293   return X;
00294 }
00295 
00296 
00297 // i represents a "half index" into an even or odd "half lattice".
00298 // when oddBit={0,1} the half lattice is {even,odd}.
00299 // 
00300 // the displacements, such as dx, refer to the full lattice coordinates. 
00301 //
00302 // neighborIndex() takes a "half index", displaces it, and returns the
00303 // new "half index", which can be an index into either the even or odd lattices.
00304 // displacements of magnitude one always interchange odd and even lattices.
00305 //
00306 
00307 int neighborIndex(int i, int oddBit, int dx4, int dx3, int dx2, int dx1) {
00308   int Y = fullLatticeIndex(i, oddBit);
00309   int x4 = Y/(Z[2]*Z[1]*Z[0]);
00310   int x3 = (Y/(Z[1]*Z[0])) % Z[2];
00311   int x2 = (Y/Z[0]) % Z[1];
00312   int x1 = Y % Z[0];
00313   
00314   // assert (oddBit == (x+y+z+t)%2);
00315   
00316   x4 = (x4+dx4+Z[3]) % Z[3];
00317   x3 = (x3+dx3+Z[2]) % Z[2];
00318   x2 = (x2+dx2+Z[1]) % Z[1];
00319   x1 = (x1+dx1+Z[0]) % Z[0];
00320   
00321   return (x4*(Z[2]*Z[1]*Z[0]) + x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2;
00322 }
00323 
00324 int
00325 neighborIndex_mg(int i, int oddBit, int dx4, int dx3, int dx2, int dx1)
00326 {
00327   int ret;
00328   
00329   int Y = fullLatticeIndex(i, oddBit);
00330   int x4 = Y/(Z[2]*Z[1]*Z[0]);
00331   int x3 = (Y/(Z[1]*Z[0])) % Z[2];
00332   int x2 = (Y/Z[0]) % Z[1];
00333   int x1 = Y % Z[0];
00334   
00335   int ghost_x4 = x4+ dx4;
00336   
00337   // assert (oddBit == (x+y+z+t)%2);
00338   
00339   x4 = (x4+dx4+Z[3]) % Z[3];
00340   x3 = (x3+dx3+Z[2]) % Z[2];
00341   x2 = (x2+dx2+Z[1]) % Z[1];
00342   x1 = (x1+dx1+Z[0]) % Z[0];
00343   
00344   if ( ghost_x4 >= 0 && ghost_x4 < Z[3]){
00345     ret = (x4*(Z[2]*Z[1]*Z[0]) + x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2;
00346   }else{
00347     ret = (x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2;    
00348   }
00349 
00350   
00351   return ret;
00352 }
00353 
00354 
00355 /*  
00356  * This is a computation of neighbor using the full index and the displacement in each direction
00357  *
00358  */
00359 
00360 int
00361 neighborIndexFullLattice(int i, int dx4, int dx3, int dx2, int dx1) 
00362 {
00363   int oddBit = 0;
00364   int half_idx = i;
00365   if (i >= Vh){
00366     oddBit =1;
00367     half_idx = i - Vh;
00368   }
00369     
00370   int nbr_half_idx = neighborIndex(half_idx, oddBit, dx4,dx3,dx2,dx1);
00371   int oddBitChanged = (dx4+dx3+dx2+dx1)%2;
00372   if (oddBitChanged){
00373     oddBit = 1 - oddBit;
00374   }
00375   int ret = nbr_half_idx;
00376   if (oddBit){
00377     ret = Vh + nbr_half_idx;
00378   }
00379     
00380   return ret;
00381 }
00382 
00383 
00384 int
00385 neighborIndexFullLattice_mg(int i, int dx4, int dx3, int dx2, int dx1) 
00386 {
00387   int ret;
00388   int oddBit = 0;
00389   int half_idx = i;
00390   if (i >= Vh){
00391     oddBit =1;
00392     half_idx = i - Vh;
00393   }
00394     
00395   int Y = fullLatticeIndex(half_idx, oddBit);
00396   int x4 = Y/(Z[2]*Z[1]*Z[0]);
00397   int x3 = (Y/(Z[1]*Z[0])) % Z[2];
00398   int x2 = (Y/Z[0]) % Z[1];
00399   int x1 = Y % Z[0];
00400   int ghost_x4 = x4+ dx4;
00401     
00402   x4 = (x4+dx4+Z[3]) % Z[3];
00403   x3 = (x3+dx3+Z[2]) % Z[2];
00404   x2 = (x2+dx2+Z[1]) % Z[1];
00405   x1 = (x1+dx1+Z[0]) % Z[0];
00406 
00407   if ( ghost_x4 >= 0 && ghost_x4 < Z[3]){
00408     ret = (x4*(Z[2]*Z[1]*Z[0]) + x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2;
00409   }else{
00410     ret = (x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2;    
00411     return ret;
00412   }
00413 
00414   int oddBitChanged = (dx4+dx3+dx2+dx1)%2;
00415   if (oddBitChanged){
00416     oddBit = 1 - oddBit;
00417   }
00418     
00419   if (oddBit){
00420     ret += Vh;
00421   }
00422     
00423   return ret;
00424 }
00425 
00426 
00427 // 4d checkerboard.
00428 // given a "half index" i into either an even or odd half lattice (corresponding
00429 // to oddBit = {0, 1}), returns the corresponding full lattice index.
00430 // Cf. GPGPU code in dslash_core_ante.h.
00431 // There, i is the thread index.
00432 int fullLatticeIndex_4d(int i, int oddBit) {
00433   if (i >= Vh || i < 0) {printf("i out of range in fullLatticeIndex_4d"); exit(-1);}
00434   /*
00435     int boundaryCrossings = i/(Z[0]/2) + i/(Z[1]*Z[0]/2) + i/(Z[2]*Z[1]*Z[0]/2);
00436     return 2*i + (boundaryCrossings + oddBit) % 2;
00437   */
00438 
00439   int X1 = Z[0];  
00440   int X2 = Z[1];
00441   int X3 = Z[2];
00442   //int X4 = Z[3];
00443   int X1h =X1/2;
00444 
00445   int sid =i;
00446   int za = sid/X1h;
00447   //int x1h = sid - za*X1h;
00448   int zb = za/X2;
00449   int x2 = za - zb*X2;
00450   int x4 = zb/X3;
00451   int x3 = zb - x4*X3;
00452   int x1odd = (x2 + x3 + x4 + oddBit) & 1;
00453   //int x1 = 2*x1h + x1odd;
00454   int X = 2*sid + x1odd; 
00455 
00456   return X;
00457 }
00458 
00459 // 5d checkerboard.
00460 // given a "half index" i into either an even or odd half lattice (corresponding
00461 // to oddBit = {0, 1}), returns the corresponding full lattice index.
00462 // Cf. GPGPU code in dslash_core_ante.h.
00463 // There, i is the thread index sid.
00464 // This function is used by neighborIndex_5d in dslash_reference.cpp.
00465 //ok
00466 int fullLatticeIndex_5d(int i, int oddBit) {
00467   int boundaryCrossings = i/(Z[0]/2) + i/(Z[1]*Z[0]/2) + i/(Z[2]*Z[1]*Z[0]/2) + i/(Z[3]*Z[2]*Z[1]*Z[0]/2);
00468   return 2*i + (boundaryCrossings + oddBit) % 2;
00469 }
00470 
00471 int 
00472 x4_from_full_index(int i)
00473 {
00474   int oddBit = 0;
00475   int half_idx = i;
00476   if (i >= Vh){
00477     oddBit =1;
00478     half_idx = i - Vh;
00479   }
00480   
00481   int Y = fullLatticeIndex(half_idx, oddBit);
00482   int x4 = Y/(Z[2]*Z[1]*Z[0]);
00483   
00484   return x4;
00485 }
00486 
00487 template <typename Float>
00488 static void applyGaugeFieldScaling(Float **gauge, int Vh, QudaGaugeParam *param) {
00489   // Apply spatial scaling factor (u0) to spatial links
00490   for (int d = 0; d < 3; d++) {
00491     for (int i = 0; i < gaugeSiteSize*Vh*2; i++) {
00492       gauge[d][i] /= param->anisotropy;
00493     }
00494   }
00495     
00496   // only apply T-boundary at edge nodes
00497   bool Ntm1 = (commCoords(3) == commDim(3)-1) ? true : false;
00498 
00499   // Apply boundary conditions to temporal links
00500   if (param->t_boundary == QUDA_ANTI_PERIODIC_T && Ntm1) {
00501     for (int j = (Z[0]/2)*Z[1]*Z[2]*(Z[3]-1); j < Vh; j++) {
00502       for (int i = 0; i < gaugeSiteSize; i++) {
00503         gauge[3][j*gaugeSiteSize+i] *= -1.0;
00504         gauge[3][(Vh+j)*gaugeSiteSize+i] *= -1.0;
00505       }
00506     }
00507   }
00508     
00509   if (param->gauge_fix) {
00510     // set all gauge links (except for the last Z[0]*Z[1]*Z[2]/2) to the identity,
00511     // to simulate fixing to the temporal gauge.
00512     int iMax = ( Ntm1 ? (Z[0]/2)*Z[1]*Z[2]*(Z[3]-1) : Vh );
00513     int dir = 3; // time direction only
00514     Float *even = gauge[dir];
00515     Float *odd  = gauge[dir]+Vh*gaugeSiteSize;
00516     for (int i = 0; i< iMax; i++) {
00517       for (int m = 0; m < 3; m++) {
00518         for (int n = 0; n < 3; n++) {
00519           even[i*(3*3*2) + m*(3*2) + n*(2) + 0] = (m==n) ? 1 : 0;
00520           even[i*(3*3*2) + m*(3*2) + n*(2) + 1] = 0.0;
00521           odd [i*(3*3*2) + m*(3*2) + n*(2) + 0] = (m==n) ? 1 : 0;
00522           odd [i*(3*3*2) + m*(3*2) + n*(2) + 1] = 0.0;
00523         }
00524       }
00525     }
00526   }
00527 }
00528 
00529 template <typename Float>
00530 void applyGaugeFieldScaling_long(Float **gauge, int Vh, QudaGaugeParam *param)
00531 {
00532 
00533   int X1h=param->X[0]/2;
00534   int X1 =param->X[0];
00535   int X2 =param->X[1];
00536   int X3 =param->X[2];
00537   int X4 =param->X[3];
00538 
00539   // rescale long links by the appropriate coefficient
00540   for(int d=0; d<4; d++){
00541     for(int i=0; i < V*gaugeSiteSize; i++){
00542       gauge[d][i] /= (-24*param->tadpole_coeff*param->tadpole_coeff);
00543     }
00544   }
00545 
00546   // apply the staggered phases
00547   for (int d = 0; d < 3; d++) {
00548 
00549     //even
00550     for (int i = 0; i < Vh; i++) {
00551 
00552       int index = fullLatticeIndex(i, 0);
00553       int i4 = index /(X3*X2*X1);
00554       int i3 = (index - i4*(X3*X2*X1))/(X2*X1);
00555       int i2 = (index - i4*(X3*X2*X1) - i3*(X2*X1))/X1;
00556       int i1 = index - i4*(X3*X2*X1) - i3*(X2*X1) - i2*X1;
00557       int sign=1;
00558 
00559       if (d == 0) {
00560         if (i4 % 2 == 1){
00561           sign= -1;
00562         }
00563       }
00564 
00565       if (d == 1){
00566         if ((i4+i1) % 2 == 1){
00567           sign= -1;
00568         }
00569       }
00570       if (d == 2){
00571         if ( (i4+i1+i2) % 2 == 1){
00572           sign= -1;
00573         }
00574       }
00575 
00576       for (int j=0;j < 6; j++){
00577         gauge[d][i*gaugeSiteSize + 12+ j] *= sign;
00578       }
00579     }
00580     //odd
00581     for (int i = 0; i < Vh; i++) {
00582       int index = fullLatticeIndex(i, 1);
00583       int i4 = index /(X3*X2*X1);
00584       int i3 = (index - i4*(X3*X2*X1))/(X2*X1);
00585       int i2 = (index - i4*(X3*X2*X1) - i3*(X2*X1))/X1;
00586       int i1 = index - i4*(X3*X2*X1) - i3*(X2*X1) - i2*X1;
00587       int sign=1;
00588 
00589       if (d == 0) {
00590         if (i4 % 2 == 1){
00591           sign= -1;
00592         }
00593       }
00594 
00595       if (d == 1){
00596         if ((i4+i1) % 2 == 1){
00597           sign= -1;
00598         }
00599       }
00600       if (d == 2){
00601         if ( (i4+i1+i2) % 2 == 1){
00602           sign = -1;
00603         }
00604       }
00605 
00606       for (int j=0;j < 6; j++){
00607         gauge[d][(Vh+i)*gaugeSiteSize + 12 + j] *= sign;
00608       }
00609     }
00610 
00611   }
00612 
00613   // Apply boundary conditions to temporal links
00614   if (param->t_boundary == QUDA_ANTI_PERIODIC_T) {
00615     for (int j = 0; j < Vh; j++) {
00616       int sign =1;
00617       if (j >= (X4-3)*X1h*X2*X3 ){
00618         sign= -1;
00619       }
00620 
00621       for (int i = 0; i < 6; i++) {
00622         gauge[3][j*gaugeSiteSize+ 12+ i ] *= sign;
00623         gauge[3][(Vh+j)*gaugeSiteSize+12 +i] *= sign;
00624       }
00625     }
00626   }
00627 }
00628 
00629 
00630 
00631 template <typename Float>
00632 static void constructUnitGaugeField(Float **res, QudaGaugeParam *param) {
00633   Float *resOdd[4], *resEven[4];
00634   for (int dir = 0; dir < 4; dir++) {  
00635     resEven[dir] = res[dir];
00636     resOdd[dir]  = res[dir]+Vh*gaugeSiteSize;
00637   }
00638     
00639   for (int dir = 0; dir < 4; dir++) {
00640     for (int i = 0; i < Vh; i++) {
00641       for (int m = 0; m < 3; m++) {
00642         for (int n = 0; n < 3; n++) {
00643           resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] = (m==n) ? 1 : 0;
00644           resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] = 0.0;
00645           resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] = (m==n) ? 1 : 0;
00646           resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] = 0.0;
00647         }
00648       }
00649     }
00650   }
00651     
00652   applyGaugeFieldScaling(res, Vh, param);
00653 }
00654 
00655 // normalize the vector a
00656 template <typename Float>
00657 static void normalize(complex<Float> *a, int len) {
00658   double sum = 0.0;
00659   for (int i=0; i<len; i++) sum += norm(a[i]);
00660   for (int i=0; i<len; i++) a[i] /= sqrt(sum);
00661 }
00662 
00663 // orthogonalize vector b to vector a
00664 template <typename Float>
00665 static void orthogonalize(complex<Float> *a, complex<Float> *b, int len) {
00666   complex<double> dot = 0.0;
00667   for (int i=0; i<len; i++) dot += conj(a[i])*b[i];
00668   for (int i=0; i<len; i++) b[i] -= (complex<Float>)dot*a[i];
00669 }
00670 
00671 template <typename Float> 
00672 static void constructGaugeField(Float **res, QudaGaugeParam *param) {
00673   Float *resOdd[4], *resEven[4];
00674   for (int dir = 0; dir < 4; dir++) {  
00675     resEven[dir] = res[dir];
00676     resOdd[dir]  = res[dir]+Vh*gaugeSiteSize;
00677   }
00678     
00679   for (int dir = 0; dir < 4; dir++) {
00680     for (int i = 0; i < Vh; i++) {
00681       for (int m = 1; m < 3; m++) { // last 2 rows
00682         for (int n = 0; n < 3; n++) { // 3 columns
00683           resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] = rand() / (Float)RAND_MAX;
00684           resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] = rand() / (Float)RAND_MAX;
00685           resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] = rand() / (Float)RAND_MAX;
00686           resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] = rand() / (Float)RAND_MAX;                    
00687         }
00688       }
00689       normalize((complex<Float>*)(resEven[dir] + (i*3+1)*3*2), 3);
00690       orthogonalize((complex<Float>*)(resEven[dir] + (i*3+1)*3*2), (complex<Float>*)(resEven[dir] + (i*3+2)*3*2), 3);
00691       normalize((complex<Float>*)(resEven[dir] + (i*3 + 2)*3*2), 3);
00692       
00693       normalize((complex<Float>*)(resOdd[dir] + (i*3+1)*3*2), 3);
00694       orthogonalize((complex<Float>*)(resOdd[dir] + (i*3+1)*3*2), (complex<Float>*)(resOdd[dir] + (i*3+2)*3*2), 3);
00695       normalize((complex<Float>*)(resOdd[dir] + (i*3 + 2)*3*2), 3);
00696 
00697       {
00698         Float *w = resEven[dir]+(i*3+0)*3*2;
00699         Float *u = resEven[dir]+(i*3+1)*3*2;
00700         Float *v = resEven[dir]+(i*3+2)*3*2;
00701         
00702         for (int n = 0; n < 6; n++) w[n] = 0.0;
00703         accumulateConjugateProduct(w+0*(2), u+1*(2), v+2*(2), +1);
00704         accumulateConjugateProduct(w+0*(2), u+2*(2), v+1*(2), -1);
00705         accumulateConjugateProduct(w+1*(2), u+2*(2), v+0*(2), +1);
00706         accumulateConjugateProduct(w+1*(2), u+0*(2), v+2*(2), -1);
00707         accumulateConjugateProduct(w+2*(2), u+0*(2), v+1*(2), +1);
00708         accumulateConjugateProduct(w+2*(2), u+1*(2), v+0*(2), -1);
00709       }
00710 
00711       {
00712         Float *w = resOdd[dir]+(i*3+0)*3*2;
00713         Float *u = resOdd[dir]+(i*3+1)*3*2;
00714         Float *v = resOdd[dir]+(i*3+2)*3*2;
00715         
00716         for (int n = 0; n < 6; n++) w[n] = 0.0;
00717         accumulateConjugateProduct(w+0*(2), u+1*(2), v+2*(2), +1);
00718         accumulateConjugateProduct(w+0*(2), u+2*(2), v+1*(2), -1);
00719         accumulateConjugateProduct(w+1*(2), u+2*(2), v+0*(2), +1);
00720         accumulateConjugateProduct(w+1*(2), u+0*(2), v+2*(2), -1);
00721         accumulateConjugateProduct(w+2*(2), u+0*(2), v+1*(2), +1);
00722         accumulateConjugateProduct(w+2*(2), u+1*(2), v+0*(2), -1);
00723       }
00724 
00725     }
00726   }
00727 
00728   if (param->type == QUDA_WILSON_LINKS){  
00729     applyGaugeFieldScaling(res, Vh, param);
00730   } else if (param->type == QUDA_ASQTAD_LONG_LINKS){
00731     applyGaugeFieldScaling_long(res, Vh, param);      
00732   } else if (param->type == QUDA_ASQTAD_FAT_LINKS){
00733     for (int dir = 0; dir < 4; dir++){ 
00734       for (int i = 0; i < Vh; i++) {
00735         for (int m = 0; m < 3; m++) { // last 2 rows
00736           for (int n = 0; n < 3; n++) { // 3 columns
00737             resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] =1.0* rand() / (Float)RAND_MAX;
00738             resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] =2.0* rand() / (Float)RAND_MAX;
00739             resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] = 3.0*rand() / (Float)RAND_MAX;
00740             resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] = 4.0*rand() / (Float)RAND_MAX;
00741           }
00742         }
00743       }
00744     }
00745     
00746   }
00747 
00748 }
00749 
00750 template <typename Float> 
00751 void constructUnitaryGaugeField(Float **res) 
00752 {
00753   Float *resOdd[4], *resEven[4];
00754   for (int dir = 0; dir < 4; dir++) {  
00755     resEven[dir] = res[dir];
00756     resOdd[dir]  = res[dir]+Vh*gaugeSiteSize;
00757   }
00758   
00759   for (int dir = 0; dir < 4; dir++) {
00760     for (int i = 0; i < Vh; i++) {
00761       for (int m = 1; m < 3; m++) { // last 2 rows
00762         for (int n = 0; n < 3; n++) { // 3 columns
00763           resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] = rand() / (Float)RAND_MAX;
00764           resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] = rand() / (Float)RAND_MAX;
00765           resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] = rand() / (Float)RAND_MAX;
00766           resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] = rand() / (Float)RAND_MAX;                    
00767         }
00768       }
00769       normalize((complex<Float>*)(resEven[dir] + (i*3+1)*3*2), 3);
00770       orthogonalize((complex<Float>*)(resEven[dir] + (i*3+1)*3*2), (complex<Float>*)(resEven[dir] + (i*3+2)*3*2), 3);
00771       normalize((complex<Float>*)(resEven[dir] + (i*3 + 2)*3*2), 3);
00772       
00773       normalize((complex<Float>*)(resOdd[dir] + (i*3+1)*3*2), 3);
00774       orthogonalize((complex<Float>*)(resOdd[dir] + (i*3+1)*3*2), (complex<Float>*)(resOdd[dir] + (i*3+2)*3*2), 3);
00775       normalize((complex<Float>*)(resOdd[dir] + (i*3 + 2)*3*2), 3);
00776 
00777       {
00778         Float *w = resEven[dir]+(i*3+0)*3*2;
00779         Float *u = resEven[dir]+(i*3+1)*3*2;
00780         Float *v = resEven[dir]+(i*3+2)*3*2;
00781         
00782         for (int n = 0; n < 6; n++) w[n] = 0.0;
00783         accumulateConjugateProduct(w+0*(2), u+1*(2), v+2*(2), +1);
00784         accumulateConjugateProduct(w+0*(2), u+2*(2), v+1*(2), -1);
00785         accumulateConjugateProduct(w+1*(2), u+2*(2), v+0*(2), +1);
00786         accumulateConjugateProduct(w+1*(2), u+0*(2), v+2*(2), -1);
00787         accumulateConjugateProduct(w+2*(2), u+0*(2), v+1*(2), +1);
00788         accumulateConjugateProduct(w+2*(2), u+1*(2), v+0*(2), -1);
00789       }
00790       
00791       {
00792         Float *w = resOdd[dir]+(i*3+0)*3*2;
00793         Float *u = resOdd[dir]+(i*3+1)*3*2;
00794         Float *v = resOdd[dir]+(i*3+2)*3*2;
00795         
00796         for (int n = 0; n < 6; n++) w[n] = 0.0;
00797         accumulateConjugateProduct(w+0*(2), u+1*(2), v+2*(2), +1);
00798         accumulateConjugateProduct(w+0*(2), u+2*(2), v+1*(2), -1);
00799         accumulateConjugateProduct(w+1*(2), u+2*(2), v+0*(2), +1);
00800         accumulateConjugateProduct(w+1*(2), u+0*(2), v+2*(2), -1);
00801         accumulateConjugateProduct(w+2*(2), u+0*(2), v+1*(2), +1);
00802         accumulateConjugateProduct(w+2*(2), u+1*(2), v+0*(2), -1);
00803       }
00804       
00805     }
00806   }
00807 }
00808 
00809 
00810 void construct_gauge_field(void **gauge, int type, QudaPrecision precision, QudaGaugeParam *param) {
00811   if (type == 0) {
00812     if (precision == QUDA_DOUBLE_PRECISION) constructUnitGaugeField((double**)gauge, param);
00813     else constructUnitGaugeField((float**)gauge, param);
00814   } else if (type == 1) {
00815     if (precision == QUDA_DOUBLE_PRECISION) constructGaugeField((double**)gauge, param);
00816     else constructGaugeField((float**)gauge, param);
00817   } else {
00818     if (precision == QUDA_DOUBLE_PRECISION) applyGaugeFieldScaling((double**)gauge, Vh, param);
00819     else applyGaugeFieldScaling((float**)gauge, Vh, param);    
00820   }
00821 
00822 }
00823 
00824 void
00825 construct_fat_long_gauge_field(void **fatlink, void** longlink,  
00826                                int type, QudaPrecision precision, QudaGaugeParam* param)
00827 {
00828   if (type == 0) {
00829     if (precision == QUDA_DOUBLE_PRECISION) {
00830       constructUnitGaugeField((double**)fatlink, param);
00831       constructUnitGaugeField((double**)longlink, param);
00832     }else {
00833       constructUnitGaugeField((float**)fatlink, param);
00834       constructUnitGaugeField((float**)longlink, param);
00835     }
00836   } else {
00837     if (precision == QUDA_DOUBLE_PRECISION) {
00838       param->type = QUDA_ASQTAD_FAT_LINKS;
00839       constructGaugeField((double**)fatlink, param);
00840       param->type = QUDA_ASQTAD_LONG_LINKS;
00841       constructGaugeField((double**)longlink, param);
00842     }else {
00843       param->type = QUDA_ASQTAD_FAT_LINKS;
00844       constructGaugeField((float**)fatlink, param);
00845       param->type = QUDA_ASQTAD_LONG_LINKS;
00846       constructGaugeField((float**)longlink, param);
00847     }
00848   }
00849 }
00850 
00851 
00852 template <typename Float>
00853 static void constructCloverField(Float *res, double norm, double diag) {
00854 
00855   Float c = 2.0 * norm / RAND_MAX;
00856 
00857   for(int i = 0; i < V; i++) {
00858     for (int j = 0; j < 72; j++) {
00859       res[i*72 + j] = c*rand() - norm;
00860     }
00861     for (int j = 0; j< 6; j++) {
00862       res[i*72 + j] += diag;
00863       res[i*72 + j+36] += diag;
00864     }
00865   }
00866 }
00867 
00868 void construct_clover_field(void *clover, double norm, double diag, QudaPrecision precision) {
00869 
00870   if (precision == QUDA_DOUBLE_PRECISION) constructCloverField((double *)clover, norm, diag);
00871   else constructCloverField((float *)clover, norm, diag);
00872 }
00873 
00874 /*void strong_check(void *spinorRef, void *spinorGPU, int len, QudaPrecision prec) {
00875   printf("Reference:\n");
00876   printSpinorElement(spinorRef, 0, prec); printf("...\n");
00877   printSpinorElement(spinorRef, len-1, prec); printf("\n");    
00878     
00879   printf("\nCUDA:\n");
00880   printSpinorElement(spinorGPU, 0, prec); printf("...\n");
00881   printSpinorElement(spinorGPU, len-1, prec); printf("\n");
00882 
00883   compare_spinor(spinorRef, spinorGPU, len, prec);
00884   }*/
00885 
00886 template <typename Float>
00887 static void checkGauge(Float **oldG, Float **newG, double epsilon) {
00888 
00889   const int fail_check = 17;
00890   int fail[4][fail_check];
00891   int iter[4][18];
00892   for (int d=0; d<4; d++) for (int i=0; i<fail_check; i++) fail[d][i] = 0;
00893   for (int d=0; d<4; d++) for (int i=0; i<18; i++) iter[d][i] = 0;
00894 
00895   for (int d=0; d<4; d++) {
00896     for (int eo=0; eo<2; eo++) {
00897       for (int i=0; i<Vh; i++) {
00898         int ga_idx = (eo*Vh+i);
00899         for (int j=0; j<18; j++) {
00900           double diff = fabs(newG[d][ga_idx*18+j] - oldG[d][ga_idx*18+j]);
00901 
00902           for (int f=0; f<fail_check; f++) if (diff > pow(10.0,-(f+1))) fail[d][f]++;
00903           if (diff > epsilon) iter[d][j]++;
00904         }
00905       }
00906     }
00907   }
00908 
00909   printf("Component fails (X, Y, Z, T)\n");
00910   for (int i=0; i<18; i++) printf("%d fails = (%8d, %8d, %8d, %8d)\n", i, iter[0][i], iter[1][i], iter[2][i], iter[3][i]);
00911 
00912   printf("\nDeviation Failures = (X, Y, Z, T)\n");
00913   for (int f=0; f<fail_check; f++) {
00914     printf("%e Failures = (%9d, %9d, %9d, %9d) = (%e, %e, %e, %e)\n", pow(10.0,-(f+1)), 
00915            fail[0][f], fail[1][f], fail[2][f], fail[3][f],
00916            fail[0][f]/(double)(V*18), fail[1][f]/(double)(V*18), fail[2][f]/(double)(V*18), fail[3][f]/(double)(V*18));
00917   }
00918 
00919 }
00920 
00921 void check_gauge(void **oldG, void **newG, double epsilon, QudaPrecision precision) {
00922   if (precision == QUDA_DOUBLE_PRECISION) 
00923     checkGauge((double**)oldG, (double**)newG, epsilon);
00924   else 
00925     checkGauge((float**)oldG, (float**)newG, epsilon);
00926 }
00927 
00928 
00929 
00930 void 
00931 createSiteLinkCPU(void** link,  QudaPrecision precision, int phase) 
00932 {
00933     
00934   if (precision == QUDA_DOUBLE_PRECISION) {
00935     constructUnitaryGaugeField((double**)link);
00936   }else {
00937     constructUnitaryGaugeField((float**)link);
00938   }
00939 
00940   if(phase){
00941         
00942     for(int i=0;i < V;i++){
00943       for(int dir =XUP; dir <= TUP; dir++){
00944         int idx = i;
00945         int oddBit =0;
00946         if (i >= Vh) {
00947           idx = i - Vh;
00948           oddBit = 1;
00949         }
00950 
00951         int X1 = Z[0];
00952         int X2 = Z[1];
00953         int X3 = Z[2];
00954         int X4 = Z[3];
00955 
00956         int full_idx = fullLatticeIndex(idx, oddBit);
00957         int i4 = full_idx /(X3*X2*X1);
00958         int i3 = (full_idx - i4*(X3*X2*X1))/(X2*X1);
00959         int i2 = (full_idx - i4*(X3*X2*X1) - i3*(X2*X1))/X1;
00960         int i1 = full_idx - i4*(X3*X2*X1) - i3*(X2*X1) - i2*X1;     
00961 
00962         double coeff= 1.0;
00963         switch(dir){
00964         case XUP:
00965           if ( (i4 & 1) != 0){
00966             coeff *= -1;
00967           }
00968           break;
00969 
00970         case YUP:
00971           if ( ((i4+i1) & 1) != 0){
00972             coeff *= -1;
00973           }
00974           break;
00975 
00976         case ZUP:
00977           if ( ((i4+i1+i2) & 1) != 0){
00978             coeff *= -1;
00979           }
00980           break;
00981                 
00982         case TUP:
00983           if ((commCoords(3) == commDim(3) -1) && i4 == (X4-1) ){
00984             coeff *= -1;
00985           }
00986           break;
00987 
00988         default:
00989           printf("ERROR: wrong dir(%d)\n", dir);
00990           exit(1);
00991         }
00992             
00993         if (precision == QUDA_DOUBLE_PRECISION){
00994           //double* mylink = (double*)link;
00995           //mylink = mylink + (4*i + dir)*gaugeSiteSize;
00996           double* mylink = (double*)link[dir];
00997           mylink = mylink + i*gaugeSiteSize;
00998 
00999           mylink[12] *= coeff;
01000           mylink[13] *= coeff;
01001           mylink[14] *= coeff;
01002           mylink[15] *= coeff;
01003           mylink[16] *= coeff;
01004           mylink[17] *= coeff;
01005                 
01006         }else{
01007           //float* mylink = (float*)link;
01008           //mylink = mylink + (4*i + dir)*gaugeSiteSize;
01009           float* mylink = (float*)link[dir];
01010           mylink = mylink + i*gaugeSiteSize;
01011                   
01012           mylink[12] *= coeff;
01013           mylink[13] *= coeff;
01014           mylink[14] *= coeff;
01015           mylink[15] *= coeff;
01016           mylink[16] *= coeff;
01017           mylink[17] *= coeff;
01018                 
01019         }
01020       }
01021     }
01022   }    
01023 
01024     
01025 #if 1
01026   for(int dir= 0;dir < 4;dir++){
01027     for(int i=0;i< V*gaugeSiteSize;i++){
01028       if (precision ==QUDA_SINGLE_PRECISION){
01029         float* f = (float*)link[dir];
01030         if (f[i] != f[i] || (fabsf(f[i]) > 1.e+3) ){
01031           fprintf(stderr, "ERROR:  %dth: bad number(%f) in function %s \n",i, f[i], __FUNCTION__);
01032           exit(1);
01033         }
01034       }else{
01035         double* f = (double*)link[dir];
01036         if (f[i] != f[i] || (fabs(f[i]) > 1.e+3)){
01037           fprintf(stderr, "ERROR:  %dth: bad number(%f) in function %s \n",i, f[i], __FUNCTION__);
01038           exit(1);
01039         }
01040           
01041       }
01042         
01043     }
01044   }
01045 #endif
01046 
01047   return;
01048 }
01049 
01050 
01051 
01052 
01053 template <typename Float>
01054 int compareLink(Float **linkA, Float **linkB, int len) {
01055   const int fail_check = 16;
01056   int fail[fail_check];
01057   for (int f=0; f<fail_check; f++) fail[f] = 0;
01058 
01059   int iter[18];
01060   for (int i=0; i<18; i++) iter[i] = 0;
01061   
01062   for(int dir=0;dir < 4; dir++){
01063     for (int i=0; i<len; i++) {
01064       for (int j=0; j<18; j++) {
01065         int is = i*18+j;
01066         double diff = fabs(linkA[dir][is]-linkB[dir][is]);
01067         for (int f=0; f<fail_check; f++)
01068           if (diff > pow(10.0,-(f+1))) fail[f]++;
01069         //if (diff > 1e-1) printf("%d %d %e\n", i, j, diff);
01070         if (diff > 1e-3) iter[j]++;
01071       }
01072     }
01073   }
01074   
01075   for (int i=0; i<18; i++) printfQuda("%d fails = %d\n", i, iter[i]);
01076   
01077   int accuracy_level = 0;
01078   for(int f =0; f < fail_check; f++){
01079     if(fail[f] == 0){
01080       accuracy_level =f;
01081     }
01082   }
01083 
01084   for (int f=0; f<fail_check; f++) {
01085     printfQuda("%e Failures: %d / %d  = %e\n", pow(10.0,-(f+1)), fail[f], len*gaugeSiteSize, fail[f] / (double)(len*6));
01086   }
01087   
01088   return accuracy_level;
01089 }
01090 
01091 static int
01092 compare_link(void **linkA, void **linkB, int len, QudaPrecision precision)
01093 {
01094   int ret;
01095 
01096   if (precision == QUDA_DOUBLE_PRECISION){    
01097     ret = compareLink((double**)linkA, (double**)linkB, len);
01098   }else {
01099     ret = compareLink((float**)linkA, (float**)linkB, len);
01100   }    
01101 
01102   return ret;
01103 }
01104 
01105 
01106 // X indexes the lattice site
01107 static void 
01108 printLinkElement(void *link, int X, QudaPrecision precision) 
01109 {
01110   if (precision == QUDA_DOUBLE_PRECISION){
01111     for(int i=0; i < 3;i++){
01112       printVector((double*)link+ X*gaugeSiteSize + i*6);
01113     }
01114         
01115   }
01116   else{
01117     for(int i=0;i < 3;i++){
01118       printVector((float*)link+X*gaugeSiteSize + i*6);
01119     }
01120   }
01121 }
01122 
01123 int strong_check_link(void** linkA, const char* msgA, 
01124                       void **linkB, const char* msgB, 
01125                       int len, QudaPrecision prec) 
01126 {
01127   printfQuda("%s\n", msgA);
01128   printLinkElement(linkA[0], 0, prec); 
01129   printfQuda("\n");
01130   printLinkElement(linkA[0], 1, prec); 
01131   printfQuda("...\n");
01132   printLinkElement(linkA[3], len-1, prec); 
01133   printfQuda("\n");    
01134     
01135   printfQuda("\n%s\n", msgB);
01136   printLinkElement(linkB[0], 0, prec); 
01137   printfQuda("\n");
01138   printLinkElement(linkB[0], 1, prec); 
01139   printfQuda("...\n");
01140   printLinkElement(linkB[3], len-1, prec); 
01141   printfQuda("\n");
01142     
01143   int ret = compare_link(linkA, linkB, len, prec);
01144   return ret;
01145 }
01146 
01147 
01148 void 
01149 createMomCPU(void* mom,  QudaPrecision precision) 
01150 {
01151   void* temp;
01152     
01153   size_t gSize = (precision == QUDA_DOUBLE_PRECISION) ? sizeof(double) : sizeof(float);
01154   temp = malloc(4*V*gaugeSiteSize*gSize);
01155   if (temp == NULL){
01156     fprintf(stderr, "Error: malloc failed for temp in function %s\n", __FUNCTION__);
01157     exit(1);
01158   }
01159     
01160     
01161     
01162   for(int i=0;i < V;i++){
01163     if (precision == QUDA_DOUBLE_PRECISION){
01164       for(int dir=0;dir < 4;dir++){
01165         double* thismom = (double*)mom;     
01166         for(int k=0; k < momSiteSize; k++){
01167           thismom[ (4*i+dir)*momSiteSize + k ]= 1.0* rand() /RAND_MAX;                          
01168           if (k==momSiteSize-1) thismom[ (4*i+dir)*momSiteSize + k ]= 0.0;
01169         }           
01170       }     
01171     }else{
01172       for(int dir=0;dir < 4;dir++){
01173         float* thismom=(float*)mom;
01174         for(int k=0; k < momSiteSize; k++){
01175           thismom[ (4*i+dir)*momSiteSize + k ]= 1.0* rand() /RAND_MAX;          
01176           if (k==momSiteSize-1) thismom[ (4*i+dir)*momSiteSize + k ]= 0.0;
01177         }           
01178       }
01179     }
01180   }
01181     
01182   free(temp);
01183   return;
01184 }
01185 
01186 void
01187 createHwCPU(void* hw,  QudaPrecision precision)
01188 {
01189   for(int i=0;i < V;i++){
01190     if (precision == QUDA_DOUBLE_PRECISION){
01191       for(int dir=0;dir < 4;dir++){
01192         double* thishw = (double*)hw;
01193         for(int k=0; k < hwSiteSize; k++){
01194           thishw[ (4*i+dir)*hwSiteSize + k ]= 1.0* rand() /RAND_MAX;
01195         }
01196       }
01197     }else{
01198       for(int dir=0;dir < 4;dir++){
01199         float* thishw=(float*)hw;
01200         for(int k=0; k < hwSiteSize; k++){
01201           thishw[ (4*i+dir)*hwSiteSize + k ]= 1.0* rand() /RAND_MAX;
01202         }
01203       }
01204     }
01205   }
01206 
01207   return;
01208 }
01209 
01210 
01211 template <typename Float>
01212 int compare_mom(Float *momA, Float *momB, int len) {
01213   const int fail_check = 16;
01214   int fail[fail_check];
01215   for (int f=0; f<fail_check; f++) fail[f] = 0;
01216 
01217   int iter[momSiteSize];
01218   for (int i=0; i<momSiteSize; i++) iter[i] = 0;
01219   
01220   for (int i=0; i<len; i++) {
01221     for (int j=0; j<momSiteSize; j++) {
01222       int is = i*momSiteSize+j;
01223       double diff = fabs(momA[is]-momB[is]);
01224       for (int f=0; f<fail_check; f++)
01225         if (diff > pow(10.0,-(f+1))) fail[f]++;
01226       //if (diff > 1e-1) printf("%d %d %e\n", i, j, diff);
01227       if (diff > 1e-3) iter[j]++;
01228     }
01229   }
01230   
01231   int accuracy_level = 0;
01232   for(int f =0; f < fail_check; f++){
01233     if(fail[f] == 0){
01234       accuracy_level =f+1;
01235     }
01236   }
01237 
01238   for (int i=0; i<momSiteSize; i++) printf("%d fails = %d\n", i, iter[i]);
01239   
01240   for (int f=0; f<fail_check; f++) {
01241     printf("%e Failures: %d / %d  = %e\n", pow(10.0,-(f+1)), fail[f], len*momSiteSize, fail[f] / (double)(len*6));
01242   }
01243   
01244   return accuracy_level;
01245 }
01246 
01247 static void 
01248 printMomElement(void *mom, int X, QudaPrecision precision) 
01249 {
01250   if (precision == QUDA_DOUBLE_PRECISION){
01251     double* thismom = ((double*)mom)+ X*momSiteSize;
01252     printVector(thismom);
01253     printf("(%9f,%9f) (%9f,%9f)\n", thismom[6], thismom[7], thismom[8], thismom[9]);
01254   }else{
01255     float* thismom = ((float*)mom)+ X*momSiteSize;
01256     printVector(thismom);
01257     printf("(%9f,%9f) (%9f,%9f)\n", thismom[6], thismom[7], thismom[8], thismom[9]);    
01258   }
01259 }
01260 int strong_check_mom(void * momA, void *momB, int len, QudaPrecision prec) 
01261 {    
01262   printf("mom:\n");
01263   printMomElement(momA, 0, prec); 
01264   printf("\n");
01265   printMomElement(momA, 1, prec); 
01266   printf("\n");
01267   printMomElement(momA, 2, prec); 
01268   printf("\n");
01269   printMomElement(momA, 3, prec); 
01270   printf("...\n");
01271   
01272   printf("\nreference mom:\n");
01273   printMomElement(momB, 0, prec); 
01274   printf("\n");
01275   printMomElement(momB, 1, prec); 
01276   printf("\n");
01277   printMomElement(momB, 2, prec); 
01278   printf("\n");
01279   printMomElement(momB, 3, prec); 
01280   printf("\n");
01281   
01282   int ret;
01283   if (prec == QUDA_DOUBLE_PRECISION){
01284     ret = compare_mom((double*)momA, (double*)momB, len);
01285   }else{
01286     ret = compare_mom((float*)momA, (float*)momB, len);
01287   }
01288   
01289   return ret;
01290 }
01291 
01292 
01293 /************
01294  * return value
01295  *
01296  * 0: command line option matched and processed sucessfully
01297  * non-zero: command line option does not match
01298  *
01299  */
01300 
01301 #ifdef MULTI_GPU
01302 int device = -1;
01303 #else
01304 int device = 0;
01305 #endif
01306 
01307 QudaReconstructType link_recon = QUDA_RECONSTRUCT_12;
01308 QudaReconstructType link_recon_sloppy = QUDA_RECONSTRUCT_INVALID;
01309 QudaPrecision prec = QUDA_SINGLE_PRECISION;
01310 QudaPrecision  prec_sloppy = QUDA_INVALID_PRECISION;
01311 int xdim = 24;
01312 int ydim = 24;
01313 int zdim = 24;
01314 int tdim = 24;
01315 QudaDagType dagger = QUDA_DAG_NO;
01316 extern bool kernelPackT;
01317 int gridsize_from_cmdline[4]={1,1,1,1};
01318 QudaDslashType dslash_type = QUDA_WILSON_DSLASH;
01319 char latfile[256] = "";
01320 bool tune = true;
01321 
01322 void __attribute__((weak)) usage_extra(char** argv){};
01323 
01324 void usage(char** argv )
01325 {
01326   printf("Usage: %s [options]\n", argv[0]);
01327   printf("Common options: \n");
01328 #ifndef MULTI_GPU
01329   printf("    --device <n>                              # Set the CUDA device to use (default 0, single GPU only)\n");     
01330 #endif
01331   printf("    --prec <double/single/half>               # Precision in GPU\n"); 
01332   printf("    --prec_sloppy <double/single/half>        # Sloppy precision in GPU\n"); 
01333   printf("    --recon <8/12/18>                         # Link reconstruction type\n"); 
01334   printf("    --recon_sloppy <8/12/18>                  # Sloppy link reconstruction type\n"); 
01335   printf("    --dagger                                  # Set the dagger to 1 (default 0)\n"); 
01336   printf("    --sdim <n>                                # Set space dimention(X/Y/Z) size\n"); 
01337   printf("    --xdim <n>                                # Set X dimension size(default 24)\n");     
01338   printf("    --ydim <n>                                # Set X dimension size(default 24)\n");     
01339   printf("    --zdim <n>                                # Set X dimension size(default 24)\n");     
01340   printf("    --tdim <n>                                # Set T dimension size(default 24)\n");  
01341   printf("    --xgridsize <n>                           # Set grid size in X dimension (default 1)\n");
01342   printf("    --ygridsize <n>                           # Set grid size in Y dimension (default 1)\n");
01343   printf("    --zgridsize <n>                           # Set grid size in Z dimension (default 1)\n");
01344   printf("    --tgridsize <n>                           # Set grid size in T dimension (default 1)\n");
01345   printf("    --partition <mask>                        # Set the communication topology (X=1, Y=2, Z=4, T=8, and combinations of these)\n");
01346   printf("    --kernel_pack_t                           # Set T dimension kernel packing to be true (default false)\n");
01347   printf("    --dslash_type <type>                      # Set the dslash type, the following values are valid\n"
01348          "                                                  wilson/clover/twisted_mass/asqtad/domain_wall\n");
01349   printf("    --load-gauge file                         # Load gauge field \"file\" for the test (requires QIO)\n");
01350   printf("    --tune <true/false>                       # Whether to autotune or not (default true)\n");     
01351   printf("    --help                                    # Print out this message\n"); 
01352   usage_extra(argv); 
01353 #ifdef MULTI_GPU
01354   char msg[]="multi";
01355 #else
01356   char msg[]="single";
01357 #endif  
01358   printf("Note: this program is %s GPU build\n", msg);
01359   exit(1);
01360   return ;
01361 }
01362 
01363 int process_command_line_option(int argc, char** argv, int* idx)
01364 {
01365 #ifdef MULTI_GPU
01366   char msg[]="multi";
01367 #else
01368   char msg[]="single";
01369 #endif
01370 
01371   int ret = -1;
01372   
01373   int i = *idx;
01374 
01375   if( strcmp(argv[i], "--help")== 0){
01376     usage(argv);
01377   }
01378 
01379   if( strcmp(argv[i], "--device") == 0){
01380     if (i+1 >= argc){
01381       usage(argv);
01382     }
01383 #ifdef MULTI_GPU
01384     printf("Warning: Ignoring --device argument since this is a multi-GPU build.\n");
01385 #else
01386     device = atoi(argv[i+1]);
01387     if (device < 0 || device > 16){
01388       printf("Error: Invalid CUDA device number (%d)\n", device);
01389       usage(argv);
01390     }
01391 #endif
01392     i++;
01393     ret = 0;
01394     goto out;
01395   }
01396 
01397   if( strcmp(argv[i], "--prec") == 0){
01398     if (i+1 >= argc){
01399       usage(argv);
01400     }       
01401     prec =  get_prec(argv[i+1]);
01402     i++;
01403     ret = 0;
01404     goto out;
01405   }
01406 
01407   if( strcmp(argv[i], "--prec_sloppy") == 0){
01408     if (i+1 >= argc){
01409       usage(argv);
01410     }       
01411     prec_sloppy =  get_prec(argv[i+1]);
01412     i++;
01413     ret = 0;
01414     goto out;
01415   }
01416   
01417   if( strcmp(argv[i], "--recon") == 0){
01418     if (i+1 >= argc){
01419       usage(argv);
01420     }       
01421     link_recon =  get_recon(argv[i+1]);
01422     i++;
01423     ret = 0;
01424     goto out;
01425   }
01426 
01427   if( strcmp(argv[i], "--recon_sloppy") == 0){
01428     if (i+1 >= argc){
01429       usage(argv);
01430     }       
01431     link_recon_sloppy =  get_recon(argv[i+1]);
01432     i++;
01433     ret = 0;
01434     goto out;
01435   }
01436   
01437   if( strcmp(argv[i], "--xdim") == 0){
01438     if (i+1 >= argc){
01439       usage(argv);
01440     }
01441     xdim= atoi(argv[i+1]);
01442     if (xdim < 0 || xdim > 128){
01443       printf("ERROR: invalid X dimension (%d)\n", xdim);
01444       usage(argv);
01445     }
01446     i++;
01447     ret = 0;
01448     goto out;
01449   }
01450 
01451   if( strcmp(argv[i], "--ydim") == 0){
01452     if (i+1 >= argc){
01453       usage(argv);
01454     }
01455     ydim= atoi(argv[i+1]);
01456     if (ydim < 0 || ydim > 128){
01457       printf("ERROR: invalid T dimension (%d)\n", ydim);
01458       usage(argv);
01459     }
01460     i++;
01461     ret = 0;
01462     goto out;
01463   }
01464 
01465 
01466   if( strcmp(argv[i], "--zdim") == 0){
01467     if (i+1 >= argc){
01468       usage(argv);
01469     }
01470     zdim= atoi(argv[i+1]);
01471     if (zdim < 0 || zdim > 128){
01472       printf("ERROR: invalid T dimension (%d)\n", zdim);
01473       usage(argv);
01474     }
01475     i++;
01476     ret = 0;
01477     goto out;
01478   }
01479 
01480   if( strcmp(argv[i], "--tdim") == 0){
01481     if (i+1 >= argc){
01482       usage(argv);
01483     }       
01484     tdim =  atoi(argv[i+1]);
01485     if (tdim < 0 || tdim > 128){
01486       errorQuda("Error: invalid t dimension");
01487     }
01488     i++;
01489     ret = 0;
01490     goto out;
01491   }
01492 
01493   if( strcmp(argv[i], "--sdim") == 0){
01494     if (i+1 >= argc){
01495       usage(argv);
01496     }       
01497     int sdim =  atoi(argv[i+1]);
01498     if (sdim < 0 || sdim > 128){
01499       printfQuda("Error: invalid S dimension\n");
01500     }
01501     xdim=ydim=zdim=sdim;
01502     i++;
01503     ret = 0;
01504     goto out;
01505   }
01506   
01507   if( strcmp(argv[i], "--dagger") == 0){
01508     dagger = QUDA_DAG_YES;
01509     ret = 0;
01510     goto out;
01511   }     
01512   
01513   if( strcmp(argv[i], "--partition") == 0){
01514     if (i+1 >= argc){
01515       usage(argv);
01516     }     
01517     int value  =  atoi(argv[i+1]);
01518     for(int j=0; j < 4;j++){
01519       if (value &  (1 << j)){
01520         commDimPartitionedSet(j);
01521       }
01522     }
01523     i++;
01524     ret = 0;
01525     goto out;
01526   }
01527   
01528   if( strcmp(argv[i], "--kernel_pack_t") == 0){
01529     kernelPackT = true;
01530     ret= 0;
01531     goto out;
01532   }
01533 
01534 
01535   if( strcmp(argv[i], "--tune") == 0){
01536     if (i+1 >= argc){
01537       usage(argv);
01538     }       
01539 
01540     if (strcmp(argv[i+1], "true") == 0){
01541       tune = true;
01542     }else if (strcmp(argv[i+1], "false") == 0){
01543       tune = false;
01544     }else{
01545       fprintf(stderr, "Error: invalid tuning type\n");  
01546       exit(1);
01547     }
01548 
01549     i++;
01550     ret = 0;
01551     goto out;
01552   }
01553 
01554   if( strcmp(argv[i], "--xgridsize") == 0){
01555     if (i+1 >= argc){ 
01556       usage(argv);
01557     }     
01558     int xsize =  atoi(argv[i+1]);
01559     if (xsize <= 0 ){
01560       errorQuda("Error: invalid X grid size");
01561     }
01562     gridsize_from_cmdline[0] = xsize;
01563     i++;
01564     ret = 0;
01565     goto out;
01566   }
01567 
01568   if( strcmp(argv[i], "--ygridsize") == 0){
01569     if (i+1 >= argc){
01570       usage(argv);
01571     }     
01572     int ysize =  atoi(argv[i+1]);
01573     if (ysize <= 0 ){
01574       errorQuda("Error: invalid Y grid size");
01575     }
01576     gridsize_from_cmdline[1] = ysize;
01577     i++;
01578     ret = 0;
01579     goto out;
01580   }
01581 
01582   if( strcmp(argv[i], "--zgridsize") == 0){
01583     if (i+1 >= argc){
01584       usage(argv);
01585     }     
01586     int zsize =  atoi(argv[i+1]);
01587     if (zsize <= 0 ){
01588       errorQuda("Error: invalid Z grid size");
01589     }
01590     gridsize_from_cmdline[2] = zsize;
01591     i++;
01592     ret = 0;
01593     goto out;
01594   }
01595   
01596   if( strcmp(argv[i], "--tgridsize") == 0){
01597     if (i+1 >= argc){
01598       usage(argv);
01599     }     
01600     int tsize =  atoi(argv[i+1]);
01601     if (tsize <= 0 ){
01602       errorQuda("Error: invalid T grid size");
01603     }
01604     gridsize_from_cmdline[3] = tsize;
01605     i++;
01606     ret = 0;
01607     goto out;
01608   }
01609   
01610   if( strcmp(argv[i], "--dslash_type") == 0){
01611     if (i+1 >= argc){
01612       usage(argv);
01613     }     
01614     dslash_type =  get_dslash_type(argv[i+1]);
01615     i++;
01616     ret = 0;
01617     goto out;
01618   }
01619   
01620   if( strcmp(argv[i], "--load-gauge") == 0){
01621     if (i+1 >= argc){
01622       usage(argv);
01623     }     
01624     strcpy(latfile, argv[i+1]);
01625     i++;
01626     ret = 0;
01627     goto out;
01628   }
01629   
01630   if( strcmp(argv[i], "--version") == 0){
01631     printf("This program is linked with QUDA library, version %s,", 
01632            get_quda_ver_str());
01633     printf(" %s GPU build\n", msg);
01634     exit(0);
01635   }
01636 
01637  out:
01638   *idx = i;
01639   return ret ;
01640 
01641 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines