QUDA  v1.1.0
A library for QCD on GPUs
coarse_op_mma_launch.h
Go to the documentation of this file.
1 #pragma once
2 
3 #include <tune_quda.h>
4 
5 #if (CUDA_VERSION >= 10010 && __COMPUTE_CAPABILITY__ >= 700)
6 
7 #include <kernels/coarse_op_kernel_mma.cuh>
8 
9 #endif
10 
17 namespace quda
18 {
19 
20  namespace mma
21  {
22 
23 #if (CUDA_VERSION >= 10010 && __COMPUTE_CAPABILITY__ >= 700)
24 
25  template <bool from_coarse, int dim, QudaDirection dir, int bM, int bN, int bK, int block_y, int block_z, class Arg>
26  typename std::enable_if<!Arg::is_mma_compatible, void>::type
27  launch_compute_uv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
28  {
29  errorQuda("MMA implementation is ONLY built for AoS order.");
30  }
31 
32  template <bool from_coarse, int dim, QudaDirection dir, int bM, int bN, int bK, int block_y, int block_z, class Arg>
33  typename std::enable_if<Arg::is_mma_compatible, void>::type
34  launch_compute_uv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
35  {
36  tp.block.x = 1;
37  tp.block.y = block_y;
38  tp.block.z = block_z;
39  constexpr int shared_bytes = shared_memory_bytes(bM, bN, bK);
40  tp.shared_bytes = shared_bytes;
41 
42  // TODO: Fix the split M/N.
43  constexpr int t_m = 1;
44  constexpr int t_n = 1;
45 
46  tp.grid = dim3(min_threads * t_m * t_n, 2, 1);
47 
48  auto kernel = ComputeUVMMA<from_coarse, dim, dir, bM, bN, bK, block_y, block_z, Arg>;
49  tp.set_max_shared_bytes = true;
50  qudaLaunchKernel(kernel, tp, stream, arg);
51  }
52 
53  template <bool from_coarse, int bM, int bN, int bK, int block_y, int block_z, class Arg>
54  void launch_compute_uv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
55  {
56  if (arg.dir == QUDA_BACKWARDS) {
57  switch (arg.dim) {
58  case 0:
59  launch_compute_uv_kernel<from_coarse, 0, QUDA_BACKWARDS, bM, bN, bK, block_y, block_z>(tp, arg, min_threads,
60  stream);
61  break;
62  case 1:
63  launch_compute_uv_kernel<from_coarse, 1, QUDA_BACKWARDS, bM, bN, bK, block_y, block_z>(tp, arg, min_threads,
64  stream);
65  break;
66  case 2:
67  launch_compute_uv_kernel<from_coarse, 2, QUDA_BACKWARDS, bM, bN, bK, block_y, block_z>(tp, arg, min_threads,
68  stream);
69  break;
70  case 3:
71  launch_compute_uv_kernel<from_coarse, 3, QUDA_BACKWARDS, bM, bN, bK, block_y, block_z>(tp, arg, min_threads,
72  stream);
73  break;
74  default: errorQuda("arg.dim(=%d) is NOT supported.", arg.dim);
75  }
76  } else {
77  switch (arg.dim) {
78  case 0:
79  launch_compute_uv_kernel<from_coarse, 0, QUDA_FORWARDS, bM, bN, bK, block_y, block_z>(tp, arg, min_threads,
80  stream);
81  break;
82  case 1:
83  launch_compute_uv_kernel<from_coarse, 1, QUDA_FORWARDS, bM, bN, bK, block_y, block_z>(tp, arg, min_threads,
84  stream);
85  break;
86  case 2:
87  launch_compute_uv_kernel<from_coarse, 2, QUDA_FORWARDS, bM, bN, bK, block_y, block_z>(tp, arg, min_threads,
88  stream);
89  break;
90  case 3:
91  launch_compute_uv_kernel<from_coarse, 3, QUDA_FORWARDS, bM, bN, bK, block_y, block_z>(tp, arg, min_threads,
92  stream);
93  break;
94  default: errorQuda("arg.dim(=%d) is NOT supported.", arg.dim);
95  }
96  }
97  }
98 
99  template <bool from_coarse, bool query_max = false, class Arg>
100  typename std::enable_if<!from_coarse, int>::type
101  launch_compute_uv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
102  {
103  errorQuda("MMA implementation is ONLY built for !from_coarse.");
104  return -1;
105  }
106 
113  template <bool from_coarse, bool query_max = false, class Arg>
114  typename std::enable_if<Arg::fineColor == 6 && Arg::coarseColor == 6 && Arg::fineSpin == 2 && Arg::coarseSpin == 2, int>::type
115  launch_compute_uv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
116  {
117  if (query_max) return 1;
118  switch (tp.aux.x) {
119  // clang-format off
120  case 0: launch_compute_uv_kernel<from_coarse, 16, 16, 8, 4, 8>(tp, arg, min_threads, stream); break;
121  case 1: launch_compute_uv_kernel<from_coarse, 16, 16, 8, 8, 4>(tp, arg, min_threads, stream); break;
122  // clang-format on
123  default:
124  errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin,
125  Arg::fineColor, Arg::coarseColor);
126  }
127  return -1;
128  }
129 
130  template <bool from_coarse, bool query_max = false, class Arg>
131  typename std::enable_if<Arg::fineColor == 24 && Arg::coarseColor == 24 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
132  int>::type
133  launch_compute_uv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
134  {
135 #if (__COMPUTE_CAPABILITY__ >= 750) // Turing or above
136  if (query_max) return 5;
137  switch (tp.aux.x) {
138  // clang-format off
139  case 0: launch_compute_uv_kernel<from_coarse, 48, 24, 24, 24, 12>(tp, arg, min_threads, stream); break;
140  case 1: launch_compute_uv_kernel<from_coarse, 48, 24, 24, 16, 6>(tp, arg, min_threads, stream); break;
141  case 2: launch_compute_uv_kernel<from_coarse, 48, 24, 24, 16, 2>(tp, arg, min_threads, stream); break;
142  case 3: launch_compute_uv_kernel<from_coarse, 48, 24, 24, 8, 12>(tp, arg, min_threads, stream); break;
143  case 4: launch_compute_uv_kernel<from_coarse, 48, 24, 24, 8, 4>(tp, arg, min_threads, stream); break;
144  case 5: launch_compute_uv_kernel<from_coarse, 48, 24, 24, 4, 8>(tp, arg, min_threads, stream); break;
145  // clang-format on
146  default:
147  errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin,
148  Arg::fineColor, Arg::coarseColor);
149  }
150 #else
151  if (query_max) return 4;
152  switch (tp.aux.x) {
153  // clang-format off
154  case 0: launch_compute_uv_kernel<from_coarse, 48, 32, 24, 8, 8>(tp, arg, min_threads, stream); break;
155  case 1: launch_compute_uv_kernel<from_coarse, 48, 32, 24, 8, 12>(tp, arg, min_threads, stream); break;
156  case 2: launch_compute_uv_kernel<from_coarse, 48, 32, 24, 8, 24>(tp, arg, min_threads, stream); break;
157  case 3: launch_compute_uv_kernel<from_coarse, 48, 32, 24, 16, 4>(tp, arg, min_threads, stream); break;
158  case 4: launch_compute_uv_kernel<from_coarse, 48, 32, 24, 16, 12>(tp, arg, min_threads, stream); break;
159  // clang-format on
160  default:
161  errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin,
162  Arg::fineColor, Arg::coarseColor);
163  }
164 #endif
165  return -1;
166  }
167 
168  template <bool from_coarse, bool query_max = false, class Arg>
169  typename std::enable_if<Arg::fineColor == 24 && Arg::coarseColor == 32 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
170  int>::type
171  launch_compute_uv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
172  {
173  if (query_max) return 3;
174  switch (tp.aux.x) {
175  // clang-format off
176  case 0: launch_compute_uv_kernel<from_coarse, 48, 32, 24, 8, 12>(tp, arg, min_threads, stream); break;
177  case 1: launch_compute_uv_kernel<from_coarse, 48, 32, 24, 8, 12>(tp, arg, min_threads, stream); break;
178  case 2: launch_compute_uv_kernel<from_coarse, 48, 32, 24, 8, 24>(tp, arg, min_threads, stream); break;
179  case 3: launch_compute_uv_kernel<from_coarse, 48, 32, 24, 16, 12>(tp, arg, min_threads, stream); break;
180  // clang-format on
181  default:
182  errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin,
183  Arg::fineColor, Arg::coarseColor);
184  }
185  return -1;
186  }
187 
188  template <bool from_coarse, bool query_max = false, class Arg>
189  typename std::enable_if<Arg::fineColor == 24 && Arg::coarseColor == 64 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
190  int>::type
191  launch_compute_uv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
192  {
193  if (query_max) return 5;
194  switch (tp.aux.x) {
195  // clang-format off
196  case 0: launch_compute_uv_kernel<from_coarse, 48, 64, 24, 8, 12>(tp, arg, min_threads, stream); break;
197  case 1: launch_compute_uv_kernel<from_coarse, 48, 64, 24, 8, 12>(tp, arg, min_threads, stream); break;
198  case 2: launch_compute_uv_kernel<from_coarse, 48, 64, 24, 8, 24>(tp, arg, min_threads, stream); break;
199  case 3: launch_compute_uv_kernel<from_coarse, 48, 64, 24, 16, 12>(tp, arg, min_threads, stream); break;
200  case 4: launch_compute_uv_kernel<from_coarse, 48, 64, 24, 32, 12>(tp, arg, min_threads, stream); break;
201  case 5: launch_compute_uv_kernel<from_coarse, 48, 64, 24, 16, 24>(tp, arg, min_threads, stream); break;
202  // clang-format on
203  default:
204  errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin,
205  Arg::fineColor, Arg::coarseColor);
206  }
207  return -1;
208  }
209 
210  // note --- currently unused, may be revisited in the future
211  template <bool from_coarse, bool query_max = false, class Arg>
212  typename std::enable_if<Arg::fineColor == 24 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
213  int>::type
214  launch_compute_uv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
215  {
216  if (query_max) return 6;
217  switch (tp.aux.x) {
218  // clang-format off
219  case 0: launch_compute_uv_kernel<from_coarse, 48, 96, 24, 8, 12>(tp, arg, min_threads, stream); break;
220  case 1: launch_compute_uv_kernel<from_coarse, 48, 96, 24, 8, 24>(tp, arg, min_threads, stream); break;
221  case 2: launch_compute_uv_kernel<from_coarse, 48, 96, 24, 16, 6>(tp, arg, min_threads, stream); break;
222  case 3: launch_compute_uv_kernel<from_coarse, 48, 96, 24, 16, 12>(tp, arg, min_threads, stream); break;
223  case 4: launch_compute_uv_kernel<from_coarse, 48, 96, 24, 16, 12>(tp, arg, min_threads, stream); break;
224  case 5: launch_compute_uv_kernel<from_coarse, 48, 96, 24, 24, 12>(tp, arg, min_threads, stream); break;
225  case 6: launch_compute_uv_kernel<from_coarse, 48, 96, 24, 24, 24>(tp, arg, min_threads, stream); break;
226  // clang-format on
227  default:
228  errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin,
229  Arg::fineColor, Arg::coarseColor);
230  }
231  return -1;
232  }
233 
234  template <bool from_coarse, bool query_max = false, class Arg>
235  typename std::enable_if<Arg::fineColor == 32 && Arg::coarseColor == 32 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
236  int>::type
237  launch_compute_uv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
238  {
239  if (query_max) return 2;
240  switch (tp.aux.x) {
241  // clang-format off
242  case 0: launch_compute_uv_kernel<from_coarse, 64, 32, 32, 8, 16>(tp, arg, min_threads, stream); break;
243  case 1: launch_compute_uv_kernel<from_coarse, 64, 32, 32, 8, 32>(tp, arg, min_threads, stream); break;
244  case 2: launch_compute_uv_kernel<from_coarse, 64, 32, 32, 16, 16>(tp, arg, min_threads, stream); break;
245  // clang-format on
246  default:
247  errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin,
248  Arg::fineColor, Arg::coarseColor);
249  }
250  return -1;
251  }
252 
253  template <bool from_coarse, bool query_max = false, class Arg>
254  typename std::enable_if<Arg::fineColor == 64 && Arg::coarseColor == 64 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
255  int>::type
256  launch_compute_uv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
257  {
258  if (query_max) return 6;
259  switch (tp.aux.x) {
260  // clang-format off
261  case 0: launch_compute_uv_kernel<from_coarse, 128, 64, 64, 8, 16>(tp, arg, min_threads, stream); break;
262  case 1: launch_compute_uv_kernel<from_coarse, 128, 64, 64, 8, 32>(tp, arg, min_threads, stream); break;
263  case 2: launch_compute_uv_kernel<from_coarse, 128, 64, 64, 16, 8>(tp, arg, min_threads, stream); break;
264  case 3: launch_compute_uv_kernel<from_coarse, 128, 64, 64, 16, 16>(tp, arg, min_threads, stream); break;
265  case 4: launch_compute_uv_kernel<from_coarse, 128, 64, 64, 16, 32>(tp, arg, min_threads, stream); break;
266  case 5: launch_compute_uv_kernel<from_coarse, 128, 64, 64, 32, 8>(tp, arg, min_threads, stream); break;
267  case 6: launch_compute_uv_kernel<from_coarse, 128, 64, 64, 32, 16>(tp, arg, min_threads, stream); break;
268  // clang-format on
269  default:
270  errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin,
271  Arg::fineColor, Arg::coarseColor);
272  }
273  return -1;
274  }
275 
276  template <bool from_coarse, bool query_max = false, class Arg>
277  typename std::enable_if<Arg::fineColor == 64 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
278  int>::type
279  launch_compute_uv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
280  {
281  if (query_max) return 6;
282  switch (tp.aux.x) {
283  // clang-format off
284  case 0: launch_compute_uv_kernel<from_coarse, 64, 96, 64, 32, 24>(tp, arg, min_threads, stream); break;
285  case 1: launch_compute_uv_kernel<from_coarse, 64, 96, 64, 12, 32>(tp, arg, min_threads, stream); break;
286  case 2: launch_compute_uv_kernel<from_coarse, 64, 96, 64, 32, 12>(tp, arg, min_threads, stream); break;
287  case 3: launch_compute_uv_kernel<from_coarse, 64, 96, 64, 16, 24>(tp, arg, min_threads, stream); break;
288  case 4: launch_compute_uv_kernel<from_coarse, 64, 96, 64, 16, 48>(tp, arg, min_threads, stream); break;
289  case 5: launch_compute_uv_kernel<from_coarse, 64, 96, 64, 32, 6>(tp, arg, min_threads, stream); break;
290  case 6: launch_compute_uv_kernel<from_coarse, 64, 96, 64, 32, 8>(tp, arg, min_threads, stream); break;
291  // clang-format on
292  default:
293  errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin,
294  Arg::fineColor, Arg::coarseColor);
295  }
296  return -1;
297  }
298 
299  // note --- currently unused, may be revisited in the future
300  template <bool from_coarse, bool query_max = false, class Arg>
301  typename std::enable_if<Arg::fineColor == 96 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
302  int>::type
303  launch_compute_uv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
304  {
305  if (query_max) return 5;
306  switch (tp.aux.x) {
307  // clang-format off
308  case 0: launch_compute_uv_kernel<from_coarse, 192, 96, 48, 24, 12>(tp, arg, min_threads, stream); break;
309  case 1: launch_compute_uv_kernel<from_coarse, 192, 96, 48, 24, 24>(tp, arg, min_threads, stream); break;
310  case 2: launch_compute_uv_kernel<from_coarse, 96, 96, 96, 24, 12>(tp, arg, min_threads, stream); break;
311  case 3: launch_compute_uv_kernel<from_coarse, 96, 96, 96, 24, 24>(tp, arg, min_threads, stream); break;
312  case 4: launch_compute_uv_kernel<from_coarse, 96, 96, 96, 32, 12>(tp, arg, min_threads, stream); break;
313  case 5: launch_compute_uv_kernel<from_coarse, 96, 96, 96, 12, 32>(tp, arg, min_threads, stream); break;
314  // clang-format on
315  default:
316  errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin,
317  Arg::fineColor, Arg::coarseColor);
318  }
319  return -1;
320  }
321 
322  template <bool from_coarse, int dim, QudaDirection dir, int bM, int bN, int bK, int block_y, int block_z, class Arg>
323  typename std::enable_if<!Arg::is_mma_compatible, void>::type
324  launch_compute_vuv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
325  {
326  errorQuda("MMA implementation is ONLY built for AoS order.");
327  }
328 
329  template <bool from_coarse, int dim, QudaDirection dir, int bM, int bN, int bK, int block_y, int block_z, class Arg>
330  typename std::enable_if<Arg::is_mma_compatible, void>::type
331  launch_compute_vuv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
332  {
333  tp.block.x = 1;
334  tp.block.y = block_y;
335  tp.block.z = block_z;
336  constexpr int shared_bytes = shared_memory_bytes(bM, bN, bK);
337  tp.shared_bytes = shared_bytes;
338 
339  // TODO: Fix the split M/N.
340  constexpr int t_m = 1;
341  constexpr int t_n = 1;
342 
343  tp.grid = dim3(min_threads * t_m * t_n, 2, 1);
344 
345  auto kernel = ComputeVUVMMA<from_coarse, dim, dir, bM, bN, bK, block_y, block_z, Arg>;
346  tp.set_max_shared_bytes = true;
347  qudaLaunchKernel(kernel, tp, stream, arg);
348  }
349 
350  template <bool from_coarse, int bM, int bN, int bK, int block_y, int block_z, class Arg>
351  void launch_compute_vuv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
352  {
353  if (arg.dir == QUDA_BACKWARDS) {
354  switch (arg.dim) {
355  case 0:
356  launch_compute_vuv_kernel<from_coarse, 0, QUDA_BACKWARDS, bM, bN, bK, block_y, block_z>(tp, arg, min_threads,
357  stream);
358  break;
359  case 1:
360  launch_compute_vuv_kernel<from_coarse, 1, QUDA_BACKWARDS, bM, bN, bK, block_y, block_z>(tp, arg, min_threads,
361  stream);
362  break;
363  case 2:
364  launch_compute_vuv_kernel<from_coarse, 2, QUDA_BACKWARDS, bM, bN, bK, block_y, block_z>(tp, arg, min_threads,
365  stream);
366  break;
367  case 3:
368  launch_compute_vuv_kernel<from_coarse, 3, QUDA_BACKWARDS, bM, bN, bK, block_y, block_z>(tp, arg, min_threads,
369  stream);
370  break;
371  default: errorQuda("arg.dim(=%d) is NOT supported.", arg.dim);
372  }
373  } else {
374  switch (arg.dim) {
375  case 0:
376  launch_compute_vuv_kernel<from_coarse, 0, QUDA_FORWARDS, bM, bN, bK, block_y, block_z>(tp, arg, min_threads,
377  stream);
378  break;
379  case 1:
380  launch_compute_vuv_kernel<from_coarse, 1, QUDA_FORWARDS, bM, bN, bK, block_y, block_z>(tp, arg, min_threads,
381  stream);
382  break;
383  case 2:
384  launch_compute_vuv_kernel<from_coarse, 2, QUDA_FORWARDS, bM, bN, bK, block_y, block_z>(tp, arg, min_threads,
385  stream);
386  break;
387  case 3:
388  launch_compute_vuv_kernel<from_coarse, 3, QUDA_FORWARDS, bM, bN, bK, block_y, block_z>(tp, arg, min_threads,
389  stream);
390  break;
391  default: errorQuda("arg.dim(=%d) is NOT supported.", arg.dim);
392  }
393  }
394  }
395 
396  template <bool from_coarse, bool query_max = false, class Arg>
397  typename std::enable_if<!from_coarse, int>::type
398  launch_compute_vuv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
399  {
400  errorQuda("MMA implementation is ONLY built for !from_coarse.");
401  return -1;
402  }
403 
404  template <bool from_coarse, bool query_max = false, class Arg>
405  typename std::enable_if<Arg::fineColor == 6 && Arg::coarseColor == 6 && Arg::fineSpin == 2 && Arg::coarseSpin == 2, int>::type
406  launch_compute_vuv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
407  {
408  if (query_max) return 2;
409  switch (tp.aux.x) {
410  // clang-format off
411  case 0: launch_compute_vuv_kernel<from_coarse, 16, 16, 8, 8, 4>(tp, arg, min_threads, stream); break;
412  case 1: launch_compute_vuv_kernel<from_coarse, 16, 16, 8, 4, 8>(tp, arg, min_threads, stream); break;
413  // clang-format on
414  default:
415  errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin,
416  Arg::fineColor, Arg::coarseColor);
417  }
418  return -1;
419  }
420 
421  template <bool from_coarse, bool query_max = false, class Arg>
422  typename std::enable_if<Arg::fineColor == 24 && Arg::coarseColor == 24 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
423  int>::type
424  launch_compute_vuv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
425  {
426 #if (__COMPUTE_CAPABILITY__ >= 750) // Turing or above
427  if (query_max) return 1;
428  switch (tp.aux.x) {
429  // clang-format off
430  case 0: launch_compute_vuv_kernel<from_coarse, 32, 24, 24, 16, 6>(tp, arg, min_threads, stream); break;
431  case 1: launch_compute_vuv_kernel<from_coarse, 32, 24, 24, 16, 12>(tp, arg, min_threads, stream); break;
432  // clang-format on
433  default:
434  errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin,
435  Arg::fineColor, Arg::coarseColor);
436  }
437 #else
438  if (query_max) return 4;
439  switch (tp.aux.x) {
440  // clang-format off
441  case 0: launch_compute_vuv_kernel<from_coarse, 32, 32, 24, 8, 8>(tp, arg, min_threads, stream); break;
442  case 1: launch_compute_vuv_kernel<from_coarse, 32, 32, 24, 8, 16>(tp, arg, min_threads, stream); break;
443  case 2: launch_compute_vuv_kernel<from_coarse, 32, 32, 24, 8, 16>(tp, arg, min_threads, stream); break;
444  case 3: launch_compute_vuv_kernel<from_coarse, 32, 32, 24, 16, 8>(tp, arg, min_threads, stream); break;
445  case 4: launch_compute_vuv_kernel<from_coarse, 32, 32, 24, 32, 4>(tp, arg, min_threads, stream); break;
446  // clang-format on
447  default:
448  errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin,
449  Arg::fineColor, Arg::coarseColor);
450  }
451 #endif
452  return -1;
453  }
454 
455  template <bool from_coarse, bool query_max = false, class Arg>
456  typename std::enable_if<Arg::fineColor == 24 && Arg::coarseColor == 32 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
457  int>::type
458  launch_compute_vuv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
459  {
460  if (query_max) return 4;
461  switch (tp.aux.x) {
462  // clang-format off
463  case 0: launch_compute_vuv_kernel<from_coarse, 32, 32, 24, 8, 8>(tp, arg, min_threads, stream); break;
464  case 1: launch_compute_vuv_kernel<from_coarse, 32, 32, 24, 8, 16>(tp, arg, min_threads, stream); break;
465  case 2: launch_compute_vuv_kernel<from_coarse, 32, 32, 24, 8, 16>(tp, arg, min_threads, stream); break;
466  case 3: launch_compute_vuv_kernel<from_coarse, 32, 32, 24, 16, 8>(tp, arg, min_threads, stream); break;
467  case 4: launch_compute_vuv_kernel<from_coarse, 32, 32, 24, 32, 4>(tp, arg, min_threads, stream); break;
468  // clang-format on
469  default:
470  errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin,
471  Arg::fineColor, Arg::coarseColor);
472  }
473  return -1;
474  }
475 
476  template <bool from_coarse, bool query_max = false, class Arg>
477  typename std::enable_if<Arg::fineColor == 24 && Arg::coarseColor == 64 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
478  int>::type
479  launch_compute_vuv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
480  {
481  if (query_max) return 7;
482  switch (tp.aux.x) {
483  // clang-format off
484  case 0: launch_compute_vuv_kernel<from_coarse, 64, 64, 24, 8, 8>(tp, arg, min_threads, stream); break;
485  case 1: launch_compute_vuv_kernel<from_coarse, 64, 64, 24, 8, 16>(tp, arg, min_threads, stream); break;
486  case 2: launch_compute_vuv_kernel<from_coarse, 64, 64, 24, 8, 32>(tp, arg, min_threads, stream); break;
487  case 3: launch_compute_vuv_kernel<from_coarse, 64, 64, 24, 16, 8>(tp, arg, min_threads, stream); break;
488  case 4: launch_compute_vuv_kernel<from_coarse, 64, 64, 24, 16, 16>(tp, arg, min_threads, stream); break;
489  case 5: launch_compute_vuv_kernel<from_coarse, 64, 64, 24, 16, 32>(tp, arg, min_threads, stream); break;
490  case 6: launch_compute_vuv_kernel<from_coarse, 64, 64, 24, 32, 8>(tp, arg, min_threads, stream); break;
491  case 7: launch_compute_vuv_kernel<from_coarse, 64, 64, 24, 32, 16>(tp, arg, min_threads, stream); break;
492  // clang-format on
493  default:
494  errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin,
495  Arg::fineColor, Arg::coarseColor);
496  }
497  return -1;
498  }
499 
500  // note -- currently unused, may be used in the future
501  template <bool from_coarse, bool query_max = false, class Arg>
502  typename std::enable_if<Arg::fineColor == 24 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
503  int>::type
504  launch_compute_vuv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
505  {
506  if (query_max) return 6;
507  // clang-format off
508  switch (tp.aux.x) {
509  case 0: launch_compute_vuv_kernel<from_coarse, 96, 96, 24, 12, 8>(tp, arg, min_threads, stream); break;
510  case 1: launch_compute_vuv_kernel<from_coarse, 96, 96, 24, 24, 8>(tp, arg, min_threads, stream); break;
511  case 2: launch_compute_vuv_kernel<from_coarse, 96, 96, 24, 6, 16>(tp, arg, min_threads, stream); break;
512  case 3: launch_compute_vuv_kernel<from_coarse, 96, 96, 24, 12, 16>(tp, arg, min_threads, stream); break;
513  case 4: launch_compute_vuv_kernel<from_coarse, 96, 96, 24, 24, 16>(tp, arg, min_threads, stream); break;
514  case 5: launch_compute_vuv_kernel<from_coarse, 96, 96, 24, 12, 24>(tp, arg, min_threads, stream); break;
515  case 6: launch_compute_vuv_kernel<from_coarse, 96, 96, 24, 24, 24>(tp, arg, min_threads, stream); break;
516  default: errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin, Arg::fineColor, Arg::coarseColor);
517  }
518  // clang-format on
519  return -1;
520  }
521 
522  template <bool from_coarse, bool query_max = false, class Arg>
523  typename std::enable_if<Arg::fineColor == 32 && Arg::coarseColor == 32 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
524  int>::type
525  launch_compute_vuv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
526  {
527  if (query_max) return 3;
528  // clang-format off
529  switch (tp.aux.x) {
530  case 0: launch_compute_vuv_kernel<from_coarse, 32, 32, 32, 8, 8>(tp, arg, min_threads, stream); break;
531  case 1: launch_compute_vuv_kernel<from_coarse, 32, 32, 32, 8, 16>(tp, arg, min_threads, stream); break;
532  case 2: launch_compute_vuv_kernel<from_coarse, 32, 32, 32, 16, 8>(tp, arg, min_threads, stream); break;
533  case 3: launch_compute_vuv_kernel<from_coarse, 32, 32, 32, 32, 4>(tp, arg, min_threads, stream); break;
534  default: errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin, Arg::fineColor, Arg::coarseColor);
535  }
536  // clang-format on
537  return -1;
538  }
539 
540  template <bool from_coarse, bool query_max = false, class Arg>
541  typename std::enable_if<Arg::fineColor == 64 && Arg::coarseColor == 64 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
542  int>::type
543  launch_compute_vuv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
544  {
545  if (query_max) return 7;
546  // clang-format off
547  switch (tp.aux.x) {
548  case 0: launch_compute_vuv_kernel<from_coarse, 64, 64, 64, 8, 8>(tp, arg, min_threads, stream); break;
549  case 1: launch_compute_vuv_kernel<from_coarse, 64, 64, 64, 8, 16>(tp, arg, min_threads, stream); break;
550  case 2: launch_compute_vuv_kernel<from_coarse, 64, 64, 64, 16, 8>(tp, arg, min_threads, stream); break;
551  case 3: launch_compute_vuv_kernel<from_coarse, 64, 64, 64, 16, 16>(tp, arg, min_threads, stream); break;
552  case 4: launch_compute_vuv_kernel<from_coarse, 64, 64, 64, 16, 32>(tp, arg, min_threads, stream); break;
553  case 5: launch_compute_vuv_kernel<from_coarse, 64, 64, 64, 32, 4>(tp, arg, min_threads, stream); break;
554  case 6: launch_compute_vuv_kernel<from_coarse, 64, 64, 64, 32, 8>(tp, arg, min_threads, stream); break;
555  case 7: launch_compute_vuv_kernel<from_coarse, 64, 64, 64, 32, 16>(tp, arg, min_threads, stream); break;
556  default: errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin, Arg::fineColor, Arg::coarseColor);
557  }
558  // clang-format on
559  return -1;
560  }
561 
562  template <bool from_coarse, bool query_max = false, class Arg>
563  typename std::enable_if<Arg::fineColor == 64 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
564  int>::type
565  launch_compute_vuv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
566  {
567  if (query_max) return 6;
568  switch (tp.aux.x) {
569  // clang-format off
570  case 0: launch_compute_vuv_kernel<from_coarse, 96, 96, 64, 8, 8>(tp, arg, min_threads, stream); break;
571  case 1: launch_compute_vuv_kernel<from_coarse, 96, 96, 64, 8, 12>(tp, arg, min_threads, stream); break;
572  case 2: launch_compute_vuv_kernel<from_coarse, 96, 96, 64, 16, 6>(tp, arg, min_threads, stream); break;
573  case 3: launch_compute_vuv_kernel<from_coarse, 96, 96, 64, 16, 8>(tp, arg, min_threads, stream); break;
574  case 4: launch_compute_vuv_kernel<from_coarse, 96, 96, 64, 16, 12>(tp, arg, min_threads, stream); break;
575  case 5: launch_compute_vuv_kernel<from_coarse, 96, 96, 64, 32, 4>(tp, arg, min_threads, stream); break;
576  case 6: launch_compute_vuv_kernel<from_coarse, 96, 96, 64, 32, 6>(tp, arg, min_threads, stream); break;
577  // clang-format on
578  default:
579  errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin,
580  Arg::fineColor, Arg::coarseColor);
581  }
582  return -1;
583  }
584 
585  // note -- currently unused, may be revisited in the future
586  template <bool from_coarse, bool query_max = false, class Arg>
587  typename std::enable_if<Arg::fineColor == 96 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
588  int>::type
589  launch_compute_vuv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
590  {
591  if (query_max) return 6;
592  switch (tp.aux.x) {
593  // clang-format off
594  case 0: launch_compute_vuv_kernel<from_coarse, 96, 96, 96, 12, 8>(tp, arg, min_threads, stream); break;
595  case 1: launch_compute_vuv_kernel<from_coarse, 96, 96, 96, 24, 8>(tp, arg, min_threads, stream); break;
596  case 2: launch_compute_vuv_kernel<from_coarse, 96, 96, 96, 6, 16>(tp, arg, min_threads, stream); break;
597  case 3: launch_compute_vuv_kernel<from_coarse, 96, 96, 96, 12, 16>(tp, arg, min_threads, stream); break;
598  case 4: launch_compute_vuv_kernel<from_coarse, 96, 96, 96, 24, 16>(tp, arg, min_threads, stream); break;
599  case 5: launch_compute_vuv_kernel<from_coarse, 96, 96, 96, 12, 24>(tp, arg, min_threads, stream); break;
600  case 6: launch_compute_vuv_kernel<from_coarse, 96, 96, 96, 24, 24>(tp, arg, min_threads, stream); break;
601  // clang-format on
602  default:
603  errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin,
604  Arg::fineColor, Arg::coarseColor);
605  }
606  return -1;
607  }
608 
609 #else
610 
611  template <bool from_coarse, bool query_max = false, class Arg>
612  int launch_compute_uv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
613  {
614  errorQuda("MMA multigrid is not available for this setup.");
615  return -1;
616  }
617 
618  template <bool from_coarse, bool query_max = false, class Arg>
619  int launch_compute_vuv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
620  {
621  errorQuda("MMA multigrid is not available for this setup.");
622  return -1;
623  }
624 
625 #endif // compute capability >= 700, CUDA >= 10.1
626 
627  } // namespace mma
628 
629 } // namespace quda
bool set_max_shared_bytes
Definition: tune_quda.h:31
@ QUDA_BACKWARDS
Definition: enum_quda.h:491
int launch_compute_vuv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
int launch_compute_uv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
qudaStream_t * stream
qudaError_t qudaLaunchKernel(const void *func, const TuneParam &tp, void **args, qudaStream_t stream)
Wrapper around cudaLaunchKernel.
Definition: quda_api.cpp:57
#define errorQuda(...)
Definition: util_quda.h:120