QUDA v0.4.0
A library for QCD on GPUs
quda/tests/domain_wall_dslash_reference.cpp
Go to the documentation of this file.
00001 #include <iostream>
00002 #include <stdio.h>
00003 #include <stdlib.h>
00004 #include <math.h>
00005 
00006 #include <quda.h>
00007 #include <test_util.h>
00008 #include <domain_wall_dslash_reference.h>
00009 #include <blas_reference.h>
00010 
00011 int Z[4];
00012 int V;
00013 int Vh;
00014 
00015 int Ls;
00016 int V5;
00017 int V5h;
00018 
00019 void setDims(int *X, const int L5) {
00020   V = 1;
00021   for (int d=0; d<4; d++) {
00022     V *= X[d];
00023     Z[d] = X[d];
00024   }
00025   Vh = V/2;
00026 
00027   Ls = L5;
00028   V5 = V*Ls;
00029   V5h = Vh*Ls;
00030 }
00031 
00032 template <typename Float>
00033 void sum(Float *dst, Float *a, Float *b, int cnt) {
00034   for (int i = 0; i < cnt; i++)
00035     dst[i] = a[i] + b[i];
00036 }
00037 
00038 template <typename Float>
00039 void product(Float *dst, Float a, Float *b, int cnt) {
00040   for (int i = 0; i < cnt; i++)
00041     dst[i] = a * b[i];
00042 }
00043 
00044 // performs the operation y[i] = x[i] + a*y[i]
00045 template <typename Float>
00046 void xpay(Float *x, Float a, Float *y, int len) {
00047     for (int i=0; i<len; i++) y[i] = x[i] + a*y[i];
00048 }
00049 
00050 
00051 // i represents a "half index" into an even or odd "half lattice".
00052 // when oddBit={0,1} the half lattice is {even,odd}.
00053 // 
00054 // the displacements, such as dx, refer to the full lattice coordinates. 
00055 //
00056 // neighborIndex() takes a "half index", displaces it, and returns the
00057 // new "half index", which can be an index into either the even or odd lattices.
00058 // displacements of magnitude one always interchange odd and even lattices.
00059 //
00060 //
00061 int neighborIndex_5d(int i, int oddBit, int dxs, int dx4, int dx3, int dx2, int dx1) {
00062   // fullLatticeIndex was modified for fullLatticeIndex_4d.  It is in util_quda.cpp.
00063   // This code bit may not properly perform 5dPC.
00064   int X = fullLatticeIndex_5d(i, oddBit);
00065   // Checked that this matches code in dslash_core_ante.h.
00066   int xs = X/(Z[3]*Z[2]*Z[1]*Z[0]);
00067   int x4 = (X/(Z[2]*Z[1]*Z[0])) % Z[3];
00068   int x3 = (X/(Z[1]*Z[0])) % Z[2];
00069   int x2 = (X/Z[0]) % Z[1];
00070   int x1 = X % Z[0];
00071   // Displace and project back into domain 0,...,Ls-1.
00072   // Note that we add Ls to avoid the negative problem
00073   // of the C % operator.
00074   xs = (xs+dxs+Ls) % Ls;
00075   // Etc.
00076   x4 = (x4+dx4+Z[3]) % Z[3];
00077   x3 = (x3+dx3+Z[2]) % Z[2];
00078   x2 = (x2+dx2+Z[1]) % Z[1];
00079   x1 = (x1+dx1+Z[0]) % Z[0];
00080   // Return linear half index.  Remember that integer division
00081   // rounds down.
00082   return (xs*(Z[3]*Z[2]*Z[1]*Z[0]) + x4*(Z[2]*Z[1]*Z[0]) + x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2;
00083 }
00084 
00085 // i represents a "half index" into an even or odd "half lattice".
00086 // when oddBit={0,1} the half lattice is {even,odd}.
00087 // 
00088 // the displacements, such as dx, refer to the full lattice coordinates. 
00089 //
00090 // neighborIndex() takes a "half index", displaces it, and returns the
00091 // new "half index", which can be an index into either the even or odd lattices.
00092 // displacements of magnitude one always interchange odd and even lattices.
00093 //
00094 //
00095 int neighborIndex_4d(int i, int oddBit, int dx4, int dx3, int dx2, int dx1) {
00096   // On input i should be in the range [0 , ... , Z[0]*Z[1]*Z[2]*Z[3]/2-1].
00097   if (i < 0 || i >= (Z[0]*Z[1]*Z[2]*Z[3]/2)) 
00098     { printf("i out of range in neighborIndex_4d\n"); exit(-1); }
00099   // Compute the linear index.  Then dissect.
00100   // fullLatticeIndex_4d is in util_quda.cpp.
00101   // The gauge fields live on a 4d sublattice.  
00102   int X = fullLatticeIndex_4d(i, oddBit);
00103   int x4 = X/(Z[2]*Z[1]*Z[0]);
00104   int x3 = (X/(Z[1]*Z[0])) % Z[2];
00105   int x2 = (X/Z[0]) % Z[1];
00106   int x1 = X % Z[0];
00107   
00108   x4 = (x4+dx4+Z[3]) % Z[3];
00109   x3 = (x3+dx3+Z[2]) % Z[2];
00110   x2 = (x2+dx2+Z[1]) % Z[1];
00111   x1 = (x1+dx1+Z[0]) % Z[0];
00112   
00113   return (x4*(Z[2]*Z[1]*Z[0]) + x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2;
00114 }
00115 
00116 // This is just a copy of gaugeLink() from the quda code, except
00117 // that neighborIndex() is replaced by the renamed version
00118 // neighborIndex_4d().
00119 //ok
00120 template <typename Float>
00121 Float *gaugeLink(int i, int dir, int oddBit, Float **gaugeEven,
00122                 Float **gaugeOdd) {
00123   Float **gaugeField;
00124   int j;
00125   
00126   // If going forward, just grab link at site, U_\mu(x).
00127   if (dir % 2 == 0) {
00128     j = i;
00129     // j will get used in the return statement below.
00130     gaugeField = (oddBit ? gaugeOdd : gaugeEven);
00131   } else {
00132     // If going backward, a shift must occur, U_\mu(x-\muhat)^\dagger;
00133     // dagger happens elsewhere, here we're just doing index gymnastics.
00134     switch (dir) {
00135     case 1: j = neighborIndex_4d(i, oddBit, 0, 0, 0, -1); break;
00136     case 3: j = neighborIndex_4d(i, oddBit, 0, 0, -1, 0); break;
00137     case 5: j = neighborIndex_4d(i, oddBit, 0, -1, 0, 0); break;
00138     case 7: j = neighborIndex_4d(i, oddBit, -1, 0, 0, 0); break;
00139     default: j = -1; break;
00140     }
00141     gaugeField = (oddBit ? gaugeEven : gaugeOdd);
00142   }
00143   
00144   return &gaugeField[dir/2][j*(3*3*2)];
00145 }
00146 
00147 template <typename Float>
00148 Float *spinorNeighbor_5d(int i, int dir, int oddBit, Float *spinorField) {
00149   int j;
00150   switch (dir) {
00151   case 0: j = neighborIndex_5d(i, oddBit, 0, 0, 0, 0, +1); break;
00152   case 1: j = neighborIndex_5d(i, oddBit, 0, 0, 0, 0, -1); break;
00153   case 2: j = neighborIndex_5d(i, oddBit, 0, 0, 0, +1, 0); break;
00154   case 3: j = neighborIndex_5d(i, oddBit, 0, 0, 0, -1, 0); break;
00155   case 4: j = neighborIndex_5d(i, oddBit, 0, 0, +1, 0, 0); break;
00156   case 5: j = neighborIndex_5d(i, oddBit, 0, 0, -1, 0, 0); break;
00157   case 6: j = neighborIndex_5d(i, oddBit, 0, +1, 0, 0, 0); break;
00158   case 7: j = neighborIndex_5d(i, oddBit, 0, -1, 0, 0, 0); break;
00159   case 8: j = neighborIndex_5d(i, oddBit, +1, 0, 0, 0, 0); break;
00160   case 9: j = neighborIndex_5d(i, oddBit, -1, 0, 0, 0, 0); break;
00161   default: j = -1; break;
00162   }
00163   
00164   return &spinorField[j*(4*3*2)];
00165 }
00166 
00167 
00168 template <typename sFloat, typename gFloat>
00169 void dot(sFloat* res, gFloat* a, sFloat* b) {
00170   res[0] = res[1] = 0;
00171   for (int m = 0; m < 3; m++) {
00172     sFloat a_re = a[2*m+0];
00173     sFloat a_im = a[2*m+1];
00174     sFloat b_re = b[2*m+0];
00175     sFloat b_im = b[2*m+1];
00176     res[0] += a_re * b_re - a_im * b_im;
00177     res[1] += a_re * b_im + a_im * b_re;
00178   }
00179 }
00180 
00181 template <typename Float>
00182 void su3Transpose(Float *res, Float *mat) {
00183   for (int m = 0; m < 3; m++) {
00184     for (int n = 0; n < 3; n++) {
00185       res[m*(3*2) + n*(2) + 0] = + mat[n*(3*2) + m*(2) + 0];
00186       res[m*(3*2) + n*(2) + 1] = - mat[n*(3*2) + m*(2) + 1];
00187     }
00188   }
00189 }
00190 
00191 template <typename sFloat, typename gFloat>
00192 void su3Mul(sFloat *res, gFloat *mat, sFloat *vec) {
00193   for (int n = 0; n < 3; n++) dot(&res[n*(2)], &mat[n*(3*2)], vec);
00194 }
00195 
00196 template <typename sFloat, typename gFloat>
00197 void su3Tmul(sFloat *res, gFloat *mat, sFloat *vec) {
00198   gFloat matT[3*3*2];
00199   su3Transpose(matT, mat);
00200   su3Mul(res, matT, vec);
00201 }
00202 
00203 //J  Directions 0..7 were used in the 4d code.
00204 //J  Directions 8,9 will be for P_- and P_+, chiral
00205 //J  projectors.
00206 const double projector[10][4][4][2] = {
00207   {
00208     {{1,0}, {0,0}, {0,0}, {0,-1}},
00209     {{0,0}, {1,0}, {0,-1}, {0,0}},
00210     {{0,0}, {0,1}, {1,0}, {0,0}},
00211     {{0,1}, {0,0}, {0,0}, {1,0}}
00212   },
00213   {
00214     {{1,0}, {0,0}, {0,0}, {0,1}},
00215     {{0,0}, {1,0}, {0,1}, {0,0}},
00216     {{0,0}, {0,-1}, {1,0}, {0,0}},
00217     {{0,-1}, {0,0}, {0,0}, {1,0}}
00218   },
00219   {
00220     {{1,0}, {0,0}, {0,0}, {1,0}},
00221     {{0,0}, {1,0}, {-1,0}, {0,0}},
00222     {{0,0}, {-1,0}, {1,0}, {0,0}},
00223     {{1,0}, {0,0}, {0,0}, {1,0}}
00224   },
00225   {
00226     {{1,0}, {0,0}, {0,0}, {-1,0}},
00227     {{0,0}, {1,0}, {1,0}, {0,0}},
00228     {{0,0}, {1,0}, {1,0}, {0,0}},
00229     {{-1,0}, {0,0}, {0,0}, {1,0}}
00230   },
00231   {
00232     {{1,0}, {0,0}, {0,-1}, {0,0}},
00233     {{0,0}, {1,0}, {0,0}, {0,1}},
00234     {{0,1}, {0,0}, {1,0}, {0,0}},
00235     {{0,0}, {0,-1}, {0,0}, {1,0}}
00236   },
00237   {
00238     {{1,0}, {0,0}, {0,1}, {0,0}},
00239     {{0,0}, {1,0}, {0,0}, {0,-1}},
00240     {{0,-1}, {0,0}, {1,0}, {0,0}},
00241     {{0,0}, {0,1}, {0,0}, {1,0}}
00242   },
00243   {
00244     {{1,0}, {0,0}, {-1,0}, {0,0}},
00245     {{0,0}, {1,0}, {0,0}, {-1,0}},
00246     {{-1,0}, {0,0}, {1,0}, {0,0}},
00247     {{0,0}, {-1,0}, {0,0}, {1,0}}
00248   },
00249   {
00250     {{1,0}, {0,0}, {1,0}, {0,0}},
00251     {{0,0}, {1,0}, {0,0}, {1,0}},
00252     {{1,0}, {0,0}, {1,0}, {0,0}},
00253     {{0,0}, {1,0}, {0,0}, {1,0}}
00254   },
00255   // P_+ = P_R
00256   {
00257     {{2,0}, {0,0}, {0,0}, {0,0}},
00258     {{0,0}, {2,0}, {0,0}, {0,0}},
00259     {{0,0}, {0,0}, {0,0}, {0,0}},
00260     {{0,0}, {0,0}, {0,0}, {0,0}}
00261   },
00262   // P_- = P_L
00263   {
00264     {{0,0}, {0,0}, {0,0}, {0,0}},
00265     {{0,0}, {0,0}, {0,0}, {0,0}},
00266     {{0,0}, {0,0}, {2,0}, {0,0}},
00267     {{0,0}, {0,0}, {0,0}, {2,0}}
00268   }
00269 };
00270 
00271 
00272 // todo pass projector
00273 template <typename Float>
00274 void multiplySpinorByDiracProjector(Float *res, int projIdx, Float *spinorIn) {
00275   for (int i=0; i<4*3*2; i++) res[i] = 0.0;
00276 
00277   for (int s = 0; s < 4; s++) {
00278     for (int t = 0; t < 4; t++) {
00279       Float projRe = projector[projIdx][s][t][0];
00280       Float projIm = projector[projIdx][s][t][1];
00281       
00282       for (int m = 0; m < 3; m++) {
00283         Float spinorRe = spinorIn[t*(3*2) + m*(2) + 0];
00284         Float spinorIm = spinorIn[t*(3*2) + m*(2) + 1];
00285         res[s*(3*2) + m*(2) + 0] += projRe*spinorRe - projIm*spinorIm;
00286         res[s*(3*2) + m*(2) + 1] += projRe*spinorIm + projIm*spinorRe;
00287       }
00288     }
00289   }
00290 }
00291 
00292 
00293 
00294 // dslashReference_4d()
00295 //J  This is just the 4d wilson dslash of quda code, with a
00296 //J  few small changes to take into account that the spinors
00297 //J  are 5d and the gauge fields are 4d.
00298 //
00299 // if oddBit is zero: calculate odd parity spinor elements (using even parity spinor)
00300 // if oddBit is one:  calculate even parity spinor elements
00301 //
00302 // if daggerBit is zero: perform ordinary dslash operator
00303 // if daggerBit is one:  perform hermitian conjugate of dslash
00304 //
00305 //An "ok" will only be granted once check2.tex is deemed complete,
00306 //since the logic in this function is important and nontrivial.
00307 template <typename sFloat, typename gFloat>
00308 void dslashReference_4d(sFloat *res, gFloat **gaugeFull, sFloat *spinorField, 
00309                 int oddBit, int daggerBit) {
00310   
00311   // Initialize the return half-spinor to zero.  Note that it is a
00312   // 5d spinor, hence the use of V5h.
00313   for (int i=0; i<V5h*4*3*2; i++) res[i] = 0.0;
00314   
00315   // Some pointers that we use to march through arrays.
00316   gFloat *gaugeEven[4], *gaugeOdd[4];
00317   // Initialize to beginning of even and odd parts of
00318   // gauge array.
00319   for (int dir = 0; dir < 4; dir++) {  
00320     gaugeEven[dir] = gaugeFull[dir];
00321     // Note the use of Vh here, since the gauge fields
00322     // are 4-dim'l.
00323     gaugeOdd[dir]  = gaugeFull[dir]+Vh*gaugeSiteSize;
00324   }
00325   int sp_idx,oddBit_gge;
00326   for (int xs=0;xs<Ls;xs++) {
00327     for (int gge_idx = 0; gge_idx < Vh; gge_idx++) {
00328       for (int dir = 0; dir < 8; dir++) {
00329         sp_idx=gge_idx+Vh*xs;
00330         // Here is a function call to study.  It is defined near
00331         // Line 90 of this file.
00332         // Here we have to switch oddBit depending on the value of xs.  E.g., suppose
00333         // xs=1.  Then the odd spinor site x1=x2=x3=x4=0 wants the even gauge array
00334         // element 0, so that we get U_\mu(0).
00335         if ((xs % 2) == 0) oddBit_gge=oddBit;
00336         else oddBit_gge= (oddBit+1) % 2;
00337         gFloat *gauge = gaugeLink(gge_idx, dir, oddBit_gge, gaugeEven, gaugeOdd);
00338         
00339         // Even though we're doing the 4d part of the dslash, we need
00340         // to use a 5d neighbor function, to get the offsets right.
00341         sFloat *spinor = spinorNeighbor_5d(sp_idx, dir, oddBit, spinorField);
00342       
00343         sFloat projectedSpinor[4*3*2], gaugedSpinor[4*3*2];
00344         int projIdx = 2*(dir/2)+(dir+daggerBit)%2;
00345         multiplySpinorByDiracProjector(projectedSpinor, projIdx, spinor);
00346       
00347         for (int s = 0; s < 4; s++) {
00348                 if (dir % 2 == 0) {
00349                   su3Mul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]);
00350 #ifdef DBUG_VERBOSE            
00351                   std::cout << "spinor:" << std::endl;
00352                   printSpinorElement(&projectedSpinor[s*(3*2)],0,QUDA_DOUBLE_PRECISION);
00353                   std::cout << "gauge:" << std::endl;
00354 #endif
00355           } else {
00356                   su3Tmul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]);
00357           }
00358         }
00359       
00360         sum(&res[sp_idx*(4*3*2)], &res[sp_idx*(4*3*2)], gaugedSpinor, 4*3*2);
00361       }
00362     }
00363   }
00364 }
00365 
00366 template <typename sFloat>
00367 void dslashReference_5th(sFloat *res, sFloat *spinorField, 
00368                 int oddBit, int daggerBit, sFloat mferm) {
00369   for (int i = 0; i < V5h; i++) {
00370     for (int dir = 8; dir < 10; dir++) {
00371       // Calls for an extension of the original function.
00372       // 8 is forward hop, which wants P_+, 9 is backward hop,
00373       // which wants P_-.  Dagger reverses these.
00374       sFloat *spinor = spinorNeighbor_5d(i, dir, oddBit, spinorField);
00375       sFloat projectedSpinor[4*3*2];
00376       int projIdx = 2*(dir/2)+(dir+daggerBit)%2;
00377       multiplySpinorByDiracProjector(projectedSpinor, projIdx, spinor);
00378       //J  Need a conditional here for s=0 and s=Ls-1.
00379       int X = fullLatticeIndex_5d(i, oddBit);
00380       int xs = X/(Z[3]*Z[2]*Z[1]*Z[0]);
00381       if ( (xs == 0 && dir == 9) || (xs == Ls-1 && dir == 8) ) {
00382         product(projectedSpinor,(sFloat)(-mferm),projectedSpinor,4*3*2);
00383       } 
00384       sum(&res[i*(4*3*2)], &res[i*(4*3*2)], projectedSpinor, 4*3*2);
00385     }
00386   }
00387 }
00388 
00389 // Recall that dslash is only the off-diagonal parts, so m0_dwf is not needed.
00390 //
00391 void dslash(void *res, void **gaugeFull, void *spinorField, 
00392             int oddBit, int daggerBit, 
00393             QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm) {
00394   
00395   if (sPrecision == QUDA_DOUBLE_PRECISION)  {
00396     if (gPrecision == QUDA_DOUBLE_PRECISION) {
00397       // Do the 4d part, which hasn't changed.
00398       printf("doing 4d part\n"); fflush(stdout);
00399       dslashReference_4d<double,double>((double*)res, (double**)gaugeFull,
00400                       (double*)spinorField, oddBit, daggerBit);
00401       // Now add in the 5th dim.
00402       printf("doing 5th dimen. part\n"); fflush(stdout);
00403       dslashReference_5th<double>((double*)res, (double*)spinorField, 
00404                       oddBit, daggerBit, mferm);
00405     } else {
00406       dslashReference_4d<double,float>((double*)res, (float**)gaugeFull, (double*)spinorField, oddBit, daggerBit);
00407       dslashReference_5th<double>((double*)res, (double*)spinorField, oddBit, daggerBit, mferm);
00408     }
00409   } else {
00410     // Single-precision spinor.
00411     if (gPrecision == QUDA_DOUBLE_PRECISION) {
00412       dslashReference_4d<float,double>((float*)res, (double**)gaugeFull, (float*)spinorField, oddBit, daggerBit);
00413       dslashReference_5th<float>((float*)res, (float*)spinorField, oddBit, daggerBit, mferm);
00414     } else {
00415       // Do the 4d part, which hasn't changed.
00416       printf("CPU reference:  doing 4d part all single precision\n"); fflush(stdout);
00417       dslashReference_4d<float,float>((float*)res, (float**)gaugeFull, (float*)spinorField, oddBit, daggerBit);
00418       // Now add in the 5th dim.
00419       printf("CPU reference:  doing 5th dimen. part all single precision\n"); fflush(stdout);
00420       dslashReference_5th<float>((float*)res, (float*)spinorField, oddBit, daggerBit, mferm);
00421     }
00422   }
00423 }
00424 
00425 
00426 template <typename sFloat, typename gFloat>
00427 void Mat(sFloat *out, gFloat **gauge, sFloat *in, sFloat kappa, sFloat mferm) {
00428   sFloat *inEven = in;
00429   sFloat *inOdd  = in + V5h*spinorSiteSize;
00430   sFloat *outEven = out;
00431   sFloat *outOdd = out + V5h*spinorSiteSize;
00432   
00433   // full dslash operator
00434   dslashReference_4d(outOdd, gauge, inEven, 1, 0);
00435   dslashReference_5th(outOdd, inEven, 1, 0, mferm);
00436   dslashReference_4d(outEven, gauge, inOdd, 0, 0);
00437   dslashReference_5th(outEven, inOdd, 0, 0, mferm);
00438   
00439   // lastly apply the kappa term
00440   xpay(in, -kappa, out, V5*spinorSiteSize);
00441 }
00442 
00443 template <typename sFloat, typename gFloat>
00444 void MatDag(sFloat *out, gFloat **gauge, sFloat *in, sFloat kappa, sFloat mferm) {
00445   sFloat *inEven = in;
00446   sFloat *inOdd  = in + V5h*spinorSiteSize;
00447   sFloat *outEven = out;
00448   sFloat *outOdd = out + V5h*spinorSiteSize;
00449   
00450   // full dslash operator
00451   dslashReference_4d(outOdd, gauge, inEven, 1, 1);
00452   dslashReference_5th(outOdd, inEven, 1, 1, mferm);
00453   dslashReference_4d(outEven, gauge, inOdd, 0, 1);
00454   dslashReference_5th(outEven, inOdd, 0, 1, mferm);
00455   
00456   // lastly apply the kappa term
00457   xpay(in, -kappa, out, V5*spinorSiteSize);
00458 }
00459 
00460 void mat(void *out, void **gauge, void *in, double kappa, int dagger_bit, 
00461          QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm) {
00462   if (!dagger_bit) {
00463     if (sPrecision == QUDA_DOUBLE_PRECISION)
00464       if (gPrecision == QUDA_DOUBLE_PRECISION) Mat((double*)out, (double**)gauge, (double*)in, (double)kappa,
00465                       (double)mferm);
00466       else Mat((double*)out, (float**)gauge, (double*)in, (double)kappa, (double)mferm);
00467     else
00468       if (gPrecision == QUDA_DOUBLE_PRECISION) Mat((float*)out, (double**)gauge, (float*)in, (float)kappa,
00469                       (float)mferm);
00470       else Mat((float*)out, (float**)gauge, (float*)in, (float)kappa, (float)mferm);
00471   } else {
00472     if (sPrecision == QUDA_DOUBLE_PRECISION)
00473       if (gPrecision == QUDA_DOUBLE_PRECISION) MatDag((double*)out, (double**)gauge, (double*)in, (double)kappa,
00474                       (double)mferm);
00475       else MatDag((float*)out, (double**)gauge, (float*)in, (float)kappa, (float)mferm);
00476     else
00477       if (gPrecision == QUDA_DOUBLE_PRECISION) MatDag((float*)out, (double**)gauge, (float*)in, (float)kappa,
00478                       (float)mferm);
00479       else MatDag((float*)out, (float**)gauge, (float*)in, (float)kappa, (float)mferm);
00480   }
00481 }
00482 
00483 // Apply the even-odd preconditioned Dirac operator
00484 template <typename sFloat, typename gFloat>
00485 void MatPC(sFloat *outEven, gFloat **gauge, sFloat *inEven, sFloat kappa,
00486            QudaMatPCType matpc_type, sFloat mferm) {
00487   
00488   sFloat *tmp = (sFloat*)malloc(V5h*spinorSiteSize*sizeof(sFloat));
00489     
00490   // full dslash operator
00491   if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
00492     dslashReference_4d(tmp, gauge, inEven, 1, 0);
00493     dslashReference_5th(tmp, inEven, 1, 0, mferm);
00494     dslashReference_4d(outEven, gauge, tmp, 0, 0);
00495     dslashReference_5th(outEven, tmp, 0, 0, mferm);
00496   } else {
00497     dslashReference_4d(tmp, gauge, inEven, 0, 0);
00498     dslashReference_5th(tmp, inEven, 0, 0, mferm);
00499     dslashReference_4d(outEven, gauge, tmp, 1, 0);
00500     dslashReference_5th(outEven, tmp, 1, 0, mferm);
00501   }    
00502   
00503   // lastly apply the kappa term
00504   sFloat kappa2 = -kappa*kappa;
00505   xpay(inEven, kappa2, outEven, V5h*spinorSiteSize);
00506   free(tmp);
00507 }
00508 
00509 // Apply the even-odd preconditioned Dirac operator
00510 template <typename sFloat, typename gFloat>
00511 void MatPCDag(sFloat *outEven, gFloat **gauge, sFloat *inEven, sFloat kappa, 
00512               QudaMatPCType matpc_type, sFloat mferm) {
00513   
00514   sFloat *tmp = (sFloat*)malloc(V5h*spinorSiteSize*sizeof(sFloat));    
00515   
00516   // full dslash operator
00517   if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
00518     dslashReference_4d(tmp, gauge, inEven, 1, 1);
00519     dslashReference_5th(tmp, inEven, 1, 1, mferm);
00520     dslashReference_4d(outEven, gauge, tmp, 0, 1);
00521     dslashReference_5th(outEven, tmp, 0, 1, mferm);
00522   } else {
00523     dslashReference_4d(tmp, gauge, inEven, 0, 1);
00524     dslashReference_5th(tmp, inEven, 0, 1, mferm);
00525     dslashReference_4d(outEven, gauge, tmp, 1, 1);
00526     dslashReference_5th(outEven, tmp, 1, 1, mferm);
00527   }
00528   
00529   sFloat kappa2 = -kappa*kappa;
00530   xpay(inEven, kappa2, outEven, V5h*spinorSiteSize);
00531   free(tmp);
00532 }
00533 
00534 void matpc(void *outEven, void **gauge, void *inEven, double kappa, 
00535            QudaMatPCType matpc_type, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision,
00536      double mferm) {
00537   if (!dagger_bit) {
00538     if (sPrecision == QUDA_DOUBLE_PRECISION)
00539       if (gPrecision == QUDA_DOUBLE_PRECISION) 
00540         MatPC((double*)outEven, (double**)gauge, (double*)inEven, (double)kappa, matpc_type, (double)mferm);
00541       else
00542         MatPC((double*)outEven, (float**)gauge, (double*)inEven, (double)kappa, matpc_type, (double)mferm);
00543     else
00544       if (gPrecision == QUDA_DOUBLE_PRECISION) 
00545         MatPC((float*)outEven, (double**)gauge, (float*)inEven, (float)kappa, matpc_type, (float)mferm);
00546       else
00547         MatPC((float*)outEven, (float**)gauge, (float*)inEven, (float)kappa, matpc_type, (float)mferm);
00548   } else {
00549     if (sPrecision == QUDA_DOUBLE_PRECISION)
00550       if (gPrecision == QUDA_DOUBLE_PRECISION) 
00551         MatPCDag((double*)outEven, (double**)gauge, (double*)inEven, (double)kappa, matpc_type, (double)mferm);
00552       else
00553         MatPCDag((double*)outEven, (float**)gauge, (double*)inEven, (double)kappa, matpc_type, (double)mferm);
00554     else
00555       if (gPrecision == QUDA_DOUBLE_PRECISION) 
00556         MatPCDag((float*)outEven, (double**)gauge, (float*)inEven, (float)kappa, matpc_type, (float)mferm);
00557       else
00558         MatPCDag((float*)outEven, (float**)gauge, (float*)inEven, (float)kappa, matpc_type, (float)mferm);
00559   }
00560 }
00561 
00562 
00563 template <typename sFloat, typename gFloat> 
00564 void MatDagMat(sFloat *out, gFloat **gauge, sFloat *in, sFloat kappa, sFloat mferm) 
00565 {
00566   // Allocate a full spinor.        
00567   sFloat *tmp = (sFloat*)malloc(V5*spinorSiteSize*sizeof(sFloat));
00568   // Call templates above.
00569   Mat(tmp, gauge, in, kappa, mferm);
00570   MatDag(out, gauge, tmp, kappa, mferm);
00571   free(tmp);
00572 }
00573 
00574 template <typename sFloat, typename gFloat> 
00575 void MatPCDagMatPC(sFloat *out, gFloat **gauge, sFloat *in, sFloat kappa, 
00576                    QudaMatPCType matpc_type, sFloat mferm)
00577 {
00578   
00579   // Allocate half spinor
00580   sFloat *tmp = (sFloat*)malloc(V5h*spinorSiteSize*sizeof(sFloat));
00581   // Apply the PC templates above
00582   MatPC(tmp, gauge, in, kappa, matpc_type, mferm);
00583   MatPCDag(out, gauge, tmp, kappa, matpc_type, mferm);
00584   free(tmp);
00585 }
00586 
00587 // Wrapper to templates that handles different precisions.
00588 void matdagmat(void *out, void **gauge, void *in, double kappa,
00589          QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm) 
00590 {
00591   if (sPrecision == QUDA_DOUBLE_PRECISION) {
00592     if (gPrecision == QUDA_DOUBLE_PRECISION) 
00593       MatDagMat((double*)out, (double**)gauge, (double*)in, (double)kappa,
00594           (double)mferm);
00595     else 
00596       MatDagMat((double*)out, (float**)gauge, (double*)in, (double)kappa, (double)mferm);
00597   } else {
00598     if (gPrecision == QUDA_DOUBLE_PRECISION) 
00599       MatDagMat((float*)out, (double**)gauge, (float*)in, (float)kappa,
00600           (float)mferm);
00601     else 
00602       MatDagMat((float*)out, (float**)gauge, (float*)in, (float)kappa, (float)mferm);
00603   }
00604 }
00605 
00606 // Wrapper to templates that handles different precisions.
00607 void matpcdagmatpc(void *out, void **gauge, void *in, double kappa,
00608          QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm, QudaMatPCType matpc_type) 
00609 {
00610   if (sPrecision == QUDA_DOUBLE_PRECISION) {
00611     if (gPrecision == QUDA_DOUBLE_PRECISION) 
00612       MatPCDagMatPC((double*)out, (double**)gauge, (double*)in, (double)kappa,
00613         matpc_type, (double)mferm);
00614     else 
00615       MatPCDagMatPC((double*)out, (float**)gauge, (double*)in, (double)kappa,
00616                       matpc_type, (double)mferm);
00617   } else {
00618     if (gPrecision == QUDA_DOUBLE_PRECISION) 
00619       MatPCDagMatPC((float*)out, (double**)gauge, (float*)in, (float)kappa,
00620         matpc_type, (float)mferm);
00621     else 
00622       MatPCDagMatPC((float*)out, (float**)gauge, (float*)in, (float)kappa, 
00623                       matpc_type, (float)mferm);
00624   }
00625 }
00626 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines