QUDA v0.4.0
A library for QCD on GPUs
quda/tests/wilson_dslash_reference.cpp
Go to the documentation of this file.
00001 #include <stdio.h>
00002 #include <stdlib.h>
00003 #include <math.h>
00004 
00005 #include <util_quda.h>
00006 
00007 #include <test_util.h>
00008 #include <blas_reference.h>
00009 #include <wilson_dslash_reference.h>
00010 
00011 #include <gauge_field.h>
00012 #include <color_spinor_field.h>
00013 #include <face_quda.h>
00014 
00015 static int mySpinorSiteSize = 24;
00016 
00017 #include <dslash_util.h>
00018 
00019 static const double projector[8][4][4][2] = {
00020   {
00021     {{1,0}, {0,0}, {0,0}, {0,-1}},
00022     {{0,0}, {1,0}, {0,-1}, {0,0}},
00023     {{0,0}, {0,1}, {1,0}, {0,0}},
00024     {{0,1}, {0,0}, {0,0}, {1,0}}
00025   },
00026   {
00027     {{1,0}, {0,0}, {0,0}, {0,1}},
00028     {{0,0}, {1,0}, {0,1}, {0,0}},
00029     {{0,0}, {0,-1}, {1,0}, {0,0}},
00030     {{0,-1}, {0,0}, {0,0}, {1,0}}
00031   },
00032   {
00033     {{1,0}, {0,0}, {0,0}, {1,0}},
00034     {{0,0}, {1,0}, {-1,0}, {0,0}},
00035     {{0,0}, {-1,0}, {1,0}, {0,0}},
00036     {{1,0}, {0,0}, {0,0}, {1,0}}
00037   },
00038   {
00039     {{1,0}, {0,0}, {0,0}, {-1,0}},
00040     {{0,0}, {1,0}, {1,0}, {0,0}},
00041     {{0,0}, {1,0}, {1,0}, {0,0}},
00042     {{-1,0}, {0,0}, {0,0}, {1,0}}
00043   },
00044   {
00045     {{1,0}, {0,0}, {0,-1}, {0,0}},
00046     {{0,0}, {1,0}, {0,0}, {0,1}},
00047     {{0,1}, {0,0}, {1,0}, {0,0}},
00048     {{0,0}, {0,-1}, {0,0}, {1,0}}
00049   },
00050   {
00051     {{1,0}, {0,0}, {0,1}, {0,0}},
00052     {{0,0}, {1,0}, {0,0}, {0,-1}},
00053     {{0,-1}, {0,0}, {1,0}, {0,0}},
00054     {{0,0}, {0,1}, {0,0}, {1,0}}
00055   },
00056   {
00057     {{1,0}, {0,0}, {-1,0}, {0,0}},
00058     {{0,0}, {1,0}, {0,0}, {-1,0}},
00059     {{-1,0}, {0,0}, {1,0}, {0,0}},
00060     {{0,0}, {-1,0}, {0,0}, {1,0}}
00061   },
00062   {
00063     {{1,0}, {0,0}, {1,0}, {0,0}},
00064     {{0,0}, {1,0}, {0,0}, {1,0}},
00065     {{1,0}, {0,0}, {1,0}, {0,0}},
00066     {{0,0}, {1,0}, {0,0}, {1,0}}
00067   }
00068 };
00069 
00070 
00071 // todo pass projector
00072 template <typename Float>
00073 void multiplySpinorByDiracProjector(Float *res, int projIdx, Float *spinorIn) {
00074   for (int i=0; i<4*3*2; i++) res[i] = 0.0;
00075 
00076   for (int s = 0; s < 4; s++) {
00077     for (int t = 0; t < 4; t++) {
00078       Float projRe = projector[projIdx][s][t][0];
00079       Float projIm = projector[projIdx][s][t][1];
00080       
00081       for (int m = 0; m < 3; m++) {
00082         Float spinorRe = spinorIn[t*(3*2) + m*(2) + 0];
00083         Float spinorIm = spinorIn[t*(3*2) + m*(2) + 1];
00084         res[s*(3*2) + m*(2) + 0] += projRe*spinorRe - projIm*spinorIm;
00085         res[s*(3*2) + m*(2) + 1] += projRe*spinorIm + projIm*spinorRe;
00086       }
00087     }
00088   }
00089 }
00090 
00091 
00092 //
00093 // dslashReference()
00094 //
00095 // if oddBit is zero: calculate odd parity spinor elements (using even parity spinor)
00096 // if oddBit is one:  calculate even parity spinor elements
00097 //
00098 // if daggerBit is zero: perform ordinary dslash operator
00099 // if daggerBit is one:  perform hermitian conjugate of dslash
00100 //
00101 
00102 #ifndef MULTI_GPU
00103 
00104 template <typename sFloat, typename gFloat>
00105 void dslashReference(sFloat *res, gFloat **gaugeFull, sFloat *spinorField, int oddBit, int daggerBit) {
00106   for (int i=0; i<Vh*mySpinorSiteSize; i++) res[i] = 0.0;
00107   
00108   gFloat *gaugeEven[4], *gaugeOdd[4];
00109   for (int dir = 0; dir < 4; dir++) {  
00110     gaugeEven[dir] = gaugeFull[dir];
00111     gaugeOdd[dir]  = gaugeFull[dir]+Vh*gaugeSiteSize;
00112   }
00113   
00114   for (int i = 0; i < Vh; i++) {
00115     for (int dir = 0; dir < 8; dir++) {
00116       gFloat *gauge = gaugeLink(i, dir, oddBit, gaugeEven, gaugeOdd, 1);
00117       sFloat *spinor = spinorNeighbor(i, dir, oddBit, spinorField, 1);
00118       
00119       sFloat projectedSpinor[4*3*2], gaugedSpinor[4*3*2];
00120       int projIdx = 2*(dir/2)+(dir+daggerBit)%2;
00121       multiplySpinorByDiracProjector(projectedSpinor, projIdx, spinor);
00122       
00123       for (int s = 0; s < 4; s++) {
00124         if (dir % 2 == 0) su3Mul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]);
00125         else su3Tmul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]);
00126       }
00127       
00128       sum(&res[i*(4*3*2)], &res[i*(4*3*2)], gaugedSpinor, 4*3*2);
00129     }
00130   }
00131 }
00132 
00133 #else
00134 
00135 template <typename sFloat, typename gFloat>
00136 void dslashReference(sFloat *res, gFloat **gaugeFull,  gFloat **ghostGauge, sFloat *spinorField, 
00137                      sFloat **fwdSpinor, sFloat **backSpinor, int oddBit, int daggerBit) {
00138   for (int i=0; i<Vh*mySpinorSiteSize; i++) res[i] = 0.0;
00139   
00140   gFloat *gaugeEven[4], *gaugeOdd[4];
00141   gFloat *ghostGaugeEven[4], *ghostGaugeOdd[4];
00142   for (int dir = 0; dir < 4; dir++) {  
00143     gaugeEven[dir] = gaugeFull[dir];
00144     gaugeOdd[dir]  = gaugeFull[dir]+Vh*gaugeSiteSize;
00145 
00146     ghostGaugeEven[dir] = ghostGauge[dir];
00147     ghostGaugeOdd[dir] = ghostGauge[dir] + (faceVolume[dir]/2)*gaugeSiteSize;
00148   }
00149   
00150   for (int i = 0; i < Vh; i++) {
00151 
00152     for (int dir = 0; dir < 8; dir++) {
00153       gFloat *gauge = gaugeLink_mg4dir(i, dir, oddBit, gaugeEven, gaugeOdd, ghostGaugeEven, ghostGaugeOdd, 1, 1);
00154       sFloat *spinor = spinorNeighbor_mg4dir(i, dir, oddBit, spinorField, fwdSpinor, backSpinor, 1, 1);
00155       
00156       sFloat projectedSpinor[mySpinorSiteSize], gaugedSpinor[mySpinorSiteSize];
00157       int projIdx = 2*(dir/2)+(dir+daggerBit)%2;
00158       multiplySpinorByDiracProjector(projectedSpinor, projIdx, spinor);
00159       
00160       for (int s = 0; s < 4; s++) {
00161         if (dir % 2 == 0) su3Mul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]);
00162         else su3Tmul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]);
00163       }
00164       
00165       sum(&res[i*(4*3*2)], &res[i*(4*3*2)], gaugedSpinor, 4*3*2);
00166     }
00167 
00168   }
00169 }
00170 
00171 #endif
00172 
00173 // this actually applies the preconditioned dslash, e.g., D_ee^{-1} D_eo or D_oo^{-1} D_oe
00174 void wil_dslash(void *out, void **gauge, void *in, int oddBit, int daggerBit,
00175                 QudaPrecision precision, QudaGaugeParam &gauge_param) {
00176   
00177 #ifndef MULTI_GPU  
00178   if (precision == QUDA_DOUBLE_PRECISION)
00179     dslashReference((double*)out, (double**)gauge, (double*)in, oddBit, daggerBit);
00180   else
00181     dslashReference((float*)out, (float**)gauge, (float*)in, oddBit, daggerBit);
00182 #else
00183 
00184   GaugeFieldParam gauge_field_param(gauge, gauge_param);
00185   cpuGaugeField cpu(gauge_field_param);
00186   cpu.exchangeGhost();
00187   void **ghostGauge = (void**)cpu.Ghost();
00188 
00189   // Get spinor ghost fields
00190   // First wrap the input spinor into a ColorSpinorField
00191   ColorSpinorParam csParam;
00192   csParam.v = in;
00193   csParam.fieldLocation = QUDA_CPU_FIELD_LOCATION;
00194   csParam.nColor = 3;
00195   csParam.nSpin = 4;
00196   csParam.nDim = 4;
00197   for (int d=0; d<4; d++) csParam.x[d] = Z[d];
00198   csParam.precision = precision;
00199   csParam.pad = 0;
00200   csParam.siteSubset = QUDA_PARITY_SITE_SUBSET;
00201   csParam.x[0] /= 2;
00202   csParam.siteOrder = QUDA_EVEN_ODD_SITE_ORDER;
00203   csParam.fieldOrder = QUDA_SPACE_SPIN_COLOR_FIELD_ORDER;
00204   csParam.gammaBasis = QUDA_DEGRAND_ROSSI_GAMMA_BASIS;
00205   csParam.create = QUDA_REFERENCE_FIELD_CREATE;
00206   
00207   cpuColorSpinorField inField(csParam);
00208 
00209   {  // Now do the exchange
00210     QudaParity otherParity = QUDA_INVALID_PARITY;
00211     if (oddBit == QUDA_EVEN_PARITY) otherParity = QUDA_ODD_PARITY;
00212     else if (oddBit == QUDA_ODD_PARITY) otherParity = QUDA_EVEN_PARITY;
00213     else errorQuda("ERROR: full parity not supported in function %s", __FUNCTION__);
00214 
00215     int nFace = 1;
00216     FaceBuffer faceBuf(Z, 4, mySpinorSiteSize, nFace, precision);
00217     faceBuf.exchangeCpuSpinor(inField, otherParity, daggerBit); 
00218   }
00219   void** fwd_nbr_spinor = inField.fwdGhostFaceBuffer;
00220   void** back_nbr_spinor = inField.backGhostFaceBuffer;
00221 
00222   if (precision == QUDA_DOUBLE_PRECISION) {
00223     dslashReference((double*)out, (double**)gauge, (double**)ghostGauge, (double*)in, 
00224                     (double**)fwd_nbr_spinor, (double**)back_nbr_spinor, oddBit, daggerBit);
00225   } else{
00226     dslashReference((float*)out, (float**)gauge, (float**)ghostGauge, (float*)in, 
00227                     (float**)fwd_nbr_spinor, (float**)back_nbr_spinor, oddBit, daggerBit);
00228   }
00229 
00230 #endif
00231 
00232 }
00233 
00234 // applies b*(1 + i*a*gamma_5)
00235 template <typename sFloat>
00236 void twistGamma5(sFloat *out, sFloat *in, const int dagger, const sFloat kappa, const sFloat mu, 
00237                  const QudaTwistFlavorType flavor, const int V, QudaTwistGamma5Type twist) {
00238 
00239   sFloat a=0.0,b=0.0;
00240   if (twist == QUDA_TWIST_GAMMA5_DIRECT) { // applying the twist
00241     a = 2.0 * kappa * mu * flavor; // mu already includes the flavor
00242     b = 1.0;
00243   } else if (twist == QUDA_TWIST_GAMMA5_INVERSE) { // applying the inverse twist
00244     a = -2.0 * kappa * mu * flavor;
00245     b = 1.0 / (1.0 + a*a);
00246   } else {
00247     printf("Twist type %d not defined\n", twist);
00248     exit(0);
00249   }
00250 
00251   if (dagger) a *= -1.0;
00252 
00253   for(int i = 0; i < V; i++) {
00254     sFloat tmp[24];
00255     for(int s = 0; s < 4; s++)
00256       for(int c = 0; c < 3; c++) {
00257         sFloat a5 = ((s / 2) ? -1.0 : +1.0) * a;          
00258         tmp[s * 6 + c * 2 + 0] = b* (in[i * 24 + s * 6 + c * 2 + 0] - a5*in[i * 24 + s * 6 + c * 2 + 1]);
00259         tmp[s * 6 + c * 2 + 1] = b* (in[i * 24 + s * 6 + c * 2 + 1] + a5*in[i * 24 + s * 6 + c * 2 + 0]);
00260       }
00261 
00262     for (int j=0; j<24; j++) out[i*24+j] = tmp[j];
00263   }
00264   
00265 }
00266 
00267 void twist_gamma5(void *out, void *in,  int daggerBit, double kappa, double mu, QudaTwistFlavorType flavor, 
00268                  int V, QudaTwistGamma5Type twist, QudaPrecision precision) {
00269 
00270   if (precision == QUDA_DOUBLE_PRECISION) {
00271     twistGamma5((double*)out, (double*)in, daggerBit, kappa, mu, flavor, V, twist);
00272   } else {
00273     twistGamma5((float*)out, (float*)in, daggerBit, (float)kappa, (float)mu, flavor, V, twist);
00274   } 
00275 }
00276 
00277 
00278 void tm_dslash(void *res, void **gaugeFull, void *spinorField, double kappa, double mu, 
00279                QudaTwistFlavorType flavor, int oddBit, int daggerBit, QudaPrecision precision,
00280                QudaGaugeParam &gauge_param)
00281 {
00282 
00283   if (daggerBit) twist_gamma5(spinorField, spinorField, daggerBit, kappa, mu, 
00284                               flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE, precision);
00285 
00286   wil_dslash(res, gaugeFull, spinorField, oddBit, daggerBit, precision, gauge_param);
00287 
00288   if (!daggerBit) {
00289     twist_gamma5(res, res, daggerBit, kappa, mu, flavor,
00290                  Vh, QUDA_TWIST_GAMMA5_INVERSE, precision);
00291   } else {
00292     twist_gamma5(spinorField, spinorField,  daggerBit, kappa, mu, flavor, 
00293                  Vh, QUDA_TWIST_GAMMA5_DIRECT, precision);
00294   }
00295 
00296 }
00297 
00298 void wil_mat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision,
00299              QudaGaugeParam &gauge_param) {
00300 
00301   void *inEven = in;
00302   void *inOdd  = (char*)in + Vh*spinorSiteSize*precision;
00303   void *outEven = out;
00304   void *outOdd = (char*)out + Vh*spinorSiteSize*precision;
00305 
00306   wil_dslash(outOdd, gauge, inEven, 1, dagger_bit, precision, gauge_param);
00307   wil_dslash(outEven, gauge, inOdd, 0, dagger_bit, precision, gauge_param);
00308 
00309   // lastly apply the kappa term
00310   if (precision == QUDA_DOUBLE_PRECISION) xpay((double*)in, -kappa, (double*)out, V*spinorSiteSize);
00311   else xpay((float*)in, -(float)kappa, (float*)out, V*spinorSiteSize);
00312 }
00313 
00314 void tm_mat(void *out, void **gauge, void *in, double kappa, double mu, 
00315             QudaTwistFlavorType flavor, int dagger_bit, QudaPrecision precision,
00316             QudaGaugeParam &gauge_param) {
00317 
00318   void *inEven = in;
00319   void *inOdd  = (char*)in + Vh*spinorSiteSize*precision;
00320   void *outEven = out;
00321   void *outOdd = (char*)out + Vh*spinorSiteSize*precision;
00322   void *tmp = malloc(V*spinorSiteSize*precision);
00323 
00324   wil_dslash(outOdd, gauge, inEven, 1, dagger_bit, precision, gauge_param);
00325   wil_dslash(outEven, gauge, inOdd, 0, dagger_bit, precision, gauge_param);
00326 
00327   // apply the twist term to the full lattice
00328   twist_gamma5(tmp, in, dagger_bit, kappa, mu, flavor, V, QUDA_TWIST_GAMMA5_DIRECT, precision);
00329 
00330   // combine
00331   if (precision == QUDA_DOUBLE_PRECISION) xpay((double*)tmp, -kappa, (double*)out, V*spinorSiteSize);
00332   else xpay((float*)tmp, -(float)kappa, (float*)out, V*spinorSiteSize);
00333 
00334   free(tmp);
00335 }
00336 
00337 // Apply the even-odd preconditioned Dirac operator
00338 void wil_matpc(void *outEven, void **gauge, void *inEven, double kappa, 
00339                QudaMatPCType matpc_type, int daggerBit, QudaPrecision precision,
00340                QudaGaugeParam &gauge_param) {
00341 
00342   void *tmp = malloc(Vh*spinorSiteSize*precision);
00343     
00344   // full dslash operator
00345   if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
00346     wil_dslash(tmp, gauge, inEven, 1, daggerBit, precision, gauge_param);
00347     wil_dslash(outEven, gauge, tmp, 0, daggerBit, precision, gauge_param);
00348   } else {
00349     wil_dslash(tmp, gauge, inEven, 0, daggerBit, precision, gauge_param);
00350     wil_dslash(outEven, gauge, tmp, 1, daggerBit, precision, gauge_param);
00351   }    
00352   
00353   // lastly apply the kappa term
00354   double kappa2 = -kappa*kappa;
00355   if (precision == QUDA_DOUBLE_PRECISION) xpay((double*)inEven, kappa2, (double*)outEven, Vh*spinorSiteSize);
00356   else xpay((float*)inEven, (float)kappa2, (float*)outEven, Vh*spinorSiteSize);
00357 
00358   free(tmp);
00359 }
00360 
00361 // Apply the even-odd preconditioned Dirac operator
00362 void tm_matpc(void *outEven, void **gauge, void *inEven, double kappa, double mu, QudaTwistFlavorType flavor,
00363               QudaMatPCType matpc_type, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param) {
00364 
00365   void *tmp = malloc(Vh*spinorSiteSize*precision);
00366     
00367   if (matpc_type == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) {
00368     wil_dslash(tmp, gauge, inEven, 1, daggerBit, precision, gauge_param);
00369     twist_gamma5(tmp, tmp, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE, precision);
00370     wil_dslash(outEven, gauge, tmp, 0, daggerBit, precision, gauge_param);
00371     twist_gamma5(tmp, inEven, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_DIRECT, precision);
00372   } else if (matpc_type == QUDA_MATPC_ODD_ODD_ASYMMETRIC) {
00373     wil_dslash(tmp, gauge, inEven, 0, daggerBit, precision, gauge_param);
00374     twist_gamma5(tmp, tmp, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE, precision);
00375     wil_dslash(outEven, gauge, tmp, 1, daggerBit, precision, gauge_param);
00376     twist_gamma5(tmp, inEven, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_DIRECT, precision);
00377   } else if (!daggerBit) {
00378     if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
00379       wil_dslash(tmp, gauge, inEven, 1, daggerBit, precision, gauge_param);
00380       twist_gamma5(tmp, tmp, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE, precision);
00381       wil_dslash(outEven, gauge, tmp, 0, daggerBit, precision, gauge_param);
00382       twist_gamma5(outEven, outEven, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE, precision);
00383     } else if (matpc_type == QUDA_MATPC_ODD_ODD) {
00384       wil_dslash(tmp, gauge, inEven, 0, daggerBit, precision, gauge_param);
00385       twist_gamma5(tmp, tmp, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE, precision);
00386       wil_dslash(outEven, gauge, tmp, 1, daggerBit, precision, gauge_param);
00387       twist_gamma5(outEven, outEven, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE, precision);
00388     }
00389   } else {
00390     if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
00391       twist_gamma5(inEven, inEven, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE, precision);
00392       wil_dslash(tmp, gauge, inEven, 1, daggerBit, precision, gauge_param);
00393       twist_gamma5(tmp, tmp, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE, precision);
00394       wil_dslash(outEven, gauge, tmp, 0, daggerBit, precision, gauge_param);
00395       twist_gamma5(inEven, inEven, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_DIRECT, precision);
00396     } else if (matpc_type == QUDA_MATPC_ODD_ODD) {
00397       twist_gamma5(inEven, inEven, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE, precision);
00398       wil_dslash(tmp, gauge, inEven, 0, daggerBit, precision, gauge_param);
00399       twist_gamma5(tmp, tmp, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE, precision);
00400       wil_dslash(outEven, gauge, tmp, 1, daggerBit, precision, gauge_param);
00401       twist_gamma5(inEven, inEven, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_DIRECT, precision); // undo
00402     }
00403   }
00404   // lastly apply the kappa term
00405   double kappa2 = -kappa*kappa;
00406   if (matpc_type == QUDA_MATPC_EVEN_EVEN || matpc_type == QUDA_MATPC_ODD_ODD) {
00407     if (precision == QUDA_DOUBLE_PRECISION) xpay((double*)inEven, kappa2, (double*)outEven, Vh*spinorSiteSize);
00408     else xpay((float*)inEven, (float)kappa2, (float*)outEven, Vh*spinorSiteSize);
00409   } else {
00410     if (precision == QUDA_DOUBLE_PRECISION) xpay((double*)tmp, kappa2, (double*)outEven, Vh*spinorSiteSize);
00411     else xpay((float*)tmp, (float)kappa2, (float*)outEven, Vh*spinorSiteSize);
00412   }
00413 
00414   free(tmp);
00415 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines