QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
dslash_wilson_clover_preconditioned.cu
Go to the documentation of this file.
1 #include <gauge_field.h>
2 #include <color_spinor_field.h>
3 #include <clover_field.h>
4 #include <dslash.h>
5 #include <worker.h>
6 
7 #include <dslash_policy.cuh>
9 
14 namespace quda
15 {
16 
21  template <typename Float, int nDim, int nColor, int nParity, bool dagger, bool xpay, KernelType kernel_type, typename Arg>
23  static constexpr const char *kernel = "quda::wilsonCloverPreconditionedGPU"; // kernel name for jit compilation
24  template <typename Dslash>
25  inline static void launch(Dslash &dslash, TuneParam &tp, Arg &arg, const cudaStream_t &stream)
26  {
27  static_assert(nParity == 1, "preconditioned wilson-clover operator only defined for nParity=1");
28  if (xpay && dagger) errorQuda("xpay operator only defined for not dagger");
29  dslash.launch(wilsonCloverPreconditionedGPU < Float, nDim, nColor, nParity, dagger && !xpay, xpay && !dagger,
30  kernel_type, Arg >, tp, arg, stream);
31  }
32  };
33 
34  template <typename Float, int nDim, int nColor, typename Arg> class WilsonCloverPreconditioned : public Dslash<Float>
35  {
36 
37 protected:
38  Arg &arg;
40 
41 public:
43  Dslash<Float>(arg, out, in, "kernels/dslash_wilson_clover_preconditioned.cuh"),
44  arg(arg),
45  in(in)
46  {
47  }
48 
50 
51  void apply(const cudaStream_t &stream)
52  {
53  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
55  if (arg.nParity == 1) {
56  if (arg.xpay)
57  Dslash<Float>::template instantiate<WilsonCloverPreconditionedLaunch, nDim, nColor, 1, true>(tp, arg, stream);
58  else
59  Dslash<Float>::template instantiate<WilsonCloverPreconditionedLaunch, nDim, nColor, 1, false>(tp, arg, stream);
60  } else {
61  errorQuda("Preconditioned Wilson-clover operator not defined nParity=%d", arg.nParity);
62  }
63  }
64 
65  long long flops() const
66  {
67  int clover_flops = 504;
68  long long flops = Dslash<Float>::flops();
69  switch (arg.kernel_type) {
70  case EXTERIOR_KERNEL_X:
71  case EXTERIOR_KERNEL_Y:
72  case EXTERIOR_KERNEL_Z:
73  case EXTERIOR_KERNEL_T: flops += clover_flops * 2 * in.GhostFace()[arg.kernel_type]; break;
75  flops += clover_flops * 2 * (in.GhostFace()[0] + in.GhostFace()[1] + in.GhostFace()[2] + in.GhostFace()[3]);
76  break;
77  case INTERIOR_KERNEL:
78  case KERNEL_POLICY:
79  flops += clover_flops * in.Volume();
80 
81  if (arg.kernel_type == KERNEL_POLICY) break;
82  // now correct for flops done by exterior kernel
83  long long ghost_sites = 0;
84  for (int d = 0; d < 4; d++)
85  if (arg.commDim[d]) ghost_sites += 2 * in.GhostFace()[d];
86  flops -= clover_flops * ghost_sites;
87 
88  break;
89  }
90  return flops;
91  }
92 
93  long long bytes() const
94  {
95  bool isFixed = (in.Precision() == sizeof(short) || in.Precision() == sizeof(char)) ? true : false;
96  int clover_bytes = 72 * in.Precision() + (isFixed ? 2 * sizeof(float) : 0);
97 
98  long long bytes = Dslash<Float>::bytes();
99  switch (arg.kernel_type) {
100  case EXTERIOR_KERNEL_X:
101  case EXTERIOR_KERNEL_Y:
102  case EXTERIOR_KERNEL_Z:
103  case EXTERIOR_KERNEL_T: bytes += clover_bytes * 2 * in.GhostFace()[arg.kernel_type]; break;
104  case EXTERIOR_KERNEL_ALL:
105  bytes += clover_bytes * 2 * (in.GhostFace()[0] + in.GhostFace()[1] + in.GhostFace()[2] + in.GhostFace()[3]);
106  break;
107  case INTERIOR_KERNEL:
108  case KERNEL_POLICY:
109  bytes += clover_bytes * in.Volume();
110 
111  if (arg.kernel_type == KERNEL_POLICY) break;
112  // now correct for bytes done by exterior kernel
113  long long ghost_sites = 0;
114  for (int d = 0; d < 4; d++)
115  if (arg.commDim[d]) ghost_sites += 2 * in.GhostFace()[d];
116  bytes -= clover_bytes * ghost_sites;
117 
118  break;
119  }
120 
121  return bytes;
122  }
123 
124  TuneKey tuneKey() const
125  {
126  return TuneKey(in.VolString(), typeid(*this).name(), Dslash<Float>::aux[arg.kernel_type]);
127  }
128  };
129 
130  template <typename Float, int nColor, QudaReconstructType recon> struct WilsonCloverPreconditionedApply {
131 
133  const CloverField &A, double a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override,
134  TimeProfile &profile)
135  {
136  constexpr int nDim = 4;
137 #ifdef DYNAMIC_CLOVER
138  constexpr bool dynamic_clover = true;
139 #else
140  constexpr bool dynamic_clover = false;
141 #endif
142  WilsonCloverArg<Float, nColor, recon, dynamic_clover> arg(out, in, U, A, a, x, parity, dagger, comm_override);
144  arg, out, in);
145 
147  const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
148  in.GhostFaceCB(), profile);
149  policy.apply(0);
150 
151  checkCudaError();
152  }
153  };
154 
155  // Apply the preconditioned Wilson-clover operator
156  // out(x) = M*in = a * A(x)^{-1} (\sum_mu U_{-\mu}(x)in(x+mu) + U^\dagger_mu(x-mu)in(x-mu))
157  // Uses the kappa normalization for the Wilson operator.
159  const CloverField &A, double a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override,
160  TimeProfile &profile)
161  {
162 #ifdef GPU_CLOVER_DIRAC
163  if (in.V() == out.V()) errorQuda("Aliasing pointers");
164  if (in.FieldOrder() != out.FieldOrder())
165  errorQuda("Field order mismatch in = %d, out = %d", in.FieldOrder(), out.FieldOrder());
166 
167  // check all precisions match
168  checkPrecision(out, in, U, A);
169 
170  // check all locations match
171  checkLocation(out, in, U, A);
172 
173  instantiate<WilsonCloverPreconditionedApply>(out, in, U, A, a, x, parity, dagger, comm_override, profile);
174 #else
175  errorQuda("Clover dslash has not been built");
176 #endif
177  }
178 
179 } // namespace quda
void launch(T *f, const TuneParam &tp, Arg &arg, const cudaStream_t &stream)
Definition: dslash.h:101
static void launch(Dslash &dslash, TuneParam &tp, Arg &arg, const cudaStream_t &stream)
void setParam(Arg &arg)
Definition: dslash.h:66
void apply(const cudaStream_t &stream)
WilsonCloverPreconditionedApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const CloverField &A, double a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define checkPrecision(...)
#define errorQuda(...)
Definition: util_quda.h:121
cudaStream_t * stream
void ApplyWilsonCloverPreconditioned(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const CloverField &A, double kappa, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
Driver for applying the preconditioned Wilson-clover stencil.
const char * VolString() const
void xpay(ColorSpinorField &x, double a, ColorSpinorField &y)
Definition: blas_quda.h:37
virtual long long bytes() const
Definition: dslash.h:364
__device__ __host__ void wilson(Arg &arg, int idx, int s, int parity)
const int nColor
Definition: covdev_test.cpp:75
This is a helper class that is used to instantiate the correct templated kernel for the dslash...
cpuColorSpinorField * in
const int * GhostFaceCB() const
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
#define checkLocation(...)
cpuColorSpinorField * out
const int nParity
Definition: spinor_noise.cu:25
__global__ void wilsonCloverPreconditionedGPU(Arg arg)
unsigned long long flops
Definition: blas_quda.cu:22
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
const int * GhostFace() const
#define checkCudaError()
Definition: util_quda.h:161
virtual long long flops() const
Definition: dslash.h:316
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
QudaPrecision Precision() const
QudaDagType dagger
Definition: test_util.cpp:1620
QudaParity parity
Definition: covdev_test.cpp:54
QudaFieldOrder FieldOrder() const
unsigned long long bytes
Definition: blas_quda.cu:23
WilsonCloverPreconditioned(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in)