QUDA v0.3.2
A library for QCD on GPUs

quda/tests/wilson_dslash_reference.cpp

Go to the documentation of this file.
00001 #include <stdio.h>
00002 #include <stdlib.h>
00003 #include <math.h>
00004 
00005 #include <util_quda.h>
00006 
00007 #include <test_util.h>
00008 #include <blas_reference.h>
00009 #include <wilson_dslash_reference.h>
00010 
00011 int Z[4];
00012 int V;
00013 int Vh;
00014 
00015 void setDims(int *X) {
00016   V = 1;
00017   for (int d=0; d< 4; d++) {
00018     V *= X[d];
00019     Z[d] = X[d];
00020   }
00021   Vh = V/2;
00022 }
00023 
00024 template <typename Float>
00025 void sum(Float *dst, Float *a, Float *b, int cnt) {
00026   for (int i = 0; i < cnt; i++)
00027     dst[i] = a[i] + b[i];
00028 }
00029 
00030 // performs the operation y[i] = x[i] + a*y[i]
00031 template <typename Float>
00032 void xpay(Float *x, Float a, Float *y, int len) {
00033     for (int i=0; i<len; i++) y[i] = x[i] + a*y[i];
00034 }
00035 
00036 
00037 
00038 template <typename Float>
00039 Float *gaugeLink(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd) {
00040   Float **gaugeField;
00041   int j;
00042   
00043   if (dir % 2 == 0) {
00044     j = i;
00045     gaugeField = (oddBit ? gaugeOdd : gaugeEven);
00046   }
00047   else {
00048     switch (dir) {
00049     case 1: j = neighborIndex(i, oddBit, 0, 0, 0, -1); break;
00050     case 3: j = neighborIndex(i, oddBit, 0, 0, -1, 0); break;
00051     case 5: j = neighborIndex(i, oddBit, 0, -1, 0, 0); break;
00052     case 7: j = neighborIndex(i, oddBit, -1, 0, 0, 0); break;
00053     default: j = -1; break;
00054     }
00055     gaugeField = (oddBit ? gaugeEven : gaugeOdd);
00056   }
00057   
00058   return &gaugeField[dir/2][j*(3*3*2)];
00059 }
00060 
00061 template <typename Float>
00062 Float *spinorNeighbor(int i, int dir, int oddBit, Float *spinorField) {
00063   int j;
00064   switch (dir) {
00065   case 0: j = neighborIndex(i, oddBit, 0, 0, 0, +1); break;
00066   case 1: j = neighborIndex(i, oddBit, 0, 0, 0, -1); break;
00067   case 2: j = neighborIndex(i, oddBit, 0, 0, +1, 0); break;
00068   case 3: j = neighborIndex(i, oddBit, 0, 0, -1, 0); break;
00069   case 4: j = neighborIndex(i, oddBit, 0, +1, 0, 0); break;
00070   case 5: j = neighborIndex(i, oddBit, 0, -1, 0, 0); break;
00071   case 6: j = neighborIndex(i, oddBit, +1, 0, 0, 0); break;
00072   case 7: j = neighborIndex(i, oddBit, -1, 0, 0, 0); break;
00073   default: j = -1; break;
00074   }
00075   
00076   return &spinorField[j*(4*3*2)];
00077 }
00078 
00079 template <typename sFloat, typename gFloat>
00080 void dot(sFloat* res, gFloat* a, sFloat* b) {
00081   res[0] = res[1] = 0;
00082   for (int m = 0; m < 3; m++) {
00083     sFloat a_re = a[2*m+0];
00084     sFloat a_im = a[2*m+1];
00085     sFloat b_re = b[2*m+0];
00086     sFloat b_im = b[2*m+1];
00087     res[0] += a_re * b_re - a_im * b_im;
00088     res[1] += a_re * b_im + a_im * b_re;
00089   }
00090 }
00091 
00092 template <typename Float>
00093 void su3Transpose(Float *res, Float *mat) {
00094   for (int m = 0; m < 3; m++) {
00095     for (int n = 0; n < 3; n++) {
00096       res[m*(3*2) + n*(2) + 0] = + mat[n*(3*2) + m*(2) + 0];
00097       res[m*(3*2) + n*(2) + 1] = - mat[n*(3*2) + m*(2) + 1];
00098     }
00099   }
00100 }
00101 
00102 template <typename sFloat, typename gFloat>
00103 void su3Mul(sFloat *res, gFloat *mat, sFloat *vec) {
00104   for (int n = 0; n < 3; n++) dot(&res[n*(2)], &mat[n*(3*2)], vec);
00105 }
00106 
00107 template <typename sFloat, typename gFloat>
00108 void su3Tmul(sFloat *res, gFloat *mat, sFloat *vec) {
00109   gFloat matT[3*3*2];
00110   su3Transpose(matT, mat);
00111   su3Mul(res, matT, vec);
00112 }
00113 
00114 const double projector[8][4][4][2] = {
00115   {
00116     {{1,0}, {0,0}, {0,0}, {0,-1}},
00117     {{0,0}, {1,0}, {0,-1}, {0,0}},
00118     {{0,0}, {0,1}, {1,0}, {0,0}},
00119     {{0,1}, {0,0}, {0,0}, {1,0}}
00120   },
00121   {
00122     {{1,0}, {0,0}, {0,0}, {0,1}},
00123     {{0,0}, {1,0}, {0,1}, {0,0}},
00124     {{0,0}, {0,-1}, {1,0}, {0,0}},
00125     {{0,-1}, {0,0}, {0,0}, {1,0}}
00126   },
00127   {
00128     {{1,0}, {0,0}, {0,0}, {1,0}},
00129     {{0,0}, {1,0}, {-1,0}, {0,0}},
00130     {{0,0}, {-1,0}, {1,0}, {0,0}},
00131     {{1,0}, {0,0}, {0,0}, {1,0}}
00132   },
00133   {
00134     {{1,0}, {0,0}, {0,0}, {-1,0}},
00135     {{0,0}, {1,0}, {1,0}, {0,0}},
00136     {{0,0}, {1,0}, {1,0}, {0,0}},
00137     {{-1,0}, {0,0}, {0,0}, {1,0}}
00138   },
00139   {
00140     {{1,0}, {0,0}, {0,-1}, {0,0}},
00141     {{0,0}, {1,0}, {0,0}, {0,1}},
00142     {{0,1}, {0,0}, {1,0}, {0,0}},
00143     {{0,0}, {0,-1}, {0,0}, {1,0}}
00144   },
00145   {
00146     {{1,0}, {0,0}, {0,1}, {0,0}},
00147     {{0,0}, {1,0}, {0,0}, {0,-1}},
00148     {{0,-1}, {0,0}, {1,0}, {0,0}},
00149     {{0,0}, {0,1}, {0,0}, {1,0}}
00150   },
00151   {
00152     {{1,0}, {0,0}, {-1,0}, {0,0}},
00153     {{0,0}, {1,0}, {0,0}, {-1,0}},
00154     {{-1,0}, {0,0}, {1,0}, {0,0}},
00155     {{0,0}, {-1,0}, {0,0}, {1,0}}
00156   },
00157   {
00158     {{1,0}, {0,0}, {1,0}, {0,0}},
00159     {{0,0}, {1,0}, {0,0}, {1,0}},
00160     {{1,0}, {0,0}, {1,0}, {0,0}},
00161     {{0,0}, {1,0}, {0,0}, {1,0}}
00162   }
00163 };
00164 
00165 
00166 // todo pass projector
00167 template <typename Float>
00168 void multiplySpinorByDiracProjector(Float *res, int projIdx, Float *spinorIn) {
00169   for (int i=0; i<4*3*2; i++) res[i] = 0.0;
00170 
00171   for (int s = 0; s < 4; s++) {
00172     for (int t = 0; t < 4; t++) {
00173       Float projRe = projector[projIdx][s][t][0];
00174       Float projIm = projector[projIdx][s][t][1];
00175       
00176       for (int m = 0; m < 3; m++) {
00177         Float spinorRe = spinorIn[t*(3*2) + m*(2) + 0];
00178         Float spinorIm = spinorIn[t*(3*2) + m*(2) + 1];
00179         res[s*(3*2) + m*(2) + 0] += projRe*spinorRe - projIm*spinorIm;
00180         res[s*(3*2) + m*(2) + 1] += projRe*spinorIm + projIm*spinorRe;
00181       }
00182     }
00183   }
00184 }
00185 
00186 
00187 //
00188 // dslashReference()
00189 //
00190 // if oddBit is zero: calculate odd parity spinor elements (using even parity spinor)
00191 // if oddBit is one:  calculate even parity spinor elements
00192 //
00193 // if daggerBit is zero: perform ordinary dslash operator
00194 // if daggerBit is one:  perform hermitian conjugate of dslash
00195 //
00196 template <typename sFloat, typename gFloat>
00197 void dslashReference(sFloat *res, gFloat **gaugeFull, sFloat *spinorField, int oddBit, int daggerBit) {
00198   for (int i=0; i<Vh*4*3*2; i++) res[i] = 0.0;
00199   
00200   gFloat *gaugeEven[4], *gaugeOdd[4];
00201   for (int dir = 0; dir < 4; dir++) {  
00202     gaugeEven[dir] = gaugeFull[dir];
00203     gaugeOdd[dir]  = gaugeFull[dir]+Vh*gaugeSiteSize;
00204   }
00205   
00206   for (int i = 0; i < Vh; i++) {
00207     for (int dir = 0; dir < 8; dir++) {
00208       gFloat *gauge = gaugeLink(i, dir, oddBit, gaugeEven, gaugeOdd);
00209       sFloat *spinor = spinorNeighbor(i, dir, oddBit, spinorField);
00210       
00211       sFloat projectedSpinor[4*3*2], gaugedSpinor[4*3*2];
00212       int projIdx = 2*(dir/2)+(dir+daggerBit)%2;
00213       multiplySpinorByDiracProjector(projectedSpinor, projIdx, spinor);
00214       
00215       for (int s = 0; s < 4; s++) {
00216         if (dir % 2 == 0)
00217           su3Mul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]);
00218         else
00219           su3Tmul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]);
00220       }
00221       
00222       sum(&res[i*(4*3*2)], &res[i*(4*3*2)], gaugedSpinor, 4*3*2);
00223     }
00224   }
00225 }
00226 
00227 void dslash(void *res, void **gaugeFull, void *spinorField, int oddBit, int daggerBit,
00228             QudaPrecision sPrecision, QudaPrecision gPrecision) {
00229   
00230   if (sPrecision == QUDA_DOUBLE_PRECISION) 
00231     if (gPrecision == QUDA_DOUBLE_PRECISION)
00232       dslashReference((double*)res, (double**)gaugeFull, (double*)spinorField, oddBit, daggerBit);
00233     else
00234       dslashReference((double*)res, (float**)gaugeFull, (double*)spinorField, oddBit, daggerBit);
00235   else
00236     if (gPrecision == QUDA_DOUBLE_PRECISION)
00237       dslashReference((float*)res, (double**)gaugeFull, (float*)spinorField, oddBit, daggerBit);
00238     else
00239       dslashReference((float*)res, (float**)gaugeFull, (float*)spinorField, oddBit, daggerBit);
00240 
00241 }
00242 
00243 template <typename sFloat, typename gFloat>
00244 void Mat(sFloat *out, gFloat **gauge, sFloat *in, sFloat kappa, int daggerBit) {
00245   sFloat *inEven = in;
00246   sFloat *inOdd  = in + Vh*spinorSiteSize;
00247   sFloat *outEven = out;
00248   sFloat *outOdd = out + Vh*spinorSiteSize;
00249   
00250   // full dslash operator
00251   dslashReference(outOdd, gauge, inEven, 1, daggerBit);
00252   dslashReference(outEven, gauge, inOdd, 0, daggerBit);
00253   
00254   // lastly apply the kappa term
00255   xpay(in, -kappa, out, V*spinorSiteSize);
00256 }
00257 
00258 void mat(void *out, void **gauge, void *in, double kappa, int dagger_bit,
00259          QudaPrecision sPrecision, QudaPrecision gPrecision) {
00260 
00261   if (sPrecision == QUDA_DOUBLE_PRECISION)
00262     if (gPrecision == QUDA_DOUBLE_PRECISION) 
00263       Mat((double*)out, (double**)gauge, (double*)in, (double)kappa, dagger_bit);
00264     else 
00265       Mat((double*)out, (float**)gauge, (double*)in, (double)kappa, dagger_bit);
00266   else
00267     if (gPrecision == QUDA_DOUBLE_PRECISION) 
00268       Mat((float*)out, (double**)gauge, (float*)in, (float)kappa, dagger_bit);
00269     else 
00270       Mat((float*)out, (float**)gauge, (float*)in, (float)kappa, dagger_bit);
00271 }
00272 
00273 // Apply the even-odd preconditioned Dirac operator
00274 template <typename sFloat, typename gFloat>
00275 void MatPC(sFloat *outEven, gFloat **gauge, sFloat *inEven, sFloat kappa, 
00276            int daggerBit, QudaMatPCType matpc_type) {
00277   
00278   sFloat *tmp = (sFloat*)malloc(Vh*spinorSiteSize*sizeof(sFloat));
00279     
00280   // full dslash operator
00281   if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
00282     dslashReference(tmp, gauge, inEven, 1, daggerBit);
00283     dslashReference(outEven, gauge, tmp, 0, daggerBit);
00284   } else {
00285     dslashReference(tmp, gauge, inEven, 0, daggerBit);
00286     dslashReference(outEven, gauge, tmp, 1, daggerBit);
00287   }    
00288   
00289   // lastly apply the kappa term
00290   sFloat kappa2 = -kappa*kappa;
00291   xpay(inEven, kappa2, outEven, Vh*spinorSiteSize);
00292   free(tmp);
00293 }
00294 
00295 void matpc(void *outEven, void **gauge, void *inEven, double kappa, 
00296            QudaMatPCType matpc_type, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision) {
00297 
00298   if (sPrecision == QUDA_DOUBLE_PRECISION)
00299     if (gPrecision == QUDA_DOUBLE_PRECISION) 
00300       MatPC((double*)outEven, (double**)gauge, (double*)inEven, (double)kappa, dagger_bit, matpc_type);
00301     else
00302       MatPC((double*)outEven, (float**)gauge, (double*)inEven, (double)kappa, dagger_bit, matpc_type);
00303   else
00304     if (gPrecision == QUDA_DOUBLE_PRECISION) 
00305       MatPC((float*)outEven, (double**)gauge, (float*)inEven, (float)kappa, dagger_bit, matpc_type);
00306     else
00307       MatPC((float*)outEven, (float**)gauge, (float*)inEven, (float)kappa, dagger_bit, matpc_type);
00308 }
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Defines