|
QUDA v0.3.2
A library for QCD on GPUs
|
00001 #include <stdio.h> 00002 #include <stdlib.h> 00003 #include <math.h> 00004 00005 #include <util_quda.h> 00006 00007 #include <test_util.h> 00008 #include <blas_reference.h> 00009 #include <wilson_dslash_reference.h> 00010 00011 int Z[4]; 00012 int V; 00013 int Vh; 00014 00015 void setDims(int *X) { 00016 V = 1; 00017 for (int d=0; d< 4; d++) { 00018 V *= X[d]; 00019 Z[d] = X[d]; 00020 } 00021 Vh = V/2; 00022 } 00023 00024 template <typename Float> 00025 void sum(Float *dst, Float *a, Float *b, int cnt) { 00026 for (int i = 0; i < cnt; i++) 00027 dst[i] = a[i] + b[i]; 00028 } 00029 00030 // performs the operation y[i] = x[i] + a*y[i] 00031 template <typename Float> 00032 void xpay(Float *x, Float a, Float *y, int len) { 00033 for (int i=0; i<len; i++) y[i] = x[i] + a*y[i]; 00034 } 00035 00036 00037 00038 template <typename Float> 00039 Float *gaugeLink(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd) { 00040 Float **gaugeField; 00041 int j; 00042 00043 if (dir % 2 == 0) { 00044 j = i; 00045 gaugeField = (oddBit ? gaugeOdd : gaugeEven); 00046 } 00047 else { 00048 switch (dir) { 00049 case 1: j = neighborIndex(i, oddBit, 0, 0, 0, -1); break; 00050 case 3: j = neighborIndex(i, oddBit, 0, 0, -1, 0); break; 00051 case 5: j = neighborIndex(i, oddBit, 0, -1, 0, 0); break; 00052 case 7: j = neighborIndex(i, oddBit, -1, 0, 0, 0); break; 00053 default: j = -1; break; 00054 } 00055 gaugeField = (oddBit ? gaugeEven : gaugeOdd); 00056 } 00057 00058 return &gaugeField[dir/2][j*(3*3*2)]; 00059 } 00060 00061 template <typename Float> 00062 Float *spinorNeighbor(int i, int dir, int oddBit, Float *spinorField) { 00063 int j; 00064 switch (dir) { 00065 case 0: j = neighborIndex(i, oddBit, 0, 0, 0, +1); break; 00066 case 1: j = neighborIndex(i, oddBit, 0, 0, 0, -1); break; 00067 case 2: j = neighborIndex(i, oddBit, 0, 0, +1, 0); break; 00068 case 3: j = neighborIndex(i, oddBit, 0, 0, -1, 0); break; 00069 case 4: j = neighborIndex(i, oddBit, 0, +1, 0, 0); break; 00070 case 5: j = neighborIndex(i, oddBit, 0, -1, 0, 0); break; 00071 case 6: j = neighborIndex(i, oddBit, +1, 0, 0, 0); break; 00072 case 7: j = neighborIndex(i, oddBit, -1, 0, 0, 0); break; 00073 default: j = -1; break; 00074 } 00075 00076 return &spinorField[j*(4*3*2)]; 00077 } 00078 00079 template <typename sFloat, typename gFloat> 00080 void dot(sFloat* res, gFloat* a, sFloat* b) { 00081 res[0] = res[1] = 0; 00082 for (int m = 0; m < 3; m++) { 00083 sFloat a_re = a[2*m+0]; 00084 sFloat a_im = a[2*m+1]; 00085 sFloat b_re = b[2*m+0]; 00086 sFloat b_im = b[2*m+1]; 00087 res[0] += a_re * b_re - a_im * b_im; 00088 res[1] += a_re * b_im + a_im * b_re; 00089 } 00090 } 00091 00092 template <typename Float> 00093 void su3Transpose(Float *res, Float *mat) { 00094 for (int m = 0; m < 3; m++) { 00095 for (int n = 0; n < 3; n++) { 00096 res[m*(3*2) + n*(2) + 0] = + mat[n*(3*2) + m*(2) + 0]; 00097 res[m*(3*2) + n*(2) + 1] = - mat[n*(3*2) + m*(2) + 1]; 00098 } 00099 } 00100 } 00101 00102 template <typename sFloat, typename gFloat> 00103 void su3Mul(sFloat *res, gFloat *mat, sFloat *vec) { 00104 for (int n = 0; n < 3; n++) dot(&res[n*(2)], &mat[n*(3*2)], vec); 00105 } 00106 00107 template <typename sFloat, typename gFloat> 00108 void su3Tmul(sFloat *res, gFloat *mat, sFloat *vec) { 00109 gFloat matT[3*3*2]; 00110 su3Transpose(matT, mat); 00111 su3Mul(res, matT, vec); 00112 } 00113 00114 const double projector[8][4][4][2] = { 00115 { 00116 {{1,0}, {0,0}, {0,0}, {0,-1}}, 00117 {{0,0}, {1,0}, {0,-1}, {0,0}}, 00118 {{0,0}, {0,1}, {1,0}, {0,0}}, 00119 {{0,1}, {0,0}, {0,0}, {1,0}} 00120 }, 00121 { 00122 {{1,0}, {0,0}, {0,0}, {0,1}}, 00123 {{0,0}, {1,0}, {0,1}, {0,0}}, 00124 {{0,0}, {0,-1}, {1,0}, {0,0}}, 00125 {{0,-1}, {0,0}, {0,0}, {1,0}} 00126 }, 00127 { 00128 {{1,0}, {0,0}, {0,0}, {1,0}}, 00129 {{0,0}, {1,0}, {-1,0}, {0,0}}, 00130 {{0,0}, {-1,0}, {1,0}, {0,0}}, 00131 {{1,0}, {0,0}, {0,0}, {1,0}} 00132 }, 00133 { 00134 {{1,0}, {0,0}, {0,0}, {-1,0}}, 00135 {{0,0}, {1,0}, {1,0}, {0,0}}, 00136 {{0,0}, {1,0}, {1,0}, {0,0}}, 00137 {{-1,0}, {0,0}, {0,0}, {1,0}} 00138 }, 00139 { 00140 {{1,0}, {0,0}, {0,-1}, {0,0}}, 00141 {{0,0}, {1,0}, {0,0}, {0,1}}, 00142 {{0,1}, {0,0}, {1,0}, {0,0}}, 00143 {{0,0}, {0,-1}, {0,0}, {1,0}} 00144 }, 00145 { 00146 {{1,0}, {0,0}, {0,1}, {0,0}}, 00147 {{0,0}, {1,0}, {0,0}, {0,-1}}, 00148 {{0,-1}, {0,0}, {1,0}, {0,0}}, 00149 {{0,0}, {0,1}, {0,0}, {1,0}} 00150 }, 00151 { 00152 {{1,0}, {0,0}, {-1,0}, {0,0}}, 00153 {{0,0}, {1,0}, {0,0}, {-1,0}}, 00154 {{-1,0}, {0,0}, {1,0}, {0,0}}, 00155 {{0,0}, {-1,0}, {0,0}, {1,0}} 00156 }, 00157 { 00158 {{1,0}, {0,0}, {1,0}, {0,0}}, 00159 {{0,0}, {1,0}, {0,0}, {1,0}}, 00160 {{1,0}, {0,0}, {1,0}, {0,0}}, 00161 {{0,0}, {1,0}, {0,0}, {1,0}} 00162 } 00163 }; 00164 00165 00166 // todo pass projector 00167 template <typename Float> 00168 void multiplySpinorByDiracProjector(Float *res, int projIdx, Float *spinorIn) { 00169 for (int i=0; i<4*3*2; i++) res[i] = 0.0; 00170 00171 for (int s = 0; s < 4; s++) { 00172 for (int t = 0; t < 4; t++) { 00173 Float projRe = projector[projIdx][s][t][0]; 00174 Float projIm = projector[projIdx][s][t][1]; 00175 00176 for (int m = 0; m < 3; m++) { 00177 Float spinorRe = spinorIn[t*(3*2) + m*(2) + 0]; 00178 Float spinorIm = spinorIn[t*(3*2) + m*(2) + 1]; 00179 res[s*(3*2) + m*(2) + 0] += projRe*spinorRe - projIm*spinorIm; 00180 res[s*(3*2) + m*(2) + 1] += projRe*spinorIm + projIm*spinorRe; 00181 } 00182 } 00183 } 00184 } 00185 00186 00187 // 00188 // dslashReference() 00189 // 00190 // if oddBit is zero: calculate odd parity spinor elements (using even parity spinor) 00191 // if oddBit is one: calculate even parity spinor elements 00192 // 00193 // if daggerBit is zero: perform ordinary dslash operator 00194 // if daggerBit is one: perform hermitian conjugate of dslash 00195 // 00196 template <typename sFloat, typename gFloat> 00197 void dslashReference(sFloat *res, gFloat **gaugeFull, sFloat *spinorField, int oddBit, int daggerBit) { 00198 for (int i=0; i<Vh*4*3*2; i++) res[i] = 0.0; 00199 00200 gFloat *gaugeEven[4], *gaugeOdd[4]; 00201 for (int dir = 0; dir < 4; dir++) { 00202 gaugeEven[dir] = gaugeFull[dir]; 00203 gaugeOdd[dir] = gaugeFull[dir]+Vh*gaugeSiteSize; 00204 } 00205 00206 for (int i = 0; i < Vh; i++) { 00207 for (int dir = 0; dir < 8; dir++) { 00208 gFloat *gauge = gaugeLink(i, dir, oddBit, gaugeEven, gaugeOdd); 00209 sFloat *spinor = spinorNeighbor(i, dir, oddBit, spinorField); 00210 00211 sFloat projectedSpinor[4*3*2], gaugedSpinor[4*3*2]; 00212 int projIdx = 2*(dir/2)+(dir+daggerBit)%2; 00213 multiplySpinorByDiracProjector(projectedSpinor, projIdx, spinor); 00214 00215 for (int s = 0; s < 4; s++) { 00216 if (dir % 2 == 0) 00217 su3Mul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]); 00218 else 00219 su3Tmul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]); 00220 } 00221 00222 sum(&res[i*(4*3*2)], &res[i*(4*3*2)], gaugedSpinor, 4*3*2); 00223 } 00224 } 00225 } 00226 00227 void dslash(void *res, void **gaugeFull, void *spinorField, int oddBit, int daggerBit, 00228 QudaPrecision sPrecision, QudaPrecision gPrecision) { 00229 00230 if (sPrecision == QUDA_DOUBLE_PRECISION) 00231 if (gPrecision == QUDA_DOUBLE_PRECISION) 00232 dslashReference((double*)res, (double**)gaugeFull, (double*)spinorField, oddBit, daggerBit); 00233 else 00234 dslashReference((double*)res, (float**)gaugeFull, (double*)spinorField, oddBit, daggerBit); 00235 else 00236 if (gPrecision == QUDA_DOUBLE_PRECISION) 00237 dslashReference((float*)res, (double**)gaugeFull, (float*)spinorField, oddBit, daggerBit); 00238 else 00239 dslashReference((float*)res, (float**)gaugeFull, (float*)spinorField, oddBit, daggerBit); 00240 00241 } 00242 00243 template <typename sFloat, typename gFloat> 00244 void Mat(sFloat *out, gFloat **gauge, sFloat *in, sFloat kappa, int daggerBit) { 00245 sFloat *inEven = in; 00246 sFloat *inOdd = in + Vh*spinorSiteSize; 00247 sFloat *outEven = out; 00248 sFloat *outOdd = out + Vh*spinorSiteSize; 00249 00250 // full dslash operator 00251 dslashReference(outOdd, gauge, inEven, 1, daggerBit); 00252 dslashReference(outEven, gauge, inOdd, 0, daggerBit); 00253 00254 // lastly apply the kappa term 00255 xpay(in, -kappa, out, V*spinorSiteSize); 00256 } 00257 00258 void mat(void *out, void **gauge, void *in, double kappa, int dagger_bit, 00259 QudaPrecision sPrecision, QudaPrecision gPrecision) { 00260 00261 if (sPrecision == QUDA_DOUBLE_PRECISION) 00262 if (gPrecision == QUDA_DOUBLE_PRECISION) 00263 Mat((double*)out, (double**)gauge, (double*)in, (double)kappa, dagger_bit); 00264 else 00265 Mat((double*)out, (float**)gauge, (double*)in, (double)kappa, dagger_bit); 00266 else 00267 if (gPrecision == QUDA_DOUBLE_PRECISION) 00268 Mat((float*)out, (double**)gauge, (float*)in, (float)kappa, dagger_bit); 00269 else 00270 Mat((float*)out, (float**)gauge, (float*)in, (float)kappa, dagger_bit); 00271 } 00272 00273 // Apply the even-odd preconditioned Dirac operator 00274 template <typename sFloat, typename gFloat> 00275 void MatPC(sFloat *outEven, gFloat **gauge, sFloat *inEven, sFloat kappa, 00276 int daggerBit, QudaMatPCType matpc_type) { 00277 00278 sFloat *tmp = (sFloat*)malloc(Vh*spinorSiteSize*sizeof(sFloat)); 00279 00280 // full dslash operator 00281 if (matpc_type == QUDA_MATPC_EVEN_EVEN) { 00282 dslashReference(tmp, gauge, inEven, 1, daggerBit); 00283 dslashReference(outEven, gauge, tmp, 0, daggerBit); 00284 } else { 00285 dslashReference(tmp, gauge, inEven, 0, daggerBit); 00286 dslashReference(outEven, gauge, tmp, 1, daggerBit); 00287 } 00288 00289 // lastly apply the kappa term 00290 sFloat kappa2 = -kappa*kappa; 00291 xpay(inEven, kappa2, outEven, Vh*spinorSiteSize); 00292 free(tmp); 00293 } 00294 00295 void matpc(void *outEven, void **gauge, void *inEven, double kappa, 00296 QudaMatPCType matpc_type, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision) { 00297 00298 if (sPrecision == QUDA_DOUBLE_PRECISION) 00299 if (gPrecision == QUDA_DOUBLE_PRECISION) 00300 MatPC((double*)outEven, (double**)gauge, (double*)inEven, (double)kappa, dagger_bit, matpc_type); 00301 else 00302 MatPC((double*)outEven, (float**)gauge, (double*)inEven, (double)kappa, dagger_bit, matpc_type); 00303 else 00304 if (gPrecision == QUDA_DOUBLE_PRECISION) 00305 MatPC((float*)outEven, (double**)gauge, (float*)inEven, (float)kappa, dagger_bit, matpc_type); 00306 else 00307 MatPC((float*)outEven, (float**)gauge, (float*)inEven, (float)kappa, dagger_bit, matpc_type); 00308 }
1.7.3