QUDA  v1.1.0
A library for QCD on GPUs
dslash_wilson_clover.cu
Go to the documentation of this file.
1 #include <gauge_field.h>
2 #include <color_spinor_field.h>
3 #include <clover_field.h>
4 #include <dslash.h>
5 #include <worker.h>
6 
7 #include <dslash_policy.cuh>
8 #include <kernels/dslash_wilson_clover.cuh>
9 
10 /**
11  This is the Wilson-clover linear operator
12 */
13 
14 namespace quda
15 {
16 
17  template <typename Arg> class WilsonClover : public Dslash<wilsonClover, Arg>
18  {
19  using Dslash = Dslash<wilsonClover, Arg>;
20  using Dslash::arg;
21  using Dslash::in;
22 
23  public:
24  WilsonClover(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) : Dslash(arg, out, in) {}
25 
26  void apply(const qudaStream_t &stream)
27  {
28  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
29  Dslash::setParam(tp);
30  if (arg.xpay)
31  Dslash::template instantiate<packShmem, true>(tp, stream);
32  else
33  errorQuda("Wilson-clover operator only defined for xpay=true");
34  }
35 
36  long long flops() const
37  {
38  int clover_flops = 504;
39  long long flops = Dslash::flops();
40 
41  switch (arg.kernel_type) {
42  case INTERIOR_KERNEL:
43  case UBER_KERNEL:
44  case KERNEL_POLICY: flops += clover_flops * in.Volume(); break;
45  default: break; // all clover flops are in the interior kernel
46  }
47  return flops;
48  }
49 
50  long long bytes() const
51  {
52  int clover_bytes = 72 * in.Precision() + (isFixed<typename Arg::Float>::value ? 2 * sizeof(float) : 0);
53  long long bytes = Dslash::bytes();
54 
55  switch (arg.kernel_type) {
56  case INTERIOR_KERNEL:
57  case KERNEL_POLICY: bytes += clover_bytes * in.Volume(); break;
58  default: break;
59  }
60 
61  return bytes;
62  }
63  };
64 
65  template <typename Float, int nColor, QudaReconstructType recon> struct WilsonCloverApply {
66 
67  inline WilsonCloverApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const CloverField &A,
68  double a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
69  {
70  constexpr int nDim = 4;
71  WilsonCloverArg<Float, nColor, nDim, recon> arg(out, in, U, A, a, 0.0, x, parity, dagger, comm_override);
72  WilsonClover<decltype(arg)> wilson(arg, out, in);
73 
74  dslash::DslashPolicyTune<decltype(wilson)> policy(wilson,
75  const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
76  in.GhostFaceCB(), profile);
77  policy.apply(0);
78  }
79  };
80 
81  template <typename Float, int nColor, QudaReconstructType recon> struct WilsonCloverWithTwistApply {
82 
83  inline WilsonCloverWithTwistApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U,
84  const CloverField &A, double a, double b, const ColorSpinorField &x, int parity,
85  bool dagger, const int *comm_override, TimeProfile &profile)
86  {
87  constexpr int nDim = 4;
88  WilsonCloverArg<Float, nColor, nDim, recon, true> arg(out, in, U, A, a, b, x, parity, dagger, comm_override);
89  WilsonClover<decltype(arg)> wilson(arg, out, in);
90 
91  dslash::DslashPolicyTune<decltype(wilson)> policy(
92  wilson, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
93  in.GhostFaceCB(), profile);
94  policy.apply(0);
95  }
96  };
97 
98  // Apply the Wilson-clover operator
99  // out(x) = M*in = (A(x) + a * \sum_mu U_{-\mu}(x)in(x+mu) + U^\dagger_mu(x-mu)in(x-mu))
100  // Uses the kappa normalization for the Wilson operator.
101  void ApplyWilsonClover(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const CloverField &A,
102  double a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
103  {
104 #ifdef GPU_CLOVER_DIRAC
105  instantiate<WilsonCloverApply>(out, in, U, A, a, x, parity, dagger, comm_override, profile);
106 #else
107  errorQuda("Clover dslash has not been built");
108 #endif
109  }
110 
111 } // namespace quda