QUDA v0.4.0
A library for QCD on GPUs
|
00001 #include <stdio.h> 00002 #include <stdlib.h> 00003 #include <math.h> 00004 #include <string.h> 00005 00006 #include <test_util.h> 00007 #include <quda_internal.h> 00008 #include <quda.h> 00009 #include <util_quda.h> 00010 #include <staggered_dslash_reference.h> 00011 #include "misc.h" 00012 #include <blas_quda.h> 00013 00014 #include <face_quda.h> 00015 00016 extern void *memset(void *s, int c, size_t n); 00017 00018 static int mySpinorSiteSize = 6; 00019 00020 #include <dslash_util.h> 00021 00022 // 00023 // dslashReference() 00024 // 00025 // if oddBit is zero: calculate even parity spinor elements (using odd parity spinor) 00026 // if oddBit is one: calculate odd parity spinor elements 00027 // 00028 // if daggerBit is zero: perform ordinary dslash operator 00029 // if daggerBit is one: perform hermitian conjugate of dslash 00030 // 00031 template<typename Float> 00032 void display_link_internal(Float* link) 00033 { 00034 int i, j; 00035 00036 for (i = 0;i < 3; i++){ 00037 for(j=0;j < 3; j++){ 00038 printf("(%10f,%10f) \t", link[i*3*2 + j*2], link[i*3*2 + j*2 + 1]); 00039 } 00040 printf("\n"); 00041 } 00042 printf("\n"); 00043 return; 00044 } 00045 00046 00047 template <typename sFloat, typename gFloat> 00048 void dslashReference(sFloat *res, gFloat **fatlink, gFloat** longlink, sFloat *spinorField, 00049 int oddBit, int daggerBit) 00050 { 00051 for (int i=0; i<Vh*1*3*2; i++) res[i] = 0.0; 00052 00053 gFloat *fatlinkEven[4], *fatlinkOdd[4]; 00054 gFloat *longlinkEven[4], *longlinkOdd[4]; 00055 00056 for (int dir = 0; dir < 4; dir++) { 00057 fatlinkEven[dir] = fatlink[dir]; 00058 fatlinkOdd[dir] = fatlink[dir] + Vh*gaugeSiteSize; 00059 longlinkEven[dir] =longlink[dir]; 00060 longlinkOdd[dir] = longlink[dir] + Vh*gaugeSiteSize; 00061 } 00062 00063 for (int i = 0; i < Vh; i++) { 00064 memset(res + i*mySpinorSiteSize, 0, mySpinorSiteSize*sizeof(sFloat)); 00065 for (int dir = 0; dir < 8; dir++) { 00066 gFloat* fatlnk = gaugeLink(i, dir, oddBit, fatlinkEven, fatlinkOdd, 1); 00067 gFloat* longlnk = gaugeLink(i, dir, oddBit, longlinkEven, longlinkOdd, 3); 00068 00069 sFloat *first_neighbor_spinor = spinorNeighbor(i, dir, oddBit, spinorField, 1); 00070 sFloat *third_neighbor_spinor = spinorNeighbor(i, dir, oddBit, spinorField, 3); 00071 00072 00073 sFloat gaugedSpinor[mySpinorSiteSize]; 00074 00075 if (dir % 2 == 0){ 00076 su3Mul(gaugedSpinor, fatlnk, first_neighbor_spinor); 00077 sum(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize); 00078 su3Mul(gaugedSpinor, longlnk, third_neighbor_spinor); 00079 sum(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize); 00080 } else { 00081 su3Tmul(gaugedSpinor, fatlnk, first_neighbor_spinor); 00082 sub(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize); 00083 00084 su3Tmul(gaugedSpinor, longlnk, third_neighbor_spinor); 00085 sub(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize); 00086 } 00087 } 00088 if (daggerBit){ 00089 negx(&res[i*mySpinorSiteSize], mySpinorSiteSize); 00090 } 00091 } 00092 00093 } 00094 00095 00096 00097 00098 void staggered_dslash(void *res, void **fatlink, void** longlink, void *spinorField, int oddBit, int daggerBit, 00099 QudaPrecision sPrecision, QudaPrecision gPrecision) { 00100 00101 if (sPrecision == QUDA_DOUBLE_PRECISION) { 00102 if (gPrecision == QUDA_DOUBLE_PRECISION){ 00103 dslashReference((double*)res, (double**)fatlink, (double**)longlink, (double*)spinorField, oddBit, daggerBit); 00104 }else{ 00105 dslashReference((double*)res, (float**)fatlink, (float**)longlink, (double*)spinorField, oddBit, daggerBit); 00106 } 00107 } 00108 else{ 00109 if (gPrecision == QUDA_DOUBLE_PRECISION){ 00110 dslashReference((float*)res, (double**)fatlink, (double**)longlink, (float*)spinorField, oddBit, daggerBit); 00111 }else{ 00112 dslashReference((float*)res, (float**)fatlink, (float**)longlink, (float*)spinorField, oddBit, daggerBit); 00113 } 00114 } 00115 } 00116 00117 00118 00119 00120 template <typename sFloat, typename gFloat> 00121 void Mat(sFloat *out, gFloat **fatlink, gFloat** longlink, sFloat *in, sFloat kappa, int daggerBit) 00122 { 00123 sFloat *inEven = in; 00124 sFloat *inOdd = in + Vh*mySpinorSiteSize; 00125 sFloat *outEven = out; 00126 sFloat *outOdd = out + Vh*mySpinorSiteSize; 00127 00128 // full dslash operator 00129 dslashReference(outOdd, fatlink, longlink, inEven, 1, daggerBit); 00130 dslashReference(outEven, fatlink, longlink, inOdd, 0, daggerBit); 00131 00132 // lastly apply the kappa term 00133 xpay(in, -kappa, out, V*mySpinorSiteSize); 00134 } 00135 00136 00137 void 00138 mat(void *out, void **fatlink, void** longlink, void *in, double kappa, int dagger_bit, 00139 QudaPrecision sPrecision, QudaPrecision gPrecision) 00140 { 00141 00142 if (sPrecision == QUDA_DOUBLE_PRECISION){ 00143 if (gPrecision == QUDA_DOUBLE_PRECISION) { 00144 Mat((double*)out, (double**)fatlink, (double**)longlink, (double*)in, (double)kappa, dagger_bit); 00145 }else { 00146 Mat((double*)out, (float**)fatlink, (float**)longlink, (double*)in, (double)kappa, dagger_bit); 00147 } 00148 }else{ 00149 if (gPrecision == QUDA_DOUBLE_PRECISION){ 00150 Mat((float*)out, (double**)fatlink, (double**)longlink, (float*)in, (float)kappa, dagger_bit); 00151 }else { 00152 Mat((float*)out, (float**)fatlink, (float**)longlink, (float*)in, (float)kappa, dagger_bit); 00153 } 00154 } 00155 } 00156 00157 00158 00159 template <typename sFloat, typename gFloat> 00160 void 00161 Matdagmat(sFloat *out, gFloat **fatlink, gFloat** longlink, sFloat *in, sFloat mass, int daggerBit, sFloat* tmp, QudaParity parity) 00162 { 00163 00164 sFloat msq_x4 = mass*mass*4; 00165 00166 switch(parity){ 00167 case QUDA_EVEN_PARITY: 00168 { 00169 sFloat *inEven = in; 00170 sFloat *outEven = out; 00171 dslashReference(tmp, fatlink, longlink, inEven, 1, daggerBit); 00172 dslashReference(outEven, fatlink, longlink, tmp, 0, daggerBit); 00173 00174 // lastly apply the mass term 00175 axmy(inEven, msq_x4, outEven, Vh*mySpinorSiteSize); 00176 break; 00177 } 00178 case QUDA_ODD_PARITY: 00179 { 00180 sFloat *inOdd = in; 00181 sFloat *outOdd = out; 00182 dslashReference(tmp, fatlink, longlink, inOdd, 0, daggerBit); 00183 dslashReference(outOdd, fatlink, longlink, tmp, 1, daggerBit); 00184 00185 // lastly apply the mass term 00186 axmy(inOdd, msq_x4, outOdd, Vh*mySpinorSiteSize); 00187 break; 00188 } 00189 00190 default: 00191 fprintf(stderr, "ERROR: invalid parity in %s,line %d\n", __FUNCTION__, __LINE__); 00192 break; 00193 } 00194 00195 } 00196 00197 00198 00199 void 00200 matdagmat(void *out, void **fatlink, void** longlink, void *in, double mass, int dagger_bit, 00201 QudaPrecision sPrecision, QudaPrecision gPrecision, void* tmp, QudaParity parity) 00202 { 00203 00204 if (sPrecision == QUDA_DOUBLE_PRECISION){ 00205 if (gPrecision == QUDA_DOUBLE_PRECISION) { 00206 Matdagmat((double*)out, (double**)fatlink, (double**)longlink, (double*)in, (double)mass, dagger_bit, (double*)tmp, parity); 00207 }else { 00208 Matdagmat((double*)out, (float**)fatlink, (float**)longlink, (double*)in, (double)mass, dagger_bit, (double*) tmp, parity); 00209 } 00210 }else{ 00211 if (gPrecision == QUDA_DOUBLE_PRECISION){ 00212 Matdagmat((float*)out, (double**)fatlink, (double**)longlink, (float*)in, (float)mass, dagger_bit, (float*)tmp, parity); 00213 }else { 00214 Matdagmat((float*)out, (float**)fatlink, (float**)longlink, (float*)in, (float)mass, dagger_bit, (float*)tmp, parity); 00215 } 00216 } 00217 } 00218 00219 00220 00221 00222 00223 // Apply the even-odd preconditioned Dirac operator 00224 template <typename sFloat, typename gFloat> 00225 static void MatPC(sFloat *outEven, gFloat **fatlink, gFloat** longlink, sFloat *inEven, sFloat kappa, 00226 int daggerBit, MatPCType matpc_type) { 00227 00228 sFloat *tmp = (sFloat*)malloc(Vh*mySpinorSiteSize*sizeof(sFloat)); 00229 00230 // full dslash operator 00231 if (matpc_type == QUDA_MATPC_EVEN_EVEN) { 00232 dslashReference(tmp, fatlink, longlink, inEven, 1, daggerBit); 00233 dslashReference(outEven, fatlink, longlink, tmp, 0, daggerBit); 00234 00235 //dslashReference(outEven, fatlink, longlink, inEven, 1, daggerBit); 00236 } else { 00237 dslashReference(tmp, fatlink, longlink, inEven, 0, daggerBit); 00238 dslashReference(outEven, fatlink, longlink, tmp, 1, daggerBit); 00239 } 00240 00241 // lastly apply the kappa term 00242 00243 sFloat kappa2 = -kappa*kappa; 00244 xpay(inEven, kappa2, outEven, Vh*mySpinorSiteSize); 00245 00246 free(tmp); 00247 } 00248 00249 00250 void 00251 staggered_matpc(void *outEven, void **fatlink, void**longlink, void *inEven, double kappa, 00252 MatPCType matpc_type, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision) 00253 { 00254 00255 if (sPrecision == QUDA_DOUBLE_PRECISION) 00256 if (gPrecision == QUDA_DOUBLE_PRECISION) { 00257 MatPC((double*)outEven, (double**)fatlink, (double**)longlink, (double*)inEven, (double)kappa, dagger_bit, matpc_type); 00258 } 00259 else{ 00260 MatPC((double*)outEven, (double**)fatlink, (double**)longlink, (double*)inEven, (double)kappa, dagger_bit, matpc_type); 00261 } 00262 else { 00263 if (gPrecision == QUDA_DOUBLE_PRECISION){ 00264 MatPC((float*)outEven, (double**)fatlink, (double**)longlink, (float*)inEven, (float)kappa, dagger_bit, matpc_type); 00265 }else{ 00266 MatPC((float*)outEven, (float**)fatlink, (float**)longlink, (float*)inEven, (float)kappa, dagger_bit, matpc_type); 00267 } 00268 } 00269 } 00270 00271 #ifdef MULTI_GPU 00272 00273 template <typename sFloat, typename gFloat> 00274 void dslashReference_mg4dir(sFloat *res, gFloat **fatlink, gFloat** longlink, 00275 gFloat** ghostFatlink, gFloat** ghostLonglink, 00276 sFloat *spinorField, sFloat** fwd_nbr_spinor, 00277 sFloat** back_nbr_spinor, int oddBit, int daggerBit) 00278 { 00279 for (int i=0; i<Vh*1*3*2; i++) res[i] = 0.0; 00280 00281 int Vsh[4] = {Vsh_x, Vsh_y, Vsh_z, Vsh_t}; 00282 gFloat *fatlinkEven[4], *fatlinkOdd[4]; 00283 gFloat *longlinkEven[4], *longlinkOdd[4]; 00284 gFloat *ghostFatlinkEven[4], *ghostFatlinkOdd[4]; 00285 gFloat *ghostLonglinkEven[4], *ghostLonglinkOdd[4]; 00286 00287 for (int dir = 0; dir < 4; dir++) { 00288 fatlinkEven[dir] = fatlink[dir]; 00289 fatlinkOdd[dir] = fatlink[dir] + Vh*gaugeSiteSize; 00290 longlinkEven[dir] =longlink[dir]; 00291 longlinkOdd[dir] = longlink[dir] + Vh*gaugeSiteSize; 00292 00293 ghostFatlinkEven[dir] = ghostFatlink[dir]; 00294 ghostFatlinkOdd[dir] = ghostFatlink[dir] + Vsh[dir]*gaugeSiteSize; 00295 ghostLonglinkEven[dir] = ghostLonglink[dir]; 00296 ghostLonglinkOdd[dir] = ghostLonglink[dir] + 3*Vsh[dir]*gaugeSiteSize; 00297 } 00298 00299 for (int i = 0; i < Vh; i++) { 00300 memset(res + i*mySpinorSiteSize, 0, mySpinorSiteSize*sizeof(sFloat)); 00301 for (int dir = 0; dir < 8; dir++) { 00302 gFloat* fatlnk = gaugeLink_mg4dir(i, dir, oddBit, fatlinkEven, fatlinkOdd, ghostFatlinkEven, ghostFatlinkOdd, 1, 1); 00303 gFloat* longlnk = gaugeLink_mg4dir(i, dir, oddBit, longlinkEven, longlinkOdd, ghostLonglinkEven, ghostLonglinkOdd, 3, 3); 00304 00305 sFloat *first_neighbor_spinor = spinorNeighbor_mg4dir(i, dir, oddBit, spinorField, fwd_nbr_spinor, back_nbr_spinor, 1, 3); 00306 sFloat *third_neighbor_spinor = spinorNeighbor_mg4dir(i, dir, oddBit, spinorField, fwd_nbr_spinor, back_nbr_spinor, 3, 3); 00307 00308 sFloat gaugedSpinor[mySpinorSiteSize]; 00309 00310 00311 if (dir % 2 == 0){ 00312 su3Mul(gaugedSpinor, fatlnk, first_neighbor_spinor); 00313 sum(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize); 00314 su3Mul(gaugedSpinor, longlnk, third_neighbor_spinor); 00315 sum(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize); 00316 } 00317 else{ 00318 su3Tmul(gaugedSpinor, fatlnk, first_neighbor_spinor); 00319 sub(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize); 00320 00321 su3Tmul(gaugedSpinor, longlnk, third_neighbor_spinor); 00322 sub(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize); 00323 00324 } 00325 00326 } 00327 if (daggerBit){ 00328 negx(&res[i*mySpinorSiteSize], mySpinorSiteSize); 00329 } 00330 } 00331 00332 } 00333 00334 00335 00336 void staggered_dslash_mg4dir(cpuColorSpinorField* out, void **fatlink, void** longlink, void** ghost_fatlink, 00337 void** ghost_longlink, cpuColorSpinorField* in, int oddBit, int daggerBit, 00338 QudaPrecision sPrecision, QudaPrecision gPrecision) 00339 { 00340 00341 QudaParity otherparity = QUDA_INVALID_PARITY; 00342 if (oddBit == QUDA_EVEN_PARITY){ 00343 otherparity = QUDA_ODD_PARITY; 00344 }else if (oddBit == QUDA_ODD_PARITY){ 00345 otherparity = QUDA_EVEN_PARITY; 00346 }else{ 00347 errorQuda("ERROR: full parity not supported in function %s", __FUNCTION__); 00348 } 00349 00350 int Nc = 3; 00351 int nFace = 3; 00352 FaceBuffer faceBuf(Z, 4, 2*Nc, nFace, sPrecision); 00353 faceBuf.exchangeCpuSpinor(*in, otherparity, daggerBit); 00354 00355 void** fwd_nbr_spinor = in->fwdGhostFaceBuffer; 00356 void** back_nbr_spinor = in->backGhostFaceBuffer; 00357 00358 if (sPrecision == QUDA_DOUBLE_PRECISION) { 00359 if (gPrecision == QUDA_DOUBLE_PRECISION){ 00360 dslashReference_mg4dir((double*)out->V(), (double**)fatlink, (double**)longlink, 00361 (double**)ghost_fatlink, (double**)ghost_longlink, (double*)in->V(), 00362 (double**)fwd_nbr_spinor, (double**)back_nbr_spinor, oddBit, daggerBit); 00363 } else { 00364 dslashReference_mg4dir((double*)out->V(), (float**)fatlink, (float**)longlink, (float**)ghost_fatlink, (float**)ghost_longlink, 00365 (double*)in->V(), (double**)fwd_nbr_spinor, (double**)back_nbr_spinor, oddBit, daggerBit); 00366 } 00367 } 00368 else{ 00369 if (gPrecision == QUDA_DOUBLE_PRECISION){ 00370 dslashReference_mg4dir((float*)out->V(), (double**)fatlink, (double**)longlink, (double**)ghost_fatlink, (double**)ghost_longlink, 00371 (float*)in->V(), (float**)fwd_nbr_spinor, (float**)back_nbr_spinor, oddBit, daggerBit); 00372 }else{ 00373 dslashReference_mg4dir((float*)out->V(), (float**)fatlink, (float**)longlink, (float**)ghost_fatlink, (float**)ghost_longlink, 00374 (float*)in->V(), (float**)fwd_nbr_spinor, (float**)back_nbr_spinor, oddBit, daggerBit); 00375 } 00376 } 00377 00378 00379 } 00380 00381 void 00382 matdagmat_mg4dir(cpuColorSpinorField* out, void **fatlink, void** longlink, void** ghost_fatlink, void** ghost_longlink, 00383 cpuColorSpinorField* in, double mass, int dagger_bit, 00384 QudaPrecision sPrecision, QudaPrecision gPrecision, cpuColorSpinorField* tmp, QudaParity parity) 00385 { 00386 //assert sPrecision and gPrecision must be the same 00387 if (sPrecision != gPrecision){ 00388 errorQuda("Spinor precision and gPrecison is not the same"); 00389 } 00390 00391 QudaParity otherparity = QUDA_INVALID_PARITY; 00392 if (parity == QUDA_EVEN_PARITY){ 00393 otherparity = QUDA_ODD_PARITY; 00394 }else if (parity == QUDA_ODD_PARITY){ 00395 otherparity = QUDA_EVEN_PARITY; 00396 }else{ 00397 errorQuda("ERROR: full parity not supported in function %s\n", __FUNCTION__); 00398 } 00399 00400 staggered_dslash_mg4dir(tmp, fatlink, longlink, ghost_fatlink, ghost_longlink, 00401 in, otherparity, dagger_bit, sPrecision, gPrecision); 00402 00403 staggered_dslash_mg4dir(out, fatlink, longlink, ghost_fatlink, ghost_longlink, 00404 tmp, parity, dagger_bit, sPrecision, gPrecision); 00405 00406 double msq_x4 = mass*mass*4; 00407 if (sPrecision == QUDA_DOUBLE_PRECISION){ 00408 axmy((double*)in->V(), (double)msq_x4, (double*)out->V(), Vh*mySpinorSiteSize); 00409 }else{ 00410 axmy((float*)in->V(), (float)msq_x4, (float*)out->V(), Vh*mySpinorSiteSize); 00411 } 00412 00413 } 00414 00415 #endif 00416