QUDA  v1.1.0
A library for QCD on GPUs
dslash_pack2.cu
Go to the documentation of this file.
1 #include <color_spinor_field.h>
2 
3 // STRIPED - spread the blocks throughout the workload to ensure we
4 // work on all directions/dimensions simultanesouly to maximize NVLink saturation
5 // if not STRIPED then this means we assign one thread block per direction / dimension
6 // currently does not work with NVSHMEM
7 #ifndef NVSHMEM_COMMS
8 #define STRIPED 1
9 #endif
10 
11 #include <dslash_quda.h>
12 #include <kernels/dslash_pack.cuh>
13 #include <instantiate.h>
14 
15 namespace quda
16 {
17 
18  int* getPackComms() { return commDim; }
19 
20  void setPackComms(const int *comm_dim)
21  {
22  for (int i = 0; i < 4; i++) commDim[i] = comm_dim[i];
23  for (int i = 4; i < QUDA_MAX_DIM; i++) commDim[i] = 0;
24  }
25 
26  template <typename Float, int nSpin, int nColor, bool spin_project>
27  std::ostream &operator<<(std::ostream &out, const PackArg<Float, nSpin, nColor, spin_project> &arg)
28  {
29  out << "parity = " << arg.parity << std::endl;
30  out << "nParity = " << arg.nParity << std::endl;
31  out << "pc_type = " << arg.pc_type << std::endl;
32  out << "nFace = " << arg.nFace << std::endl;
33  out << "dagger = " << arg.dagger << std::endl;
34  out << "a = " << arg.a << std::endl;
35  out << "b = " << arg.b << std::endl;
36  out << "c = " << arg.c << std::endl;
37  out << "twist = " << arg.twist << std::endl;
38  out << "threads = " << arg.threads << std::endl;
39  out << "threadDimMapLower = { ";
40  for (int i = 0; i < 4; i++) out << arg.threadDimMapLower[i] << (i < 3 ? ", " : " }");
41  out << std::endl;
42  out << "threadDimMapUpper = { ";
43  for (int i = 0; i < 4; i++) out << arg.threadDimMapUpper[i] << (i < 3 ? ", " : " }");
44  out << std::endl;
45  out << "sites_per_block = " << arg.sites_per_block << std::endl;
46  return out;
47  }
48 
49  // FIXME - add CPU variant
50 
51  template <typename Float, int nColor, bool spin_project> class Pack : TunableVectorYZ
52  {
53 
54 protected:
55  void **ghost;
56  const ColorSpinorField &in;
57  MemoryLocation location;
58  const int nFace;
59  const bool dagger; // only has meaning for nSpin=4
60  const int parity;
61  const int nParity;
62  int threads;
63  const double a;
64  const double b;
65  const double c;
66  int twist; // only has meaning for nSpin=4
67 #ifdef NVSHMEM_COMMS
68  const int shmem;
69 #else
70  static constexpr int shmem = 0;
71 #endif
72 
73  bool tuneGridDim() const { return true; } // If striping, always tune grid dimension
74 
75  unsigned int maxGridSize() const
76  {
77  if (location & Host) {
78 #ifdef STRIPED
79  // if zero-copy policy then set a maximum number of blocks to be
80  // the 3 * number of dimensions we are communicating
81  int max = 3;
82 #else
83  // if zero-copy policy then assign exactly up to four thread blocks
84  // per direction per dimension (effectively no grid-size tuning)
85  int max = 2 * 4;
86 #endif
87  int nDimComms = 0;
88  for (int d = 0; d < in.Ndim(); d++) nDimComms += commDim[d];
89  return max * nDimComms;
90  } else {
91  return TunableVectorYZ::maxGridSize();
92  }
93  } // use no more than a quarter of the GPU
94 
95  unsigned int minGridSize() const
96  {
97  if (location & Host || location & Shmem) {
98 #ifdef STRIPED
99  // if zero-copy policy then set a minimum number of blocks to be
100  // the 1 * number of dimensions we are communicating
101  int min = 1;
102 #else
103  // if zero-copy policy then assign exactly one thread block
104  // per direction per dimension (effectively no grid-size tuning)
105  int min = 2;
106 #endif
107  int nDimComms = 0;
108  for (int d = 0; d < in.Ndim(); d++) nDimComms += commDim[d];
109  return min * nDimComms;
110  } else {
111  return TunableVectorYZ::minGridSize();
112  }
113  }
114 
115  int gridStep() const
116  {
117 #ifdef STRIPED
118  return TunableVectorYZ::gridStep();
119 #else
120  if (location & Host || location & Shmem) {
121  // the shmem kernel must ensure the grid size autotuner
122  // increments in steps of 2 * number partitioned dimensions
123  // for equal division of blocks to each direction/dimension
124  int nDimComms = 0;
125  for (int d = 0; d < in.Ndim(); d++) nDimComms += commDim[d];
126  return 2 * nDimComms;
127  } else {
128  return TunableVectorYZ::gridStep();
129  }
130 #endif
131  }
132 
133  bool tuneAuxDim() const { return true; } // Do tune the aux dimensions.
134  unsigned int minThreads() const { return threads; }
135 
136  void fillAux()
137  {
138  strcpy(aux, "policy_kernel,");
139  strcat(aux, in.AuxString());
140  char comm[5];
141  for (int i = 0; i < 4; i++) comm[i] = (commDim[i] ? '1' : '0');
142  comm[4] = '\0';
143  strcat(aux, ",comm=");
144  strcat(aux, comm);
145  strcat(aux, comm_dim_topology_string());
146  if (in.PCType() == QUDA_5D_PC) { strcat(aux, ",5D_pc"); }
147  if (dagger && in.Nspin() == 4) { strcat(aux, ",dagger"); }
148  if (getKernelPackT()) { strcat(aux, ",kernelPackT"); }
149  switch (nFace) {
150  case 1: strcat(aux, ",nFace=1"); break;
151  case 3: strcat(aux, ",nFace=3"); break;
152  default: errorQuda("Number of faces not supported");
153  }
154 
155  twist = ((b != 0.0) ? (c != 0.0 ? 2 : 1) : 0);
156  if (twist && a == 0.0) errorQuda("Twisted packing requires non-zero scale factor a");
157  if (twist) strcat(aux, twist == 2 ? ",twist-doublet" : ",twist-singlet");
158 
159  // label the locations we are packing to
160  // location label is nonp2p-p2p
161  switch ((int)location) {
162  case Device | Remote: strcat(aux, ",device-remote"); break;
163  case Host | Remote: strcat(aux, ",host-remote"); break;
164  case Device: strcat(aux, ",device-device"); break;
165  case Host: strcat(aux, comm_peer2peer_enabled_global() ? ",host-device" : ",host-host"); break;
166  case Shmem: strcat(aux, ",shmem"); break;
167  default: errorQuda("Unknown pack target location %d\n", location);
168  }
169  }
170 
171 public:
172  Pack(void *ghost[], const ColorSpinorField &in, MemoryLocation location, int nFace, bool dagger, int parity, double a,
173  double b, double c, int shmem) :
174  TunableVectorYZ((in.Ndim() == 5 ? in.X(4) : 1), in.SiteSubset()),
175  ghost(ghost),
176  in(in),
177  location(location),
178  nFace(nFace),
179  dagger(dagger),
180  parity(parity),
181  nParity(in.SiteSubset()),
182  threads(0),
183  a(a),
184  b(b),
185  c(c)
186 #ifdef NVSHMEM_COMMS
187  ,
188  shmem(shmem)
189 #endif
190  {
191  fillAux();
192 
193  // compute number of threads - really number of active work items we have to do
194  for (int i = 0; i < 4; i++) {
195  if (!commDim[i]) continue;
196  if (i == 3 && !getKernelPackT()) continue;
197  threads += 2 * nFace * in.getDslashConstant().ghostFaceCB[i]; // 2 for forwards and backwards faces
198  }
199  }
200 
201  virtual ~Pack() { }
202 
203  template <typename T, typename Arg>
204  inline void launch(T *f, const TuneParam &tp, Arg &arg, const qudaStream_t &stream)
205  {
206  qudaLaunchKernel(f, tp, stream, arg);
207  }
208 
209  void apply(const qudaStream_t &stream)
210  {
211  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
212  // enable max shared memory mode on GPUs that support it
213  if (deviceProp.major >= 7) tp.set_max_shared_bytes = true;
214 
215  if (in.Nspin() == 4) {
216  using Arg = PackArg<Float, nColor, 4, spin_project>;
217  Arg arg(ghost, in, nFace, dagger, parity, threads, a, b, c, shmem);
218  arg.counter = dslash::get_shmem_sync_counter();
219  arg.swizzle = tp.aux.x;
220  arg.sites_per_block = (arg.threads + tp.grid.x - 1) / tp.grid.x;
221  arg.blocks_per_dir = tp.grid.x / (2 * arg.active_dims); // set number of blocks per direction
222 
223 #ifdef STRIPED
224  if (in.PCType() == QUDA_4D_PC) {
225  if (arg.dagger) {
226  switch (arg.twist) {
227  case 0: launch(packKernel<true, 0, QUDA_4D_PC, Arg>, tp, arg, stream); break;
228  case 1: launch(packKernel<true, 1, QUDA_4D_PC, Arg>, tp, arg, stream); break;
229  case 2: launch(packKernel<true, 2, QUDA_4D_PC, Arg>, tp, arg, stream); break;
230  }
231  } else {
232  switch (arg.twist) {
233  case 0: launch(packKernel<false, 0, QUDA_4D_PC, Arg>, tp, arg, stream); break;
234  default: errorQuda("Twisted packing only for dagger");
235  }
236  }
237  } else if (arg.pc_type == QUDA_5D_PC) {
238  if (arg.twist) errorQuda("Twist packing not defined");
239  if (arg.dagger) {
240  launch(packKernel<true, 0, QUDA_5D_PC, Arg>, tp, arg, stream);
241  } else {
242  launch(packKernel<false, 0, QUDA_5D_PC, Arg>, tp, arg, stream);
243  }
244  } else {
245  errorQuda("Unexpected preconditioning type %d", in.PCType());
246  }
247 #else
248  if (in.PCType() == QUDA_4D_PC) {
249  if (arg.dagger) {
250  switch (arg.twist) {
251  case 0:
252  launch((location & Host || location & Shmem) ? packShmemKernel<true, 0, QUDA_4D_PC, Arg> :
253  packKernel<true, 0, QUDA_4D_PC, Arg>,
254  tp, arg, stream);
255  break;
256  case 1:
257  launch((location & Host || location & Shmem) ? packShmemKernel<true, 1, QUDA_4D_PC, Arg> :
258  packKernel<true, 1, QUDA_4D_PC, Arg>,
259  tp, arg, stream);
260  break;
261  case 2:
262  launch((location & Host || location & Shmem) ? packShmemKernel<true, 2, QUDA_4D_PC, Arg> :
263  packKernel<true, 2, QUDA_4D_PC, Arg>,
264  tp, arg, stream);
265  break;
266  }
267  } else {
268  switch (arg.twist) {
269  case 0:
270  launch((location & Host || location & Shmem) ? packShmemKernel<false, 0, QUDA_4D_PC, Arg> :
271  packKernel<false, 0, QUDA_4D_PC, Arg>,
272  tp, arg, stream);
273  break;
274  default: errorQuda("Twisted packing only for dagger");
275  }
276  }
277  } else if (arg.pc_type == QUDA_5D_PC) {
278  if (arg.twist) errorQuda("Twist packing not defined");
279  if (arg.dagger) {
280  launch(packKernel<true, 0, QUDA_5D_PC, Arg>, tp, arg, stream);
281  } else {
282  launch(packKernel<false, 0, QUDA_5D_PC, Arg>, tp, arg, stream);
283  }
284  }
285 #endif
286  } else if (in.Nspin() == 1) {
287  using Arg = PackArg<Float, nColor, 1, false>;
288  Arg arg(ghost, in, nFace, dagger, parity, threads, a, b, c, shmem);
289  arg.counter = dslash::get_shmem_sync_counter();
290  arg.swizzle = tp.aux.x;
291  arg.sites_per_block = (arg.threads + tp.grid.x - 1) / tp.grid.x;
292  arg.blocks_per_dir = tp.grid.x / (2 * arg.active_dims); // set number of blocks per direction
293 
294 #ifdef STRIPED
295  launch(packStaggeredKernel<Arg>, tp, arg, stream);
296 #else
297  launch((location & Host || location & Shmem) ? packStaggeredShmemKernel<Arg> : packStaggeredKernel<Arg>, tp,
298  arg, stream);
299 #endif
300  } else {
301  errorQuda("Unsupported nSpin = %d\n", in.Nspin());
302  }
303  }
304 
305  bool tuneSharedBytes() const { return false; }
306 
307 #if 0
308  // not used at present, but if tuneSharedBytes is enabled then
309  // this allows tuning up the full dynamic shared memory if needed
310  unsigned int maxSharedBytesPerBlock() const { return maxDynamicSharedBytesPerBlock(); }
311 #endif
312 
313  void initTuneParam(TuneParam &param) const
314  {
315  TunableVectorYZ::initTuneParam(param);
316  // if doing a zero-copy policy then ensure that each thread block
317  // runs exclusively on a given SM - this is to ensure quality of
318  // service for the packing kernel when running concurrently.
319  if (location & Host) param.shared_bytes = maxDynamicSharedBytesPerBlock() / 2 + 1;
320 #ifndef STRIPED
321  if (location & Host) param.grid.x = minGridSize();
322 #endif
323  }
324 
325  void defaultTuneParam(TuneParam &param) const
326  {
327  TunableVectorYZ::defaultTuneParam(param);
328  // if doing a zero-copy policy then ensure that each thread block
329  // runs exclusively on a given SM - this is to ensure quality of
330  // service for the packing kernel when running concurrently.
331  if (location & Host) param.shared_bytes = maxDynamicSharedBytesPerBlock() / 2 + 1;
332 #ifndef STRIPED
333  if (location & Host) param.grid.x = minGridSize();
334 #endif
335  }
336 
337  TuneKey tuneKey() const { return TuneKey(in.VolString(), typeid(*this).name(), aux); }
338 
339  int tuningIter() const { return 3; }
340 
341  long long flops() const
342  {
343  // unless we are spin projecting (nSpin = 4), there are no flops to do
344  return in.Nspin() == 4 ? 2 * in.Nspin() / 2 * nColor * nParity * in.getDslashConstant().Ls * threads : 0;
345  }
346 
347  long long bytes() const
348  {
349  size_t precision = sizeof(Float);
350  size_t faceBytes = 2 * ((in.Nspin() == 4 ? in.Nspin() / 2 : in.Nspin()) + in.Nspin()) * nColor * precision;
351  if (precision == QUDA_HALF_PRECISION || precision == QUDA_QUARTER_PRECISION)
352  faceBytes += 2 * sizeof(float); // 2 is from input and output
353  return faceBytes * nParity * in.getDslashConstant().Ls * threads;
354  }
355  };
356 
357  template <typename Float, int nColor> struct GhostPack {
358  GhostPack(const ColorSpinorField &in, void *ghost[], MemoryLocation location, int nFace, bool dagger, int parity,
359  bool spin_project, double a, double b, double c, int shmem, const qudaStream_t &stream)
360  {
361  if (spin_project) {
362  Pack<Float, nColor, true> pack(ghost, in, location, nFace, dagger, parity, a, b, c, shmem);
363  pack.apply(stream);
364  } else {
365  Pack<Float, nColor, false> pack(ghost, in, location, nFace, dagger, parity, a, b, c, shmem);
366  pack.apply(stream);
367  }
368  }
369  };
370 
371  // Pack the ghost for the Dslash operator
372  void PackGhost(void *ghost[2 * QUDA_MAX_DIM], const ColorSpinorField &in, MemoryLocation location, int nFace,
373  bool dagger, int parity, bool spin_project, double a, double b, double c, int shmem,
374  const qudaStream_t &stream)
375  {
376  int nDimPack = 0;
377  for (int d = 0; d < 4; d++) {
378  if (!commDim[d]) continue;
379  if (d != 3 || getKernelPackT()) nDimPack++;
380  }
381  if (!nDimPack) return; // if zero then we have nothing to pack
382 
383  instantiate<GhostPack>(in, ghost, location, nFace, dagger, parity, spin_project, a, b, c, shmem, stream);
384  }
385 
386 } // namespace quda