QUDA v0.4.0
A library for QCD on GPUs
|
00001 // tm_dslash_def.h - Twisted Mass Dslash kernel definitions 00002 00003 // There are currently 36 different variants of the Twisted Mass 00004 // Wilson Dslash kernel, each one characterized by a set of 5 options, 00005 // where each option can take one of several values (3*2*2*3 = 36). 00006 // This file is structured so that the C preprocessor loops through all 36 00007 // variants (in a manner resembling a counter), sets the appropriate 00008 // macros, and defines the corresponding functions. 00009 // 00010 // As an example of the function naming conventions, consider 00011 // 00012 // twistedMassDslash12DaggerXpayKernel(float4* out, ...). 00013 // 00014 // This is a twisted mass Dslash^dagger kernel where the result is 00015 // multiplied by "a" and summed with an input vector (Xpay), and the 00016 // gauge matrix is reconstructed from 12 real numbers. More 00017 // generally, each function name is given by the concatenation of the 00018 // following 4 fields, with "Kernel" at the end: 00019 // 00020 // DD_NAME_F = twistedMassDslash 00021 // DD_RECON_F = 8, 12, 18 00022 // DD_DAG_F = Dagger, [blank] 00023 // DD_XPAY_F = Xpay, [blank] 00024 // 00025 // In addition, the kernels are templated on the precision of the 00026 // fields (double, single, or half). 00027 00028 // initialize on first iteration 00029 00030 #ifndef DD_LOOP 00031 #define DD_LOOP 00032 #define DD_DAG 0 00033 #define DD_XPAY 0 00034 #define DD_RECON 0 00035 #define DD_PREC 0 00036 #endif 00037 00038 // set options for current iteration 00039 00040 #define DD_NAME_F twistedMassDslash 00041 00042 #if (DD_DAG==0) // no dagger 00043 #define DD_DAG_F 00044 #else // dagger 00045 #define DD_DAG_F Dagger 00046 #endif 00047 00048 #if (DD_XPAY==0) // no xpay 00049 #define DD_XPAY_F 00050 #else 00051 #define DSLASH_XPAY 00052 #define DD_XPAY_F Xpay 00053 #endif 00054 00055 #if (DD_PREC == 0) 00056 #define DD_PARAM4 const double a, const double b, const double2 *x, const float *xNorm, const DslashParam param 00057 #elif (DD_PREC == 1) 00058 #define DD_PARAM4 const float a, const float b, const float4 *x, const float *xNorm, const DslashParam param 00059 #else 00060 #define DD_PARAM4 const float a, const float b, const short4 *x, const float *xNorm, const DslashParam param 00061 #endif 00062 00063 #if (DD_RECON==0) // reconstruct from 8 reals 00064 #define DD_RECON_F 8 00065 00066 #if (DD_PREC==0) 00067 #define DD_PARAM2 const double2 *gauge0, const double2 *gauge1 00068 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_8_DOUBLE 00069 #ifdef DIRECT_ACCESS_LINK 00070 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_8_DOUBLE2 00071 #else 00072 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_8_DOUBLE2_TEX 00073 #endif // DIRECT_ACCESS_LINK 00074 00075 #elif (DD_PREC==1) 00076 #define DD_PARAM2 const float4 *gauge0, const float4 *gauge1 00077 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_8_SINGLE 00078 #ifdef DIRECT_ACCESS_LINK 00079 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_8_FLOAT4 00080 #else 00081 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_8_FLOAT4_TEX 00082 #endif // DIRECT_ACCESS_LINK 00083 00084 #else 00085 #define DD_PARAM2 const short4 *gauge0, const short4* gauge1 00086 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_8_SINGLE 00087 #ifdef DIRECT_ACCESS_LINK 00088 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_8_SHORT4 00089 #else 00090 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_8_SHORT4_TEX 00091 #endif // DIRECT_ACCESS_LINK 00092 #endif // DD_PREC 00093 #elif (DD_RECON==1) // reconstruct from 12 reals 00094 #define DD_RECON_F 12 00095 00096 #if (DD_PREC==0) 00097 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_12_DOUBLE 00098 #ifdef DIRECT_ACCESS_LINK 00099 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_12_DOUBLE2 00100 #else 00101 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_12_DOUBLE2_TEX 00102 #endif // DIRECT_ACCESS_LINK 00103 #define DD_PARAM2 const double2 *gauge0, const double2 *gauge1 00104 00105 #elif (DD_PREC==1) 00106 #define DD_PARAM2 const float4 *gauge0, const float4 *gauge1 00107 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_12_SINGLE 00108 #ifdef DIRECT_ACCESS_LINK 00109 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_12_FLOAT4 00110 #else 00111 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_12_FLOAT4_TEX 00112 #endif // DIRECT_ACCESS_LINK 00113 00114 #else 00115 #define DD_PARAM2 const short4 *gauge0, const short4 *gauge1 00116 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_12_SINGLE 00117 #ifdef DIRECT_ACCESS_LINK 00118 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_12_SHORT4 00119 #else 00120 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_12_SHORT4_TEX 00121 #endif // DIRECT_ACCESS_LINK 00122 #endif // DD_PREC 00123 #else // no reconstruct, load all components 00124 #define DD_RECON_F 18 00125 #define GAUGE_FLOAT2 00126 #if (DD_PREC==0) 00127 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_18_DOUBLE 00128 #ifdef DIRECT_ACCESS_LINK 00129 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_18_DOUBLE2 00130 #else 00131 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_18_DOUBLE2_TEX 00132 #endif // DIRECT_ACCESS_LINK 00133 #define DD_PARAM2 const double2 *gauge0, const double2 *gauge1 00134 00135 #elif (DD_PREC==1) 00136 #define DD_PARAM2 const float4 *gauge0, const float4 *gauge1 // FIXME for direct reading, really float2 00137 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_18_SINGLE 00138 #ifdef DIRECT_ACCESS_LINK 00139 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_18_FLOAT2 00140 #else 00141 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_18_FLOAT2_TEX 00142 #endif // DIRECT_ACCESS_LINK 00143 00144 #else 00145 #define DD_PARAM2 const short4 *gauge0, const short4 *gauge1 // FIXME for direct reading, really short2 00146 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_18_SINGLE 00147 #ifdef DIRECT_ACCESS_LINK 00148 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_18_SHORT2 00149 #else 00150 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_18_SHORT2_TEX 00151 #endif //DIRECT_ACCESS_LINK 00152 #endif 00153 #endif 00154 00155 #if (DD_PREC==0) // double-precision fields 00156 00157 #define TPROJSCALE tProjScale 00158 00159 // double-precision gauge field 00160 #if (defined DIRECT_ACCESS_LINK) || (defined FERMI_NO_DBLE_TEX) 00161 #define GAUGE0TEX gauge0 00162 #define GAUGE1TEX gauge1 00163 #else 00164 #define GAUGE0TEX gauge0TexDouble2 00165 #define GAUGE1TEX gauge1TexDouble2 00166 #endif 00167 00168 #define GAUGE_FLOAT2 00169 00170 // double-precision spinor fields 00171 #define DD_PARAM1 double2* out, float *null1 00172 #define DD_PARAM3 const double2* in, const float *null4 00173 #if (defined DIRECT_ACCESS_WILSON_SPINOR) || (defined FERMI_NO_DBLE_TEX) 00174 #define READ_SPINOR READ_SPINOR_DOUBLE 00175 #define READ_SPINOR_UP READ_SPINOR_DOUBLE_UP 00176 #define READ_SPINOR_DOWN READ_SPINOR_DOUBLE_DOWN 00177 #define SPINORTEX in 00178 #else 00179 #define READ_SPINOR READ_SPINOR_DOUBLE_TEX 00180 #define READ_SPINOR_UP READ_SPINOR_DOUBLE_UP_TEX 00181 #define READ_SPINOR_DOWN READ_SPINOR_DOUBLE_DOWN_TEX 00182 #define SPINORTEX spinorTexDouble 00183 #endif 00184 #if (defined DIRECT_ACCESS_WILSON_INTER) || (defined FERMI_NO_DBLE_TEX) 00185 #define READ_INTERMEDIATE_SPINOR READ_SPINOR_DOUBLE 00186 #define INTERTEX out 00187 #else 00188 #define READ_INTERMEDIATE_SPINOR READ_SPINOR_DOUBLE_TEX 00189 #define INTERTEX interTexDouble 00190 #endif 00191 #define WRITE_SPINOR WRITE_SPINOR_DOUBLE2 00192 #define SPINOR_DOUBLE 00193 #if (DD_XPAY==1) 00194 #if (defined DIRECT_ACCESS_WILSON_ACCUM) || (defined FERMI_NO_DBLE_TEX) 00195 #define ACCUMTEX x 00196 #define READ_ACCUM READ_ACCUM_DOUBLE 00197 #else 00198 #define ACCUMTEX accumTexDouble 00199 #define READ_ACCUM READ_ACCUM_DOUBLE_TEX 00200 #endif 00201 00202 #endif 00203 00204 #define SPINOR_HOP 12 00205 00206 #elif (DD_PREC==1) // single-precision fields 00207 00208 #define TPROJSCALE tProjScale_f 00209 00210 // single-precision gauge field 00211 #ifdef DIRECT_ACCESS_LINK 00212 #define GAUGE0TEX gauge0 00213 #define GAUGE1TEX gauge1 00214 #else 00215 #if (DD_RECON_F == 18) 00216 #define GAUGE0TEX gauge0TexSingle2 00217 #define GAUGE1TEX gauge1TexSingle2 00218 #else 00219 #define GAUGE0TEX gauge0TexSingle4 00220 #define GAUGE1TEX gauge1TexSingle4 00221 #endif 00222 #endif 00223 00224 00225 // single-precision spinor fields 00226 #define DD_PARAM1 float4* out, float *null1 00227 #define DD_PARAM3 const float4* in, const float *null4 00228 #ifdef DIRECT_ACCESS_WILSON_SPINOR 00229 #define READ_SPINOR READ_SPINOR_SINGLE 00230 #define READ_SPINOR_UP READ_SPINOR_SINGLE_UP 00231 #define READ_SPINOR_DOWN READ_SPINOR_SINGLE_DOWN 00232 #define SPINORTEX in 00233 #else 00234 #define READ_SPINOR READ_SPINOR_SINGLE_TEX 00235 #define READ_SPINOR_UP READ_SPINOR_SINGLE_UP_TEX 00236 #define READ_SPINOR_DOWN READ_SPINOR_SINGLE_DOWN_TEX 00237 #define SPINORTEX spinorTexSingle 00238 #endif 00239 #ifdef DIRECT_ACCESS_WILSON_INTER 00240 #define READ_INTERMEDIATE_SPINOR READ_SPINOR_SINGLE 00241 #define INTERTEX out 00242 #else 00243 #define READ_INTERMEDIATE_SPINOR READ_SPINOR_SINGLE_TEX 00244 #define INTERTEX interTexSingle 00245 #endif 00246 #define WRITE_SPINOR WRITE_SPINOR_FLOAT4 00247 #if (DD_XPAY==1) 00248 #ifdef DIRECT_ACCESS_WILSON_ACCUM 00249 #define ACCUMTEX x 00250 #define READ_ACCUM READ_ACCUM_SINGLE 00251 #else 00252 #define ACCUMTEX accumTexSingle 00253 #define READ_ACCUM READ_ACCUM_SINGLE_TEX 00254 #endif 00255 #endif 00256 00257 #define SPINOR_HOP 6 00258 00259 #else // half-precision fields 00260 00261 #define TPROJSCALE tProjScale_f 00262 00263 // half-precision gauge field 00264 #ifdef DIRECT_ACCESS_LINK 00265 #define GAUGE0TEX gauge0 00266 #define GAUGE1TEX gauge1 00267 #else 00268 #if (DD_RECON_F == 18) 00269 #define GAUGE0TEX gauge0TexHalf2 00270 #define GAUGE1TEX gauge1TexHalf2 00271 #else 00272 #define GAUGE0TEX gauge0TexHalf4 00273 #define GAUGE1TEX gauge1TexHalf4 00274 #endif 00275 #endif 00276 00277 00278 // half-precision spinor fields 00279 #ifdef DIRECT_ACCESS_WILSON_SPINOR 00280 #define READ_SPINOR READ_SPINOR_HALF 00281 #define READ_SPINOR_UP READ_SPINOR_HALF_UP 00282 #define READ_SPINOR_DOWN READ_SPINOR_HALF_DOWN 00283 #define SPINORTEX in 00284 #else 00285 #define READ_SPINOR READ_SPINOR_HALF_TEX 00286 #define READ_SPINOR_UP READ_SPINOR_HALF_UP_TEX 00287 #define READ_SPINOR_DOWN READ_SPINOR_HALF_DOWN_TEX 00288 #define SPINORTEX spinorTexHalf 00289 #endif 00290 #ifdef DIRECT_ACCESS_WILSON_INTER 00291 #define READ_INTERMEDIATE_SPINOR READ_SPINOR_HALF 00292 #define INTERTEX out 00293 #else 00294 #define READ_INTERMEDIATE_SPINOR READ_SPINOR_HALF_TEX 00295 #define INTERTEX interTexHalf 00296 #endif 00297 #define DD_PARAM1 short4* out, float *outNorm 00298 #define DD_PARAM3 const short4* in, const float *inNorm 00299 #define WRITE_SPINOR WRITE_SPINOR_SHORT4 00300 #if (DD_XPAY==1) 00301 #ifdef DIRECT_ACCESS_WILSON_ACCUM 00302 #define ACCUMTEX x 00303 #define READ_ACCUM READ_ACCUM_HALF 00304 #else 00305 #define ACCUMTEX accumTexHalf 00306 #define READ_ACCUM READ_ACCUM_HALF_TEX 00307 #endif 00308 00309 #endif 00310 00311 #define SPINOR_HOP 6 00312 00313 #endif 00314 00315 // only build double precision if supported 00316 #if !(__COMPUTE_CAPABILITY__ < 130 && DD_PREC == 0) 00317 00318 #define DD_CONCAT(n,r,d,x) n ## r ## d ## x ## Kernel 00319 #define DD_FUNC(n,r,d,x) DD_CONCAT(n,r,d,x) 00320 00321 // define the kernel 00322 00323 template <KernelType kernel_type> 00324 __global__ void DD_FUNC(DD_NAME_F, DD_RECON_F, DD_DAG_F, DD_XPAY_F) 00325 (DD_PARAM1, DD_PARAM2, DD_PARAM3, DD_PARAM4) { 00326 00327 #ifdef GPU_TWISTED_MASS_DIRAC 00328 00329 #if (__COMPUTE_CAPABILITY__ >= 200 && defined(SHARED_WILSON_DSLASH)) // Fermi optimal code 00330 00331 #if DD_DAG 00332 #include "tm_dslash_dagger_gt200_core.h" 00333 #else 00334 #include "tm_dslash_gt200_core.h" 00335 #endif 00336 00337 #elif (__COMPUTE_CAPABILITY__ >= 120) // GT200 optimal code 00338 00339 #if DD_DAG 00340 #include "tm_dslash_dagger_gt200_core.h" 00341 #else 00342 #include "tm_dslash_gt200_core.h" 00343 #endif 00344 00345 #else // fall-back is original G80 00346 00347 #if DD_DAG 00348 #include "tm_dslash_dagger_g80_core.h" 00349 #else 00350 #include "tm_dslash_g80_core.h" 00351 #endif 00352 00353 #endif 00354 00355 #endif 00356 00357 } 00358 00359 #endif 00360 00361 // clean up 00362 00363 #undef DD_NAME_F 00364 #undef DD_RECON_F 00365 #undef DD_DAG_F 00366 #undef DD_XPAY_F 00367 #undef DD_PARAM1 00368 #undef DD_PARAM2 00369 #undef DD_PARAM3 00370 #undef DD_PARAM4 00371 #undef DD_CONCAT 00372 #undef DD_FUNC 00373 00374 #undef DSLASH_XPAY 00375 #undef READ_GAUGE_MATRIX 00376 #undef RECONSTRUCT_GAUGE_MATRIX 00377 #undef GAUGE0TEX 00378 #undef GAUGE1TEX 00379 #undef READ_SPINOR 00380 #undef READ_SPINOR_UP 00381 #undef READ_SPINOR_DOWN 00382 #undef SPINORTEX 00383 #undef READ_INTERMEDIATE_SPINOR 00384 #undef INTERTEX 00385 #undef READ_ACCUM 00386 #undef ACCUMTEX 00387 #undef WRITE_SPINOR 00388 #undef GAUGE_FLOAT2 00389 #undef SPINOR_DOUBLE 00390 00391 #undef SPINOR_HOP 00392 00393 #undef TPROJSCALE 00394 00395 // prepare next set of options, or clean up after final iteration 00396 00397 #if (DD_DAG==0) 00398 #undef DD_DAG 00399 #define DD_DAG 1 00400 #else 00401 #undef DD_DAG 00402 #define DD_DAG 0 00403 00404 #if (DD_XPAY==0) 00405 #undef DD_XPAY 00406 #define DD_XPAY 1 00407 #else 00408 #undef DD_XPAY 00409 #define DD_XPAY 0 00410 00411 #if (DD_RECON==0) 00412 #undef DD_RECON 00413 #define DD_RECON 1 00414 #elif (DD_RECON==1) 00415 #undef DD_RECON 00416 #define DD_RECON 2 00417 #else 00418 #undef DD_RECON 00419 #define DD_RECON 0 00420 00421 #if (DD_PREC==0) 00422 #undef DD_PREC 00423 #define DD_PREC 1 00424 #elif (DD_PREC==1) 00425 #undef DD_PREC 00426 #define DD_PREC 2 00427 00428 #else 00429 00430 #undef DD_LOOP 00431 #undef DD_DAG 00432 #undef DD_XPAY 00433 #undef DD_RECON 00434 #undef DD_PREC 00435 00436 #endif // DD_PREC 00437 #endif // DD_RECON 00438 #endif // DD_XPAY 00439 #endif // DD_DAG 00440 00441 #ifdef DD_LOOP 00442 #include "tm_dslash_def.h" 00443 #endif