QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
covDev.cu
Go to the documentation of this file.
1 #include <dslash.h>
2 #include <worker.h>
3 #include <dslash_helper.cuh>
5 #include <gauge_field_order.h>
6 #include <color_spinor.h>
7 #include <dslash_helper.cuh>
8 #include <index_helper.cuh>
9 #include <gauge_field.h>
10 #include <uint_to_char.h>
11 
12 #include <dslash_policy.cuh>
13 #include <kernels/covDev.cuh>
14 
19 namespace quda
20 {
21 
22 #ifdef GPU_COVDEV
23 
28  template <typename Float, int nDim, int nColor, int nParity, bool dagger, bool xpay, KernelType kernel_type, typename Arg>
29  struct CovDevLaunch {
30 
31  // kernel name for jit compilation
32  static constexpr const char *kernel = "quda::covDevGPU";
33 
34  template <typename Dslash>
35  inline static void launch(Dslash &dslash, TuneParam &tp, Arg &arg, const cudaStream_t &stream)
36  {
37  static_assert(xpay == false, "Covariant derivative operator only defined without xpay");
38  static_assert(nParity == 2, "Covariant derivative operator only defined for full field");
39  dslash.launch(covDevGPU<Float, nDim, nColor, nParity, dagger, xpay, kernel_type, Arg>, tp, arg, stream);
40  }
41  };
42 
43  template <typename Float, int nDim, int nColor, typename Arg> class CovDev : public Dslash<Float>
44  {
45 
46 protected:
47  Arg &arg;
48  const ColorSpinorField &in;
49 
50 public:
51  CovDev(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) :
52  Dslash<Float>(arg, out, in, "kernels/covDev.cuh"),
53  arg(arg),
54  in(in)
55  {
56  }
57 
58  virtual ~CovDev() {}
59 
60  void apply(const cudaStream_t &stream)
61  {
62  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
64  if (arg.xpay) errorQuda("Covariant derivative operator only defined without xpay");
65  if (arg.nParity != 2) errorQuda("Covariant derivative operator only defined for full field");
66 
67  constexpr bool xpay = false;
68  constexpr int nParity = 2;
69  Dslash<Float>::template instantiate<CovDevLaunch, nDim, nColor, nParity, xpay>(tp, arg, stream);
70  }
71 
72  long long flops() const
73  {
74  int mv_flops = (8 * in.Ncolor() - 2) * in.Ncolor(); // SU(3) matrix-vector flops
75  int num_mv_multiply = in.Nspin();
76  int ghost_flops = num_mv_multiply * mv_flops;
77  int dim = arg.mu % 4;
78  long long flops_ = 0;
79 
80  switch (arg.kernel_type) {
81  case EXTERIOR_KERNEL_X:
82  case EXTERIOR_KERNEL_Y:
83  case EXTERIOR_KERNEL_Z:
84  case EXTERIOR_KERNEL_T:
85  if (arg.kernel_type != dim) break;
86  flops_ = (ghost_flops)*in.GhostFace()[dim];
87  break;
88  case EXTERIOR_KERNEL_ALL: {
89  long long ghost_sites = in.GhostFace()[dim];
90  flops_ = ghost_flops * ghost_sites;
91  break;
92  }
93  case INTERIOR_KERNEL:
94  case KERNEL_POLICY: {
95  long long sites = in.Volume();
96  flops_ = num_mv_multiply * mv_flops * sites; // SU(3) matrix-vector multiplies
97 
98  if (arg.kernel_type == KERNEL_POLICY) break;
99  // now correct for flops done by exterior kernel
100  long long ghost_sites = arg.commDim[dim] ? in.GhostFace()[dim] : 0;
101  flops_ -= ghost_flops * ghost_sites;
102 
103  break;
104  }
105  }
106 
107  return flops_;
108  }
109 
110  long long bytes() const
111  {
112  int gauge_bytes = arg.reconstruct * in.Precision();
113  bool isFixed = (in.Precision() == sizeof(short) || in.Precision() == sizeof(char)) ? true : false;
114  int spinor_bytes = 2 * in.Ncolor() * in.Nspin() * in.Precision() + (isFixed ? sizeof(float) : 0);
115  int ghost_bytes = gauge_bytes + 3 * spinor_bytes; // 3 since we have to load the partial
116  int dim = arg.mu % 4;
117  long long bytes_ = 0;
118 
119  switch (arg.kernel_type) {
120  case EXTERIOR_KERNEL_X:
121  case EXTERIOR_KERNEL_Y:
122  case EXTERIOR_KERNEL_Z:
123  case EXTERIOR_KERNEL_T:
124  if (arg.kernel_type != dim) break;
125  bytes_ = ghost_bytes * in.GhostFace()[dim];
126  break;
127  case EXTERIOR_KERNEL_ALL: {
128  long long ghost_sites = in.GhostFace()[dim];
129  bytes_ = ghost_bytes * ghost_sites;
130  break;
131  }
132  case INTERIOR_KERNEL:
133  case KERNEL_POLICY: {
134  long long sites = in.Volume();
135  bytes_ = (gauge_bytes + 2 * spinor_bytes) * sites;
136 
137  if (arg.kernel_type == KERNEL_POLICY) break;
138  // now correct for bytes done by exterior kernel
139  long long ghost_sites = arg.commDim[dim] ? in.GhostFace()[dim] : 0;
140  bytes_ -= ghost_bytes * ghost_sites;
141 
142  break;
143  }
144  }
145  return bytes_;
146  }
147 
148  TuneKey tuneKey() const
149  {
150  // add mu to the key
151  char aux[TuneKey::aux_n];
152  strcpy(aux, Dslash<Float>::aux[arg.kernel_type]);
153  strcat(aux, ",mu=");
154  char mu[8];
155  u32toa(mu, arg.mu);
156  strcat(aux, mu);
157  return TuneKey(in.VolString(), typeid(*this).name(), aux);
158  }
159  };
160 
161  template <typename Float, int nColor, QudaReconstructType recon> struct CovDevApply {
162 
163  inline CovDevApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, int mu, int parity,
164  bool dagger, const int *comm_override, TimeProfile &profile)
165 
166  {
167  constexpr int nDim = 4;
168  CovDevArg<Float, nColor, recon> arg(out, in, U, mu, parity, dagger, comm_override);
169  CovDev<Float, nDim, nColor, CovDevArg<Float, nColor, recon>> covDev(arg, out, in);
170 
171  dslash::DslashPolicyTune<decltype(covDev)> policy(
172  covDev, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
173  in.GhostFaceCB(), profile);
174  policy.apply(0);
175 
176  checkCudaError();
177  }
178  };
179 
180 #endif
181 
182  // Apply the covariant derivative operator
183  // out(x) = U_{\mu}(x)in(x+mu) for mu = 0...3
184  // out(x) = U^\dagger_mu'(x-mu')in(x-mu') for mu = 4...7 and we set mu' = mu-4
185  void ApplyCovDev(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, int mu, int parity,
186  bool dagger, const int *comm_override, TimeProfile &profile)
187  {
188 #ifdef GPU_COVDEV
189  if (in.V() == out.V()) errorQuda("Aliasing pointers");
190  if (in.FieldOrder() != out.FieldOrder())
191  errorQuda("Field order mismatch in = %d, out = %d", in.FieldOrder(), out.FieldOrder());
192 
193  // check all precisions match
194  checkPrecision(out, in, U);
195 
196  // check all locations match
197  checkLocation(out, in, U);
198 
199  pushKernelPackT(true); // non-spin projection requires kernel packing
200 
201  instantiate<CovDevApply>(out, in, U, mu, parity, dagger, comm_override, profile);
202 
203  popKernelPackT();
204 #else
205  errorQuda("Covariant derivative kernels have not been built");
206 #endif
207  }
208 } // namespace quda
double mu
Definition: test_util.cpp:1648
void setParam(Arg &arg)
Definition: dslash.h:66
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define checkPrecision(...)
#define errorQuda(...)
Definition: util_quda.h:121
cudaStream_t * stream
void ApplyCovDev(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, int mu, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
Driver for applying the covariant derivative.
Definition: covDev.cu:185
void xpay(ColorSpinorField &x, double a, ColorSpinorField &y)
Definition: blas_quda.h:37
void popKernelPackT()
Definition: dslash_quda.cu:42
__device__ __host__ void covDev(Arg &arg, int idx, int parity)
Definition: covDev.cuh:119
cpuColorSpinorField * in
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
#define checkLocation(...)
Main header file for host and device accessors to GaugeFields.
cpuColorSpinorField * out
void u32toa(char *buffer, uint32_t value)
Definition: uint_to_char.h:45
static const int aux_n
Definition: tune_key.h:12
unsigned long long flops
Definition: blas_quda.cu:22
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
void pushKernelPackT(bool pack)
Definition: dslash_quda.cu:30
#define checkCudaError()
Definition: util_quda.h:161
char aux[8][TuneKey::aux_n]
Definition: dslash.h:23
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
QudaDagType dagger
Definition: test_util.cpp:1620
QudaParity parity
Definition: covdev_test.cpp:54
QudaFieldOrder FieldOrder() const
unsigned long long bytes
Definition: blas_quda.cu:23