QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
dslash_coarse.cu
Go to the documentation of this file.
1 #include <gauge_field.h>
2 #include <color_spinor_field.h>
3 #include <uint_to_char.h>
4 #include <worker.h>
5 #include <tune_quda.h>
6 
7 #include <jitify_helper.cuh>
9 
10 namespace quda {
11 
12 #ifdef GPU_MULTIGRID
13 
14  template <typename Float, typename yFloat, typename ghostFloat, int nDim, int Ns, int Nc, int Mc, bool dslash, bool clover, bool dagger, DslashType type>
15  class DslashCoarse : public TunableVectorY {
16 
17  protected:
19  const ColorSpinorField &inA;
20  const ColorSpinorField &inB;
21  const GaugeField &Y;
22  const GaugeField &X;
23  const double kappa;
24  const int parity;
25  const int nParity;
26  const int nSrc;
27 
28  const int max_color_col_stride = 8;
29  mutable int color_col_stride;
30  mutable int dim_threads;
31  char *saveOut;
32 
33  long long flops() const
34  {
35  return ((dslash*2*nDim+clover*1)*(8*Ns*Nc*Ns*Nc)-2*Ns*Nc)*nParity*(long long)out.VolumeCB();
36  }
37  long long bytes() const
38  {
39  return (dslash||clover) * out.Bytes() + dslash*8*inA.Bytes() + clover*inB.Bytes() +
40  nSrc*nParity*(dslash*Y.Bytes()*Y.VolumeCB()/(2*Y.Stride()) + clover*X.Bytes()/2);
41  }
42  unsigned int sharedBytesPerThread() const { return (sizeof(complex<Float>) * Mc); }
43  unsigned int sharedBytesPerBlock(const TuneParam &param) const { return 0; }
44  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions
45  bool tuneAuxDim() const { return true; } // Do tune the aux dimensions
46  unsigned int minThreads() const { return color_col_stride * X.VolumeCB(); } // 4-d volume since this x threads only
47 
48  bool advanceBlockDim(TuneParam &param) const
49  {
50  dim3 grid = param.grid;
51  bool ret = TunableVectorY::advanceBlockDim(param);
52  param.grid.z = grid.z;
53 
54  if (ret) { // we advanced the block.x so we're done
55  return true;
56  } else { // block.x (spacetime) was reset
57 
58  // let's try to advance spin/block-color
59  while(param.block.z <= (unsigned int)(dim_threads * 2 * 2 * (Nc/Mc))) {
60  param.block.z+=dim_threads * 2;
61  if ( (dim_threads*2*2*(Nc/Mc)) % param.block.z == 0) {
62  param.grid.z = (dim_threads * 2 * 2 * (Nc/Mc)) / param.block.z;
63  break;
64  }
65  }
66 
67  // we can advance spin/block-color since this is valid
68  if (param.block.z <= (unsigned int)(dim_threads * 2 * 2 * (Nc/Mc)) &&
69  param.block.z <= (unsigned int)deviceProp.maxThreadsDim[2] &&
70  param.block.x*param.block.y*param.block.z <= (unsigned int)deviceProp.maxThreadsPerBlock ) { //
71  return true;
72  } else { // we have run off the end so let's reset
73  param.block.z = dim_threads * 2;
74  param.grid.z = 2 * (Nc/Mc);
75  return false;
76  }
77  }
78  }
79 
80  // FIXME: understand why this leads to slower perf and variable correctness
81  //int blockStep() const { return deviceProp.warpSize/4; }
82  //int blockMin() const { return deviceProp.warpSize/4; }
83 
84  // Experimental autotuning of the color column stride
85  bool advanceAux(TuneParam &param) const
86  {
87 
88 #ifdef DOT_PRODUCT_SPLIT
89  // we can only split the dot product on Kepler and later since we need the __shfl instruction
90  if (2*param.aux.x <= max_color_col_stride && Nc % (2*param.aux.x) == 0 &&
91  param.block.x % deviceProp.warpSize == 0) {
92  // An x-dimension block size that is not a multiple of the
93  // warp size is incompatible with splitting the dot product
94  // across the warp so we must skip this
95 
96  param.aux.x *= 2; // safe to advance
97  color_col_stride = param.aux.x;
98 
99  // recompute grid size since minThreads() has now been updated
100  param.grid.x = (minThreads()+param.block.x-1)/param.block.x;
101 
102  // check this grid size is valid before returning
103  if (param.grid.x < (unsigned int)deviceProp.maxGridSize[0]) return true;
104  }
105 #endif
106 
107  // reset color column stride if too large or not divisible
108  param.aux.x = 1;
109  color_col_stride = param.aux.x;
110 
111  // recompute grid size since minThreads() has now been updated
112  param.grid.x = (minThreads()+param.block.x-1)/param.block.x;
113 
114  if (2*param.aux.y <= nDim &&
115  param.block.x*param.block.y*dim_threads*2 <= (unsigned int)deviceProp.maxThreadsPerBlock) {
116  param.aux.y *= 2;
117  dim_threads = param.aux.y;
118 
119  // need to reset z-block/grid size/shared_bytes since dim_threads has changed
120  param.block.z = dim_threads * 2;
121  param.grid.z = 2* (Nc / Mc);
122 
123  param.shared_bytes = sharedBytesPerThread()*param.block.x*param.block.y*param.block.z > sharedBytesPerBlock(param) ?
124  sharedBytesPerThread()*param.block.x*param.block.y*param.block.z : sharedBytesPerBlock(param);
125 
126  return true;
127  } else {
128  param.aux.y = 1;
129  dim_threads = param.aux.y;
130 
131  // need to reset z-block/grid size/shared_bytes since
132  // dim_threads has changed. Strictly speaking this isn't needed
133  // since this is the outer dimension to tune, but would be
134  // needed if we added an aux.z tuning dimension
135  param.block.z = dim_threads * 2;
136  param.grid.z = 2* (Nc / Mc);
137 
138  param.shared_bytes = sharedBytesPerThread()*param.block.x*param.block.y*param.block.z > sharedBytesPerBlock(param) ?
139  sharedBytesPerThread()*param.block.x*param.block.y*param.block.z : sharedBytesPerBlock(param);
140 
141  return false;
142  }
143  }
144 
145  virtual void initTuneParam(TuneParam &param) const
146  {
147  param.aux = make_int4(1,1,1,1);
148  color_col_stride = param.aux.x;
149  dim_threads = param.aux.y;
150 
152  param.block.z = dim_threads * 2;
153  param.grid.z = 2*(Nc/Mc);
154  param.shared_bytes = sharedBytesPerThread()*param.block.x*param.block.y*param.block.z > sharedBytesPerBlock(param) ?
155  sharedBytesPerThread()*param.block.x*param.block.y*param.block.z : sharedBytesPerBlock(param);
156  }
157 
159  virtual void defaultTuneParam(TuneParam &param) const
160  {
161  param.aux = make_int4(1,1,1,1);
162  color_col_stride = param.aux.x;
163  dim_threads = param.aux.y;
164 
166  // ensure that the default x block size is divisible by the warpSize
167  param.block.x = deviceProp.warpSize;
168  param.grid.x = (minThreads()+param.block.x-1)/param.block.x;
169  param.block.z = dim_threads * 2;
170  param.grid.z = 2*(Nc/Mc);
171  param.shared_bytes = sharedBytesPerThread()*param.block.x*param.block.y*param.block.z > sharedBytesPerBlock(param) ?
172  sharedBytesPerThread()*param.block.x*param.block.y*param.block.z : sharedBytesPerBlock(param);
173  }
174 
175  public:
176  inline DslashCoarse(ColorSpinorField &out, const ColorSpinorField &inA, const ColorSpinorField &inB,
177  const GaugeField &Y, const GaugeField &X, double kappa, int parity,
178  MemoryLocation *halo_location)
179  : TunableVectorY(out.SiteSubset() * (out.Ndim()==5 ? out.X(4) : 1)),
180  out(out), inA(inA), inB(inB), Y(Y), X(X), kappa(kappa), parity(parity),
181  nParity(out.SiteSubset()), nSrc(out.Ndim()==5 ? out.X(4) : 1)
182  {
183  strcpy(aux, "policy_kernel,");
184  if (out.Location() == QUDA_CUDA_FIELD_LOCATION) {
185 #ifdef JITIFY
186  create_jitify_program("kernels/dslash_coarse.cuh");
187 #endif
188  }
189  strcat(aux, compile_type_str(out));
190  strcat(aux, out.AuxString());
191  strcat(aux, comm_dim_partitioned_string());
192 
193  // record the location of where each pack buffer is in [2*dim+dir] ordering
194  // 0 - no packing
195  // 1 - pack to local GPU memory
196  // 2 - pack to local mapped CPU memory
197  // 3 - pack to remote mapped GPU memory
198  switch(type) {
199  case DSLASH_INTERIOR: strcat(aux,",interior"); break;
200  case DSLASH_EXTERIOR: strcat(aux,",exterior"); break;
201  case DSLASH_FULL: strcat(aux,",full"); break;
202  }
203 
204  if (doHalo<type>()) {
205  char label[15] = ",halo=";
206  for (int dim=0; dim<4; dim++) {
207  for (int dir=0; dir<2; dir++) {
208  label[2*dim+dir+6] = !comm_dim_partitioned(dim) ? '0' : halo_location[2*dim+dir] == Device ? '1' : halo_location[2*dim+dir] == Host ? '2' : '3';
209  }
210  }
211  label[14] = '\0';
212  strcat(aux,label);
213  }
214  }
215  virtual ~DslashCoarse() { }
216 
217  inline void apply(const cudaStream_t &stream) {
218 
219  if (out.Location() == QUDA_CPU_FIELD_LOCATION) {
220 
222  errorQuda("Unsupported field order colorspinor=%d gauge=%d combination\n", inA.FieldOrder(), Y.FieldOrder());
223 
225  coarseDslash<Float,nDim,Ns,Nc,Mc,dslash,clover,dagger,type>(arg);
226  } else {
227 
228  const TuneParam &tp = tuneLaunch(*this, getTuning(), getVerbosity());
229 
231  errorQuda("Unsupported field order colorspinor=%d gauge=%d combination\n", inA.FieldOrder(), Y.FieldOrder());
232 
234  Arg arg(out, inA, inB, Y, X, (Float)kappa, parity);
235 
236 #ifdef JITIFY
237  using namespace jitify::reflection;
238  jitify_error = program->kernel("quda::coarseDslashKernel")
239  .instantiate(Type<Float>(),nDim,Ns,Nc,Mc,(int)tp.aux.x,(int)tp.aux.y,dslash,clover,dagger,type,Type<Arg>())
240  .configure(tp.grid,tp.block,tp.shared_bytes,stream).launch(arg);
241 #else
242  switch (tp.aux.y) { // dimension gather parallelisation
243  case 1:
244  switch (tp.aux.x) { // this is color_col_stride
245  case 1:
246  coarseDslashKernel<Float,nDim,Ns,Nc,Mc,1,1,dslash,clover,dagger,type> <<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
247  break;
248 #ifdef DOT_PRODUCT_SPLIT
249  case 2:
250  coarseDslashKernel<Float,nDim,Ns,Nc,Mc,2,1,dslash,clover,dagger,type> <<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
251  break;
252  case 4:
253  coarseDslashKernel<Float,nDim,Ns,Nc,Mc,4,1,dslash,clover,dagger,type> <<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
254  break;
255  case 8:
256  coarseDslashKernel<Float,nDim,Ns,Nc,Mc,8,1,dslash,clover,dagger,type> <<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
257  break;
258 #endif // DOT_PRODUCT_SPLIT
259  default:
260  errorQuda("Color column stride %d not valid", tp.aux.x);
261  }
262  break;
263  case 2:
264  switch (tp.aux.x) { // this is color_col_stride
265  case 1:
266  coarseDslashKernel<Float,nDim,Ns,Nc,Mc,1,2,dslash,clover,dagger,type> <<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
267  break;
268 #ifdef DOT_PRODUCT_SPLIT
269  case 2:
270  coarseDslashKernel<Float,nDim,Ns,Nc,Mc,2,2,dslash,clover,dagger,type> <<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
271  break;
272  case 4:
273  coarseDslashKernel<Float,nDim,Ns,Nc,Mc,4,2,dslash,clover,dagger,type> <<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
274  break;
275  case 8:
276  coarseDslashKernel<Float,nDim,Ns,Nc,Mc,8,2,dslash,clover,dagger,type> <<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
277  break;
278 #endif // DOT_PRODUCT_SPLIT
279  default:
280  errorQuda("Color column stride %d not valid", tp.aux.x);
281  }
282  break;
283  case 4:
284  switch (tp.aux.x) { // this is color_col_stride
285  case 1:
286  coarseDslashKernel<Float,nDim,Ns,Nc,Mc,1,4,dslash,clover,dagger,type> <<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
287  break;
288 #ifdef DOT_PRODUCT_SPLIT
289  case 2:
290  coarseDslashKernel<Float,nDim,Ns,Nc,Mc,2,4,dslash,clover,dagger,type> <<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
291  break;
292  case 4:
293  coarseDslashKernel<Float,nDim,Ns,Nc,Mc,4,4,dslash,clover,dagger,type> <<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
294  break;
295  case 8:
296  coarseDslashKernel<Float,nDim,Ns,Nc,Mc,8,4,dslash,clover,dagger,type> <<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
297  break;
298 #endif // DOT_PRODUCT_SPLIT
299  default:
300  errorQuda("Color column stride %d not valid", tp.aux.x);
301  }
302  break;
303  default:
304  errorQuda("Invalid dimension thread splitting %d", tp.aux.y);
305  }
306 #endif
307  }
308  }
309 
310  TuneKey tuneKey() const {
311  return TuneKey(out.VolString(), typeid(*this).name(), aux);
312  }
313 
314  void preTune() {
315  saveOut = new char[out.Bytes()];
316  cudaMemcpy(saveOut, out.V(), out.Bytes(), cudaMemcpyDeviceToHost);
317  }
318 
319  void postTune()
320  {
321  cudaMemcpy(out.V(), saveOut, out.Bytes(), cudaMemcpyHostToDevice);
322  delete[] saveOut;
323  }
324 
325  };
326 
327 
328  template <typename Float, typename yFloat, typename ghostFloat, int coarseColor, int coarseSpin>
329  inline void ApplyCoarse(ColorSpinorField &out, const ColorSpinorField &inA, const ColorSpinorField &inB,
330  const GaugeField &Y, const GaugeField &X, double kappa, int parity, bool dslash,
331  bool clover, bool dagger, DslashType type, MemoryLocation *halo_location) {
332 
333  const int colors_per_thread = 1;
334  const int nDim = 4;
335 
336  if (dagger) {
337  if (dslash) {
338  if (clover) {
339 
340  if (type == DSLASH_FULL) {
341  DslashCoarse<Float,yFloat,ghostFloat,nDim,coarseSpin,coarseColor,colors_per_thread,true,true,true,DSLASH_FULL> dslash(out, inA, inB, Y, X, kappa, parity, halo_location);
342  dslash.apply(0);
343  } else if (type == DSLASH_INTERIOR) {
344  DslashCoarse<Float,yFloat,ghostFloat,nDim,coarseSpin,coarseColor,colors_per_thread,true,true,true,DSLASH_INTERIOR> dslash(out, inA, inB, Y, X, kappa, parity, halo_location);
345  dslash.apply(0);
346  } else { errorQuda("Dslash type %d not instantiated", type); }
347 
348  } else { // plain dslash
349 
350  if (type == DSLASH_FULL) {
351  DslashCoarse<Float,yFloat,ghostFloat,nDim,coarseSpin,coarseColor,colors_per_thread,true,false,true,DSLASH_FULL> dslash(out, inA, inB, Y, X, kappa, parity, halo_location);
352  dslash.apply(0);
353  } else if (type == DSLASH_INTERIOR) {
354  DslashCoarse<Float,yFloat,ghostFloat,nDim,coarseSpin,coarseColor,colors_per_thread,true,false,true,DSLASH_INTERIOR> dslash(out, inA, inB, Y, X, kappa, parity, halo_location);
355  dslash.apply(0);
356  } else { errorQuda("Dslash type %d not instantiated", type); }
357 
358  }
359  } else {
360 
361  if (type == DSLASH_EXTERIOR) errorQuda("Cannot call halo on pure clover kernel");
362  if (clover) {
363  DslashCoarse<Float,yFloat,ghostFloat,nDim,coarseSpin,coarseColor,colors_per_thread,false,true,true,DSLASH_FULL> dslash(out, inA, inB, Y, X, kappa, parity, halo_location);
364  dslash.apply(0);
365  } else {
366  errorQuda("Unsupported dslash=false clover=false");
367  }
368 
369  }
370  } else {
371 
372  if (dslash) {
373  if (clover) {
374 
375  if (type == DSLASH_FULL) {
376  DslashCoarse<Float,yFloat,ghostFloat,nDim,coarseSpin,coarseColor,colors_per_thread,true,true,false,DSLASH_FULL> dslash(out, inA, inB, Y, X, kappa, parity, halo_location);
377  dslash.apply(0);
378  } else if (type == DSLASH_INTERIOR) {
379  DslashCoarse<Float,yFloat,ghostFloat,nDim,coarseSpin,coarseColor,colors_per_thread,true,true,false,DSLASH_INTERIOR> dslash(out, inA, inB, Y, X, kappa, parity, halo_location);
380  dslash.apply(0);
381  } else { errorQuda("Dslash type %d not instantiated", type); }
382 
383  } else { // plain dslash
384 
385  if (type == DSLASH_FULL) {
386  DslashCoarse<Float,yFloat,ghostFloat,nDim,coarseSpin,coarseColor,colors_per_thread,true,false,false,DSLASH_FULL> dslash(out, inA, inB, Y, X, kappa, parity, halo_location);
387  dslash.apply(0);
388  } else if (type == DSLASH_INTERIOR) {
389  DslashCoarse<Float,yFloat,ghostFloat,nDim,coarseSpin,coarseColor,colors_per_thread,true,false,false,DSLASH_INTERIOR> dslash(out, inA, inB, Y, X, kappa, parity, halo_location);
390  dslash.apply(0);
391  } else { errorQuda("Dslash type %d not instantiated", type); }
392 
393  }
394  } else {
395  if (type == DSLASH_EXTERIOR) errorQuda("Cannot call halo on pure clover kernel");
396  if (clover) {
397  DslashCoarse<Float,yFloat,ghostFloat,nDim,coarseSpin,coarseColor,colors_per_thread,false,true,false,DSLASH_FULL> dslash(out, inA, inB, Y, X, kappa, parity, halo_location);
398  dslash.apply(0);
399  } else {
400  errorQuda("Unsupported dslash=false clover=false");
401  }
402  }
403  }
404  }
405 
406  // template on the number of coarse colors
407  template <typename Float, typename yFloat, typename ghostFloat>
408  inline void ApplyCoarse(ColorSpinorField &out, const ColorSpinorField &inA, const ColorSpinorField &inB,
409  const GaugeField &Y, const GaugeField &X, double kappa, int parity, bool dslash,
410  bool clover, bool dagger, DslashType type, MemoryLocation *halo_location) {
411 
412  if (Y.FieldOrder() != X.FieldOrder())
413  errorQuda("Field order mismatch Y = %d, X = %d", Y.FieldOrder(), X.FieldOrder());
414 
415  if (inA.FieldOrder() != out.FieldOrder())
416  errorQuda("Field order mismatch inA = %d, out = %d", inA.FieldOrder(), out.FieldOrder());
417 
418  if (inA.Nspin() != 2)
419  errorQuda("Unsupported number of coarse spins %d\n",inA.Nspin());
420 
421 #if 0
422  } else if (inA.Ncolor() == 4) {
423  ApplyCoarse<Float,yFloat,ghostFloat,4,2>(out, inA, inB, Y, X, kappa, parity, dslash, clover, dagger, type, halo_location);
424 #endif
425  if (inA.Ncolor() == 6) { // free field Wilson
426  ApplyCoarse<Float,yFloat,ghostFloat,6,2>(out, inA, inB, Y, X, kappa, parity, dslash, clover, dagger, type, halo_location);
427 #if 0
428  } else if (inA.Ncolor() == 8) {
429  ApplyCoarse<Float,yFloat,ghostFloat,8,2>(out, inA, inB, Y, X, kappa, parity, dslash, clover, dagger, type, halo_location);
430  } else if (inA.Ncolor() == 12) {
431  ApplyCoarse<Float,yFloat,ghostFloat,12,2>(out, inA, inB, Y, X, kappa, parity, dslash, clover, dagger, type, halo_location);
432  } else if (inA.Ncolor() == 16) {
433  ApplyCoarse<Float,yFloat,ghostFloat,16,2>(out, inA, inB, Y, X, kappa, parity, dslash, clover, dagger, type, halo_location);
434  } else if (inA.Ncolor() == 20) {
435  ApplyCoarse<Float,yFloat,ghostFloat,20,2>(out, inA, inB, Y, X, kappa, parity, dslash, clover, dagger, type, halo_location);
436 #endif
437  } else if (inA.Ncolor() == 24) {
438  ApplyCoarse<Float,yFloat,ghostFloat,24,2>(out, inA, inB, Y, X, kappa, parity, dslash, clover, dagger, type, halo_location);
439 #if 0
440  } else if (inA.Ncolor() == 28) {
441  ApplyCoarse<Float,yFloat,ghostFloat,28,2>(out, inA, inB, Y, X, kappa, parity, dslash, clover, dagger, type, halo_location);
442 #endif
443  } else if (inA.Ncolor() == 32) {
444  ApplyCoarse<Float,yFloat,ghostFloat,32,2>(out, inA, inB, Y, X, kappa, parity, dslash, clover, dagger, type, halo_location);
445  } else {
446  errorQuda("Unsupported number of coarse dof %d\n", Y.Ncolor());
447  }
448  }
449 
450  // this is the Worker pointer that may have issue additional work
451  // while we're waiting on communication to finish
452  namespace dslash {
453  extern Worker* aux_worker;
454  }
455 
456 #endif // GPU_MULTIGRID
457 
458  enum class DslashCoarsePolicy {
459  DSLASH_COARSE_BASIC, // stage both sends and recvs in host memory using memcpys
460  DSLASH_COARSE_ZERO_COPY_PACK, // zero copy write pack buffers
461  DSLASH_COARSE_ZERO_COPY_READ, // zero copy read halos in dslash kernel
462  DSLASH_COARSE_ZERO_COPY, // full zero copy
463  DSLASH_COARSE_GDR_SEND, // GDR send
464  DSLASH_COARSE_GDR_RECV, // GDR recv
465  DSLASH_COARSE_GDR, // full GDR
466  DSLASH_COARSE_ZERO_COPY_PACK_GDR_RECV, // zero copy write and GDR recv
467  DSLASH_COARSE_GDR_SEND_ZERO_COPY_READ, // GDR send and zero copy read
469  };
470 
472 
476  const GaugeField &Y;
477  const GaugeField &X;
478  double kappa;
479  int parity;
480  bool dslash;
481  bool clover;
482  bool dagger;
483  const int *commDim;
485 
487  const GaugeField &Y, const GaugeField &X, double kappa, int parity,
488  bool dslash, bool clover, bool dagger, const int *commDim, QudaPrecision halo_precision)
489  : out(out), inA(inA), inB(inB), Y(Y), X(X), kappa(kappa), parity(parity),
490  dslash(dslash), clover(clover), dagger(dagger), commDim(commDim),
491  halo_precision(halo_precision == QUDA_INVALID_PRECISION ? Y.Precision() : halo_precision) { }
492 
496  inline void operator()(DslashCoarsePolicy policy) {
497 #ifdef GPU_MULTIGRID
498  if (inA.V() == out.V()) errorQuda("Aliasing pointers");
499 
500  // check all precisions match
501  QudaPrecision precision = checkPrecision(out, inA, inB);
502  checkPrecision(Y, X);
503 
504  // check all locations match
505  checkLocation(out, inA, inB, Y, X);
506 
507  int comm_sum = 4;
508  if (commDim) for (int i=0; i<4; i++) comm_sum -= (1-commDim[i]);
509  if (comm_sum != 4 && comm_sum != 0) errorQuda("Unsupported comms %d", comm_sum);
510  bool comms = comm_sum;
511 
512  MemoryLocation pack_destination[2*QUDA_MAX_DIM]; // where we will pack the ghost buffer to
513  MemoryLocation halo_location[2*QUDA_MAX_DIM]; // where we load the halo from
514  for (int i=0; i<2*QUDA_MAX_DIM; i++) {
519  }
524 
525  // disable peer-to-peer if doing a zero-copy policy (temporary)
531 
532  if (dslash && comm_partitioned() && comms) {
533  const int nFace = 1;
534  inA.exchangeGhost((QudaParity)(inA.SiteSubset() == QUDA_PARITY_SITE_SUBSET ? (1 - parity) : 0), nFace, dagger,
535  pack_destination, halo_location, gdr_send, gdr_recv, halo_precision);
536  }
537 
539 
540  if (precision == QUDA_DOUBLE_PRECISION) {
541 #ifdef GPU_MULTIGRID_DOUBLE
542  if (Y.Precision() != QUDA_DOUBLE_PRECISION)
543  errorQuda("Y Precision %d not supported", Y.Precision());
544  if (halo_precision != QUDA_DOUBLE_PRECISION)
545  errorQuda("Halo precision %d not supported with field precision %d and link precision %d", halo_precision, precision, Y.Precision());
546  ApplyCoarse<double,double,double>(out, inA, inB, Y, X, kappa, parity, dslash, clover,
547  dagger, comms ? DSLASH_FULL : DSLASH_INTERIOR, halo_location);
548  //if (dslash && comm_partitioned()) ApplyCoarse<double>(out, inA, inB, Y, X, kappa, parity, dslash, clover, dagger, true, halo_location);
549 #else
550  errorQuda("Double precision multigrid has not been enabled");
551 #endif
552  } else if (precision == QUDA_SINGLE_PRECISION) {
553  if (Y.Precision() == QUDA_SINGLE_PRECISION) {
554  if (halo_precision == QUDA_SINGLE_PRECISION) {
555  ApplyCoarse<float,float,float>(out, inA, inB, Y, X, kappa, parity, dslash, clover,
556  dagger, comms ? DSLASH_FULL : DSLASH_INTERIOR, halo_location);
557  } else {
558  errorQuda("Halo precision %d not supported with field precision %d and link precision %d", halo_precision, precision, Y.Precision());
559  }
560  } else if (Y.Precision() == QUDA_HALF_PRECISION) {
561 #if QUDA_PRECISION & 2
562  if (halo_precision == QUDA_HALF_PRECISION) {
563  ApplyCoarse<float,short,short>(out, inA, inB, Y, X, kappa, parity, dslash, clover,
564  dagger, comms ? DSLASH_FULL : DSLASH_INTERIOR, halo_location);
565  } else if (halo_precision == QUDA_QUARTER_PRECISION) {
566 #if QUDA_PRECISION & 1
567  ApplyCoarse<float,short,char>(out, inA, inB, Y, X, kappa, parity, dslash, clover,
568  dagger, comms ? DSLASH_FULL : DSLASH_INTERIOR, halo_location);
569 #else
570  errorQuda("QUDA_PRECISION=%d does not enable quarter precision", QUDA_PRECISION);
571 #endif
572  } else {
573  errorQuda("Halo precision %d not supported with field precision %d and link precision %d", halo_precision, precision, Y.Precision());
574  }
575 #else
576  errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION);
577 #endif
578  } else {
579  errorQuda("Unsupported precision %d\n", Y.Precision());
580  }
581  //if (dslash && comm_partitioned()) ApplyCoarse<float>(out, inA, inB, Y, X, kappa, parity, dslash, clover, dagger, true, halo_location);
582  } else {
583  errorQuda("Unsupported precision %d\n", Y.Precision());
584  }
585 
586  if (dslash && comm_partitioned() && comms) inA.bufferIndex = (1 - inA.bufferIndex);
587  comm_enable_peer2peer(true);
588 #else
589  errorQuda("Multigrid has not been built");
590 #endif
591  }
592 
593  };
594 
595  static bool dslash_init = false;
598 
599  // string used as a tunekey to ensure we retune if the dslash policy env changes
601 
603  policies[static_cast<std::size_t>(p)] = p;
604  }
605 
607  policies[static_cast<std::size_t>(p)] = DslashCoarsePolicy::DSLASH_COARSE_POLICY_DISABLED;
608  }
609 
611 
613 
614  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
615  bool tuneAuxDim() const { return true; } // Do tune the aux dimensions.
616  unsigned int sharedBytesPerThread() const { return 0; }
617  unsigned int sharedBytesPerBlock(const TuneParam &param) const { return 0; }
618 
619  public:
620  inline DslashCoarsePolicyTune(DslashCoarseLaunch &dslash) : dslash(dslash)
621  {
622  if (!dslash_init) {
623 
624  static char *dslash_policy_env = getenv("QUDA_ENABLE_DSLASH_COARSE_POLICY");
625 
626  if (dslash_policy_env) { // set the policies to tune for explicitly
627  std::stringstream policy_list(dslash_policy_env);
628 
629  int policy_;
630  while (policy_list >> policy_) {
631  DslashCoarsePolicy dslash_policy = static_cast<DslashCoarsePolicy>(policy_);
632 
633  // check this is a valid policy choice
634  if ( (dslash_policy == DslashCoarsePolicy::DSLASH_COARSE_GDR_SEND ||
636  dslash_policy == DslashCoarsePolicy::DSLASH_COARSE_GDR ||
639  errorQuda("Cannot select a GDR policy %d unless QUDA_ENABLE_GDR is set", static_cast<int>(dslash_policy));
640  }
641 
642  enable_policy(dslash_policy);
643  first_active_policy = policy_ < first_active_policy ? policy_ : first_active_policy;
644  if (policy_list.peek() == ',') policy_list.ignore();
645  }
646  if(first_active_policy == static_cast<int>(DslashCoarsePolicy::DSLASH_COARSE_POLICY_DISABLED)) errorQuda("No valid policy found in QUDA_ENABLE_DSLASH_COARSE_POLICY");
647  } else {
648  first_active_policy = 0;
653  if (comm_gdr_enabled()) {
659  }
660  }
661 
662  // construct string specifying which policies have been enabled
663  strcat(policy_string, ",pol=");
664  for (int i = 0; i < (int)DslashCoarsePolicy::DSLASH_COARSE_POLICY_DISABLED; i++) {
665  strcat(policy_string, (int)policies[i] == i ? "1" : "0");
666  }
667 
668  dslash_init = true;
669  }
670 
671  strcpy(aux, "policy,");
672  if (dslash.dslash) strcat(aux, "dslash");
673  strcat(aux, dslash.clover ? "clover," : ",");
674  strcat(aux, dslash.inA.AuxString());
675  strcat(aux, ",gauge_prec=");
676 
677  char prec_str[8];
678  i32toa(prec_str, dslash.Y.Precision());
679  strcat(aux, prec_str);
680  strcat(aux, ",halo_prec=");
681  i32toa(prec_str, dslash.halo_precision);
682  strcat(aux, prec_str);
683  strcat(aux, comm_dim_partitioned_string(dslash.commDim));
684  strcat(aux, comm_dim_topology_string());
685  strcat(aux, comm_config_string()); // and change in P2P/GDR will be stored as a separate tunecache entry
686  strcat(aux, policy_string); // any change in policies enabled will be stored as a separate entry
687 
688  int comm_sum = 4;
689  if (dslash.commDim)
690  for (int i = 0; i < 4; i++) comm_sum -= (1 - dslash.commDim[i]);
691  strcat(aux, comm_sum ? ",full" : ",interior");
692 
693  // before we do policy tuning we must ensure the kernel
694  // constituents have been tuned since we can't do nested tuning
695  if (getTuning() && getTuneCache().find(tuneKey()) == getTuneCache().end()) {
697  for (auto &i : policies) if(i!= DslashCoarsePolicy::DSLASH_COARSE_POLICY_DISABLED) dslash(i);
699  setPolicyTuning(true);
700  }
701  }
702 
704 
705  inline void apply(const cudaStream_t &stream) {
706  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
707 
708  if (tp.aux.x >= (int)policies.size()) errorQuda("Requested policy that is outside of range");
709  if (policies[tp.aux.x] == DslashCoarsePolicy::DSLASH_COARSE_POLICY_DISABLED ) errorQuda("Requested policy is disabled");
710  dslash(policies[tp.aux.x]);
711  }
712 
713  int tuningIter() const { return 10; }
714 
715  bool advanceAux(TuneParam &param) const
716  {
717  while ((unsigned)param.aux.x < policies.size()-1) {
718  param.aux.x++;
720  }
721  param.aux.x = 0;
722  return false;
723  }
724 
725  bool advanceTuneParam(TuneParam &param) const { return advanceAux(param); }
726 
727  void initTuneParam(TuneParam &param) const {
728  Tunable::initTuneParam(param);
729  param.aux.x = first_active_policy;
730  param.aux.y = 0;
731  param.aux.z = 0;
732  param.aux.w = 0;
733  }
734 
735  void defaultTuneParam(TuneParam &param) const {
737  param.aux.x = first_active_policy;
738  param.aux.y = 0;
739  param.aux.z = 0;
740  param.aux.w = 0;
741  }
742 
743  TuneKey tuneKey() const {
744  return TuneKey(dslash.inA.VolString(), typeid(*this).name(), aux);
745  }
746 
747  long long flops() const {
748  int nDim = 4;
749  int Ns = dslash.inA.Nspin();
750  int Nc = dslash.inA.Ncolor();
751  int nParity = dslash.inA.SiteSubset();
752  int volumeCB = dslash.inA.VolumeCB();
753  return ((dslash.dslash*2*nDim+dslash.clover*1)*(8*Ns*Nc*Ns*Nc)-2*Ns*Nc)*nParity*volumeCB;
754  }
755 
756  long long bytes() const {
757  int nParity = dslash.inA.SiteSubset();
758  return (dslash.dslash||dslash.clover) * dslash.out.Bytes() +
759  dslash.dslash*8*dslash.inA.Bytes() + dslash.clover*dslash.inB.Bytes() +
760  nParity*(dslash.dslash*dslash.Y.Bytes()*dslash.Y.VolumeCB()/(2*dslash.Y.Stride())
761  + dslash.clover*dslash.X.Bytes()/2);
762  // multiply Y by volume / stride to correct for pad
763  }
764  };
765 
766 
767  //Apply the coarse Dirac matrix to a coarse grid vector
768  //out(x) = M*in = X*in - kappa*\sum_mu Y_{-\mu}(x)in(x+mu) + Y^\dagger_mu(x-mu)in(x-mu)
769  // or
770  //out(x) = M^dagger*in = X^dagger*in - kappa*\sum_mu Y^\dagger_{-\mu}(x)in(x+mu) + Y_mu(x-mu)in(x-mu)
771  //Uses the kappa normalization for the Wilson operator.
773  const GaugeField &Y, const GaugeField &X, double kappa, int parity,
774  bool dslash, bool clover, bool dagger, const int *commDim, QudaPrecision halo_precision) {
775 
776  DslashCoarseLaunch Dslash(out, inA, inB, Y, X, kappa, parity, dslash, clover, dagger, commDim, halo_precision);
777 
778  DslashCoarsePolicyTune policy(Dslash);
779  policy.apply(0);
780 
781  }//ApplyCoarse
782 
783 
784 } // namespace quda
virtual void apply(const cudaStream_t &stream)=0
void operator()(DslashCoarsePolicy policy)
Execute the coarse dslash using the given policy.
void ApplyCoarse(ColorSpinorField &out, const ColorSpinorField &inA, const ColorSpinorField &inB, const GaugeField &Y, const GaugeField &X, double kappa, int parity=QUDA_INVALID_PARITY, bool dslash=true, bool clover=true, bool dagger=false, const int *commDim=0, QudaPrecision halo_precision=QUDA_INVALID_PRECISION)
Apply the coarse dslash stencil. This single driver accounts for all variations with and without the ...
const GaugeField & X
enum QudaPrecision_s QudaPrecision
static bool dslash_init
QudaGaugeFieldOrder FieldOrder() const
Definition: gauge_field.h:257
const char * AuxString() const
cudaDeviceProp deviceProp
void disableProfileCount()
Disable the profile kernel counting.
Definition: tune.cpp:125
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
double kappa
Definition: test_util.cpp:1647
void end(void)
Definition: blas_quda.cu:489
#define checkPrecision(...)
const ColorSpinorField & inB
int Stride() const
#define errorQuda(...)
Definition: util_quda.h:121
Helper file when using jitify run-time compilation. This file should be included in source code...
static char policy_string[TuneKey::aux_n]
cudaStream_t * stream
int comm_partitioned()
Loop over comm_dim_partitioned(dim) for all comms dimensions.
DslashCoarseLaunch(ColorSpinorField &out, const ColorSpinorField &inA, const ColorSpinorField &inB, const GaugeField &Y, const GaugeField &X, double kappa, int parity, bool dslash, bool clover, bool dagger, const int *commDim, QudaPrecision halo_precision)
unsigned int sharedBytesPerThread() const
const char * VolString() const
unsigned int sharedBytesPerBlock(const TuneParam &param) const
void initTuneParam(TuneParam &param) const
Definition: tune_quda.h:466
bool advanceBlockDim(TuneParam &param) const
Definition: tune_quda.h:440
static int first_active_policy
const char * comm_dim_partitioned_string(const int *comm_dim_override=0)
Return a string that defines the comm partitioning (used as a tuneKey)
const char * compile_type_str(const LatticeField &meta, QudaFieldLocation location_=QUDA_INVALID_FIELD_LOCATION)
Helper function for setting auxilary string.
void enableProfileCount()
Enable the profile kernel counting.
Definition: tune.cpp:126
DslashCoarseLaunch & dslash
void comm_enable_peer2peer(bool enable)
Enable / disable peer-to-peer communication: used for dslash policies that do not presently support p...
void i32toa(char *buffer, int32_t value)
Definition: uint_to_char.h:117
QudaGaugeParam param
Definition: pack_test.cpp:17
DslashCoarsePolicy
static int bufferIndex
Worker * aux_worker
Definition: dslash_quda.cu:87
size_t Bytes() const
Definition: gauge_field.h:311
const char * comm_dim_topology_string()
Return a string that defines the comm topology (for use as a tuneKey)
ColorSpinorField & out
QudaSiteSubset SiteSubset() const
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
const char * prec_str[]
Definition: blas_test.cu:909
#define checkLocation(...)
void defaultTuneParam(TuneParam &param) const
Definition: tune_quda.h:474
void disable_policy(DslashCoarsePolicy p)
int X[4]
Definition: covdev_test.cpp:70
bool advanceAux(TuneParam &param) const
enum QudaParity_s QudaParity
void setPolicyTuning(bool)
Enable / disable whether are tuning a policy.
Definition: tune.cpp:499
QudaFieldLocation Location() const
void apply(const cudaStream_t &stream)
static int commDim[QUDA_MAX_DIM]
Definition: dslash_pack.cuh:9
cpuColorSpinorField * out
static const int aux_n
Definition: tune_key.h:12
const GaugeField & Y
int VolumeCB() const
unsigned long long flops
Definition: blas_quda.cu:22
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
const QudaPrecision halo_precision
bool comm_gdr_enabled()
Query if GPU Direct RDMA communication is enabled (global setting)
const ColorSpinorField & inA
virtual void initTuneParam(TuneParam &param) const
Definition: tune_quda.h:304
#define QUDA_MAX_DIM
Maximum number of dimensions supported by QUDA. In practice, no routines make use of more than 5...
static int volumeCB
Definition: face_gauge.cpp:43
void enable_policy(DslashCoarsePolicy p)
void defaultTuneParam(TuneParam &param) const
virtual void exchangeGhost(QudaParity parity, int nFace, int dagger, const MemoryLocation *pack_destination=nullptr, const MemoryLocation *halo_location=nullptr, bool gdr_send=false, bool gdr_recv=false, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION) const =0
const std::map< TuneKey, TuneParam > & getTuneCache()
Returns a reference to the tunecache map.
Definition: tune.cpp:128
static std::vector< DslashCoarsePolicy > policies(static_cast< int >(DslashCoarsePolicy::DSLASH_COARSE_POLICY_DISABLED), DslashCoarsePolicy::DSLASH_COARSE_POLICY_DISABLED)
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
bool advanceTuneParam(TuneParam &param) const
QudaPrecision Precision() const
void initTuneParam(TuneParam &param) const
QudaDagType dagger
Definition: test_util.cpp:1620
QudaParity parity
Definition: covdev_test.cpp:54
DslashCoarsePolicyTune(DslashCoarseLaunch &dslash)
QudaFieldOrder FieldOrder() const
unsigned long long bytes
Definition: blas_quda.cu:23
int comm_dim_partitioned(int dim)
const char * comm_config_string()
Return a string that defines the P2P/GDR environment variable configuration (for use as a tuneKey to ...
virtual void defaultTuneParam(TuneParam &param) const
Definition: tune_quda.h:329