QUDA v0.4.0
A library for QCD on GPUs
|
00001 #include <stdio.h> 00002 #include <stdlib.h> 00003 #include <math.h> 00004 00005 #include <util_quda.h> 00006 00007 #include <test_util.h> 00008 #include <blas_reference.h> 00009 #include <wilson_dslash_reference.h> 00010 00011 #include <gauge_field.h> 00012 #include <color_spinor_field.h> 00013 #include <face_quda.h> 00014 00015 static int mySpinorSiteSize = 24; 00016 00017 #include <dslash_util.h> 00018 00019 static const double projector[8][4][4][2] = { 00020 { 00021 {{1,0}, {0,0}, {0,0}, {0,-1}}, 00022 {{0,0}, {1,0}, {0,-1}, {0,0}}, 00023 {{0,0}, {0,1}, {1,0}, {0,0}}, 00024 {{0,1}, {0,0}, {0,0}, {1,0}} 00025 }, 00026 { 00027 {{1,0}, {0,0}, {0,0}, {0,1}}, 00028 {{0,0}, {1,0}, {0,1}, {0,0}}, 00029 {{0,0}, {0,-1}, {1,0}, {0,0}}, 00030 {{0,-1}, {0,0}, {0,0}, {1,0}} 00031 }, 00032 { 00033 {{1,0}, {0,0}, {0,0}, {1,0}}, 00034 {{0,0}, {1,0}, {-1,0}, {0,0}}, 00035 {{0,0}, {-1,0}, {1,0}, {0,0}}, 00036 {{1,0}, {0,0}, {0,0}, {1,0}} 00037 }, 00038 { 00039 {{1,0}, {0,0}, {0,0}, {-1,0}}, 00040 {{0,0}, {1,0}, {1,0}, {0,0}}, 00041 {{0,0}, {1,0}, {1,0}, {0,0}}, 00042 {{-1,0}, {0,0}, {0,0}, {1,0}} 00043 }, 00044 { 00045 {{1,0}, {0,0}, {0,-1}, {0,0}}, 00046 {{0,0}, {1,0}, {0,0}, {0,1}}, 00047 {{0,1}, {0,0}, {1,0}, {0,0}}, 00048 {{0,0}, {0,-1}, {0,0}, {1,0}} 00049 }, 00050 { 00051 {{1,0}, {0,0}, {0,1}, {0,0}}, 00052 {{0,0}, {1,0}, {0,0}, {0,-1}}, 00053 {{0,-1}, {0,0}, {1,0}, {0,0}}, 00054 {{0,0}, {0,1}, {0,0}, {1,0}} 00055 }, 00056 { 00057 {{1,0}, {0,0}, {-1,0}, {0,0}}, 00058 {{0,0}, {1,0}, {0,0}, {-1,0}}, 00059 {{-1,0}, {0,0}, {1,0}, {0,0}}, 00060 {{0,0}, {-1,0}, {0,0}, {1,0}} 00061 }, 00062 { 00063 {{1,0}, {0,0}, {1,0}, {0,0}}, 00064 {{0,0}, {1,0}, {0,0}, {1,0}}, 00065 {{1,0}, {0,0}, {1,0}, {0,0}}, 00066 {{0,0}, {1,0}, {0,0}, {1,0}} 00067 } 00068 }; 00069 00070 00071 // todo pass projector 00072 template <typename Float> 00073 void multiplySpinorByDiracProjector(Float *res, int projIdx, Float *spinorIn) { 00074 for (int i=0; i<4*3*2; i++) res[i] = 0.0; 00075 00076 for (int s = 0; s < 4; s++) { 00077 for (int t = 0; t < 4; t++) { 00078 Float projRe = projector[projIdx][s][t][0]; 00079 Float projIm = projector[projIdx][s][t][1]; 00080 00081 for (int m = 0; m < 3; m++) { 00082 Float spinorRe = spinorIn[t*(3*2) + m*(2) + 0]; 00083 Float spinorIm = spinorIn[t*(3*2) + m*(2) + 1]; 00084 res[s*(3*2) + m*(2) + 0] += projRe*spinorRe - projIm*spinorIm; 00085 res[s*(3*2) + m*(2) + 1] += projRe*spinorIm + projIm*spinorRe; 00086 } 00087 } 00088 } 00089 } 00090 00091 00092 // 00093 // dslashReference() 00094 // 00095 // if oddBit is zero: calculate odd parity spinor elements (using even parity spinor) 00096 // if oddBit is one: calculate even parity spinor elements 00097 // 00098 // if daggerBit is zero: perform ordinary dslash operator 00099 // if daggerBit is one: perform hermitian conjugate of dslash 00100 // 00101 00102 #ifndef MULTI_GPU 00103 00104 template <typename sFloat, typename gFloat> 00105 void dslashReference(sFloat *res, gFloat **gaugeFull, sFloat *spinorField, int oddBit, int daggerBit) { 00106 for (int i=0; i<Vh*mySpinorSiteSize; i++) res[i] = 0.0; 00107 00108 gFloat *gaugeEven[4], *gaugeOdd[4]; 00109 for (int dir = 0; dir < 4; dir++) { 00110 gaugeEven[dir] = gaugeFull[dir]; 00111 gaugeOdd[dir] = gaugeFull[dir]+Vh*gaugeSiteSize; 00112 } 00113 00114 for (int i = 0; i < Vh; i++) { 00115 for (int dir = 0; dir < 8; dir++) { 00116 gFloat *gauge = gaugeLink(i, dir, oddBit, gaugeEven, gaugeOdd, 1); 00117 sFloat *spinor = spinorNeighbor(i, dir, oddBit, spinorField, 1); 00118 00119 sFloat projectedSpinor[4*3*2], gaugedSpinor[4*3*2]; 00120 int projIdx = 2*(dir/2)+(dir+daggerBit)%2; 00121 multiplySpinorByDiracProjector(projectedSpinor, projIdx, spinor); 00122 00123 for (int s = 0; s < 4; s++) { 00124 if (dir % 2 == 0) su3Mul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]); 00125 else su3Tmul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]); 00126 } 00127 00128 sum(&res[i*(4*3*2)], &res[i*(4*3*2)], gaugedSpinor, 4*3*2); 00129 } 00130 } 00131 } 00132 00133 #else 00134 00135 template <typename sFloat, typename gFloat> 00136 void dslashReference(sFloat *res, gFloat **gaugeFull, gFloat **ghostGauge, sFloat *spinorField, 00137 sFloat **fwdSpinor, sFloat **backSpinor, int oddBit, int daggerBit) { 00138 for (int i=0; i<Vh*mySpinorSiteSize; i++) res[i] = 0.0; 00139 00140 gFloat *gaugeEven[4], *gaugeOdd[4]; 00141 gFloat *ghostGaugeEven[4], *ghostGaugeOdd[4]; 00142 for (int dir = 0; dir < 4; dir++) { 00143 gaugeEven[dir] = gaugeFull[dir]; 00144 gaugeOdd[dir] = gaugeFull[dir]+Vh*gaugeSiteSize; 00145 00146 ghostGaugeEven[dir] = ghostGauge[dir]; 00147 ghostGaugeOdd[dir] = ghostGauge[dir] + (faceVolume[dir]/2)*gaugeSiteSize; 00148 } 00149 00150 for (int i = 0; i < Vh; i++) { 00151 00152 for (int dir = 0; dir < 8; dir++) { 00153 gFloat *gauge = gaugeLink_mg4dir(i, dir, oddBit, gaugeEven, gaugeOdd, ghostGaugeEven, ghostGaugeOdd, 1, 1); 00154 sFloat *spinor = spinorNeighbor_mg4dir(i, dir, oddBit, spinorField, fwdSpinor, backSpinor, 1, 1); 00155 00156 sFloat projectedSpinor[mySpinorSiteSize], gaugedSpinor[mySpinorSiteSize]; 00157 int projIdx = 2*(dir/2)+(dir+daggerBit)%2; 00158 multiplySpinorByDiracProjector(projectedSpinor, projIdx, spinor); 00159 00160 for (int s = 0; s < 4; s++) { 00161 if (dir % 2 == 0) su3Mul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]); 00162 else su3Tmul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]); 00163 } 00164 00165 sum(&res[i*(4*3*2)], &res[i*(4*3*2)], gaugedSpinor, 4*3*2); 00166 } 00167 00168 } 00169 } 00170 00171 #endif 00172 00173 // this actually applies the preconditioned dslash, e.g., D_ee^{-1} D_eo or D_oo^{-1} D_oe 00174 void wil_dslash(void *out, void **gauge, void *in, int oddBit, int daggerBit, 00175 QudaPrecision precision, QudaGaugeParam &gauge_param) { 00176 00177 #ifndef MULTI_GPU 00178 if (precision == QUDA_DOUBLE_PRECISION) 00179 dslashReference((double*)out, (double**)gauge, (double*)in, oddBit, daggerBit); 00180 else 00181 dslashReference((float*)out, (float**)gauge, (float*)in, oddBit, daggerBit); 00182 #else 00183 00184 GaugeFieldParam gauge_field_param(gauge, gauge_param); 00185 cpuGaugeField cpu(gauge_field_param); 00186 cpu.exchangeGhost(); 00187 void **ghostGauge = (void**)cpu.Ghost(); 00188 00189 // Get spinor ghost fields 00190 // First wrap the input spinor into a ColorSpinorField 00191 ColorSpinorParam csParam; 00192 csParam.v = in; 00193 csParam.fieldLocation = QUDA_CPU_FIELD_LOCATION; 00194 csParam.nColor = 3; 00195 csParam.nSpin = 4; 00196 csParam.nDim = 4; 00197 for (int d=0; d<4; d++) csParam.x[d] = Z[d]; 00198 csParam.precision = precision; 00199 csParam.pad = 0; 00200 csParam.siteSubset = QUDA_PARITY_SITE_SUBSET; 00201 csParam.x[0] /= 2; 00202 csParam.siteOrder = QUDA_EVEN_ODD_SITE_ORDER; 00203 csParam.fieldOrder = QUDA_SPACE_SPIN_COLOR_FIELD_ORDER; 00204 csParam.gammaBasis = QUDA_DEGRAND_ROSSI_GAMMA_BASIS; 00205 csParam.create = QUDA_REFERENCE_FIELD_CREATE; 00206 00207 cpuColorSpinorField inField(csParam); 00208 00209 { // Now do the exchange 00210 QudaParity otherParity = QUDA_INVALID_PARITY; 00211 if (oddBit == QUDA_EVEN_PARITY) otherParity = QUDA_ODD_PARITY; 00212 else if (oddBit == QUDA_ODD_PARITY) otherParity = QUDA_EVEN_PARITY; 00213 else errorQuda("ERROR: full parity not supported in function %s", __FUNCTION__); 00214 00215 int nFace = 1; 00216 FaceBuffer faceBuf(Z, 4, mySpinorSiteSize, nFace, precision); 00217 faceBuf.exchangeCpuSpinor(inField, otherParity, daggerBit); 00218 } 00219 void** fwd_nbr_spinor = inField.fwdGhostFaceBuffer; 00220 void** back_nbr_spinor = inField.backGhostFaceBuffer; 00221 00222 if (precision == QUDA_DOUBLE_PRECISION) { 00223 dslashReference((double*)out, (double**)gauge, (double**)ghostGauge, (double*)in, 00224 (double**)fwd_nbr_spinor, (double**)back_nbr_spinor, oddBit, daggerBit); 00225 } else{ 00226 dslashReference((float*)out, (float**)gauge, (float**)ghostGauge, (float*)in, 00227 (float**)fwd_nbr_spinor, (float**)back_nbr_spinor, oddBit, daggerBit); 00228 } 00229 00230 #endif 00231 00232 } 00233 00234 // applies b*(1 + i*a*gamma_5) 00235 template <typename sFloat> 00236 void twistGamma5(sFloat *out, sFloat *in, const int dagger, const sFloat kappa, const sFloat mu, 00237 const QudaTwistFlavorType flavor, const int V, QudaTwistGamma5Type twist) { 00238 00239 sFloat a=0.0,b=0.0; 00240 if (twist == QUDA_TWIST_GAMMA5_DIRECT) { // applying the twist 00241 a = 2.0 * kappa * mu * flavor; // mu already includes the flavor 00242 b = 1.0; 00243 } else if (twist == QUDA_TWIST_GAMMA5_INVERSE) { // applying the inverse twist 00244 a = -2.0 * kappa * mu * flavor; 00245 b = 1.0 / (1.0 + a*a); 00246 } else { 00247 printf("Twist type %d not defined\n", twist); 00248 exit(0); 00249 } 00250 00251 if (dagger) a *= -1.0; 00252 00253 for(int i = 0; i < V; i++) { 00254 sFloat tmp[24]; 00255 for(int s = 0; s < 4; s++) 00256 for(int c = 0; c < 3; c++) { 00257 sFloat a5 = ((s / 2) ? -1.0 : +1.0) * a; 00258 tmp[s * 6 + c * 2 + 0] = b* (in[i * 24 + s * 6 + c * 2 + 0] - a5*in[i * 24 + s * 6 + c * 2 + 1]); 00259 tmp[s * 6 + c * 2 + 1] = b* (in[i * 24 + s * 6 + c * 2 + 1] + a5*in[i * 24 + s * 6 + c * 2 + 0]); 00260 } 00261 00262 for (int j=0; j<24; j++) out[i*24+j] = tmp[j]; 00263 } 00264 00265 } 00266 00267 void twist_gamma5(void *out, void *in, int daggerBit, double kappa, double mu, QudaTwistFlavorType flavor, 00268 int V, QudaTwistGamma5Type twist, QudaPrecision precision) { 00269 00270 if (precision == QUDA_DOUBLE_PRECISION) { 00271 twistGamma5((double*)out, (double*)in, daggerBit, kappa, mu, flavor, V, twist); 00272 } else { 00273 twistGamma5((float*)out, (float*)in, daggerBit, (float)kappa, (float)mu, flavor, V, twist); 00274 } 00275 } 00276 00277 00278 void tm_dslash(void *res, void **gaugeFull, void *spinorField, double kappa, double mu, 00279 QudaTwistFlavorType flavor, int oddBit, int daggerBit, QudaPrecision precision, 00280 QudaGaugeParam &gauge_param) 00281 { 00282 00283 if (daggerBit) twist_gamma5(spinorField, spinorField, daggerBit, kappa, mu, 00284 flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE, precision); 00285 00286 wil_dslash(res, gaugeFull, spinorField, oddBit, daggerBit, precision, gauge_param); 00287 00288 if (!daggerBit) { 00289 twist_gamma5(res, res, daggerBit, kappa, mu, flavor, 00290 Vh, QUDA_TWIST_GAMMA5_INVERSE, precision); 00291 } else { 00292 twist_gamma5(spinorField, spinorField, daggerBit, kappa, mu, flavor, 00293 Vh, QUDA_TWIST_GAMMA5_DIRECT, precision); 00294 } 00295 00296 } 00297 00298 void wil_mat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, 00299 QudaGaugeParam &gauge_param) { 00300 00301 void *inEven = in; 00302 void *inOdd = (char*)in + Vh*spinorSiteSize*precision; 00303 void *outEven = out; 00304 void *outOdd = (char*)out + Vh*spinorSiteSize*precision; 00305 00306 wil_dslash(outOdd, gauge, inEven, 1, dagger_bit, precision, gauge_param); 00307 wil_dslash(outEven, gauge, inOdd, 0, dagger_bit, precision, gauge_param); 00308 00309 // lastly apply the kappa term 00310 if (precision == QUDA_DOUBLE_PRECISION) xpay((double*)in, -kappa, (double*)out, V*spinorSiteSize); 00311 else xpay((float*)in, -(float)kappa, (float*)out, V*spinorSiteSize); 00312 } 00313 00314 void tm_mat(void *out, void **gauge, void *in, double kappa, double mu, 00315 QudaTwistFlavorType flavor, int dagger_bit, QudaPrecision precision, 00316 QudaGaugeParam &gauge_param) { 00317 00318 void *inEven = in; 00319 void *inOdd = (char*)in + Vh*spinorSiteSize*precision; 00320 void *outEven = out; 00321 void *outOdd = (char*)out + Vh*spinorSiteSize*precision; 00322 void *tmp = malloc(V*spinorSiteSize*precision); 00323 00324 wil_dslash(outOdd, gauge, inEven, 1, dagger_bit, precision, gauge_param); 00325 wil_dslash(outEven, gauge, inOdd, 0, dagger_bit, precision, gauge_param); 00326 00327 // apply the twist term to the full lattice 00328 twist_gamma5(tmp, in, dagger_bit, kappa, mu, flavor, V, QUDA_TWIST_GAMMA5_DIRECT, precision); 00329 00330 // combine 00331 if (precision == QUDA_DOUBLE_PRECISION) xpay((double*)tmp, -kappa, (double*)out, V*spinorSiteSize); 00332 else xpay((float*)tmp, -(float)kappa, (float*)out, V*spinorSiteSize); 00333 00334 free(tmp); 00335 } 00336 00337 // Apply the even-odd preconditioned Dirac operator 00338 void wil_matpc(void *outEven, void **gauge, void *inEven, double kappa, 00339 QudaMatPCType matpc_type, int daggerBit, QudaPrecision precision, 00340 QudaGaugeParam &gauge_param) { 00341 00342 void *tmp = malloc(Vh*spinorSiteSize*precision); 00343 00344 // full dslash operator 00345 if (matpc_type == QUDA_MATPC_EVEN_EVEN) { 00346 wil_dslash(tmp, gauge, inEven, 1, daggerBit, precision, gauge_param); 00347 wil_dslash(outEven, gauge, tmp, 0, daggerBit, precision, gauge_param); 00348 } else { 00349 wil_dslash(tmp, gauge, inEven, 0, daggerBit, precision, gauge_param); 00350 wil_dslash(outEven, gauge, tmp, 1, daggerBit, precision, gauge_param); 00351 } 00352 00353 // lastly apply the kappa term 00354 double kappa2 = -kappa*kappa; 00355 if (precision == QUDA_DOUBLE_PRECISION) xpay((double*)inEven, kappa2, (double*)outEven, Vh*spinorSiteSize); 00356 else xpay((float*)inEven, (float)kappa2, (float*)outEven, Vh*spinorSiteSize); 00357 00358 free(tmp); 00359 } 00360 00361 // Apply the even-odd preconditioned Dirac operator 00362 void tm_matpc(void *outEven, void **gauge, void *inEven, double kappa, double mu, QudaTwistFlavorType flavor, 00363 QudaMatPCType matpc_type, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param) { 00364 00365 void *tmp = malloc(Vh*spinorSiteSize*precision); 00366 00367 if (matpc_type == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) { 00368 wil_dslash(tmp, gauge, inEven, 1, daggerBit, precision, gauge_param); 00369 twist_gamma5(tmp, tmp, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE, precision); 00370 wil_dslash(outEven, gauge, tmp, 0, daggerBit, precision, gauge_param); 00371 twist_gamma5(tmp, inEven, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_DIRECT, precision); 00372 } else if (matpc_type == QUDA_MATPC_ODD_ODD_ASYMMETRIC) { 00373 wil_dslash(tmp, gauge, inEven, 0, daggerBit, precision, gauge_param); 00374 twist_gamma5(tmp, tmp, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE, precision); 00375 wil_dslash(outEven, gauge, tmp, 1, daggerBit, precision, gauge_param); 00376 twist_gamma5(tmp, inEven, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_DIRECT, precision); 00377 } else if (!daggerBit) { 00378 if (matpc_type == QUDA_MATPC_EVEN_EVEN) { 00379 wil_dslash(tmp, gauge, inEven, 1, daggerBit, precision, gauge_param); 00380 twist_gamma5(tmp, tmp, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE, precision); 00381 wil_dslash(outEven, gauge, tmp, 0, daggerBit, precision, gauge_param); 00382 twist_gamma5(outEven, outEven, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE, precision); 00383 } else if (matpc_type == QUDA_MATPC_ODD_ODD) { 00384 wil_dslash(tmp, gauge, inEven, 0, daggerBit, precision, gauge_param); 00385 twist_gamma5(tmp, tmp, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE, precision); 00386 wil_dslash(outEven, gauge, tmp, 1, daggerBit, precision, gauge_param); 00387 twist_gamma5(outEven, outEven, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE, precision); 00388 } 00389 } else { 00390 if (matpc_type == QUDA_MATPC_EVEN_EVEN) { 00391 twist_gamma5(inEven, inEven, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE, precision); 00392 wil_dslash(tmp, gauge, inEven, 1, daggerBit, precision, gauge_param); 00393 twist_gamma5(tmp, tmp, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE, precision); 00394 wil_dslash(outEven, gauge, tmp, 0, daggerBit, precision, gauge_param); 00395 twist_gamma5(inEven, inEven, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_DIRECT, precision); 00396 } else if (matpc_type == QUDA_MATPC_ODD_ODD) { 00397 twist_gamma5(inEven, inEven, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE, precision); 00398 wil_dslash(tmp, gauge, inEven, 0, daggerBit, precision, gauge_param); 00399 twist_gamma5(tmp, tmp, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_INVERSE, precision); 00400 wil_dslash(outEven, gauge, tmp, 1, daggerBit, precision, gauge_param); 00401 twist_gamma5(inEven, inEven, daggerBit, kappa, mu, flavor, Vh, QUDA_TWIST_GAMMA5_DIRECT, precision); // undo 00402 } 00403 } 00404 // lastly apply the kappa term 00405 double kappa2 = -kappa*kappa; 00406 if (matpc_type == QUDA_MATPC_EVEN_EVEN || matpc_type == QUDA_MATPC_ODD_ODD) { 00407 if (precision == QUDA_DOUBLE_PRECISION) xpay((double*)inEven, kappa2, (double*)outEven, Vh*spinorSiteSize); 00408 else xpay((float*)inEven, (float)kappa2, (float*)outEven, Vh*spinorSiteSize); 00409 } else { 00410 if (precision == QUDA_DOUBLE_PRECISION) xpay((double*)tmp, kappa2, (double*)outEven, Vh*spinorSiteSize); 00411 else xpay((float*)tmp, (float)kappa2, (float*)outEven, Vh*spinorSiteSize); 00412 } 00413 00414 free(tmp); 00415 }