QUDA  v1.1.0
A library for QCD on GPUs
instantiate.h
Go to the documentation of this file.
1 #pragma once
2 
3 #include <array>
4 #include <enum_quda.h>
5 #include <util_quda.h>
6 
7 namespace quda
8 {
9 
10  template <QudaReconstructType recon> constexpr bool is_enabled() { return true; }
11 #if !(QUDA_RECONSTRUCT & 4)
12  template <> constexpr bool is_enabled<QUDA_RECONSTRUCT_NO>() { return false; }
13 #endif
14 #if !(QUDA_RECONSTRUCT & 2)
15  template <> constexpr bool is_enabled<QUDA_RECONSTRUCT_13>() { return false; }
16  template <> constexpr bool is_enabled<QUDA_RECONSTRUCT_12>() { return false; }
17 #endif
18 #if !(QUDA_RECONSTRUCT & 1)
19  template <> constexpr bool is_enabled<QUDA_RECONSTRUCT_9>() { return false; }
20  template <> constexpr bool is_enabled<QUDA_RECONSTRUCT_8>() { return false; }
21 #endif
22 
23  struct ReconstructFull {
24  static constexpr std::array<QudaReconstructType, 5> recon
26  };
27 
29  static constexpr std::array<QudaReconstructType, 3> recon
31  };
32 
34  static constexpr std::array<QudaReconstructType, 3> recon
36  };
37 
38  struct ReconstructNo12 {
39  static constexpr std::array<QudaReconstructType, 2> recon = {QUDA_RECONSTRUCT_NO, QUDA_RECONSTRUCT_12};
40  };
41 
42  struct ReconstructNone {
43  static constexpr std::array<QudaReconstructType, 1> recon = {QUDA_RECONSTRUCT_NO};
44  };
45 
46  struct ReconstructMom {
47  static constexpr std::array<QudaReconstructType, 2> recon = {QUDA_RECONSTRUCT_NO, QUDA_RECONSTRUCT_10};
48  };
49 
50  struct Reconstruct10 {
51  static constexpr std::array<QudaReconstructType, 1> recon = {QUDA_RECONSTRUCT_10};
52  };
53 
58  template <bool enabled, template <typename, int, QudaReconstructType> class Apply, typename Float, int nColor,
59  QudaReconstructType recon, typename G, typename... Args>
61  instantiateApply(G &U, Args &&... args) { Apply<Float, nColor, recon>(U, args...); }
62  };
63 
68  template <template <typename, int, QudaReconstructType> class Apply, typename Float, int nColor,
69  QudaReconstructType recon, typename G, typename... Args>
70  struct instantiateApply<false, Apply, Float, nColor, recon, G, Args...> {
71  instantiateApply(G &U, Args &&... args)
72  {
73  errorQuda("QUDA_RECONSTRUCT=%d does not enable %d", QUDA_RECONSTRUCT, recon);
74  }
75  };
76 
81  template <template <typename, int, QudaReconstructType> class Apply, typename Float, int nColor, typename Recon,
82  int i, typename G, typename... Args>
84  instantiateReconstruct(G &U, Args &&... args)
85  {
86  if (U.Reconstruct() == Recon::recon[i]) {
87  instantiateApply<is_enabled<Recon::recon[i]>(), Apply, Float, nColor, Recon::recon[i], G, Args...>(U, args...);
88  } else {
89  instantiateReconstruct<Apply, Float, nColor, Recon, i - 1, G, Args...>(U, args...);
90  }
91  }
92  };
93 
97  template <template <typename, int, QudaReconstructType> class Apply, typename Float, int nColor, typename Recon,
98  typename G, typename... Args>
99  struct instantiateReconstruct<Apply, Float, nColor, Recon, 0, G, Args...> {
100  instantiateReconstruct(G &U, Args &&... args)
101  {
102  if (U.Reconstruct() == Recon::recon[0]) {
103  instantiateApply<is_enabled<Recon::recon[0]>(), Apply, Float, nColor, Recon::recon[0], G, Args...>(U, args...);
104  } else {
105  errorQuda("Unsupported reconstruct type %d\n", U.Reconstruct());
106  }
107  }
108  };
109 
115  template <template <typename, int, QudaReconstructType> class Apply, typename Recon, typename Float, typename G,
116  typename... Args>
117  constexpr void instantiate(G &U, Args &&... args)
118  {
119  if (U.Ncolor() == 3) {
120  constexpr int i = Recon::recon.size() - 1;
121  instantiateReconstruct<Apply, Float, 3, Recon, i, G, Args...>(U, args...);
122  } else {
123  errorQuda("Unsupported number of colors %d\n", U.Ncolor());
124  }
125  }
126 
132  template <template <typename, int, QudaReconstructType> class Apply, typename Recon = ReconstructFull, typename G,
133  typename... Args>
134  constexpr void instantiate(G &U, Args &&... args)
135  {
136  if (U.Precision() == QUDA_DOUBLE_PRECISION) {
137 #if QUDA_PRECISION & 8
138  instantiate<Apply, Recon, double>(U, args...);
139 #else
140  errorQuda("QUDA_PRECISION=%d does not enable double precision", QUDA_PRECISION);
141 #endif
142  } else if (U.Precision() == QUDA_SINGLE_PRECISION) {
143 #if QUDA_PRECISION & 4
144  instantiate<Apply, Recon, float>(U, args...);
145 #else
146  errorQuda("QUDA_PRECISION=%d does not enable single precision", QUDA_PRECISION);
147 #endif
148  } else {
149  errorQuda("Unsupported precision %d\n", U.Precision());
150  }
151  }
152 
153 #if defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ <= 9
154 #define constexpr
155 #endif
156 
162  template <template <typename> class Apply, typename C, typename... Args>
163  constexpr void instantiate(C &c, Args &&... args)
164  {
165  if (c.Precision() == QUDA_DOUBLE_PRECISION) {
166 #if QUDA_PRECISION & 8
167  Apply<double>(c, args...);
168 #else
169  errorQuda("QUDA_PRECISION=%d does not enable double precision", QUDA_PRECISION);
170 #endif
171  } else if (c.Precision() == QUDA_SINGLE_PRECISION) {
172 #if QUDA_PRECISION & 4
173  Apply<float>(c, args...);
174 #else
175  errorQuda("QUDA_PRECISION=%d does not enable single precision", QUDA_PRECISION);
176 #endif
177  } else if (c.Precision() == QUDA_HALF_PRECISION) {
178 #if QUDA_PRECISION & 2
179  Apply<short>(c, args...);
180 #else
181  errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION);
182 #endif
183  } else if (c.Precision() == QUDA_QUARTER_PRECISION) {
184 #if QUDA_PRECISION & 1
185  Apply<int8_t>(c, args...);
186 #else
187  errorQuda("QUDA_PRECISION=%d does not enable quarter precision", QUDA_PRECISION);
188 #endif
189  } else {
190  errorQuda("Unsupported precision %d\n", c.Precision());
191  }
192  }
193 
199  template <template <typename, int> class Apply, typename store_t, typename F, typename... Args>
200  constexpr void instantiate(F &field, Args &&... args)
201  {
202  if (field.Ncolor() == 3) {
203  Apply<store_t, 3>(field, args...);
204  } else {
205  errorQuda("Unsupported number of colors %d\n", field.Ncolor());
206  }
207  }
208 
209 #if defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ <= 9
210 #undef constexpr
211 #define constexpr constexpr
212 #endif
213 
220  template <template <typename, int> class Apply, typename F, typename... Args>
221  constexpr void instantiate(F &field, Args &&... args)
222  {
223  if (field.Precision() == QUDA_DOUBLE_PRECISION) {
224 #if QUDA_PRECISION & 8
225  instantiate<Apply, double>(field, args...);
226 #else
227  errorQuda("QUDA_PRECISION=%d does not enable double precision", QUDA_PRECISION);
228 #endif
229  } else if (field.Precision() == QUDA_SINGLE_PRECISION) {
230 #if QUDA_PRECISION & 4
231  instantiate<Apply, float>(field, args...);
232 #else
233  errorQuda("QUDA_PRECISION=%d does not enable single precision", QUDA_PRECISION);
234 #endif
235  } else if (field.Precision() == QUDA_HALF_PRECISION) {
236 #if QUDA_PRECISION & 2
237  instantiate<Apply, short>(field, args...);
238 #else
239  errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION);
240 #endif
241  } else if (field.Precision() == QUDA_QUARTER_PRECISION) {
242 #if QUDA_PRECISION & 1
243  instantiate<Apply, int8_t>(field, args...);
244 #else
245  errorQuda("QUDA_PRECISION=%d does not enable quarter precision", QUDA_PRECISION);
246 #endif
247  } else {
248  errorQuda("Unsupported precision %d\n", field.Precision());
249  }
250  }
251 
263  template <template <typename> class Apply, typename F, typename... Args>
264  constexpr void instantiatePrecision(F &field, Args &&... args)
265  {
266  if (field.Precision() == QUDA_DOUBLE_PRECISION) {
267  // always instantiate double precision
268  Apply<double>(field, args...);
269  } else if (field.Precision() == QUDA_SINGLE_PRECISION) {
270 #if QUDA_PRECISION & 4
271  Apply<float>(field, args...);
272 #else
273  errorQuda("QUDA_PRECISION=%d does not enable single precision", QUDA_PRECISION);
274 #endif
275  } else if (field.Precision() == QUDA_HALF_PRECISION) {
276 #if QUDA_PRECISION & 2
277  Apply<short>(field, args...);
278 #else
279  errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION);
280 #endif
281  } else if (field.Precision() == QUDA_QUARTER_PRECISION) {
282 #if QUDA_PRECISION & 1
283  Apply<int8_t>(field, args...);
284 #else
285  errorQuda("QUDA_PRECISION=%d does not enable quarter precision", QUDA_PRECISION);
286 #endif
287  } else {
288  errorQuda("Unsupported precision %d\n", field.Precision());
289  }
290  }
291 
308  template <template <typename, typename> class Apply, typename T, typename F, typename... Args>
309  constexpr void instantiatePrecision2(F &field, Args &&... args)
310  {
311  if (field.Precision() == QUDA_DOUBLE_PRECISION) {
312  // always instantiate double precision
313  Apply<double, T>(field, args...);
314  } else if (field.Precision() == QUDA_SINGLE_PRECISION) {
315 #if QUDA_PRECISION & 4
316  Apply<float, T>(field, args...);
317 #else
318  errorQuda("QUDA_PRECISION=%d does not enable single precision", QUDA_PRECISION);
319 #endif
320  } else if (field.Precision() == QUDA_HALF_PRECISION) {
321 #if QUDA_PRECISION & 2
322  Apply<short, T>(field, args...);
323 #else
324  errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION);
325 #endif
326  } else if (field.Precision() == QUDA_QUARTER_PRECISION) {
327 #if QUDA_PRECISION & 1
328  Apply<int8_t, T>(field, args...);
329 #else
330  errorQuda("QUDA_PRECISION=%d does not enable quarter precision", QUDA_PRECISION);
331 #endif
332  } else {
333  errorQuda("Unsupported precision %d\n", field.Precision());
334  }
335  }
336 
344  template <template <typename> class Apply, typename F, typename... Args>
345  constexpr void instantiatePrecisionMG(F &field, Args &&... args)
346  {
347  if (field.Precision() == QUDA_DOUBLE_PRECISION) {
348 #ifdef GPU_MULTIGRID_DOUBLE
349  Apply<double>(field, args...);
350 #else
351  errorQuda("Multigrid not support in double precision");
352 #endif
353  } else if (field.Precision() == QUDA_SINGLE_PRECISION) {
354 #if QUDA_PRECISION & 4
355  Apply<float>(field, args...);
356 #else
357  errorQuda("QUDA_PRECISION=%d does not enable single precision", QUDA_PRECISION);
358 #endif
359  } else if (field.Precision() == QUDA_HALF_PRECISION) {
360 #if QUDA_PRECISION & 2
361  Apply<short>(field, args...);
362 #else
363  errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION);
364 #endif
365  } else if (field.Precision() == QUDA_QUARTER_PRECISION) {
366 #if QUDA_PRECISION & 1
367  Apply<int8_t>(field, args...);
368 #else
369  errorQuda("QUDA_PRECISION=%d does not enable quarter precision", QUDA_PRECISION);
370 #endif
371  } else {
372  errorQuda("Unsupported precision %d\n", field.Precision());
373  }
374  }
375 
376  // these are used in dslash.h
377 
379  static constexpr std::array<QudaReconstructType, 3> recon
381  };
382 
384  static constexpr std::array<QudaReconstructType, 3> recon
386  };
387 
388 } // namespace quda
const int nColor
Definition: covdev_test.cpp:44
@ QUDA_RECONSTRUCT_NO
Definition: enum_quda.h:70
@ QUDA_RECONSTRUCT_12
Definition: enum_quda.h:71
@ QUDA_RECONSTRUCT_13
Definition: enum_quda.h:74
@ QUDA_RECONSTRUCT_8
Definition: enum_quda.h:72
@ QUDA_RECONSTRUCT_10
Definition: enum_quda.h:75
@ QUDA_RECONSTRUCT_9
Definition: enum_quda.h:73
enum QudaReconstructType_s QudaReconstructType
@ QUDA_DOUBLE_PRECISION
Definition: enum_quda.h:65
@ QUDA_SINGLE_PRECISION
Definition: enum_quda.h:64
@ QUDA_QUARTER_PRECISION
Definition: enum_quda.h:62
@ QUDA_HALF_PRECISION
Definition: enum_quda.h:63
constexpr void instantiate(G &U, Args &&... args)
This instantiate function is used to instantiate the colors.
Definition: instantiate.h:117
constexpr bool is_enabled()
Definition: instantiate.h:10
constexpr bool is_enabled< QUDA_RECONSTRUCT_13 >()
Definition: instantiate.h:15
constexpr bool is_enabled< QUDA_RECONSTRUCT_12 >()
Definition: instantiate.h:16
constexpr void instantiatePrecisionMG(F &field, Args &&... args)
The instantiatePrecision function is used to instantiate the precision.
Definition: instantiate.h:345
constexpr bool is_enabled< QUDA_RECONSTRUCT_NO >()
Definition: instantiate.h:12
constexpr bool is_enabled< QUDA_RECONSTRUCT_9 >()
Definition: instantiate.h:19
constexpr void instantiatePrecision2(F &field, Args &&... args)
The instantiatePrecision2 function is used to instantiate the precision for a class that accepts 2 ty...
Definition: instantiate.h:309
constexpr void instantiatePrecision(F &field, Args &&... args)
The instantiatePrecision function is used to instantiate the precision. Note unlike the "instantiate"...
Definition: instantiate.h:264
constexpr bool is_enabled< QUDA_RECONSTRUCT_8 >()
Definition: instantiate.h:20
FloatingPoint< float > Float
static constexpr std::array< QudaReconstructType, 1 > recon
Definition: instantiate.h:51
static constexpr std::array< QudaReconstructType, 5 > recon
Definition: instantiate.h:25
static constexpr std::array< QudaReconstructType, 2 > recon
Definition: instantiate.h:47
static constexpr std::array< QudaReconstructType, 2 > recon
Definition: instantiate.h:39
static constexpr std::array< QudaReconstructType, 1 > recon
Definition: instantiate.h:43
static constexpr std::array< QudaReconstructType, 3 > recon
Definition: instantiate.h:35
static constexpr std::array< QudaReconstructType, 3 > recon
Definition: instantiate.h:30
static constexpr std::array< QudaReconstructType, 3 > recon
Definition: instantiate.h:385
static constexpr std::array< QudaReconstructType, 3 > recon
Definition: instantiate.h:380
This class instantiates the Apply class based on the instantiated templates below.
Definition: instantiate.h:60
instantiateApply(G &U, Args &&... args)
Definition: instantiate.h:61
Instantiate the reconstruction template at index i and recurse to prior element.
Definition: instantiate.h:83
instantiateReconstruct(G &U, Args &&... args)
Definition: instantiate.h:84
#define errorQuda(...)
Definition: util_quda.h:120