QUDA  v1.1.0
A library for QCD on GPUs
coarse_op_preconditioned.cu
Go to the documentation of this file.
1 #include <typeinfo>
2 #include <gauge_field.h>
3 #include <blas_lapack.h>
4 #include <blas_quda.h>
5 #include <tune_quda.h>
6 
7 #include <jitify_helper.cuh>
8 #include <kernels/coarse_op_preconditioned.cuh>
9 
10 #include <coarse_op_preconditioned_mma_launch.h>
11 
12 namespace quda
13 {
14 
15  /**
16  @brief Launcher for CPU instantiations of preconditioned coarse-link construction
17  */
18  template <QudaFieldLocation location, typename Arg>
19  struct Launch {
20  Launch(Arg &arg, CUresult &error, bool compute_max_only, TuneParam &tp, bool use_mma, const qudaStream_t &stream)
21  {
22  if (compute_max_only)
23  CalculateYhatCPU<true, Arg>(arg);
24  else
25  CalculateYhatCPU<false, Arg>(arg);
26  }
27  };
28 
29  /**
30  @brief Launcher for GPU instantiations of preconditioned coarse-link construction
31  */
32  template <typename Arg>
33  struct Launch<QUDA_CUDA_FIELD_LOCATION, Arg> {
34  Launch(Arg &arg, CUresult &error, bool compute_max_only, TuneParam &tp, bool use_mma, const qudaStream_t &stream)
35  {
36  if (compute_max_only) {
37  if (!activeTuning()) {
38  qudaMemsetAsync(arg.max_d, 0, sizeof(typename Arg::Float), stream);
39  }
40  }
41 #ifdef JITIFY
42  if (use_mma) {
43  errorQuda("MMA kernels haven't been jitify'ed.");
44  } else {
45  using namespace jitify::reflection;
46  error = program->kernel("quda::CalculateYhatGPU")
47  .instantiate(compute_max_only, Type<Arg>())
48  .configure(tp.grid, tp.block, tp.shared_bytes, stream)
49  .launch(arg);
50  }
51 #else
52  if (use_mma) {
53  if (compute_max_only) {
54  mma::launch_yhat_kernel<true>(arg, arg.Y.VolumeCB(), tp, stream);
55  } else {
56  mma::launch_yhat_kernel<false>(arg, arg.Y.VolumeCB(), tp, stream);
57  }
58  } else {
59  if (compute_max_only) {
60  qudaLaunchKernel(CalculateYhatGPU<true, Arg>, tp, stream, arg);
61  } else {
62  qudaLaunchKernel(CalculateYhatGPU<false, Arg>, tp, stream, arg);
63  }
64  }
65 #endif
66  if (compute_max_only) {
67  if (!activeTuning()) { // only do copy once tuning is done
68  qudaMemcpyAsync(arg.max_h, arg.max_d, sizeof(typename Arg::Float), cudaMemcpyDeviceToHost, stream);
69  qudaStreamSynchronize(const_cast<qudaStream_t&>(stream));
70  }
71  }
72  }
73  };
74 
75  template <QudaFieldLocation location, typename Arg>
76  class CalculateYhat : public TunableVectorYZ {
77 
78  using Float = typename Arg::Float;
79  Arg &arg;
80  const LatticeField &meta;
81  const int n;
82 
83  bool compute_max_only;
84 
85  bool use_mma;
86 
87  long long flops() const { return 2l * arg.Y.VolumeCB() * 8 * n * n * (8*n-2); } // 8 from dir, 8 from complexity,
88  long long bytes() const { return 2l * (arg.Xinv.Bytes() + 8*arg.Y.Bytes() + !compute_max_only * 8*arg.Yhat.Bytes()) * n; }
89 
90  unsigned int minThreads() const { return arg.Y.VolumeCB(); }
91  bool tuneGridDim() const { return false; } // don't tune the grid dimension
92 
93  // all the tuning done is only in matrix tile size (Y/Z block.grid)
94  int blockMin() const { return 8; }
95  int blockStep() const { return 8; }
96  unsigned int maxBlockSize(const TuneParam &param) const { return 8u; }
97  bool tuneAuxDim() const { return use_mma; } // tune aux if doing mma
98 
99  public:
100  CalculateYhat(Arg &arg, const LatticeField &meta, bool use_mma) :
101  TunableVectorYZ(2 * arg.tile.M_tiles, 4 * arg.tile.N_tiles),
102  arg(arg),
103  meta(meta),
104  n(arg.tile.n),
105  compute_max_only(false),
106  use_mma(use_mma)
107  {
108  if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
109 #ifdef JITIFY
110  create_jitify_program("kernels/coarse_op_preconditioned.cuh");
111 #endif
112  arg.max_d = static_cast<Float*>(pool_device_malloc(sizeof(Float)));
113  }
114  arg.max_h = static_cast<Float*>(pool_pinned_malloc(sizeof(Float)));
115  strcpy(aux, compile_type_str(meta));
116  strcat(aux, comm_dim_partitioned_string());
117  }
118 
119  virtual ~CalculateYhat() {
120  if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
121  pool_device_free(arg.max_d);
122  }
123  pool_pinned_free(arg.max_h);
124  }
125 
126  void apply(const qudaStream_t &stream)
127  {
128  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
129  Launch<location, Arg>(arg, jitify_error, compute_max_only, tp, use_mma, stream);
130  }
131 
132  /**
133  Set if we're doing a max-only compute (fixed point only)
134  */
135  void setComputeMaxOnly(bool compute_max_only_) { compute_max_only = compute_max_only_; }
136 
137  bool advanceSharedBytes(TuneParam &param) const { return false; }
138 
139  bool advanceAux(TuneParam &param) const
140  {
141  if (use_mma) {
142  constexpr bool compute_max_only_dummy = true;
143  constexpr bool query_max = true;
144  int max = mma::template launch_yhat_kernel<compute_max_only_dummy, query_max>(arg, 1, param, 0);
145  if (param.aux.x < max) {
146  param.aux.x++;
147  return true;
148  }
149  return false;
150  } else {
151  return false;
152  }
153  }
154 
155  bool advanceTuneParam(TuneParam &param) const
156  {
157  if (!use_mma) {
158  if (meta.Location() == QUDA_CUDA_FIELD_LOCATION && meta.MemType() == QUDA_MEMORY_DEVICE)
159  return Tunable::advanceTuneParam(param);
160  else
161  return false;
162  } else {
163  return false;
164  }
165  }
166 
167  void initTuneParam(TuneParam &param) const
168  {
169  TunableVectorYZ::initTuneParam(param);
170  param.aux = make_int4(0, 0, 0, 0);
171  }
172 
173  void defaultTuneParam(TuneParam &param) const
174  {
175  TunableVectorYZ::defaultTuneParam(param);
176  param.aux = make_int4(0, 0, 0, 0);
177  }
178 
179  TuneKey tuneKey() const {
180  char Aux[TuneKey::aux_n];
181  strcpy(Aux,aux);
182  if (compute_max_only) strcat(Aux, ",compute_max_only");
183  if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
184  strcat(Aux, meta.MemType() == QUDA_MEMORY_MAPPED ? ",GPU-mapped" : ",GPU-device");
185  } else if (meta.Location() == QUDA_CPU_FIELD_LOCATION) {
186  strcat(Aux, ",CPU");
187  strcat(Aux, getOmpThreadStr());
188  }
189  if (use_mma) { strcat(Aux, ",MMA"); }
190  return TuneKey(meta.VolString(), typeid(*this).name(), Aux);
191  }
192  };
193 
194  /**
195  @brief Calculate the preconditioned coarse-link field and the clover inverse.
196  @param Yhat[out] Preconditioned coarse link field
197  @param Xinv[out] Coarse clover inverse field
198  @param Y[out] Coarse link field
199  @param X[out] Coarse clover field
200  @param use_mma[in] Whether or not use MMA (tensor core) to do the calculation
201  */
202  template <QudaFieldLocation location, typename storeFloat, typename Float, int N, QudaGaugeFieldOrder gOrder>
203  void calculateYhat(GaugeField &Yhat, GaugeField &Xinv, const GaugeField &Y, const GaugeField &X, bool use_mma)
204  {
205  using namespace blas_lapack;
206  auto invert = use_native() ? native::BatchInvertMatrix : generic::BatchInvertMatrix;
207 
208  constexpr QudaGaugeFieldOrder gOrder_milc = QUDA_MILC_GAUGE_ORDER;
209  GaugeField *Xinv_aos = nullptr;
210 
211  // invert the clover matrix field
212  const int n = X.Ncolor();
213 
214  if (X.Location() == QUDA_CUDA_FIELD_LOCATION) {
215 
216  auto create_gauge_copy = [](const GaugeField &X, bool copy_content) -> auto
217  {
218  GaugeField *output = nullptr;
219  if (X.Order() == gOrder_milc && X.Precision() >= QUDA_SINGLE_PRECISION) {
220  output = const_cast<GaugeField *>(&X);
221  } else {
222  GaugeFieldParam param(X);
223  param.order = gOrder_milc;
224  param.setPrecision(X.Precision() < QUDA_SINGLE_PRECISION ? QUDA_SINGLE_PRECISION : X.Precision());
225  output = cudaGaugeField::Create(param);
226  if (copy_content) output->copy(X);
227  }
228  return output;
229  };
230 
231  GaugeField *X_aos = create_gauge_copy(X, true);
232  Xinv_aos = create_gauge_copy(Xinv, false);
233 
234  blas::flops += invert((void *)Xinv_aos->Gauge_p(), (void *)X_aos->Gauge_p(), n, X_aos->Volume(),
235  X_aos->Precision(), X.Location());
236 
237  if (&Xinv != Xinv_aos) {
238  if (Xinv.Precision() < QUDA_SINGLE_PRECISION) Xinv.Scale(Xinv_aos->abs_max());
239  Xinv.copy(*Xinv_aos);
240  }
241  if (&X != X_aos) { delete X_aos; }
242 
243  if (!use_mma) { delete Xinv_aos; }
244 
245  } else if (X.Location() == QUDA_CPU_FIELD_LOCATION && X.Order() == QUDA_QDP_GAUGE_ORDER) {
246  const cpuGaugeField *X_h = static_cast<const cpuGaugeField*>(&X);
247  cpuGaugeField *Xinv_h = static_cast<cpuGaugeField*>(&Xinv);
248  blas::flops += invert(*(void**)Xinv_h->Gauge_p(), *(void**)X_h->Gauge_p(), n, X_h->Volume(), X.Precision(), X.Location());
249  } else {
250  errorQuda("Unsupported location=%d and order=%d", X.Location(), X.Order());
251  }
252 
253  // now exchange Y halos of both forwards and backwards links for multi-process dslash
254  const_cast<GaugeField&>(Y).exchangeGhost(QUDA_LINK_BIDIRECTIONAL);
255 
256  // compute the preconditioned links
257  // Yhat_back(x-\mu) = Y_back(x-\mu) * Xinv^dagger(x) (positive projector)
258  // Yhat_fwd(x) = Xinv(x) * Y_fwd(x) (negative projector)
259  {
260  int xc_size[5];
261  for (int i=0; i<4; i++) xc_size[i] = X.X()[i];
262  xc_size[4] = 1;
263 
264  if (use_mma) {
265 
266  auto create_gauge_copy = [](const GaugeField &X, QudaGaugeFieldOrder order, bool copy_content) -> auto
267  {
268  GaugeField *output = nullptr;
269  if (X.Order() == order) {
270  output = const_cast<GaugeField *>(&X);
271  } else {
272  GaugeFieldParam param(X);
273  param.order = order;
274  output = cudaGaugeField::Create(param);
275  if (copy_content) output->copy(X);
276  }
277  return output;
278  };
279 
280  GaugeField *Y_aos = create_gauge_copy(Y, gOrder_milc, true);
281  GaugeField *Yhat_aos = create_gauge_copy(Yhat, gOrder_milc, false);
282 
283  constexpr bool use_native_ghosts = true;
284  // use spin-ignorant accessor to make multiplication simpler
285  typedef typename gauge::FieldOrder<Float, N, 1, gOrder_milc, use_native_ghosts, storeFloat> gCoarse;
286  typedef typename gauge::FieldOrder<Float, N, 1, gOrder_milc, use_native_ghosts, storeFloat> gPreconditionedCoarse;
287  gCoarse yAccessor(*Y_aos);
288  gPreconditionedCoarse yHatAccessor(*Yhat_aos);
289 
290  // XXX: This doesn't work for double precision.
291  using gCoarseInv = gauge::FieldOrder<float, N, 1, gOrder_milc, use_native_ghosts, float>;
292  gCoarseInv xInvAccessor(*Xinv_aos);
293  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("Xinv = %e\n", Xinv_aos->norm2(0));
294 
295  int comm_dim[4];
296  for (int i = 0; i < 4; i++) comm_dim[i] = comm_dim_partitioned(i);
297 
298  using yHatArg = CalculateYhatArg<Float, gPreconditionedCoarse, gCoarse, gCoarseInv, N, 4, 2>;
299  yHatArg arg(yHatAccessor, yAccessor, xInvAccessor, xc_size, comm_dim, 1);
300 
301  CalculateYhat<location, yHatArg> yHat(arg, Y, use_mma);
302  if (Yhat.Precision() == QUDA_HALF_PRECISION || Yhat.Precision() == QUDA_QUARTER_PRECISION) {
303  yHat.setComputeMaxOnly(true);
304  yHat.apply(0);
305 
306  double max_h_double = *arg.max_h;
307  comm_allreduce_max(&max_h_double);
308  *arg.max_h = static_cast<Float>(max_h_double);
309 
310  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("Yhat Max = %e\n", *arg.max_h);
311 
312  Yhat_aos->Scale(*arg.max_h);
313  arg.Yhat.resetScale(*arg.max_h);
314  }
315  yHat.setComputeMaxOnly(false);
316  yHat.apply(0);
317 
318  if (&Y != Y_aos) { delete Y_aos; }
319 
320  if (&Yhat != Yhat_aos) {
321  Yhat.copy(*Yhat_aos);
322  delete Yhat_aos;
323  }
324 
325  if (Xinv_aos != &Xinv) { delete Xinv_aos; }
326 
327  } else {
328 
329  // use spin-ignorant accessor to make multiplication simpler
330  typedef typename gauge::FieldOrder<Float, N, 1, gOrder, true, storeFloat> gCoarse;
331  typedef typename gauge::FieldOrder<Float, N, 1, gOrder, true, storeFloat> gPreconditionedCoarse;
332  gCoarse yAccessor(const_cast<GaugeField &>(Y));
333  gPreconditionedCoarse yHatAccessor(const_cast<GaugeField &>(Yhat));
334  gCoarse xInvAccessor(const_cast<GaugeField &>(Xinv));
335  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("Xinv = %e\n", Xinv.norm2(0));
336 
337  int comm_dim[4];
338  for (int i = 0; i < 4; i++) comm_dim[i] = comm_dim_partitioned(i);
339  typedef CalculateYhatArg<Float, gPreconditionedCoarse, gCoarse, gCoarse, N, 4, 2> yHatArg;
340  yHatArg arg(yHatAccessor, yAccessor, xInvAccessor, xc_size, comm_dim, 1);
341 
342  CalculateYhat<location, yHatArg> yHat(arg, Y, use_mma);
343  if (Yhat.Precision() == QUDA_HALF_PRECISION || Yhat.Precision() == QUDA_QUARTER_PRECISION) {
344  yHat.setComputeMaxOnly(true);
345  yHat.apply(0);
346 
347  double max_h_double = *arg.max_h;
348  comm_allreduce_max(&max_h_double);
349  *arg.max_h = static_cast<Float>(max_h_double);
350 
351  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("Yhat Max = %e\n", *arg.max_h);
352 
353  Yhat.Scale(*arg.max_h);
354  arg.Yhat.resetScale(*arg.max_h);
355  }
356  yHat.setComputeMaxOnly(false);
357  yHat.apply(0);
358  }
359 
360  if (getVerbosity() >= QUDA_VERBOSE)
361  {
362  if (use_mma && X.Location() == QUDA_CUDA_FIELD_LOCATION)
363  warningQuda("There is a known issue with Yhat norms 0 through 3 for CUDA+MMA builds. These are harmless and will be addressed in the future.\n");
364  for (int d = 0; d < 8; d++)
365  printfQuda("Yhat[%d] = %e (%e %e = %e x %e)\n", d, Yhat.norm2(d), Yhat.abs_max(d),
366  Y.abs_max(d) * Xinv.abs_max(0), Y.abs_max(d), Xinv.abs_max(0));
367  }
368  }
369 
370  // fill back in the bulk of Yhat so that the backward link is updated on the previous node
371  // need to put this in the bulk of the previous node - but only send backwards the backwards
372  // links to and not overwrite the forwards bulk
373  Yhat.injectGhost(QUDA_LINK_BACKWARDS);
374 
375  // exchange forwards links for multi-process dslash dagger
376  // need to put this in the ghost zone of the next node - but only send forwards the forwards
377  // links and not overwrite the backwards ghost
378  Yhat.exchangeGhost(QUDA_LINK_FORWARDS);
379  }
380 
381  template <typename storeFloat, typename Float, int N>
382  void calculateYhat(GaugeField &Yhat, GaugeField &Xinv, const GaugeField &Y, const GaugeField &X, bool use_mma)
383  {
384  if (Y.Location() == QUDA_CPU_FIELD_LOCATION) {
385  constexpr QudaGaugeFieldOrder gOrder = QUDA_QDP_GAUGE_ORDER;
386  if (Y.FieldOrder() != gOrder) errorQuda("Unsupported field order %d\n", Y.FieldOrder());
387  calculateYhat<QUDA_CPU_FIELD_LOCATION, storeFloat, Float, N, gOrder>(Yhat, Xinv, Y, X, use_mma);
388  } else {
389  constexpr QudaGaugeFieldOrder gOrder = QUDA_FLOAT2_GAUGE_ORDER;
390  // if (Y.FieldOrder() != gOrder) errorQuda("Unsupported field order %d\n", Y.FieldOrder());
391  calculateYhat<QUDA_CUDA_FIELD_LOCATION, storeFloat, Float, N, gOrder>(Yhat, Xinv, Y, X, use_mma);
392  }
393  }
394 
395  // template on the number of coarse degrees of freedom
396  template <typename storeFloat, typename Float>
397  void calculateYhat(GaugeField &Yhat, GaugeField &Xinv, const GaugeField &Y, const GaugeField &X, bool use_mma)
398  {
399  switch (Y.Ncolor()) {
400  case 48: calculateYhat<storeFloat, Float, 48>(Yhat, Xinv, Y, X, use_mma); break;
401 #ifdef NSPIN4
402  case 12: calculateYhat<storeFloat, Float, 12>(Yhat, Xinv, Y, X, use_mma); break;
403  case 64: calculateYhat<storeFloat, Float, 64>(Yhat, Xinv, Y, X, use_mma); break;
404 #endif // NSPIN4
405 #ifdef NSPIN1
406  case 128: calculateYhat<storeFloat, Float, 128>(Yhat, Xinv, Y, X, use_mma); break;
407  case 192: calculateYhat<storeFloat, Float, 192>(Yhat, Xinv, Y, X, use_mma); break;
408 #endif // NSPIN1
409  default: errorQuda("Unsupported number of coarse dof %d\n", Y.Ncolor()); break;
410  }
411  }
412 
413  //Does the heavy lifting of creating the coarse color matrices Y
414  void calculateYhat(GaugeField &Yhat, GaugeField &Xinv, const GaugeField &Y, const GaugeField &X, bool use_mma)
415  {
416 
417 #ifdef GPU_MULTIGRID
418  QudaPrecision precision = checkPrecision(Xinv, Y, X);
419  if (getVerbosity() >= QUDA_SUMMARIZE) printfQuda("Computing Yhat field......\n");
420 
421  if (precision == QUDA_DOUBLE_PRECISION) {
422 #ifdef GPU_MULTIGRID_DOUBLE
423  if (Yhat.Precision() != QUDA_DOUBLE_PRECISION) errorQuda("Unsupported precision %d\n", Yhat.Precision());
424  if (use_mma) errorQuda("MG-MMA does not support double precision, yet.");
425  calculateYhat<double, double>(Yhat, Xinv, Y, X, use_mma);
426 #else
427  errorQuda("Double precision multigrid has not been enabled");
428 #endif
429  } else if (precision == QUDA_SINGLE_PRECISION) {
430  if (Yhat.Precision() == QUDA_SINGLE_PRECISION) {
431  calculateYhat<float, float>(Yhat, Xinv, Y, X, use_mma);
432  } else {
433  errorQuda("Unsupported precision %d\n", precision);
434  }
435  } else if (precision == QUDA_HALF_PRECISION) {
436  if (Yhat.Precision() == QUDA_HALF_PRECISION) {
437  calculateYhat<short, float>(Yhat, Xinv, Y, X, use_mma);
438  } else {
439  errorQuda("Unsupported precision %d\n", precision);
440  }
441  } else {
442  errorQuda("Unsupported precision %d\n", precision);
443  }
444 
445  if (getVerbosity() >= QUDA_SUMMARIZE) printfQuda("....done computing Yhat field\n");
446 #else
447  errorQuda("Multigrid has not been built");
448 #endif
449  }
450 
451 } // namespace quda