|
QUDA v0.3.2
A library for QCD on GPUs
|
00001 // tm_dslash_def.h - Twisted Mass Dslash kernel definitions 00002 00003 // There are currently 36 different variants of the Twisted Mass 00004 // Wilson Dslash kernel, each one characterized by a set of 5 options, 00005 // where each option can take one of several values (3*2*2*3 = 36). 00006 // This file is structured so that the C preprocessor loops through all 36 00007 // variants (in a manner resembling a counter), sets the appropriate 00008 // macros, and defines the corresponding functions. 00009 // 00010 // As an example of the function naming conventions, consider 00011 // 00012 // twistedMassDslash12DaggerXpayKernel(float4* out, ...). 00013 // 00014 // This is a twisted mass Dslash^dagger kernel where the result is 00015 // multiplied by "a" and summed with an input vector (Xpay), and the 00016 // gauge matrix is reconstructed from 12 real numbers. More 00017 // generally, each function name is given by the concatenation of the 00018 // following 4 fields, with "Kernel" at the end: 00019 // 00020 // DD_NAME_F = twistedMassDslash 00021 // DD_RECON_F = 8, 12, 18 00022 // DD_DAG_F = Dagger, [blank] 00023 // DD_XPAY_F = Xpay, [blank] 00024 // 00025 // In addition, the kernels are templated on the precision of the 00026 // fields (double, single, or half). 00027 00028 // initialize on first iteration 00029 00030 #ifndef DD_LOOP 00031 #define DD_LOOP 00032 #define DD_DAG 0 00033 #define DD_XPAY 0 00034 #define DD_RECON 0 00035 #define DD_PREC 0 00036 #endif 00037 00038 // set options for current iteration 00039 00040 #define DD_NAME_F twistedMassDslash 00041 00042 #if (DD_DAG==0) // no dagger 00043 #define DD_DAG_F 00044 #else // dagger 00045 #define DD_DAG_F Dagger 00046 #endif 00047 00048 #if (DD_XPAY==0) // no xpay 00049 #define DD_XPAY_F 00050 #if (DD_PREC == 0) 00051 #define DD_PARAM4 const int oddBit, const double a, const double b 00052 #else 00053 #define DD_PARAM4 const int oddBit, const float a, const float b 00054 #endif 00055 #else // xpay 00056 #define DSLASH_XPAY 00057 #define DD_XPAY_F Xpay 00058 #if (DD_PREC == 0) 00059 #define DD_PARAM4 const int oddBit, const double a, const double b, const double2 *x, const float *xNorm 00060 #elif (DD_PREC == 1) 00061 #define DD_PARAM4 const int oddBit, const float a, const float b, const float4 *x, const float *xNorm 00062 #else 00063 #define DD_PARAM4 const int oddBit, const float a, const float b, const short4 *x, const float *xNorm 00064 #endif 00065 #endif 00066 00067 #if (DD_RECON==0) // reconstruct from 8 reals 00068 #define DD_RECON_F 8 00069 #if (DD_PREC==0) 00070 #define DD_PARAM2 const double2 *gauge0, const double2 *gauge1 00071 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_8_DOUBLE 00072 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_8_DOUBLE 00073 #elif (DD_PREC==1) 00074 #define DD_PARAM2 const float4 *gauge0, const float4 *gauge1 00075 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_8_SINGLE 00076 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_8_SINGLE 00077 #else 00078 #define DD_PARAM2 const short4 *gauge0, const short4* gauge1 00079 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_8_SINGLE 00080 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_8_HALF 00081 #endif 00082 #elif (DD_RECON==1) // reconstruct from 12 reals 00083 #define DD_RECON_F 12 00084 #if (DD_PREC==0) 00085 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_12_DOUBLE 00086 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_12_DOUBLE 00087 #define DD_PARAM2 const double2 *gauge0, const double2 *gauge1 00088 #elif (DD_PREC==1) 00089 #define DD_PARAM2 const float4 *gauge0, const float4 *gauge1 00090 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_12_SINGLE 00091 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_12_SINGLE 00092 #else 00093 #define DD_PARAM2 const short4 *gauge0, const short4 *gauge1 00094 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_12_SINGLE 00095 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_12_SINGLE 00096 #endif 00097 #else // no reconstruct, load all components 00098 #define DD_RECON_F 18 00099 #define GAUGE_FLOAT2 00100 #if (DD_PREC==0) 00101 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_18_DOUBLE 00102 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_18_DOUBLE 00103 #define DD_PARAM2 const double2 *gauge0, const double2 *gauge1 00104 #elif (DD_PREC==1) 00105 #define DD_PARAM2 const float4 *gauge0, const float4 *gauge1 // FIXME for direct reading, really float2 00106 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_18_SINGLE 00107 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_18_SINGLE 00108 #else 00109 #define DD_PARAM2 const short4 *gauge0, const short4 *gauge1 // FIXME for direct reading, really short2 00110 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_MATRIX_18_SINGLE 00111 #define READ_GAUGE_MATRIX READ_GAUGE_MATRIX_18_SINGLE 00112 #endif 00113 #endif 00114 00115 #if (DD_PREC==0) // double-precision fields 00116 00117 // double-precision gauge field 00118 #define GAUGE0TEX gauge0TexDouble2 00119 #define GAUGE1TEX gauge1TexDouble2 00120 #define GAUGE_FLOAT2 00121 00122 // double-precision spinor fields 00123 #define DD_PARAM1 double2* out, float *null1 00124 #define DD_PARAM3 const double2* in, const float *null4 00125 #define READ_SPINOR READ_SPINOR_DOUBLE 00126 #define READ_SPINOR_UP READ_SPINOR_DOUBLE_UP 00127 #define READ_SPINOR_DOWN READ_SPINOR_DOUBLE_DOWN 00128 #define SPINORTEX spinorTexDouble 00129 #define WRITE_SPINOR WRITE_SPINOR_DOUBLE2 00130 #define SPINOR_DOUBLE 00131 #if (DD_XPAY==1) 00132 #define ACCUMTEX accumTexDouble 00133 #define READ_ACCUM READ_ACCUM_DOUBLE 00134 #endif 00135 00136 #elif (DD_PREC==1) // single-precision fields 00137 00138 // single-precision gauge field 00139 #if (DD_RECON_F == 18) 00140 #define GAUGE0TEX gauge0TexSingle2 00141 #define GAUGE1TEX gauge1TexSingle2 00142 #else 00143 #define GAUGE0TEX gauge0TexSingle4 00144 #define GAUGE1TEX gauge1TexSingle4 00145 #endif 00146 00147 // single-precision spinor fields 00148 #define DD_PARAM1 float4* out, float *null1 00149 #define DD_PARAM3 const float4* in, const float *null4 00150 #define READ_SPINOR READ_SPINOR_SINGLE 00151 #define READ_SPINOR_UP READ_SPINOR_SINGLE_UP 00152 #define READ_SPINOR_DOWN READ_SPINOR_SINGLE_DOWN 00153 #define SPINORTEX spinorTexSingle 00154 #define WRITE_SPINOR WRITE_SPINOR_FLOAT4 00155 #if (DD_XPAY==1) 00156 #define ACCUMTEX accumTexSingle 00157 #define READ_ACCUM READ_ACCUM_SINGLE 00158 #endif 00159 00160 #else // half-precision fields 00161 00162 // half-precision gauge field 00163 #if (DD_RECON_F == 18) 00164 #define GAUGE0TEX gauge0TexHalf2 00165 #define GAUGE1TEX gauge1TexHalf2 00166 #else 00167 #define GAUGE0TEX gauge0TexHalf4 00168 #define GAUGE1TEX gauge1TexHalf4 00169 #endif 00170 00171 // half-precision spinor fields 00172 #define READ_SPINOR READ_SPINOR_HALF 00173 #define READ_SPINOR_UP READ_SPINOR_HALF_UP 00174 #define READ_SPINOR_DOWN READ_SPINOR_HALF_DOWN 00175 #define SPINORTEX spinorTexHalf 00176 #define DD_PARAM1 short4* out, float *outNorm 00177 #define DD_PARAM3 const short4* in, const float *inNorm 00178 #define WRITE_SPINOR WRITE_SPINOR_SHORT4 00179 #if (DD_XPAY==1) 00180 #define ACCUMTEX accumTexHalf 00181 #define READ_ACCUM READ_ACCUM_HALF 00182 #endif 00183 00184 #endif 00185 00186 // only build double precision if supported 00187 #if !(__CUDA_ARCH__ < 130 && DD_PREC == 0) 00188 00189 #define DD_CONCAT(n,r,d,x) n ## r ## d ## x ## Kernel 00190 #define DD_FUNC(n,r,d,x) DD_CONCAT(n,r,d,x) 00191 00192 // define the kernel 00193 00194 __global__ void DD_FUNC(DD_NAME_F, DD_RECON_F, DD_DAG_F, DD_XPAY_F) 00195 (DD_PARAM1, DD_PARAM2, DD_PARAM3, DD_PARAM4) { 00196 00197 #ifdef GPU_TWISTED_MASS_DIRAC 00198 #if DD_DAG 00199 #include "tm_dslash_dagger_core.h" 00200 #else 00201 #include "tm_dslash_core.h" 00202 #endif 00203 #endif 00204 00205 } 00206 00207 #endif 00208 00209 // clean up 00210 00211 #undef DD_NAME_F 00212 #undef DD_RECON_F 00213 #undef DD_DAG_F 00214 #undef DD_XPAY_F 00215 #undef DD_PARAM1 00216 #undef DD_PARAM2 00217 #undef DD_PARAM3 00218 #undef DD_PARAM4 00219 #undef DD_CONCAT 00220 #undef DD_FUNC 00221 00222 #undef DSLASH_XPAY 00223 #undef READ_GAUGE_MATRIX 00224 #undef RECONSTRUCT_GAUGE_MATRIX 00225 #undef GAUGE0TEX 00226 #undef GAUGE1TEX 00227 #undef READ_SPINOR 00228 #undef READ_SPINOR_UP 00229 #undef READ_SPINOR_DOWN 00230 #undef SPINORTEX 00231 #undef WRITE_SPINOR 00232 #undef ACCUMTEX 00233 #undef READ_ACCUM 00234 #undef GAUGE_FLOAT2 00235 #undef SPINOR_DOUBLE 00236 00237 // prepare next set of options, or clean up after final iteration 00238 00239 #if (DD_DAG==0) 00240 #undef DD_DAG 00241 #define DD_DAG 1 00242 #else 00243 #undef DD_DAG 00244 #define DD_DAG 0 00245 00246 #if (DD_XPAY==0) 00247 #undef DD_XPAY 00248 #define DD_XPAY 1 00249 #else 00250 #undef DD_XPAY 00251 #define DD_XPAY 0 00252 00253 #if (DD_RECON==0) 00254 #undef DD_RECON 00255 #define DD_RECON 1 00256 #elif (DD_RECON==1) 00257 #undef DD_RECON 00258 #define DD_RECON 2 00259 #else 00260 #undef DD_RECON 00261 #define DD_RECON 0 00262 00263 #if (DD_PREC==0) 00264 #undef DD_PREC 00265 #define DD_PREC 1 00266 #elif (DD_PREC==1) 00267 #undef DD_PREC 00268 #define DD_PREC 2 00269 00270 #else 00271 00272 #undef DD_LOOP 00273 #undef DD_DAG 00274 #undef DD_XPAY 00275 #undef DD_RECON 00276 #undef DD_PREC 00277 00278 #endif // DD_PREC 00279 #endif // DD_RECON 00280 #endif // DD_XPAY 00281 #endif // DD_DAG 00282 00283 #ifdef DD_LOOP 00284 #include "tm_dslash_def.h" 00285 #endif
1.7.3