QUDA v0.3.2
A library for QCD on GPUs

quda/lib/dslash_quda.cu

Go to the documentation of this file.
00001 
00002 #include <stdlib.h>
00003 #include <stdio.h>
00004 
00005 //these are access control for staggered action
00006 #if (__CUDA_ARCH__ >= 200)
00007 //#define DIRECT_ACCESS_FAT_LINK
00008 //#define DIRECT_ACCESS_LONG_LINK
00009 #define DIRECT_ACCESS_SPINOR
00010 #else
00011 #define DIRECT_ACCESS_FAT_LINK
00012 //#define DIRECT_ACCESS_LONG_LINK
00013 //#define DIRECT_ACCESS_SPINOR
00014 #endif
00015 
00016 #include <quda_internal.h>
00017 #include <dslash_quda.h>
00018 
00019 #define BLOCK_DIM 64
00020 
00021 #include <dslash_textures.h>
00022 #include <dslash_constants.h>
00023 
00024 #include <staggered_dslash_def.h> // staggered Dslash kernels
00025 #include <wilson_dslash_def.h>    // Wilson Dslash kernels (including clover)
00026 #include <dw_dslash_def.h>        // Domain Wall kernels
00027 #include <tm_dslash_def.h>        // Twisted Mass kernels
00028 #include <dslash_core/tm_core.h>  // solo twisted mass kernel
00029 #include <clover_def.h>           // kernels for applying the clover term alone
00030 
00031 #ifndef SHARED_FLOATS_PER_THREAD
00032 #define SHARED_FLOATS_PER_THREAD 0
00033 #endif
00034 
00035 #include <blas_quda.h>
00036 
00037 __global__ void dummyKernel() {
00038   // do nothing
00039 }
00040 
00041 void initCache() {
00042 
00043 #if (__CUDA_ARCH__ >= 200)
00044 
00045   static int firsttime = 1;
00046   if (firsttime){       
00047     cudaFuncSetCacheConfig(dummyKernel, cudaFuncCachePreferL1);
00048     dummyKernel<<<1,1>>>();
00049     firsttime=0;
00050   }
00051 
00052 #endif
00053 
00054 }
00055 
00056 int dslashCudaSharedBytes(QudaPrecision precision) {
00057   return BLOCK_DIM*SHARED_FLOATS_PER_THREAD*precision;
00058 }
00059 
00060 template <int spinorN, typename spinorFloat, typename gaugeFloat>
00061 void dslashCuda(spinorFloat *out, float *outNorm, const gaugeFloat *gauge0, const gaugeFloat *gauge1, 
00062                 const QudaReconstructType reconstruct, const spinorFloat *in, const float *inNorm,
00063                 const int parity, const int dagger, const spinorFloat *x, const float *xNorm, 
00064                 const double &a, const int volume, const int length) {
00065 
00066   dim3 gridDim(volume/BLOCK_DIM, 1, 1);
00067   dim3 blockDim(BLOCK_DIM, 1, 1);
00068 
00069   int shared_bytes = blockDim.x*SHARED_FLOATS_PER_THREAD*bindSpinorTex<spinorN>(length, in, inNorm, x, xNorm);
00070 
00071   if (x==0) { // not doing xpay
00072     if (reconstruct == QUDA_RECONSTRUCT_NO) {
00073       if (!dagger) {
00074         dslash18Kernel <<<gridDim, blockDim, shared_bytes>>> 
00075           (out, outNorm, gauge0, gauge1, in, inNorm, parity);
00076       } else {
00077         dslash18DaggerKernel <<<gridDim, blockDim, shared_bytes>>> 
00078           (out, outNorm, gauge0, gauge1, in, inNorm, parity);
00079       }
00080     } else if (reconstruct == QUDA_RECONSTRUCT_12) {
00081       if (!dagger) {
00082         dslash12Kernel <<<gridDim, blockDim, shared_bytes>>> 
00083           (out, outNorm, gauge0, gauge1, in, inNorm, parity);
00084       } else {
00085         dslash12DaggerKernel <<<gridDim, blockDim, shared_bytes>>> 
00086           (out, outNorm, gauge0, gauge1, in, inNorm, parity);
00087       }
00088     } else {
00089       if (!dagger) {
00090         dslash8Kernel <<<gridDim, blockDim, shared_bytes>>> 
00091           (out, outNorm, gauge0, gauge1, in, inNorm, parity);
00092       } else {
00093         dslash8DaggerKernel <<<gridDim, blockDim, shared_bytes>>> 
00094           (out, outNorm, gauge0, gauge1, in, inNorm, parity);
00095       }
00096     }
00097   } else { // doing xpay
00098     if (reconstruct == QUDA_RECONSTRUCT_NO) {
00099       if (!dagger) {
00100         dslash18XpayKernel <<<gridDim, blockDim, shared_bytes>>> 
00101           (out, outNorm, gauge0, gauge1, in, inNorm, parity, x, xNorm, a);
00102       } else {
00103         dslash18DaggerXpayKernel <<<gridDim, blockDim, shared_bytes>>> 
00104           (out, outNorm, gauge0, gauge1, in, inNorm, parity, x, xNorm, a);
00105       }
00106     } else if (reconstruct == QUDA_RECONSTRUCT_12) {
00107       if (!dagger) {
00108         dslash12XpayKernel <<<gridDim, blockDim, shared_bytes>>> 
00109           (out, outNorm, gauge0, gauge1, in, inNorm, parity, x, xNorm, a);
00110       } else {
00111         dslash12DaggerXpayKernel <<<gridDim, blockDim, shared_bytes>>> 
00112           (out, outNorm, gauge0, gauge1, in, inNorm, parity, x, xNorm, a);
00113       }
00114     } else if (reconstruct == QUDA_RECONSTRUCT_8) {
00115       if (!dagger) {
00116         dslash8XpayKernel <<<gridDim, blockDim, shared_bytes>>> 
00117           (out, outNorm, gauge0, gauge1, in, inNorm, parity, x, xNorm, a);
00118       } else {
00119         dslash8DaggerXpayKernel <<<gridDim, blockDim, shared_bytes>>>
00120           (out, outNorm, gauge0, gauge1, in, inNorm, parity, x, xNorm, a);
00121       }
00122     }
00123   }
00124  
00125   unbindSpinorTex<spinorN>(in, inNorm, x, xNorm);
00126  
00127 }
00128 
00129 // Wilson wrappers
00130 void dslashCuda(void *out, void *outNorm, const FullGauge gauge, const void *in, const void *inNorm, 
00131                 const int parity, const int dagger, const void *x, const void *xNorm, 
00132                 const double k, const int volume, const int length, const QudaPrecision precision) {
00133 
00134 #ifdef GPU_WILSON_DIRAC
00135   void *gauge0, *gauge1;
00136   bindGaugeTex(gauge, parity, &gauge0, &gauge1);
00137 
00138   if (precision != gauge.precision)
00139     errorQuda("Mixing gauge and spinor precision not supported");
00140 
00141   if (precision == QUDA_DOUBLE_PRECISION) {
00142 #if (__CUDA_ARCH__ >= 130)
00143     dslashCuda<2>((double2*)out, (float*)outNorm, (double2*)gauge0, (double2*)gauge1, 
00144                   gauge.reconstruct, (double2*)in, (float*)inNorm, parity, dagger, 
00145                   (double2*)x, (float*)xNorm, k, volume, length);
00146 #else
00147     errorQuda("Double precision not supported on this GPU");
00148 #endif
00149   } else if (precision == QUDA_SINGLE_PRECISION) {
00150     dslashCuda<4>((float4*)out, (float*)outNorm, (float4*)gauge0, (float4*)gauge1,
00151                   gauge.reconstruct, (float4*)in, (float*)inNorm, parity, dagger, 
00152                   (float4*)x, (float*)xNorm, k, volume, length);
00153   } else if (precision == QUDA_HALF_PRECISION) {
00154     dslashCuda<4>((short4*)out, (float*)outNorm, (short4*)gauge0, (short4*)gauge1,
00155                   gauge.reconstruct, (short4*)in, (float*)inNorm, parity, dagger, 
00156                   (short4*)x, (float*)xNorm, k, volume, length);
00157   }
00158   unbindGaugeTex(gauge);
00159 
00160   checkCudaError();
00161 #else
00162   errorQuda("Wilson dslash has not been built");
00163 #endif // GPU_WILSON_DIRAC
00164 
00165 }
00166 
00167 
00168 template <int N, typename spinorFloat, typename cloverFloat>
00169 void cloverCuda(spinorFloat *out, float *outNorm, const cloverFloat *clover,
00170                 const float *cloverNorm, const spinorFloat *in, const float *inNorm, 
00171                 const int parity, const int volume, const int length)
00172 {
00173   dim3 gridDim(volume/BLOCK_DIM, 1, 1);
00174   dim3 blockDim(BLOCK_DIM, 1, 1);
00175 
00176   int shared_bytes = blockDim.x*SHARED_FLOATS_PER_THREAD*bindSpinorTex<N>(length, in, inNorm);
00177   cloverKernel<<<gridDim, blockDim, shared_bytes>>> 
00178     (out, outNorm, clover, cloverNorm, in, inNorm, parity);
00179   unbindSpinorTex<N>(in, inNorm);
00180 }
00181 
00182 void cloverCuda(void *out, void *outNorm, const FullGauge gauge, const FullClover clover, 
00183                 const void *in, const void *inNorm, const int parity, const int volume,
00184                 const int length, const QudaPrecision precision) {
00185 
00186 #ifdef GPU_WILSON_DIRAC
00187   void *cloverP, *cloverNormP;
00188   QudaPrecision clover_prec = bindCloverTex(clover, parity, &cloverP, &cloverNormP);
00189 
00190   if (precision != clover_prec)
00191     errorQuda("Mixing clover and spinor precision not supported");
00192 
00193   if (precision == QUDA_DOUBLE_PRECISION) {
00194 #if (__CUDA_ARCH__ >= 130)
00195     cloverCuda<2>((double2*)out, (float*)outNorm, (double2*)cloverP, 
00196                   (float*)cloverNormP, (double2*)in, 
00197                   (float*)inNorm, parity, volume, length);
00198 #else
00199     errorQuda("Double precision not supported on this GPU");
00200 #endif
00201   } else if (precision == QUDA_SINGLE_PRECISION) {
00202     cloverCuda<4>((float4*)out, (float*)outNorm, (float4*)cloverP, 
00203                   (float*)cloverNormP, (float4*)in, 
00204                   (float*)inNorm, parity, volume, length);
00205   } else if (precision == QUDA_HALF_PRECISION) {
00206     cloverCuda<4>((short4*)out, (float*)outNorm, (short4*)cloverP, 
00207                   (float*)cloverNormP, (short4*)in,
00208                   (float*)inNorm, parity, volume, length);
00209   }
00210   unbindCloverTex(clover);
00211 
00212   checkCudaError();
00213 #else
00214   errorQuda("Clover dslash has not been built");
00215 #endif
00216 
00217 }
00218 
00219 // Clover wrappers
00220 template <int N, typename spinorFloat, typename cloverFloat, typename gaugeFloat>
00221 void cloverDslashCuda(spinorFloat *out, float *outNorm, const gaugeFloat gauge0, 
00222                       const gaugeFloat gauge1, const QudaReconstructType reconstruct, 
00223                       const cloverFloat *clover, const float *cloverNorm, const spinorFloat *in, 
00224                       const float* inNorm, const int parity, const int dagger, const spinorFloat *x, 
00225                       const float* xNorm, const double &a, const int volume, const int length)
00226 {
00227   dim3 gridDim(volume/BLOCK_DIM, 1, 1);
00228   dim3 blockDim(BLOCK_DIM, 1, 1);
00229 
00230   int shared_bytes = blockDim.x*SHARED_FLOATS_PER_THREAD*bindSpinorTex<N>(length, in, inNorm, x, xNorm);
00231 
00232   if (x==0) { // not xpay
00233     if (reconstruct == QUDA_RECONSTRUCT_NO) {
00234       if (!dagger) {
00235         cloverDslash18Kernel <<<gridDim, blockDim, shared_bytes>>> 
00236           (out, outNorm, gauge0, gauge1, clover, cloverNorm, in, inNorm, parity);
00237       } else {
00238         cloverDslash18DaggerKernel <<<gridDim, blockDim, shared_bytes>>>
00239           (out, outNorm, gauge0, gauge1, clover, cloverNorm, in, inNorm, parity);
00240       }
00241     } else if (reconstruct == QUDA_RECONSTRUCT_12) {
00242       if (!dagger) {
00243         cloverDslash12Kernel <<<gridDim, blockDim, shared_bytes>>> 
00244           (out, outNorm, gauge0, gauge1, clover, cloverNorm, in, inNorm, parity);
00245       } else {
00246         cloverDslash12DaggerKernel <<<gridDim, blockDim, shared_bytes>>>
00247           (out, outNorm, gauge0, gauge1, clover, cloverNorm, in, inNorm, parity);
00248       }
00249     } else {
00250       if (!dagger) {
00251         cloverDslash8Kernel <<<gridDim, blockDim, shared_bytes>>>       
00252           (out, outNorm, gauge0, gauge1, clover, cloverNorm, in, inNorm, parity);
00253       } else {
00254         cloverDslash8DaggerKernel <<<gridDim, blockDim, shared_bytes>>>
00255           (out, outNorm, gauge0, gauge1, clover, cloverNorm, in, inNorm, parity);
00256       }
00257     }
00258   } else { // doing xpay
00259     if (reconstruct == QUDA_RECONSTRUCT_NO) {
00260       if (!dagger) {
00261         cloverDslash18XpayKernel <<<gridDim, blockDim, shared_bytes>>> 
00262           (out, outNorm, gauge0, gauge1, clover, cloverNorm, in, inNorm, parity, x, xNorm, a);
00263       } else {
00264         cloverDslash18DaggerXpayKernel <<<gridDim, blockDim, shared_bytes>>>
00265           (out, outNorm, gauge0, gauge1, clover, cloverNorm, in, inNorm, parity, x, xNorm, a);
00266       }
00267     } else if (reconstruct == QUDA_RECONSTRUCT_12) {
00268       if (!dagger) {
00269         cloverDslash12XpayKernel <<<gridDim, blockDim, shared_bytes>>> 
00270           (out, outNorm, gauge0, gauge1, clover, cloverNorm, in, inNorm, parity, x, xNorm, a);
00271       } else {
00272         cloverDslash12DaggerXpayKernel <<<gridDim, blockDim, shared_bytes>>>
00273           (out, outNorm, gauge0, gauge1, clover, cloverNorm, in, inNorm, parity, x, xNorm, a);
00274       }
00275     } else {
00276       if (!dagger) {
00277         cloverDslash8XpayKernel <<<gridDim, blockDim, shared_bytes>>>   
00278           (out, outNorm, gauge0, gauge1, clover, cloverNorm, in, inNorm, parity, x, xNorm, a);
00279       } else {
00280         cloverDslash8DaggerXpayKernel <<<gridDim, blockDim, shared_bytes>>>
00281           (out, outNorm, gauge0, gauge1, clover, cloverNorm, in, inNorm, parity, x, xNorm, a);
00282       }
00283     }
00284   }
00285 
00286   unbindSpinorTex<N>(in, inNorm, x, xNorm);
00287 }
00288 
00289 void cloverDslashCuda(void *out, void *outNorm, const FullGauge gauge, const FullClover cloverInv,
00290                       const void *in, const void *inNorm, const int parity, const int dagger, 
00291                       const void *x, const void *xNorm, const double a, const int volume, 
00292                       const int length, const QudaPrecision precision) {
00293 
00294 #ifdef GPU_WILSON_DIRAC
00295   void *cloverP, *cloverNormP;
00296   QudaPrecision clover_prec = bindCloverTex(cloverInv, parity, &cloverP, &cloverNormP);
00297 
00298   void *gauge0, *gauge1;
00299 
00300   bindGaugeTex(gauge, parity, &gauge0, &gauge1);
00301 
00302   if (precision != gauge.precision)
00303     errorQuda("Mixing gauge and spinor precision not supported");
00304 
00305   if (precision != clover_prec)
00306     errorQuda("Mixing clover and spinor precision not supported");
00307 
00308   if (precision == QUDA_DOUBLE_PRECISION) {
00309 #if (__CUDA_ARCH__ >= 130)
00310     cloverDslashCuda<2>((double2*)out, (float*)outNorm, (double2*)gauge0, (double2*)gauge1, 
00311                         gauge.reconstruct, (double2*)cloverP, (float*)cloverNormP, (double2*)in, 
00312                         (float*)inNorm, parity, dagger, (double2*)x, (float*)xNorm, a, volume, length);
00313 #else
00314     errorQuda("Double precision not supported on this GPU");
00315 #endif
00316   } else if (precision == QUDA_SINGLE_PRECISION) {
00317     cloverDslashCuda<4>((float4*)out, (float*)outNorm, (float4*)gauge0, (float4*)gauge1, 
00318                         gauge.reconstruct, (float4*)cloverP, (float*)cloverNormP, (float4*)in, 
00319                         (float*)inNorm, parity, dagger, (float4*)x, (float*)xNorm, a, volume, length);
00320   } else if (precision == QUDA_HALF_PRECISION) {
00321     cloverDslashCuda<4>((short4*)out, (float*)outNorm, (short4*)gauge0, (short4*)gauge1, 
00322                         gauge.reconstruct, (short4*)cloverP, (float*)cloverNormP, (short4*)in,
00323                         (float*)inNorm, parity, dagger, (short4*)x, (float*)xNorm, a, volume, length);
00324   }
00325 
00326   unbindGaugeTex(gauge);
00327   unbindCloverTex(cloverInv);
00328 
00329   checkCudaError();
00330 #else
00331   errorQuda("Clover dslash has not been built");
00332 #endif
00333 
00334 
00335 }
00336 
00337 // Domain wall wrappers
00338 template <int N, typename spinorFloat, typename gaugeFloat>
00339 void domainWallDslashCuda(spinorFloat *out, float *outNorm, const gaugeFloat gauge0, 
00340                           const gaugeFloat gauge1, const QudaReconstructType reconstruct, 
00341                           const spinorFloat *in, const float* inNorm, const int parity, const int dagger, const spinorFloat *x, 
00342                           const float* xNorm, const double &m_f, const double &k2, const int volume_5d, const int length)
00343 {
00344 
00345   dim3 gridDim(volume_5d/BLOCK_DIM, 1, 1);
00346   dim3 blockDim(BLOCK_DIM, 1, 1);
00347 
00348   int shared_bytes = blockDim.x*SHARED_FLOATS_PER_THREAD*bindSpinorTex<N>(length, in, inNorm, x, xNorm);
00349 
00350   if (x==0) { // not xpay
00351     if (reconstruct == QUDA_RECONSTRUCT_NO) {
00352       if (!dagger) {
00353         domainWallDslash18Kernel <<<gridDim, blockDim, shared_bytes>>> 
00354           (out, outNorm, gauge0, gauge1, in, inNorm, parity, m_f);
00355       } else {
00356         domainWallDslash18DaggerKernel <<<gridDim, blockDim, shared_bytes>>>
00357           (out, outNorm, gauge0, gauge1, in, inNorm, parity, m_f);
00358       }
00359     } else if (reconstruct == QUDA_RECONSTRUCT_12) {
00360       if (!dagger) {
00361         domainWallDslash12Kernel <<<gridDim, blockDim, shared_bytes>>> 
00362           (out, outNorm, gauge0, gauge1, in, inNorm, parity, m_f);
00363       } else {
00364         domainWallDslash12DaggerKernel <<<gridDim, blockDim, shared_bytes>>>
00365           (out, outNorm, gauge0, gauge1, in, inNorm, parity, m_f);
00366       }
00367     } else {
00368       if (!dagger) {
00369         domainWallDslash8Kernel <<<gridDim, blockDim, shared_bytes>>>   
00370           (out, outNorm, gauge0, gauge1, in, inNorm, parity, m_f);
00371       } else {
00372         domainWallDslash8DaggerKernel <<<gridDim, blockDim, shared_bytes>>>
00373           (out, outNorm, gauge0, gauge1, in, inNorm, parity, m_f);
00374       }
00375     }
00376   } else { // doing xpay
00377     if (reconstruct == QUDA_RECONSTRUCT_NO) {
00378       if (!dagger) {
00379         domainWallDslash18XpayKernel <<<gridDim, blockDim, shared_bytes>>> 
00380           (out, outNorm, gauge0, gauge1, in, inNorm, parity, m_f, x, xNorm, k2);
00381       } else {
00382         domainWallDslash18DaggerXpayKernel <<<gridDim, blockDim, shared_bytes>>>
00383           (out, outNorm, gauge0, gauge1, in, inNorm, parity, m_f, x, xNorm, k2);
00384       }
00385     } else if (reconstruct == QUDA_RECONSTRUCT_12) {
00386       if (!dagger) {
00387         domainWallDslash12XpayKernel <<<gridDim, blockDim, shared_bytes>>> 
00388           (out, outNorm, gauge0, gauge1, in, inNorm, parity, m_f, x, xNorm, k2);
00389       } else {
00390         domainWallDslash12DaggerXpayKernel <<<gridDim, blockDim, shared_bytes>>>
00391           (out, outNorm, gauge0, gauge1, in, inNorm, parity, m_f, x, xNorm, k2);
00392       }
00393     } else {
00394       if (!dagger) {
00395         domainWallDslash8XpayKernel <<<gridDim, blockDim, shared_bytes>>>       
00396           (out, outNorm, gauge0, gauge1, in, inNorm, parity, m_f, x, xNorm, k2);
00397       } else {
00398         domainWallDslash8DaggerXpayKernel <<<gridDim, blockDim, shared_bytes>>>
00399           (out, outNorm, gauge0, gauge1, in, inNorm, parity, m_f, x, xNorm, k2);
00400       }
00401     }
00402   }
00403 
00404   unbindSpinorTex<N>(in, inNorm, x, xNorm);
00405 }
00406 
00407 void domainWallDslashCuda(void *out, void *outNorm, const FullGauge gauge, 
00408                           const void *in, const void *inNorm, const int parity, const int dagger, 
00409                           const void *x, const void *xNorm, const double m_f, const double k2, const int volume5d, 
00410                           const int length, const QudaPrecision precision) {
00411 
00412 #ifdef GPU_DOMAIN_WALL_DIRAC
00413   void *gauge0, *gauge1;
00414   bindGaugeTex(gauge, parity, &gauge0, &gauge1);
00415 
00416   if (precision != gauge.precision)
00417     errorQuda("Mixing gauge and spinor precision not supported");
00418 
00419   if (precision == QUDA_DOUBLE_PRECISION) {
00420 #if (__CUDA_ARCH__ >= 130)
00421     domainWallDslashCuda<2>((double2*)out, (float*)outNorm, (double2*)gauge0, (double2*)gauge1, 
00422                             gauge.reconstruct, (double2*)in, (float*)inNorm, parity, dagger, 
00423                             (double2*)x, (float*)xNorm, m_f, k2, volume5d, length);
00424 #else
00425     errorQuda("Double precision not supported on this GPU");
00426 #endif
00427   } else if (precision == QUDA_SINGLE_PRECISION) {
00428     domainWallDslashCuda<4>((float4*)out, (float*)outNorm, (float4*)gauge0, (float4*)gauge1, 
00429                             gauge.reconstruct, (float4*)in, (float*)inNorm, parity, dagger, 
00430                             (float4*)x, (float*)xNorm, m_f, k2, volume5d, length);
00431   } else if (precision == QUDA_HALF_PRECISION) {
00432     domainWallDslashCuda<4>((short4*)out, (float*)outNorm, (short4*)gauge0, (short4*)gauge1, 
00433                             gauge.reconstruct, (short4*)in, (float*)inNorm, parity, dagger, 
00434                             (short4*)x, (float*)xNorm, m_f, k2, volume5d, length);
00435   }
00436 
00437   unbindGaugeTex(gauge);
00438 
00439   checkCudaError();
00440 #else
00441   errorQuda("Domain wall dslash has not been built");
00442 #endif
00443 
00444 }
00445 
00446 template <int spinorN, typename spinorFloat, typename fatGaugeFloat, typename longGaugeFloat>
00447   void staggeredDslashCuda(spinorFloat *out, float *outNorm, const fatGaugeFloat *fatGauge0, const fatGaugeFloat *fatGauge1, 
00448                            const longGaugeFloat* longGauge0, const longGaugeFloat* longGauge1, 
00449                            const QudaReconstructType reconstruct, const spinorFloat *in, const float *inNorm,
00450                            const int parity, const int dagger, const spinorFloat *x, const float *xNorm, 
00451                            const double &a, const int volume, const int length, const QudaPrecision precision) {
00452     
00453   dim3 gridDim(volume/BLOCK_DIM, 1, 1);
00454   dim3 blockDim(BLOCK_DIM, 1, 1);
00455   if (precision == QUDA_HALF_PRECISION && (volume % 128 == 0)) {
00456     blockDim.x = 128;
00457     gridDim.x = volume/blockDim.x;
00458   }
00459   
00460   int shared_bytes = blockDim.x*6*bindSpinorTex<spinorN>(length, in, inNorm, x, xNorm);
00461   
00462   if (x==0) { // not doing xpay
00463     if (reconstruct == QUDA_RECONSTRUCT_12) {
00464       if (!dagger) {
00465         staggeredDslash12Kernel <<<gridDim, blockDim, shared_bytes>>> 
00466           (out, outNorm, fatGauge0, fatGauge1, longGauge0, longGauge1, in, inNorm, parity);
00467       } else {
00468         staggeredDslash12DaggerKernel <<<gridDim, blockDim, shared_bytes>>> 
00469           (out, outNorm, fatGauge0, fatGauge1, longGauge0, longGauge1, in, inNorm, parity);
00470       }
00471     } else if (reconstruct == QUDA_RECONSTRUCT_8){
00472           
00473       if (!dagger) {
00474         staggeredDslash8Kernel <<<gridDim, blockDim, shared_bytes>>> 
00475           (out, outNorm, fatGauge0, fatGauge1, longGauge0, longGauge1, in, inNorm, parity);
00476       } else {
00477         staggeredDslash8DaggerKernel <<<gridDim, blockDim, shared_bytes>>> 
00478           (out, outNorm, fatGauge0, fatGauge1, longGauge0, longGauge1, in, inNorm, parity);
00479       }
00480     }else{
00481       errorQuda("Invalid reconstruct value(%d) in function %s\n", reconstruct, __FUNCTION__);
00482     }
00483   } else { // doing xpay
00484     
00485     if (reconstruct == QUDA_RECONSTRUCT_12) {
00486       if (!dagger) {
00487         staggeredDslash12AxpyKernel <<<gridDim, blockDim, shared_bytes>>> 
00488           (out, outNorm, fatGauge0, fatGauge1, longGauge0, longGauge1, in, inNorm, parity, x, xNorm, a);
00489       } else {
00490         staggeredDslash12DaggerAxpyKernel <<<gridDim, blockDim, shared_bytes>>> 
00491           (out, outNorm, fatGauge0, fatGauge1, longGauge0, longGauge1, in, inNorm, parity, x, xNorm, a);
00492       }
00493     } else if (reconstruct == QUDA_RECONSTRUCT_8) {
00494       if (!dagger) {
00495         staggeredDslash8AxpyKernel <<<gridDim, blockDim, shared_bytes>>> 
00496           (out, outNorm, fatGauge0, fatGauge1, longGauge0, longGauge1, in, inNorm, parity, x, xNorm, a);
00497       } else {
00498         staggeredDslash8DaggerAxpyKernel <<<gridDim, blockDim, shared_bytes>>>
00499           (out, outNorm, fatGauge0, fatGauge1, longGauge0, longGauge1, in, inNorm, parity, x, xNorm, a);
00500       }
00501     }else{
00502       errorQuda("Invalid reconstruct value in function %s\n", __FUNCTION__);      
00503     }    
00504   }
00505   
00506   cudaThreadSynchronize();
00507   
00508   unbindSpinorTex<spinorN>(in, inNorm, x, xNorm);
00509 }
00510 
00511 
00512 template <int spinorN, typename spinorFloat, typename fatGaugeFloat, typename longGaugeFloat>
00513   void staggeredDslashNoReconCuda(spinorFloat *out, float *outNorm, const fatGaugeFloat *fatGauge0, const fatGaugeFloat *fatGauge1, 
00514                                   const longGaugeFloat* longGauge0, const longGaugeFloat* longGauge1, 
00515                                   const QudaReconstructType reconstruct, const spinorFloat *in, const float *inNorm,
00516                                   const int parity, const int dagger, const spinorFloat *x, const float *xNorm, 
00517                                   const double &a, const int volume, const int length, const QudaPrecision precision) 
00518 {  
00519   dim3 gridDim(volume/BLOCK_DIM, 1, 1);
00520   dim3 blockDim(BLOCK_DIM, 1, 1);
00521 
00522   if (precision == QUDA_HALF_PRECISION) {
00523     blockDim.x = 128;
00524     gridDim.x = volume/blockDim.x;
00525   }
00526   int shared_bytes = blockDim.x*6*bindSpinorTex<spinorN>(length, in, inNorm, x, xNorm);
00527   
00528   if (x==0) { // not doing xpay
00529     if (!dagger) {
00530       staggeredDslash18Kernel <<<gridDim, blockDim, shared_bytes>>> 
00531         (out, outNorm, fatGauge0, fatGauge1, longGauge0, longGauge1, in, inNorm, parity);
00532     } else {
00533       staggeredDslash18DaggerKernel <<<gridDim, blockDim, shared_bytes>>> 
00534         (out, outNorm, fatGauge0, fatGauge1, longGauge0, longGauge1, in, inNorm, parity);
00535     }    
00536   } else { // doing xpay
00537     
00538     if (!dagger) {
00539       staggeredDslash18AxpyKernel <<<gridDim, blockDim, shared_bytes>>> 
00540         (out, outNorm, fatGauge0, fatGauge1, longGauge0, longGauge1, in, inNorm, parity, x, xNorm, a);
00541     } else {
00542       staggeredDslash18DaggerAxpyKernel <<<gridDim, blockDim, shared_bytes>>>
00543         (out, outNorm, fatGauge0, fatGauge1, longGauge0, longGauge1, in, inNorm, parity, x, xNorm, a);
00544     }          
00545   }
00546   
00547   cudaThreadSynchronize();
00548 
00549   unbindSpinorTex<spinorN>(in, inNorm, x, xNorm);
00550 }
00551 
00552 
00553 void staggeredDslashCuda(void *out, void *outNorm, const FullGauge fatGauge, const FullGauge longGauge, 
00554                          const void *in, const void *inNorm, 
00555                          const int parity, const int dagger, const void *x, const void *xNorm, 
00556                          const double k, const int volume, const int length, const QudaPrecision precision) 
00557 {
00558 
00559 #ifdef GPU_STAGGERED_DIRAC
00560   void *fatGauge0, *fatGauge1;
00561   void* longGauge0, *longGauge1;
00562   bindFatGaugeTex(fatGauge, parity, &fatGauge0, &fatGauge1);
00563   bindLongGaugeTex(longGauge, parity, &longGauge0, &longGauge1);
00564     
00565   if (precision != fatGauge.precision || precision != longGauge.precision){
00566     errorQuda("Mixing gauge and spinor precision not supported");
00567   }
00568     
00569   if (precision == QUDA_DOUBLE_PRECISION) {
00570 #if (__CUDA_ARCH__ >= 130)
00571     if (longGauge.reconstruct == QUDA_RECONSTRUCT_NO){
00572       staggeredDslashNoReconCuda<2>((double2*)out, (float*)outNorm, (double2*)fatGauge0, (double2*)fatGauge1,                          
00573                                     (double2*)longGauge0, (double2*)longGauge1,
00574                                     longGauge.reconstruct, (double2*)in, (float*)inNorm, parity, dagger, 
00575                                     (double2*)x, (float*)xNorm, k, volume, length, precision);
00576     }else{
00577       staggeredDslashCuda<2>((double2*)out, (float*)outNorm, (double2*)fatGauge0, (double2*)fatGauge1,                         
00578                              (double2*)longGauge0, (double2*)longGauge1,
00579                              longGauge.reconstruct, (double2*)in, (float*)inNorm, parity, dagger, 
00580                              (double2*)x, (float*)xNorm, k, volume, length, precision);
00581     }
00582     
00583 #else
00584     errorQuda("Double precision not supported on this GPU");
00585 #endif
00586   } else if (precision == QUDA_SINGLE_PRECISION) {
00587     if (longGauge.reconstruct == QUDA_RECONSTRUCT_NO){
00588       staggeredDslashNoReconCuda<2>((float2*)out, (float*)outNorm, (float2*)fatGauge0, (float2*)fatGauge1,
00589                                     (float2*)longGauge0, (float2*)longGauge1,
00590                                     longGauge.reconstruct, (float2*)in, (float*)inNorm, parity, dagger, 
00591                                     (float2*)x, (float*)xNorm, k, volume, length, precision);
00592     }else{
00593       staggeredDslashCuda<2>((float2*)out, (float*)outNorm, (float2*)fatGauge0, (float2*)fatGauge1,
00594                              (float4*)longGauge0, (float4*)longGauge1,
00595                              longGauge.reconstruct, (float2*)in, (float*)inNorm, parity, dagger, 
00596                              (float2*)x, (float*)xNorm, k, volume, length, precision);
00597     }
00598   } else if (precision == QUDA_HALF_PRECISION) {        
00599     if (longGauge.reconstruct == QUDA_RECONSTRUCT_NO){
00600       staggeredDslashNoReconCuda<2>((short2*)out, (float*)outNorm, (short2*)fatGauge0, (short2*)fatGauge1,
00601                                     (short2*)longGauge0, (short2*)longGauge1,
00602                                     longGauge.reconstruct, (short2*)in, (float*)inNorm, parity, dagger, 
00603                                     (short2*)x, (float*)xNorm, k, volume, length, precision);
00604     }else{
00605       staggeredDslashCuda<2>((short2*)out, (float*)outNorm, (short2*)fatGauge0, (short2*)fatGauge1,
00606                              (short4*)longGauge0, (short4*)longGauge1,
00607                              longGauge.reconstruct, (short2*)in, (float*)inNorm, parity, dagger, 
00608                              (short2*)x, (float*)xNorm, k, volume, length, precision);
00609     }
00610   }
00611 
00612   unbindLongGaugeTex(longGauge);
00613   unbindFatGaugeTex(fatGauge);
00614 
00615   checkCudaError();
00616 #else
00617   errorQuda("Staggered dslash has not been built");
00618 #endif  
00619 
00620 }
00621 
00622 void setTwistParam(double &a, double &b, const double &kappa, const double &mu, 
00623                    const int dagger, const QudaTwistGamma5Type twist) {
00624   if (twist == QUDA_TWIST_GAMMA5_DIRECT) {
00625     a = 2.0 * kappa * mu;
00626     b = 1.0;
00627   } else if (twist == QUDA_TWIST_GAMMA5_INVERSE) {
00628     a = -2.0 * kappa * mu;
00629     b = 1.0 / (1.0 + a*a);
00630   } else {
00631     errorQuda("Twist type %d not defined\n", twist);
00632   }
00633   if (dagger) a *= -1.0;
00634 
00635 }
00636 
00637 template <int N, typename spinorFloat>
00638 void twistGamma5Cuda(spinorFloat *out, float *outNorm, const spinorFloat *in, 
00639                      const float *inNorm, const int dagger, const double &kappa, 
00640                      const double &mu, const int volume, const int length, 
00641                      const QudaTwistGamma5Type twist)
00642 {
00643   dim3 gridDim(volume/BLOCK_DIM, 1, 1);
00644   dim3 blockDim(BLOCK_DIM, 1, 1);
00645 
00646   double a=0.0, b=0.0;
00647   setTwistParam(a, b, kappa, mu, dagger, twist);
00648 
00649   bindSpinorTex<N>(length, in, inNorm);
00650   twistGamma5Kernel<<<gridDim, blockDim, 0>>> (out, outNorm, a, b);
00651   unbindSpinorTex<N>(in, inNorm);
00652 }
00653 
00654 void twistGamma5Cuda(void *out, void *outNorm, const void *in, const void *inNorm,
00655                      const int dagger, const double kappa, const double mu, const int volume, 
00656                      const int length, const QudaPrecision precision, 
00657                      const QudaTwistGamma5Type twist) {
00658 
00659 #ifdef GPU_TWISTED_MASS_DIRAC
00660   if (precision == QUDA_DOUBLE_PRECISION) {
00661 #if (__CUDA_ARCH__ >= 130)
00662     twistGamma5Cuda<2>((double2*)out, (float*)outNorm, (double2*)in, 
00663                        (float*)inNorm, dagger, kappa, mu, volume, length, twist);
00664 #else
00665     errorQuda("Double precision not supported on this GPU");
00666 #endif
00667   } else if (precision == QUDA_SINGLE_PRECISION) {
00668     twistGamma5Cuda<4>((float4*)out, (float*)outNorm, (float4*)in, 
00669                        (float*)inNorm, dagger, kappa, mu, volume, length, twist);
00670   } else if (precision == QUDA_HALF_PRECISION) {
00671     twistGamma5Cuda<4>((short4*)out, (float*)outNorm, (short4*)in,
00672                        (float*)inNorm, dagger, kappa, mu, volume, length, twist);
00673   }
00674   checkCudaError();
00675 #else
00676   errorQuda("Twisted mass dslash has not been built");
00677 #endif // GPU_TWISTED_MASS_DIRAC
00678 
00679 }
00680 
00681 // Twisted mass wrappers
00682 template <int N, typename spinorFloat, typename gaugeFloat>
00683 void twistedMassDslashCuda(spinorFloat *out, float *outNorm, const gaugeFloat gauge0, 
00684                            const gaugeFloat gauge1, const QudaReconstructType reconstruct, 
00685                            const spinorFloat *in, const float* inNorm, const int parity, 
00686                            const int dagger, const spinorFloat *x, const float* xNorm, 
00687                            const double &kappa, const double &mu, const double &k, 
00688                            const int volume, const int length)
00689 {
00690 
00691   dim3 gridDim(volume/BLOCK_DIM, 1, 1);
00692   dim3 blockDim(BLOCK_DIM, 1, 1);
00693 
00694   int shared_bytes = blockDim.x*SHARED_FLOATS_PER_THREAD*bindSpinorTex<N>(length, in, inNorm, x, xNorm);
00695 
00696   double a=0.0, b=0.0;
00697   setTwistParam(a, b, kappa, mu, dagger, QUDA_TWIST_GAMMA5_INVERSE);
00698 
00699   if (x==0) { // not xpay
00700     if (reconstruct == QUDA_RECONSTRUCT_NO) {
00701       if (!dagger) {
00702         twistedMassDslash18Kernel <<<gridDim, blockDim, shared_bytes>>> 
00703           (out, outNorm, gauge0, gauge1, in, inNorm, parity, a, b);
00704       } else {
00705         twistedMassDslash18DaggerKernel <<<gridDim, blockDim, shared_bytes>>>
00706           (out, outNorm, gauge0, gauge1, in, inNorm, parity, a, b);
00707       }
00708     } else if (reconstruct == QUDA_RECONSTRUCT_12) {
00709       if (!dagger) {
00710         twistedMassDslash12Kernel <<<gridDim, blockDim, shared_bytes>>> 
00711           (out, outNorm, gauge0, gauge1, in, inNorm, parity, a, b);
00712       } else {
00713         twistedMassDslash12DaggerKernel <<<gridDim, blockDim, shared_bytes>>>
00714           (out, outNorm, gauge0, gauge1, in, inNorm, parity, a, b);
00715       }
00716     } else {
00717       if (!dagger) {
00718         twistedMassDslash8Kernel <<<gridDim, blockDim, shared_bytes>>>  
00719           (out, outNorm, gauge0, gauge1, in, inNorm, parity, a, b);
00720       } else {
00721         twistedMassDslash8DaggerKernel <<<gridDim, blockDim, shared_bytes>>>
00722           (out, outNorm, gauge0, gauge1, in, inNorm, parity, a, b);
00723       }
00724     }
00725   } else { // doing xpay
00726     b *= k;
00727     if (reconstruct == QUDA_RECONSTRUCT_NO) {
00728       if (!dagger) {
00729         twistedMassDslash18XpayKernel <<<gridDim, blockDim, shared_bytes>>> 
00730           (out, outNorm, gauge0, gauge1, in, inNorm, parity, a, b, x, xNorm);
00731       } else {
00732         twistedMassDslash18DaggerXpayKernel <<<gridDim, blockDim, shared_bytes>>>
00733           (out, outNorm, gauge0, gauge1, in, inNorm, parity, a, b, x, xNorm);
00734       }
00735     } else if (reconstruct == QUDA_RECONSTRUCT_12) {
00736       if (!dagger) {
00737         twistedMassDslash12XpayKernel <<<gridDim, blockDim, shared_bytes>>> 
00738           (out, outNorm, gauge0, gauge1, in, inNorm, parity, a, b, x, xNorm);
00739       } else {
00740         twistedMassDslash12DaggerXpayKernel <<<gridDim, blockDim, shared_bytes>>>
00741           (out, outNorm, gauge0, gauge1, in, inNorm, parity, a, b, x, xNorm);
00742       }
00743     } else {
00744       if (!dagger) {
00745         twistedMassDslash8XpayKernel <<<gridDim, blockDim, shared_bytes>>>      
00746           (out, outNorm, gauge0, gauge1, in, inNorm, parity, a, b, x, xNorm);
00747       } else {
00748         twistedMassDslash8DaggerXpayKernel <<<gridDim, blockDim, shared_bytes>>>
00749           (out, outNorm, gauge0, gauge1, in, inNorm, parity, a, b, x, xNorm);
00750       }
00751     }
00752   }
00753   
00754   unbindSpinorTex<N>(in, inNorm, x, xNorm);
00755 }
00756 
00757 void twistedMassDslashCuda(void *out, void *outNorm, const FullGauge gauge, 
00758                            const void *in, const void *inNorm, const int parity, const int dagger, 
00759                            const void *x, const void *xNorm, const double kappa, const double mu, 
00760                            const double a, const int volume, const int length, 
00761                            const QudaPrecision precision) {
00762 
00763 #ifdef GPU_TWISTED_MASS_DIRAC
00764   void *gauge0, *gauge1;
00765   bindGaugeTex(gauge, parity, &gauge0, &gauge1);
00766 
00767   if (precision != gauge.precision)
00768     errorQuda("Mixing gauge and spinor precision not supported");
00769 
00770   if (precision == QUDA_DOUBLE_PRECISION) {
00771 #if (__CUDA_ARCH__ >= 130)
00772     twistedMassDslashCuda<2>((double2*)out, (float*)outNorm, (double2*)gauge0, (double2*)gauge1, 
00773                              gauge.reconstruct, (double2*)in, (float*)inNorm, parity, dagger, 
00774                              (double2*)x, (float*)xNorm, kappa, mu, a, volume, length);
00775 #else
00776     errorQuda("Double precision not supported on this GPU");
00777 #endif
00778   } else if (precision == QUDA_SINGLE_PRECISION) {
00779     twistedMassDslashCuda<4>((float4*)out, (float*)outNorm, (float4*)gauge0, (float4*)gauge1, 
00780                              gauge.reconstruct, (float4*)in, (float*)inNorm, parity, dagger, 
00781                              (float4*)x, (float*)xNorm, kappa, mu, a, volume, length);
00782   } else if (precision == QUDA_HALF_PRECISION) {
00783     twistedMassDslashCuda<4>((short4*)out, (float*)outNorm, (short4*)gauge0, (short4*)gauge1, 
00784                              gauge.reconstruct, (short4*)in, (float*)inNorm, parity, dagger, 
00785                              (short4*)x, (float*)xNorm, kappa, mu, a, volume, length);
00786   }
00787 
00788   unbindGaugeTex(gauge);
00789 
00790   checkCudaError();
00791 #else
00792   errorQuda("Twisted mass dslash has not been built");
00793 #endif
00794 
00795 }
00796 
00797 #if defined(GPU_FATLINK)||defined(GPU_GAUGE_FORCE)|| defined(GPU_FERMION_FORCE)
00798 #include <force_common.h>
00799 #include "force_kernel_common.cu"
00800 #endif
00801 
00802 #ifdef GPU_FATLINK
00803 #include "llfat_quda.cu"
00804 #endif
00805 
00806 #ifdef GPU_GAUGE_FORCE
00807 #include "gauge_force_quda.cu"
00808 #endif
00809 
00810 #ifdef GPU_FERMION_FORCE
00811 #include "fermion_force_quda.cu"
00812 #endif
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Defines