QUDA  v1.1.0
A library for QCD on GPUs
dslash_wilson_clover_hasenbusch_twist.cu
Go to the documentation of this file.
1 #ifndef USE_LEGACY_DSLASH
2 
3 #include <gauge_field.h>
4 #include <color_spinor_field.h>
5 #include <clover_field.h>
6 #include <dslash.h>
7 #include <worker.h>
8 
9 #include <dslash_policy.cuh>
10 #include <kernels/dslash_wilson_clover_hasenbusch_twist.cuh>
11 
12 /**
13  This is the Wilson-clover linear operator
14 */
15 
16 namespace quda
17 {
18 
19  template <typename Arg> class WilsonCloverHasenbuschTwist : public Dslash<cloverHasenbusch, Arg>
20  {
21  using Dslash = Dslash<cloverHasenbusch, Arg>;
22  using Dslash::arg;
23  using Dslash::in;
24 
25  public:
26  WilsonCloverHasenbuschTwist(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) :
27  Dslash(arg, out, in) {}
28 
29  void apply(const qudaStream_t &stream)
30  {
31  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
32  Dslash::setParam(tp);
33  if (arg.xpay)
34  Dslash::template instantiate<packShmem, true>(tp, stream);
35  else
36  errorQuda("Wilson-clover - Hasenbusch Twist operator only defined for xpay=true");
37  }
38 
39  long long flops() const
40  {
41  int clover_flops = 504;
42  long long flops = Dslash::flops();
43  switch (arg.kernel_type) {
44  case INTERIOR_KERNEL:
45  case UBER_KERNEL:
46  case KERNEL_POLICY:
47  flops += clover_flops * in.Volume();
48 
49  // -mu * (i gamma_5 A) (A x)
50  flops += ((clover_flops + 48) * in.Volume());
51 
52  break;
53  default: break; // all clover flops are in the interior kernel
54  }
55  return flops;
56  }
57 
58  long long bytes() const
59  {
60  int clover_bytes = 72 * in.Precision() + (isFixed<typename Arg::Float>::value ? 2 * sizeof(float) : 0);
61 
62  long long bytes = Dslash::bytes();
63  switch (arg.kernel_type) {
64  case INTERIOR_KERNEL:
65  case UBER_KERNEL:
66  case KERNEL_POLICY: bytes += clover_bytes * in.Volume(); break;
67  default: break;
68  }
69 
70  return bytes;
71  }
72  };
73 
74  template <typename Float, int nColor, QudaReconstructType recon> struct WilsonCloverHasenbuschTwistApply {
75 
76  inline WilsonCloverHasenbuschTwistApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U,
77  const CloverField &A, double a, double b, const ColorSpinorField &x,
78  int parity, bool dagger, const int *comm_override, TimeProfile &profile)
79  {
80  constexpr int nDim = 4;
81  WilsonCloverHasenbuschTwistArg<Float, nColor, nDim, recon> arg(out, in, U, A, a, b, x, parity, dagger, comm_override);
82  WilsonCloverHasenbuschTwist<decltype(arg)> wilson(arg, out, in);
83 
84  dslash::DslashPolicyTune<decltype(wilson)> policy(
85  wilson, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
86  in.GhostFaceCB(), profile);
87  policy.apply(0);
88  }
89  };
90 
91  // Apply the Wilson-clover operator
92  // out(x) = M*in = (A(x) + a * \sum_mu U_{-\mu}(x)in(x+mu) + U^\dagger_mu(x-mu)in(x-mu))
93  // Uses the kappa normalization for the Wilson operator.
94  void ApplyWilsonCloverHasenbuschTwist(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U,
95  const CloverField &A, double a, double b, const ColorSpinorField &x, int parity,
96  bool dagger, const int *comm_override, TimeProfile &profile)
97  {
98 #ifdef GPU_CLOVER_HASENBUSCH_TWIST
99  instantiate<WilsonCloverHasenbuschTwistApply>(out, in, U, A, a, b, x, parity, dagger, comm_override, profile);
100 #else
101  errorQuda("Clover Hasensbuch Twist dslash has not been built");
102 #endif
103  }
104 
105 } // namespace quda
106 
107 #endif