QUDA v0.3.2
A library for QCD on GPUs

quda/tests/twisted_mass_dslash_reference.cpp

Go to the documentation of this file.
00001 #include <stdio.h>
00002 #include <stdlib.h>
00003 #include <math.h>
00004 
00005 #include <util_quda.h>
00006 
00007 #include <test_util.h>
00008 #include <blas_reference.h>
00009 #include <twisted_mass_dslash_reference.h>
00010 
00011 int Z[4];
00012 int V;
00013 int Vh;
00014 
00015 void setDims(int *X) {
00016   V = 1;
00017   for (int d=0; d< 4; d++) {
00018     V *= X[d];
00019     Z[d] = X[d];
00020   }
00021   Vh = V/2;
00022 }
00023 
00024 template <typename Float>
00025 void sum(Float *dst, Float *a, Float *b, int cnt) {
00026   for (int i = 0; i < cnt; i++)
00027     dst[i] = a[i] + b[i];
00028 }
00029 
00030 // performs the operation y[i] = x[i] + a*y[i]
00031 template <typename Float>
00032 void xpay(Float *x, Float a, Float *y, int len) {
00033     for (int i=0; i<len; i++) y[i] = x[i] + a*y[i];
00034 }
00035 
00036 
00037 
00038 template <typename Float>
00039 Float *gaugeLink(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd) {
00040   Float **gaugeField;
00041   int j;
00042   
00043   if (dir % 2 == 0) {
00044     j = i;
00045     gaugeField = (oddBit ? gaugeOdd : gaugeEven);
00046   }
00047   else {
00048     switch (dir) {
00049     case 1: j = neighborIndex(i, oddBit, 0, 0, 0, -1); break;
00050     case 3: j = neighborIndex(i, oddBit, 0, 0, -1, 0); break;
00051     case 5: j = neighborIndex(i, oddBit, 0, -1, 0, 0); break;
00052     case 7: j = neighborIndex(i, oddBit, -1, 0, 0, 0); break;
00053     default: j = -1; break;
00054     }
00055     gaugeField = (oddBit ? gaugeEven : gaugeOdd);
00056   }
00057   
00058   return &gaugeField[dir/2][j*(3*3*2)];
00059 }
00060 
00061 template <typename Float>
00062 Float *spinorNeighbor(int i, int dir, int oddBit, Float *spinorField) {
00063   int j;
00064   switch (dir) {
00065   case 0: j = neighborIndex(i, oddBit, 0, 0, 0, +1); break;
00066   case 1: j = neighborIndex(i, oddBit, 0, 0, 0, -1); break;
00067   case 2: j = neighborIndex(i, oddBit, 0, 0, +1, 0); break;
00068   case 3: j = neighborIndex(i, oddBit, 0, 0, -1, 0); break;
00069   case 4: j = neighborIndex(i, oddBit, 0, +1, 0, 0); break;
00070   case 5: j = neighborIndex(i, oddBit, 0, -1, 0, 0); break;
00071   case 6: j = neighborIndex(i, oddBit, +1, 0, 0, 0); break;
00072   case 7: j = neighborIndex(i, oddBit, -1, 0, 0, 0); break;
00073   default: j = -1; break;
00074   }
00075   
00076   return &spinorField[j*(4*3*2)];
00077 }
00078 
00079 template <typename sFloat, typename gFloat>
00080 void dot(sFloat* res, gFloat* a, sFloat* b) {
00081   res[0] = res[1] = 0;
00082   for (int m = 0; m < 3; m++) {
00083     sFloat a_re = a[2*m+0];
00084     sFloat a_im = a[2*m+1];
00085     sFloat b_re = b[2*m+0];
00086     sFloat b_im = b[2*m+1];
00087     res[0] += a_re * b_re - a_im * b_im;
00088     res[1] += a_re * b_im + a_im * b_re;
00089   }
00090 }
00091 
00092 template <typename Float>
00093 void su3Transpose(Float *res, Float *mat) {
00094   for (int m = 0; m < 3; m++) {
00095     for (int n = 0; n < 3; n++) {
00096       res[m*(3*2) + n*(2) + 0] = + mat[n*(3*2) + m*(2) + 0];
00097       res[m*(3*2) + n*(2) + 1] = - mat[n*(3*2) + m*(2) + 1];
00098     }
00099   }
00100 }
00101 
00102 template <typename sFloat, typename gFloat>
00103 void su3Mul(sFloat *res, gFloat *mat, sFloat *vec) {
00104   for (int n = 0; n < 3; n++) dot(&res[n*(2)], &mat[n*(3*2)], vec);
00105 }
00106 
00107 template <typename sFloat, typename gFloat>
00108 void su3Tmul(sFloat *res, gFloat *mat, sFloat *vec) {
00109   gFloat matT[3*3*2];
00110   su3Transpose(matT, mat);
00111   su3Mul(res, matT, vec);
00112 }
00113 
00114 const double projector[8][4][4][2] = {
00115   {
00116     {{1,0}, {0,0}, {0,0}, {0,-1}},
00117     {{0,0}, {1,0}, {0,-1}, {0,0}},
00118     {{0,0}, {0,1}, {1,0}, {0,0}},
00119     {{0,1}, {0,0}, {0,0}, {1,0}}
00120   },
00121   {
00122     {{1,0}, {0,0}, {0,0}, {0,1}},
00123     {{0,0}, {1,0}, {0,1}, {0,0}},
00124     {{0,0}, {0,-1}, {1,0}, {0,0}},
00125     {{0,-1}, {0,0}, {0,0}, {1,0}}
00126   },
00127   {
00128     {{1,0}, {0,0}, {0,0}, {1,0}},
00129     {{0,0}, {1,0}, {-1,0}, {0,0}},
00130     {{0,0}, {-1,0}, {1,0}, {0,0}},
00131     {{1,0}, {0,0}, {0,0}, {1,0}}
00132   },
00133   {
00134     {{1,0}, {0,0}, {0,0}, {-1,0}},
00135     {{0,0}, {1,0}, {1,0}, {0,0}},
00136     {{0,0}, {1,0}, {1,0}, {0,0}},
00137     {{-1,0}, {0,0}, {0,0}, {1,0}}
00138   },
00139   {
00140     {{1,0}, {0,0}, {0,-1}, {0,0}},
00141     {{0,0}, {1,0}, {0,0}, {0,1}},
00142     {{0,1}, {0,0}, {1,0}, {0,0}},
00143     {{0,0}, {0,-1}, {0,0}, {1,0}}
00144   },
00145   {
00146     {{1,0}, {0,0}, {0,1}, {0,0}},
00147     {{0,0}, {1,0}, {0,0}, {0,-1}},
00148     {{0,-1}, {0,0}, {1,0}, {0,0}},
00149     {{0,0}, {0,1}, {0,0}, {1,0}}
00150   },
00151   {
00152     {{1,0}, {0,0}, {-1,0}, {0,0}},
00153     {{0,0}, {1,0}, {0,0}, {-1,0}},
00154     {{-1,0}, {0,0}, {1,0}, {0,0}},
00155     {{0,0}, {-1,0}, {0,0}, {1,0}}
00156   },
00157   {
00158     {{1,0}, {0,0}, {1,0}, {0,0}},
00159     {{0,0}, {1,0}, {0,0}, {1,0}},
00160     {{1,0}, {0,0}, {1,0}, {0,0}},
00161     {{0,0}, {1,0}, {0,0}, {1,0}}
00162   }
00163 };
00164 
00165 
00166 // todo pass projector
00167 template <typename Float>
00168 void multiplySpinorByDiracProjector(Float *res, int projIdx, Float *spinorIn) {
00169   for (int i=0; i<4*3*2; i++) res[i] = 0.0;
00170 
00171   for (int s = 0; s < 4; s++) {
00172     for (int t = 0; t < 4; t++) {
00173       Float projRe = projector[projIdx][s][t][0];
00174       Float projIm = projector[projIdx][s][t][1];
00175       
00176       for (int m = 0; m < 3; m++) {
00177         Float spinorRe = spinorIn[t*(3*2) + m*(2) + 0];
00178         Float spinorIm = spinorIn[t*(3*2) + m*(2) + 1];
00179         res[s*(3*2) + m*(2) + 0] += projRe*spinorRe - projIm*spinorIm;
00180         res[s*(3*2) + m*(2) + 1] += projRe*spinorIm + projIm*spinorRe;
00181       }
00182     }
00183   }
00184 }
00185 
00186 
00187 //
00188 // dslashReference()
00189 //
00190 // if oddBit is zero: calculate odd parity spinor elements (using even parity spinor)
00191 // if oddBit is one:  calculate even parity spinor elements
00192 //
00193 // if daggerBit is zero: perform ordinary dslash operator
00194 // if daggerBit is one:  perform hermitian conjugate of dslash
00195 //
00196 template <typename sFloat, typename gFloat>
00197 void dslashReference(sFloat *res, gFloat **gaugeFull, sFloat *spinorField, int oddBit, int daggerBit) {
00198   for (int i=0; i<Vh*4*3*2; i++) res[i] = 0.0;
00199   
00200   gFloat *gaugeEven[4], *gaugeOdd[4];
00201   for (int dir = 0; dir < 4; dir++) {  
00202     gaugeEven[dir] = gaugeFull[dir];
00203     gaugeOdd[dir]  = gaugeFull[dir]+Vh*gaugeSiteSize;
00204   }
00205   
00206   for (int i = 0; i < Vh; i++) {
00207     for (int dir = 0; dir < 8; dir++) {
00208       gFloat *gauge = gaugeLink(i, dir, oddBit, gaugeEven, gaugeOdd);
00209       sFloat *spinor = spinorNeighbor(i, dir, oddBit, spinorField);
00210       
00211       sFloat projectedSpinor[4*3*2], gaugedSpinor[4*3*2];
00212       int projIdx = 2*(dir/2)+(dir+daggerBit)%2;
00213       multiplySpinorByDiracProjector(projectedSpinor, projIdx, spinor);
00214       
00215       for (int s = 0; s < 4; s++) {
00216         if (dir % 2 == 0)
00217           su3Mul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]);
00218         else
00219           su3Tmul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]);
00220       }
00221       
00222       sum(&res[i*(4*3*2)], &res[i*(4*3*2)], gaugedSpinor, 4*3*2);
00223     }
00224   }
00225 }
00226 
00227 // applies b*(1 + i*a*gamma_5)
00228 template <typename sFloat>
00229 void twistGamma5(sFloat *out, sFloat *in, const int dagger, const sFloat kappa, const sFloat mu, 
00230                  const QudaTwistFlavorType flavor, const int V, QudaTwistGamma5Type twist) {
00231 
00232   sFloat a=0.0,b=0.0;
00233   if (twist == QUDA_TWIST_GAMMA5_DIRECT) { // applying the twist
00234     a = 2.0 * kappa * mu * flavor; // mu already includes the flavor
00235     b = 1.0;
00236   } else if (twist == QUDA_TWIST_GAMMA5_INVERSE) { // applying the inverse twist
00237     a = -2.0 * kappa * mu * flavor;
00238     b = 1.0 / (1.0 + a*a);
00239   } else {
00240     printf("Twist type %d not defined\n", twist);
00241     exit(0);
00242   }
00243 
00244   if (dagger) a *= -1.0;
00245 
00246   for(int i = 0; i < V; i++) {
00247     sFloat tmp[24];
00248     for(int s = 0; s < 4; s++)
00249       for(int c = 0; c < 3; c++) {
00250         sFloat a5 = ((s / 2) ? -1.0 : +1.0) * a;          
00251         tmp[s * 6 + c * 2 + 0] = b* (in[i * 24 + s * 6 + c * 2 + 0] - a5*in[i * 24 + s * 6 + c * 2 + 1]);
00252         tmp[s * 6 + c * 2 + 1] = b* (in[i * 24 + s * 6 + c * 2 + 1] + a5*in[i * 24 + s * 6 + c * 2 + 0]);
00253       }
00254 
00255     for (int j=0; j<24; j++) out[i*24+j] = tmp[j];
00256   }
00257   
00258 }
00259 
00260 // this actually applies the preconditioned dslash, e.g., D_ee^{-1} D_eo or D_oo^{-1} D_oe
00261 void dslash(void *res, void **gaugeFull, void *spinorField, double kappa, double mu, 
00262             QudaTwistFlavorType flavor, int oddBit, int daggerBit,
00263             QudaPrecision sPrecision, QudaPrecision gPrecision) {
00264 
00265   if (!daggerBit) {
00266     if (sPrecision == QUDA_DOUBLE_PRECISION) {
00267       if (gPrecision == QUDA_DOUBLE_PRECISION) {
00268         dslashReference((double*)res, (double**)gaugeFull, (double*)spinorField, oddBit, daggerBit);
00269       } else {
00270         dslashReference((double*)res, (float**)gaugeFull, (double*)spinorField, oddBit, daggerBit);
00271       } 
00272       twistGamma5((double*)res, (double*)res, daggerBit, kappa, mu, 
00273                   flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE);
00274     } else {
00275       if (gPrecision == QUDA_DOUBLE_PRECISION) {
00276         dslashReference((float*)res, (double**)gaugeFull, (float*)spinorField, oddBit, daggerBit);
00277       } else {
00278         dslashReference((float*)res, (float**)gaugeFull, (float*)spinorField, oddBit, daggerBit);
00279       }
00280       twistGamma5((float*)res, (float*)res, daggerBit, (float)kappa, (float)mu, 
00281                   flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE);
00282     }
00283   } else {
00284     if (sPrecision == QUDA_DOUBLE_PRECISION) {
00285       twistGamma5((double*)spinorField, (double*)spinorField, daggerBit, kappa, mu, 
00286                   flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE);
00287       if (gPrecision == QUDA_DOUBLE_PRECISION) {
00288         dslashReference((double*)res, (double**)gaugeFull, (double*)spinorField, oddBit, daggerBit);
00289       } else {
00290         dslashReference((double*)res, (float**)gaugeFull, (double*)spinorField, oddBit, daggerBit);
00291       }
00292       twistGamma5((double*)spinorField, (double*)spinorField, daggerBit, kappa, mu, 
00293                   flavor, Vh, QUDA_TWIST_GAMMA5_DIRECT);
00294     } else {
00295       twistGamma5((float*)spinorField, (float*)spinorField, daggerBit, (float)kappa, (float)mu, 
00296                   flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE);
00297       if (gPrecision == QUDA_DOUBLE_PRECISION) {
00298         dslashReference((float*)res, (double**)gaugeFull, (float*)spinorField, oddBit, daggerBit);
00299       } else {
00300         dslashReference((float*)res, (float**)gaugeFull, (float*)spinorField, oddBit, daggerBit);
00301       }
00302       twistGamma5((float*)spinorField, (float*)spinorField, daggerBit, (float)kappa, (float)mu, 
00303                   flavor, Vh, QUDA_TWIST_GAMMA5_DIRECT);
00304     }
00305   }
00306 }
00307 
00308 template <typename sFloat, typename gFloat>
00309 void Mat(sFloat *out, gFloat **gauge, sFloat *in, sFloat kappa, sFloat mu, 
00310          QudaTwistFlavorType flavor, int daggerBit) {
00311 
00312   sFloat *inEven = in;
00313   sFloat *inOdd  = in + Vh*spinorSiteSize;
00314   sFloat *outEven = out;
00315   sFloat *outOdd = out + Vh*spinorSiteSize;
00316   
00317   sFloat *tmp = (sFloat*)malloc(V*spinorSiteSize*sizeof(sFloat));
00318 
00319   // full dslash operator
00320   dslashReference(outOdd, gauge, inEven, 1, daggerBit);
00321   dslashReference(outEven, gauge, inOdd, 0, daggerBit);
00322   // apply the twist term
00323   twistGamma5(tmp, in, daggerBit, kappa, mu, flavor, V, QUDA_TWIST_GAMMA5_DIRECT);
00324 
00325   // combine
00326   xpay(tmp, -kappa, out, V*spinorSiteSize);
00327 
00328   free(tmp);
00329 }
00330 
00331 void mat(void *out, void **gauge, void *in, double kappa, double mu, 
00332          QudaTwistFlavorType flavor, int dagger_bit,
00333          QudaPrecision sPrecision, QudaPrecision gPrecision) {
00334 
00335   if (sPrecision == QUDA_DOUBLE_PRECISION)
00336     if (gPrecision == QUDA_DOUBLE_PRECISION) 
00337       Mat((double*)out, (double**)gauge, (double*)in, (double)kappa, (double)mu, flavor, dagger_bit);
00338     else 
00339       Mat((double*)out, (float**)gauge, (double*)in, (double)kappa, (double)mu, flavor, dagger_bit);
00340   else
00341     if (gPrecision == QUDA_DOUBLE_PRECISION) 
00342       Mat((float*)out, (double**)gauge, (float*)in, (float)kappa, (float)mu, flavor, dagger_bit);
00343     else 
00344       Mat((float*)out, (float**)gauge, (float*)in, (float)kappa, (float)mu, flavor, dagger_bit);
00345 }
00346 
00347 template <typename Float>
00348 double norm2(Float *v, int len) {
00349   double sum=0.0;
00350   for (int i=0; i<len; i++) sum += v[i]*v[i];
00351   return sum;
00352 }
00353 
00354 // Apply the even-odd preconditioned Dirac operator
00355 template <typename sFloat, typename gFloat>
00356 void MatPC(sFloat *outEven, gFloat **gauge, sFloat *inEven, sFloat kappa, sFloat mu, 
00357            QudaTwistFlavorType flavor, int daggerBit, QudaMatPCType matpc_type) {
00358   
00359   sFloat *tmp = (sFloat*)malloc(Vh*spinorSiteSize*sizeof(sFloat));
00360     
00361   if (!daggerBit) {
00362     if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
00363       dslashReference(tmp, gauge, inEven, 1, daggerBit);
00364       twistGamma5(tmp, tmp, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE);
00365       dslashReference(outEven, gauge, tmp, 0, daggerBit);
00366       twistGamma5(outEven, outEven, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE);
00367     } else if (matpc_type == QUDA_MATPC_ODD_ODD) {
00368       dslashReference(tmp, gauge, inEven, 0, daggerBit);
00369       twistGamma5(tmp, tmp, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE);
00370       dslashReference(outEven, gauge, tmp, 1, daggerBit);
00371       twistGamma5(outEven, outEven, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE);
00372     }
00373   } else {
00374     if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
00375       twistGamma5(inEven, inEven, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE);
00376       dslashReference(tmp, gauge, inEven, 1, daggerBit);
00377       twistGamma5(tmp, tmp, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE);
00378       dslashReference(outEven, gauge, tmp, 0, daggerBit);
00379       twistGamma5(inEven, inEven, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_DIRECT);
00380     } else if (matpc_type == QUDA_MATPC_ODD_ODD) {
00381       twistGamma5(inEven, inEven, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE);
00382       dslashReference(tmp, gauge, inEven, 0, daggerBit);
00383       twistGamma5(tmp, tmp, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE);
00384       dslashReference(outEven, gauge, tmp, 1, daggerBit);
00385       twistGamma5(inEven, inEven, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_DIRECT); // undo
00386     }
00387   }
00388   // lastly apply the kappa term
00389   sFloat kappa2 = -kappa*kappa;
00390   xpay(inEven, kappa2, outEven, Vh*spinorSiteSize);
00391   free(tmp);
00392 
00393 }
00394 
00395 void matpc(void *outEven, void **gauge, void *inEven, double kappa, double mu, QudaTwistFlavorType flavor,
00396            QudaMatPCType matpc_type, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision) {
00397 
00398   if (matpc_type != QUDA_MATPC_EVEN_EVEN && matpc_type != QUDA_MATPC_ODD_ODD) {
00399     printf("Only symmetric preconditioning is implemented in reference\n");
00400     exit(-1);
00401   }
00402 
00403   if (sPrecision == QUDA_DOUBLE_PRECISION)
00404     if (gPrecision == QUDA_DOUBLE_PRECISION) 
00405       MatPC((double*)outEven, (double**)gauge, (double*)inEven, (double)kappa, (double)mu, 
00406             flavor, dagger_bit, matpc_type);
00407     else
00408       MatPC((double*)outEven, (float**)gauge, (double*)inEven, (double)kappa, (double)mu, 
00409             flavor, dagger_bit, matpc_type);
00410   else
00411     if (gPrecision == QUDA_DOUBLE_PRECISION) 
00412       MatPC((float*)outEven, (double**)gauge, (float*)inEven, (float)kappa, (float)mu, 
00413             flavor, dagger_bit, matpc_type);
00414     else
00415       MatPC((float*)outEven, (float**)gauge, (float*)inEven, (float)kappa, (float)mu,
00416             flavor, dagger_bit, matpc_type);
00417 }
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Defines