QUDA  v1.1.0
A library for QCD on GPUs
restrictor.cu
Go to the documentation of this file.
1 #include <color_spinor_field.h>
2 #include <tune_quda.h>
3 #include <launch_kernel.cuh>
4 
5 #include <jitify_helper.cuh>
6 #include <kernels/restrictor.cuh>
7 
8 namespace quda {
9 
10  template <typename Float, typename vFloat, int fineSpin, int fineColor, int coarseSpin, int coarseColor,
11  int coarse_colors_per_thread>
12  class RestrictLaunch : public Tunable {
13 
14  protected:
15  ColorSpinorField &out;
16  const ColorSpinorField &in;
17  const ColorSpinorField &v;
18  const int *fine_to_coarse;
19  const int *coarse_to_fine;
20  const int parity;
21  const QudaFieldLocation location;
22  const int block_size;
23  char vol[TuneKey::volume_n];
24 
25  unsigned int sharedBytesPerThread() const { return 0; }
26  unsigned int sharedBytesPerBlock(const TuneParam &param) const { return 0; }
27  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
28  bool tuneAuxDim() const { return true; } // Do tune the aux dimensions.
29  unsigned int minThreads() const { return in.VolumeCB(); } // fine parity is the block y dimension
30 
31  public:
32  RestrictLaunch(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v,
33  const int *fine_to_coarse, const int *coarse_to_fine, int parity)
34  : out(out), in(in), v(v), fine_to_coarse(fine_to_coarse), coarse_to_fine(coarse_to_fine),
35  parity(parity), location(checkLocation(out,in,v)), block_size(in.VolumeCB()/(2*out.VolumeCB()))
36  {
37  if (v.Location() == QUDA_CUDA_FIELD_LOCATION) {
38 #ifdef JITIFY
39  create_jitify_program("kernels/restrictor.cuh");
40 #endif
41  }
42  strcpy(aux, compile_type_str(in));
43  strcat(aux, out.AuxString());
44  strcat(aux, ",");
45  strcat(aux, in.AuxString());
46 
47  strcpy(vol, out.VolString());
48  strcat(vol, ",");
49  strcat(vol, in.VolString());
50  } // block size is checkerboard fine length / full coarse length
51 
52  void apply(const qudaStream_t &stream) {
53  if (location == QUDA_CPU_FIELD_LOCATION) {
54  if (out.FieldOrder() == QUDA_SPACE_SPIN_COLOR_FIELD_ORDER) {
55  RestrictArg<Float,vFloat,fineSpin,fineColor,coarseSpin,coarseColor,QUDA_SPACE_SPIN_COLOR_FIELD_ORDER>
56  arg(out, in, v, fine_to_coarse, coarse_to_fine, parity);
57  Restrict<Float,fineSpin,fineColor,coarseSpin,coarseColor,coarse_colors_per_thread>(arg);
58  } else {
59  errorQuda("Unsupported field order %d", out.FieldOrder());
60  }
61  } else {
62  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
63 
64  if (out.FieldOrder() == QUDA_FLOAT2_FIELD_ORDER) {
65  typedef RestrictArg<Float,vFloat,fineSpin,fineColor,coarseSpin,coarseColor,QUDA_FLOAT2_FIELD_ORDER> Arg;
66  Arg arg(out, in, v, fine_to_coarse, coarse_to_fine, parity);
67  arg.swizzle = tp.aux.x;
68 
69 #ifdef JITIFY
70  using namespace jitify::reflection;
71  jitify_error = program->kernel("quda::RestrictKernel")
72  .instantiate((int)tp.block.x,Type<Float>(),fineSpin,fineColor,coarseSpin,coarseColor,coarse_colors_per_thread,Type<Arg>())
73  .configure(tp.grid,tp.block,tp.shared_bytes,stream).launch(arg);
74 #else
75  LAUNCH_KERNEL_MG_BLOCK_SIZE(RestrictKernel,tp,stream,arg,Float,fineSpin,fineColor,
76  coarseSpin,coarseColor,coarse_colors_per_thread,Arg);
77 #endif
78  } else {
79  errorQuda("Unsupported field order %d", out.FieldOrder());
80  }
81  }
82  }
83 
84  // This block tuning tunes for the optimal amount of color
85  // splitting between blockDim.z and gridDim.z. However, enabling
86  // blockDim.z > 1 gives incorrect results due to cub reductions
87  // being unable to do independent sliced reductions along
88  // blockDim.z. So for now we only split between colors per thread
89  // and grid.z.
90  bool advanceBlockDim(TuneParam &param) const
91  {
92  // let's try to advance spin/block-color
93  while(param.block.z <= coarseColor/coarse_colors_per_thread) {
94  param.block.z++;
95  if ( (coarseColor/coarse_colors_per_thread) % param.block.z == 0) {
96  param.grid.z = (coarseColor/coarse_colors_per_thread) / param.block.z;
97  break;
98  }
99  }
100 
101  // we can advance spin/block-color since this is valid
102  if (param.block.z <= (coarseColor/coarse_colors_per_thread) ) { //
103  return true;
104  } else { // we have run off the end so let's reset
105  param.block.z = 1;
106  param.grid.z = coarseColor/coarse_colors_per_thread;
107  return false;
108  }
109  }
110 
111  int tuningIter() const { return 3; }
112 
113  bool advanceAux(TuneParam &param) const
114  {
115 #ifdef SWIZZLE
116  if (param.aux.x < 2*deviceProp.multiProcessorCount) {
117  param.aux.x++;
118  return true;
119  } else {
120  param.aux.x = 1;
121  return false;
122  }
123 #else
124  return false;
125 #endif
126  }
127 
128  // only tune shared memory per thread (disable tuning for block.z for now)
129  bool advanceTuneParam(TuneParam &param) const { return advanceSharedBytes(param) || advanceAux(param); }
130 
131  TuneKey tuneKey() const { return TuneKey(vol, typeid(*this).name(), aux); }
132 
133  void initTuneParam(TuneParam &param) const { defaultTuneParam(param); }
134 
135  /** sets default values for when tuning is disabled */
136  void defaultTuneParam(TuneParam &param) const {
137  param.block = dim3(block_size, in.SiteSubset(), 1);
138  param.grid = dim3( (minThreads()+param.block.x-1) / param.block.x, 1, 1);
139  param.shared_bytes = 0;
140 
141  param.block.z = 1;
142  param.grid.z = coarseColor / coarse_colors_per_thread;
143  param.aux.x = 1; // swizzle factor
144  }
145 
146  long long flops() const { return 8 * fineSpin * fineColor * coarseColor * in.SiteSubset()*(long long)in.VolumeCB(); }
147 
148  long long bytes() const {
149  size_t v_bytes = v.Bytes() / (v.SiteSubset() == in.SiteSubset() ? 1 : 2);
150  return in.Bytes() + out.Bytes() + v_bytes + in.SiteSubset()*in.VolumeCB()*sizeof(int);
151  }
152 
153  };
154 
155  template <typename Float, int fineSpin, int fineColor, int coarseSpin, int coarseColor>
156  void Restrict(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v,
157  const int *fine_to_coarse, const int *coarse_to_fine, int parity) {
158 
159  // for fine grids (Nc=3) have more parallelism so can use more coarse strategy
160  constexpr int coarse_colors_per_thread = fineColor != 3 ? 2 : coarseColor >= 4 && coarseColor % 4 == 0 ? 4 : 2;
161  //coarseColor >= 8 && coarseColor % 8 == 0 ? 8 : coarseColor >= 4 && coarseColor % 4 == 0 ? 4 : 2;
162 
163  if (v.Precision() == QUDA_HALF_PRECISION) {
164 #if QUDA_PRECISION & 2
165  RestrictLaunch<Float, short, fineSpin, fineColor, coarseSpin, coarseColor, coarse_colors_per_thread>
166  restrictor(out, in, v, fine_to_coarse, coarse_to_fine, parity);
167  restrictor.apply(0);
168 #else
169  errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION);
170 #endif
171  } else if (v.Precision() == in.Precision()) {
172  RestrictLaunch<Float, Float, fineSpin, fineColor, coarseSpin, coarseColor, coarse_colors_per_thread>
173  restrictor(out, in, v, fine_to_coarse, coarse_to_fine, parity);
174  restrictor.apply(0);
175  } else {
176  errorQuda("Unsupported V precision %d", v.Precision());
177  }
178  }
179 
180  template <typename Float>
181  void Restrict(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v,
182  int nVec, const int *fine_to_coarse, const int *coarse_to_fine, const int * const * spin_map, int parity)
183  {
184  if (out.Nspin() != 2) errorQuda("Unsupported nSpin %d", out.Nspin());
185  constexpr int coarseSpin = 2;
186 
187  // Template over fine color
188  if (in.Ncolor() == 3) { // standard QCD
189  constexpr int fineColor = 3;
190 #ifdef NSPIN4
191  if (in.Nspin() == 4) {
192  constexpr int fineSpin = 4;
193 
194  // first check that the spin_map matches the spin_mapper
195  spin_mapper<fineSpin,coarseSpin> mapper;
196  for (int s=0; s<fineSpin; s++)
197  for (int p=0; p<2; p++)
198  if (mapper(s,p) != spin_map[s][p]) errorQuda("Spin map does not match spin_mapper");
199 
200  if (nVec == 6) { // free field Wilson
201  Restrict<Float,fineSpin,fineColor,coarseSpin,6>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
202  } else if (nVec == 24) {
203  Restrict<Float,fineSpin,fineColor,coarseSpin,24>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
204  } else if (nVec == 32) {
205  Restrict<Float,fineSpin,fineColor,coarseSpin,32>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
206  } else {
207  errorQuda("Unsupported nVec %d", nVec);
208  }
209  } else
210 #endif // NSPIN4
211 #ifdef NSPIN1
212  if (in.Nspin() == 1) {
213  constexpr int fineSpin = 1;
214 
215  // first check that the spin_map matches the spin_mapper
216  spin_mapper<fineSpin,coarseSpin> mapper;
217  for (int s=0; s<fineSpin; s++)
218  for (int p=0; p<2; p++)
219  if (mapper(s,p) != spin_map[s][p]) errorQuda("Spin map does not match spin_mapper");
220 
221  if (nVec == 24) { // free field staggered
222  Restrict<Float,fineSpin,fineColor,coarseSpin,24>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
223  } else if (nVec == 64) {
224  Restrict<Float,fineSpin,fineColor,coarseSpin,64>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
225  } else if (nVec == 96) {
226  Restrict<Float,fineSpin,fineColor,coarseSpin,96>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
227  } else {
228  errorQuda("Unsupported nVec %d", nVec);
229  }
230  } else
231 #endif
232  {
233  errorQuda("Unexpected nSpin = %d", in.Nspin());
234  }
235 
236  } else { // Nc != 3
237 
238  if (in.Nspin() != 2) errorQuda("Unexpected nSpin = %d", in.Nspin());
239  constexpr int fineSpin = 2;
240 
241  // first check that the spin_map matches the spin_mapper
242  spin_mapper<fineSpin,coarseSpin> mapper;
243  for (int s=0; s<fineSpin; s++)
244  for (int p=0; p<2; p++)
245  if (mapper(s,p) != spin_map[s][p]) errorQuda("Spin map does not match spin_mapper");
246 
247 #ifdef NSPIN4
248  if (in.Ncolor() == 6) { // Coarsen coarsened Wilson free field
249  const int fineColor = 6;
250  if (nVec == 6) {
251  Restrict<Float,fineSpin,fineColor,coarseSpin,6>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
252  } else {
253  errorQuda("Unsupported nVec %d", nVec);
254  }
255  } else
256 #endif // NSPIN4
257  if (in.Ncolor() == 24) { // to keep compilation under control coarse grids have same or more colors
258  const int fineColor = 24;
259  if (nVec == 24) {
260  Restrict<Float,fineSpin,fineColor,coarseSpin,24>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
261 #ifdef NSPIN4
262  } else if (nVec == 32) {
263  Restrict<Float,fineSpin,fineColor,coarseSpin,32>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
264 #endif // NSPIN4
265 #ifdef NSPIN1
266  } else if (nVec == 64) {
267  Restrict<Float,fineSpin,fineColor,coarseSpin,64>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
268  } else if (nVec == 96) {
269  Restrict<Float,fineSpin,fineColor,coarseSpin,96>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
270 #endif // NSPIN1
271  } else {
272  errorQuda("Unsupported nVec %d", nVec);
273  }
274 #ifdef NSPIN4
275  } else if (in.Ncolor() == 32) {
276  const int fineColor = 32;
277  if (nVec == 32) {
278  Restrict<Float,fineSpin,fineColor,coarseSpin,32>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
279  } else {
280  errorQuda("Unsupported nVec %d", nVec);
281  }
282 #endif // NSPIN4
283 #ifdef NSPIN1
284  } else if (in.Ncolor() == 64) {
285  const int fineColor = 64;
286  if (nVec == 64) {
287  Restrict<Float,fineSpin,fineColor,coarseSpin,64>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
288  } else if (nVec == 96) {
289  Restrict<Float,fineSpin,fineColor,coarseSpin,96>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
290  } else {
291  errorQuda("Unsupported nVec %d", nVec);
292  }
293  } else if (in.Ncolor() == 96) {
294  const int fineColor = 96;
295  if (nVec == 96) {
296  Restrict<Float,fineSpin,fineColor,coarseSpin,96>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
297  } else {
298  errorQuda("Unsupported nVec %d", nVec);
299  }
300 #endif // NSPIN1
301  } else {
302  errorQuda("Unsupported nColor %d", in.Ncolor());
303  }
304  } // Nc != 3
305  }
306 
307  void Restrict(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v,
308  int Nvec, const int *fine_to_coarse, const int *coarse_to_fine, const int * const * spin_map, int parity)
309  {
310 #ifdef GPU_MULTIGRID
311  if (out.FieldOrder() != in.FieldOrder() || out.FieldOrder() != v.FieldOrder())
312  errorQuda("Field orders do not match (out=%d, in=%d, v=%d)",
313  out.FieldOrder(), in.FieldOrder(), v.FieldOrder());
314 
315  QudaPrecision precision = checkPrecision(out, in);
316 
317  if (precision == QUDA_DOUBLE_PRECISION) {
318 #ifdef GPU_MULTIGRID_DOUBLE
319  Restrict<double>(out, in, v, Nvec, fine_to_coarse, coarse_to_fine, spin_map, parity);
320 #else
321  errorQuda("Double precision multigrid has not been enabled");
322 #endif
323  } else if (precision == QUDA_SINGLE_PRECISION) {
324  Restrict<float>(out, in, v, Nvec, fine_to_coarse, coarse_to_fine, spin_map, parity);
325  } else {
326  errorQuda("Unsupported precision %d", out.Precision());
327  }
328 #else
329  errorQuda("Multigrid has not been built");
330 #endif
331  }
332 
333 } // namespace quda