QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
dslash_twisted_clover_preconditioned.cu
Go to the documentation of this file.
1 #include <gauge_field.h>
2 #include <color_spinor_field.h>
3 #include <clover_field.h>
4 #include <dslash.h>
5 #include <worker.h>
6 
7 #include <dslash_policy.cuh>
9 
14 namespace quda
15 {
16 
21  template <typename Float, int nDim, int nColor, int nParity, bool dagger, bool xpay, KernelType kernel_type, typename Arg>
23  static constexpr const char *kernel = "quda::twistedCloverPreconditionedGPU"; // kernel name for jit compilation
24  template <typename Dslash>
25  inline static void launch(Dslash &dslash, TuneParam &tp, Arg &arg, const cudaStream_t &stream)
26  {
27  static_assert(nParity == 1, "preconditioned twisted-mass operator only defined for nParity=1");
28  if (xpay && dagger) errorQuda("xpay operator only defined for not dagger");
29  dslash.launch(twistedCloverPreconditionedGPU < Float, nDim, nColor, nParity, dagger && !xpay, xpay && !dagger,
30  kernel_type, Arg >, tp, arg, stream);
31  }
32  };
33 
34  template <typename Float, int nDim, int nColor, typename Arg> class TwistedCloverPreconditioned : public Dslash<Float>
35  {
36 
37 protected:
38  Arg &arg;
40 
41 public:
43  Dslash<Float>(arg, out, in, "kernels/dslash_twisted_clover_preconditioned.cuh"),
44  arg(arg),
45  in(in)
46  {
47  }
48 
50 
51  void apply(const cudaStream_t &stream)
52  {
53  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
55  if (arg.nParity == 1) {
56  if (arg.xpay)
57  Dslash<Float>::template instantiate<TwistedCloverPreconditionedLaunch, nDim, nColor, 1, true>(tp, arg, stream);
58  else
59  Dslash<Float>::template instantiate<TwistedCloverPreconditionedLaunch, nDim, nColor, 1, false>(tp, arg, stream);
60  } else {
61  errorQuda("Preconditioned twisted-clover operator not defined nParity=%d", arg.nParity);
62  }
63  }
64 
65  long long flops() const
66  {
67  int clover_flops = 504 + 48;
68  long long flops = Dslash<Float>::flops();
69  switch (arg.kernel_type) {
70  case EXTERIOR_KERNEL_X:
71  case EXTERIOR_KERNEL_Y:
72  case EXTERIOR_KERNEL_Z:
73  case EXTERIOR_KERNEL_T: flops += clover_flops * 2 * in.GhostFace()[arg.kernel_type]; break;
75  flops += clover_flops * 2 * (in.GhostFace()[0] + in.GhostFace()[1] + in.GhostFace()[2] + in.GhostFace()[3]);
76  break;
77  case INTERIOR_KERNEL:
78  case KERNEL_POLICY:
79  flops += clover_flops * in.Volume();
80 
81  if (arg.kernel_type == KERNEL_POLICY) break;
82  // now correct for flops done by exterior kernel
83  long long ghost_sites = 0;
84  for (int d = 0; d < 4; d++)
85  if (arg.commDim[d]) ghost_sites += 2 * in.GhostFace()[d];
86  flops -= clover_flops * ghost_sites;
87 
88  break;
89  }
90  return flops;
91  }
92 
93  long long bytes() const
94  {
95  bool isFixed = (in.Precision() == sizeof(short) || in.Precision() == sizeof(char)) ? true : false;
96  int clover_bytes = 72 * in.Precision() + (isFixed ? 2 * sizeof(float) : 0);
97  if (!arg.dynamic_clover) clover_bytes *= 2;
98 
99  long long bytes = Dslash<Float>::bytes();
100  switch (arg.kernel_type) {
101  case EXTERIOR_KERNEL_X:
102  case EXTERIOR_KERNEL_Y:
103  case EXTERIOR_KERNEL_Z:
104  case EXTERIOR_KERNEL_T: bytes += clover_bytes * 2 * in.GhostFace()[arg.kernel_type]; break;
105  case EXTERIOR_KERNEL_ALL:
106  bytes += clover_bytes * 2 * (in.GhostFace()[0] + in.GhostFace()[1] + in.GhostFace()[2] + in.GhostFace()[3]);
107  break;
108  case INTERIOR_KERNEL:
109  case KERNEL_POLICY:
110  bytes += clover_bytes * in.Volume();
111 
112  if (arg.kernel_type == KERNEL_POLICY) break;
113  // now correct for bytes done by exterior kernel
114  long long ghost_sites = 0;
115  for (int d = 0; d < 4; d++)
116  if (arg.commDim[d]) ghost_sites += 2 * in.GhostFace()[d];
117  bytes -= clover_bytes * ghost_sites;
118 
119  break;
120  }
121 
122  return bytes;
123  }
124 
125  TuneKey tuneKey() const
126  {
127  return TuneKey(in.VolString(), typeid(*this).name(), Dslash<Float>::aux[arg.kernel_type]);
128  }
129  };
130 
131  template <typename Float, int nColor, QudaReconstructType recon> struct TwistedCloverPreconditionedApply {
132 
134  const CloverField &C, double a, double b, bool xpay, const ColorSpinorField &x, int parity, bool dagger,
135  const int *comm_override, TimeProfile &profile)
136  {
137  constexpr int nDim = 4;
138 #ifdef DYNAMIC_CLOVER
139  constexpr bool dynamic_clover = true;
140 #else
141  constexpr bool dynamic_clover = false;
142 #endif
144  out, in, U, C, a, b, xpay, x, parity, dagger, comm_override);
146  arg, out, in);
147 
149  const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
150  in.GhostFaceCB(), profile);
151  policy.apply(0);
152  checkCudaError();
153  }
154  };
155 
156  /*
157  Apply the preconditioned twisted-mass Dslash operator
158 
159  out = x + a*A^{-1} D * in = x + a*(C + i*b*gamma_5)^{-1}*\sum_mu U_{-\mu}(x)in(x+mu) + U^\dagger_mu(x-mu)in(x-mu)
160  */
162  const CloverField &C, double a, double b, bool xpay, const ColorSpinorField &x, int parity, bool dagger,
163  const int *comm_override, TimeProfile &profile)
164  {
165 #ifdef GPU_TWISTED_CLOVER_DIRAC
166  if (in.V() == out.V()) errorQuda("Aliasing pointers");
167  if (in.FieldOrder() != out.FieldOrder())
168  errorQuda("Field order mismatch in = %d, out = %d", in.FieldOrder(), out.FieldOrder());
169 
170  // check all precisions match
171  checkPrecision(out, in, U, C);
172 
173  // check all locations match
174  checkLocation(out, in, U, C);
175 
176  instantiate<TwistedCloverPreconditionedApply>(out, in, U, C, a, b, xpay, x, parity, dagger, comm_override, profile);
177 #else
178  errorQuda("Twisted-clover dslash has not been built");
179 #endif // GPU_TWISTED_CLOVER_DIRAC
180  }
181 
182 } // namespace quda
void launch(T *f, const TuneParam &tp, Arg &arg, const cudaStream_t &stream)
Definition: dslash.h:101
void setParam(Arg &arg)
Definition: dslash.h:66
void apply(const cudaStream_t &stream)
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define checkPrecision(...)
#define errorQuda(...)
Definition: util_quda.h:121
cudaStream_t * stream
__global__ void twistedCloverPreconditionedGPU(Arg arg)
const char * VolString() const
void xpay(ColorSpinorField &x, double a, ColorSpinorField &y)
Definition: blas_quda.h:37
virtual long long bytes() const
Definition: dslash.h:364
const int nColor
Definition: covdev_test.cpp:75
cpuColorSpinorField * in
const int * GhostFaceCB() const
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
#define checkLocation(...)
cpuColorSpinorField * out
This is a helper class that is used to instantiate the correct templated kernel for the dslash...
const int nParity
Definition: spinor_noise.cu:25
void ApplyTwistedCloverPreconditioned(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const CloverField &C, double a, double b, bool xpay, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
Driver for applying the preconditioned twisted-clover stencil.
static void launch(Dslash &dslash, TuneParam &tp, Arg &arg, const cudaStream_t &stream)
unsigned long long flops
Definition: blas_quda.cu:22
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
const int * GhostFace() const
#define checkCudaError()
Definition: util_quda.h:161
virtual long long flops() const
Definition: dslash.h:316
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
TwistedCloverPreconditioned(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in)
QudaPrecision Precision() const
QudaDagType dagger
Definition: test_util.cpp:1620
QudaParity parity
Definition: covdev_test.cpp:54
TwistedCloverPreconditionedApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const CloverField &C, double a, double b, bool xpay, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
QudaFieldOrder FieldOrder() const
unsigned long long bytes
Definition: blas_quda.cu:23