QUDA v0.3.2
A library for QCD on GPUs

quda/lib/staggered_dslash_def.h

Go to the documentation of this file.
00001 // staggered_dslash_def.h - Dslash kernel definitions
00002 //
00003 // See comments in wilson_dslash_def.h
00004 
00005 // initialize on first iteration
00006 
00007 #ifndef DD_LOOP
00008 #define DD_LOOP
00009 #define DD_DAG 0
00010 #define DD_XPAY 0
00011 #define DD_RECON 0
00012 #define DD_PREC 0
00013 #endif
00014 
00015 // set options for current iteration
00016 
00017 #define DD_FNAME staggeredDslash
00018 
00019 #if (DD_DAG==0) // no dagger
00020 #define DD_DAG_F
00021 #else           // dagger
00022 #define DD_DAG_F Dagger
00023 #endif
00024 
00025 #if (DD_XPAY==0) // no xpay 
00026 #define DD_XPAY_F 
00027 #define DD_PARAM5 const int oddBit
00028 #else            // xpay
00029 #if (DD_PREC == 0)
00030 #define DD_PARAM5 const int oddBit, const double2 *x, const float *xNorm, const double a
00031 #elif (DD_PREC == 1) 
00032 #define DD_PARAM5 const int oddBit, const float2 *x, const float *xNorm, const float a
00033 #else
00034 #define DD_PARAM5 const int oddBit, const short2 *x, const float *xNorm, const float a
00035 #endif
00036 #if (DD_XPAY==1)
00037 #define DD_XPAY_F Xpay
00038 #define DSLASH_XPAY
00039 #else
00040 #define DD_XPAY_F Axpy
00041 #define DSLASH_AXPY
00042 #endif
00043 #endif
00044 
00045 #if (DD_RECON==0) // reconstruct from 8 reals
00046 #define DD_RECON_F 8
00047 #if (DD_PREC==0)
00048 #define DD_PARAM2 const double2 *fatGauge0, const double2 *fatGauge1, const double2* longGauge0, const double2* longGauge1
00049 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_GAUGE_MATRIX_8_DOUBLE
00050 #define READ_FAT_MATRIX READ_FAT_MATRIX_18_DOUBLE
00051 #define READ_LONG_MATRIX READ_LONG_MATRIX_8_DOUBLE
00052 #elif (DD_PREC==1)
00053 #define DD_PARAM2 const float2 *fatGauge0, const float2 *fatGauge1, const float4* longGauge0, const float4* longGauge1
00054 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_GAUGE_MATRIX_8_SINGLE
00055 #define READ_FAT_MATRIX READ_FAT_MATRIX_18_SINGLE
00056 #define READ_LONG_MATRIX READ_LONG_MATRIX_8_SINGLE
00057 #else
00058 #define DD_PARAM2 const short2 *fatGauge0, const short2* fatGauge1, const short4* longGauge0, const short4* longGauge1
00059 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_GAUGE_MATRIX_8_SINGLE
00060 #define READ_FAT_MATRIX READ_FAT_MATRIX_18_HALF
00061 #define READ_LONG_MATRIX READ_LONG_MATRIX_8_HALF
00062 #endif
00063 
00064 #elif (DD_RECON ==1)// reconstruct from 12 reals
00065 
00066 #define DD_RECON_F 12
00067 #if (DD_PREC==0)
00068 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_GAUGE_MATRIX_12_DOUBLE
00069 #define READ_FAT_MATRIX READ_FAT_MATRIX_18_DOUBLE
00070 #define READ_LONG_MATRIX READ_LONG_MATRIX_12_DOUBLE
00071 #define DD_PARAM2 const double2 *fatGauge0, const double2 *fatGauge1,  const double2* longGauge0, const double2* longGauge1
00072 #elif (DD_PREC==1)
00073 #define DD_PARAM2 const float2 *fatGauge0, const float2 *fatGauge1, const float4* longGauge0, const float4* longGauge1
00074 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_GAUGE_MATRIX_12_SINGLE
00075 #define READ_FAT_MATRIX READ_FAT_MATRIX_18_SINGLE
00076 #define READ_LONG_MATRIX READ_LONG_MATRIX_12_SINGLE
00077 #else
00078 #define DD_PARAM2 const short2 *fatGauge0, const short2 *fatGauge1, const short4* longGauge0, const short4* longGauge1
00079 #define RECONSTRUCT_GAUGE_MATRIX RECONSTRUCT_GAUGE_MATRIX_12_SINGLE
00080 #define READ_FAT_MATRIX READ_FAT_MATRIX_18_HALF
00081 #define READ_LONG_MATRIX READ_LONG_MATRIX_12_HALF
00082 #endif
00083 
00084 #else //18 reconstruct
00085 #define DD_RECON_F 18
00086 #define RECONSTRUCT_GAUGE_MATRIX(dir, gauge, idx, sign)
00087 #if (DD_PREC==0)
00088 #define READ_FAT_MATRIX READ_FAT_MATRIX_18_DOUBLE
00089 #define READ_LONG_MATRIX READ_LONG_MATRIX_18_DOUBLE
00090 #define DD_PARAM2 const double2 *fatGauge0, const double2 *fatGauge1,  const double2* longGauge0, const double2* longGauge1
00091 #elif (DD_PREC==1)
00092 #define DD_PARAM2 const float2 *fatGauge0, const float2 *fatGauge1, const float2* longGauge0, const float2* longGauge1
00093 #define READ_FAT_MATRIX READ_FAT_MATRIX_18_SINGLE
00094 #define READ_LONG_MATRIX READ_LONG_MATRIX_18_SINGLE
00095 #else 
00096 #define DD_PARAM2 const short2 *fatGauge0, const short2 *fatGauge1, const short2* longGauge0, const short2* longGauge1
00097 #define READ_FAT_MATRIX READ_FAT_MATRIX_18_HALF
00098 #define READ_LONG_MATRIX READ_LONG_MATRIX_18_HALF
00099 #endif
00100 
00101 #endif
00102 
00103 #if (DD_PREC==0) // double-precision fields
00104 
00105 // gauge field
00106 #define DD_PREC_F D
00107 #ifndef DIRECT_ACCESS_FAT_LINK
00108 #define FATLINK0TEX fatGauge0TexDouble
00109 #define FATLINK1TEX fatGauge1TexDouble
00110 #else
00111 #define FATLINK0TEX fatGauge0
00112 #define FATLINK1TEX fatGauge1
00113 #endif
00114 
00115 #ifndef DIRECT_ACCESS_LONG_LINK //longlink access
00116 #define LONGLINK0TEX longGauge0TexDouble
00117 #define LONGLINK1TEX longGauge1TexDouble
00118 #else
00119 #define LONGLINK0TEX longGauge0
00120 #define LONGLINK1TEX longGauge1
00121 #endif
00122 
00123 #define GAUGE_DOUBLE
00124 
00125 // spinor fields
00126 #define DD_PARAM1 double2* g_out, float *null1
00127 #define DD_PARAM4 const double2* in, const float *null4
00128 #define READ_SPINOR READ_SPINOR_DOUBLE
00129 #define READ_SPINOR_UP READ_SPINOR_DOUBLE_UP
00130 #define READ_SPINOR_DOWN READ_SPINOR_DOUBLE_DOWN
00131 #ifndef DIRECT_ACCESS_SPINOR
00132 #define SPINORTEX spinorTexDouble
00133 #else
00134 #define SPINORTEX in
00135 #endif
00136 #define WRITE_SPINOR WRITE_ST_SPINOR_DOUBLE2
00137 #define READ_AND_SUM_SPINOR READ_AND_SUM_ST_SPINOR
00138 #define READ_1ST_NBR_SPINOR READ_1ST_NBR_SPINOR_DOUBLE
00139 #define READ_3RD_NBR_SPINOR READ_3RD_NBR_SPINOR_DOUBLE
00140 #define SPINOR_DOUBLE
00141 #if (DD_XPAY==1 || DD_XPAY == 2)
00142 #define ACCUMTEX accumTexDouble
00143 #define READ_ACCUM READ_ACCUM_DOUBLE
00144 #endif
00145 
00146 
00147 #elif (DD_PREC==1) // single-precision fields
00148 
00149 // gauge fields
00150 #define DD_PREC_F S
00151 
00152 #ifndef DIRECT_ACCESS_FAT_LINK
00153 #define FATLINK0TEX fatGauge0TexSingle
00154 #define FATLINK1TEX fatGauge1TexSingle
00155 #else
00156 #define FATLINK0TEX fatGauge0
00157 #define FATLINK1TEX fatGauge1
00158 #endif
00159 
00160 #ifndef DIRECT_ACCESS_LONG_LINK //longlink access
00161 #if (DD_RECON ==2)
00162 #define LONGLINK0TEX longGauge0TexSingle_norecon
00163 #define LONGLINK1TEX longGauge1TexSingle_norecon
00164 #else
00165 #define LONGLINK0TEX longGauge0TexSingle
00166 #define LONGLINK1TEX longGauge1TexSingle
00167 #endif
00168 #else
00169 #define LONGLINK0TEX longGauge0
00170 #define LONGLINK1TEX longGauge1
00171 #endif
00172 
00173 // spinor fields
00174 #define DD_PARAM1 float2* g_out, float *null1
00175 #define DD_PARAM4 const float2* in, const float *null4
00176 #define READ_SPINOR READ_SPINOR_SINGLE
00177 #define READ_SPINOR_UP READ_SPINOR_SINGLE_UP
00178 #define READ_SPINOR_DOWN READ_SPINOR_SINGLE_DOWN
00179 #define READ_1ST_NBR_SPINOR READ_1ST_NBR_SPINOR_SINGLE
00180 #define READ_3RD_NBR_SPINOR READ_3RD_NBR_SPINOR_SINGLE
00181 #ifndef DIRECT_ACCESS_SPINOR
00182 #define SPINORTEX spinorTexSingle2
00183 #else
00184 #define SPINORTEX in
00185 #endif
00186 #define WRITE_SPINOR WRITE_ST_SPINOR_FLOAT2
00187 #define READ_AND_SUM_SPINOR READ_AND_SUM_ST_SPINOR
00188 #if (DD_XPAY==1 || DD_XPAY == 2)
00189 #define ACCUMTEX accumTexSingle2
00190 #define READ_ACCUM READ_ST_ACCUM_SINGLE
00191 #endif
00192 
00193 
00194 #else             // half-precision fields
00195 
00196 // gauge fields
00197 #define DD_PREC_F H
00198 #define FATLINK0TEX fatGauge0TexHalf
00199 #define FATLINK1TEX fatGauge1TexHalf
00200 #if (DD_RECON ==2)
00201 #define LONGLINK0TEX longGauge0TexHalf_norecon
00202 #define LONGLINK1TEX longGauge1TexHalf_norecon
00203 #else
00204 #define LONGLINK0TEX longGauge0TexHalf
00205 #define LONGLINK1TEX longGauge1TexHalf
00206 #endif
00207 
00208 #define READ_SPINOR READ_ST_SPINOR_HALF
00209 #define READ_SPINOR_UP READ_SPINOR_HALF_UP
00210 #define READ_SPINOR_DOWN READ_SPINOR_HALF_DOWN
00211 #define READ_1ST_NBR_SPINOR READ_1ST_NBR_SPINOR_HALF
00212 #define READ_3RD_NBR_SPINOR READ_3RD_NBR_SPINOR_HALF
00213 #define SPINORTEX spinorTexHalf2
00214 #define DD_PARAM1 short2* g_out, float *outNorm
00215 #define DD_PARAM4 const short2* in, const float *inNorm
00216 #define WRITE_SPINOR WRITE_ST_SPINOR_SHORT2
00217 #define READ_AND_SUM_SPINOR READ_AND_SUM_ST_SPINOR_HALF
00218 #if (DD_XPAY==1 || DD_XPAY == 2)
00219 #define ACCUMTEX accumTexHalf2
00220 #define READ_ACCUM READ_ST_ACCUM_HALF
00221 #endif
00222 
00223 #endif
00224 
00225 // only build double precision if supported
00226 #if !(__CUDA_ARCH__ < 130 && DD_PREC == 0) 
00227 
00228 #define DD_CONCAT(n,r,d,x) n ## r ## d ## x ## Kernel
00229 #define DD_FUNC(n,r,d,x) DD_CONCAT(n,r,d,x)
00230 
00231 // define the kernel
00232 __global__ void DD_FUNC(DD_FNAME, DD_RECON_F, DD_DAG_F, DD_XPAY_F)
00233   (DD_PARAM1, DD_PARAM2,  DD_PARAM4, DD_PARAM5) {
00234 
00235 #ifdef GPU_STAGGERED_DIRAC
00236 #include "staggered_dslash_core.h"
00237 #endif
00238 
00239 }
00240 
00241 #endif
00242 
00243 
00244 
00245 
00246 // clean up
00247 
00248 #undef DD_PREC_F
00249 #undef DD_RECON_F
00250 #undef DD_DAG_F
00251 #undef DD_XPAY_F
00252 #undef DD_PARAM1
00253 #undef DD_PARAM2
00254 #undef DD_PARAM4
00255 #undef DD_PARAM5
00256 #undef DD_FNAME
00257 #undef DD_CONCAT
00258 #undef DD_FUNC
00259 
00260 #undef DSLASH_XPAY
00261 #undef DSLASH_AXPY
00262 #undef READ_GAUGE_MATRIX
00263 #undef RECONSTRUCT_GAUGE_MATRIX
00264 #undef FATLINK0TEX
00265 #undef FATLINK1TEX
00266 #undef LONGLINK0TEX
00267 #undef LONGLINK1TEX
00268 #undef READ_SPINOR
00269 #undef READ_SPINOR_UP
00270 #undef READ_SPINOR_DOWN
00271 #undef SPINORTEX
00272 #undef WRITE_SPINOR
00273 #undef READ_AND_SUM_SPINOR
00274 #undef ACCUMTEX
00275 #undef READ_ACCUM
00276 #undef CLOVERTEX
00277 #undef READ_CLOVER
00278 #undef DSLASH_CLOVER
00279 #undef GAUGE_DOUBLE
00280 #undef SPINOR_DOUBLE
00281 #undef CLOVER_DOUBLE
00282 #undef READ_FAT_MATRIX
00283 #undef READ_LONG_MATRIX
00284 #undef READ_1ST_NBR_SPINOR
00285 #undef READ_3RD_NBR_SPINOR
00286 
00287 // prepare next set of options, or clean up after final iteration
00288 
00289 #if (DD_DAG==0)
00290 #undef DD_DAG
00291 #define DD_DAG 1
00292 #else
00293 #undef DD_DAG
00294 #define DD_DAG 0
00295 
00296 #if (DD_XPAY==0)
00297 #undef DD_XPAY
00298 #define DD_XPAY 1
00299 #elif (DD_XPAY==1)
00300 #undef DD_XPAY
00301 #define DD_XPAY 2
00302 #else
00303 #undef DD_XPAY
00304 #define DD_XPAY 0
00305 
00306 #if (DD_RECON==0)
00307 #undef DD_RECON
00308 #define DD_RECON 1
00309 #elif (DD_RECON ==1)
00310 #undef DD_RECON
00311 #define DD_RECON 2
00312 #else
00313 #undef DD_RECON
00314 #define DD_RECON 0
00315 
00316 #if (DD_PREC==0)
00317 #undef DD_PREC
00318 #define DD_PREC 1
00319 #elif (DD_PREC==1)
00320 #undef DD_PREC
00321 #define DD_PREC 2
00322 #else
00323 #undef DD_PREC
00324 #define DD_PREC 0
00325 
00326 #undef DD_LOOP
00327 #undef DD_DAG
00328 #undef DD_XPAY
00329 #undef DD_RECON
00330 #undef DD_PREC
00331 
00332 #endif // DD_PREC
00333 #endif // DD_RECON
00334 #endif // DD_XPAY
00335 #endif // DD_DAG
00336 
00337 #ifdef DD_LOOP
00338 #include "staggered_dslash_def.h"
00339 #endif
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Defines