QUDA  v1.1.0
A library for QCD on GPUs
dslash_ndeg_twisted_mass_preconditioned.cu
Go to the documentation of this file.
1 #include <gauge_field.h>
2 #include <color_spinor_field.h>
3 #include <dslash.h>
4 #include <worker.h>
5 
6 #include <dslash_policy.cuh>
7 #include <kernels/dslash_ndeg_twisted_mass_preconditioned.cuh>
8 
9 /**
10  This is the preconditioned twisted-mass operator acting on a non-generate
11  quark doublet.
12 */
13 
14 namespace quda
15 {
16 
17  // trait to ensure we don't instantiate asymmetric & xpay
18  template <bool symmetric> constexpr bool xpay_() { return true; }
19  template <> constexpr bool xpay_<true>() { return false; }
20 
21  // trait to ensure we don't instantiate asymmetric & !dagger
22  template <bool symmetric> constexpr bool not_dagger_() { return false; }
23  template <> constexpr bool not_dagger_<true>() { return true; }
24 
25  template <typename Arg> class NdegTwistedMassPreconditioned : public Dslash<nDegTwistedMassPreconditioned, Arg>
26  {
27  using Dslash = Dslash<nDegTwistedMassPreconditioned, Arg>;
28  using Dslash::arg;
29  using Dslash::in;
30 
31  protected:
32  bool shared;
33  unsigned int sharedBytesPerThread() const
34  {
35  return shared ? 2 * in.Ncolor() * 4 * sizeof(typename mapper<typename Arg::Float>::type) : 0;
36  }
37 
38  public:
39  NdegTwistedMassPreconditioned(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) :
40  Dslash(arg, out, in),
41  shared(arg.asymmetric || !arg.dagger)
42  {
43  TunableVectorYZ::resizeVector(2, arg.nParity);
44  if (shared) TunableVectorY::resizeStep(2); // this will force flavor to be contained in the block
45  }
46 
47  void apply(const qudaStream_t &stream)
48  {
49  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
50  Dslash::setParam(tp);
51  if (arg.asymmetric && !arg.dagger) errorQuda("asymmetric operator only defined for dagger");
52  if (arg.asymmetric && arg.xpay) errorQuda("asymmetric operator not defined for xpay");
53  if (arg.nParity != 1) errorQuda("Preconditioned non-degenerate twisted-mass operator not defined nParity=%d", arg.nParity);
54 
55  if (arg.dagger) {
56  if (arg.xpay)
57  Dslash::template instantiate<packShmem, 1, true, xpay_<Arg::asymmetric>()>(tp, stream);
58  else
59  Dslash::template instantiate<packShmem, 1, true, false>(tp, stream);
60  } else {
61  if (arg.xpay)
62  Dslash::template instantiate<packShmem, 1, not_dagger_<Arg::asymmetric>(), xpay_<Arg::asymmetric>()>(tp, stream);
63  else
64  Dslash::template instantiate<packShmem, 1, not_dagger_<Arg::asymmetric>(), false>(tp, stream);
65  }
66  }
67 
68  void initTuneParam(TuneParam &param) const
69  {
70  Dslash::initTuneParam(param);
71  if (shared) param.shared_bytes = sharedBytesPerThread() * param.block.x * param.block.y * param.block.z;
72  }
73 
74  void defaultTuneParam(TuneParam &param) const
75  {
76  Dslash::defaultTuneParam(param);
77  if (shared) param.shared_bytes = sharedBytesPerThread() * param.block.x * param.block.y * param.block.z;
78  }
79 
80  long long flops() const
81  {
82  long long flops = Dslash::flops();
83  switch (arg.kernel_type) {
84  case INTERIOR_KERNEL:
85  case UBER_KERNEL:
86  case KERNEL_POLICY:
87  flops += 2 * in.Ncolor() * 4 * 4 * in.Volume(); // complex * Nc * Ns * fma * vol
88  break;
89  default: break; // twisted-mass flops are in the interior kernel
90  }
91  return flops;
92  }
93  };
94 
95  template <typename Float, int nColor, QudaReconstructType recon> struct NdegTwistedMassPreconditionedApply {
96 
97  inline NdegTwistedMassPreconditionedApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U,
98  double a, double b, double c, bool xpay, const ColorSpinorField &x, int parity, bool dagger, bool asymmetric,
99  const int *comm_override, TimeProfile &profile)
100  {
101  constexpr int nDim = 4;
102  if (asymmetric) {
103  NdegTwistedMassArg<Float, nColor, nDim, recon, true> arg(out, in, U, a, b, c, xpay, x, parity, dagger, comm_override);
104  NdegTwistedMassPreconditioned<decltype(arg)> twisted(arg, out, in);
105 
106  dslash::DslashPolicyTune<decltype(twisted)> policy(twisted,
107  const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)),
108  in.getDslashConstant().volume_4d_cb, in.getDslashConstant().ghostFaceCB, profile);
109  policy.apply(0);
110  } else {
111  NdegTwistedMassArg<Float, nColor, nDim, recon, false> arg(out, in, U, a, b, c, xpay, x, parity, dagger, comm_override);
112  NdegTwistedMassPreconditioned<decltype(arg)> twisted(arg, out, in);
113 
114  dslash::DslashPolicyTune<decltype(twisted)> policy(twisted,
115  const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)),
116  in.getDslashConstant().volume_4d_cb, in.getDslashConstant().ghostFaceCB, profile);
117  policy.apply(0);
118  }
119  }
120  };
121 
122  // Apply the non-degenerate twisted-mass Dslash operator
123  // out(x) = M*in = a*(1 + i*b*gamma_5*tau_3 + c*tau_1)*D + x
124  // Uses the kappa normalization for the Wilson operator, with a = -kappa.
125  void ApplyNdegTwistedMassPreconditioned(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U,
126  double a, double b, double c, bool xpay, const ColorSpinorField &x, int parity, bool dagger, bool asymmetric,
127  const int *comm_override, TimeProfile &profile)
128  {
129 #ifdef GPU_NDEG_TWISTED_MASS_DIRAC
130  // with symmetric dagger operator we must use kernel packing
131  if (dagger && !asymmetric) pushKernelPackT(true);
132 
133  instantiate<NdegTwistedMassPreconditionedApply>(
134  out, in, U, a, b, c, xpay, x, parity, dagger, asymmetric, comm_override, profile);
135 
136  if (dagger && !asymmetric) popKernelPackT();
137 #else
138  errorQuda("Non-degenerate twisted-mass dslash has not been built");
139 #endif // GPU_NDEG_TWISTED_MASS_DIRAC
140  }
141 
142 } // namespace quda