QUDA  v1.1.0
A library for QCD on GPUs
coarse_op_preconditioned_mma_launch.h
Go to the documentation of this file.
1 #pragma once
2 
3 #include <gauge_field.h>
4 #include <tune_quda.h>
5 
6 #if (CUDA_VERSION >= 10010 && __COMPUTE_CAPABILITY__ >= 700)
7 
8 #include <kernels/coarse_op_preconditioned_mma.cuh>
9 
10 #endif
11 
17 namespace quda
18 {
19 
20  namespace mma
21  {
22 
23 #if (CUDA_VERSION >= 10010 && __COMPUTE_CAPABILITY__ >= 700)
24 
25  template <bool compute_max_only, int bM, int bN, int bK, int block_y, int block_z, int min_block_cta = 1, class Arg>
26  typename std::enable_if<!Arg::is_mma_compatible, void>::type launch_kernel(Arg &arg, int min_threads, TuneParam &tp,
27  const cudaStream_t &stream)
28  {
29  errorQuda("MMA implementation is ONLY built for AoS order.");
30  }
31 
32  template <bool compute_max_only, int bM, int bN, int bK, int block_y, int block_z, int min_block_cta = 1, class Arg>
33  typename std::enable_if<Arg::is_mma_compatible, void>::type launch_kernel(Arg &arg, int min_threads, TuneParam &tp,
34  const cudaStream_t &stream)
35  {
36  tp.block.x = 1;
37  tp.block.y = block_y;
38  tp.block.z = block_z;
39  constexpr int shared_bytes = shared_memory_bytes(bM, bN, bK);
40  tp.shared_bytes = shared_bytes;
41 
42  constexpr bool divide_b_no = bM < Arg::M && bK == Arg::K && bN == Arg::N;
43 
44  constexpr int t_m = divide_b_no ? 1 : (Arg::M + bM - 1) / bM;
45  constexpr int t_n = divide_b_no ? 1 : (Arg::N + bN - 1) / bN;
46 
47  tp.grid = dim3(min_threads * t_m * t_n, 2, 4);
48 
49  auto kernel = mma::CalculateYhatGPU<compute_max_only, Arg, bM, bN, bK, block_y, block_z, min_block_cta>;
50  tp.set_max_shared_bytes = true;
51  qudaLaunchKernel(kernel, tp, stream, arg);
52  }
53 
59  template <bool compute_max_only, bool query_max = false, class Arg>
60  typename std::enable_if<Arg::N == 48, int>::type launch_yhat_kernel(Arg &arg, int min_threads, TuneParam &tp,
61  const cudaStream_t &stream)
62  {
63  if (query_max) return 2;
64  // clang-format off
65  switch (tp.aux.x) {
66  case 0: launch_kernel<compute_max_only, 48, 48, 48, 24, 12>(arg, min_threads, tp, stream); break;
67  case 1: launch_kernel<compute_max_only, 48, 48, 48, 6, 48>(arg, min_threads, tp, stream); break;
68  case 2: launch_kernel<compute_max_only, 48, 48, 48, 12, 24>(arg, min_threads, tp, stream); break;
69  default: errorQuda("tp.aux.x(=%d) is NOT supported by N = 48", tp.aux.x);
70  }
71  // clang-format on
72  return -1;
73  }
74 
75  template <bool compute_max_only, bool query_max = false, class Arg>
76  typename std::enable_if<Arg::N == 12, int>::type launch_yhat_kernel(Arg &arg, int min_threads, TuneParam &tp,
77  const cudaStream_t &stream)
78  {
79  if (query_max) return 1;
80  // clang-format off
81  switch (tp.aux.x) {
82  case 0: launch_kernel<compute_max_only, 16, 16, 16, 4, 8>(arg, min_threads, tp, stream); break;
83  case 1: launch_kernel<compute_max_only, 16, 16, 16, 8, 4>(arg, min_threads, tp, stream); break;
84  default: errorQuda("tp.aux.x(=%d) is NOT supported by N = 12", tp.aux.x);
85  }
86  // clang-format on
87  return -1;
88  }
89 
90  template <bool compute_max_only, bool query_max = false, class Arg>
91  typename std::enable_if<Arg::N == 64, int>::type launch_yhat_kernel(Arg &arg, int min_threads, TuneParam &tp,
92  const cudaStream_t &stream)
93  {
94  if (query_max) return 6;
95  // clang-format off
96  switch (tp.aux.x) {
97  case 0: launch_kernel<compute_max_only, 64, 64, 16, 32, 8>(arg, min_threads, tp, stream); break;
98  case 1: launch_kernel<compute_max_only, 64, 64, 16, 16, 16>(arg, min_threads, tp, stream); break;
99  case 2: launch_kernel<compute_max_only, 64, 64, 16, 32, 16>(arg, min_threads, tp, stream); break;
100  case 3: launch_kernel<compute_max_only, 64, 64, 32, 32, 16>(arg, min_threads, tp, stream); break;
101  case 4: launch_kernel<compute_max_only, 64, 64, 64, 8, 64>(arg, min_threads, tp, stream); break;
102  case 5: launch_kernel<compute_max_only, 64, 64, 64, 16, 32>(arg, min_threads, tp, stream); break;
103  case 6: launch_kernel<compute_max_only, 64, 64, 64, 32, 16>(arg, min_threads, tp, stream); break;
104  default: errorQuda("tp.aux.x(=%d) is NOT supported by N = 64", tp.aux.x);
105  }
106  // clang-format on
107  return -1;
108  }
109 
110  template <bool compute_max_only, bool query_max = false, class Arg>
111  typename std::enable_if<Arg::N == 128, int>::type launch_yhat_kernel(Arg &arg, int min_threads, TuneParam &tp,
112  const cudaStream_t &stream)
113  {
114  if (query_max) return 7;
115  // clang-format off
116  switch (tp.aux.x) {
117  case 0: launch_kernel<compute_max_only, 64, 64, 16, 32, 16, 2>(arg, min_threads, tp, stream); break;
118 #if (__COMPUTE_CAPABILITY__ >= 750) // Turing or above
119  case 1: launch_kernel<compute_max_only, 16, 128, 128, 32, 16, 2>(arg, min_threads, tp, stream); break;
120 #else
121  case 1: launch_kernel<compute_max_only, 32, 128, 128, 32, 16, 2>(arg, min_threads, tp, stream); break;
122 #endif
123  case 2: launch_kernel<compute_max_only, 128, 128, 16, 64, 8 >(arg, min_threads, tp, stream); break;
124  case 3: launch_kernel<compute_max_only, 128, 128, 16, 32, 16 >(arg, min_threads, tp, stream); break;
125  case 4: launch_kernel<compute_max_only, 128, 128, 32, 16, 32 >(arg, min_threads, tp, stream); break;
126  case 5: launch_kernel<compute_max_only, 128, 128, 32, 64, 8 >(arg, min_threads, tp, stream); break;
127  case 6: launch_kernel<compute_max_only, 128, 128, 32, 32, 16 >(arg, min_threads, tp, stream); break;
128  case 7: launch_kernel<compute_max_only, 128, 128, 32, 32, 32 >(arg, min_threads, tp, stream); break;
129  default: errorQuda("tp.aux.x(=%d) is NOT supported by N = 128", tp.aux.x);
130  }
131  // clang-format on
132  return -1;
133  }
134 
135  template <bool compute_max_only, bool query_max = false, class Arg>
136  typename std::enable_if<Arg::N == 192, int>::type launch_yhat_kernel(Arg &arg, int min_threads, TuneParam &tp,
137  const cudaStream_t &stream)
138  {
139  if (query_max) return 4;
140  // clang-format off
141  switch (tp.aux.x) {
142  case 0: launch_kernel<compute_max_only, 64, 64, 16, 16, 16, 2>(arg, min_threads, tp, stream); break;
143  case 1: launch_kernel<compute_max_only, 64, 64, 64, 16, 16, 2>(arg, min_threads, tp, stream); break;
144  case 2: launch_kernel<compute_max_only, 16, 192, 192, 24, 16 >(arg, min_threads, tp, stream); break;
145  case 3: launch_kernel<compute_max_only, 64, 64, 32, 16, 16, 2>(arg, min_threads, tp, stream); break;
146 #if (__COMPUTE_CAPABILITY__ >= 750) // Turing or above
147  case 4: launch_kernel<compute_max_only, 16, 192, 192, 96, 8 >(arg, min_threads, tp, stream); break;
148 #else
149  case 4: launch_kernel<compute_max_only, 16, 192, 192, 48, 8 >(arg, min_threads, tp, stream); break;
150 #endif
151  default: errorQuda("tp.aux.x(=%d) is NOT supported by N = 192", tp.aux.x);
152  }
153  // clang-format on
154  return -1;
155  }
156 
157 #else
158 
159  template <bool compute_max_only, bool query_max = false, class Arg>
160  int launch_yhat_kernel(Arg &arg, int min_threads, TuneParam &tp, const cudaStream_t &stream)
161  {
162  errorQuda("MMA multigrid is not available for this setup.");
163  return -1;
164  }
165 
166 #endif // compute capability >= 700, CUDA >= 10.1
167 
168  } // namespace mma
169 
170 } // namespace quda
int launch_yhat_kernel(Arg &arg, int min_threads, TuneParam &tp, const cudaStream_t &stream)
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
qudaStream_t * stream
qudaError_t qudaLaunchKernel(const void *func, const TuneParam &tp, void **args, qudaStream_t stream)
Wrapper around cudaLaunchKernel.
Definition: quda_api.cpp:57
#define errorQuda(...)
Definition: util_quda.h:120