|
QUDA v0.3.2
A library for QCD on GPUs
|
00001 #include <stdio.h> 00002 #include <stdlib.h> 00003 #include <math.h> 00004 #include <string.h> 00005 00006 #include <test_util.h> 00007 #include <quda_internal.h> 00008 #include <quda.h> 00009 #include <util_quda.h> 00010 #include <staggered_dslash_reference.h> 00011 #include "misc.h" 00012 extern void *memset(void *s, int c, size_t n); 00013 00014 static int mySpinorSiteSize = 6; 00015 00016 int Z[4]; 00017 int V; 00018 int Vh; 00019 00020 void setDims(int *X) { 00021 V = 1; 00022 for (int d=0; d< 4; d++) { 00023 V *= X[d]; 00024 Z[d] = X[d]; 00025 } 00026 Vh = V/2; 00027 } 00028 00029 template <typename Float> 00030 void sum(Float *dst, Float *a, Float *b, int cnt) { 00031 for (int i = 0; i < cnt; i++) 00032 dst[i] = a[i] + b[i]; 00033 } 00034 template <typename Float> 00035 void sub(Float *dst, Float *a, Float *b, int cnt) { 00036 for (int i = 0; i < cnt; i++) 00037 dst[i] = a[i] - b[i]; 00038 } 00039 // performs the operation y[i] = x[i] + a*y[i] 00040 template <typename Float> 00041 void xpay(Float *x, Float a, Float *y, int len) { 00042 for (int i=0; i<len; i++) y[i] = x[i] + a*y[i]; 00043 } 00044 // performs the operation y[i] = a*x[i] - y[i] 00045 template <typename Float> 00046 void axmy(Float *x, Float a, Float *y, int len) { 00047 for (int i=0; i<len; i++) y[i] = a*x[i] - y[i]; 00048 } 00049 00050 template <typename Float> 00051 void negx(Float *x, int len) { 00052 for (int i=0; i<len; i++) x[i] = -x[i]; 00053 } 00054 00055 // i represents a "half index" into an even or odd "half lattice". 00056 // when oddBit={0,1} the half lattice is {even,odd}. 00057 // 00058 // the displacements, such as dx, refer to the full lattice coordinates. 00059 // 00060 // neighborIndex() takes a "half index", displaces it, and returns the 00061 // new "half index", which can be an index into either the even or odd lattices. 00062 // displacements of magnitude one always interchange odd and even lattices. 00063 // 00064 00065 00066 template <typename Float> 00067 Float *gaugeLink(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd, int nbr_distance) { 00068 Float **gaugeField; 00069 int j; 00070 int d = nbr_distance; 00071 if (dir % 2 == 0) { 00072 j = i; 00073 gaugeField = (oddBit ? gaugeOdd : gaugeEven); 00074 } 00075 else { 00076 switch (dir) { 00077 case 1: j = neighborIndex(i, oddBit, 0, 0, 0, -d); break; 00078 case 3: j = neighborIndex(i, oddBit, 0, 0, -d, 0); break; 00079 case 5: j = neighborIndex(i, oddBit, 0, -d, 0, 0); break; 00080 case 7: j = neighborIndex(i, oddBit, -d, 0, 0, 0); break; 00081 default: j = -1; break; 00082 } 00083 gaugeField = (oddBit ? gaugeEven : gaugeOdd); 00084 } 00085 00086 return &gaugeField[dir/2][j*(3*3*2)]; 00087 } 00088 00089 00090 00091 template <typename Float> 00092 Float *spinorNeighbor(int i, int dir, int oddBit, Float *spinorField, int neighbor_distance) 00093 { 00094 int j; 00095 int nb = neighbor_distance; 00096 switch (dir) { 00097 case 0: j = neighborIndex(i, oddBit, 0, 0, 0, +nb); break; 00098 case 1: j = neighborIndex(i, oddBit, 0, 0, 0, -nb); break; 00099 case 2: j = neighborIndex(i, oddBit, 0, 0, +nb, 0); break; 00100 case 3: j = neighborIndex(i, oddBit, 0, 0, -nb, 0); break; 00101 case 4: j = neighborIndex(i, oddBit, 0, +nb, 0, 0); break; 00102 case 5: j = neighborIndex(i, oddBit, 0, -nb, 0, 0); break; 00103 case 6: j = neighborIndex(i, oddBit, +nb, 0, 0, 0); break; 00104 case 7: j = neighborIndex(i, oddBit, -nb, 0, 0, 0); break; 00105 default: j = -1; break; 00106 } 00107 00108 return &spinorField[j*(mySpinorSiteSize)]; 00109 } 00110 00111 template <typename sFloat, typename gFloat> 00112 void dot(sFloat* res, gFloat* a, sFloat* b) { 00113 res[0] = res[1] = 0; 00114 for (int m = 0; m < 3; m++) { 00115 sFloat a_re = a[2*m+0]; 00116 sFloat a_im = a[2*m+1]; 00117 sFloat b_re = b[2*m+0]; 00118 sFloat b_im = b[2*m+1]; 00119 res[0] += a_re * b_re - a_im * b_im; 00120 res[1] += a_re * b_im + a_im * b_re; 00121 } 00122 } 00123 00124 template <typename Float> 00125 void su3Transpose(Float *res, Float *mat) { 00126 for (int m = 0; m < 3; m++) { 00127 for (int n = 0; n < 3; n++) { 00128 res[m*(3*2) + n*(2) + 0] = + mat[n*(3*2) + m*(2) + 0]; 00129 res[m*(3*2) + n*(2) + 1] = - mat[n*(3*2) + m*(2) + 1]; 00130 } 00131 } 00132 } 00133 00134 00135 template <typename sFloat, typename gFloat> 00136 void su3Mul(sFloat *res, gFloat *mat, sFloat *vec) { 00137 for (int n = 0; n < 3; n++) dot(&res[n*(2)], &mat[n*(3*2)], vec); 00138 } 00139 00140 template <typename sFloat, typename gFloat> 00141 void su3Tmul(sFloat *res, gFloat *mat, sFloat *vec) { 00142 gFloat matT[3*3*2]; 00143 su3Transpose(matT, mat); 00144 su3Mul(res, matT, vec); 00145 } 00146 00147 00148 // 00149 // dslashReference() 00150 // 00151 // if oddBit is zero: calculate even parity spinor elements (using odd parity spinor) 00152 // if oddBit is one: calculate odd parity spinor elements 00153 // 00154 // if daggerBit is zero: perform ordinary dslash operator 00155 // if daggerBit is one: perform hermitian conjugate of dslash 00156 // 00157 template<typename Float> 00158 void display_link_internal(Float* link) 00159 { 00160 int i, j; 00161 00162 for (i = 0;i < 3; i++){ 00163 for(j=0;j < 3; j++){ 00164 printf("(%10f,%10f) \t", link[i*3*2 + j*2], link[i*3*2 + j*2 + 1]); 00165 } 00166 printf("\n"); 00167 } 00168 printf("\n"); 00169 return; 00170 } 00171 00172 00173 template <typename sFloat, typename gFloat> 00174 void dslashReference(sFloat *res, gFloat **fatlink, gFloat** longlink, sFloat *spinorField, int oddBit, int daggerBit) 00175 { 00176 for (int i=0; i<Vh*1*3*2; i++) res[i] = 0.0; 00177 00178 gFloat *fatlinkEven[4], *fatlinkOdd[4]; 00179 gFloat *longlinkEven[4], *longlinkOdd[4]; 00180 00181 for (int dir = 0; dir < 4; dir++) { 00182 fatlinkEven[dir] = fatlink[dir]; 00183 fatlinkOdd[dir] = fatlink[dir] + Vh*gaugeSiteSize; 00184 longlinkEven[dir] =longlink[dir]; 00185 longlinkOdd[dir] = longlink[dir] + Vh*gaugeSiteSize; 00186 } 00187 00188 for (int i = 0; i < Vh; i++) { 00189 memset(res + i*mySpinorSiteSize, 0, mySpinorSiteSize*sizeof(sFloat)); 00190 for (int dir = 0; dir < 8; dir++) { 00191 gFloat* fatlnk = gaugeLink(i, dir, oddBit, fatlinkEven, fatlinkOdd, 1); 00192 gFloat* longlnk = gaugeLink(i, dir, oddBit, longlinkEven, longlinkOdd, 3); 00193 00194 sFloat *first_neighbor_spinor = spinorNeighbor(i, dir, oddBit, spinorField, 1); 00195 sFloat *third_neighbor_spinor = spinorNeighbor(i, dir, oddBit, spinorField, 3); 00196 00197 00198 sFloat gaugedSpinor[mySpinorSiteSize]; 00199 00200 if (dir % 2 == 0){ 00201 su3Mul(gaugedSpinor, fatlnk, first_neighbor_spinor); 00202 sum(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize); 00203 su3Mul(gaugedSpinor, longlnk, third_neighbor_spinor); 00204 sum(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize); 00205 } 00206 else{ 00207 su3Tmul(gaugedSpinor, fatlnk, first_neighbor_spinor); 00208 sub(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize); 00209 00210 su3Tmul(gaugedSpinor, longlnk, third_neighbor_spinor); 00211 sub(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize); 00212 00213 } 00214 } 00215 if (daggerBit){ 00216 negx(&res[i*mySpinorSiteSize], mySpinorSiteSize); 00217 } 00218 } 00219 00220 } 00221 00222 00223 void staggered_dslash(void *res, void **fatlink, void** longlink, void *spinorField, int oddBit, int daggerBit, 00224 QudaPrecision sPrecision, QudaPrecision gPrecision) { 00225 00226 if (sPrecision == QUDA_DOUBLE_PRECISION) { 00227 if (gPrecision == QUDA_DOUBLE_PRECISION){ 00228 dslashReference((double*)res, (double**)fatlink, (double**)longlink, (double*)spinorField, oddBit, daggerBit); 00229 }else{ 00230 dslashReference((double*)res, (float**)fatlink, (float**)longlink, (double*)spinorField, oddBit, daggerBit); 00231 } 00232 } 00233 else{ 00234 if (gPrecision == QUDA_DOUBLE_PRECISION){ 00235 dslashReference((float*)res, (double**)fatlink, (double**)longlink, (float*)spinorField, oddBit, daggerBit); 00236 }else{ 00237 dslashReference((float*)res, (float**)fatlink, (float**)longlink, (float*)spinorField, oddBit, daggerBit); 00238 } 00239 } 00240 } 00241 00242 00243 00244 template <typename sFloat, typename gFloat> 00245 void Mat(sFloat *out, gFloat **fatlink, gFloat** longlink, sFloat *in, sFloat kappa, int daggerBit) 00246 { 00247 sFloat *inEven = in; 00248 sFloat *inOdd = in + Vh*mySpinorSiteSize; 00249 sFloat *outEven = out; 00250 sFloat *outOdd = out + Vh*mySpinorSiteSize; 00251 00252 // full dslash operator 00253 dslashReference(outOdd, fatlink, longlink, inEven, 1, daggerBit); 00254 dslashReference(outEven, fatlink, longlink, inOdd, 0, daggerBit); 00255 00256 // lastly apply the kappa term 00257 xpay(in, -kappa, out, V*mySpinorSiteSize); 00258 } 00259 00260 00261 void 00262 mat(void *out, void **fatlink, void** longlink, void *in, double kappa, int dagger_bit, 00263 QudaPrecision sPrecision, QudaPrecision gPrecision) 00264 { 00265 00266 if (sPrecision == QUDA_DOUBLE_PRECISION){ 00267 if (gPrecision == QUDA_DOUBLE_PRECISION) { 00268 Mat((double*)out, (double**)fatlink, (double**)longlink, (double*)in, (double)kappa, dagger_bit); 00269 }else { 00270 Mat((double*)out, (float**)fatlink, (float**)longlink, (double*)in, (double)kappa, dagger_bit); 00271 } 00272 }else{ 00273 if (gPrecision == QUDA_DOUBLE_PRECISION){ 00274 Mat((float*)out, (double**)fatlink, (double**)longlink, (float*)in, (float)kappa, dagger_bit); 00275 }else { 00276 Mat((float*)out, (float**)fatlink, (float**)longlink, (float*)in, (float)kappa, dagger_bit); 00277 } 00278 } 00279 } 00280 00281 00282 00283 template <typename sFloat, typename gFloat> 00284 void 00285 Matdagmat_milc(sFloat *out, gFloat **fatlink, gFloat** longlink, sFloat *in, sFloat mass, int daggerBit, sFloat* tmp, MyQudaParity parity) 00286 { 00287 00288 sFloat msq_x4 = mass*mass*4; 00289 00290 switch(parity){ 00291 case QUDA_EVEN: 00292 { 00293 sFloat *inEven = in; 00294 sFloat *outEven = out; 00295 dslashReference(tmp, fatlink, longlink, inEven, 1, daggerBit); 00296 dslashReference(outEven, fatlink, longlink, tmp, 0, daggerBit); 00297 00298 // lastly apply the mass term 00299 axmy(inEven, msq_x4, outEven, Vh*mySpinorSiteSize); 00300 break; 00301 } 00302 case QUDA_ODD: 00303 { 00304 sFloat *inOdd = in; 00305 sFloat *outOdd = out; 00306 dslashReference(tmp, fatlink, longlink, inOdd, 0, daggerBit); 00307 dslashReference(outOdd, fatlink, longlink, tmp, 1, daggerBit); 00308 00309 // lastly apply the mass term 00310 axmy(inOdd, msq_x4, outOdd, Vh*mySpinorSiteSize); 00311 break; 00312 } 00313 00314 case QUDA_EVENODD: 00315 { 00316 sFloat *inEven = in; 00317 sFloat *inOdd = in + Vh*mySpinorSiteSize; 00318 sFloat *outEven = out; 00319 sFloat *outOdd = out + Vh*mySpinorSiteSize; 00320 sFloat *tmpEven = tmp; 00321 sFloat *tmpOdd = tmp + Vh*mySpinorSiteSize; 00322 00323 dslashReference(tmpOdd, fatlink, longlink, inEven, 1, daggerBit); 00324 dslashReference(tmpEven, fatlink, longlink, inOdd, 0, daggerBit); 00325 00326 dslashReference(outOdd, fatlink, longlink, tmpEven, 1, daggerBit); 00327 dslashReference(outEven, fatlink, longlink, tmpOdd, 0, daggerBit); 00328 00329 // lastly apply the mass term 00330 axmy(in, msq_x4, out, V*mySpinorSiteSize); 00331 break; 00332 } 00333 default: 00334 fprintf(stderr, "ERROR: invalid parity in %s,line %d\n", __FUNCTION__, __LINE__); 00335 break; 00336 } 00337 00338 } 00339 00340 00341 void 00342 matdagmat_milc(void *out, void **fatlink, void** longlink, void *in, double mass, int dagger_bit, 00343 QudaPrecision sPrecision, QudaPrecision gPrecision, void* tmp, MyQudaParity parity) 00344 { 00345 00346 if (sPrecision == QUDA_DOUBLE_PRECISION){ 00347 if (gPrecision == QUDA_DOUBLE_PRECISION) { 00348 Matdagmat_milc((double*)out, (double**)fatlink, (double**)longlink, (double*)in, (double)mass, dagger_bit, (double*)tmp, parity); 00349 }else { 00350 Matdagmat_milc((double*)out, (float**)fatlink, (float**)longlink, (double*)in, (double)mass, dagger_bit, (double*) tmp, parity); 00351 } 00352 }else{ 00353 if (gPrecision == QUDA_DOUBLE_PRECISION){ 00354 Matdagmat_milc((float*)out, (double**)fatlink, (double**)longlink, (float*)in, (float)mass, dagger_bit, (float*)tmp, parity); 00355 }else { 00356 Matdagmat_milc((float*)out, (float**)fatlink, (float**)longlink, (float*)in, (float)mass, dagger_bit, (float*)tmp, parity); 00357 } 00358 } 00359 } 00360 00361 00362 00363 // Apply the even-odd preconditioned Dirac operator 00364 template <typename sFloat, typename gFloat> 00365 static void MatPC(sFloat *outEven, gFloat **fatlink, gFloat** longlink, sFloat *inEven, sFloat kappa, 00366 int daggerBit, MatPCType matpc_type) { 00367 00368 sFloat *tmp = (sFloat*)malloc(Vh*mySpinorSiteSize*sizeof(sFloat)); 00369 00370 // full dslash operator 00371 if (matpc_type == QUDA_MATPC_EVEN_EVEN) { 00372 dslashReference(tmp, fatlink, longlink, inEven, 1, daggerBit); 00373 dslashReference(outEven, fatlink, longlink, tmp, 0, daggerBit); 00374 00375 //dslashReference(outEven, fatlink, longlink, inEven, 1, daggerBit); 00376 } else { 00377 dslashReference(tmp, fatlink, longlink, inEven, 0, daggerBit); 00378 dslashReference(outEven, fatlink, longlink, tmp, 1, daggerBit); 00379 } 00380 00381 // lastly apply the kappa term 00382 00383 sFloat kappa2 = -kappa*kappa; 00384 xpay(inEven, kappa2, outEven, Vh*mySpinorSiteSize); 00385 00386 free(tmp); 00387 } 00388 00389 00390 void 00391 staggered_matpc(void *outEven, void **fatlink, void**longlink, void *inEven, double kappa, 00392 MatPCType matpc_type, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision) 00393 { 00394 00395 if (sPrecision == QUDA_DOUBLE_PRECISION) 00396 if (gPrecision == QUDA_DOUBLE_PRECISION) { 00397 MatPC((double*)outEven, (double**)fatlink, (double**)longlink, (double*)inEven, (double)kappa, dagger_bit, matpc_type); 00398 } 00399 else{ 00400 MatPC((double*)outEven, (double**)fatlink, (double**)longlink, (double*)inEven, (double)kappa, dagger_bit, matpc_type); 00401 } 00402 else { 00403 if (gPrecision == QUDA_DOUBLE_PRECISION){ 00404 MatPC((float*)outEven, (double**)fatlink, (double**)longlink, (float*)inEven, (float)kappa, dagger_bit, matpc_type); 00405 }else{ 00406 MatPC((float*)outEven, (float**)fatlink, (float**)longlink, (float*)inEven, (float)kappa, dagger_bit, matpc_type); 00407 } 00408 } 00409 }
1.7.3