QUDA  v1.1.0
A library for QCD on GPUs
reduce_quda.cu
Go to the documentation of this file.
1 #include <blas_quda.h>
2 #include <tune_quda.h>
3 #include <color_spinor_field_order.h>
4 #include <jitify_helper.cuh>
5 #include <kernels/reduce_core.cuh>
6 
7 namespace quda {
8 
9  namespace blas {
10 
11  qudaStream_t* getStream();
12 
13  template <int block_size, typename real, int len, typename Arg>
14  typename std::enable_if<block_size!=32, qudaError_t>::type launch(Arg &arg, const TuneParam &tp, const qudaStream_t &stream)
15  {
16  if (tp.block.x == block_size)
17  return qudaLaunchKernel(reduceKernel<block_size, real, len, Arg>, tp, stream, arg);
18  else
19  return launch<block_size - 32, real, len>(arg, tp, stream);
20  }
21 
22  template <int block_size, typename real, int len, typename Arg>
23  typename std::enable_if<block_size==32, qudaError_t>::type launch(Arg &arg, const TuneParam &tp, const qudaStream_t &stream)
24  {
25  if (block_size != tp.block.x) errorQuda("Unexpected block size %d\n", tp.block.x);
26  return qudaLaunchKernel(reduceKernel<block_size, real, len, Arg>, tp, stream, arg);
27  }
28 
29 #ifdef QUDA_FAST_COMPILE_REDUCE
30  constexpr static unsigned int max_block_size() { return 32; }
31 #else
32  constexpr static unsigned int max_block_size() { return 1024; }
33 #endif
34 
35  /**
36  Generic reduction kernel launcher
37  */
38  template <typename host_reduce_t, typename real, int len, typename Arg>
39  auto reduceLaunch(Arg &arg, const TuneParam &tp, const qudaStream_t &stream, Tunable &tunable)
40  {
41  using device_reduce_t = typename Arg::Reducer::reduce_t;
42  if (tp.grid.x > (unsigned int)deviceProp.maxGridSize[0])
43  errorQuda("Grid size %d greater than maximum %d\n", tp.grid.x, deviceProp.maxGridSize[0]);
44 
45 #ifdef JITIFY
46  using namespace jitify::reflection;
47  tunable.jitifyError() = program->kernel("quda::blas::reduceKernel")
48  .instantiate((int)tp.block.x, Type<real>(), len, Type<Arg>())
49  .configure(tp.grid, tp.block, tp.shared_bytes, stream)
50  .launch(arg);
51  arg.launch_error = tunable.jitifyError() == CUDA_SUCCESS ? QUDA_SUCCESS : QUDA_ERROR;
52 #else
53  arg.launch_error = launch<max_block_size(), real, len>(arg, tp, stream);
54 #endif
55 
56  host_reduce_t result;
57  ::quda::zero(result);
58  if (!commAsyncReduction()) arg.complete(result, stream);
59  return result;
60  }
61 
62  template <template <typename ReducerType, typename real> class Reducer,
63  typename store_t, typename y_store_t, int nSpin, typename coeff_t>
64  class Reduce : public Tunable
65  {
66  using real = typename mapper<y_store_t>::type;
67  using host_reduce_t = typename Reducer<double, real>::reduce_t;
68  Reducer<device_reduce_t, real> r;
69  const int nParity; // for composite fields this includes the number of composites
70  host_reduce_t &result;
71 
72  const coeff_t &a, &b;
73  ColorSpinorField &x, &y, &z, &w, &v;
74  QudaFieldLocation location;
75 
76  unsigned int sharedBytesPerThread() const { return 0; }
77  unsigned int sharedBytesPerBlock(const TuneParam &param) const { return 0; }
78 
79  bool advanceSharedBytes(TuneParam &param) const
80  {
81  TuneParam next(param);
82  advanceBlockDim(next); // to get next blockDim
83  int nthreads = next.block.x * next.block.y * next.block.z;
84  param.shared_bytes = sharedBytesPerThread() * nthreads > sharedBytesPerBlock(param) ?
85  sharedBytesPerThread() * nthreads :
86  sharedBytesPerBlock(param);
87  return false;
88  }
89 
90  unsigned int maxBlockSize(const TuneParam &param) const { return max_block_size(); }
91 
92  public:
93  Reduce(const coeff_t &a, const coeff_t &b, const coeff_t &c, ColorSpinorField &x, ColorSpinorField &y,
94  ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v, host_reduce_t &result) :
95  r(a, b),
96  nParity((x.IsComposite() ? x.CompositeDim() : 1) * (x.SiteSubset())),
97  a(a),
98  b(b),
99  x(x),
100  y(y),
101  z(z),
102  w(w),
103  v(v),
104  result(result),
105  location(checkLocation(x, y, z, w, v))
106  {
107  checkLength(x, y, z, w, v);
108  auto x_prec = checkPrecision(x, z, w, v);
109  auto y_prec = y.Precision();
110  auto x_order = checkOrder(x, z, w, v);
111  auto y_order = y.FieldOrder();
112  if (sizeof(store_t) != x_prec) errorQuda("Expected precision %lu but received %d", sizeof(store_t), x_prec);
113  if (sizeof(y_store_t) != y_prec) errorQuda("Expected precision %lu but received %d", sizeof(y_store_t), y_prec);
114  if (x_prec == y_prec && x_order != y_order) errorQuda("Orders %d %d do not match", x_order, y_order);
115 
116  strcpy(aux, x.AuxString());
117  if (x_prec != y_prec) {
118  strcat(aux, ",");
119  strcat(aux, y.AuxString());
120  }
121  strcat(aux, nParity == 2 ? ",nParity=2" : ",nParity=1");
122  if (location == QUDA_CPU_FIELD_LOCATION) strcat(aux, ",CPU");
123  if (commAsyncReduction()) strcat(aux, ",async");
124 
125 #ifdef JITIFY
126  ::quda::create_jitify_program("kernels/reduce_core.cuh");
127 #endif
128 
129  apply(*(blas::getStream()));
130 
131  blas::bytes += bytes();
132  blas::flops += flops();
133 
134  const int Nreduce = sizeof(host_reduce_t) / sizeof(double);
135  reduceDoubleArray((double *)&result, Nreduce);
136  }
137 
138  TuneKey tuneKey() const { return TuneKey(x.VolString(), typeid(r).name(), aux); }
139 
140  void apply(const qudaStream_t &stream)
141  {
142  constexpr bool site_unroll_check = !std::is_same<store_t, y_store_t>::value || isFixed<store_t>::value || decltype(r)::site_unroll;
143  if (site_unroll_check && (x.Ncolor() != 3 || x.Nspin() == 2))
144  errorQuda("site unroll not supported for nSpin = %d nColor = %d", x.Nspin(), x.Ncolor());
145 
146  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
147  if (location == QUDA_CUDA_FIELD_LOCATION) {
148  if (site_unroll_check) checkNative(x, y, z, w, v); // require native order when using site_unroll
149  using device_store_t = typename device_type_mapper<store_t>::type;
150  using device_y_store_t = typename device_type_mapper<y_store_t>::type;
151  using device_real_t = typename mapper<device_y_store_t>::type;
152  Reducer<device_reduce_t, device_real_t> r_(a, b);
153 
154  // redefine site_unroll with device_store types to ensure we have correct N/Ny/M values
155  constexpr bool site_unroll = !std::is_same<device_store_t, device_y_store_t>::value || isFixed<device_store_t>::value || decltype(r)::site_unroll;
156  constexpr int N = n_vector<device_store_t, true, nSpin, site_unroll>();
157  constexpr int Ny = n_vector<device_y_store_t, true, nSpin, site_unroll>();
158  constexpr int M = site_unroll ? (nSpin == 4 ? 24 : 6) : N; // real numbers per thread
159  const int length = x.Length() / (nParity * M);
160 
161  ReductionArg<device_store_t, N, device_y_store_t, Ny, decltype(r_)> arg(x, y, z, w, v, r_, length, nParity, tp);
162  result = reduceLaunch<host_reduce_t, device_real_t, M>(arg, tp, stream, *this);
163  } else {
164  if (checkOrder(x, y, z, w, v) != QUDA_SPACE_SPIN_COLOR_FIELD_ORDER) {
165  warningQuda("CPU Blas functions expect AoS field order");
166  return;
167  }
168 
169  using host_store_t = typename host_type_mapper<store_t>::type;
170  using host_y_store_t = typename host_type_mapper<y_store_t>::type;
171  using host_real_t = typename mapper<host_y_store_t>::type;
172  Reducer<double, host_real_t> r_(a, b);
173 
174  // redefine site_unroll with host_store types to ensure we have correct N/Ny/M values
175  constexpr bool site_unroll = !std::is_same<host_store_t, host_y_store_t>::value || isFixed<host_store_t>::value || decltype(r)::site_unroll;
176  constexpr int N = n_vector<host_store_t, false, nSpin, site_unroll>();
177  constexpr int Ny = n_vector<host_y_store_t, false, nSpin, site_unroll>();
178  constexpr int M = N; // if site unrolling then M=N will be 24/6, e.g., full AoS
179  const int length = x.Length() / (nParity * M);
180 
181  ReductionArg<host_store_t, N, host_y_store_t, Ny, decltype(r_)> arg(x, y, z, w, v, r_, length, nParity, tp);
182  result = reduceCPU<host_real_t, M>(arg);
183  }
184  }
185 
186  void preTune()
187  {
188  if (r.write.X) x.backup();
189  if (r.write.Y) y.backup();
190  if (r.write.Z) z.backup();
191  if (r.write.W) w.backup();
192  if (r.write.V) v.backup();
193  }
194 
195  void postTune()
196  {
197  if (r.write.X) x.restore();
198  if (r.write.Y) y.restore();
199  if (r.write.Z) z.restore();
200  if (r.write.W) w.restore();
201  if (r.write.V) v.restore();
202  }
203 
204  bool advanceTuneParam(TuneParam &param) const
205  {
206  return location == QUDA_CPU_FIELD_LOCATION ? false : Tunable::advanceTuneParam(param);
207  }
208 
209  void initTuneParam(TuneParam &param) const
210  {
211  Tunable::initTuneParam(param);
212  }
213 
214  void defaultTuneParam(TuneParam &param) const
215  {
216  Tunable::defaultTuneParam(param);
217  }
218 
219  long long flops() const { return r.flops() * x.Length(); }
220 
221  long long bytes() const
222  {
223  return (r.read.X + r.write.X) * x.Bytes() + (r.read.Y + r.write.Y) * y.Bytes() +
224  (r.read.Z + r.write.Z) * z.Bytes() + (r.read.W + r.write.W) * w.Bytes() + (r.read.V + r.write.V) * v.Bytes();
225  }
226 
227  int tuningIter() const { return 3; }
228  };
229 
230  template <template <typename reduce_t, typename real> class Functor, bool mixed, typename... Args>
231  auto instantiateReduce(Args &&... args)
232  {
233  using host_reduce_t = typename Functor<double, double>::reduce_t;
234  host_reduce_t value;
235  ::quda::zero(value); // no default constructor so we need to explicitly zero
236  instantiate<Functor, Reduce, mixed>(args..., value);
237  return value;
238  }
239 
240  double norm1(const ColorSpinorField &x)
241  {
242  ColorSpinorField &y = const_cast<ColorSpinorField &>(x); // FIXME
243  return instantiateReduce<Norm1, false>(0.0, 0.0, 0.0, y, y, y, y, y);
244  }
245 
246  double norm2(const ColorSpinorField &x)
247  {
248  ColorSpinorField &y = const_cast<ColorSpinorField &>(x);
249  return instantiateReduce<Norm2, false>(0.0, 0.0, 0.0, y, y, y, y, y);
250  }
251 
252  double reDotProduct(ColorSpinorField &x, ColorSpinorField &y)
253  {
254  return instantiateReduce<Dot, false>(0.0, 0.0, 0.0, x, y, x, x, x);
255  }
256 
257  double axpbyzNorm(double a, ColorSpinorField &x, double b, ColorSpinorField &y, ColorSpinorField &z)
258  {
259  return instantiateReduce<axpbyzNorm2, false>(a, b, 0.0, x, y, z, x, x);
260  }
261 
262  double axpyReDot(double a, ColorSpinorField &x, ColorSpinorField &y)
263  {
264  return instantiateReduce<AxpyReDot, false>(a, 0.0, 0.0, x, y, x, x, x);
265  }
266 
267  double caxpyNorm(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
268  {
269  return instantiateReduce<caxpyNorm2, false>(a, Complex(0.0), Complex(0.0), x, y, x, x, x);
270  }
271 
272  double caxpyXmazNormX(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
273  {
274  return instantiateReduce<caxpyxmaznormx, false>(a, Complex(0.0), Complex(0.0), x, y, z, x, x);
275  }
276 
277  double cabxpyzAxNorm(double a, const Complex &b, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
278  {
279  return instantiateReduce<cabxpyzaxnorm, false>(Complex(a), b, Complex(0.0), x, y, z, x, x);
280  }
281 
282  Complex cDotProduct(ColorSpinorField &x, ColorSpinorField &y)
283  {
284  auto cdot = instantiateReduce<Cdot, false>(0.0, 0.0, 0.0, x, y, x, x, x);
285  return Complex(cdot.x, cdot.y);
286  }
287 
288  Complex caxpyDotzy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
289  {
290  auto cdot = instantiateReduce<caxpydotzy, false>(a, Complex(0.0), Complex(0.0), x, y, z, x, x);
291  return Complex(cdot.x, cdot.y);
292  }
293 
294  double3 cDotProductNormA(ColorSpinorField &x, ColorSpinorField &y)
295  {
296  return instantiateReduce<CdotNormA, false>(0.0, 0.0, 0.0, x, y, x, x, x);
297  }
298 
299  double3 caxpbypzYmbwcDotProductUYNormY(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y,
300  ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &u)
301  {
302  return instantiateReduce<caxpbypzYmbwcDotProductUYNormY_, true>(a, b, Complex(0.0), x, z, y, w, u);
303  }
304 
305  Complex axpyCGNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
306  {
307  double2 cg_norm = instantiateReduce<axpyCGNorm2, true>(a, 0.0, 0.0, x, y, x, x, x);
308  return Complex(cg_norm.x, cg_norm.y);
309  }
310 
311  double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
312  {
313  // in case of x.Ncolor()!=3 (MG mainly) reduce_core do not support this function.
314  if (x.Ncolor() != 3) return make_double3(0.0, 0.0, 0.0);
315  double3 rtn = instantiateReduce<HeavyQuarkResidualNorm_, false>(0.0, 0.0, 0.0, x, r, r, r, r);
316  rtn.z /= (x.Volume()*comm_size());
317  return rtn;
318  }
319 
320  double3 xpyHeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &r)
321  {
322  // in case of x.Ncolor()!=3 (MG mainly) reduce_core do not support this function.
323  if (x.Ncolor()!=3) return make_double3(0.0, 0.0, 0.0);
324  double3 rtn = instantiateReduce<xpyHeavyQuarkResidualNorm_, false>(0.0, 0.0, 0.0, x, y, r, r, r);
325  rtn.z /= (x.Volume()*comm_size());
326  return rtn;
327  }
328 
329  double3 tripleCGReduction(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
330  {
331  return instantiateReduce<tripleCGReduction_, false>(0.0, 0.0, 0.0, x, y, z, x, x);
332  }
333 
334  double4 quadrupleCGReduction(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
335  {
336  return instantiateReduce<quadrupleCGReduction_, false>(0.0, 0.0, 0.0, x, y, z, x, x);
337  }
338 
339  double quadrupleCG3InitNorm(double a, ColorSpinorField &x, ColorSpinorField &y,
340  ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v)
341  {
342  return instantiateReduce<quadrupleCG3InitNorm_, false>(a, 0.0, 0.0, x, y, z, w, v);
343  }
344 
345  double quadrupleCG3UpdateNorm(double a, double b, ColorSpinorField &x, ColorSpinorField &y,
346  ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v)
347  {
348  return instantiateReduce<quadrupleCG3UpdateNorm_, false>(a, b, 0.0, x, y, z, w, v);
349  }
350 
351  } // namespace blas
352 
353 } // namespace quda