QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
dslash_twisted_mass.cu
Go to the documentation of this file.
1 #include <gauge_field.h>
2 #include <color_spinor_field.h>
3 #include <dslash.h>
4 #include <worker.h>
5 
6 #include <dslash_policy.cuh>
8 
13 namespace quda
14 {
15 
20  template <typename Float, int nDim, int nColor, int nParity, bool dagger, bool xpay, KernelType kernel_type, typename Arg>
22  static constexpr const char *kernel = "quda::twistedMassGPU"; // kernel name for jit compilation
23  template <typename Dslash>
24  inline static void launch(Dslash &dslash, TuneParam &tp, Arg &arg, const cudaStream_t &stream)
25  {
26  static_assert(xpay == true, "Twisted-mass operator only defined for xpay");
27  dslash.launch(twistedMassGPU<Float, nDim, nColor, nParity, dagger, xpay, kernel_type, Arg>, tp, arg, stream);
28  }
29  };
30 
31  template <typename Float, int nDim, int nColor, typename Arg> class TwistedMass : public Dslash<Float>
32  {
33 
34 protected:
35  Arg &arg;
37 
38 public:
40  Dslash<Float>(arg, out, in, "kernels/dslash_twisted_mass.cuh"),
41  arg(arg),
42  in(in)
43  {
44  }
45 
46  virtual ~TwistedMass() {}
47 
48  void apply(const cudaStream_t &stream)
49  {
50  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
52  if (arg.xpay)
53  Dslash<Float>::template instantiate<TwistedMassLaunch, nDim, nColor, true>(tp, arg, stream);
54  else
55  errorQuda("Twisted-mass operator only defined for xpay=true");
56  }
57 
58  long long flops() const
59  {
60  long long flops = Dslash<Float>::flops();
61  switch (arg.kernel_type) {
62  case EXTERIOR_KERNEL_X:
63  case EXTERIOR_KERNEL_Y:
64  case EXTERIOR_KERNEL_Z:
65  case EXTERIOR_KERNEL_T:
66  case EXTERIOR_KERNEL_ALL: break; // twisted-mass flops are in the interior kernel
67  case INTERIOR_KERNEL:
68  case KERNEL_POLICY:
69  flops += 2 * nColor * 4 * 2 * in.Volume(); // complex * Nc * Ns * fma * vol
70  break;
71  }
72  return flops;
73  }
74 
75  TuneKey tuneKey() const
76  {
77  return TuneKey(in.VolString(), typeid(*this).name(), Dslash<Float>::aux[arg.kernel_type]);
78  }
79  };
80 
81  template <typename Float, int nColor, QudaReconstructType recon> struct TwistedMassApply {
82 
83  inline TwistedMassApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double b,
84  const ColorSpinorField &x, int parity, bool dagger, const int *comm_override,
85  TimeProfile &profile)
86  {
87  constexpr int nDim = 4;
88  TwistedMassArg<Float, nColor, recon> arg(out, in, U, a, b, x, parity, dagger, comm_override);
90 
92  twisted, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
93  in.GhostFaceCB(), profile);
94  policy.apply(0);
95 
97  }
98  };
99 
100  // Apply the twisted-mass Dslash operator
101  // out(x) = M*in = (1 + i*b*gamma_5)*in(x) + a*\sum_mu U_{-\mu}(x)in(x+mu) + U^\dagger_mu(x-mu)in(x-mu)
102  // Uses the kappa normalization for the Wilson operator, with a = -kappa.
103  void ApplyTwistedMass(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double b,
104  const ColorSpinorField &x, int parity, bool dagger, const int *comm_override,
105  TimeProfile &profile)
106  {
107 #ifdef GPU_TWISTED_MASS_DIRAC
108  if (in.V() == out.V()) errorQuda("Aliasing pointers");
109  if (in.FieldOrder() != out.FieldOrder())
110  errorQuda("Field order mismatch in = %d, out = %d", in.FieldOrder(), out.FieldOrder());
111 
112  // check all precisions match
113  checkPrecision(out, in, U);
114 
115  // check all locations match
116  checkLocation(out, in, U);
117 
118  instantiate<TwistedMassApply>(out, in, U, a, b, x, parity, dagger, comm_override, profile);
119 #else
120  errorQuda("Twisted-mass dslash has not been built");
121 #endif // GPU_TWISTED_MASS_DIRAC
122  }
123 
124 } // namespace quda
void launch(T *f, const TuneParam &tp, Arg &arg, const cudaStream_t &stream)
Definition: dslash.h:101
static constexpr const char * kernel
void setParam(Arg &arg)
Definition: dslash.h:66
void apply(const cudaStream_t &stream)
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define checkPrecision(...)
#define errorQuda(...)
Definition: util_quda.h:121
TuneKey tuneKey() const
cudaStream_t * stream
const ColorSpinorField & in
const char * VolString() const
TwistedMassApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double b, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
void xpay(ColorSpinorField &x, double a, ColorSpinorField &y)
Definition: blas_quda.h:37
This is a helper class that is used to instantiate the correct templated kernel for the dslash...
TwistedMass(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in)
const int nColor
Definition: covdev_test.cpp:75
void apply(const cudaStream_t &stream)
cpuColorSpinorField * in
const int * GhostFaceCB() const
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
#define checkLocation(...)
cpuColorSpinorField * out
void ApplyTwistedMass(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double b, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
Driver for applying the twisted-mass stencil.
static void launch(Dslash &dslash, TuneParam &tp, Arg &arg, const cudaStream_t &stream)
unsigned long long flops
Definition: blas_quda.cu:22
long long flops() const
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
#define checkCudaError()
Definition: util_quda.h:161
virtual long long flops() const
Definition: dslash.h:316
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
QudaDagType dagger
Definition: test_util.cpp:1620
QudaParity parity
Definition: covdev_test.cpp:54
QudaFieldOrder FieldOrder() const