QUDA v0.4.0
A library for QCD on GPUs
|
00001 #include <iostream> 00002 #include <stdio.h> 00003 #include <stdlib.h> 00004 #include <math.h> 00005 00006 #include <quda.h> 00007 #include <test_util.h> 00008 #include <domain_wall_dslash_reference.h> 00009 #include <blas_reference.h> 00010 00011 int Z[4]; 00012 int V; 00013 int Vh; 00014 00015 int Ls; 00016 int V5; 00017 int V5h; 00018 00019 void setDims(int *X, const int L5) { 00020 V = 1; 00021 for (int d=0; d<4; d++) { 00022 V *= X[d]; 00023 Z[d] = X[d]; 00024 } 00025 Vh = V/2; 00026 00027 Ls = L5; 00028 V5 = V*Ls; 00029 V5h = Vh*Ls; 00030 } 00031 00032 template <typename Float> 00033 void sum(Float *dst, Float *a, Float *b, int cnt) { 00034 for (int i = 0; i < cnt; i++) 00035 dst[i] = a[i] + b[i]; 00036 } 00037 00038 template <typename Float> 00039 void product(Float *dst, Float a, Float *b, int cnt) { 00040 for (int i = 0; i < cnt; i++) 00041 dst[i] = a * b[i]; 00042 } 00043 00044 // performs the operation y[i] = x[i] + a*y[i] 00045 template <typename Float> 00046 void xpay(Float *x, Float a, Float *y, int len) { 00047 for (int i=0; i<len; i++) y[i] = x[i] + a*y[i]; 00048 } 00049 00050 00051 // i represents a "half index" into an even or odd "half lattice". 00052 // when oddBit={0,1} the half lattice is {even,odd}. 00053 // 00054 // the displacements, such as dx, refer to the full lattice coordinates. 00055 // 00056 // neighborIndex() takes a "half index", displaces it, and returns the 00057 // new "half index", which can be an index into either the even or odd lattices. 00058 // displacements of magnitude one always interchange odd and even lattices. 00059 // 00060 // 00061 int neighborIndex_5d(int i, int oddBit, int dxs, int dx4, int dx3, int dx2, int dx1) { 00062 // fullLatticeIndex was modified for fullLatticeIndex_4d. It is in util_quda.cpp. 00063 // This code bit may not properly perform 5dPC. 00064 int X = fullLatticeIndex_5d(i, oddBit); 00065 // Checked that this matches code in dslash_core_ante.h. 00066 int xs = X/(Z[3]*Z[2]*Z[1]*Z[0]); 00067 int x4 = (X/(Z[2]*Z[1]*Z[0])) % Z[3]; 00068 int x3 = (X/(Z[1]*Z[0])) % Z[2]; 00069 int x2 = (X/Z[0]) % Z[1]; 00070 int x1 = X % Z[0]; 00071 // Displace and project back into domain 0,...,Ls-1. 00072 // Note that we add Ls to avoid the negative problem 00073 // of the C % operator. 00074 xs = (xs+dxs+Ls) % Ls; 00075 // Etc. 00076 x4 = (x4+dx4+Z[3]) % Z[3]; 00077 x3 = (x3+dx3+Z[2]) % Z[2]; 00078 x2 = (x2+dx2+Z[1]) % Z[1]; 00079 x1 = (x1+dx1+Z[0]) % Z[0]; 00080 // Return linear half index. Remember that integer division 00081 // rounds down. 00082 return (xs*(Z[3]*Z[2]*Z[1]*Z[0]) + x4*(Z[2]*Z[1]*Z[0]) + x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2; 00083 } 00084 00085 // i represents a "half index" into an even or odd "half lattice". 00086 // when oddBit={0,1} the half lattice is {even,odd}. 00087 // 00088 // the displacements, such as dx, refer to the full lattice coordinates. 00089 // 00090 // neighborIndex() takes a "half index", displaces it, and returns the 00091 // new "half index", which can be an index into either the even or odd lattices. 00092 // displacements of magnitude one always interchange odd and even lattices. 00093 // 00094 // 00095 int neighborIndex_4d(int i, int oddBit, int dx4, int dx3, int dx2, int dx1) { 00096 // On input i should be in the range [0 , ... , Z[0]*Z[1]*Z[2]*Z[3]/2-1]. 00097 if (i < 0 || i >= (Z[0]*Z[1]*Z[2]*Z[3]/2)) 00098 { printf("i out of range in neighborIndex_4d\n"); exit(-1); } 00099 // Compute the linear index. Then dissect. 00100 // fullLatticeIndex_4d is in util_quda.cpp. 00101 // The gauge fields live on a 4d sublattice. 00102 int X = fullLatticeIndex_4d(i, oddBit); 00103 int x4 = X/(Z[2]*Z[1]*Z[0]); 00104 int x3 = (X/(Z[1]*Z[0])) % Z[2]; 00105 int x2 = (X/Z[0]) % Z[1]; 00106 int x1 = X % Z[0]; 00107 00108 x4 = (x4+dx4+Z[3]) % Z[3]; 00109 x3 = (x3+dx3+Z[2]) % Z[2]; 00110 x2 = (x2+dx2+Z[1]) % Z[1]; 00111 x1 = (x1+dx1+Z[0]) % Z[0]; 00112 00113 return (x4*(Z[2]*Z[1]*Z[0]) + x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2; 00114 } 00115 00116 // This is just a copy of gaugeLink() from the quda code, except 00117 // that neighborIndex() is replaced by the renamed version 00118 // neighborIndex_4d(). 00119 //ok 00120 template <typename Float> 00121 Float *gaugeLink(int i, int dir, int oddBit, Float **gaugeEven, 00122 Float **gaugeOdd) { 00123 Float **gaugeField; 00124 int j; 00125 00126 // If going forward, just grab link at site, U_\mu(x). 00127 if (dir % 2 == 0) { 00128 j = i; 00129 // j will get used in the return statement below. 00130 gaugeField = (oddBit ? gaugeOdd : gaugeEven); 00131 } else { 00132 // If going backward, a shift must occur, U_\mu(x-\muhat)^\dagger; 00133 // dagger happens elsewhere, here we're just doing index gymnastics. 00134 switch (dir) { 00135 case 1: j = neighborIndex_4d(i, oddBit, 0, 0, 0, -1); break; 00136 case 3: j = neighborIndex_4d(i, oddBit, 0, 0, -1, 0); break; 00137 case 5: j = neighborIndex_4d(i, oddBit, 0, -1, 0, 0); break; 00138 case 7: j = neighborIndex_4d(i, oddBit, -1, 0, 0, 0); break; 00139 default: j = -1; break; 00140 } 00141 gaugeField = (oddBit ? gaugeEven : gaugeOdd); 00142 } 00143 00144 return &gaugeField[dir/2][j*(3*3*2)]; 00145 } 00146 00147 template <typename Float> 00148 Float *spinorNeighbor_5d(int i, int dir, int oddBit, Float *spinorField) { 00149 int j; 00150 switch (dir) { 00151 case 0: j = neighborIndex_5d(i, oddBit, 0, 0, 0, 0, +1); break; 00152 case 1: j = neighborIndex_5d(i, oddBit, 0, 0, 0, 0, -1); break; 00153 case 2: j = neighborIndex_5d(i, oddBit, 0, 0, 0, +1, 0); break; 00154 case 3: j = neighborIndex_5d(i, oddBit, 0, 0, 0, -1, 0); break; 00155 case 4: j = neighborIndex_5d(i, oddBit, 0, 0, +1, 0, 0); break; 00156 case 5: j = neighborIndex_5d(i, oddBit, 0, 0, -1, 0, 0); break; 00157 case 6: j = neighborIndex_5d(i, oddBit, 0, +1, 0, 0, 0); break; 00158 case 7: j = neighborIndex_5d(i, oddBit, 0, -1, 0, 0, 0); break; 00159 case 8: j = neighborIndex_5d(i, oddBit, +1, 0, 0, 0, 0); break; 00160 case 9: j = neighborIndex_5d(i, oddBit, -1, 0, 0, 0, 0); break; 00161 default: j = -1; break; 00162 } 00163 00164 return &spinorField[j*(4*3*2)]; 00165 } 00166 00167 00168 template <typename sFloat, typename gFloat> 00169 void dot(sFloat* res, gFloat* a, sFloat* b) { 00170 res[0] = res[1] = 0; 00171 for (int m = 0; m < 3; m++) { 00172 sFloat a_re = a[2*m+0]; 00173 sFloat a_im = a[2*m+1]; 00174 sFloat b_re = b[2*m+0]; 00175 sFloat b_im = b[2*m+1]; 00176 res[0] += a_re * b_re - a_im * b_im; 00177 res[1] += a_re * b_im + a_im * b_re; 00178 } 00179 } 00180 00181 template <typename Float> 00182 void su3Transpose(Float *res, Float *mat) { 00183 for (int m = 0; m < 3; m++) { 00184 for (int n = 0; n < 3; n++) { 00185 res[m*(3*2) + n*(2) + 0] = + mat[n*(3*2) + m*(2) + 0]; 00186 res[m*(3*2) + n*(2) + 1] = - mat[n*(3*2) + m*(2) + 1]; 00187 } 00188 } 00189 } 00190 00191 template <typename sFloat, typename gFloat> 00192 void su3Mul(sFloat *res, gFloat *mat, sFloat *vec) { 00193 for (int n = 0; n < 3; n++) dot(&res[n*(2)], &mat[n*(3*2)], vec); 00194 } 00195 00196 template <typename sFloat, typename gFloat> 00197 void su3Tmul(sFloat *res, gFloat *mat, sFloat *vec) { 00198 gFloat matT[3*3*2]; 00199 su3Transpose(matT, mat); 00200 su3Mul(res, matT, vec); 00201 } 00202 00203 //J Directions 0..7 were used in the 4d code. 00204 //J Directions 8,9 will be for P_- and P_+, chiral 00205 //J projectors. 00206 const double projector[10][4][4][2] = { 00207 { 00208 {{1,0}, {0,0}, {0,0}, {0,-1}}, 00209 {{0,0}, {1,0}, {0,-1}, {0,0}}, 00210 {{0,0}, {0,1}, {1,0}, {0,0}}, 00211 {{0,1}, {0,0}, {0,0}, {1,0}} 00212 }, 00213 { 00214 {{1,0}, {0,0}, {0,0}, {0,1}}, 00215 {{0,0}, {1,0}, {0,1}, {0,0}}, 00216 {{0,0}, {0,-1}, {1,0}, {0,0}}, 00217 {{0,-1}, {0,0}, {0,0}, {1,0}} 00218 }, 00219 { 00220 {{1,0}, {0,0}, {0,0}, {1,0}}, 00221 {{0,0}, {1,0}, {-1,0}, {0,0}}, 00222 {{0,0}, {-1,0}, {1,0}, {0,0}}, 00223 {{1,0}, {0,0}, {0,0}, {1,0}} 00224 }, 00225 { 00226 {{1,0}, {0,0}, {0,0}, {-1,0}}, 00227 {{0,0}, {1,0}, {1,0}, {0,0}}, 00228 {{0,0}, {1,0}, {1,0}, {0,0}}, 00229 {{-1,0}, {0,0}, {0,0}, {1,0}} 00230 }, 00231 { 00232 {{1,0}, {0,0}, {0,-1}, {0,0}}, 00233 {{0,0}, {1,0}, {0,0}, {0,1}}, 00234 {{0,1}, {0,0}, {1,0}, {0,0}}, 00235 {{0,0}, {0,-1}, {0,0}, {1,0}} 00236 }, 00237 { 00238 {{1,0}, {0,0}, {0,1}, {0,0}}, 00239 {{0,0}, {1,0}, {0,0}, {0,-1}}, 00240 {{0,-1}, {0,0}, {1,0}, {0,0}}, 00241 {{0,0}, {0,1}, {0,0}, {1,0}} 00242 }, 00243 { 00244 {{1,0}, {0,0}, {-1,0}, {0,0}}, 00245 {{0,0}, {1,0}, {0,0}, {-1,0}}, 00246 {{-1,0}, {0,0}, {1,0}, {0,0}}, 00247 {{0,0}, {-1,0}, {0,0}, {1,0}} 00248 }, 00249 { 00250 {{1,0}, {0,0}, {1,0}, {0,0}}, 00251 {{0,0}, {1,0}, {0,0}, {1,0}}, 00252 {{1,0}, {0,0}, {1,0}, {0,0}}, 00253 {{0,0}, {1,0}, {0,0}, {1,0}} 00254 }, 00255 // P_+ = P_R 00256 { 00257 {{2,0}, {0,0}, {0,0}, {0,0}}, 00258 {{0,0}, {2,0}, {0,0}, {0,0}}, 00259 {{0,0}, {0,0}, {0,0}, {0,0}}, 00260 {{0,0}, {0,0}, {0,0}, {0,0}} 00261 }, 00262 // P_- = P_L 00263 { 00264 {{0,0}, {0,0}, {0,0}, {0,0}}, 00265 {{0,0}, {0,0}, {0,0}, {0,0}}, 00266 {{0,0}, {0,0}, {2,0}, {0,0}}, 00267 {{0,0}, {0,0}, {0,0}, {2,0}} 00268 } 00269 }; 00270 00271 00272 // todo pass projector 00273 template <typename Float> 00274 void multiplySpinorByDiracProjector(Float *res, int projIdx, Float *spinorIn) { 00275 for (int i=0; i<4*3*2; i++) res[i] = 0.0; 00276 00277 for (int s = 0; s < 4; s++) { 00278 for (int t = 0; t < 4; t++) { 00279 Float projRe = projector[projIdx][s][t][0]; 00280 Float projIm = projector[projIdx][s][t][1]; 00281 00282 for (int m = 0; m < 3; m++) { 00283 Float spinorRe = spinorIn[t*(3*2) + m*(2) + 0]; 00284 Float spinorIm = spinorIn[t*(3*2) + m*(2) + 1]; 00285 res[s*(3*2) + m*(2) + 0] += projRe*spinorRe - projIm*spinorIm; 00286 res[s*(3*2) + m*(2) + 1] += projRe*spinorIm + projIm*spinorRe; 00287 } 00288 } 00289 } 00290 } 00291 00292 00293 00294 // dslashReference_4d() 00295 //J This is just the 4d wilson dslash of quda code, with a 00296 //J few small changes to take into account that the spinors 00297 //J are 5d and the gauge fields are 4d. 00298 // 00299 // if oddBit is zero: calculate odd parity spinor elements (using even parity spinor) 00300 // if oddBit is one: calculate even parity spinor elements 00301 // 00302 // if daggerBit is zero: perform ordinary dslash operator 00303 // if daggerBit is one: perform hermitian conjugate of dslash 00304 // 00305 //An "ok" will only be granted once check2.tex is deemed complete, 00306 //since the logic in this function is important and nontrivial. 00307 template <typename sFloat, typename gFloat> 00308 void dslashReference_4d(sFloat *res, gFloat **gaugeFull, sFloat *spinorField, 00309 int oddBit, int daggerBit) { 00310 00311 // Initialize the return half-spinor to zero. Note that it is a 00312 // 5d spinor, hence the use of V5h. 00313 for (int i=0; i<V5h*4*3*2; i++) res[i] = 0.0; 00314 00315 // Some pointers that we use to march through arrays. 00316 gFloat *gaugeEven[4], *gaugeOdd[4]; 00317 // Initialize to beginning of even and odd parts of 00318 // gauge array. 00319 for (int dir = 0; dir < 4; dir++) { 00320 gaugeEven[dir] = gaugeFull[dir]; 00321 // Note the use of Vh here, since the gauge fields 00322 // are 4-dim'l. 00323 gaugeOdd[dir] = gaugeFull[dir]+Vh*gaugeSiteSize; 00324 } 00325 int sp_idx,oddBit_gge; 00326 for (int xs=0;xs<Ls;xs++) { 00327 for (int gge_idx = 0; gge_idx < Vh; gge_idx++) { 00328 for (int dir = 0; dir < 8; dir++) { 00329 sp_idx=gge_idx+Vh*xs; 00330 // Here is a function call to study. It is defined near 00331 // Line 90 of this file. 00332 // Here we have to switch oddBit depending on the value of xs. E.g., suppose 00333 // xs=1. Then the odd spinor site x1=x2=x3=x4=0 wants the even gauge array 00334 // element 0, so that we get U_\mu(0). 00335 if ((xs % 2) == 0) oddBit_gge=oddBit; 00336 else oddBit_gge= (oddBit+1) % 2; 00337 gFloat *gauge = gaugeLink(gge_idx, dir, oddBit_gge, gaugeEven, gaugeOdd); 00338 00339 // Even though we're doing the 4d part of the dslash, we need 00340 // to use a 5d neighbor function, to get the offsets right. 00341 sFloat *spinor = spinorNeighbor_5d(sp_idx, dir, oddBit, spinorField); 00342 00343 sFloat projectedSpinor[4*3*2], gaugedSpinor[4*3*2]; 00344 int projIdx = 2*(dir/2)+(dir+daggerBit)%2; 00345 multiplySpinorByDiracProjector(projectedSpinor, projIdx, spinor); 00346 00347 for (int s = 0; s < 4; s++) { 00348 if (dir % 2 == 0) { 00349 su3Mul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]); 00350 #ifdef DBUG_VERBOSE 00351 std::cout << "spinor:" << std::endl; 00352 printSpinorElement(&projectedSpinor[s*(3*2)],0,QUDA_DOUBLE_PRECISION); 00353 std::cout << "gauge:" << std::endl; 00354 #endif 00355 } else { 00356 su3Tmul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]); 00357 } 00358 } 00359 00360 sum(&res[sp_idx*(4*3*2)], &res[sp_idx*(4*3*2)], gaugedSpinor, 4*3*2); 00361 } 00362 } 00363 } 00364 } 00365 00366 template <typename sFloat> 00367 void dslashReference_5th(sFloat *res, sFloat *spinorField, 00368 int oddBit, int daggerBit, sFloat mferm) { 00369 for (int i = 0; i < V5h; i++) { 00370 for (int dir = 8; dir < 10; dir++) { 00371 // Calls for an extension of the original function. 00372 // 8 is forward hop, which wants P_+, 9 is backward hop, 00373 // which wants P_-. Dagger reverses these. 00374 sFloat *spinor = spinorNeighbor_5d(i, dir, oddBit, spinorField); 00375 sFloat projectedSpinor[4*3*2]; 00376 int projIdx = 2*(dir/2)+(dir+daggerBit)%2; 00377 multiplySpinorByDiracProjector(projectedSpinor, projIdx, spinor); 00378 //J Need a conditional here for s=0 and s=Ls-1. 00379 int X = fullLatticeIndex_5d(i, oddBit); 00380 int xs = X/(Z[3]*Z[2]*Z[1]*Z[0]); 00381 if ( (xs == 0 && dir == 9) || (xs == Ls-1 && dir == 8) ) { 00382 product(projectedSpinor,(sFloat)(-mferm),projectedSpinor,4*3*2); 00383 } 00384 sum(&res[i*(4*3*2)], &res[i*(4*3*2)], projectedSpinor, 4*3*2); 00385 } 00386 } 00387 } 00388 00389 // Recall that dslash is only the off-diagonal parts, so m0_dwf is not needed. 00390 // 00391 void dslash(void *res, void **gaugeFull, void *spinorField, 00392 int oddBit, int daggerBit, 00393 QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm) { 00394 00395 if (sPrecision == QUDA_DOUBLE_PRECISION) { 00396 if (gPrecision == QUDA_DOUBLE_PRECISION) { 00397 // Do the 4d part, which hasn't changed. 00398 printf("doing 4d part\n"); fflush(stdout); 00399 dslashReference_4d<double,double>((double*)res, (double**)gaugeFull, 00400 (double*)spinorField, oddBit, daggerBit); 00401 // Now add in the 5th dim. 00402 printf("doing 5th dimen. part\n"); fflush(stdout); 00403 dslashReference_5th<double>((double*)res, (double*)spinorField, 00404 oddBit, daggerBit, mferm); 00405 } else { 00406 dslashReference_4d<double,float>((double*)res, (float**)gaugeFull, (double*)spinorField, oddBit, daggerBit); 00407 dslashReference_5th<double>((double*)res, (double*)spinorField, oddBit, daggerBit, mferm); 00408 } 00409 } else { 00410 // Single-precision spinor. 00411 if (gPrecision == QUDA_DOUBLE_PRECISION) { 00412 dslashReference_4d<float,double>((float*)res, (double**)gaugeFull, (float*)spinorField, oddBit, daggerBit); 00413 dslashReference_5th<float>((float*)res, (float*)spinorField, oddBit, daggerBit, mferm); 00414 } else { 00415 // Do the 4d part, which hasn't changed. 00416 printf("CPU reference: doing 4d part all single precision\n"); fflush(stdout); 00417 dslashReference_4d<float,float>((float*)res, (float**)gaugeFull, (float*)spinorField, oddBit, daggerBit); 00418 // Now add in the 5th dim. 00419 printf("CPU reference: doing 5th dimen. part all single precision\n"); fflush(stdout); 00420 dslashReference_5th<float>((float*)res, (float*)spinorField, oddBit, daggerBit, mferm); 00421 } 00422 } 00423 } 00424 00425 00426 template <typename sFloat, typename gFloat> 00427 void Mat(sFloat *out, gFloat **gauge, sFloat *in, sFloat kappa, sFloat mferm) { 00428 sFloat *inEven = in; 00429 sFloat *inOdd = in + V5h*spinorSiteSize; 00430 sFloat *outEven = out; 00431 sFloat *outOdd = out + V5h*spinorSiteSize; 00432 00433 // full dslash operator 00434 dslashReference_4d(outOdd, gauge, inEven, 1, 0); 00435 dslashReference_5th(outOdd, inEven, 1, 0, mferm); 00436 dslashReference_4d(outEven, gauge, inOdd, 0, 0); 00437 dslashReference_5th(outEven, inOdd, 0, 0, mferm); 00438 00439 // lastly apply the kappa term 00440 xpay(in, -kappa, out, V5*spinorSiteSize); 00441 } 00442 00443 template <typename sFloat, typename gFloat> 00444 void MatDag(sFloat *out, gFloat **gauge, sFloat *in, sFloat kappa, sFloat mferm) { 00445 sFloat *inEven = in; 00446 sFloat *inOdd = in + V5h*spinorSiteSize; 00447 sFloat *outEven = out; 00448 sFloat *outOdd = out + V5h*spinorSiteSize; 00449 00450 // full dslash operator 00451 dslashReference_4d(outOdd, gauge, inEven, 1, 1); 00452 dslashReference_5th(outOdd, inEven, 1, 1, mferm); 00453 dslashReference_4d(outEven, gauge, inOdd, 0, 1); 00454 dslashReference_5th(outEven, inOdd, 0, 1, mferm); 00455 00456 // lastly apply the kappa term 00457 xpay(in, -kappa, out, V5*spinorSiteSize); 00458 } 00459 00460 void mat(void *out, void **gauge, void *in, double kappa, int dagger_bit, 00461 QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm) { 00462 if (!dagger_bit) { 00463 if (sPrecision == QUDA_DOUBLE_PRECISION) 00464 if (gPrecision == QUDA_DOUBLE_PRECISION) Mat((double*)out, (double**)gauge, (double*)in, (double)kappa, 00465 (double)mferm); 00466 else Mat((double*)out, (float**)gauge, (double*)in, (double)kappa, (double)mferm); 00467 else 00468 if (gPrecision == QUDA_DOUBLE_PRECISION) Mat((float*)out, (double**)gauge, (float*)in, (float)kappa, 00469 (float)mferm); 00470 else Mat((float*)out, (float**)gauge, (float*)in, (float)kappa, (float)mferm); 00471 } else { 00472 if (sPrecision == QUDA_DOUBLE_PRECISION) 00473 if (gPrecision == QUDA_DOUBLE_PRECISION) MatDag((double*)out, (double**)gauge, (double*)in, (double)kappa, 00474 (double)mferm); 00475 else MatDag((float*)out, (double**)gauge, (float*)in, (float)kappa, (float)mferm); 00476 else 00477 if (gPrecision == QUDA_DOUBLE_PRECISION) MatDag((float*)out, (double**)gauge, (float*)in, (float)kappa, 00478 (float)mferm); 00479 else MatDag((float*)out, (float**)gauge, (float*)in, (float)kappa, (float)mferm); 00480 } 00481 } 00482 00483 // Apply the even-odd preconditioned Dirac operator 00484 template <typename sFloat, typename gFloat> 00485 void MatPC(sFloat *outEven, gFloat **gauge, sFloat *inEven, sFloat kappa, 00486 QudaMatPCType matpc_type, sFloat mferm) { 00487 00488 sFloat *tmp = (sFloat*)malloc(V5h*spinorSiteSize*sizeof(sFloat)); 00489 00490 // full dslash operator 00491 if (matpc_type == QUDA_MATPC_EVEN_EVEN) { 00492 dslashReference_4d(tmp, gauge, inEven, 1, 0); 00493 dslashReference_5th(tmp, inEven, 1, 0, mferm); 00494 dslashReference_4d(outEven, gauge, tmp, 0, 0); 00495 dslashReference_5th(outEven, tmp, 0, 0, mferm); 00496 } else { 00497 dslashReference_4d(tmp, gauge, inEven, 0, 0); 00498 dslashReference_5th(tmp, inEven, 0, 0, mferm); 00499 dslashReference_4d(outEven, gauge, tmp, 1, 0); 00500 dslashReference_5th(outEven, tmp, 1, 0, mferm); 00501 } 00502 00503 // lastly apply the kappa term 00504 sFloat kappa2 = -kappa*kappa; 00505 xpay(inEven, kappa2, outEven, V5h*spinorSiteSize); 00506 free(tmp); 00507 } 00508 00509 // Apply the even-odd preconditioned Dirac operator 00510 template <typename sFloat, typename gFloat> 00511 void MatPCDag(sFloat *outEven, gFloat **gauge, sFloat *inEven, sFloat kappa, 00512 QudaMatPCType matpc_type, sFloat mferm) { 00513 00514 sFloat *tmp = (sFloat*)malloc(V5h*spinorSiteSize*sizeof(sFloat)); 00515 00516 // full dslash operator 00517 if (matpc_type == QUDA_MATPC_EVEN_EVEN) { 00518 dslashReference_4d(tmp, gauge, inEven, 1, 1); 00519 dslashReference_5th(tmp, inEven, 1, 1, mferm); 00520 dslashReference_4d(outEven, gauge, tmp, 0, 1); 00521 dslashReference_5th(outEven, tmp, 0, 1, mferm); 00522 } else { 00523 dslashReference_4d(tmp, gauge, inEven, 0, 1); 00524 dslashReference_5th(tmp, inEven, 0, 1, mferm); 00525 dslashReference_4d(outEven, gauge, tmp, 1, 1); 00526 dslashReference_5th(outEven, tmp, 1, 1, mferm); 00527 } 00528 00529 sFloat kappa2 = -kappa*kappa; 00530 xpay(inEven, kappa2, outEven, V5h*spinorSiteSize); 00531 free(tmp); 00532 } 00533 00534 void matpc(void *outEven, void **gauge, void *inEven, double kappa, 00535 QudaMatPCType matpc_type, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision, 00536 double mferm) { 00537 if (!dagger_bit) { 00538 if (sPrecision == QUDA_DOUBLE_PRECISION) 00539 if (gPrecision == QUDA_DOUBLE_PRECISION) 00540 MatPC((double*)outEven, (double**)gauge, (double*)inEven, (double)kappa, matpc_type, (double)mferm); 00541 else 00542 MatPC((double*)outEven, (float**)gauge, (double*)inEven, (double)kappa, matpc_type, (double)mferm); 00543 else 00544 if (gPrecision == QUDA_DOUBLE_PRECISION) 00545 MatPC((float*)outEven, (double**)gauge, (float*)inEven, (float)kappa, matpc_type, (float)mferm); 00546 else 00547 MatPC((float*)outEven, (float**)gauge, (float*)inEven, (float)kappa, matpc_type, (float)mferm); 00548 } else { 00549 if (sPrecision == QUDA_DOUBLE_PRECISION) 00550 if (gPrecision == QUDA_DOUBLE_PRECISION) 00551 MatPCDag((double*)outEven, (double**)gauge, (double*)inEven, (double)kappa, matpc_type, (double)mferm); 00552 else 00553 MatPCDag((double*)outEven, (float**)gauge, (double*)inEven, (double)kappa, matpc_type, (double)mferm); 00554 else 00555 if (gPrecision == QUDA_DOUBLE_PRECISION) 00556 MatPCDag((float*)outEven, (double**)gauge, (float*)inEven, (float)kappa, matpc_type, (float)mferm); 00557 else 00558 MatPCDag((float*)outEven, (float**)gauge, (float*)inEven, (float)kappa, matpc_type, (float)mferm); 00559 } 00560 } 00561 00562 00563 template <typename sFloat, typename gFloat> 00564 void MatDagMat(sFloat *out, gFloat **gauge, sFloat *in, sFloat kappa, sFloat mferm) 00565 { 00566 // Allocate a full spinor. 00567 sFloat *tmp = (sFloat*)malloc(V5*spinorSiteSize*sizeof(sFloat)); 00568 // Call templates above. 00569 Mat(tmp, gauge, in, kappa, mferm); 00570 MatDag(out, gauge, tmp, kappa, mferm); 00571 free(tmp); 00572 } 00573 00574 template <typename sFloat, typename gFloat> 00575 void MatPCDagMatPC(sFloat *out, gFloat **gauge, sFloat *in, sFloat kappa, 00576 QudaMatPCType matpc_type, sFloat mferm) 00577 { 00578 00579 // Allocate half spinor 00580 sFloat *tmp = (sFloat*)malloc(V5h*spinorSiteSize*sizeof(sFloat)); 00581 // Apply the PC templates above 00582 MatPC(tmp, gauge, in, kappa, matpc_type, mferm); 00583 MatPCDag(out, gauge, tmp, kappa, matpc_type, mferm); 00584 free(tmp); 00585 } 00586 00587 // Wrapper to templates that handles different precisions. 00588 void matdagmat(void *out, void **gauge, void *in, double kappa, 00589 QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm) 00590 { 00591 if (sPrecision == QUDA_DOUBLE_PRECISION) { 00592 if (gPrecision == QUDA_DOUBLE_PRECISION) 00593 MatDagMat((double*)out, (double**)gauge, (double*)in, (double)kappa, 00594 (double)mferm); 00595 else 00596 MatDagMat((double*)out, (float**)gauge, (double*)in, (double)kappa, (double)mferm); 00597 } else { 00598 if (gPrecision == QUDA_DOUBLE_PRECISION) 00599 MatDagMat((float*)out, (double**)gauge, (float*)in, (float)kappa, 00600 (float)mferm); 00601 else 00602 MatDagMat((float*)out, (float**)gauge, (float*)in, (float)kappa, (float)mferm); 00603 } 00604 } 00605 00606 // Wrapper to templates that handles different precisions. 00607 void matpcdagmatpc(void *out, void **gauge, void *in, double kappa, 00608 QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm, QudaMatPCType matpc_type) 00609 { 00610 if (sPrecision == QUDA_DOUBLE_PRECISION) { 00611 if (gPrecision == QUDA_DOUBLE_PRECISION) 00612 MatPCDagMatPC((double*)out, (double**)gauge, (double*)in, (double)kappa, 00613 matpc_type, (double)mferm); 00614 else 00615 MatPCDagMatPC((double*)out, (float**)gauge, (double*)in, (double)kappa, 00616 matpc_type, (double)mferm); 00617 } else { 00618 if (gPrecision == QUDA_DOUBLE_PRECISION) 00619 MatPCDagMatPC((float*)out, (double**)gauge, (float*)in, (float)kappa, 00620 matpc_type, (float)mferm); 00621 else 00622 MatPCDagMatPC((float*)out, (float**)gauge, (float*)in, (float)kappa, 00623 matpc_type, (float)mferm); 00624 } 00625 } 00626