1 #include <gauge_field.h>
2 #include <gauge_field_order.h>
6 #include <color_spinor_field.h>
8 #include <dslash_quda.h>
9 #include <jitify_helper.cuh>
10 #include <instantiate_dslash.h>
12 #if (CUDA_VERSION >= 10010 && __COMPUTE_CAPABILITY__ >= 700)
13 #include <mdw_dslash5_tensor_core.cuh>
18 namespace mobius_tensor_core
21 #if (CUDA_VERSION >= 10010 && __COMPUTE_CAPABILITY__ >= 700)
23 constexpr int sm_m_pad_size(int m)
25 return quda::mma::pad_size(m);
28 constexpr int sm_n_pad_size(int n)
30 return quda::mma::pad_size(n);
34 @brief Parameter structure for applying the Dslash
36 template <class storage_type_, QudaReconstructType recon_,
37 int Ls_> // storage_type is the usual "Float" in other places in QUDA
38 struct FusedDslashArg {
39 using storage_type = storage_type_;
40 using real = typename mapper<storage_type>::type; // the compute type for the in kernel computation
41 static constexpr QudaReconstructType recon = recon_;
42 static constexpr int Ls = Ls_;
43 static constexpr bool spin_project = true;
44 static constexpr bool spinor_direct_load = true; // false means texture load
47 = colorspinor::FloatNOrder<storage_type, 4, 3, 8, spin_project, spinor_direct_load>; // color spin field order
50 = colorspinor::FloatNOrder<storage_type, 4, 3, 4, spin_project, spinor_direct_load>; // color spin field order
52 static constexpr bool gauge_direct_load = true; // false means texture load
53 static constexpr QudaGhostExchange ghost = QUDA_GHOST_EXCHANGE_EXTENDED; // gauge field used is an extended one
54 using G = typename gauge_mapper<storage_type, recon, 18, QUDA_STAGGERED_PHASE_NO, gauge_direct_load, ghost>::type; // gauge field order
56 F out; // output vector field
57 const F in; // input vector field
58 F y; // auxiliary output vector field
59 const F x; // auxiliary input vector field
61 const G U; // The gauge field
63 const int nParity; // number of parities we're working on
64 const int parity; // output parity of this dslash operator
65 const int volume_cb; // checkerboarded volume
66 const int volume_4d_cb; // 4-d checkerboarded volume
69 const int shift[4]; // sites where we actually calculate.
70 const int halo_shift[4]; // halo means zero. When we are expanding we have halo of cs-field where values are zero.
72 const int_fastdiv shrinked_dim[4]; // dimension after shifts are considered.
74 // partial kernel and expansion parameters
75 const int volume_4d_cb_shift; // number of 4d sites we need calculate
76 // const int volume_4d_cb_expansive; //
78 const real m_f; // fermion mass parameter
79 const real m_5; // Wilson mass shift
81 const bool dagger; // dagger
82 // const bool xpay; // whether we are doing xpay or not
84 real b; // real constant Mobius coefficient
85 real c; // real constant Mobius coefficient
86 real a; // real xpay coefficient
91 // (beta + alpha*m5inv) * in
95 real m_scale = 1.; // scale factor for the matrix
97 bool small_kappa = false;
101 MdwfFusedDslashType type;
102 FusedDslashArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, ColorSpinorField &y,
103 const ColorSpinorField &x, double m_f_, double m_5_, const Complex *b_5, const Complex *c_5,
104 bool dagger_, int parity, int shift_[4], int halo_shift_[4], MdwfFusedDslashType type_) :
110 nParity(in.SiteSubset()),
112 volume_cb(in.VolumeCB() > out.VolumeCB() ? in.VolumeCB() : out.VolumeCB()),
113 volume_4d_cb(volume_cb / Ls_),
117 shift {shift_[0], shift_[1], shift_[2], shift_[3]},
118 halo_shift {halo_shift_[0], halo_shift_[1], halo_shift_[2], halo_shift_[3]},
119 dim {(3 - nParity) * (in.VolumeCB() > out.VolumeCB() ? in.X(0) : out.X(0)),
120 in.VolumeCB() > out.VolumeCB() ? in.X(1) : out.X(1), in.VolumeCB() > out.VolumeCB() ? in.X(2) : out.X(2),
121 in.VolumeCB() > out.VolumeCB() ? in.X(3) : out.X(3)},
122 shrinked_dim {dim[0] - 2 * shift[0], dim[1] - 2 * shift[1], dim[2] - 2 * shift[2], dim[3] - 2 * shift[3]},
123 volume_4d_cb_shift(shrinked_dim[0] * shrinked_dim[1] * shrinked_dim[2] * shrinked_dim[3] / 2),
125 comm {static_cast<bool>(comm_dim_partitioned(0)), static_cast<bool>(comm_dim_partitioned(1)),
126 static_cast<bool>(comm_dim_partitioned(2)), static_cast<bool>(comm_dim_partitioned(3))}
128 if (in.Nspin() != 4) { errorQuda("nSpin = %d NOT supported.\n", in.Nspin()); }
130 if (nParity == 2) { errorQuda("nParity = 2 NOT supported, yet.\n"); }
132 if (b_5[0] != b_5[1] || b_5[0].imag() != 0) { errorQuda("zMobius is NOT supported yet.\n"); }
136 kappa = -(c * (4. + m_5) - 1.) / (b * (4. + m_5) + 1.); // This is actually -kappa in my(Jiqun Tu) notes.
138 if (kappa * kappa < 1e-6) { small_kappa = true; }
141 = 0.5 / (1. + std::pow(kappa, (int)Ls) * m_f); // 0.5 to normalize the (1 +/- gamma5) in the chiral projector.
143 case MdwfFusedDslashType::D4_D5INV_D5PRE:
144 case MdwfFusedDslashType::D4DAG_D5PREDAG_D5INVDAG:
147 alpha = (c - b * kappa) / (2. * b);
150 m_scale = b + c / kappa;
152 beta = -1. / (1. + (kappa * b) / c);
155 case MdwfFusedDslashType::D4_D5INV_D5INVDAG:
156 m_scale = -0.25 / ((b * (4. + m_5) + 1.) * (b * (4. + m_5) + 1.)); // -kappa_b^2
158 case MdwfFusedDslashType::D4DAG_D5PREDAG:
159 m_scale = -0.25 / ((b * (4. + m_5) + 1.) * (b * (4. + m_5) + 1.)) * b; // -kappa_b^2
160 alpha = c / (2. * b); // 2 to compensate for the spin projection
163 case MdwfFusedDslashType::D5PRE:
165 alpha = c / (2. * b);
168 default: errorQuda("Unknown MdwfFusedDslashType");
173 __device__ inline int index_4d_cb_from_coordinate_4d(const int coordinate[4], const int dim[4])
175 return (((coordinate[3] * dim[2] + coordinate[2]) * dim[1] + coordinate[1]) * dim[0] + coordinate[0]) / 2;
178 __device__ inline bool is_halo_4d(const int coordinate[4], const int dim[4], const int halo_shift[4])
182 for (int d = 0; d < 4; d++) {
183 ret = ret or (coordinate[d] >= dim[d] - halo_shift[d] or coordinate[d] < halo_shift[d]);
188 __device__ inline int index_from_extended_coordinate(const int x[4], const int dim[4], const bool comm[4], const int y)
190 constexpr int pad = 2;
195 for (int d = 0; d < 4; d++) {
196 back_x[d] = comm[d] ? x[d] - pad : x[d];
197 back_dim[d] = comm[d] ? dim[d] - pad * 2 : dim[d];
200 bool is_center = true;
202 for (int d = 0; d < 4; d++) { is_center = is_center && (back_x[d] >= 0 && back_x[d] < back_dim[d]); }
205 int volume_4d_cb_back = back_dim[0] * back_dim[1] * back_dim[2] * back_dim[3] / 2;
206 return y * volume_4d_cb_back
207 + index_4d_cb_from_coordinate_4d(back_x, back_dim); // the input coordinate is in the center region
214 -> Everything should be understood in a 4d checkboarding sense.
216 template <class storage_type, bool dagger, bool halo, bool back, class Vector, class Arg>
217 __device__ inline void apply_wilson_5d(Vector &out, int coordinate[4], Arg &arg, int s)
219 typedef typename mapper<storage_type>::type compute_type;
220 typedef Matrix<complex<compute_type>, 3> Link;
221 const int their_spinor_parity = arg.nParity == 2 ? 1 - arg.parity : 0;
223 const int index_4d_cb = index_4d_cb_from_coordinate_4d(coordinate, arg.dim);
226 for (int d = 0; d < 4; d++) // loop over dimension
228 int x[4] = {coordinate[0], coordinate[1], coordinate[2], coordinate[3]};
229 x[d] = (coordinate[d] == arg.dim[d] - 1 && !arg.comm[d]) ? 0 : coordinate[d] + 1;
230 if (!halo || !is_halo_4d(x, arg.dim, arg.halo_shift)) {
231 // Forward gather - compute fwd offset for vector fetch
234 fwd_idx = index_from_extended_coordinate(x, arg.dim, arg.comm, s);
236 fwd_idx = s * arg.volume_4d_cb + index_4d_cb_from_coordinate_4d(x, arg.dim);
238 constexpr int proj_dir = dagger ? +1 : -1;
240 const Link U = arg.U(d, index_4d_cb, arg.parity);
241 const Vector in = arg.in(fwd_idx, their_spinor_parity);
242 out += (U * in.project(d, proj_dir)).reconstruct(d, proj_dir);
244 x[d] = (coordinate[d] == 0 && !arg.comm[d]) ? arg.dim[d] - 1 : coordinate[d] - 1;
245 if (!halo || !is_halo_4d(x, arg.dim, arg.halo_shift)) {
246 // Backward gather - compute back offset for spinor and gauge fetch
247 const int gauge_idx = index_4d_cb_from_coordinate_4d(x, arg.dim);
251 back_idx = index_from_extended_coordinate(x, arg.dim, arg.comm, s);
253 back_idx = s * arg.volume_4d_cb + gauge_idx;
255 constexpr int proj_dir = dagger ? -1 : +1;
257 const Link U = arg.U(d, gauge_idx, 1 - arg.parity);
258 const Vector in = arg.in(back_idx, their_spinor_parity);
259 out += (conj(U) * in.project(d, proj_dir)).reconstruct(d, proj_dir);
265 -> Everything should be understood in a 4d checkboarding sense.
266 Given index in the shrinked block, calculate the coordinate in the shrinked block,
267 then shift the coordinate to the un-shrinked coordinate, e.g. (0,0,4,1) -> (2,2,6,3) with shift = (2,2,2,2)
270 __device__ inline void coordinate_from_shrinked_index(int coordinate[4], int shrinked_index,
271 const T shrinked_dim[4], const int shift[4], int parity)
274 aux[0] = shrinked_index * 2;
277 for (int i = 0; i < 3; i++) { aux[i + 1] = aux[i] / shrinked_dim[i]; }
279 coordinate[0] = aux[0] - aux[1] * shrinked_dim[0];
280 coordinate[1] = aux[1] - aux[2] * shrinked_dim[1];
281 coordinate[2] = aux[2] - aux[3] * shrinked_dim[2];
282 coordinate[3] = aux[3];
284 // Find the full coordinate in the shrinked volume.
286 += (shift[0] + shift[1] + shift[2] + shift[3] + parity + coordinate[3] + coordinate[2] + coordinate[1]) & 1;
288 // Now go back to the extended volume.
290 for (int d = 0; d < 4; d++) { coordinate[d] += shift[d]; }
294 @brief Tensor core kernel for applying Wilson hopping term and then the beta + alpha * M5inv operator
295 The integer kernel types corresponds to the enum MdwfFusedDslashType.
297 template <int block_dim_x, int minBlocksPerMultiprocessor, bool reload, class Arg, int type>
298 __global__ void __launch_bounds__(block_dim_x *Arg::Ls, minBlocksPerMultiprocessor) fused_tensor_core(Arg arg)
300 using storage_type = typename Arg::storage_type;
301 using real = typename mapper<storage_type>::type;
302 using Vector = ColorSpinor<real, 3, 4>;
303 constexpr int Ls = Arg::Ls;
304 const int explicit_parity = arg.nParity == 2 ? arg.parity : 0;
306 TensorCoreSharedMemory<float> shared_memory_data;
308 static_assert(block_dim_x * Ls / 32 < 32, "Number of threads in a threadblock should be less than 1024.");
310 constexpr int M = 4 * Ls;
311 constexpr int N = 6 * block_dim_x;
313 constexpr int N_sm = N + sm_n_pad_size(N);
314 constexpr int M_sm = M + sm_m_pad_size(M);
316 float *smem_scale = shared_memory_data;
318 half2 *sm_b = reinterpret_cast<half2 *>(smem_scale + 32);
319 half *sm_c = reinterpret_cast<half *>(sm_b);
321 half *sm_a = reload ? sm_c + M * N_sm : sm_c;
322 // This is for type == 1 ONLY.
323 half *sm_a_black = sm_a + M * M_sm;
326 if (arg.small_kappa) {
327 construct_matrix_a_d5<block_dim_x, Ls, M_sm, false, Arg>(arg, sm_a); // dagger = false
329 construct_matrix_a_m5inv<block_dim_x, Ls, M_sm, false, Arg>(arg, sm_a); // dagger = false
331 } else if (type == 2) {
332 if (arg.small_kappa) {
333 construct_matrix_a_d5<block_dim_x, Ls, M_sm, true, Arg>(arg, sm_a); // dagger = true
335 construct_matrix_a_m5inv<block_dim_x, Ls, M_sm, true, Arg>(arg, sm_a); // dagger = false
337 } else if (type == 1) {
338 construct_matrix_a_m5inv<block_dim_x, Ls, M_sm, false, Arg>(arg, sm_a); // dagger = false
339 } else if (type == 3) {
340 construct_matrix_a_d5<block_dim_x, Ls, M_sm, true, Arg>(arg, sm_a); // dagger = true
341 } else if (type == 4) {
342 construct_matrix_a_d5<block_dim_x, Ls, M_sm, false, Arg>(arg, sm_a); // dagger = true
347 int s4_shift_base = blockIdx.x * blockDim.x; // base.
350 constexpr int tm_dim = M / mma::MMA_M;
351 constexpr int tn_dim = N / mma::MMA_N;
352 constexpr int tk_dim = M / mma::MMA_K;
354 constexpr int total_warp = block_dim_x * Ls >> 5;
355 const int this_warp = (threadIdx.y * block_dim_x + threadIdx.x) >> 5;
357 constexpr int total_tile = tm_dim * tn_dim;
359 constexpr int warp_cycle = total_tile / total_warp;
360 const int warp_m = this_warp * warp_cycle / tn_dim;
362 mma::WarpRegisterMapping wrm(threadIdx.y * blockDim.x + threadIdx.x);
363 mma::MmaOperandA op_a[reload ? 1 : tk_dim];
364 mma::MmaOperandA op_a_aux[reload ? 1 : tk_dim];
365 if (!reload) { // the data in registers can be resued.
367 for (int tile_k = 0; tile_k < tk_dim; tile_k++) { op_a[tile_k].template load<M_sm>(sm_a, tile_k, warp_m, wrm); }
372 if (!reload) { // in the preload case we preload ...
373 construct_matrix_a_m5inv<block_dim_x, Ls, M_sm, true, Arg>(arg, sm_a); // dagger = true
377 for (int tile_k = 0; tile_k < tk_dim; tile_k++) {
378 op_a_aux[tile_k].template load<M_sm>(sm_a, tile_k, warp_m, wrm);
382 construct_matrix_a_m5inv<block_dim_x, Ls, M_sm, true, Arg>(arg, sm_a_black); // dagger = true
387 while (s4_shift_base < arg.volume_4d_cb_shift) {
389 s4_shift = s4_shift_base + threadIdx.x;
390 coordinate_from_shrinked_index(x, s4_shift, arg.shrinked_dim, arg.shift, arg.parity);
391 sid = threadIdx.y * arg.volume_4d_cb + index_4d_cb_from_coordinate_4d(x, arg.dim);
393 if (s4_shift >= arg.volume_4d_cb_shift) { idle = true; }
397 // the Wilson hopping terms
399 apply_wilson_5d<storage_type, false, true, true>(in_vec, x, arg, threadIdx.y); // dagger = false; halo = true
400 } else if (type == 2) {
401 apply_wilson_5d<storage_type, true, false, false>(in_vec, x, arg,
402 threadIdx.y); // dagger = true; halo = false
403 } else if (type == 1) {
404 apply_wilson_5d<storage_type, false, true, false>(in_vec, x, arg, threadIdx.y); // dagger = false; halo = true
405 } else if (type == 3) {
406 apply_wilson_5d<storage_type, true, false, false>(in_vec, x, arg,
407 threadIdx.y); // dagger = true; halo = false
408 } else if (type == 4) {
409 int sid_shift = threadIdx.y * arg.volume_4d_cb_shift + s4_shift;
410 in_vec = arg.in(sid_shift, explicit_parity);
412 // store result to shared memory
414 load_matrix_b_vector<block_dim_x, Ls, N_sm / 2, false>(in_vec, sm_b, smem_scale); // acc(accumulation) = false
417 mma_sync_gemm<block_dim_x, Ls, M, N, M_sm, N_sm, reload>(op_a, sm_a, sm_c, sm_c, wrm);
425 sid_back = index_from_extended_coordinate(x, arg.dim, arg.comm, threadIdx.y);
428 aux_in_vec = arg.x(sid_back, explicit_parity);
431 load_matrix_b_vector<block_dim_x, Ls, N_sm / 2, true>(aux_in_vec, sm_b, smem_scale, arg.m_scale); // acc = true
432 if (!idle && center) { store_matrix_c<storage_type, N_sm>(arg.y, sm_b, sid_back, smem_scale[0]); }
434 mma_sync_gemm<block_dim_x, Ls, M, N, M_sm, N_sm, reload>(op_a_aux, sm_a_black, sm_c, sm_c, wrm);
437 } else if (type == 3) {
439 int sid_shift = threadIdx.y * arg.volume_4d_cb_shift + s4_shift;
440 if (!idle) { aux_in_vec = arg.x(sid_shift, explicit_parity); }
441 load_matrix_b_vector<block_dim_x, Ls, N_sm / 2, true, false>(aux_in_vec, sm_b, smem_scale, arg.m_scale);
442 if (!idle) { arg.out(sid_shift, explicit_parity) = aux_in_vec; }
447 } else if (type == 1) {
448 if (!idle) { store_matrix_c<storage_type, N_sm>(arg.out, sm_b, sid, smem_scale[0]); }
450 if (!idle) { store_matrix_c<storage_type, N_sm>(arg.out, sm_b, sid, smem_scale[0] * arg.m_scale); }
453 s4_shift_base += gridDim.x * blockDim.x;
458 template <class Arg> class FusedDslash : public Tunable
463 const ColorSpinorField &meta;
465 /** Whether to use variable or fixed coefficient algorithm. Must be true if using ZMOBIUS */
466 static constexpr bool var_inverse = true;
468 long long flops() const
470 constexpr long long hop = 7ll * 8ll;
471 constexpr long long mat = 2ll * 4ll * Arg::Ls - 1ll;
472 long long volume_4d_cb_halo_shift = (arg.dim[0] - 2 * arg.halo_shift[0]) * (arg.dim[1] - 2 * arg.halo_shift[1])
473 * (arg.dim[2] - 2 * arg.halo_shift[2]) * (arg.dim[3] - 2 * arg.halo_shift[3]) / 2;
475 long long flops_ = 0;
477 case MdwfFusedDslashType::D4_D5INV_D5PRE:
478 flops_ = volume_4d_cb_halo_shift * 6ll * 4ll * arg.Ls * hop + arg.volume_4d_cb_shift * 24ll * arg.Ls * mat;
480 case MdwfFusedDslashType::D4_D5INV_D5INVDAG:
482 = volume_4d_cb_halo_shift * 6ll * 4ll * arg.Ls * hop + arg.volume_4d_cb_shift * 24ll * arg.Ls * 2ll * mat;
484 case MdwfFusedDslashType::D4DAG_D5PREDAG_D5INVDAG:
485 case MdwfFusedDslashType::D4DAG_D5PREDAG:
486 flops_ = arg.volume_4d_cb_shift * 6ll * 4ll * arg.Ls
487 * (hop + mat); // for 2 and 3 we don't have the halo complication.
489 case MdwfFusedDslashType::D5PRE: flops_ = arg.volume_4d_cb_shift * 6ll * 4ll * arg.Ls * (mat); break;
490 default: errorQuda("Unknown MdwfFusedDslashType");
496 long long bytes() const
498 auto site_size = arg.Ls * (2ll * meta.Nspin() * meta.Ncolor() * meta.Precision() + sizeof(float));
500 auto b_m0 = ((dim[0] - 0) * (dim[1] - 0) * (dim[2] - 0) * (dim[3] - 0) / 2) * site_size;
501 auto b_m1 = ((dim[0] - 1) * (dim[1] - 1) * (dim[2] - 1) * (dim[3] - 1) / 2) * site_size;
502 auto b_m2 = ((dim[0] - 2) * (dim[1] - 2) * (dim[2] - 2) * (dim[3] - 2) / 2) * site_size;
504 case MdwfFusedDslashType::D4_D5INV_D5PRE: return b_m1 + b_m2 + arg.U.Bytes();
505 case MdwfFusedDslashType::D4_D5INV_D5INVDAG: return 2 * b_m2 + b_m1 + b_m0 + arg.U.Bytes();
506 case MdwfFusedDslashType::D4DAG_D5PREDAG_D5INVDAG: return b_m1 + b_m0 + arg.U.Bytes();
507 case MdwfFusedDslashType::D4DAG_D5PREDAG: return 2 * b_m2 + b_m1 + arg.U.Bytes();
508 case MdwfFusedDslashType::D5PRE: return 2 * b_m0;
509 default: errorQuda("Unknown MdwfFusedDslashType");
514 bool tuneAuxDim() const { return true; }
516 int blockStep() const { return 16; }
517 int blockMin() const { return 16; }
518 unsigned int maxBlockSize(const TuneParam ¶m) const { return 32; }
520 int gridStep() const { return deviceProp.multiProcessorCount; }
521 unsigned int maxGridSize() const { return (arg.volume_4d_cb_shift + blockMin() - 1) / blockMin(); }
522 unsigned int minGridSize() const { return deviceProp.multiProcessorCount; }
524 unsigned int sharedBytesPerBlock(const TuneParam ¶m) const
526 const int a_size = (param.block.y * 4) * (param.block.y * 4 + sm_m_pad_size(param.block.y * 4));
527 const int b_size = (param.block.y * 4) * (param.block.x * 6 + sm_n_pad_size(param.block.x * 6));
528 // (Ls*4) by (Ls*4), (Ls*4) by (volume_4d*6 + 16)
529 if (param.aux.x == 1) { // aux.x == 1 --> reload == true
530 if (arg.type == MdwfFusedDslashType::D4_D5INV_D5INVDAG) {
531 return (a_size * 2 + b_size) * sizeof(half) + 128;
533 return (a_size + b_size) * sizeof(half) + 128;
536 return (a_size > b_size ? a_size : b_size) * sizeof(half) + 128;
540 unsigned int sharedBytesPerThread() const { return 0; }
542 bool advanceAux(TuneParam ¶m) const
544 bool aux_advanced = false;
545 if (param.aux.x == 0) { // first see if aux.x(ONLY 0(false) or 1(true))
549 if (param.aux.y < 3) { // second see if aux.y
555 // shared bytes depends on aux, so update if changed
556 if (aux_advanced) param.shared_bytes = sharedBytesPerBlock(param);
560 // overloaded to return max dynamic shared memory if doing shared-memory inverse
561 unsigned int maxSharedBytesPerBlock() const { return maxDynamicSharedBytesPerBlock(); }
564 FusedDslash(Arg &arg, const ColorSpinorField &meta) : arg(arg), meta(meta)
566 strcpy(aux, meta.AuxString());
567 if (arg.dagger) strcat(aux, ",Dagger");
570 case MdwfFusedDslashType::D4_D5INV_D5PRE: sprintf(config, ",f0"); break;
571 case MdwfFusedDslashType::D4DAG_D5PREDAG_D5INVDAG: sprintf(config, ",f2"); break;
572 case MdwfFusedDslashType::D4_D5INV_D5INVDAG: sprintf(config, ",f1"); break;
573 case MdwfFusedDslashType::D4DAG_D5PREDAG: sprintf(config, ",f3"); break;
574 case MdwfFusedDslashType::D5PRE: sprintf(config, ",f4"); break;
575 default: errorQuda("Unknown MdwfFusedDslashType");
578 sprintf(config, "shift%d%d%d%d,halo%d%d%d%d,comm%d%d%d%d", arg.shift[0], arg.shift[1], arg.shift[2],
579 arg.shift[3], arg.halo_shift[0], arg.halo_shift[1], arg.halo_shift[2], arg.halo_shift[3], arg.comm[0],
580 arg.comm[1], arg.comm[2], arg.comm[3]);
584 template <typename T> inline void launch(T *f, const TuneParam &tp, Arg &arg, const qudaStream_t &stream)
586 const_cast<TuneParam &>(tp).set_max_shared_bytes = true;
587 qudaLaunchKernel(f, tp, stream, arg);
590 // The following apply<...> functions are used to turn the tune parameters into template arguments.
591 // Specifically tp.aux.y dictates the minBlocksPerMultiprocessor in __launch_bounds__(..).
592 // tp.aux.x dictates whether or not to reload.
593 template <int block_dim_x, bool reload, int type>
594 void apply(const TuneParam &tp, Arg &arg, const qudaStream_t &stream)
597 case 1: launch(fused_tensor_core<block_dim_x, 1, reload, Arg, type>, tp, arg, stream); break;
598 case 2: launch(fused_tensor_core<block_dim_x, 2, reload, Arg, type>, tp, arg, stream); break;
599 case 3: launch(fused_tensor_core<block_dim_x, 3, reload, Arg, type>, tp, arg, stream); break;
600 default: errorQuda("NOT valid tp.aux.y(=%d)\n", tp.aux.y);
604 template <bool reload, int type> void apply(const TuneParam &tp, Arg &arg, const qudaStream_t &stream)
606 switch (tp.block.x) {
607 case 16: apply<16, reload, type>(tp, arg, stream); break;
608 case 32: apply<32, reload, type>(tp, arg, stream); break;
609 default: errorQuda("NOT valid tp.block.x(=%d)\n", tp.block.x);
613 template <int type> void apply(const TuneParam &tp, Arg &arg, const qudaStream_t &stream)
616 apply<false, type>(tp, arg, stream); // reload = false
618 apply<true, type>(tp, arg, stream); // reload = true
622 void apply(const qudaStream_t &stream)
624 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
626 case MdwfFusedDslashType::D4_D5INV_D5PRE: apply<0>(tp, arg, stream); break;
627 case MdwfFusedDslashType::D4_D5INV_D5INVDAG: apply<1>(tp, arg, stream); break;
628 case MdwfFusedDslashType::D4DAG_D5PREDAG_D5INVDAG: apply<2>(tp, arg, stream); break;
629 case MdwfFusedDslashType::D4DAG_D5PREDAG: apply<3>(tp, arg, stream); break;
630 case MdwfFusedDslashType::D5PRE: apply<4>(tp, arg, stream); break;
631 default: errorQuda("Unknown MdwfFusedDslashType");
635 void initTuneParam(TuneParam ¶m) const
637 Tunable::initTuneParam(param);
638 param.block = dim3(blockMin(), arg.Ls, 1); // Ls must be contained in the block
639 param.grid = dim3(minGridSize(), 1, 1);
640 param.shared_bytes = sharedBytesPerBlock(param);
645 void defaultTuneParam(TuneParam ¶m) const { initTuneParam(param); }
647 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
650 // Apply the 5th dimension dslash operator to a colorspinor field
651 // out = Dslash5 * in
652 template <typename storage_type, int nColor, QudaReconstructType recon> struct FusedApply {
654 inline FusedApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, ColorSpinorField &y,
655 const ColorSpinorField &x, double m_f, double m_5, const Complex *b_5, const Complex *c_5,
656 bool dagger, int parity, int shift[4], int halo_shift[4], MdwfFusedDslashType type)
659 // Only mutiple of 4 are supported since tensor core MMA only supports multiple of 16 shapes and we get a
660 // factor of 4 for free.
663 FusedDslashArg<storage_type, recon, 4> arg(out, in, U, y, x, m_f, m_5, b_5, c_5, dagger, parity, shift,
665 FusedDslash<decltype(arg)> dslash(arg, in);
666 dslash.apply(streams[Nstream - 1]);
669 FusedDslashArg<storage_type, recon, 8> arg(out, in, U, y, x, m_f, m_5, b_5, c_5, dagger, parity, shift,
671 FusedDslash<decltype(arg)> dslash(arg, in);
672 dslash.apply(streams[Nstream - 1]);
675 FusedDslashArg<storage_type, recon, 12> arg(out, in, U, y, x, m_f, m_5, b_5, c_5, dagger, parity, shift,
677 FusedDslash<decltype(arg)> dslash(arg, in);
678 dslash.apply(streams[Nstream - 1]);
681 FusedDslashArg<storage_type, recon, 16> arg(out, in, U, y, x, m_f, m_5, b_5, c_5, dagger, parity, shift,
683 FusedDslash<decltype(arg)> dslash(arg, in);
684 dslash.apply(streams[Nstream - 1]);
687 FusedDslashArg<storage_type, recon, 20> arg(out, in, U, y, x, m_f, m_5, b_5, c_5, dagger, parity, shift,
689 FusedDslash<decltype(arg)> dslash(arg, in);
690 dslash.apply(streams[Nstream - 1]);
692 default: errorQuda("Ls = %d is NOT supported.\n", in.X(4));
696 #endif // #if (CUDA_VERSION >= 10010 && __COMPUTE_CAPABILITY__ >= 700)
698 void apply_fused_dslash(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, ColorSpinorField &y,
699 const ColorSpinorField &x, double m_f, double m_5, const Complex *b_5, const Complex *c_5,
700 bool dagger, int parity, int shift[4], int halo_shift[4], MdwfFusedDslashType type)
702 #if defined(GPU_DOMAIN_WALL_DIRAC) && (CUDA_VERSION >= 10010 && __COMPUTE_CAPABILITY__ >= 700)
703 checkLocation(out, in); // check all locations match
704 instantiatePreconditioner<FusedApply>(out, in, U, y, x, m_f, m_5, b_5, c_5, dagger, parity, shift, halo_shift,
707 errorQuda("Domain wall dslash with tensor cores has not been built");
710 } // namespace mobius_tensor_core