QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
dslash_ndeg_twisted_mass.cu
Go to the documentation of this file.
1 #include <gauge_field.h>
2 #include <color_spinor_field.h>
3 #include <dslash.h>
4 #include <worker.h>
5 
6 #include <dslash_policy.cuh>
8 
14 namespace quda
15 {
16 
21  template <typename Float, int nDim, int nColor, int nParity, bool dagger, bool xpay, KernelType kernel_type, typename Arg>
23  static constexpr const char *kernel = "quda::ndegTwistedMassGPU"; // kernel name for jit compilation
24  template <typename Dslash>
25  inline static void launch(Dslash &dslash, TuneParam &tp, Arg &arg, const cudaStream_t &stream)
26  {
27  static_assert(xpay == true, "Non-generate twisted-mass operator only defined for xpay");
28  dslash.launch(ndegTwistedMassGPU<Float, nDim, nColor, nParity, dagger, xpay, kernel_type, Arg>, tp, arg, stream);
29  }
30  };
31 
32  template <typename Float, int nDim, int nColor, typename Arg> class NdegTwistedMass : public Dslash<Float>
33  {
34 
35 protected:
36  Arg &arg;
38 
39 public:
41  Dslash<Float>(arg, out, in, "kernels/dslash_ndeg_twisted_mass.cuh"),
42  arg(arg),
43  in(in)
44  {
46  }
47 
48  virtual ~NdegTwistedMass() {}
49 
50  void apply(const cudaStream_t &stream)
51  {
52  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
54  if (arg.xpay)
55  Dslash<Float>::template instantiate<NdegTwistedMassLaunch, nDim, nColor, true>(tp, arg, stream);
56  else
57  errorQuda("Non-degenerate twisted-mass operator only defined for xpay=true");
58  }
59 
60  long long flops() const
61  {
62  long long flops = Dslash<Float>::flops();
63  switch (arg.kernel_type) {
64  case EXTERIOR_KERNEL_X:
65  case EXTERIOR_KERNEL_Y:
66  case EXTERIOR_KERNEL_Z:
67  case EXTERIOR_KERNEL_T:
68  case EXTERIOR_KERNEL_ALL: break; // twisted-mass flops are in the interior kernel
69  case INTERIOR_KERNEL:
70  case KERNEL_POLICY:
71  flops += 2 * nColor * 4 * 4 * in.Volume(); // complex * Nc * Ns * fma * vol
72  break;
73  }
74  return flops;
75  }
76 
77  TuneKey tuneKey() const
78  {
79  return TuneKey(in.VolString(), typeid(*this).name(), Dslash<Float>::aux[arg.kernel_type]);
80  }
81  };
82 
83  template <typename Float, int nColor, QudaReconstructType recon> struct NdegTwistedMassApply {
84 
86  double b, double c, const ColorSpinorField &x, int parity, bool dagger,
87  const int *comm_override, TimeProfile &profile)
88  {
89  constexpr int nDim = 4;
90  NdegTwistedMassArg<Float, nColor, recon> arg(out, in, U, a, b, c, x, parity, dagger, comm_override);
92 
94  twisted, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)),
96  policy.apply(0);
97 
99  }
100  };
101 
102  void ApplyNdegTwistedMass(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double b,
103  double c, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override,
104  TimeProfile &profile)
105  {
106 #ifdef GPU_NDEG_TWISTED_MASS_DIRAC
107  if (in.V() == out.V()) errorQuda("Aliasing pointers");
108  if (in.FieldOrder() != out.FieldOrder())
109  errorQuda("Field order mismatch in = %d, out = %d", in.FieldOrder(), out.FieldOrder());
110 
111  // check all precisions match
112  checkPrecision(out, in, x, U);
113 
114  // check all locations match
115  checkLocation(out, in, x, U);
116 
117  instantiate<NdegTwistedMassApply>(out, in, U, a, b, c, x, parity, dagger, comm_override, profile);
118 #else
119  errorQuda("Non-degenerate twisted-mass dslash has not been built");
120 #endif // GPU_NDEG_TWISTED_MASS_DIRAC
121  }
122 
123 } // namespace quda
void launch(T *f, const TuneParam &tp, Arg &arg, const cudaStream_t &stream)
Definition: dslash.h:101
void setParam(Arg &arg)
Definition: dslash.h:66
void apply(const cudaStream_t &stream)
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define checkPrecision(...)
#define errorQuda(...)
Definition: util_quda.h:121
cudaStream_t * stream
const char * VolString() const
static void launch(Dslash &dslash, TuneParam &tp, Arg &arg, const cudaStream_t &stream)
static constexpr const char * kernel
void xpay(ColorSpinorField &x, double a, ColorSpinorField &y)
Definition: blas_quda.h:37
void ApplyNdegTwistedMass(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double b, double c, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
Driver for applying the non-degenerate twisted-mass stencil.
const int nColor
Definition: covdev_test.cpp:75
cpuColorSpinorField * in
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
#define checkLocation(...)
int ghostFaceCB[QUDA_MAX_DIM+1]
const DslashConstant & getDslashConstant() const
Get the dslash_constant structure from this field.
cpuColorSpinorField * out
const int nParity
Definition: spinor_noise.cu:25
void apply(const cudaStream_t &stream)
void resizeVector(int y, int z) const
Definition: tune_quda.h:538
This is a helper class that is used to instantiate the correct templated kernel for the dslash...
unsigned long long flops
Definition: blas_quda.cu:22
const ColorSpinorField & in
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
#define checkCudaError()
Definition: util_quda.h:161
virtual long long flops() const
Definition: dslash.h:316
NdegTwistedMassApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double b, double c, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
QudaDagType dagger
Definition: test_util.cpp:1620
NdegTwistedMass(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in)
QudaParity parity
Definition: covdev_test.cpp:54
QudaFieldOrder FieldOrder() const