5 #if (CUDA_VERSION >= 10010 && __COMPUTE_CAPABILITY__ >= 700)
7 #include <kernels/coarse_op_kernel_mma.cuh>
23 #if (CUDA_VERSION >= 10010 && __COMPUTE_CAPABILITY__ >= 700)
25 template <
bool from_coarse,
int dim, QudaDirection dir,
int bM,
int bN,
int bK,
int block_y,
int block_z,
class Arg>
26 typename std::enable_if<!Arg::is_mma_compatible, void>::type
29 errorQuda(
"MMA implementation is ONLY built for AoS order.");
32 template <
bool from_coarse,
int dim, QudaDirection dir,
int bM,
int bN,
int bK,
int block_y,
int block_z,
class Arg>
33 typename std::enable_if<Arg::is_mma_compatible, void>::type
39 constexpr
int shared_bytes = shared_memory_bytes(bM, bN, bK);
43 constexpr
int t_m = 1;
44 constexpr
int t_n = 1;
46 tp.
grid = dim3(min_threads * t_m * t_n, 2, 1);
48 auto kernel = ComputeUVMMA<from_coarse, dim, dir, bM, bN, bK, block_y, block_z, Arg>;
53 template <
bool from_coarse,
int bM,
int bN,
int bK,
int block_y,
int block_z,
class Arg>
59 launch_compute_uv_kernel<from_coarse, 0, QUDA_BACKWARDS, bM, bN, bK, block_y, block_z>(tp,
arg, min_threads,
63 launch_compute_uv_kernel<from_coarse, 1, QUDA_BACKWARDS, bM, bN, bK, block_y, block_z>(tp,
arg, min_threads,
67 launch_compute_uv_kernel<from_coarse, 2, QUDA_BACKWARDS, bM, bN, bK, block_y, block_z>(tp,
arg, min_threads,
71 launch_compute_uv_kernel<from_coarse, 3, QUDA_BACKWARDS, bM, bN, bK, block_y, block_z>(tp,
arg, min_threads,
74 default:
errorQuda(
"arg.dim(=%d) is NOT supported.",
arg.dim);
79 launch_compute_uv_kernel<from_coarse, 0, QUDA_FORWARDS, bM, bN, bK, block_y, block_z>(tp,
arg, min_threads,
83 launch_compute_uv_kernel<from_coarse, 1, QUDA_FORWARDS, bM, bN, bK, block_y, block_z>(tp,
arg, min_threads,
87 launch_compute_uv_kernel<from_coarse, 2, QUDA_FORWARDS, bM, bN, bK, block_y, block_z>(tp,
arg, min_threads,
91 launch_compute_uv_kernel<from_coarse, 3, QUDA_FORWARDS, bM, bN, bK, block_y, block_z>(tp,
arg, min_threads,
94 default:
errorQuda(
"arg.dim(=%d) is NOT supported.",
arg.dim);
99 template <
bool from_coarse,
bool query_max = false,
class Arg>
100 typename std::enable_if<!from_coarse, int>::type
103 errorQuda(
"MMA implementation is ONLY built for !from_coarse.");
113 template <
bool from_coarse,
bool query_max = false,
class Arg>
114 typename std::enable_if<Arg::fineColor == 6 && Arg::coarseColor == 6 && Arg::fineSpin == 2 && Arg::coarseSpin == 2, int>::type
117 if (query_max)
return 1;
120 case 0: launch_compute_uv_kernel<from_coarse, 16, 16, 8, 4, 8>(tp,
arg, min_threads,
stream);
break;
121 case 1: launch_compute_uv_kernel<from_coarse, 16, 16, 8, 8, 4>(tp,
arg, min_threads,
stream);
break;
124 errorQuda(
"tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.
aux.x, Arg::fineSpin, Arg::coarseSpin,
125 Arg::fineColor, Arg::coarseColor);
130 template <
bool from_coarse,
bool query_max = false,
class Arg>
131 typename std::enable_if<Arg::fineColor == 24 && Arg::coarseColor == 24 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
135 #if (__COMPUTE_CAPABILITY__ >= 750)
136 if (query_max)
return 5;
139 case 0: launch_compute_uv_kernel<from_coarse, 48, 24, 24, 24, 12>(tp,
arg, min_threads,
stream);
break;
140 case 1: launch_compute_uv_kernel<from_coarse, 48, 24, 24, 16, 6>(tp,
arg, min_threads,
stream);
break;
141 case 2: launch_compute_uv_kernel<from_coarse, 48, 24, 24, 16, 2>(tp,
arg, min_threads,
stream);
break;
142 case 3: launch_compute_uv_kernel<from_coarse, 48, 24, 24, 8, 12>(tp,
arg, min_threads,
stream);
break;
143 case 4: launch_compute_uv_kernel<from_coarse, 48, 24, 24, 8, 4>(tp,
arg, min_threads,
stream);
break;
144 case 5: launch_compute_uv_kernel<from_coarse, 48, 24, 24, 4, 8>(tp,
arg, min_threads,
stream);
break;
147 errorQuda(
"tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.
aux.x, Arg::fineSpin, Arg::coarseSpin,
148 Arg::fineColor, Arg::coarseColor);
151 if (query_max)
return 4;
154 case 0: launch_compute_uv_kernel<from_coarse, 48, 32, 24, 8, 8>(tp,
arg, min_threads,
stream);
break;
155 case 1: launch_compute_uv_kernel<from_coarse, 48, 32, 24, 8, 12>(tp,
arg, min_threads,
stream);
break;
156 case 2: launch_compute_uv_kernel<from_coarse, 48, 32, 24, 8, 24>(tp,
arg, min_threads,
stream);
break;
157 case 3: launch_compute_uv_kernel<from_coarse, 48, 32, 24, 16, 4>(tp,
arg, min_threads,
stream);
break;
158 case 4: launch_compute_uv_kernel<from_coarse, 48, 32, 24, 16, 12>(tp,
arg, min_threads,
stream);
break;
161 errorQuda(
"tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.
aux.x, Arg::fineSpin, Arg::coarseSpin,
162 Arg::fineColor, Arg::coarseColor);
168 template <
bool from_coarse,
bool query_max = false,
class Arg>
169 typename std::enable_if<Arg::fineColor == 24 && Arg::coarseColor == 32 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
173 if (query_max)
return 3;
176 case 0: launch_compute_uv_kernel<from_coarse, 48, 32, 24, 8, 12>(tp,
arg, min_threads,
stream);
break;
177 case 1: launch_compute_uv_kernel<from_coarse, 48, 32, 24, 8, 12>(tp,
arg, min_threads,
stream);
break;
178 case 2: launch_compute_uv_kernel<from_coarse, 48, 32, 24, 8, 24>(tp,
arg, min_threads,
stream);
break;
179 case 3: launch_compute_uv_kernel<from_coarse, 48, 32, 24, 16, 12>(tp,
arg, min_threads,
stream);
break;
182 errorQuda(
"tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.
aux.x, Arg::fineSpin, Arg::coarseSpin,
183 Arg::fineColor, Arg::coarseColor);
188 template <
bool from_coarse,
bool query_max = false,
class Arg>
189 typename std::enable_if<Arg::fineColor == 24 && Arg::coarseColor == 64 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
193 if (query_max)
return 5;
196 case 0: launch_compute_uv_kernel<from_coarse, 48, 64, 24, 8, 12>(tp,
arg, min_threads,
stream);
break;
197 case 1: launch_compute_uv_kernel<from_coarse, 48, 64, 24, 8, 12>(tp,
arg, min_threads,
stream);
break;
198 case 2: launch_compute_uv_kernel<from_coarse, 48, 64, 24, 8, 24>(tp,
arg, min_threads,
stream);
break;
199 case 3: launch_compute_uv_kernel<from_coarse, 48, 64, 24, 16, 12>(tp,
arg, min_threads,
stream);
break;
200 case 4: launch_compute_uv_kernel<from_coarse, 48, 64, 24, 32, 12>(tp,
arg, min_threads,
stream);
break;
201 case 5: launch_compute_uv_kernel<from_coarse, 48, 64, 24, 16, 24>(tp,
arg, min_threads,
stream);
break;
204 errorQuda(
"tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.
aux.x, Arg::fineSpin, Arg::coarseSpin,
205 Arg::fineColor, Arg::coarseColor);
211 template <
bool from_coarse,
bool query_max = false,
class Arg>
212 typename std::enable_if<Arg::fineColor == 24 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
216 if (query_max)
return 6;
219 case 0: launch_compute_uv_kernel<from_coarse, 48, 96, 24, 8, 12>(tp,
arg, min_threads,
stream);
break;
220 case 1: launch_compute_uv_kernel<from_coarse, 48, 96, 24, 8, 24>(tp,
arg, min_threads,
stream);
break;
221 case 2: launch_compute_uv_kernel<from_coarse, 48, 96, 24, 16, 6>(tp,
arg, min_threads,
stream);
break;
222 case 3: launch_compute_uv_kernel<from_coarse, 48, 96, 24, 16, 12>(tp,
arg, min_threads,
stream);
break;
223 case 4: launch_compute_uv_kernel<from_coarse, 48, 96, 24, 16, 12>(tp,
arg, min_threads,
stream);
break;
224 case 5: launch_compute_uv_kernel<from_coarse, 48, 96, 24, 24, 12>(tp,
arg, min_threads,
stream);
break;
225 case 6: launch_compute_uv_kernel<from_coarse, 48, 96, 24, 24, 24>(tp,
arg, min_threads,
stream);
break;
228 errorQuda(
"tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.
aux.x, Arg::fineSpin, Arg::coarseSpin,
229 Arg::fineColor, Arg::coarseColor);
234 template <
bool from_coarse,
bool query_max = false,
class Arg>
235 typename std::enable_if<Arg::fineColor == 32 && Arg::coarseColor == 32 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
239 if (query_max)
return 2;
242 case 0: launch_compute_uv_kernel<from_coarse, 64, 32, 32, 8, 16>(tp,
arg, min_threads,
stream);
break;
243 case 1: launch_compute_uv_kernel<from_coarse, 64, 32, 32, 8, 32>(tp,
arg, min_threads,
stream);
break;
244 case 2: launch_compute_uv_kernel<from_coarse, 64, 32, 32, 16, 16>(tp,
arg, min_threads,
stream);
break;
247 errorQuda(
"tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.
aux.x, Arg::fineSpin, Arg::coarseSpin,
248 Arg::fineColor, Arg::coarseColor);
253 template <
bool from_coarse,
bool query_max = false,
class Arg>
254 typename std::enable_if<Arg::fineColor == 64 && Arg::coarseColor == 64 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
258 if (query_max)
return 6;
261 case 0: launch_compute_uv_kernel<from_coarse, 128, 64, 64, 8, 16>(tp,
arg, min_threads,
stream);
break;
262 case 1: launch_compute_uv_kernel<from_coarse, 128, 64, 64, 8, 32>(tp,
arg, min_threads,
stream);
break;
263 case 2: launch_compute_uv_kernel<from_coarse, 128, 64, 64, 16, 8>(tp,
arg, min_threads,
stream);
break;
264 case 3: launch_compute_uv_kernel<from_coarse, 128, 64, 64, 16, 16>(tp,
arg, min_threads,
stream);
break;
265 case 4: launch_compute_uv_kernel<from_coarse, 128, 64, 64, 16, 32>(tp,
arg, min_threads,
stream);
break;
266 case 5: launch_compute_uv_kernel<from_coarse, 128, 64, 64, 32, 8>(tp,
arg, min_threads,
stream);
break;
267 case 6: launch_compute_uv_kernel<from_coarse, 128, 64, 64, 32, 16>(tp,
arg, min_threads,
stream);
break;
270 errorQuda(
"tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.
aux.x, Arg::fineSpin, Arg::coarseSpin,
271 Arg::fineColor, Arg::coarseColor);
276 template <
bool from_coarse,
bool query_max = false,
class Arg>
277 typename std::enable_if<Arg::fineColor == 64 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
281 if (query_max)
return 6;
284 case 0: launch_compute_uv_kernel<from_coarse, 64, 96, 64, 32, 24>(tp,
arg, min_threads,
stream);
break;
285 case 1: launch_compute_uv_kernel<from_coarse, 64, 96, 64, 12, 32>(tp,
arg, min_threads,
stream);
break;
286 case 2: launch_compute_uv_kernel<from_coarse, 64, 96, 64, 32, 12>(tp,
arg, min_threads,
stream);
break;
287 case 3: launch_compute_uv_kernel<from_coarse, 64, 96, 64, 16, 24>(tp,
arg, min_threads,
stream);
break;
288 case 4: launch_compute_uv_kernel<from_coarse, 64, 96, 64, 16, 48>(tp,
arg, min_threads,
stream);
break;
289 case 5: launch_compute_uv_kernel<from_coarse, 64, 96, 64, 32, 6>(tp,
arg, min_threads,
stream);
break;
290 case 6: launch_compute_uv_kernel<from_coarse, 64, 96, 64, 32, 8>(tp,
arg, min_threads,
stream);
break;
293 errorQuda(
"tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.
aux.x, Arg::fineSpin, Arg::coarseSpin,
294 Arg::fineColor, Arg::coarseColor);
300 template <
bool from_coarse,
bool query_max = false,
class Arg>
301 typename std::enable_if<Arg::fineColor == 96 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
305 if (query_max)
return 5;
308 case 0: launch_compute_uv_kernel<from_coarse, 192, 96, 48, 24, 12>(tp,
arg, min_threads,
stream);
break;
309 case 1: launch_compute_uv_kernel<from_coarse, 192, 96, 48, 24, 24>(tp,
arg, min_threads,
stream);
break;
310 case 2: launch_compute_uv_kernel<from_coarse, 96, 96, 96, 24, 12>(tp,
arg, min_threads,
stream);
break;
311 case 3: launch_compute_uv_kernel<from_coarse, 96, 96, 96, 24, 24>(tp,
arg, min_threads,
stream);
break;
312 case 4: launch_compute_uv_kernel<from_coarse, 96, 96, 96, 32, 12>(tp,
arg, min_threads,
stream);
break;
313 case 5: launch_compute_uv_kernel<from_coarse, 96, 96, 96, 12, 32>(tp,
arg, min_threads,
stream);
break;
316 errorQuda(
"tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.
aux.x, Arg::fineSpin, Arg::coarseSpin,
317 Arg::fineColor, Arg::coarseColor);
322 template <
bool from_coarse,
int dim, QudaDirection dir,
int bM,
int bN,
int bK,
int block_y,
int block_z,
class Arg>
323 typename std::enable_if<!Arg::is_mma_compatible, void>::type
326 errorQuda(
"MMA implementation is ONLY built for AoS order.");
329 template <
bool from_coarse,
int dim, QudaDirection dir,
int bM,
int bN,
int bK,
int block_y,
int block_z,
class Arg>
330 typename std::enable_if<Arg::is_mma_compatible, void>::type
334 tp.
block.y = block_y;
335 tp.
block.z = block_z;
336 constexpr
int shared_bytes = shared_memory_bytes(bM, bN, bK);
340 constexpr
int t_m = 1;
341 constexpr
int t_n = 1;
343 tp.
grid = dim3(min_threads * t_m * t_n, 2, 1);
345 auto kernel = ComputeVUVMMA<from_coarse, dim, dir, bM, bN, bK, block_y, block_z, Arg>;
350 template <
bool from_coarse,
int bM,
int bN,
int bK,
int block_y,
int block_z,
class Arg>
356 launch_compute_vuv_kernel<from_coarse, 0, QUDA_BACKWARDS, bM, bN, bK, block_y, block_z>(tp,
arg, min_threads,
360 launch_compute_vuv_kernel<from_coarse, 1, QUDA_BACKWARDS, bM, bN, bK, block_y, block_z>(tp,
arg, min_threads,
364 launch_compute_vuv_kernel<from_coarse, 2, QUDA_BACKWARDS, bM, bN, bK, block_y, block_z>(tp,
arg, min_threads,
368 launch_compute_vuv_kernel<from_coarse, 3, QUDA_BACKWARDS, bM, bN, bK, block_y, block_z>(tp,
arg, min_threads,
371 default:
errorQuda(
"arg.dim(=%d) is NOT supported.",
arg.dim);
376 launch_compute_vuv_kernel<from_coarse, 0, QUDA_FORWARDS, bM, bN, bK, block_y, block_z>(tp,
arg, min_threads,
380 launch_compute_vuv_kernel<from_coarse, 1, QUDA_FORWARDS, bM, bN, bK, block_y, block_z>(tp,
arg, min_threads,
384 launch_compute_vuv_kernel<from_coarse, 2, QUDA_FORWARDS, bM, bN, bK, block_y, block_z>(tp,
arg, min_threads,
388 launch_compute_vuv_kernel<from_coarse, 3, QUDA_FORWARDS, bM, bN, bK, block_y, block_z>(tp,
arg, min_threads,
391 default:
errorQuda(
"arg.dim(=%d) is NOT supported.",
arg.dim);
396 template <
bool from_coarse,
bool query_max = false,
class Arg>
397 typename std::enable_if<!from_coarse, int>::type
400 errorQuda(
"MMA implementation is ONLY built for !from_coarse.");
404 template <
bool from_coarse,
bool query_max = false,
class Arg>
405 typename std::enable_if<Arg::fineColor == 6 && Arg::coarseColor == 6 && Arg::fineSpin == 2 && Arg::coarseSpin == 2, int>::type
408 if (query_max)
return 2;
411 case 0: launch_compute_vuv_kernel<from_coarse, 16, 16, 8, 8, 4>(tp,
arg, min_threads,
stream);
break;
412 case 1: launch_compute_vuv_kernel<from_coarse, 16, 16, 8, 4, 8>(tp,
arg, min_threads,
stream);
break;
415 errorQuda(
"tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.
aux.x, Arg::fineSpin, Arg::coarseSpin,
416 Arg::fineColor, Arg::coarseColor);
421 template <
bool from_coarse,
bool query_max = false,
class Arg>
422 typename std::enable_if<Arg::fineColor == 24 && Arg::coarseColor == 24 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
426 #if (__COMPUTE_CAPABILITY__ >= 750)
427 if (query_max)
return 1;
430 case 0: launch_compute_vuv_kernel<from_coarse, 32, 24, 24, 16, 6>(tp,
arg, min_threads,
stream);
break;
431 case 1: launch_compute_vuv_kernel<from_coarse, 32, 24, 24, 16, 12>(tp,
arg, min_threads,
stream);
break;
434 errorQuda(
"tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.
aux.x, Arg::fineSpin, Arg::coarseSpin,
435 Arg::fineColor, Arg::coarseColor);
438 if (query_max)
return 4;
441 case 0: launch_compute_vuv_kernel<from_coarse, 32, 32, 24, 8, 8>(tp,
arg, min_threads,
stream);
break;
442 case 1: launch_compute_vuv_kernel<from_coarse, 32, 32, 24, 8, 16>(tp,
arg, min_threads,
stream);
break;
443 case 2: launch_compute_vuv_kernel<from_coarse, 32, 32, 24, 8, 16>(tp,
arg, min_threads,
stream);
break;
444 case 3: launch_compute_vuv_kernel<from_coarse, 32, 32, 24, 16, 8>(tp,
arg, min_threads,
stream);
break;
445 case 4: launch_compute_vuv_kernel<from_coarse, 32, 32, 24, 32, 4>(tp,
arg, min_threads,
stream);
break;
448 errorQuda(
"tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.
aux.x, Arg::fineSpin, Arg::coarseSpin,
449 Arg::fineColor, Arg::coarseColor);
455 template <
bool from_coarse,
bool query_max = false,
class Arg>
456 typename std::enable_if<Arg::fineColor == 24 && Arg::coarseColor == 32 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
460 if (query_max)
return 4;
463 case 0: launch_compute_vuv_kernel<from_coarse, 32, 32, 24, 8, 8>(tp,
arg, min_threads,
stream);
break;
464 case 1: launch_compute_vuv_kernel<from_coarse, 32, 32, 24, 8, 16>(tp,
arg, min_threads,
stream);
break;
465 case 2: launch_compute_vuv_kernel<from_coarse, 32, 32, 24, 8, 16>(tp,
arg, min_threads,
stream);
break;
466 case 3: launch_compute_vuv_kernel<from_coarse, 32, 32, 24, 16, 8>(tp,
arg, min_threads,
stream);
break;
467 case 4: launch_compute_vuv_kernel<from_coarse, 32, 32, 24, 32, 4>(tp,
arg, min_threads,
stream);
break;
470 errorQuda(
"tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.
aux.x, Arg::fineSpin, Arg::coarseSpin,
471 Arg::fineColor, Arg::coarseColor);
476 template <
bool from_coarse,
bool query_max = false,
class Arg>
477 typename std::enable_if<Arg::fineColor == 24 && Arg::coarseColor == 64 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
481 if (query_max)
return 7;
484 case 0: launch_compute_vuv_kernel<from_coarse, 64, 64, 24, 8, 8>(tp,
arg, min_threads,
stream);
break;
485 case 1: launch_compute_vuv_kernel<from_coarse, 64, 64, 24, 8, 16>(tp,
arg, min_threads,
stream);
break;
486 case 2: launch_compute_vuv_kernel<from_coarse, 64, 64, 24, 8, 32>(tp,
arg, min_threads,
stream);
break;
487 case 3: launch_compute_vuv_kernel<from_coarse, 64, 64, 24, 16, 8>(tp,
arg, min_threads,
stream);
break;
488 case 4: launch_compute_vuv_kernel<from_coarse, 64, 64, 24, 16, 16>(tp,
arg, min_threads,
stream);
break;
489 case 5: launch_compute_vuv_kernel<from_coarse, 64, 64, 24, 16, 32>(tp,
arg, min_threads,
stream);
break;
490 case 6: launch_compute_vuv_kernel<from_coarse, 64, 64, 24, 32, 8>(tp,
arg, min_threads,
stream);
break;
491 case 7: launch_compute_vuv_kernel<from_coarse, 64, 64, 24, 32, 16>(tp,
arg, min_threads,
stream);
break;
494 errorQuda(
"tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.
aux.x, Arg::fineSpin, Arg::coarseSpin,
495 Arg::fineColor, Arg::coarseColor);
501 template <
bool from_coarse,
bool query_max = false,
class Arg>
502 typename std::enable_if<Arg::fineColor == 24 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
506 if (query_max)
return 6;
509 case 0: launch_compute_vuv_kernel<from_coarse, 96, 96, 24, 12, 8>(tp,
arg, min_threads,
stream);
break;
510 case 1: launch_compute_vuv_kernel<from_coarse, 96, 96, 24, 24, 8>(tp,
arg, min_threads,
stream);
break;
511 case 2: launch_compute_vuv_kernel<from_coarse, 96, 96, 24, 6, 16>(tp,
arg, min_threads,
stream);
break;
512 case 3: launch_compute_vuv_kernel<from_coarse, 96, 96, 24, 12, 16>(tp,
arg, min_threads,
stream);
break;
513 case 4: launch_compute_vuv_kernel<from_coarse, 96, 96, 24, 24, 16>(tp,
arg, min_threads,
stream);
break;
514 case 5: launch_compute_vuv_kernel<from_coarse, 96, 96, 24, 12, 24>(tp,
arg, min_threads,
stream);
break;
515 case 6: launch_compute_vuv_kernel<from_coarse, 96, 96, 24, 24, 24>(tp,
arg, min_threads,
stream);
break;
516 default:
errorQuda(
"tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.
aux.x, Arg::fineSpin, Arg::coarseSpin, Arg::fineColor, Arg::coarseColor);
522 template <
bool from_coarse,
bool query_max = false,
class Arg>
523 typename std::enable_if<Arg::fineColor == 32 && Arg::coarseColor == 32 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
527 if (query_max)
return 3;
530 case 0: launch_compute_vuv_kernel<from_coarse, 32, 32, 32, 8, 8>(tp,
arg, min_threads,
stream);
break;
531 case 1: launch_compute_vuv_kernel<from_coarse, 32, 32, 32, 8, 16>(tp,
arg, min_threads,
stream);
break;
532 case 2: launch_compute_vuv_kernel<from_coarse, 32, 32, 32, 16, 8>(tp,
arg, min_threads,
stream);
break;
533 case 3: launch_compute_vuv_kernel<from_coarse, 32, 32, 32, 32, 4>(tp,
arg, min_threads,
stream);
break;
534 default:
errorQuda(
"tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.
aux.x, Arg::fineSpin, Arg::coarseSpin, Arg::fineColor, Arg::coarseColor);
540 template <
bool from_coarse,
bool query_max = false,
class Arg>
541 typename std::enable_if<Arg::fineColor == 64 && Arg::coarseColor == 64 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
545 if (query_max)
return 7;
548 case 0: launch_compute_vuv_kernel<from_coarse, 64, 64, 64, 8, 8>(tp,
arg, min_threads,
stream);
break;
549 case 1: launch_compute_vuv_kernel<from_coarse, 64, 64, 64, 8, 16>(tp,
arg, min_threads,
stream);
break;
550 case 2: launch_compute_vuv_kernel<from_coarse, 64, 64, 64, 16, 8>(tp,
arg, min_threads,
stream);
break;
551 case 3: launch_compute_vuv_kernel<from_coarse, 64, 64, 64, 16, 16>(tp,
arg, min_threads,
stream);
break;
552 case 4: launch_compute_vuv_kernel<from_coarse, 64, 64, 64, 16, 32>(tp,
arg, min_threads,
stream);
break;
553 case 5: launch_compute_vuv_kernel<from_coarse, 64, 64, 64, 32, 4>(tp,
arg, min_threads,
stream);
break;
554 case 6: launch_compute_vuv_kernel<from_coarse, 64, 64, 64, 32, 8>(tp,
arg, min_threads,
stream);
break;
555 case 7: launch_compute_vuv_kernel<from_coarse, 64, 64, 64, 32, 16>(tp,
arg, min_threads,
stream);
break;
556 default:
errorQuda(
"tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.
aux.x, Arg::fineSpin, Arg::coarseSpin, Arg::fineColor, Arg::coarseColor);
562 template <
bool from_coarse,
bool query_max = false,
class Arg>
563 typename std::enable_if<Arg::fineColor == 64 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
567 if (query_max)
return 6;
570 case 0: launch_compute_vuv_kernel<from_coarse, 96, 96, 64, 8, 8>(tp,
arg, min_threads,
stream);
break;
571 case 1: launch_compute_vuv_kernel<from_coarse, 96, 96, 64, 8, 12>(tp,
arg, min_threads,
stream);
break;
572 case 2: launch_compute_vuv_kernel<from_coarse, 96, 96, 64, 16, 6>(tp,
arg, min_threads,
stream);
break;
573 case 3: launch_compute_vuv_kernel<from_coarse, 96, 96, 64, 16, 8>(tp,
arg, min_threads,
stream);
break;
574 case 4: launch_compute_vuv_kernel<from_coarse, 96, 96, 64, 16, 12>(tp,
arg, min_threads,
stream);
break;
575 case 5: launch_compute_vuv_kernel<from_coarse, 96, 96, 64, 32, 4>(tp,
arg, min_threads,
stream);
break;
576 case 6: launch_compute_vuv_kernel<from_coarse, 96, 96, 64, 32, 6>(tp,
arg, min_threads,
stream);
break;
579 errorQuda(
"tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.
aux.x, Arg::fineSpin, Arg::coarseSpin,
580 Arg::fineColor, Arg::coarseColor);
586 template <
bool from_coarse,
bool query_max = false,
class Arg>
587 typename std::enable_if<Arg::fineColor == 96 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2,
591 if (query_max)
return 6;
594 case 0: launch_compute_vuv_kernel<from_coarse, 96, 96, 96, 12, 8>(tp,
arg, min_threads,
stream);
break;
595 case 1: launch_compute_vuv_kernel<from_coarse, 96, 96, 96, 24, 8>(tp,
arg, min_threads,
stream);
break;
596 case 2: launch_compute_vuv_kernel<from_coarse, 96, 96, 96, 6, 16>(tp,
arg, min_threads,
stream);
break;
597 case 3: launch_compute_vuv_kernel<from_coarse, 96, 96, 96, 12, 16>(tp,
arg, min_threads,
stream);
break;
598 case 4: launch_compute_vuv_kernel<from_coarse, 96, 96, 96, 24, 16>(tp,
arg, min_threads,
stream);
break;
599 case 5: launch_compute_vuv_kernel<from_coarse, 96, 96, 96, 12, 24>(tp,
arg, min_threads,
stream);
break;
600 case 6: launch_compute_vuv_kernel<from_coarse, 96, 96, 96, 24, 24>(tp,
arg, min_threads,
stream);
break;
603 errorQuda(
"tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.
aux.x, Arg::fineSpin, Arg::coarseSpin,
604 Arg::fineColor, Arg::coarseColor);
611 template <
bool from_coarse,
bool query_max = false,
class Arg>
614 errorQuda(
"MMA multigrid is not available for this setup.");
618 template <
bool from_coarse,
bool query_max = false,
class Arg>
621 errorQuda(
"MMA multigrid is not available for this setup.");
bool set_max_shared_bytes
int launch_compute_vuv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
int launch_compute_uv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const cudaStream_t &stream)
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
qudaError_t qudaLaunchKernel(const void *func, const TuneParam &tp, void **args, qudaStream_t stream)
Wrapper around cudaLaunchKernel.