QUDA  v1.1.0
A library for QCD on GPUs
dslash_quda.cu
Go to the documentation of this file.
1 #include <stack>
2 
3 #include <color_spinor_field.h>
4 #include <clover_field.h>
5 #include <dslash_quda.h>
6 #include <color_spinor_field_order.h>
7 #include <clover_field_order.h>
8 #include <index_helper.cuh>
9 #include <color_spinor.h>
10 #include <linalg.cuh>
11 #include <dslash_policy.cuh>
12 #include <instantiate.h>
13 #ifdef NVSHMEM_COMMS
14 #include <cuda/atomic>
15 #endif
16 
17 namespace quda {
18 
19  // these should not be namespaced!!
20  // determines whether the temporal ghost zones are packed with a gather kernel,
21  // as opposed to multiple memcpys
22  static bool kernelPackT = false;
23 
24  void setKernelPackT(bool packT) { kernelPackT = packT; }
25 
26  bool getKernelPackT() { return kernelPackT; }
27 
28  static std::stack<bool> kptstack;
29 
30  void pushKernelPackT(bool packT)
31  {
32  kptstack.push(getKernelPackT());
33  setKernelPackT(packT);
34 
35  if (kptstack.size() > 10)
36  {
37  warningQuda("KernelPackT stack contains %u elements. Is there a missing popKernelPackT() somewhere?",
38  static_cast<unsigned int>(kptstack.size()));
39  }
40  }
41 
42  void popKernelPackT()
43  {
44  if (kptstack.empty())
45  {
46  errorQuda("popKernelPackT() called with empty stack");
47  }
48  setKernelPackT(kptstack.top());
49  kptstack.pop();
50  }
51 
52  namespace dslash {
53  int it = 0;
54 
55  cudaEvent_t packEnd[2];
56  cudaEvent_t gatherStart[Nstream];
57  cudaEvent_t gatherEnd[Nstream];
58  cudaEvent_t scatterStart[Nstream];
59  cudaEvent_t scatterEnd[Nstream];
60  cudaEvent_t dslashStart[2];
61 
62  // for shmem lightweight sync
63  shmem_sync_t sync_counter = 10;
64  shmem_sync_t get_shmem_sync_counter() { return sync_counter; }
65  shmem_sync_t set_shmem_sync_counter(shmem_sync_t count) { return sync_counter = count; }
66  shmem_sync_t inc_shmem_sync_counter() { return sync_counter++; }
67 #ifdef NVSHMEM_COMMS
68  shmem_sync_t *sync_arr = nullptr;
69  shmem_retcount_intra_t *_retcount_intra = nullptr;
70  shmem_retcount_inter_t *_retcount_inter = nullptr;
71  shmem_interior_done_t *_interior_done = nullptr;
72  shmem_interior_count_t *_interior_count = nullptr;
73  shmem_sync_t *get_shmem_sync_arr() { return sync_arr; }
74  shmem_retcount_intra_t *get_shmem_retcount_intra() { return _retcount_intra; }
75  shmem_retcount_inter_t *get_shmem_retcount_inter() { return _retcount_inter; }
76  shmem_interior_done_t *get_shmem_interior_done() { return _interior_done; }
77  shmem_interior_count_t *get_shmem_interior_count() { return _interior_count; }
78 #endif
79 
80  // these variables are used for benchmarking the dslash components in isolation
81  bool dslash_pack_compute;
82  bool dslash_interior_compute;
83  bool dslash_exterior_compute;
84  bool dslash_comms;
85  bool dslash_copy;
86 
87  // whether the dslash policy tuner has been enabled
88  bool dslash_policy_init;
89 
90  // used to keep track of which policy to start the autotuning
91  int first_active_policy;
92  int first_active_p2p_policy;
93 
94  // list of dslash policies that are enabled
95  std::vector<QudaDslashPolicy> policies;
96 
97  // list of p2p policies that are enabled
98  std::vector<QudaP2PPolicy> p2p_policies;
99 
100  // string used as a tunekey to ensure we retune if the dslash policy env changes
101  char policy_string[TuneKey::aux_n];
102 
103  // FIX this is a hack from hell
104  // Auxiliary work that can be done while waiting on comms to finis
105  Worker *aux_worker;
106  }
107 
108  // need to use placement new constructor to initialize the atomic counters
109  template <typename T> __global__ void init_dslash_atomic(T *counter, int max)
110  {
111  for (int i = 0; i < max; i++) new (counter + i) T {0};
112  }
113  // need to use placement new constructor to initialize the atomic counters
114  template <typename T> __global__ void init_sync_arr(T *arr, T val, int max)
115  {
116  for (int i = 0; i < max; i++) *(arr + i) = val;
117  }
118 
119  void createDslashEvents()
120  {
121  using namespace dslash;
122  // add cudaEventDisableTiming for lower sync overhead
123  for (int i=0; i<Nstream; i++) {
124  cudaEventCreateWithFlags(&gatherStart[i], cudaEventDisableTiming);
125  cudaEventCreateWithFlags(&gatherEnd[i], cudaEventDisableTiming);
126  cudaEventCreateWithFlags(&scatterStart[i], cudaEventDisableTiming);
127  cudaEventCreateWithFlags(&scatterEnd[i], cudaEventDisableTiming);
128  }
129  for (int i=0; i<2; i++) {
130  cudaEventCreateWithFlags(&packEnd[i], cudaEventDisableTiming);
131  cudaEventCreateWithFlags(&dslashStart[i], cudaEventDisableTiming);
132  }
133 #ifdef NVSHMEM_COMMS
134  sync_arr = static_cast<shmem_sync_t *>(device_comms_pinned_malloc(2 * QUDA_MAX_DIM * sizeof(shmem_sync_t)));
135  TuneParam tp;
136  tp.grid = dim3(1, 1, 1);
137  tp.block = dim3(1, 1, 1);
138 
139  /* initialize to 9 here so in cases where we need to do tuning we can skip the wait if necessary
140  by using smaller values */
141  qudaLaunchKernel(init_sync_arr<shmem_sync_t>, tp, 0, sync_arr, static_cast<shmem_sync_t>(9), 2 * QUDA_MAX_DIM);
142  sync_counter = 10;
143 
144  // atomic for controlling signaling in nvshmem packing
145  _retcount_intra
146  = static_cast<shmem_retcount_intra_t *>(device_pinned_malloc(2 * QUDA_MAX_DIM * sizeof(shmem_retcount_intra_t)));
147  qudaLaunchKernel(init_dslash_atomic<shmem_retcount_intra_t>, tp, 0, _retcount_intra, 2 * QUDA_MAX_DIM);
148  _retcount_inter
149  = static_cast<shmem_retcount_inter_t *>(device_pinned_malloc(2 * QUDA_MAX_DIM * sizeof(shmem_retcount_inter_t)));
150  qudaLaunchKernel(init_dslash_atomic<shmem_retcount_inter_t>, tp, 0, _retcount_inter, 2 * QUDA_MAX_DIM);
151  // workspace for interior done sync in uber kernel
152  _interior_done = static_cast<shmem_interior_done_t *>(device_pinned_malloc(sizeof(shmem_interior_done_t)));
153  qudaLaunchKernel(init_dslash_atomic<shmem_interior_done_t>, tp, 0, _interior_done, 1);
154  _interior_count = static_cast<shmem_interior_count_t *>(device_pinned_malloc(sizeof(shmem_interior_count_t)));
155  qudaLaunchKernel(init_dslash_atomic<shmem_interior_count_t>, tp, 0, _interior_count, 1);
156 #endif
157 
158  aux_worker = NULL;
159 
160  checkCudaError();
161 
162  dslash_pack_compute = true;
163  dslash_interior_compute = true;
164  dslash_exterior_compute = true;
165  dslash_comms = true;
166  dslash_copy = true;
167 
168  dslash_policy_init = false;
169  first_active_policy = 0;
170  first_active_p2p_policy = 0;
171 
172  // list of dslash policies that are enabled
173  policies = std::vector<QudaDslashPolicy>(
174  static_cast<int>(QudaDslashPolicy::QUDA_DSLASH_POLICY_DISABLED), QudaDslashPolicy::QUDA_DSLASH_POLICY_DISABLED);
175 
176  // list of p2p policies that are enabled
177  p2p_policies = std::vector<QudaP2PPolicy>(
178  static_cast<int>(QudaP2PPolicy::QUDA_P2P_POLICY_DISABLED), QudaP2PPolicy::QUDA_P2P_POLICY_DISABLED);
179 
180  strcat(policy_string, ",pol=");
181  }
182 
183 
184  void destroyDslashEvents()
185  {
186  using namespace dslash;
187 
188  for (int i=0; i<Nstream; i++) {
189  cudaEventDestroy(gatherStart[i]);
190  cudaEventDestroy(gatherEnd[i]);
191  cudaEventDestroy(scatterStart[i]);
192  cudaEventDestroy(scatterEnd[i]);
193  }
194 
195  for (int i=0; i<2; i++) {
196  cudaEventDestroy(packEnd[i]);
197  cudaEventDestroy(dslashStart[i]);
198  }
199 #ifdef NVSHMEM_COMMS
200  device_comms_pinned_free(sync_arr);
201  device_pinned_free(_retcount_intra);
202  device_pinned_free(_retcount_inter);
203  device_pinned_free(_interior_done);
204  device_pinned_free(_interior_count);
205 #endif
206  checkCudaError();
207  }
208 
209  /**
210  @brief Parameter structure for driving the Gamma operator
211  */
212  template <typename Float, int nColor>
213  struct GammaArg {
214  typedef typename colorspinor_mapper<Float,4,nColor>::type F;
215  typedef typename mapper<Float>::type RegType;
216 
217  F out; // output vector field
218  const F in; // input vector field
219  const int d; // which gamma matrix are we applying
220  const int nParity; // number of parities we're working on
221  bool doublet; // whether we applying the operator to a doublet
222  const int volumeCB; // checkerboarded volume
223  RegType a; // scale factor
224  RegType b; // chiral twist
225  RegType c; // flavor twist
226 
227  GammaArg(ColorSpinorField &out, const ColorSpinorField &in, int d,
228  RegType kappa=0.0, RegType mu=0.0, RegType epsilon=0.0,
229  bool dagger=false, QudaTwistGamma5Type twist=QUDA_TWIST_GAMMA5_INVALID)
230  : out(out), in(in), d(d), nParity(in.SiteSubset()),
231  doublet(in.TwistFlavor() == QUDA_TWIST_DEG_DOUBLET || in.TwistFlavor() == QUDA_TWIST_NONDEG_DOUBLET),
232  volumeCB(doublet ? in.VolumeCB()/2 : in.VolumeCB()), a(0.0), b(0.0), c(0.0)
233  {
234  checkPrecision(out, in);
235  checkLocation(out, in);
236  if (d < 0 || d > 4) errorQuda("Undefined gamma matrix %d", d);
237  if (in.Nspin() != 4) errorQuda("Cannot apply gamma5 to nSpin=%d field", in.Nspin());
238  if (!in.isNative() || !out.isNative()) errorQuda("Unsupported field order out=%d in=%d\n", out.FieldOrder(), in.FieldOrder());
239 
240  if (in.TwistFlavor() == QUDA_TWIST_SINGLET) {
241  if (twist == QUDA_TWIST_GAMMA5_DIRECT) {
242  b = 2.0 * kappa * mu;
243  a = 1.0;
244  } else if (twist == QUDA_TWIST_GAMMA5_INVERSE) {
245  b = -2.0 * kappa * mu;
246  a = 1.0 / (1.0 + b * b);
247  }
248  c = 0.0;
249  if (dagger) b *= -1.0;
250  } else if (doublet) {
251  if (twist == QUDA_TWIST_GAMMA5_DIRECT) {
252  b = 2.0 * kappa * mu;
253  c = -2.0 * kappa * epsilon;
254  a = 1.0;
255  } else if (twist == QUDA_TWIST_GAMMA5_INVERSE) {
256  b = -2.0 * kappa * mu;
257  c = 2.0 * kappa * epsilon;
258  a = 1.0 / (1.0 + b * b - c * c);
259  if (a <= 0) errorQuda("Invalid twisted mass parameters (kappa=%e, mu=%e, epsilon=%e)\n", kappa, mu, epsilon);
260  }
261  if (dagger) b *= -1.0;
262  }
263  }
264  };
265 
266  // CPU kernel for applying the gamma matrix to a colorspinor
267  template <typename Float, int nColor, typename Arg>
268  void gammaCPU(Arg arg)
269  {
270  typedef typename mapper<Float>::type RegType;
271  for (int parity= 0; parity < arg.nParity; parity++) {
272 
273  for (int x_cb = 0; x_cb < arg.volumeCB; x_cb++) { // 4-d volume
274  ColorSpinor<RegType,nColor,4> in = arg.in(x_cb, parity);
275  arg.out(x_cb, parity) = in.gamma(arg.d);
276  } // 4-d volumeCB
277  } // parity
278 
279  }
280 
281  // GPU Kernel for applying the gamma matrix to a colorspinor
282  template <typename Float, int nColor, int d, typename Arg>
283  __global__ void gammaGPU(Arg arg)
284  {
285  typedef typename mapper<Float>::type RegType;
286  int x_cb = blockIdx.x*blockDim.x + threadIdx.x;
287  int parity = blockDim.y*blockIdx.y + threadIdx.y;
288 
289  if (x_cb >= arg.volumeCB) return;
290  if (parity >= arg.nParity) return;
291 
292  ColorSpinor<RegType,nColor,4> in = arg.in(x_cb, parity);
293  arg.out(x_cb, parity) = in.gamma(d);
294  }
295 
296  template <typename Float, int nColor>
297  class Gamma : public TunableVectorY {
298 
299  GammaArg<Float, nColor> arg;
300  const ColorSpinorField &meta;
301 
302  long long flops() const { return 0; }
303  long long bytes() const { return arg.out.Bytes() + arg.in.Bytes(); }
304  bool tuneGridDim() const { return false; }
305  unsigned int minThreads() const { return arg.volumeCB; }
306 
307  public:
308  Gamma(ColorSpinorField &out, const ColorSpinorField &in, int d) :
309  TunableVectorY(in.SiteSubset()),
310  arg(out, in, d),
311  meta(in)
312  {
313  strcpy(aux, meta.AuxString());
314 
315  apply(streams[Nstream-1]);
316  }
317 
318  void apply(const qudaStream_t &stream) {
319  if (meta.Location() == QUDA_CPU_FIELD_LOCATION) {
320  gammaCPU<Float,nColor>(arg);
321  } else {
322  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
323  switch (arg.d) {
324  case 4: qudaLaunchKernel(gammaGPU<Float,nColor,4,decltype(arg)>, tp, stream, arg); break;
325  default: errorQuda("%d not instantiated", arg.d);
326  }
327  }
328  }
329 
330  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
331 
332  void preTune() { arg.out.save(); }
333  void postTune() { arg.out.load(); }
334  };
335 
336  //Apply the Gamma matrix to a colorspinor field
337  //out(x) = gamma_d*in
338  void ApplyGamma(ColorSpinorField &out, const ColorSpinorField &in, int d)
339  {
340  instantiate<Gamma>(out, in, d);
341  }
342 
343  // CPU kernel for applying the gamma matrix to a colorspinor
344  template <bool doublet, typename Float, int nColor, typename Arg>
345  void twistGammaCPU(Arg arg)
346  {
347  typedef typename mapper<Float>::type RegType;
348  for (int parity= 0; parity < arg.nParity; parity++) {
349  for (int x_cb = 0; x_cb < arg.volumeCB; x_cb++) { // 4-d volume
350  if (!doublet) {
351  ColorSpinor<RegType,nColor,4> in = arg.in(x_cb, parity);
352  arg.out(x_cb, parity) = arg.a * (in + arg.b * in.igamma(arg.d));
353  } else {
354  ColorSpinor<RegType,nColor,4> in_1 = arg.in(x_cb+0*arg.volumeCB, parity);
355  ColorSpinor<RegType,nColor,4> in_2 = arg.in(x_cb+1*arg.volumeCB, parity);
356  arg.out(x_cb + 0 * arg.volumeCB, parity) = arg.a * (in_1 + arg.b * in_1.igamma(arg.d) + arg.c * in_2);
357  arg.out(x_cb + 1 * arg.volumeCB, parity) = arg.a * (in_2 - arg.b * in_2.igamma(arg.d) + arg.c * in_1);
358  }
359  } // 4-d volumeCB
360  } // parity
361 
362  }
363 
364  // GPU Kernel for applying the gamma matrix to a colorspinor
365  template <bool doublet, typename Float, int nColor, int d, typename Arg>
366  __global__ void twistGammaGPU(Arg arg)
367  {
368  typedef typename mapper<Float>::type RegType;
369  int x_cb = blockIdx.x*blockDim.x + threadIdx.x;
370  int parity = blockDim.y*blockIdx.y + threadIdx.y;
371  if (x_cb >= arg.volumeCB) return;
372 
373  if (!doublet) {
374  ColorSpinor<RegType,nColor,4> in = arg.in(x_cb, parity);
375  arg.out(x_cb, parity) = arg.a * (in + arg.b * in.igamma(d));
376  } else {
377  ColorSpinor<RegType,nColor,4> in_1 = arg.in(x_cb+0*arg.volumeCB, parity);
378  ColorSpinor<RegType,nColor,4> in_2 = arg.in(x_cb+1*arg.volumeCB, parity);
379  arg.out(x_cb + 0 * arg.volumeCB, parity) = arg.a * (in_1 + arg.b * in_1.igamma(d) + arg.c * in_2);
380  arg.out(x_cb + 1 * arg.volumeCB, parity) = arg.a * (in_2 - arg.b * in_2.igamma(d) + arg.c * in_1);
381  }
382  }
383 
384  template <typename Float, int nColor>
385  class TwistGamma : public TunableVectorY {
386 
387  GammaArg<Float, nColor> arg;
388  const ColorSpinorField &meta;
389 
390  long long flops() const { return 0; }
391  long long bytes() const { return arg.out.Bytes() + arg.in.Bytes(); }
392  bool tuneGridDim() const { return false; }
393  unsigned int minThreads() const { return arg.volumeCB; }
394 
395  public:
396  TwistGamma(ColorSpinorField &out, const ColorSpinorField &in, int d, double kappa, double mu, double epsilon, int dagger, QudaTwistGamma5Type type) :
397  TunableVectorY(in.SiteSubset()),
398  arg(out, in, d, kappa, mu, epsilon, dagger, type),
399  meta(in)
400  {
401  strcpy(aux, meta.AuxString());
402 
403  apply(streams[Nstream-1]);
404  }
405 
406  void apply(const qudaStream_t &stream) {
407  if (meta.Location() == QUDA_CPU_FIELD_LOCATION) {
408  if (arg.doublet) twistGammaCPU<true,Float,nColor>(arg);
409  twistGammaCPU<false,Float,nColor>(arg);
410  } else {
411  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
412  if (arg.doublet)
413  switch (arg.d) {
414  case 4: qudaLaunchKernel(twistGammaGPU<true,Float,nColor,4,decltype(arg)>, tp, stream, arg); break;
415  default: errorQuda("%d not instantiated", arg.d);
416  }
417  else
418  switch (arg.d) {
419  case 4: qudaLaunchKernel(twistGammaGPU<false,Float,nColor,4,decltype(arg)>, tp, stream, arg); break;
420  default: errorQuda("%d not instantiated", arg.d);
421  }
422  }
423  }
424 
425  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
426  void preTune() { if (arg.out.field == arg.in.field) arg.out.save(); }
427  void postTune() { if (arg.out.field == arg.in.field) arg.out.load(); }
428  };
429 
430  //Apply the Gamma matrix to a colorspinor field
431  //out(x) = gamma_d*in
432  void ApplyTwistGamma(ColorSpinorField &out, const ColorSpinorField &in, int d, double kappa, double mu, double epsilon, int dagger, QudaTwistGamma5Type type)
433  {
434 #ifdef GPU_TWISTED_MASS_DIRAC
435  instantiate<TwistGamma>(out, in, d, kappa, mu, epsilon, dagger, type);
436 #else
437  errorQuda("Twisted mass dslash has not been built");
438 #endif // GPU_TWISTED_MASS_DIRAC
439  }
440 
441  // Applies a gamma5 matrix to a spinor (wrapper to ApplyGamma)
442  void gamma5(ColorSpinorField &out, const ColorSpinorField &in) { ApplyGamma(out,in,4); }
443 
444  /**
445  @brief Parameteter structure for driving the clover and twist-clover application kernels
446  @tparam Float Underlying storage precision
447  @tparam nSpin Number of spin components
448  @tparam nColor Number of colors
449  @tparam dynamic_clover Whether we are inverting the clover field on the fly
450  */
451  template <typename Float, int nSpin, int nColor>
452  struct CloverArg {
453  static constexpr int length = (nSpin / (nSpin/2)) * 2 * nColor * nColor * (nSpin/2) * (nSpin/2) / 2;
454  static constexpr bool dynamic_clover = dynamic_clover_inverse();
455 
456  typedef typename colorspinor_mapper<Float,nSpin,nColor>::type F;
457  typedef typename clover_mapper<Float,length>::type C;
458  typedef typename mapper<Float>::type RegType;
459 
460  F out; // output vector field
461  const F in; // input vector field
462  const C clover; // clover field
463  const C cloverInv; // inverse clover field (only set if not dynamic clover and doing twisted clover)
464  const int nParity; // number of parities we're working on
465  const int parity; // which parity we're acting on (if nParity=1)
466  bool inverse; // whether we are applying the inverse
467  bool doublet; // whether we applying the operator to a doublet
468  const int volumeCB; // checkerboarded volume
469  RegType a;
470  RegType b;
471  RegType c;
472  QudaTwistGamma5Type twist;
473 
474  CloverArg(ColorSpinorField &out, const ColorSpinorField &in, const CloverField &clover,
475  bool inverse, int parity, RegType kappa=0.0, RegType mu=0.0, RegType epsilon=0.0,
476  bool dagger = false, QudaTwistGamma5Type twist=QUDA_TWIST_GAMMA5_INVALID)
477  : out(out), clover(clover, twist == QUDA_TWIST_GAMMA5_INVALID ? inverse : false),
478  cloverInv(clover, (twist != QUDA_TWIST_GAMMA5_INVALID && !dynamic_clover) ? true : false),
479  in(in), nParity(in.SiteSubset()), parity(parity), inverse(inverse),
480  doublet(in.TwistFlavor() == QUDA_TWIST_DEG_DOUBLET || in.TwistFlavor() == QUDA_TWIST_NONDEG_DOUBLET),
481  volumeCB(doublet ? in.VolumeCB()/2 : in.VolumeCB()), a(0.0), b(0.0), c(0.0), twist(twist)
482  {
483  checkPrecision(out, in, clover);
484  checkLocation(out, in, clover);
485  if (in.TwistFlavor() == QUDA_TWIST_SINGLET) {
486  if (twist == QUDA_TWIST_GAMMA5_DIRECT) {
487  a = 2.0 * kappa * mu;
488  b = 1.0;
489  } else if (twist == QUDA_TWIST_GAMMA5_INVERSE) {
490  a = -2.0 * kappa * mu;
491  b = 1.0 / (1.0 + a*a);
492  }
493  c = 0.0;
494  if (dagger) a *= -1.0;
495  } else if (doublet) {
496  errorQuda("ERROR: Non-degenerated twisted-mass not supported in this regularization\n");
497  }
498  }
499  };
500 
501  template <typename Float, int nSpin, int nColor, typename Arg>
502  __device__ __host__ inline void cloverApply(Arg &arg, int x_cb, int parity) {
503  using namespace linalg; // for Cholesky
504  typedef typename mapper<Float>::type RegType;
505  typedef ColorSpinor<RegType, nColor, nSpin> Spinor;
506  typedef ColorSpinor<RegType, nColor, nSpin / 2> HalfSpinor;
507  int spinor_parity = arg.nParity == 2 ? parity : 0;
508  Spinor in = arg.in(x_cb, spinor_parity);
509  Spinor out;
510 
511  in.toRel(); // change to chiral basis here
512 
513 #pragma unroll
514  for (int chirality=0; chirality<2; chirality++) {
515 
516  HMatrix<RegType,nColor*nSpin/2> A = arg.clover(x_cb, parity, chirality);
517  HalfSpinor chi = in.chiral_project(chirality);
518 
519  if (arg.dynamic_clover) {
520  Cholesky<HMatrix, RegType, nColor * nSpin / 2> cholesky(A);
521  chi = static_cast<RegType>(0.25) * cholesky.backward(cholesky.forward(chi));
522  } else {
523  chi = A * chi;
524  }
525 
526  out += chi.chiral_reconstruct(chirality);
527  }
528 
529  out.toNonRel(); // change basis back
530 
531  arg.out(x_cb, spinor_parity) = out;
532  }
533 
534  template <typename Float, int nSpin, int nColor, typename Arg>
535  void cloverCPU(Arg &arg) {
536  for (int parity=0; parity<arg.nParity; parity++) {
537  parity = (arg.nParity == 2) ? parity : arg.parity;
538  for (int x_cb=0; x_cb<arg.volumeCB; x_cb++) cloverApply<Float,nSpin,nColor>(arg, x_cb, parity);
539  }
540  }
541 
542  template <typename Float, int nSpin, int nColor, typename Arg>
543  __global__ void cloverGPU(Arg arg) {
544  int x_cb = blockIdx.x*blockDim.x + threadIdx.x;
545  int parity = (arg.nParity == 2) ? blockDim.y*blockIdx.y + threadIdx.y : arg.parity;
546  if (x_cb >= arg.volumeCB) return;
547  cloverApply<Float,nSpin,nColor>(arg, x_cb, parity);
548  }
549 
550  template <typename Float, int nColor>
551  class Clover : public TunableVectorY {
552 
553  static constexpr int nSpin = 4;
554  CloverArg<Float, nSpin, nColor> arg;
555  const ColorSpinorField &meta;
556 
557  long long flops() const { return arg.nParity*arg.volumeCB*504ll; }
558  long long bytes() const { return arg.out.Bytes() + arg.in.Bytes() + arg.nParity*arg.volumeCB*arg.clover.Bytes(); }
559  bool tuneGridDim() const { return false; }
560  unsigned int minThreads() const { return arg.volumeCB; }
561 
562  public:
563  Clover(ColorSpinorField &out, const ColorSpinorField &in, const CloverField &clover, bool inverse, int parity) :
564  TunableVectorY(in.SiteSubset()),
565  arg(out, in, clover, inverse, parity),
566  meta(in)
567  {
568  if (in.Nspin() != 4 || out.Nspin() != 4) errorQuda("Unsupported nSpin=%d %d", out.Nspin(), in.Nspin());
569  if (!inverse) errorQuda("Unsupported direct application");
570  strcpy(aux, meta.AuxString());
571 
572  apply(streams[Nstream-1]);
573  }
574 
575  void apply(const qudaStream_t &stream)
576  {
577  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
578  if (meta.Location() == QUDA_CPU_FIELD_LOCATION) {
579  cloverCPU<Float,nSpin,nColor>(arg);
580  } else {
581  qudaLaunchKernel(cloverGPU<Float,nSpin,nColor,decltype(arg)>, tp, stream, arg);
582  }
583  }
584 
585  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
586  void preTune() { if (arg.out.field == arg.in.field) arg.out.save(); } // Need to save the out field if it aliases the in field
587  void postTune() { if (arg.out.field == arg.in.field) arg.out.load(); } // Restore if the in and out fields alias
588  };
589 
590  //Apply the clover matrix field to a colorspinor field
591  //out(x) = clover*in
592  void ApplyClover(ColorSpinorField &out, const ColorSpinorField &in, const CloverField &clover, bool inverse, int parity)
593  {
594 #ifdef GPU_CLOVER_DIRAC
595  instantiate<Clover>(out, in, clover, inverse, parity);
596 #else
597  errorQuda("Clover dslash has not been built");
598 #endif // GPU_TWISTED_MASS_DIRAC
599  }
600 
601  // if (!inverse) apply (Clover + i*a*gamma_5) to the input spinor
602  // else apply (Clover + i*a*gamma_5)/(Clover^2 + a^2) to the input spinor
603  template <bool inverse, typename Float, int nSpin, int nColor, typename Arg>
604  __device__ __host__ inline void twistCloverApply(Arg &arg, int x_cb, int parity) {
605  using namespace linalg; // for Cholesky
606  constexpr int N = nColor*nSpin/2;
607  typedef typename mapper<Float>::type RegType;
608  typedef ColorSpinor<RegType,nColor,nSpin> Spinor;
609  typedef ColorSpinor<RegType,nColor,nSpin/2> HalfSpinor;
610  typedef HMatrix<RegType,N> Mat;
611  int spinor_parity = arg.nParity == 2 ? parity : 0;
612  Spinor in = arg.in(x_cb, spinor_parity);
613  Spinor out;
614 
615  in.toRel(); // change to chiral basis here
616 
617 #pragma unroll
618  for (int chirality=0; chirality<2; chirality++) {
619  // factor of 2 comes from clover normalization we need to correct for
620  const complex<RegType> j(0.0, chirality == 0 ? static_cast<RegType>(0.5) : -static_cast<RegType>(0.5));
621 
622  Mat A = arg.clover(x_cb, parity, chirality);
623 
624  HalfSpinor in_chi = in.chiral_project(chirality);
625  HalfSpinor out_chi = A*in_chi + j*arg.a*in_chi;
626 
627  if (inverse) {
628  if (arg.dynamic_clover) {
629  Mat A2 = A.square();
630  A2 += arg.a*arg.a*static_cast<RegType>(0.25);
631  Cholesky<HMatrix,RegType,N> cholesky(A2);
632  out_chi = static_cast<RegType>(0.25)*cholesky.backward(cholesky.forward(out_chi));
633  } else {
634  Mat Ainv = arg.cloverInv(x_cb, parity, chirality);
635  out_chi = static_cast<RegType>(2.0)*(Ainv*out_chi);
636  }
637  }
638 
639  out += (out_chi).chiral_reconstruct(chirality);
640  }
641 
642  out.toNonRel(); // change basis back
643 
644  arg.out(x_cb, spinor_parity) = out;
645  }
646 
647  template <bool inverse, typename Float, int nSpin, int nColor, typename Arg>
648  void twistCloverCPU(Arg &arg) {
649  for (int parity=0; parity<arg.nParity; parity++) {
650  parity = (arg.nParity == 2) ? parity : arg.parity;
651  for (int x_cb=0; x_cb<arg.volumeCB; x_cb++) twistCloverApply<inverse,Float,nSpin,nColor>(arg, x_cb, parity);
652  }
653  }
654 
655  template <bool inverse, typename Float, int nSpin, int nColor, typename Arg>
656  __global__ void twistCloverGPU(Arg arg) {
657  int x_cb = blockIdx.x*blockDim.x + threadIdx.x;
658  int parity = (arg.nParity == 2) ? blockDim.y*blockIdx.y + threadIdx.y : arg.parity;
659  if (x_cb >= arg.volumeCB) return;
660  twistCloverApply<inverse,Float,nSpin,nColor>(arg, x_cb, parity);
661  }
662 
663  template <typename Float, int nColor>
664  class TwistClover : public TunableVectorY {
665 
666  static constexpr int nSpin = 4;
667  CloverArg<Float,nSpin,nColor> arg;
668  const ColorSpinorField &meta;
669 
670  long long flops() const { return (arg.inverse ? 1056ll : 552ll) * arg.nParity*arg.volumeCB; }
671  long long bytes() const {
672  long long rtn = arg.out.Bytes() + arg.in.Bytes() + arg.nParity*arg.volumeCB*arg.clover.Bytes();
673  if (arg.twist == QUDA_TWIST_GAMMA5_INVERSE && !arg.dynamic_clover)
674  rtn += arg.nParity*arg.volumeCB*arg.cloverInv.Bytes();
675  return rtn;
676  }
677  bool tuneGridDim() const { return false; }
678  unsigned int minThreads() const { return arg.volumeCB; }
679 
680  public:
681  TwistClover(ColorSpinorField &out, const ColorSpinorField &in, const CloverField &clover,
682  double kappa, double mu, double epsilon, int parity, int dagger, QudaTwistGamma5Type twist) :
683  TunableVectorY(in.SiteSubset()),
684  arg(out, in, clover, twist != QUDA_TWIST_GAMMA5_DIRECT, parity, kappa, mu, epsilon, dagger, twist),
685  meta(in)
686  {
687  if (in.Nspin() != 4 || out.Nspin() != 4) errorQuda("Unsupported nSpin=%d %d", out.Nspin(), in.Nspin());
688  strcpy(aux, meta.AuxString());
689  strcat(aux, arg.inverse ? ",inverse" : ",direct");
690 
691  apply(streams[Nstream-1]);
692  }
693 
694  void apply(const qudaStream_t &stream)
695  {
696  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
697  if (meta.Location() == QUDA_CPU_FIELD_LOCATION) {
698  if (arg.inverse) twistCloverCPU<true,Float,nSpin,nColor>(arg);
699  else twistCloverCPU<false,Float,nSpin,nColor>(arg);
700  } else {
701  if (arg.inverse) qudaLaunchKernel(twistCloverGPU<true,Float,nSpin,nColor,decltype(arg)>, tp, stream, arg);
702  else qudaLaunchKernel(twistCloverGPU<false,Float,nSpin,nColor,decltype(arg)>, tp, stream, arg);
703  }
704  }
705 
706  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
707  void preTune() { if (arg.out.field == arg.in.field) arg.out.save(); } // Need to save the out field if it aliases the in field
708  void postTune() { if (arg.out.field == arg.in.field) arg.out.load(); } // Restore if the in and out fields alias
709  };
710 
711  //Apply the twisted-clover matrix field to a colorspinor field
712  void ApplyTwistClover(ColorSpinorField &out, const ColorSpinorField &in, const CloverField &clover,
713  double kappa, double mu, double epsilon, int parity, int dagger, QudaTwistGamma5Type twist)
714  {
715 #ifdef GPU_CLOVER_DIRAC
716  instantiate<TwistClover>(out, in, clover, kappa, mu, epsilon, parity, dagger, twist);
717 #else
718  errorQuda("Clover dslash has not been built");
719 #endif // GPU_TWISTED_MASS_DIRAC
720  }
721 
722 } // namespace quda