QUDA  1.0.0
dslash_wilson.cu
Go to the documentation of this file.
1 #include <gauge_field.h>
2 #include <color_spinor_field.h>
3 #include <dslash.h>
4 #include <worker.h>
5 
6 #include <dslash_policy.cuh>
8 
18 namespace quda
19 {
20 
25  template <typename Float, int nDim, int nColor, int nParity, bool dagger, bool xpay, KernelType kernel_type, typename Arg>
26  struct WilsonLaunch {
27  static constexpr const char *kernel = "quda::wilsonGPU"; // kernel name for jit compilation
28  template <typename Dslash>
29  inline static void launch(Dslash &dslash, TuneParam &tp, Arg &arg, const cudaStream_t &stream)
30  {
31  dslash.launch(wilsonGPU<Float, nDim, nColor, nParity, dagger, xpay, kernel_type, Arg>, tp, arg, stream);
32  }
33  };
34 
35  template <typename Float, int nDim, int nColor, typename Arg> class Wilson : public Dslash<Float>
36  {
37 
38 protected:
39  Arg &arg;
41 
42 public:
43  Wilson(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) :
44  Dslash<Float>(arg, out, in, "kernels/dslash_wilson.cuh"),
45  arg(arg),
46  in(in)
47  {
48  }
49 
50  virtual ~Wilson() {}
51 
52  void apply(const cudaStream_t &stream)
53  {
54  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
56  Dslash<Float>::template instantiate<WilsonLaunch, nDim, nColor>(tp, arg, stream);
57  }
58 
59  TuneKey tuneKey() const
60  {
61  return TuneKey(in.VolString(), typeid(*this).name(), Dslash<Float>::aux[arg.kernel_type]);
62  }
63  };
64 
65  template <typename Float, int nColor, QudaReconstructType recon> struct WilsonApply {
66 
67  inline WilsonApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a,
68  const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
69  {
70  constexpr int nDim = 4;
71  WilsonArg<Float, nColor, recon> arg(out, in, U, a, x, parity, dagger, comm_override);
73 
75  wilson, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
76  in.GhostFaceCB(), profile);
77  policy.apply(0);
78 
80  }
81  };
82 
83  // Apply the Wilson operator
84  // out(x) = M*in = - a*\sum_mu U_{-\mu}(x)in(x+mu) + U^\dagger_mu(x-mu)in(x-mu)
85  // Uses the a normalization for the Wilson operator.
86  void ApplyWilson(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a,
87  const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
88  {
89 #ifdef GPU_WILSON_DIRAC
90  if (in.V() == out.V()) errorQuda("Aliasing pointers");
91  if (in.FieldOrder() != out.FieldOrder())
92  errorQuda("Field order mismatch in = %d, out = %d", in.FieldOrder(), out.FieldOrder());
93 
94  // check all precisions match
95  checkPrecision(out, in, U);
96 
97  // check all locations match
98  checkLocation(out, in, U);
99 
100  instantiate<WilsonApply, WilsonReconstruct>(out, in, U, a, x, parity, dagger, comm_override, profile);
101 #else
102  errorQuda("Wilson dslash has not been built");
103 #endif // GPU_WILSON_DIRAC
104  }
105 
106 } // namespace quda
void launch(T *f, const TuneParam &tp, Arg &arg, const cudaStream_t &stream)
Definition: dslash.h:101
void setParam(Arg &arg)
Definition: dslash.h:66
void apply(const cudaStream_t &stream)
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define checkPrecision(...)
#define errorQuda(...)
Definition: util_quda.h:121
cudaStream_t * stream
static void launch(Dslash &dslash, TuneParam &tp, Arg &arg, const cudaStream_t &stream)
const char * VolString() const
void apply(const cudaStream_t &stream)
Parameter structure for driving the Wilson operator.
WilsonApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
__device__ __host__ void wilson(Arg &arg, int idx, int s, int parity)
const ColorSpinorField & in
TuneKey tuneKey() const
cpuColorSpinorField * in
const int * GhostFaceCB() const
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
static constexpr const char * kernel
#define checkLocation(...)
Wilson(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in)
cpuColorSpinorField * out
virtual ~Wilson()
void ApplyWilson(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double kappa, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
Driver for applying the Wilson stencil.
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
#define checkCudaError()
Definition: util_quda.h:161
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
QudaDagType dagger
Definition: test_util.cpp:1620
This is a helper class that is used to instantiate the correct templated kernel for the dslash...
QudaParity parity
Definition: covdev_test.cpp:54
QudaFieldOrder FieldOrder() const