QUDA  1.0.0
dslash_twisted_clover.cu
Go to the documentation of this file.
1 #include <gauge_field.h>
2 #include <color_spinor_field.h>
3 #include <clover_field.h>
4 #include <dslash.h>
5 #include <worker.h>
6 
7 #include <dslash_policy.cuh>
9 
14 namespace quda
15 {
16 
21  template <typename Float, int nDim, int nColor, int nParity, bool dagger, bool xpay, KernelType kernel_type, typename Arg>
23  static constexpr const char *kernel = "quda::wilsonCloverGPU"; // kernel name for jit compilation
24  template <typename Dslash>
25  inline static void launch(Dslash &dslash, TuneParam &tp, Arg &arg, const cudaStream_t &stream)
26  {
27  static_assert(xpay == true, "Twisted-clover operator only defined for xpay");
28  dslash.launch(wilsonCloverGPU<Float, nDim, nColor, nParity, dagger, xpay, kernel_type, Arg>, tp, arg, stream);
29  }
30  };
31 
32  template <typename Float, int nDim, int nColor, typename Arg> class TwistedClover : public Dslash<Float>
33  {
34 
35 protected:
36  Arg &arg;
38 
39 public:
41  Dslash<Float>(arg, out, in, "kernels/dslash_wilson_clover.cuh"),
42  arg(arg),
43  in(in)
44  {
45  }
46 
47  virtual ~TwistedClover() {}
48 
49  void apply(const cudaStream_t &stream)
50  {
51  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
53  if (arg.xpay)
54  Dslash<Float>::template instantiate<TwistedCloverLaunch, nDim, nColor, true>(tp, arg, stream);
55  else
56  errorQuda("Twisted-clover operator only defined for xpay=true");
57  }
58 
59  long long flops() const
60  {
61  int clover_flops = 504 + 48;
62  long long flops = Dslash<Float>::flops();
63  switch (arg.kernel_type) {
64  case EXTERIOR_KERNEL_X:
65  case EXTERIOR_KERNEL_Y:
66  case EXTERIOR_KERNEL_Z:
67  case EXTERIOR_KERNEL_T:
68  case EXTERIOR_KERNEL_ALL: break; // all clover flops are in the interior kernel
69  case INTERIOR_KERNEL:
70  case KERNEL_POLICY: flops += clover_flops * in.Volume(); break;
71  }
72  return flops;
73  }
74 
75  long long bytes() const
76  {
77  bool isFixed = (in.Precision() == sizeof(short) || in.Precision() == sizeof(char)) ? true : false;
78  int clover_bytes = 72 * in.Precision() + (isFixed ? 2 * sizeof(float) : 0);
79 
80  long long bytes = Dslash<Float>::bytes();
81  switch (arg.kernel_type) {
82  case EXTERIOR_KERNEL_X:
83  case EXTERIOR_KERNEL_Y:
84  case EXTERIOR_KERNEL_Z:
85  case EXTERIOR_KERNEL_T:
86  case EXTERIOR_KERNEL_ALL: break;
87  case INTERIOR_KERNEL:
88  case KERNEL_POLICY: bytes += clover_bytes * in.Volume(); break;
89  }
90 
91  return bytes;
92  }
93 
94  TuneKey tuneKey() const
95  {
96  return TuneKey(in.VolString(), typeid(*this).name(), Dslash<Float>::aux[arg.kernel_type]);
97  }
98  };
99 
100  template <typename Float, int nColor, QudaReconstructType recon> struct TwistedCloverApply {
101 
103  const CloverField &C, double a, double b, const ColorSpinorField &x, int parity,
104  bool dagger, const int *comm_override, TimeProfile &profile)
105  {
106  constexpr int nDim = 4;
107  WilsonCloverArg<Float, nColor, recon, true> arg(out, in, U, C, a, b, x, parity, dagger, comm_override);
109 
111  twisted, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
112  in.GhostFaceCB(), profile);
113  policy.apply(0);
114 
115  checkCudaError();
116  }
117  };
118 
119  // Apply the twisted-mass Dslash operator
120  // out(x) = M*in = (A + i*b*gamma_5)*in(x) + a*\sum_mu U_{-\mu}(x)in(x+mu) + U^\dagger_mu(x-mu)in(x-mu)
121  // Uses the kappa normalization for the Wilson operator, with a = -kappa.
123  double a, double b, const ColorSpinorField &x, int parity, bool dagger,
124  const int *comm_override, TimeProfile &profile)
125  {
126 #ifdef GPU_TWISTED_CLOVER_DIRAC
127  if (in.V() == out.V()) errorQuda("Aliasing pointers");
128  if (in.FieldOrder() != out.FieldOrder())
129  errorQuda("Field order mismatch in = %d, out = %d", in.FieldOrder(), out.FieldOrder());
130 
131  // check all precisions match
132  checkPrecision(out, in, U, C);
133 
134  // check all locations match
135  checkLocation(out, in, U, C);
136 
137  instantiate<TwistedCloverApply>(out, in, U, C, a, b, x, parity, dagger, comm_override, profile);
138 #else
139  errorQuda("Twisted-clover dslash has not been built");
140 #endif // GPU_TWISTED_CLOVEr_DIRAC
141  }
142 
143 } // namespace quda
void apply(const cudaStream_t &stream)
void launch(T *f, const TuneParam &tp, Arg &arg, const cudaStream_t &stream)
Definition: dslash.h:101
void setParam(Arg &arg)
Definition: dslash.h:66
void apply(const cudaStream_t &stream)
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
TwistedClover(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in)
#define checkPrecision(...)
#define errorQuda(...)
Definition: util_quda.h:121
cudaStream_t * stream
void ApplyTwistedClover(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const CloverField &C, double a, double b, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
Driver for applying the twisted-clover stencil.
const char * VolString() const
void xpay(ColorSpinorField &x, double a, ColorSpinorField &y)
Definition: blas_quda.h:37
static void launch(Dslash &dslash, TuneParam &tp, Arg &arg, const cudaStream_t &stream)
virtual long long bytes() const
Definition: dslash.h:364
cpuColorSpinorField * in
const int * GhostFaceCB() const
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
#define checkLocation(...)
const ColorSpinorField & in
cpuColorSpinorField * out
static constexpr const char * kernel
unsigned long long flops
Definition: blas_quda.cu:22
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
TwistedCloverApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const CloverField &C, double a, double b, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
This is a helper class that is used to instantiate the correct templated kernel for the dslash...
#define checkCudaError()
Definition: util_quda.h:161
virtual long long flops() const
Definition: dslash.h:316
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
QudaPrecision Precision() const
QudaDagType dagger
Definition: test_util.cpp:1620
QudaParity parity
Definition: covdev_test.cpp:54
QudaFieldOrder FieldOrder() const
unsigned long long bytes
Definition: blas_quda.cu:23