QUDA  v1.1.0
A library for QCD on GPUs
dslash_ndeg_twisted_mass.cu
Go to the documentation of this file.
1 #include <gauge_field.h>
2 #include <color_spinor_field.h>
3 #include <dslash.h>
4 #include <worker.h>
5 
6 #include <dslash_policy.cuh>
7 #include <kernels/dslash_ndeg_twisted_mass.cuh>
8 
9 /**
10  This is the gauged twisted-mass operator acting on a non-generate
11  quark doublet.
12 */
13 
14 namespace quda
15 {
16 
17  template <typename Arg> class NdegTwistedMass : public Dslash<nDegTwistedMass, Arg>
18  {
19  using Dslash = Dslash<nDegTwistedMass, Arg>;
20  using Dslash::arg;
21  using Dslash::in;
22 
23  public:
24  NdegTwistedMass(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) : Dslash(arg, out, in)
25  {
26  TunableVectorYZ::resizeVector(2, arg.nParity);
27  }
28 
29  void apply(const qudaStream_t &stream)
30  {
31  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
32  Dslash::setParam(tp);
33  if (arg.xpay)
34  Dslash::template instantiate<packShmem, true>(tp, stream);
35  else
36  errorQuda("Non-degenerate twisted-mass operator only defined for xpay=true");
37  }
38 
39  long long flops() const
40  {
41  long long flops = Dslash::flops();
42  switch (arg.kernel_type) {
43  case INTERIOR_KERNEL:
44  case UBER_KERNEL:
45  case KERNEL_POLICY:
46  flops += 2 * in.Ncolor() * 4 * 4 * in.Volume(); // complex * Nc * Ns * fma * vol
47  break;
48  default: break; // twisted-mass flops are in the interior kernel
49  }
50  return flops;
51  }
52  };
53 
54  template <typename Float, int nColor, QudaReconstructType recon> struct NdegTwistedMassApply {
55 
56  inline NdegTwistedMassApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a,
57  double b, double c, const ColorSpinorField &x, int parity, bool dagger,
58  const int *comm_override, TimeProfile &profile)
59  {
60  constexpr int nDim = 4;
61  NdegTwistedMassArg<Float, nColor, nDim, recon> arg(out, in, U, a, b, c, x, parity, dagger, comm_override);
62  NdegTwistedMass<decltype(arg)> twisted(arg, out, in);
63 
64  dslash::DslashPolicyTune<decltype(twisted)> policy(
65  twisted, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)),
66  in.getDslashConstant().volume_4d_cb, in.getDslashConstant().ghostFaceCB, profile);
67  policy.apply(0);
68  }
69  };
70 
71  void ApplyNdegTwistedMass(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double b,
72  double c, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override,
73  TimeProfile &profile)
74  {
75 #ifdef GPU_NDEG_TWISTED_MASS_DIRAC
76  instantiate<NdegTwistedMassApply>(out, in, U, a, b, c, x, parity, dagger, comm_override, profile);
77 #else
78  errorQuda("Non-degenerate twisted-mass dslash has not been built");
79 #endif // GPU_NDEG_TWISTED_MASS_DIRAC
80  }
81 
82 } // namespace quda