QUDA  v1.1.0
A library for QCD on GPUs
mdw_fused_dslash.cu
Go to the documentation of this file.
1 #include <gauge_field.h>
2 #include <gauge_field_order.h>
3 
4 #include <typeinfo>
5 
6 #include <color_spinor_field.h>
7 #include <tune_quda.h>
8 #include <dslash_quda.h>
9 #include <jitify_helper.cuh>
10 #include <instantiate_dslash.h>
11 
12 #if (CUDA_VERSION >= 10010 && __COMPUTE_CAPABILITY__ >= 700)
13 #include <mdw_dslash5_tensor_core.cuh>
14 #endif
15 
16 namespace quda
17 {
18  namespace mobius_tensor_core
19  {
20 
21 #if (CUDA_VERSION >= 10010 && __COMPUTE_CAPABILITY__ >= 700)
22 
23  constexpr int sm_m_pad_size(int m)
24  {
25  return quda::mma::pad_size(m);
26  }
27 
28  constexpr int sm_n_pad_size(int n)
29  {
30  return quda::mma::pad_size(n);
31  }
32 
33  /**
34  @brief Parameter structure for applying the Dslash
35  */
36  template <class storage_type_, QudaReconstructType recon_,
37  int Ls_> // storage_type is the usual "Float" in other places in QUDA
38  struct FusedDslashArg {
39  using storage_type = storage_type_;
40  using real = typename mapper<storage_type>::type; // the compute type for the in kernel computation
41  static constexpr QudaReconstructType recon = recon_;
42  static constexpr int Ls = Ls_;
43  static constexpr bool spin_project = true;
44  static constexpr bool spinor_direct_load = true; // false means texture load
45 #ifdef FLOAT8
46  using F
47  = colorspinor::FloatNOrder<storage_type, 4, 3, 8, spin_project, spinor_direct_load>; // color spin field order
48 #else
49  using F
50  = colorspinor::FloatNOrder<storage_type, 4, 3, 4, spin_project, spinor_direct_load>; // color spin field order
51 #endif
52  static constexpr bool gauge_direct_load = true; // false means texture load
53  static constexpr QudaGhostExchange ghost = QUDA_GHOST_EXCHANGE_EXTENDED; // gauge field used is an extended one
54  using G = typename gauge_mapper<storage_type, recon, 18, QUDA_STAGGERED_PHASE_NO, gauge_direct_load, ghost>::type; // gauge field order
55 
56  F out; // output vector field
57  const F in; // input vector field
58  F y; // auxiliary output vector field
59  const F x; // auxiliary input vector field
60 
61  const G U; // The gauge field
62 
63  const int nParity; // number of parities we're working on
64  const int parity; // output parity of this dslash operator
65  const int volume_cb; // checkerboarded volume
66  const int volume_4d_cb; // 4-d checkerboarded volume
67 
68  const int dim[4];
69  const int shift[4]; // sites where we actually calculate.
70  const int halo_shift[4]; // halo means zero. When we are expanding we have halo of cs-field where values are zero.
71 
72  const int_fastdiv shrinked_dim[4]; // dimension after shifts are considered.
73 
74  // partial kernel and expansion parameters
75  const int volume_4d_cb_shift; // number of 4d sites we need calculate
76  // const int volume_4d_cb_expansive; //
77 
78  const real m_f; // fermion mass parameter
79  const real m_5; // Wilson mass shift
80 
81  const bool dagger; // dagger
82  // const bool xpay; // whether we are doing xpay or not
83 
84  real b; // real constant Mobius coefficient
85  real c; // real constant Mobius coefficient
86  real a; // real xpay coefficient
87 
88  real kappa;
89  real fac_inv;
90 
91  // (beta + alpha*m5inv) * in
92  real alpha = 1.;
93  real beta = 0.;
94 
95  real m_scale = 1.; // scale factor for the matrix
96 
97  bool small_kappa = false;
98 
99  const bool comm[4];
100 
101  MdwfFusedDslashType type;
102  FusedDslashArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, ColorSpinorField &y,
103  const ColorSpinorField &x, double m_f_, double m_5_, const Complex *b_5, const Complex *c_5,
104  bool dagger_, int parity, int shift_[4], int halo_shift_[4], MdwfFusedDslashType type_) :
105  out(out),
106  in(in),
107  U(U),
108  y(y),
109  x(x),
110  nParity(in.SiteSubset()),
111  parity(parity),
112  volume_cb(in.VolumeCB() > out.VolumeCB() ? in.VolumeCB() : out.VolumeCB()),
113  volume_4d_cb(volume_cb / Ls_),
114  m_f(m_f_),
115  m_5(m_5_),
116  dagger(dagger_),
117  shift {shift_[0], shift_[1], shift_[2], shift_[3]},
118  halo_shift {halo_shift_[0], halo_shift_[1], halo_shift_[2], halo_shift_[3]},
119  dim {(3 - nParity) * (in.VolumeCB() > out.VolumeCB() ? in.X(0) : out.X(0)),
120  in.VolumeCB() > out.VolumeCB() ? in.X(1) : out.X(1), in.VolumeCB() > out.VolumeCB() ? in.X(2) : out.X(2),
121  in.VolumeCB() > out.VolumeCB() ? in.X(3) : out.X(3)},
122  shrinked_dim {dim[0] - 2 * shift[0], dim[1] - 2 * shift[1], dim[2] - 2 * shift[2], dim[3] - 2 * shift[3]},
123  volume_4d_cb_shift(shrinked_dim[0] * shrinked_dim[1] * shrinked_dim[2] * shrinked_dim[3] / 2),
124  type(type_),
125  comm {static_cast<bool>(comm_dim_partitioned(0)), static_cast<bool>(comm_dim_partitioned(1)),
126  static_cast<bool>(comm_dim_partitioned(2)), static_cast<bool>(comm_dim_partitioned(3))}
127  {
128  if (in.Nspin() != 4) { errorQuda("nSpin = %d NOT supported.\n", in.Nspin()); }
129 
130  if (nParity == 2) { errorQuda("nParity = 2 NOT supported, yet.\n"); }
131 
132  if (b_5[0] != b_5[1] || b_5[0].imag() != 0) { errorQuda("zMobius is NOT supported yet.\n"); }
133 
134  b = b_5[0].real();
135  c = c_5[0].real();
136  kappa = -(c * (4. + m_5) - 1.) / (b * (4. + m_5) + 1.); // This is actually -kappa in my(Jiqun Tu) notes.
137 
138  if (kappa * kappa < 1e-6) { small_kappa = true; }
139 
140  fac_inv
141  = 0.5 / (1. + std::pow(kappa, (int)Ls) * m_f); // 0.5 to normalize the (1 +/- gamma5) in the chiral projector.
142  switch (type) {
143  case MdwfFusedDslashType::D4_D5INV_D5PRE:
144  case MdwfFusedDslashType::D4DAG_D5PREDAG_D5INVDAG:
145  if (small_kappa) {
146  m_scale = b;
147  alpha = (c - b * kappa) / (2. * b);
148  beta = 1.;
149  } else {
150  m_scale = b + c / kappa;
151  alpha = 1.;
152  beta = -1. / (1. + (kappa * b) / c);
153  }
154  break;
155  case MdwfFusedDslashType::D4_D5INV_D5INVDAG:
156  m_scale = -0.25 / ((b * (4. + m_5) + 1.) * (b * (4. + m_5) + 1.)); // -kappa_b^2
157  break;
158  case MdwfFusedDslashType::D4DAG_D5PREDAG:
159  m_scale = -0.25 / ((b * (4. + m_5) + 1.) * (b * (4. + m_5) + 1.)) * b; // -kappa_b^2
160  alpha = c / (2. * b); // 2 to compensate for the spin projection
161  beta = 1.;
162  break;
163  case MdwfFusedDslashType::D5PRE:
164  m_scale = b;
165  alpha = c / (2. * b);
166  beta = 1.;
167  break;
168  default: errorQuda("Unknown MdwfFusedDslashType");
169  }
170  }
171  };
172 
173  __device__ inline int index_4d_cb_from_coordinate_4d(const int coordinate[4], const int dim[4])
174  {
175  return (((coordinate[3] * dim[2] + coordinate[2]) * dim[1] + coordinate[1]) * dim[0] + coordinate[0]) / 2;
176  }
177 
178  __device__ inline bool is_halo_4d(const int coordinate[4], const int dim[4], const int halo_shift[4])
179  {
180  bool ret = false;
181 #pragma unroll
182  for (int d = 0; d < 4; d++) {
183  ret = ret or (coordinate[d] >= dim[d] - halo_shift[d] or coordinate[d] < halo_shift[d]);
184  }
185  return ret;
186  }
187 
188  __device__ inline int index_from_extended_coordinate(const int x[4], const int dim[4], const bool comm[4], const int y)
189  {
190  constexpr int pad = 2;
191  int back_x[4];
192  int back_dim[4];
193 
194 #pragma unroll
195  for (int d = 0; d < 4; d++) {
196  back_x[d] = comm[d] ? x[d] - pad : x[d];
197  back_dim[d] = comm[d] ? dim[d] - pad * 2 : dim[d];
198  }
199 
200  bool is_center = true;
201 #pragma unroll
202  for (int d = 0; d < 4; d++) { is_center = is_center && (back_x[d] >= 0 && back_x[d] < back_dim[d]); }
203 
204  if (is_center) {
205  int volume_4d_cb_back = back_dim[0] * back_dim[1] * back_dim[2] * back_dim[3] / 2;
206  return y * volume_4d_cb_back
207  + index_4d_cb_from_coordinate_4d(back_x, back_dim); // the input coordinate is in the center region
208  } else {
209  return -1;
210  }
211  }
212 
213  /**
214  -> Everything should be understood in a 4d checkboarding sense.
215  */
216  template <class storage_type, bool dagger, bool halo, bool back, class Vector, class Arg>
217  __device__ inline void apply_wilson_5d(Vector &out, int coordinate[4], Arg &arg, int s)
218  {
219  typedef typename mapper<storage_type>::type compute_type;
220  typedef Matrix<complex<compute_type>, 3> Link;
221  const int their_spinor_parity = arg.nParity == 2 ? 1 - arg.parity : 0;
222 
223  const int index_4d_cb = index_4d_cb_from_coordinate_4d(coordinate, arg.dim);
224 
225 #pragma unroll
226  for (int d = 0; d < 4; d++) // loop over dimension
227  {
228  int x[4] = {coordinate[0], coordinate[1], coordinate[2], coordinate[3]};
229  x[d] = (coordinate[d] == arg.dim[d] - 1 && !arg.comm[d]) ? 0 : coordinate[d] + 1;
230  if (!halo || !is_halo_4d(x, arg.dim, arg.halo_shift)) {
231  // Forward gather - compute fwd offset for vector fetch
232  int fwd_idx;
233  if (back) {
234  fwd_idx = index_from_extended_coordinate(x, arg.dim, arg.comm, s);
235  } else {
236  fwd_idx = s * arg.volume_4d_cb + index_4d_cb_from_coordinate_4d(x, arg.dim);
237  }
238  constexpr int proj_dir = dagger ? +1 : -1;
239 
240  const Link U = arg.U(d, index_4d_cb, arg.parity);
241  const Vector in = arg.in(fwd_idx, their_spinor_parity);
242  out += (U * in.project(d, proj_dir)).reconstruct(d, proj_dir);
243  }
244  x[d] = (coordinate[d] == 0 && !arg.comm[d]) ? arg.dim[d] - 1 : coordinate[d] - 1;
245  if (!halo || !is_halo_4d(x, arg.dim, arg.halo_shift)) {
246  // Backward gather - compute back offset for spinor and gauge fetch
247  const int gauge_idx = index_4d_cb_from_coordinate_4d(x, arg.dim);
248 
249  int back_idx;
250  if (back) {
251  back_idx = index_from_extended_coordinate(x, arg.dim, arg.comm, s);
252  } else {
253  back_idx = s * arg.volume_4d_cb + gauge_idx;
254  }
255  constexpr int proj_dir = dagger ? -1 : +1;
256 
257  const Link U = arg.U(d, gauge_idx, 1 - arg.parity);
258  const Vector in = arg.in(back_idx, their_spinor_parity);
259  out += (conj(U) * in.project(d, proj_dir)).reconstruct(d, proj_dir);
260  }
261  } // nDim
262  }
263 
264  /**
265  -> Everything should be understood in a 4d checkboarding sense.
266  Given index in the shrinked block, calculate the coordinate in the shrinked block,
267  then shift the coordinate to the un-shrinked coordinate, e.g. (0,0,4,1) -> (2,2,6,3) with shift = (2,2,2,2)
268  */
269  template <class T>
270  __device__ inline void coordinate_from_shrinked_index(int coordinate[4], int shrinked_index,
271  const T shrinked_dim[4], const int shift[4], int parity)
272  {
273  int aux[4];
274  aux[0] = shrinked_index * 2;
275 
276 #pragma unroll
277  for (int i = 0; i < 3; i++) { aux[i + 1] = aux[i] / shrinked_dim[i]; }
278 
279  coordinate[0] = aux[0] - aux[1] * shrinked_dim[0];
280  coordinate[1] = aux[1] - aux[2] * shrinked_dim[1];
281  coordinate[2] = aux[2] - aux[3] * shrinked_dim[2];
282  coordinate[3] = aux[3];
283 
284  // Find the full coordinate in the shrinked volume.
285  coordinate[0]
286  += (shift[0] + shift[1] + shift[2] + shift[3] + parity + coordinate[3] + coordinate[2] + coordinate[1]) & 1;
287 
288 // Now go back to the extended volume.
289 #pragma unroll
290  for (int d = 0; d < 4; d++) { coordinate[d] += shift[d]; }
291  }
292 
293  /**
294  @brief Tensor core kernel for applying Wilson hopping term and then the beta + alpha * M5inv operator
295  The integer kernel types corresponds to the enum MdwfFusedDslashType.
296  */
297  template <int block_dim_x, int minBlocksPerMultiprocessor, bool reload, class Arg, int type>
298  __global__ void __launch_bounds__(block_dim_x *Arg::Ls, minBlocksPerMultiprocessor) fused_tensor_core(Arg arg)
299  {
300  using storage_type = typename Arg::storage_type;
301  using real = typename mapper<storage_type>::type;
302  using Vector = ColorSpinor<real, 3, 4>;
303  constexpr int Ls = Arg::Ls;
304  const int explicit_parity = arg.nParity == 2 ? arg.parity : 0;
305 
306  TensorCoreSharedMemory<float> shared_memory_data;
307 
308  static_assert(block_dim_x * Ls / 32 < 32, "Number of threads in a threadblock should be less than 1024.");
309 
310  constexpr int M = 4 * Ls;
311  constexpr int N = 6 * block_dim_x;
312 
313  constexpr int N_sm = N + sm_n_pad_size(N);
314  constexpr int M_sm = M + sm_m_pad_size(M);
315 
316  float *smem_scale = shared_memory_data;
317 
318  half2 *sm_b = reinterpret_cast<half2 *>(smem_scale + 32);
319  half *sm_c = reinterpret_cast<half *>(sm_b);
320 
321  half *sm_a = reload ? sm_c + M * N_sm : sm_c;
322  // This is for type == 1 ONLY.
323  half *sm_a_black = sm_a + M * M_sm;
324 
325  if (type == 0) {
326  if (arg.small_kappa) {
327  construct_matrix_a_d5<block_dim_x, Ls, M_sm, false, Arg>(arg, sm_a); // dagger = false
328  } else {
329  construct_matrix_a_m5inv<block_dim_x, Ls, M_sm, false, Arg>(arg, sm_a); // dagger = false
330  }
331  } else if (type == 2) {
332  if (arg.small_kappa) {
333  construct_matrix_a_d5<block_dim_x, Ls, M_sm, true, Arg>(arg, sm_a); // dagger = true
334  } else {
335  construct_matrix_a_m5inv<block_dim_x, Ls, M_sm, true, Arg>(arg, sm_a); // dagger = false
336  }
337  } else if (type == 1) {
338  construct_matrix_a_m5inv<block_dim_x, Ls, M_sm, false, Arg>(arg, sm_a); // dagger = false
339  } else if (type == 3) {
340  construct_matrix_a_d5<block_dim_x, Ls, M_sm, true, Arg>(arg, sm_a); // dagger = true
341  } else if (type == 4) {
342  construct_matrix_a_d5<block_dim_x, Ls, M_sm, false, Arg>(arg, sm_a); // dagger = true
343  }
344  __syncthreads();
345 
346  bool idle = false;
347  int s4_shift_base = blockIdx.x * blockDim.x; // base.
348  int s4_shift, sid;
349 
350  constexpr int tm_dim = M / mma::MMA_M;
351  constexpr int tn_dim = N / mma::MMA_N;
352  constexpr int tk_dim = M / mma::MMA_K;
353 
354  constexpr int total_warp = block_dim_x * Ls >> 5;
355  const int this_warp = (threadIdx.y * block_dim_x + threadIdx.x) >> 5;
356 
357  constexpr int total_tile = tm_dim * tn_dim;
358 
359  constexpr int warp_cycle = total_tile / total_warp;
360  const int warp_m = this_warp * warp_cycle / tn_dim;
361 
362  mma::WarpRegisterMapping wrm(threadIdx.y * blockDim.x + threadIdx.x);
363  mma::MmaOperandA op_a[reload ? 1 : tk_dim];
364  mma::MmaOperandA op_a_aux[reload ? 1 : tk_dim];
365  if (!reload) { // the data in registers can be resued.
366 #pragma unroll
367  for (int tile_k = 0; tile_k < tk_dim; tile_k++) { op_a[tile_k].template load<M_sm>(sm_a, tile_k, warp_m, wrm); }
368  }
369 
370  if (type == 1) {
371  arg.alpha = 1.;
372  if (!reload) { // in the preload case we preload ...
373  construct_matrix_a_m5inv<block_dim_x, Ls, M_sm, true, Arg>(arg, sm_a); // dagger = true
374  __syncthreads();
375 
376 #pragma unroll
377  for (int tile_k = 0; tile_k < tk_dim; tile_k++) {
378  op_a_aux[tile_k].template load<M_sm>(sm_a, tile_k, warp_m, wrm);
379  }
380 
381  } else {
382  construct_matrix_a_m5inv<block_dim_x, Ls, M_sm, true, Arg>(arg, sm_a_black); // dagger = true
383  __syncthreads();
384  }
385  }
386 
387  while (s4_shift_base < arg.volume_4d_cb_shift) {
388  int x[4];
389  s4_shift = s4_shift_base + threadIdx.x;
390  coordinate_from_shrinked_index(x, s4_shift, arg.shrinked_dim, arg.shift, arg.parity);
391  sid = threadIdx.y * arg.volume_4d_cb + index_4d_cb_from_coordinate_4d(x, arg.dim);
392 
393  if (s4_shift >= arg.volume_4d_cb_shift) { idle = true; }
394 
395  Vector in_vec;
396  if (!idle) {
397  // the Wilson hopping terms
398  if (type == 0) {
399  apply_wilson_5d<storage_type, false, true, true>(in_vec, x, arg, threadIdx.y); // dagger = false; halo = true
400  } else if (type == 2) {
401  apply_wilson_5d<storage_type, true, false, false>(in_vec, x, arg,
402  threadIdx.y); // dagger = true; halo = false
403  } else if (type == 1) {
404  apply_wilson_5d<storage_type, false, true, false>(in_vec, x, arg, threadIdx.y); // dagger = false; halo = true
405  } else if (type == 3) {
406  apply_wilson_5d<storage_type, true, false, false>(in_vec, x, arg,
407  threadIdx.y); // dagger = true; halo = false
408  } else if (type == 4) {
409  int sid_shift = threadIdx.y * arg.volume_4d_cb_shift + s4_shift;
410  in_vec = arg.in(sid_shift, explicit_parity);
411  }
412  // store result to shared memory
413  }
414  load_matrix_b_vector<block_dim_x, Ls, N_sm / 2, false>(in_vec, sm_b, smem_scale); // acc(accumulation) = false
415 
416  __syncthreads();
417  mma_sync_gemm<block_dim_x, Ls, M, N, M_sm, N_sm, reload>(op_a, sm_a, sm_c, sm_c, wrm);
418  __syncthreads();
419 
420  if (type == 1) {
421  Vector aux_in_vec;
422  int sid_back;
423  bool center = false;
424  if (!idle) {
425  sid_back = index_from_extended_coordinate(x, arg.dim, arg.comm, threadIdx.y);
426  if (sid_back >= 0) {
427  center = true;
428  aux_in_vec = arg.x(sid_back, explicit_parity);
429  }
430  }
431  load_matrix_b_vector<block_dim_x, Ls, N_sm / 2, true>(aux_in_vec, sm_b, smem_scale, arg.m_scale); // acc = true
432  if (!idle && center) { store_matrix_c<storage_type, N_sm>(arg.y, sm_b, sid_back, smem_scale[0]); }
433  __syncthreads();
434  mma_sync_gemm<block_dim_x, Ls, M, N, M_sm, N_sm, reload>(op_a_aux, sm_a_black, sm_c, sm_c, wrm);
435  __syncthreads();
436 
437  } else if (type == 3) {
438  Vector aux_in_vec;
439  int sid_shift = threadIdx.y * arg.volume_4d_cb_shift + s4_shift;
440  if (!idle) { aux_in_vec = arg.x(sid_shift, explicit_parity); }
441  load_matrix_b_vector<block_dim_x, Ls, N_sm / 2, true, false>(aux_in_vec, sm_b, smem_scale, arg.m_scale);
442  if (!idle) { arg.out(sid_shift, explicit_parity) = aux_in_vec; }
443  }
444 
445  if (type == 3) {
446 
447  } else if (type == 1) {
448  if (!idle) { store_matrix_c<storage_type, N_sm>(arg.out, sm_b, sid, smem_scale[0]); }
449  } else {
450  if (!idle) { store_matrix_c<storage_type, N_sm>(arg.out, sm_b, sid, smem_scale[0] * arg.m_scale); }
451  }
452 
453  s4_shift_base += gridDim.x * blockDim.x;
454 
455  } // while
456  }
457 
458  template <class Arg> class FusedDslash : public Tunable
459  {
460 
461  protected:
462  Arg &arg;
463  const ColorSpinorField &meta;
464 
465  /** Whether to use variable or fixed coefficient algorithm. Must be true if using ZMOBIUS */
466  static constexpr bool var_inverse = true;
467 
468  long long flops() const
469  {
470  constexpr long long hop = 7ll * 8ll;
471  constexpr long long mat = 2ll * 4ll * Arg::Ls - 1ll;
472  long long volume_4d_cb_halo_shift = (arg.dim[0] - 2 * arg.halo_shift[0]) * (arg.dim[1] - 2 * arg.halo_shift[1])
473  * (arg.dim[2] - 2 * arg.halo_shift[2]) * (arg.dim[3] - 2 * arg.halo_shift[3]) / 2;
474 
475  long long flops_ = 0;
476  switch (arg.type) {
477  case MdwfFusedDslashType::D4_D5INV_D5PRE:
478  flops_ = volume_4d_cb_halo_shift * 6ll * 4ll * arg.Ls * hop + arg.volume_4d_cb_shift * 24ll * arg.Ls * mat;
479  break;
480  case MdwfFusedDslashType::D4_D5INV_D5INVDAG:
481  flops_
482  = volume_4d_cb_halo_shift * 6ll * 4ll * arg.Ls * hop + arg.volume_4d_cb_shift * 24ll * arg.Ls * 2ll * mat;
483  break;
484  case MdwfFusedDslashType::D4DAG_D5PREDAG_D5INVDAG:
485  case MdwfFusedDslashType::D4DAG_D5PREDAG:
486  flops_ = arg.volume_4d_cb_shift * 6ll * 4ll * arg.Ls
487  * (hop + mat); // for 2 and 3 we don't have the halo complication.
488  break;
489  case MdwfFusedDslashType::D5PRE: flops_ = arg.volume_4d_cb_shift * 6ll * 4ll * arg.Ls * (mat); break;
490  default: errorQuda("Unknown MdwfFusedDslashType");
491  }
492 
493  return flops_;
494  }
495 
496  long long bytes() const
497  {
498  auto site_size = arg.Ls * (2ll * meta.Nspin() * meta.Ncolor() * meta.Precision() + sizeof(float));
499  auto dim = arg.dim;
500  auto b_m0 = ((dim[0] - 0) * (dim[1] - 0) * (dim[2] - 0) * (dim[3] - 0) / 2) * site_size;
501  auto b_m1 = ((dim[0] - 1) * (dim[1] - 1) * (dim[2] - 1) * (dim[3] - 1) / 2) * site_size;
502  auto b_m2 = ((dim[0] - 2) * (dim[1] - 2) * (dim[2] - 2) * (dim[3] - 2) / 2) * site_size;
503  switch (arg.type) {
504  case MdwfFusedDslashType::D4_D5INV_D5PRE: return b_m1 + b_m2 + arg.U.Bytes();
505  case MdwfFusedDslashType::D4_D5INV_D5INVDAG: return 2 * b_m2 + b_m1 + b_m0 + arg.U.Bytes();
506  case MdwfFusedDslashType::D4DAG_D5PREDAG_D5INVDAG: return b_m1 + b_m0 + arg.U.Bytes();
507  case MdwfFusedDslashType::D4DAG_D5PREDAG: return 2 * b_m2 + b_m1 + arg.U.Bytes();
508  case MdwfFusedDslashType::D5PRE: return 2 * b_m0;
509  default: errorQuda("Unknown MdwfFusedDslashType");
510  }
511  return 0ll;
512  }
513 
514  bool tuneAuxDim() const { return true; }
515 
516  int blockStep() const { return 16; }
517  int blockMin() const { return 16; }
518  unsigned int maxBlockSize(const TuneParam &param) const { return 32; }
519 
520  int gridStep() const { return deviceProp.multiProcessorCount; }
521  unsigned int maxGridSize() const { return (arg.volume_4d_cb_shift + blockMin() - 1) / blockMin(); }
522  unsigned int minGridSize() const { return deviceProp.multiProcessorCount; }
523 
524  unsigned int sharedBytesPerBlock(const TuneParam &param) const
525  {
526  const int a_size = (param.block.y * 4) * (param.block.y * 4 + sm_m_pad_size(param.block.y * 4));
527  const int b_size = (param.block.y * 4) * (param.block.x * 6 + sm_n_pad_size(param.block.x * 6));
528  // (Ls*4) by (Ls*4), (Ls*4) by (volume_4d*6 + 16)
529  if (param.aux.x == 1) { // aux.x == 1 --> reload == true
530  if (arg.type == MdwfFusedDslashType::D4_D5INV_D5INVDAG) {
531  return (a_size * 2 + b_size) * sizeof(half) + 128;
532  } else {
533  return (a_size + b_size) * sizeof(half) + 128;
534  }
535  } else {
536  return (a_size > b_size ? a_size : b_size) * sizeof(half) + 128;
537  }
538  }
539 
540  unsigned int sharedBytesPerThread() const { return 0; }
541 
542  bool advanceAux(TuneParam &param) const
543  {
544  bool aux_advanced = false;
545  if (param.aux.x == 0) { // first see if aux.x(ONLY 0(false) or 1(true))
546  param.aux.x++;
547  aux_advanced = true;
548  } else {
549  if (param.aux.y < 3) { // second see if aux.y
550  param.aux.y++;
551  aux_advanced = true;
552  param.aux.x = 0;
553  }
554  }
555  // shared bytes depends on aux, so update if changed
556  if (aux_advanced) param.shared_bytes = sharedBytesPerBlock(param);
557  return aux_advanced;
558  }
559 
560  // overloaded to return max dynamic shared memory if doing shared-memory inverse
561  unsigned int maxSharedBytesPerBlock() const { return maxDynamicSharedBytesPerBlock(); }
562 
563  public:
564  FusedDslash(Arg &arg, const ColorSpinorField &meta) : arg(arg), meta(meta)
565  {
566  strcpy(aux, meta.AuxString());
567  if (arg.dagger) strcat(aux, ",Dagger");
568  char config[512];
569  switch (arg.type) {
570  case MdwfFusedDslashType::D4_D5INV_D5PRE: sprintf(config, ",f0"); break;
571  case MdwfFusedDslashType::D4DAG_D5PREDAG_D5INVDAG: sprintf(config, ",f2"); break;
572  case MdwfFusedDslashType::D4_D5INV_D5INVDAG: sprintf(config, ",f1"); break;
573  case MdwfFusedDslashType::D4DAG_D5PREDAG: sprintf(config, ",f3"); break;
574  case MdwfFusedDslashType::D5PRE: sprintf(config, ",f4"); break;
575  default: errorQuda("Unknown MdwfFusedDslashType");
576  }
577  strcat(aux, config);
578  sprintf(config, "shift%d%d%d%d,halo%d%d%d%d,comm%d%d%d%d", arg.shift[0], arg.shift[1], arg.shift[2],
579  arg.shift[3], arg.halo_shift[0], arg.halo_shift[1], arg.halo_shift[2], arg.halo_shift[3], arg.comm[0],
580  arg.comm[1], arg.comm[2], arg.comm[3]);
581  strcat(aux, config);
582  }
583 
584  template <typename T> inline void launch(T *f, const TuneParam &tp, Arg &arg, const qudaStream_t &stream)
585  {
586  const_cast<TuneParam &>(tp).set_max_shared_bytes = true;
587  qudaLaunchKernel(f, tp, stream, arg);
588  }
589 
590  // The following apply<...> functions are used to turn the tune parameters into template arguments.
591  // Specifically tp.aux.y dictates the minBlocksPerMultiprocessor in __launch_bounds__(..).
592  // tp.aux.x dictates whether or not to reload.
593  template <int block_dim_x, bool reload, int type>
594  void apply(const TuneParam &tp, Arg &arg, const qudaStream_t &stream)
595  {
596  switch (tp.aux.y) {
597  case 1: launch(fused_tensor_core<block_dim_x, 1, reload, Arg, type>, tp, arg, stream); break;
598  case 2: launch(fused_tensor_core<block_dim_x, 2, reload, Arg, type>, tp, arg, stream); break;
599  case 3: launch(fused_tensor_core<block_dim_x, 3, reload, Arg, type>, tp, arg, stream); break;
600  default: errorQuda("NOT valid tp.aux.y(=%d)\n", tp.aux.y);
601  }
602  }
603 
604  template <bool reload, int type> void apply(const TuneParam &tp, Arg &arg, const qudaStream_t &stream)
605  {
606  switch (tp.block.x) {
607  case 16: apply<16, reload, type>(tp, arg, stream); break;
608  case 32: apply<32, reload, type>(tp, arg, stream); break;
609  default: errorQuda("NOT valid tp.block.x(=%d)\n", tp.block.x);
610  }
611  }
612 
613  template <int type> void apply(const TuneParam &tp, Arg &arg, const qudaStream_t &stream)
614  {
615  if (tp.aux.x == 0) {
616  apply<false, type>(tp, arg, stream); // reload = false
617  } else {
618  apply<true, type>(tp, arg, stream); // reload = true
619  }
620  }
621 
622  void apply(const qudaStream_t &stream)
623  {
624  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
625  switch (arg.type) {
626  case MdwfFusedDslashType::D4_D5INV_D5PRE: apply<0>(tp, arg, stream); break;
627  case MdwfFusedDslashType::D4_D5INV_D5INVDAG: apply<1>(tp, arg, stream); break;
628  case MdwfFusedDslashType::D4DAG_D5PREDAG_D5INVDAG: apply<2>(tp, arg, stream); break;
629  case MdwfFusedDslashType::D4DAG_D5PREDAG: apply<3>(tp, arg, stream); break;
630  case MdwfFusedDslashType::D5PRE: apply<4>(tp, arg, stream); break;
631  default: errorQuda("Unknown MdwfFusedDslashType");
632  }
633  }
634 
635  void initTuneParam(TuneParam &param) const
636  {
637  Tunable::initTuneParam(param);
638  param.block = dim3(blockMin(), arg.Ls, 1); // Ls must be contained in the block
639  param.grid = dim3(minGridSize(), 1, 1);
640  param.shared_bytes = sharedBytesPerBlock(param);
641  param.aux.x = 0;
642  param.aux.y = 1;
643  }
644 
645  void defaultTuneParam(TuneParam &param) const { initTuneParam(param); }
646 
647  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
648  };
649 
650  // Apply the 5th dimension dslash operator to a colorspinor field
651  // out = Dslash5 * in
652  template <typename storage_type, int nColor, QudaReconstructType recon> struct FusedApply {
653 
654  inline FusedApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, ColorSpinorField &y,
655  const ColorSpinorField &x, double m_f, double m_5, const Complex *b_5, const Complex *c_5,
656  bool dagger, int parity, int shift[4], int halo_shift[4], MdwfFusedDslashType type)
657  {
658  // switch for Ls
659  // Only mutiple of 4 are supported since tensor core MMA only supports multiple of 16 shapes and we get a
660  // factor of 4 for free.
661  switch (in.X(4)) {
662  case 4: {
663  FusedDslashArg<storage_type, recon, 4> arg(out, in, U, y, x, m_f, m_5, b_5, c_5, dagger, parity, shift,
664  halo_shift, type);
665  FusedDslash<decltype(arg)> dslash(arg, in);
666  dslash.apply(streams[Nstream - 1]);
667  } break;
668  case 8: {
669  FusedDslashArg<storage_type, recon, 8> arg(out, in, U, y, x, m_f, m_5, b_5, c_5, dagger, parity, shift,
670  halo_shift, type);
671  FusedDslash<decltype(arg)> dslash(arg, in);
672  dslash.apply(streams[Nstream - 1]);
673  } break;
674  case 12: {
675  FusedDslashArg<storage_type, recon, 12> arg(out, in, U, y, x, m_f, m_5, b_5, c_5, dagger, parity, shift,
676  halo_shift, type);
677  FusedDslash<decltype(arg)> dslash(arg, in);
678  dslash.apply(streams[Nstream - 1]);
679  } break;
680  case 16: {
681  FusedDslashArg<storage_type, recon, 16> arg(out, in, U, y, x, m_f, m_5, b_5, c_5, dagger, parity, shift,
682  halo_shift, type);
683  FusedDslash<decltype(arg)> dslash(arg, in);
684  dslash.apply(streams[Nstream - 1]);
685  } break;
686  case 20: {
687  FusedDslashArg<storage_type, recon, 20> arg(out, in, U, y, x, m_f, m_5, b_5, c_5, dagger, parity, shift,
688  halo_shift, type);
689  FusedDslash<decltype(arg)> dslash(arg, in);
690  dslash.apply(streams[Nstream - 1]);
691  } break;
692  default: errorQuda("Ls = %d is NOT supported.\n", in.X(4));
693  }
694  }
695  };
696 #endif // #if (CUDA_VERSION >= 10010 && __COMPUTE_CAPABILITY__ >= 700)
697 
698  void apply_fused_dslash(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, ColorSpinorField &y,
699  const ColorSpinorField &x, double m_f, double m_5, const Complex *b_5, const Complex *c_5,
700  bool dagger, int parity, int shift[4], int halo_shift[4], MdwfFusedDslashType type)
701  {
702 #if defined(GPU_DOMAIN_WALL_DIRAC) && (CUDA_VERSION >= 10010 && __COMPUTE_CAPABILITY__ >= 700)
703  checkLocation(out, in); // check all locations match
704  instantiatePreconditioner<FusedApply>(out, in, U, y, x, m_f, m_5, b_5, c_5, dagger, parity, shift, halo_shift,
705  type);
706 #else
707  errorQuda("Domain wall dslash with tensor cores has not been built");
708 #endif
709  }
710  } // namespace mobius_tensor_core
711 } // namespace quda