6 #if (CUDA_VERSION >= 10010 && __COMPUTE_CAPABILITY__ >= 700)
8 #include <kernels/coarse_op_preconditioned_mma.cuh>
23 #if (CUDA_VERSION >= 10010 && __COMPUTE_CAPABILITY__ >= 700)
25 template <
bool compute_max_only,
int bM,
int bN,
int bK,
int block_y,
int block_z,
int min_block_cta = 1,
class Arg>
26 typename std::enable_if<!Arg::is_mma_compatible, void>::type launch_kernel(Arg &
arg,
int min_threads, TuneParam &tp,
27 const cudaStream_t &
stream)
29 errorQuda(
"MMA implementation is ONLY built for AoS order.");
32 template <
bool compute_max_only,
int bM,
int bN,
int bK,
int block_y,
int block_z,
int min_block_cta = 1,
class Arg>
33 typename std::enable_if<Arg::is_mma_compatible, void>::type launch_kernel(Arg &
arg,
int min_threads, TuneParam &tp,
34 const cudaStream_t &
stream)
39 constexpr
int shared_bytes = shared_memory_bytes(bM, bN, bK);
40 tp.shared_bytes = shared_bytes;
42 constexpr
bool divide_b_no = bM < Arg::M && bK == Arg::K && bN == Arg::N;
44 constexpr
int t_m = divide_b_no ? 1 : (Arg::M + bM - 1) / bM;
45 constexpr
int t_n = divide_b_no ? 1 : (Arg::N + bN - 1) / bN;
47 tp.grid = dim3(min_threads * t_m * t_n, 2, 4);
49 auto kernel = mma::CalculateYhatGPU<compute_max_only, Arg, bM, bN, bK, block_y, block_z, min_block_cta>;
50 tp.set_max_shared_bytes =
true;
59 template <
bool compute_max_only,
bool query_max = false,
class Arg>
60 typename std::enable_if<Arg::N == 48, int>::type
launch_yhat_kernel(Arg &
arg,
int min_threads, TuneParam &tp,
61 const cudaStream_t &
stream)
63 if (query_max)
return 2;
66 case 0: launch_kernel<compute_max_only, 48, 48, 48, 24, 12>(
arg, min_threads, tp,
stream);
break;
67 case 1: launch_kernel<compute_max_only, 48, 48, 48, 6, 48>(
arg, min_threads, tp,
stream);
break;
68 case 2: launch_kernel<compute_max_only, 48, 48, 48, 12, 24>(
arg, min_threads, tp,
stream);
break;
69 default:
errorQuda(
"tp.aux.x(=%d) is NOT supported by N = 48", tp.aux.x);
75 template <
bool compute_max_only,
bool query_max = false,
class Arg>
76 typename std::enable_if<Arg::N == 12, int>::type
launch_yhat_kernel(Arg &
arg,
int min_threads, TuneParam &tp,
77 const cudaStream_t &
stream)
79 if (query_max)
return 1;
82 case 0: launch_kernel<compute_max_only, 16, 16, 16, 4, 8>(
arg, min_threads, tp,
stream);
break;
83 case 1: launch_kernel<compute_max_only, 16, 16, 16, 8, 4>(
arg, min_threads, tp,
stream);
break;
84 default:
errorQuda(
"tp.aux.x(=%d) is NOT supported by N = 12", tp.aux.x);
90 template <
bool compute_max_only,
bool query_max = false,
class Arg>
91 typename std::enable_if<Arg::N == 64, int>::type
launch_yhat_kernel(Arg &
arg,
int min_threads, TuneParam &tp,
92 const cudaStream_t &
stream)
94 if (query_max)
return 6;
97 case 0: launch_kernel<compute_max_only, 64, 64, 16, 32, 8>(
arg, min_threads, tp,
stream);
break;
98 case 1: launch_kernel<compute_max_only, 64, 64, 16, 16, 16>(
arg, min_threads, tp,
stream);
break;
99 case 2: launch_kernel<compute_max_only, 64, 64, 16, 32, 16>(
arg, min_threads, tp,
stream);
break;
100 case 3: launch_kernel<compute_max_only, 64, 64, 32, 32, 16>(
arg, min_threads, tp,
stream);
break;
101 case 4: launch_kernel<compute_max_only, 64, 64, 64, 8, 64>(
arg, min_threads, tp,
stream);
break;
102 case 5: launch_kernel<compute_max_only, 64, 64, 64, 16, 32>(
arg, min_threads, tp,
stream);
break;
103 case 6: launch_kernel<compute_max_only, 64, 64, 64, 32, 16>(
arg, min_threads, tp,
stream);
break;
104 default:
errorQuda(
"tp.aux.x(=%d) is NOT supported by N = 64", tp.aux.x);
110 template <
bool compute_max_only,
bool query_max = false,
class Arg>
111 typename std::enable_if<Arg::N == 128, int>::type
launch_yhat_kernel(Arg &
arg,
int min_threads, TuneParam &tp,
112 const cudaStream_t &
stream)
114 if (query_max)
return 7;
117 case 0: launch_kernel<compute_max_only, 64, 64, 16, 32, 16, 2>(
arg, min_threads, tp,
stream);
break;
118 #if (__COMPUTE_CAPABILITY__ >= 750)
119 case 1: launch_kernel<compute_max_only, 16, 128, 128, 32, 16, 2>(
arg, min_threads, tp,
stream);
break;
121 case 1: launch_kernel<compute_max_only, 32, 128, 128, 32, 16, 2>(
arg, min_threads, tp,
stream);
break;
123 case 2: launch_kernel<compute_max_only, 128, 128, 16, 64, 8 >(
arg, min_threads, tp,
stream);
break;
124 case 3: launch_kernel<compute_max_only, 128, 128, 16, 32, 16 >(
arg, min_threads, tp,
stream);
break;
125 case 4: launch_kernel<compute_max_only, 128, 128, 32, 16, 32 >(
arg, min_threads, tp,
stream);
break;
126 case 5: launch_kernel<compute_max_only, 128, 128, 32, 64, 8 >(
arg, min_threads, tp,
stream);
break;
127 case 6: launch_kernel<compute_max_only, 128, 128, 32, 32, 16 >(
arg, min_threads, tp,
stream);
break;
128 case 7: launch_kernel<compute_max_only, 128, 128, 32, 32, 32 >(
arg, min_threads, tp,
stream);
break;
129 default:
errorQuda(
"tp.aux.x(=%d) is NOT supported by N = 128", tp.aux.x);
135 template <
bool compute_max_only,
bool query_max = false,
class Arg>
136 typename std::enable_if<Arg::N == 192, int>::type
launch_yhat_kernel(Arg &
arg,
int min_threads, TuneParam &tp,
137 const cudaStream_t &
stream)
139 if (query_max)
return 4;
142 case 0: launch_kernel<compute_max_only, 64, 64, 16, 16, 16, 2>(
arg, min_threads, tp,
stream);
break;
143 case 1: launch_kernel<compute_max_only, 64, 64, 64, 16, 16, 2>(
arg, min_threads, tp,
stream);
break;
144 case 2: launch_kernel<compute_max_only, 16, 192, 192, 24, 16 >(
arg, min_threads, tp,
stream);
break;
145 case 3: launch_kernel<compute_max_only, 64, 64, 32, 16, 16, 2>(
arg, min_threads, tp,
stream);
break;
146 #if (__COMPUTE_CAPABILITY__ >= 750)
147 case 4: launch_kernel<compute_max_only, 16, 192, 192, 96, 8 >(
arg, min_threads, tp,
stream);
break;
149 case 4: launch_kernel<compute_max_only, 16, 192, 192, 48, 8 >(
arg, min_threads, tp,
stream);
break;
151 default:
errorQuda(
"tp.aux.x(=%d) is NOT supported by N = 192", tp.aux.x);
159 template <
bool compute_max_only,
bool query_max = false,
class Arg>
162 errorQuda(
"MMA multigrid is not available for this setup.");
int launch_yhat_kernel(Arg &arg, int min_threads, TuneParam &tp, const cudaStream_t &stream)
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
qudaError_t qudaLaunchKernel(const void *func, const TuneParam &tp, void **args, qudaStream_t stream)
Wrapper around cudaLaunchKernel.