QUDA  v1.1.0
A library for QCD on GPUs
dslash_twisted_mass_preconditioned.cu
Go to the documentation of this file.
1 #include <gauge_field.h>
2 #include <color_spinor_field.h>
3 #include <dslash.h>
4 #include <worker.h>
5 
6 #include <dslash_policy.cuh>
7 #include <kernels/dslash_twisted_mass_preconditioned.cuh>
8 
9 /**
10  This is the preconditioned gauged twisted-mass operator
11 */
12 
13 namespace quda
14 {
15 
16  // trait to ensure we don't instantiate asymmetric & xpay
17  template <bool symmetric> constexpr bool xpay_() { return true; }
18  template <> constexpr bool xpay_<true>() { return false; }
19 
20  // trait to ensure we don't instantiate asymmetric & !dagger
21  template <bool symmetric> constexpr bool not_dagger_() { return false; }
22  template <> constexpr bool not_dagger_<true>() { return true; }
23 
24  template <typename Arg> class TwistedMassPreconditioned : public Dslash<twistedMassPreconditioned, Arg>
25  {
26  using Dslash = Dslash<twistedMassPreconditioned, Arg>;
27  using Dslash::arg;
28  using Dslash::in;
29 
30  public:
31  TwistedMassPreconditioned(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) : Dslash(arg, out, in)
32  {
33  }
34 
35  void apply(const qudaStream_t &stream)
36  {
37  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
38  Dslash::setParam(tp);
39  if (arg.asymmetric && !arg.dagger) errorQuda("asymmetric operator only defined for dagger");
40  if (arg.asymmetric && arg.xpay) errorQuda("asymmetric operator not defined for xpay");
41  if (arg.nParity != 1) errorQuda("Preconditioned twisted-mass operator not defined nParity=%d", arg.nParity);
42 
43  if (arg.dagger) {
44  if (arg.xpay)
45  Dslash::template instantiate<packShmem, 1, true, xpay_<Arg::asymmetric>()>(tp, stream);
46  else
47  Dslash::template instantiate<packShmem, 1, true, false>(tp, stream);
48  } else {
49  if (arg.xpay)
50  Dslash::template instantiate<packShmem, 1, not_dagger_<Arg::asymmetric>(), xpay_<Arg::asymmetric>()>(tp, stream);
51  else
52  Dslash::template instantiate<packShmem, 1, not_dagger_<Arg::asymmetric>(), false>(tp, stream);
53  }
54  }
55 
56  long long flops() const
57  {
58  long long flops = Dslash::flops();
59  switch (arg.kernel_type) {
60  case INTERIOR_KERNEL:
61  case UBER_KERNEL:
62  case KERNEL_POLICY:
63  flops += 2 * in.Ncolor() * 4 * 2 * in.Volume(); // complex * Nc * Ns * fma * vol
64  break;
65  default: break;
66  }
67  return flops;
68  }
69  };
70 
71  template <typename Float, int nColor, QudaReconstructType recon> struct TwistedMassPreconditionedApply {
72 
73  inline TwistedMassPreconditionedApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U,
74  double a, double b, bool xpay, const ColorSpinorField &x, int parity, bool dagger, bool asymmetric,
75  const int *comm_override, TimeProfile &profile)
76  {
77  constexpr int nDim = 4;
78  if (asymmetric) {
79  TwistedMassArg<Float, nColor, nDim, recon, true> arg(out, in, U, a, b, xpay, x, parity, dagger, comm_override);
80  TwistedMassPreconditioned<decltype(arg)> twisted(arg, out, in);
81 
82  dslash::DslashPolicyTune<decltype(twisted)> policy(twisted,
83  const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
84  in.GhostFaceCB(), profile);
85  policy.apply(0);
86  } else {
87  TwistedMassArg<Float, nColor, nDim, recon, false> arg(out, in, U, a, b, xpay, x, parity, dagger, comm_override);
88  TwistedMassPreconditioned<decltype(arg)> twisted(arg, out, in);
89 
90  dslash::DslashPolicyTune<decltype(twisted)> policy(twisted,
91  const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
92  in.GhostFaceCB(), profile);
93  policy.apply(0);
94  }
95  }
96  };
97 
98  /*
99  Apply the preconditioned twisted-mass Dslash operator
100 
101  out = x + A^{-1} D * in = x + a*(1 + i*b*gamma_5)*\sum_mu U_{-\mu}(x)in(x+mu) + U^\dagger_mu(x-mu)in(x-mu)
102  */
103  void ApplyTwistedMassPreconditioned(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a,
104  double b, bool xpay, const ColorSpinorField &x, int parity, bool dagger, bool asymmetric,
105  const int *comm_override, TimeProfile &profile)
106  {
107 #ifdef GPU_TWISTED_MASS_DIRAC
108  // with symmetric dagger operator we must use kernel packing
109  if (dagger && !asymmetric) pushKernelPackT(true);
110 
111  instantiate<TwistedMassPreconditionedApply>(
112  out, in, U, a, b, xpay, x, parity, dagger, asymmetric, comm_override, profile);
113 
114  if (dagger && !asymmetric) popKernelPackT();
115 #else
116  errorQuda("Twisted-mass dslash has not been built");
117 #endif // GPU_TWISTED_MASS_DIRAC
118  }
119 
120 } // namespace quda