QUDA v0.4.0
A library for QCD on GPUs
|
00001 // dw_dslash_def.h - Domain Wall Dslash kernel definitions 00002 00003 // There are currently 36 different variants of the Domain Wall Dslash 00004 // kernel, each one characterized by a set of 4 options, where each 00005 // option can take one of several values (3*2*2*3 = 72). This file 00006 // is structured so that the C preprocessor loops through all 36 00007 // variants (in a manner resembling a counter), sets the appropriate 00008 // macros, and defines the corresponding functions. 00009 // 00010 // As an example of the function naming conventions, consider 00011 // 00012 // domainWallDslash12DaggerXpayKernel(float4* out, ...). 00013 // 00014 // This is a dw Dslash^dagger kernel where the result is 00015 // multiplied by "a" and summed with an input vector (Xpay), and the 00016 // gauge matrix is reconstructed from 12 real numbers. More 00017 // generally, each function name is given by the concatenation of the 00018 // following 4 fields, with "Kernel" at the end: 00019 // 00020 // DD_NAME_F = domainWallDslash 00021 // DD_RECON_F = 8, 12, 18 00022 // DD_DAG_F = Dagger, [blank] 00023 // DD_XPAY_F = Xpay, [blank] 00024 // 00025 // In addition, the kernels are templated on the precision of the 00026 // fields (double, single, or half). 00027 00028 // initialize on first iteration 00029 00030 #ifndef DD_LOOP 00031 #define DD_LOOP 00032 #define DD_DAG 0 00033 #define DD_XPAY 0 00034 #define DD_RECON 0 00035 #define DD_PREC 0 00036 #endif 00037 00038 // set options for current iteration 00039 00040 #define DD_NAME_F domainWallDslash 00041 00042 #if (DD_DAG==0) // no dagger 00043 #define DD_DAG_F 00044 #else // dagger 00045 #define DD_DAG_F Dagger 00046 #endif 00047 00048 #if (DD_XPAY==0) // no xpay 00049 #define DD_XPAY_F 00050 #else // xpay 00051 #define DSLASH_XPAY 00052 #define DD_XPAY_F Xpay 00053 #endif 00054 00055 #if (DD_PREC == 0) 00056 #define DD_PARAM4 const double mferm, const double2 *x, const float *xNorm, const double a, const DslashParam param 00057 #elif (DD_PREC == 1) 00058 #define DD_PARAM4 const float mferm, const float4 *x, const float *xNorm, const float a, const DslashParam param 00059 #else 00060 #define DD_PARAM4 const float mferm, const short4 *x, const float *xNorm, const float a, const DslashParam param 00061 #endif 00062 00063 #if (DD_RECON==0) // reconstruct from 8 reals 00064 #define DD_RECON_F 8 00065 00066 #if (DD_PREC==0) 00067 #define DD_PARAM2 const double2 *gauge0, const double2 *gauge1 00068 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_8_DOUBLE 00069 #ifdef DIRECT_ACCESS_LINK 00070 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_8_DOUBLE2 00071 #else 00072 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_8_DOUBLE2_TEX 00073 #endif // DIRECT_ACCESS_LINK 00074 00075 #elif (DD_PREC==1) 00076 #define DD_PARAM2 const float4 *gauge0, const float4 *gauge1 00077 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_8_SINGLE 00078 #ifdef DIRECT_ACCESS_LINK 00079 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_8_FLOAT4 00080 #else 00081 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_8_FLOAT4_TEX 00082 #endif // DIRECT_ACCESS_LINK 00083 00084 #else 00085 #define DD_PARAM2 const short4 *gauge0, const short4* gauge1 00086 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_8_SINGLE 00087 #ifdef DIRECT_ACCESS_LINK 00088 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_8_SHORT4 00089 #else 00090 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_8_SHORT4_TEX 00091 #endif // DIRECT_ACCESS_LINK 00092 #endif // DD_PREC 00093 #elif (DD_RECON==1) // reconstruct from 12 reals 00094 #define DD_RECON_F 12 00095 00096 #if (DD_PREC==0) 00097 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_12_DOUBLE 00098 #ifdef DIRECT_ACCESS_LINK 00099 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_12_DOUBLE2 00100 #else 00101 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_12_DOUBLE2_TEX 00102 #endif // DIRECT_ACCESS_LINK 00103 #define DD_PARAM2 const double2 *gauge0, const double2 *gauge1 00104 00105 #elif (DD_PREC==1) 00106 #define DD_PARAM2 const float4 *gauge0, const float4 *gauge1 00107 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_12_SINGLE 00108 #ifdef DIRECT_ACCESS_LINK 00109 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_12_FLOAT4 00110 #else 00111 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_12_FLOAT4_TEX 00112 #endif // DIRECT_ACCESS_LINK 00113 00114 #else 00115 #define DD_PARAM2 const short4 *gauge0, const short4 *gauge1 00116 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_12_SINGLE 00117 #ifdef DIRECT_ACCESS_LINK 00118 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_12_SHORT4 00119 #else 00120 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_12_SHORT4_TEX 00121 #endif // DIRECT_ACCESS_LINK 00122 #endif // DD_PREC 00123 #else // no reconstruct, load all components 00124 #define DD_RECON_F 18 00125 #define GAUGE_FLOAT2 00126 #if (DD_PREC==0) 00127 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_18_DOUBLE 00128 #ifdef DIRECT_ACCESS_LINK 00129 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_18_DOUBLE2 00130 #else 00131 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_18_DOUBLE2_TEX 00132 #endif // DIRECT_ACCESS_LINK 00133 #define DD_PARAM2 const double2 *gauge0, const double2 *gauge1 00134 00135 #elif (DD_PREC==1) 00136 #define DD_PARAM2 const float4 *gauge0, const float4 *gauge1 // FIXME for direct reading, really float2 00137 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_18_SINGLE 00138 #ifdef DIRECT_ACCESS_LINK 00139 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_18_FLOAT2 00140 #else 00141 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_18_FLOAT2_TEX 00142 #endif // DIRECT_ACCESS_LINK 00143 00144 #else 00145 #define DD_PARAM2 const short4 *gauge0, const short4 *gauge1 // FIXME for direct reading, really short2 00146 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_18_SINGLE 00147 #ifdef DIRECT_ACCESS_LINK 00148 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_18_SHORT2 00149 #else 00150 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_18_SHORT2_TEX 00151 #endif //DIRECT_ACCESS_LINK 00152 #endif 00153 #endif 00154 00155 #if (DD_PREC==0) // double-precision fields 00156 00157 // double-precision gauge field 00158 #if (defined DIRECT_ACCESS_LINK) || (defined FERMI_NO_DBLE_TEX) 00159 #define GAUGE0TEX gauge0 00160 #define GAUGE1TEX gauge1 00161 #else 00162 #define GAUGE0TEX gauge0TexDouble2 00163 #define GAUGE1TEX gauge1TexDouble2 00164 #endif 00165 00166 #define GAUGE_FLOAT2 00167 00168 // double-precision spinor fields 00169 #define DD_PARAM1 double2* out, float *null1 00170 #define DD_PARAM3 const double2* in, const float *null4 00171 #if (defined DIRECT_ACCESS_WILSON_SPINOR) || (defined FERMI_NO_DBLE_TEX) 00172 #define READ_SPINOR READ_SPINOR_DOUBLE 00173 #define READ_SPINOR_UP READ_SPINOR_DOUBLE_UP 00174 #define READ_SPINOR_DOWN READ_SPINOR_DOUBLE_DOWN 00175 #define SPINORTEX in 00176 #else 00177 #define READ_SPINOR READ_SPINOR_DOUBLE_TEX 00178 #define READ_SPINOR_UP READ_SPINOR_DOUBLE_UP_TEX 00179 #define READ_SPINOR_DOWN READ_SPINOR_DOUBLE_DOWN_TEX 00180 #define SPINORTEX spinorTexDouble 00181 #endif 00182 #define WRITE_SPINOR WRITE_SPINOR_DOUBLE2 00183 #define SPINOR_DOUBLE 00184 #if (DD_XPAY==1) 00185 #if (defined DIRECT_ACCESS_WILSON_ACCUM) || (defined FERMI_NO_DBLE_TEX) 00186 #define ACCUMTEX x 00187 #define READ_ACCUM READ_ACCUM_DOUBLE 00188 #else 00189 #define ACCUMTEX accumTexDouble 00190 #define READ_ACCUM READ_ACCUM_DOUBLE_TEX 00191 #endif 00192 00193 #endif 00194 00195 #elif (DD_PREC==1) // single-precision fields 00196 00197 // single-precision gauge field 00198 #ifdef DIRECT_ACCESS_LINK 00199 #define GAUGE0TEX gauge0 00200 #define GAUGE1TEX gauge1 00201 #else 00202 #if (DD_RECON_F == 18) 00203 #define GAUGE0TEX gauge0TexSingle2 00204 #define GAUGE1TEX gauge1TexSingle2 00205 #else 00206 #define GAUGE0TEX gauge0TexSingle4 00207 #define GAUGE1TEX gauge1TexSingle4 00208 #endif 00209 #endif 00210 00211 00212 // single-precision spinor fields 00213 #define DD_PARAM1 float4* out, float *null1 00214 #define DD_PARAM3 const float4* in, const float *null4 00215 #ifdef DIRECT_ACCESS_WILSON_SPINOR 00216 #define READ_SPINOR READ_SPINOR_SINGLE 00217 #define READ_SPINOR_UP READ_SPINOR_SINGLE_UP 00218 #define READ_SPINOR_DOWN READ_SPINOR_SINGLE_DOWN 00219 #define SPINORTEX in 00220 #else 00221 #define READ_SPINOR READ_SPINOR_SINGLE_TEX 00222 #define READ_SPINOR_UP READ_SPINOR_SINGLE_UP_TEX 00223 #define READ_SPINOR_DOWN READ_SPINOR_SINGLE_DOWN_TEX 00224 #define SPINORTEX spinorTexSingle 00225 #endif 00226 #define WRITE_SPINOR WRITE_SPINOR_FLOAT4 00227 #if (DD_XPAY==1) 00228 #ifdef DIRECT_ACCESS_WILSON_ACCUM 00229 #define ACCUMTEX x 00230 #define READ_ACCUM READ_ACCUM_SINGLE 00231 #else 00232 #define ACCUMTEX accumTexSingle 00233 #define READ_ACCUM READ_ACCUM_SINGLE_TEX 00234 #endif 00235 00236 #endif 00237 00238 #else // half-precision fields 00239 00240 // half-precision gauge field 00241 #ifdef DIRECT_ACCESS_LINK 00242 #define GAUGE0TEX gauge0 00243 #define GAUGE1TEX gauge1 00244 #else 00245 #if (DD_RECON_F == 18) 00246 #define GAUGE0TEX gauge0TexHalf2 00247 #define GAUGE1TEX gauge1TexHalf2 00248 #else 00249 #define GAUGE0TEX gauge0TexHalf4 00250 #define GAUGE1TEX gauge1TexHalf4 00251 #endif 00252 #endif 00253 00254 00255 // half-precision spinor fields 00256 #ifdef DIRECT_ACCESS_WILSON_SPINOR 00257 #define READ_SPINOR READ_SPINOR_HALF 00258 #define READ_SPINOR_UP READ_SPINOR_HALF_UP 00259 #define READ_SPINOR_DOWN READ_SPINOR_HALF_DOWN 00260 #define SPINORTEX in 00261 #else 00262 #define READ_SPINOR READ_SPINOR_HALF_TEX 00263 #define READ_SPINOR_UP READ_SPINOR_HALF_UP_TEX 00264 #define READ_SPINOR_DOWN READ_SPINOR_HALF_DOWN_TEX 00265 #define SPINORTEX spinorTexHalf 00266 #endif 00267 #define DD_PARAM1 short4* out, float *outNorm 00268 #define DD_PARAM3 const short4* in, const float *inNorm 00269 #define WRITE_SPINOR WRITE_SPINOR_SHORT4 00270 #if (DD_XPAY==1) 00271 #ifdef DIRECT_ACCESS_WILSON_ACCUM 00272 #define ACCUMTEX x 00273 #define READ_ACCUM READ_ACCUM_HALF 00274 #else 00275 #define ACCUMTEX accumTexHalf 00276 #define READ_ACCUM READ_ACCUM_HALF_TEX 00277 #endif 00278 00279 #endif 00280 00281 #endif 00282 00283 // only build double precision if supported 00284 #if !(__COMPUTE_CAPABILITY__ < 130 && DD_PREC == 0) 00285 00286 #define DD_CONCAT(n,r,d,x) n ## r ## d ## x ## Kernel 00287 #define DD_FUNC(n,r,d,x) DD_CONCAT(n,r,d,x) 00288 00289 // define the kernel 00290 00291 template <KernelType kernel_type> 00292 __global__ void DD_FUNC(DD_NAME_F, DD_RECON_F, DD_DAG_F, DD_XPAY_F) 00293 (DD_PARAM1, DD_PARAM2, DD_PARAM3, DD_PARAM4) { 00294 00295 #ifdef GPU_DOMAIN_WALL_DIRAC 00296 #if DD_DAG 00297 #include "dw_dslash_dagger_core.h" 00298 #else 00299 #include "dw_dslash_core.h" 00300 #endif 00301 #endif 00302 00303 } 00304 00305 #endif 00306 00307 // clean up 00308 00309 #undef DD_NAME_F 00310 #undef DD_RECON_F 00311 #undef DD_DAG_F 00312 #undef DD_XPAY_F 00313 #undef DD_PARAM1 00314 #undef DD_PARAM2 00315 #undef DD_PARAM3 00316 #undef DD_PARAM4 00317 #undef DD_CONCAT 00318 #undef DD_FUNC 00319 00320 #undef DSLASH_XPAY 00321 #undef READ_GAUGE_MATRIX 00322 #undef RECONSTRUCT_GAUGE_MATRIX 00323 #undef GAUGE0TEX 00324 #undef GAUGE1TEX 00325 #undef READ_SPINOR 00326 #undef READ_SPINOR_UP 00327 #undef READ_SPINOR_DOWN 00328 #undef SPINORTEX 00329 #undef WRITE_SPINOR 00330 #undef ACCUMTEX 00331 #undef READ_ACCUM 00332 #undef GAUGE_FLOAT2 00333 #undef SPINOR_DOUBLE 00334 00335 // prepare next set of options, or clean up after final iteration 00336 00337 #if (DD_DAG==0) 00338 #undef DD_DAG 00339 #define DD_DAG 1 00340 #else 00341 #undef DD_DAG 00342 #define DD_DAG 0 00343 00344 #if (DD_XPAY==0) 00345 #undef DD_XPAY 00346 #define DD_XPAY 1 00347 #else 00348 #undef DD_XPAY 00349 #define DD_XPAY 0 00350 00351 #if (DD_RECON==0) 00352 #undef DD_RECON 00353 #define DD_RECON 1 00354 #elif (DD_RECON==1) 00355 #undef DD_RECON 00356 #define DD_RECON 2 00357 #else 00358 #undef DD_RECON 00359 #define DD_RECON 0 00360 00361 #if (DD_PREC==0) 00362 #undef DD_PREC 00363 #define DD_PREC 1 00364 #elif (DD_PREC==1) 00365 #undef DD_PREC 00366 #define DD_PREC 2 00367 00368 #else 00369 00370 #undef DD_LOOP 00371 #undef DD_DAG 00372 #undef DD_XPAY 00373 #undef DD_RECON 00374 #undef DD_PREC 00375 00376 #endif // DD_PREC 00377 #endif // DD_RECON 00378 #endif // DD_XPAY 00379 #endif // DD_DAG 00380 00381 #ifdef DD_LOOP 00382 #include "dw_dslash_def.h" 00383 #endif