QUDA  v1.1.0
A library for QCD on GPUs
multi_blas_quda.cu
Go to the documentation of this file.
1 #include <stdlib.h>
2 #include <stdio.h>
3 #include <cstring> // needed for memset
4 
5 #include <tune_quda.h>
6 #include <blas_quda.h>
7 #include <color_spinor_field.h>
8 
9 #include <jitify_helper.cuh>
10 #include <kernels/multi_blas_core.cuh>
11 
12 namespace quda {
13 
14  namespace blas {
15 
16  qudaStream_t* getStream();
17 
18  template <template <typename ...> class Functor, typename store_t, typename y_store_t, int nSpin, typename T>
19  class MultiBlas : public TunableVectorY
20  {
21  using real = typename mapper<y_store_t>::type;
22  const int NXZ;
23  const int NYW;
24  Functor<real> f;
25  int max_warp_split;
26  mutable int warp_split; // helper used to keep track of current warp splitting
27  const int nParity;
28  const T &a, &b, &c;
29  std::vector<ColorSpinorField *> &x, &y, &z, &w;
30  const QudaFieldLocation location;
31 
32  bool tuneSharedBytes() const { return false; }
33 
34  // for these streaming kernels, there is no need to tune the grid size, just use max
35  unsigned int minGridSize() const { return maxGridSize(); }
36 
37  public:
38  MultiBlas(const T &a, const T &b, const T &c, const ColorSpinorField &x_meta, const ColorSpinorField &y_meta,
39  std::vector<ColorSpinorField *> &x, std::vector<ColorSpinorField *> &y,
40  std::vector<ColorSpinorField *> &z, std::vector<ColorSpinorField *> &w) :
41  TunableVectorY(y.size()),
42  NXZ(x.size()),
43  NYW(y.size()),
44  f(NXZ, NYW),
45  warp_split(1),
46  nParity(x[0]->SiteSubset()),
47  a(a),
48  b(b),
49  c(c),
50  x(x),
51  y(y),
52  z(z),
53  w(w),
54  location(checkLocation(*x[0], *y[0], *z[0], *w[0]))
55  {
56  checkLength(*x[0], *y[0], *z[0], *w[0]);
57  auto x_prec = checkPrecision(*x[0], *z[0], *w[0]);
58  auto y_prec = y[0]->Precision();
59  auto x_order = checkOrder(*x[0], *z[0], *w[0]);
60  auto y_order = y[0]->FieldOrder();
61  if (sizeof(store_t) != x_prec) errorQuda("Expected precision %lu but received %d", sizeof(store_t), x_prec);
62  if (sizeof(y_store_t) != y_prec) errorQuda("Expected precision %lu but received %d", sizeof(y_store_t), y_prec);
63  if (x_prec == y_prec && x_order != y_order) errorQuda("Orders %d %d do not match", x_order, y_order);
64 
65  // heuristic for enabling if we need the warp-splitting optimization
66  const int gpu_size = 2 * deviceProp.maxThreadsPerBlock * deviceProp.multiProcessorCount;
67  switch (gpu_size / (x[0]->Length() * NYW)) {
68  case 0: max_warp_split = 1; break; // we have plenty of work, no need to split
69  case 1: max_warp_split = 2; break; // double the thread count
70  case 2: // quadruple the thread count
71  default: max_warp_split = 4;
72  }
73  max_warp_split = std::min(NXZ, max_warp_split); // ensure we only split if valid
74 
75  Amatrix_h = reinterpret_cast<signed char *>(const_cast<typename T::type *>(a.data));
76  Bmatrix_h = reinterpret_cast<signed char *>(const_cast<typename T::type *>(b.data));
77  Cmatrix_h = reinterpret_cast<signed char *>(const_cast<typename T::type *>(c.data));
78 
79  strcpy(aux, x[0]->AuxString());
80  if (x_prec != y_prec) {
81  strcat(aux, ",");
82  strcat(aux, y[0]->AuxString());
83  }
84 
85 #ifdef JITIFY
86  ::quda::create_jitify_program("kernels/multi_blas_core.cuh");
87 #endif
88 
89  apply(*getStream());
90 
91  blas::bytes += bytes();
92  blas::flops += flops();
93  }
94 
95  TuneKey tuneKey() const
96  {
97  char name[TuneKey::name_n];
98  char NXZ_str[8];
99  char NYW_str[8];
100  u32toa(NXZ_str, NXZ);
101  u32toa(NYW_str, NYW);
102  strcpy(name, "Nxz");
103  strcat(name, NXZ_str);
104  strcat(name, "Nyw");
105  strcat(name, NYW_str);
106  strcat(name, typeid(f).name());
107  return TuneKey(x[0]->VolString(), name, aux);
108  }
109 
110  template <bool multi_1d, typename device_buffer_t, typename Arg> typename std::enable_if<multi_1d, void>::type
111  set_param(device_buffer_t &&buf_d, Arg &arg, char select, const T &h, const qudaStream_t &stream)
112  {
113  using coeff_t = typename decltype(arg.f)::coeff_t;
114  coeff_t *buf_arg = nullptr;
115  switch (select) {
116  case 'a': buf_arg = arg.f.a; break;
117  case 'b': buf_arg = arg.f.b; break;
118  case 'c': buf_arg = arg.f.c; break;
119  default: errorQuda("Unknown buffer %c", select);
120  }
121  const auto N = std::max(NXZ,NYW);
122  for (int i = 0; i < N; i++) buf_arg[i] = coeff_t(h.data[i]);
123  }
124 
125  template <bool multi_1d, typename device_buffer_t, typename Arg> typename std::enable_if<!multi_1d, void>::type
126  set_param(device_buffer_t &&buf_d, Arg &arg, char dummy, const T &h, const qudaStream_t &stream)
127  {
128  using coeff_t = typename decltype(arg.f)::coeff_t;
129  constexpr size_t n_coeff = MAX_MATRIX_SIZE / sizeof(coeff_t);
130 
131  coeff_t tmp[n_coeff];
132  for (int i = 0; i < NXZ; i++)
133  for (int j = 0; j < NYW; j++) tmp[NYW * i + j] = coeff_t(h.data[NYW * i + j]);
134 
135 #ifdef JITIFY
136  cuMemcpyHtoDAsync(buf_d, tmp, NXZ * NYW * sizeof(coeff_t), stream);
137 #else
138  cudaMemcpyToSymbolAsync(buf_d, tmp, NXZ * NYW * sizeof(coeff_t), 0, cudaMemcpyHostToDevice, stream);
139 #endif
140  }
141 
142  template <int NXZ> void compute(const qudaStream_t &stream)
143  {
144  staticCheck<NXZ, store_t, y_store_t, decltype(f)>(f, x, y);
145 
146  constexpr bool site_unroll_check = !std::is_same<store_t, y_store_t>::value || isFixed<store_t>::value;
147  if (site_unroll_check && (x[0]->Ncolor() != 3 || x[0]->Nspin() == 2))
148  errorQuda("site unroll not supported for nSpin = %d nColor = %d", x[0]->Nspin(), x[0]->Ncolor());
149 
150  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
151 
152  if (location == QUDA_CUDA_FIELD_LOCATION) {
153  if (site_unroll_check) checkNative(*x[0], *y[0], *z[0], *w[0]); // require native order when using site_unroll
154  using device_store_t = typename device_type_mapper<store_t>::type;
155  using device_y_store_t = typename device_type_mapper<y_store_t>::type;
156  using device_real_t = typename mapper<device_y_store_t>::type;
157  Functor<device_real_t> f_(NXZ, NYW);
158 
159  // redefine site_unroll with device_store types to ensure we have correct N/Ny/M values
160  constexpr bool site_unroll = !std::is_same<device_store_t, device_y_store_t>::value || isFixed<device_store_t>::value;
161  constexpr int N = n_vector<device_store_t, true, nSpin, site_unroll>();
162  constexpr int Ny = n_vector<device_y_store_t, true, nSpin, site_unroll>();
163  constexpr int M = site_unroll ? (nSpin == 4 ? 24 : 6) : N; // real numbers per thread
164  const int length = x[0]->Length() / (nParity * M);
165 
166  tp.block.x *= tp.aux.x; // include warp-split factor
167 
168  MultiBlasArg<NXZ, device_store_t, N, device_y_store_t, Ny, decltype(f_)> arg(x, y, z, w, f_, NYW, length);
169 #ifdef JITIFY
170  using namespace jitify::reflection;
171  auto instance = program->kernel("quda::blas::multiBlasKernel")
172  .instantiate(Type<device_real_t>(), M, NXZ, tp.aux.x, Type<decltype(arg)>());
173 
174  if (a.data) set_param<decltype(f_)::multi_1d>(instance.get_constant_ptr("quda::blas::Amatrix_d"), arg, 'a', a, stream);
175  if (b.data) set_param<decltype(f_)::multi_1d>(instance.get_constant_ptr("quda::blas::Bmatrix_d"), arg, 'b', b, stream);
176  if (c.data) set_param<decltype(f_)::multi_1d>(instance.get_constant_ptr("quda::blas::Cmatrix_d"), arg, 'c', c, stream);
177 
178  jitify_error = instance.configure(tp.grid, tp.block, tp.shared_bytes, stream).launch(arg);
179 #else
180  if (a.data) { set_param<decltype(f_)::multi_1d>(Amatrix_d, arg, 'a', a, stream); }
181  if (b.data) { set_param<decltype(f_)::multi_1d>(Bmatrix_d, arg, 'b', b, stream); }
182  if (c.data) { set_param<decltype(f_)::multi_1d>(Cmatrix_d, arg, 'c', c, stream); }
183  switch (tp.aux.x) {
184  case 1: qudaLaunchKernel(multiBlasKernel<device_real_t, M, NXZ, 1, decltype(arg)>, tp, stream, arg); break;
185 #ifdef WARP_SPLIT
186  case 2: qudaLaunchKernel(multiBlasKernel<device_real_t, M, NXZ, 2, decltype(arg)>, tp, stream, arg); break;
187  case 4: qudaLaunchKernel(multiBlasKernel<device_real_t, M, NXZ, 4, decltype(arg)>, tp, stream, arg); break;
188 #endif
189  default: errorQuda("warp-split factor %d not instantiated", tp.aux.x);
190  }
191 #endif
192 
193  tp.block.x /= tp.aux.x; // restore block size
194  } else {
195  errorQuda("Only implemented for GPU fields");
196  }
197  }
198 
199  template <int n> typename std::enable_if<n!=1, void>::type instantiateLinear(const qudaStream_t &stream)
200  {
201  if (NXZ == n) compute<n>(stream);
202  else instantiateLinear<n-1>(stream);
203  }
204 
205  template <int n> typename std::enable_if<n==1, void>::type instantiateLinear(const qudaStream_t &stream)
206  {
207  compute<1>(stream);
208  }
209 
210  template <int n> typename std::enable_if<n!=1, void>::type instantiatePow2(const qudaStream_t &stream)
211  {
212  if (NXZ == n) compute<n>(stream);
213  else instantiatePow2<n/2>(stream);
214  }
215 
216  template <int n> typename std::enable_if<n==1, void>::type instantiatePow2(const qudaStream_t &stream)
217  {
218  compute<1>(stream);
219  }
220 
221  // instantiate the loop unrolling template
222  template <int NXZ_max> typename std::enable_if<NXZ_max!=1, void>::type instantiate(const qudaStream_t &stream)
223  {
224  // if multi-1d then constrain the templates to no larger than max-1d size
225  constexpr int pow2_max = !decltype(f)::multi_1d ? max_NXZ_power2<false, isFixed<store_t>::value>() :
226  std::min(max_N_multi_1d(), max_NXZ_power2<false, isFixed<store_t>::value>());
227  constexpr int linear_max = !decltype(f)::multi_1d ? MAX_MULTI_BLAS_N : std::min(max_N_multi_1d(), MAX_MULTI_BLAS_N);
228 
229  if (NXZ <= pow2_max && is_power2(NXZ)) instantiatePow2<pow2_max>(stream);
230  else if (NXZ <= linear_max) instantiateLinear<linear_max>(stream);
231  else errorQuda("x.size %lu greater than maximum supported size (pow2 = %d, linear = %d)", x.size(), pow2_max, linear_max);
232  }
233 
234  template <int NXZ_max> typename std::enable_if<NXZ_max==1, void>::type instantiate(const qudaStream_t &stream)
235  {
236  compute<1>(stream);
237  }
238 
239  void apply(const qudaStream_t &stream) { instantiate<decltype(f)::NXZ_max>(stream); }
240 
241  void preTune()
242  {
243  for (int i = 0; i < NYW; ++i) {
244  if (f.write.X) x[i]->backup();
245  if (f.write.Y) y[i]->backup();
246  if (f.write.Z) z[i]->backup();
247  if (f.write.W) w[i]->backup();
248  }
249  }
250 
251  void postTune()
252  {
253  for (int i = 0; i < NYW; ++i) {
254  if (f.write.X) x[i]->restore();
255  if (f.write.Y) y[i]->restore();
256  if (f.write.Z) z[i]->restore();
257  if (f.write.W) w[i]->restore();
258  }
259  }
260 
261  bool advanceAux(TuneParam &param) const
262  {
263 #ifdef WARP_SPLIT
264  if (2 * param.aux.x <= max_warp_split) {
265  param.aux.x *= 2;
266  warp_split = param.aux.x;
267  return true;
268  } else {
269  param.aux.x = 1;
270  warp_split = param.aux.x;
271  // reset the block dimension manually here to pick up the warp_split parameter
272  resetBlockDim(param);
273  return false;
274  }
275 #else
276  warp_split = 1;
277  return false;
278 #endif
279  }
280 
281  int blockStep() const { return deviceProp.warpSize / warp_split; }
282  int blockMin() const { return deviceProp.warpSize / warp_split; }
283 
284  void initTuneParam(TuneParam &param) const
285  {
286  TunableVectorY::initTuneParam(param);
287  param.grid.z = nParity;
288  param.aux = make_int4(1, 0, 0, 0); // warp-split parameter
289  }
290 
291  void defaultTuneParam(TuneParam &param) const
292  {
293  TunableVectorY::defaultTuneParam(param);
294  param.grid.z = nParity;
295  param.aux = make_int4(1, 0, 0, 0); // warp-split parameter
296  }
297 
298  long long flops() const
299  {
300  return NYW * NXZ * f.flops() * x[0]->Length();
301  }
302 
303  long long bytes() const
304  {
305  // X and Z reads are repeated (and hopefully cached) across NYW
306  // each Y and W read/write is done once
307  return NYW * NXZ * (f.read.X + f.write.X) * x[0]->Bytes() +
308  NYW * (f.read.Y + f.write.Y) * y[0]->Bytes() +
309  NYW * NXZ * (f.read.Z + f.write.Z) * z[0]->Bytes() +
310  NYW * (f.read.W + f.write.W) * w[0]->Bytes();
311  }
312 
313  int tuningIter() const { return 3; }
314  };
315 
316  using range = std::pair<size_t,size_t>;
317 
318  template <template <typename...> class Functor, typename T>
319  void axpy_recurse(const T *a_, std::vector<ColorSpinorField *> &x, std::vector<ColorSpinorField *> &y,
320  const range &range_x, const range &range_y, int upper, int coeff_width)
321  {
322  // if greater than max single-kernel size, recurse
323  if (y.size() > (size_t)max_YW_size(x.size(), x[0]->Precision(), y[0]->Precision(), false, false, coeff_width, false)) {
324  // We need to split up 'a' carefully since it's row-major.
325  T *tmpmajor = new T[x.size() * y.size()];
326  T *tmpmajor0 = &tmpmajor[0];
327  T *tmpmajor1 = &tmpmajor[x.size() * (y.size() / 2)];
328  std::vector<ColorSpinorField*> y0(y.begin(), y.begin() + y.size()/2);
329  std::vector<ColorSpinorField*> y1(y.begin() + y.size()/2, y.end());
330 
331  const unsigned int xlen = x.size();
332  const unsigned int ylen0 = y.size()/2;
333  const unsigned int ylen1 = y.size() - y.size()/2;
334 
335  int count = 0, count0 = 0, count1 = 0;
336  for (unsigned int i = 0; i < xlen; i++)
337  {
338  for (unsigned int j = 0; j < ylen0; j++)
339  tmpmajor0[count0++] = a_[count++];
340  for (unsigned int j = 0; j < ylen1; j++)
341  tmpmajor1[count1++] = a_[count++];
342  }
343 
344  axpy_recurse<Functor>(tmpmajor0, x, y0, range_x, range(range_y.first, range_y.first + y0.size()), upper, coeff_width);
345  axpy_recurse<Functor>(tmpmajor1, x, y1, range_x, range(range_y.first + y0.size(), range_y.second), upper, coeff_width);
346 
347  delete[] tmpmajor;
348  } else {
349  // if at the bottom of recursion,
350  if (is_valid_NXZ(x.size(), false, x[0]->Precision() < QUDA_SINGLE_PRECISION)) {
351  // since tile range is [first,second), e.g., [first,second-1], we need >= here
352  // if upper triangular and upper-right tile corner is below diagonal return
353  if (upper == 1 && range_y.first >= range_x.second) { return; }
354  // if lower triangular and lower-left tile corner is above diagonal return
355  if (upper == -1 && range_x.first >= range_y.second) { return; }
356 
357  // mark true since we will copy the "a" matrix into constant memory
358  coeff_array<T> a(a_), b, c;
359  constexpr bool mixed = true;
360  instantiate<Functor, MultiBlas, mixed>(a, b, c, *x[0], *y[0], x, y, x, x);
361  } else {
362  // split the problem in half and recurse
363  const T *a0 = &a_[0];
364  const T *a1 = &a_[(x.size() / 2) * y.size()];
365 
366  std::vector<ColorSpinorField *> x0(x.begin(), x.begin() + x.size() / 2);
367  std::vector<ColorSpinorField *> x1(x.begin() + x.size() / 2, x.end());
368 
369  axpy_recurse<Functor>(a0, x0, y, range(range_x.first, range_x.first + x0.size()), range_y, upper, coeff_width);
370  axpy_recurse<Functor>(a1, x1, y, range(range_x.first + x0.size(), range_x.second), range_y, upper, coeff_width);
371  }
372  } // end if (y.size() > max_YW_size())
373  }
374 
375  void caxpy(const Complex *a_, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y) {
376  // Enter a recursion.
377  // Pass a, x, y. (0,0) indexes the tiles. false specifies the matrix is unstructured.
378  axpy_recurse<multicaxpy_>(a_, x, y, range(0,x.size()), range(0,y.size()), 0, 2);
379  }
380 
381  void caxpy_U(const Complex *a_, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y) {
382  // Enter a recursion.
383  // Pass a, x, y. (0,0) indexes the tiles. 1 indicates the matrix is upper-triangular,
384  // which lets us skip some tiles.
385  if (x.size() != y.size())
386  {
387  errorQuda("An optimal block caxpy_U with non-square 'a' has not yet been implemented. Use block caxpy instead");
388  }
389  axpy_recurse<multicaxpy_>(a_, x, y, range(0,x.size()), range(0,y.size()), 1, 2);
390  }
391 
392  void caxpy_L(const Complex *a_, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y) {
393  // Enter a recursion.
394  // Pass a, x, y. (0,0) indexes the tiles. -1 indicates the matrix is lower-triangular
395  // which lets us skip some tiles.
396  if (x.size() != y.size())
397  {
398  errorQuda("An optimal block caxpy_L with non-square 'a' has not yet been implemented. Use block caxpy instead");
399  }
400  axpy_recurse<multicaxpy_>(a_, x, y, range(0,x.size()), range(0,y.size()), -1, 2);
401  }
402 
403  void caxpy(const Complex *a, ColorSpinorField &x, ColorSpinorField &y) { caxpy(a, x.Components(), y.Components()); }
404 
405  void caxpy_U(const Complex *a, ColorSpinorField &x, ColorSpinorField &y) { caxpy_U(a, x.Components(), y.Components()); }
406 
407  void caxpy_L(const Complex *a, ColorSpinorField &x, ColorSpinorField &y) { caxpy_L(a, x.Components(), y.Components()); }
408 
409  void caxpyz_recurse(const Complex *a_, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y,
410  std::vector<ColorSpinorField*> &z, const range &range_x, const range &range_y,
411  int pass, int upper)
412  {
413  // if greater than max single-kernel size, recurse
414  if (y.size() > (size_t)max_YW_size(x.size(), x[0]->Precision(), y[0]->Precision(), false, true, 2, false)) {
415  // We need to split up 'a' carefully since it's row-major.
416  Complex* tmpmajor = new Complex[x.size()*y.size()];
417  Complex* tmpmajor0 = &tmpmajor[0];
418  Complex* tmpmajor1 = &tmpmajor[x.size()*(y.size()/2)];
419  std::vector<ColorSpinorField*> y0(y.begin(), y.begin() + y.size()/2);
420  std::vector<ColorSpinorField*> y1(y.begin() + y.size()/2, y.end());
421 
422  std::vector<ColorSpinorField*> z0(z.begin(), z.begin() + z.size()/2);
423  std::vector<ColorSpinorField*> z1(z.begin() + z.size()/2, z.end());
424 
425  const unsigned int xlen = x.size();
426  const unsigned int ylen0 = y.size()/2;
427  const unsigned int ylen1 = y.size() - y.size()/2;
428 
429  int count = 0, count0 = 0, count1 = 0;
430  for (unsigned int i_ = 0; i_ < xlen; i_++)
431  {
432  for (unsigned int j = 0; j < ylen0; j++)
433  tmpmajor0[count0++] = a_[count++];
434  for (unsigned int j = 0; j < ylen1; j++)
435  tmpmajor1[count1++] = a_[count++];
436  }
437 
438  caxpyz_recurse(tmpmajor0, x, y0, z0, range_x, range(range_y.first, range_y.first + y0.size()), pass, upper);
439  caxpyz_recurse(tmpmajor1, x, y1, z1, range_x, range(range_y.first + y0.size(), range_y.second), pass, upper);
440 
441  delete[] tmpmajor;
442  } else {
443  // if at bottom of recursion check where we are
444  if (is_valid_NXZ(x.size(), false, x[0]->Precision() < QUDA_SINGLE_PRECISION)) {
445  // check if tile straddles diagonal
446  bool is_diagonal = (range_x.first < range_y.second) && (range_y.first < range_x.second);
447  if (pass==1) {
448  if (!is_diagonal) {
449  // if upper triangular and upper-right tile corner is below diagonal return
450  if (upper == 1 && range_y.first >= range_x.second) { return; }
451  // if lower triangular and lower-left tile corner is above diagonal return
452  if (upper == -1 && range_x.first >= range_y.second) { return; }
453  caxpy(a_, x, z); return; // off diagonal
454  }
455  return;
456  } else {
457  if (!is_diagonal) return; // We're on the first pass, so we only want to update the diagonal.
458  }
459 
460  coeff_array<Complex> a(a_), b, c;
461  constexpr bool mixed = false;
462  instantiate<multicaxpyz_, MultiBlas, mixed>(a, b, c, *x[0], *y[0], x, y, x, z);
463  } else {
464  // split the problem in half and recurse
465  const Complex *a0 = &a_[0];
466  const Complex *a1 = &a_[(x.size() / 2) * y.size()];
467 
468  std::vector<ColorSpinorField *> x0(x.begin(), x.begin() + x.size() / 2);
469  std::vector<ColorSpinorField *> x1(x.begin() + x.size() / 2, x.end());
470 
471  caxpyz_recurse(a0, x0, y, z, range(range_x.first, range_x.first + x0.size()), range_y, pass, upper);
472  caxpyz_recurse(a1, x1, y, z, range(range_x.first + x0.size(), range_x.second), range_y, pass, upper);
473  }
474  } // end if (y.size() > max_YW_size())
475  }
476 
477  void caxpyz(const Complex *a, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y, std::vector<ColorSpinorField*> &z)
478  {
479  // first pass does the caxpyz on the diagonal
480  caxpyz_recurse(a, x, y, z, range(0, x.size()), range(0, y.size()), 0, 0);
481  // second pass does caxpy on the off diagonals
482  caxpyz_recurse(a, x, y, z, range(0, x.size()), range(0, y.size()), 1, 0);
483  }
484 
485  void caxpyz_U(const Complex *a, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y, std::vector<ColorSpinorField*> &z)
486  {
487  // a is upper triangular.
488  // first pass does the caxpyz on the diagonal
489  caxpyz_recurse(a, x, y, z, range(0, x.size()), range(0, y.size()), 0, 1);
490  // second pass does caxpy on the off diagonals
491  caxpyz_recurse(a, x, y, z, range(0, x.size()), range(0, y.size()), 1, 1);
492  }
493 
494  void caxpyz_L(const Complex *a, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y, std::vector<ColorSpinorField*> &z)
495  {
496  // a is upper triangular.
497  // first pass does the caxpyz on the diagonal
498  caxpyz_recurse(a, x, y, z, range(0, x.size()), range(0, y.size()), 0, -1);
499  // second pass does caxpy on the off diagonals
500  caxpyz_recurse(a, x, y, z, range(0, x.size()), range(0, y.size()), 1, -1);
501  }
502 
503 
504  void caxpyz(const Complex *a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
505  {
506  caxpyz(a, x.Components(), y.Components(), z.Components());
507  }
508 
509  void caxpyz_U(const Complex *a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
510  {
511  caxpyz_U(a, x.Components(), y.Components(), z.Components());
512  }
513 
514  void caxpyz_L(const Complex *a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
515  {
516  caxpyz_L(a, x.Components(), y.Components(), z.Components());
517  }
518 
519  void axpyBzpcx(const double *a_, std::vector<ColorSpinorField *> &x_, std::vector<ColorSpinorField *> &y_,
520  const double *b_, ColorSpinorField &z_, const double *c_)
521  {
522  if (y_.size() <= (size_t)max_N_multi_1d()) {
523  // swizzle order since we are writing to x_ and y_, but the
524  // multi-blas only allow writing to y and w, and moreover the
525  // block width of y and w must match, and x and z must match.
526  std::vector<ColorSpinorField*> &y = y_;
527  std::vector<ColorSpinorField*> &w = x_;
528 
529  // wrap a container around the third solo vector
530  std::vector<ColorSpinorField*> x;
531  x.push_back(&z_);
532 
533  coeff_array<double> a(a_), b(b_), c(c_);
534  constexpr bool mixed = true;
535  instantiate<multi_axpyBzpcx_, MultiBlas, mixed>(a, b, c, *x[0], *y[0], x, y, x, w);
536  } else {
537  // split the problem in half and recurse
538  const double *a0 = &a_[0];
539  const double *b0 = &b_[0];
540  const double *c0 = &c_[0];
541 
542  std::vector<ColorSpinorField*> x0(x_.begin(), x_.begin() + x_.size()/2);
543  std::vector<ColorSpinorField*> y0(y_.begin(), y_.begin() + y_.size()/2);
544 
545  axpyBzpcx(a0, x0, y0, b0, z_, c0);
546 
547  const double *a1 = &a_[y_.size()/2];
548  const double *b1 = &b_[y_.size()/2];
549  const double *c1 = &c_[y_.size()/2];
550 
551  std::vector<ColorSpinorField*> x1(x_.begin() + x_.size()/2, x_.end());
552  std::vector<ColorSpinorField*> y1(y_.begin() + y_.size()/2, y_.end());
553 
554  axpyBzpcx(a1, x1, y1, b1, z_, c1);
555  }
556  }
557 
558  void caxpyBxpz(const Complex *a_, std::vector<ColorSpinorField*> &x_, ColorSpinorField &y_,
559  const Complex *b_, ColorSpinorField &z_)
560  {
561  if (x_.size() <= (size_t)max_N_multi_1d() &&
562  is_valid_NXZ(x_.size(), false, x_[0]->Precision() < QUDA_SINGLE_PRECISION)) // only split if we have to.
563  {
564  // swizzle order since we are writing to y_ and z_, but the
565  // multi-blas only allow writing to y and w, and moreover the
566  // block width of y and w must match, and x and z must match.
567  // Also, wrap a container around them.
568  std::vector<ColorSpinorField*> y;
569  y.push_back(&y_);
570  std::vector<ColorSpinorField*> w;
571  w.push_back(&z_);
572 
573  // we're reading from x
574  std::vector<ColorSpinorField*> &x = x_;
575 
576  coeff_array<Complex> a(a_), b(b_), c;
577  constexpr bool mixed = true;
578  instantiate<multi_caxpyBxpz_, MultiBlas, mixed>(a, b, c, *x[0], *y[0], x, y, x, w);
579  } else {
580  // split the problem in half and recurse
581  const Complex *a0 = &a_[0];
582  const Complex *b0 = &b_[0];
583 
584  std::vector<ColorSpinorField*> x0(x_.begin(), x_.begin() + x_.size()/2);
585 
586  caxpyBxpz(a0, x0, y_, b0, z_);
587 
588  const Complex *a1 = &a_[x_.size()/2];
589  const Complex *b1 = &b_[x_.size()/2];
590 
591  std::vector<ColorSpinorField*> x1(x_.begin() + x_.size()/2, x_.end());
592 
593  caxpyBxpz(a1, x1, y_, b1, z_);
594  }
595  }
596 
597  void axpy(const double *a_, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y)
598  {
599  // Enter a recursion.
600  // Pass a, x, y. (0,0) indexes the tiles. false specifies the matrix is unstructured.
601  axpy_recurse<multiaxpy_>(a_, x, y, range(0, x.size()), range(0, y.size()), 0, 1);
602  }
603 
604  void axpy_U(const double *a_, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y)
605  {
606  // Enter a recursion.
607  // Pass a, x, y. (0,0) indexes the tiles. 1 indicates the matrix is upper-triangular,
608  // which lets us skip some tiles.
609  if (x.size() != y.size())
610  {
611  errorQuda("An optimal block axpy_U with non-square 'a' has not yet been implemented. Use block axpy instead");
612  }
613  axpy_recurse<multiaxpy_>(a_, x, y, range(0, x.size()), range(0, y.size()), 1, 1);
614  }
615 
616  void axpy_L(const double *a_, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y)
617  {
618  // Enter a recursion.
619  // Pass a, x, y. (0,0) indexes the tiles. -1 indicates the matrix is lower-triangular
620  // which lets us skip some tiles.
621  if (x.size() != y.size())
622  {
623  errorQuda("An optimal block axpy_L with non-square 'a' has not yet been implemented. Use block axpy instead");
624  }
625  axpy_recurse<multiaxpy_>(a_, x, y, range(0, x.size()), range(0, y.size()), -1, 1);
626  }
627 
628  // Composite field version
629  void axpy(const double *a, ColorSpinorField &x, ColorSpinorField &y) { axpy(a, x.Components(), y.Components()); }
630 
631  void axpy_U(const double *a, ColorSpinorField &x, ColorSpinorField &y) { axpy_U(a, x.Components(), y.Components()); }
632 
633  void axpy_L(const double *a, ColorSpinorField &x, ColorSpinorField &y) { axpy_L(a, x.Components(), y.Components()); }
634 
635  } // namespace blas
636 
637 } // namespace quda