|
QUDA v0.3.2
A library for QCD on GPUs
|
00001 // staggered_dslash_def.h - Dslash kernel definitions 00002 // 00003 // See comments in wilson_dslash_def.h 00004 00005 // initialize on first iteration 00006 00007 #ifndef DD_LOOP 00008 #define DD_LOOP 00009 #define DD_DAG 0 00010 #define DD_XPAY 0 00011 #define DD_RECON 0 00012 #define DD_PREC 0 00013 #endif 00014 00015 // set options for current iteration 00016 00017 #define DD_FNAME staggeredDslash 00018 00019 #if (DD_DAG==0) // no dagger 00020 #define DD_DAG_F 00021 #else // dagger 00022 #define DD_DAG_F Dagger 00023 #endif 00024 00025 #if (DD_XPAY==0) // no xpay 00026 #define DD_XPAY_F 00027 #define DD_PARAM5 const int oddBit 00028 #else // xpay 00029 #if (DD_PREC == 0) 00030 #define DD_PARAM5 const int oddBit, const double2 *x, const float *xNorm, const double a 00031 #elif (DD_PREC == 1) 00032 #define DD_PARAM5 const int oddBit, const float2 *x, const float *xNorm, const float a 00033 #else 00034 #define DD_PARAM5 const int oddBit, const short2 *x, const float *xNorm, const float a 00035 #endif 00036 #if (DD_XPAY==1) 00037 #define DD_XPAY_F Xpay 00038 #define DSLASH_XPAY 00039 #else 00040 #define DD_XPAY_F Axpy 00041 #define DSLASH_AXPY 00042 #endif 00043 #endif 00044 00045 #if (DD_RECON==0) // reconstruct from 8 reals 00046 #define DD_RECON_F 8 00047 #if (DD_PREC==0) 00048 #define DD_PARAM2 const double2 *fatGauge0, const double2 *fatGauge1, const double2* longGauge0, const double2* longGauge1 00049 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_GAUGE_MATRIX_8_DOUBLE 00050 #define READ_FAT_MATRIX READ_FAT_MATRIX_18_DOUBLE 00051 #define READ_LONG_MATRIX READ_LONG_MATRIX_8_DOUBLE 00052 #elif (DD_PREC==1) 00053 #define DD_PARAM2 const float2 *fatGauge0, const float2 *fatGauge1, const float4* longGauge0, const float4* longGauge1 00054 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_GAUGE_MATRIX_8_SINGLE 00055 #define READ_FAT_MATRIX READ_FAT_MATRIX_18_SINGLE 00056 #define READ_LONG_MATRIX READ_LONG_MATRIX_8_SINGLE 00057 #else 00058 #define DD_PARAM2 const short2 *fatGauge0, const short2* fatGauge1, const short4* longGauge0, const short4* longGauge1 00059 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_GAUGE_MATRIX_8_SINGLE 00060 #define READ_FAT_MATRIX READ_FAT_MATRIX_18_HALF 00061 #define READ_LONG_MATRIX READ_LONG_MATRIX_8_HALF 00062 #endif 00063 00064 #elif (DD_RECON ==1)// reconstruct from 12 reals 00065 00066 #define DD_RECON_F 12 00067 #if (DD_PREC==0) 00068 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_GAUGE_MATRIX_12_DOUBLE 00069 #define READ_FAT_MATRIX READ_FAT_MATRIX_18_DOUBLE 00070 #define READ_LONG_MATRIX READ_LONG_MATRIX_12_DOUBLE 00071 #define DD_PARAM2 const double2 *fatGauge0, const double2 *fatGauge1, const double2* longGauge0, const double2* longGauge1 00072 #elif (DD_PREC==1) 00073 #define DD_PARAM2 const float2 *fatGauge0, const float2 *fatGauge1, const float4* longGauge0, const float4* longGauge1 00074 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_GAUGE_MATRIX_12_SINGLE 00075 #define READ_FAT_MATRIX READ_FAT_MATRIX_18_SINGLE 00076 #define READ_LONG_MATRIX READ_LONG_MATRIX_12_SINGLE 00077 #else 00078 #define DD_PARAM2 const short2 *fatGauge0, const short2 *fatGauge1, const short4* longGauge0, const short4* longGauge1 00079 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_GAUGE_MATRIX_12_SINGLE 00080 #define READ_FAT_MATRIX READ_FAT_MATRIX_18_HALF 00081 #define READ_LONG_MATRIX READ_LONG_MATRIX_12_HALF 00082 #endif 00083 00084 #else //18 reconstruct 00085 #define DD_RECON_F 18 00086 #define RECONSTRUCT_GAUGE_MATRIX(dir, gauge, idx, sign) 00087 #if (DD_PREC==0) 00088 #define READ_FAT_MATRIX READ_FAT_MATRIX_18_DOUBLE 00089 #define READ_LONG_MATRIX READ_LONG_MATRIX_18_DOUBLE 00090 #define DD_PARAM2 const double2 *fatGauge0, const double2 *fatGauge1, const double2* longGauge0, const double2* longGauge1 00091 #elif (DD_PREC==1) 00092 #define DD_PARAM2 const float2 *fatGauge0, const float2 *fatGauge1, const float2* longGauge0, const float2* longGauge1 00093 #define READ_FAT_MATRIX READ_FAT_MATRIX_18_SINGLE 00094 #define READ_LONG_MATRIX READ_LONG_MATRIX_18_SINGLE 00095 #else 00096 #define DD_PARAM2 const short2 *fatGauge0, const short2 *fatGauge1, const short2* longGauge0, const short2* longGauge1 00097 #define READ_FAT_MATRIX READ_FAT_MATRIX_18_HALF 00098 #define READ_LONG_MATRIX READ_LONG_MATRIX_18_HALF 00099 #endif 00100 00101 #endif 00102 00103 #if (DD_PREC==0) // double-precision fields 00104 00105 // gauge field 00106 #define DD_PREC_F D 00107 #ifndef DIRECT_ACCESS_FAT_LINK 00108 #define FATLINK0TEX fatGauge0TexDouble 00109 #define FATLINK1TEX fatGauge1TexDouble 00110 #else 00111 #define FATLINK0TEX fatGauge0 00112 #define FATLINK1TEX fatGauge1 00113 #endif 00114 00115 #ifndef DIRECT_ACCESS_LONG_LINK //longlink access 00116 #define LONGLINK0TEX longGauge0TexDouble 00117 #define LONGLINK1TEX longGauge1TexDouble 00118 #else 00119 #define LONGLINK0TEX longGauge0 00120 #define LONGLINK1TEX longGauge1 00121 #endif 00122 00123 #define GAUGE_DOUBLE 00124 00125 // spinor fields 00126 #define DD_PARAM1 double2* g_out, float *null1 00127 #define DD_PARAM4 const double2* in, const float *null4 00128 #define READ_SPINOR READ_SPINOR_DOUBLE 00129 #define READ_SPINOR_UP READ_SPINOR_DOUBLE_UP 00130 #define READ_SPINOR_DOWN READ_SPINOR_DOUBLE_DOWN 00131 #ifndef DIRECT_ACCESS_SPINOR 00132 #define SPINORTEX spinorTexDouble 00133 #else 00134 #define SPINORTEX in 00135 #endif 00136 #define WRITE_SPINOR WRITE_ST_SPINOR_DOUBLE2 00137 #define READ_AND_SUM_SPINOR READ_AND_SUM_ST_SPINOR 00138 #define READ_1ST_NBR_SPINOR READ_1ST_NBR_SPINOR_DOUBLE 00139 #define READ_3RD_NBR_SPINOR READ_3RD_NBR_SPINOR_DOUBLE 00140 #define SPINOR_DOUBLE 00141 #if (DD_XPAY==1 || DD_XPAY == 2) 00142 #define ACCUMTEX accumTexDouble 00143 #define READ_ACCUM READ_ACCUM_DOUBLE 00144 #endif 00145 00146 00147 #elif (DD_PREC==1) // single-precision fields 00148 00149 // gauge fields 00150 #define DD_PREC_F S 00151 00152 #ifndef DIRECT_ACCESS_FAT_LINK 00153 #define FATLINK0TEX fatGauge0TexSingle 00154 #define FATLINK1TEX fatGauge1TexSingle 00155 #else 00156 #define FATLINK0TEX fatGauge0 00157 #define FATLINK1TEX fatGauge1 00158 #endif 00159 00160 #ifndef DIRECT_ACCESS_LONG_LINK //longlink access 00161 #if (DD_RECON ==2) 00162 #define LONGLINK0TEX longGauge0TexSingle_norecon 00163 #define LONGLINK1TEX longGauge1TexSingle_norecon 00164 #else 00165 #define LONGLINK0TEX longGauge0TexSingle 00166 #define LONGLINK1TEX longGauge1TexSingle 00167 #endif 00168 #else 00169 #define LONGLINK0TEX longGauge0 00170 #define LONGLINK1TEX longGauge1 00171 #endif 00172 00173 // spinor fields 00174 #define DD_PARAM1 float2* g_out, float *null1 00175 #define DD_PARAM4 const float2* in, const float *null4 00176 #define READ_SPINOR READ_SPINOR_SINGLE 00177 #define READ_SPINOR_UP READ_SPINOR_SINGLE_UP 00178 #define READ_SPINOR_DOWN READ_SPINOR_SINGLE_DOWN 00179 #define READ_1ST_NBR_SPINOR READ_1ST_NBR_SPINOR_SINGLE 00180 #define READ_3RD_NBR_SPINOR READ_3RD_NBR_SPINOR_SINGLE 00181 #ifndef DIRECT_ACCESS_SPINOR 00182 #define SPINORTEX spinorTexSingle2 00183 #else 00184 #define SPINORTEX in 00185 #endif 00186 #define WRITE_SPINOR WRITE_ST_SPINOR_FLOAT2 00187 #define READ_AND_SUM_SPINOR READ_AND_SUM_ST_SPINOR 00188 #if (DD_XPAY==1 || DD_XPAY == 2) 00189 #define ACCUMTEX accumTexSingle2 00190 #define READ_ACCUM READ_ST_ACCUM_SINGLE 00191 #endif 00192 00193 00194 #else // half-precision fields 00195 00196 // gauge fields 00197 #define DD_PREC_F H 00198 #define FATLINK0TEX fatGauge0TexHalf 00199 #define FATLINK1TEX fatGauge1TexHalf 00200 #if (DD_RECON ==2) 00201 #define LONGLINK0TEX longGauge0TexHalf_norecon 00202 #define LONGLINK1TEX longGauge1TexHalf_norecon 00203 #else 00204 #define LONGLINK0TEX longGauge0TexHalf 00205 #define LONGLINK1TEX longGauge1TexHalf 00206 #endif 00207 00208 #define READ_SPINOR READ_ST_SPINOR_HALF 00209 #define READ_SPINOR_UP READ_SPINOR_HALF_UP 00210 #define READ_SPINOR_DOWN READ_SPINOR_HALF_DOWN 00211 #define READ_1ST_NBR_SPINOR READ_1ST_NBR_SPINOR_HALF 00212 #define READ_3RD_NBR_SPINOR READ_3RD_NBR_SPINOR_HALF 00213 #define SPINORTEX spinorTexHalf2 00214 #define DD_PARAM1 short2* g_out, float *outNorm 00215 #define DD_PARAM4 const short2* in, const float *inNorm 00216 #define WRITE_SPINOR WRITE_ST_SPINOR_SHORT2 00217 #define READ_AND_SUM_SPINOR READ_AND_SUM_ST_SPINOR_HALF 00218 #if (DD_XPAY==1 || DD_XPAY == 2) 00219 #define ACCUMTEX accumTexHalf2 00220 #define READ_ACCUM READ_ST_ACCUM_HALF 00221 #endif 00222 00223 #endif 00224 00225 // only build double precision if supported 00226 #if !(__CUDA_ARCH__ < 130 && DD_PREC == 0) 00227 00228 #define DD_CONCAT(n,r,d,x) n ## r ## d ## x ## Kernel 00229 #define DD_FUNC(n,r,d,x) DD_CONCAT(n,r,d,x) 00230 00231 // define the kernel 00232 __global__ void DD_FUNC(DD_FNAME, DD_RECON_F, DD_DAG_F, DD_XPAY_F) 00233 (DD_PARAM1, DD_PARAM2, DD_PARAM4, DD_PARAM5) { 00234 00235 #ifdef GPU_STAGGERED_DIRAC 00236 #include "staggered_dslash_core.h" 00237 #endif 00238 00239 } 00240 00241 #endif 00242 00243 00244 00245 00246 // clean up 00247 00248 #undef DD_PREC_F 00249 #undef DD_RECON_F 00250 #undef DD_DAG_F 00251 #undef DD_XPAY_F 00252 #undef DD_PARAM1 00253 #undef DD_PARAM2 00254 #undef DD_PARAM4 00255 #undef DD_PARAM5 00256 #undef DD_FNAME 00257 #undef DD_CONCAT 00258 #undef DD_FUNC 00259 00260 #undef DSLASH_XPAY 00261 #undef DSLASH_AXPY 00262 #undef READ_GAUGE_MATRIX 00263 #undef RECONSTRUCT_GAUGE_MATRIX 00264 #undef FATLINK0TEX 00265 #undef FATLINK1TEX 00266 #undef LONGLINK0TEX 00267 #undef LONGLINK1TEX 00268 #undef READ_SPINOR 00269 #undef READ_SPINOR_UP 00270 #undef READ_SPINOR_DOWN 00271 #undef SPINORTEX 00272 #undef WRITE_SPINOR 00273 #undef READ_AND_SUM_SPINOR 00274 #undef ACCUMTEX 00275 #undef READ_ACCUM 00276 #undef CLOVERTEX 00277 #undef READ_CLOVER 00278 #undef DSLASH_CLOVER 00279 #undef GAUGE_DOUBLE 00280 #undef SPINOR_DOUBLE 00281 #undef CLOVER_DOUBLE 00282 #undef READ_FAT_MATRIX 00283 #undef READ_LONG_MATRIX 00284 #undef READ_1ST_NBR_SPINOR 00285 #undef READ_3RD_NBR_SPINOR 00286 00287 // prepare next set of options, or clean up after final iteration 00288 00289 #if (DD_DAG==0) 00290 #undef DD_DAG 00291 #define DD_DAG 1 00292 #else 00293 #undef DD_DAG 00294 #define DD_DAG 0 00295 00296 #if (DD_XPAY==0) 00297 #undef DD_XPAY 00298 #define DD_XPAY 1 00299 #elif (DD_XPAY==1) 00300 #undef DD_XPAY 00301 #define DD_XPAY 2 00302 #else 00303 #undef DD_XPAY 00304 #define DD_XPAY 0 00305 00306 #if (DD_RECON==0) 00307 #undef DD_RECON 00308 #define DD_RECON 1 00309 #elif (DD_RECON ==1) 00310 #undef DD_RECON 00311 #define DD_RECON 2 00312 #else 00313 #undef DD_RECON 00314 #define DD_RECON 0 00315 00316 #if (DD_PREC==0) 00317 #undef DD_PREC 00318 #define DD_PREC 1 00319 #elif (DD_PREC==1) 00320 #undef DD_PREC 00321 #define DD_PREC 2 00322 #else 00323 #undef DD_PREC 00324 #define DD_PREC 0 00325 00326 #undef DD_LOOP 00327 #undef DD_DAG 00328 #undef DD_XPAY 00329 #undef DD_RECON 00330 #undef DD_PREC 00331 00332 #endif // DD_PREC 00333 #endif // DD_RECON 00334 #endif // DD_XPAY 00335 #endif // DD_DAG 00336 00337 #ifdef DD_LOOP 00338 #include "staggered_dslash_def.h" 00339 #endif
1.7.3