QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
multi_blas_quda.cu
Go to the documentation of this file.
1 #include <stdlib.h>
2 #include <stdio.h>
3 #include <cstring> // needed for memset
4 #include <typeinfo>
5 
6 #include <tune_quda.h>
7 #include <blas_quda.h>
8 #include <color_spinor_field.h>
9 
10 #include <jitify_helper.cuh>
12 
13 namespace quda {
14 
15  namespace blas {
16 
17  cudaStream_t* getStream();
18 
19  template <int writeX, int writeY, int writeZ, int writeW>
20  struct write {
21  static constexpr int X = writeX;
22  static constexpr int Y = writeY;
23  static constexpr int Z = writeZ;
24  static constexpr int W = writeW;
25  };
26 
27  namespace detail
28  {
29  template <unsigned... digits> struct to_chars {
30  static const char value[];
31  };
32 
33  template <unsigned... digits> const char to_chars<digits...>::value[] = {('0' + digits)..., 0};
34 
35  template <unsigned rem, unsigned... digits> struct explode : explode<rem / 10, rem % 10, digits...> {
36  };
37 
38  template <unsigned... digits> struct explode<0, digits...> : to_chars<digits...> {
39  };
40  } // namespace detail
41 
42  template <unsigned num> struct num_to_string : detail::explode<num / 10, num % 10> {
43  };
44 
45  template <int NXZ, typename FloatN, int M, typename SpinorX, typename SpinorY, typename SpinorZ, typename SpinorW,
46  typename Functor, typename T>
47  class MultiBlas : public TunableVectorY
48  {
49 
50  private:
51  const int NYW;
52  const int nParity;
54  const coeff_array<T> &a, &b, &c;
55 
56  std::vector<ColorSpinorField *> &x, &y, &z, &w;
57 
58  // host pointers used for backing up fields when tuning
59  // don't curry into the Spinors to minimize parameter size
61 
62  bool tuneSharedBytes() const { return false; }
63 
64  public:
65  MultiBlas(SpinorX X[], SpinorY Y[], SpinorZ Z[], SpinorW W[], Functor &f, const coeff_array<T> &a,
66  const coeff_array<T> &b, const coeff_array<T> &c, std::vector<ColorSpinorField *> &x,
67  std::vector<ColorSpinorField *> &y, std::vector<ColorSpinorField *> &z, std::vector<ColorSpinorField *> &w,
68  int NYW, int length) :
69  TunableVectorY(NYW),
70  NYW(NYW),
71  nParity(x[0]->SiteSubset()),
72  arg(X, Y, Z, W, f, NYW, length / nParity),
73  a(a),
74  b(b),
75  c(c),
76  x(x),
77  y(y),
78  z(z),
79  w(w),
80  Y_h(),
81  W_h(),
82  Ynorm_h(),
83  Wnorm_h()
84  {
85  Amatrix_h = reinterpret_cast<signed char *>(const_cast<T *>(a.data));
86  Bmatrix_h = reinterpret_cast<signed char *>(const_cast<T *>(b.data));
87  Cmatrix_h = reinterpret_cast<signed char *>(const_cast<T *>(c.data));
88 
89  strcpy(aux, x[0]->AuxString());
90  if (x[0]->Precision() != y[0]->Precision()) {
91  strcat(aux, ",");
92  strcat(aux, y[0]->AuxString());
93  }
94 
95 #ifdef JITIFY
96  ::quda::create_jitify_program("kernels/multi_blas_core.cuh");
97 #endif
98  }
99 
100  virtual ~MultiBlas() {}
101 
102  inline TuneKey tuneKey() const
103  {
104  char name[TuneKey::name_n];
105  strcpy(name, num_to_string<NXZ>::value);
106  strcat(name, std::to_string(NYW).c_str());
107  strcat(name, typeid(arg.f).name());
108  return TuneKey(x[0]->VolString(), name, aux);
109  }
110 
111  inline void apply(const cudaStream_t &stream)
112  {
113  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
114 
115  typedef typename scalar<FloatN>::type Float;
116  typedef typename vector<Float, 2>::type Float2;
117 #ifdef JITIFY
118  using namespace jitify::reflection;
119  auto instance
120  = program->kernel("quda::blas::multiBlasKernel").instantiate(Type<FloatN>(), M, NXZ, Type<decltype(arg)>());
121 
122  // FIXME - if NXZ=1 no need to copy entire array
123  // FIXME - do we really need strided access here?
124  if (a.data && a.use_const) {
125  Float2 A[MAX_MATRIX_SIZE / sizeof(Float2)];
126  // since the kernel doesn't know the width of them matrix at compile
127  // time we stride it and copy the padded matrix to GPU
128  for (int i = 0; i < NXZ; i++)
129  for (int j = 0; j < NYW; j++)
130  A[MAX_MULTI_BLAS_N * i + j] = make_Float2<Float2>(Complex(a.data[NYW * i + j]));
131 
132  auto Amatrix_d = instance.get_constant_ptr("quda::blas::Amatrix_d");
133  cuMemcpyHtoDAsync(Amatrix_d, A, MAX_MATRIX_SIZE, *getStream());
134  }
135 
136  if (b.data && b.use_const) {
137  Float2 B[MAX_MATRIX_SIZE / sizeof(Float2)];
138  // since the kernel doesn't know the width of them matrix at compile
139  // time we stride it and copy the padded matrix to GPU
140  for (int i = 0; i < NXZ; i++)
141  for (int j = 0; j < NYW; j++)
142  B[MAX_MULTI_BLAS_N * i + j] = make_Float2<Float2>(Complex(b.data[NYW * i + j]));
143 
144  auto Bmatrix_d = instance.get_constant_ptr("quda::blas::Bmatrix_d");
145  cuMemcpyHtoDAsync(Bmatrix_d, B, MAX_MATRIX_SIZE, *getStream());
146  }
147 
148  if (c.data && c.use_const) {
149  Float2 C[MAX_MATRIX_SIZE / sizeof(Float2)];
150  // since the kernel doesn't know the width of them matrix at compile
151  // time we stride it and copy the padded matrix to GPU
152  for (int i = 0; i < NXZ; i++)
153  for (int j = 0; j < NYW; j++)
154  C[MAX_MULTI_BLAS_N * i + j] = make_Float2<Float2>(Complex(c.data[NYW * i + j]));
155 
156  auto Cmatrix_d = instance.get_constant_ptr("quda::blas::Cmatrix_d");
157  cuMemcpyHtoDAsync(Cmatrix_d, C, MAX_MATRIX_SIZE, *getStream());
158  }
159 
160  jitify_error = instance.configure(tp.grid, tp.block, tp.shared_bytes, stream).launch(arg);
161 #else
162  // FIXME - if NXZ=1 no need to copy entire array
163  // FIXME - do we really need strided access here?
164  if (a.data && a.use_const) {
165  Float2 A[MAX_MATRIX_SIZE / sizeof(Float2)];
166  // since the kernel doesn't know the width of them matrix at compile
167  // time we stride it and copy the padded matrix to GPU
168  for (int i = 0; i < NXZ; i++)
169  for (int j = 0; j < NYW; j++)
170  A[MAX_MULTI_BLAS_N * i + j] = make_Float2<Float2>(Complex(a.data[NYW * i + j]));
171 
172  cudaMemcpyToSymbolAsync(Amatrix_d, A, MAX_MATRIX_SIZE, 0, cudaMemcpyHostToDevice, *getStream());
173  }
174 
175  if (b.data && b.use_const) {
176  Float2 B[MAX_MATRIX_SIZE / sizeof(Float2)];
177  // since the kernel doesn't know the width of them matrix at compile
178  // time we stride it and copy the padded matrix to GPU
179  for (int i = 0; i < NXZ; i++)
180  for (int j = 0; j < NYW; j++)
181  B[MAX_MULTI_BLAS_N * i + j] = make_Float2<Float2>(Complex(b.data[NYW * i + j]));
182 
183  cudaMemcpyToSymbolAsync(Bmatrix_d, B, MAX_MATRIX_SIZE, 0, cudaMemcpyHostToDevice, *getStream());
184  }
185 
186  if (c.data && c.use_const) {
187  Float2 C[MAX_MATRIX_SIZE / sizeof(Float2)];
188  // since the kernel doesn't know the width of them matrix at compile
189  // time we stride it and copy the padded matrix to GPU
190  for (int i = 0; i < NXZ; i++)
191  for (int j = 0; j < NYW; j++)
192  C[MAX_MULTI_BLAS_N * i + j] = make_Float2<Float2>(Complex(c.data[NYW * i + j]));
193 
194  cudaMemcpyToSymbolAsync(Cmatrix_d, C, MAX_MATRIX_SIZE, 0, cudaMemcpyHostToDevice, *getStream());
195  }
196 #if CUDA_VERSION < 9000
197  cudaMemcpyToSymbolAsync(arg_buffer, reinterpret_cast<char *>(&arg), sizeof(arg), 0, cudaMemcpyHostToDevice,
198  *getStream());
199 #endif
200  multiBlasKernel<FloatN, M, NXZ><<<tp.grid, tp.block, tp.shared_bytes, stream>>>(arg);
201 #endif
202  }
203 
204  void preTune()
205  {
206  for (int i = 0; i < NYW; ++i) {
207  arg.Y[i].backup(&Y_h[i], &Ynorm_h[i], y[i]->Bytes(), y[i]->NormBytes());
208  arg.W[i].backup(&W_h[i], &Wnorm_h[i], w[i]->Bytes(), w[i]->NormBytes());
209  }
210  }
211 
212  void postTune()
213  {
214  for (int i = 0; i < NYW; ++i) {
215  arg.Y[i].restore(&Y_h[i], &Ynorm_h[i], y[i]->Bytes(), y[i]->NormBytes());
216  arg.W[i].restore(&W_h[i], &Wnorm_h[i], w[i]->Bytes(), w[i]->NormBytes());
217  }
218  }
219 
221  {
223  param.grid.z = nParity;
224  }
225 
227  {
229  param.grid.z = nParity;
230  }
231 
232  long long flops() const { return arg.f.flops() * vec_length<FloatN>::value * (long)arg.length * nParity * M; }
233 
234  long long bytes() const
235  {
236  // the factor two here assumes we are reading and writing to the high precision vector
237  return ((arg.f.streams() - 2) * x[0]->Bytes() + 2 * y[0]->Bytes());
238  }
239 
240  int tuningIter() const { return 3; }
241  };
242 
243  template <int NXZ, typename RegType, typename StoreType, typename yType, int M,
244  template <int, typename, typename> class Functor, typename write, typename T>
245  void multiBlas(const coeff_array<T> &a, const coeff_array<T> &b, const coeff_array<T> &c,
246  std::vector<ColorSpinorField *> &x, std::vector<ColorSpinorField *> &y, std::vector<ColorSpinorField *> &z,
247  std::vector<ColorSpinorField *> &w, int length)
248  {
249  const int NYW = y.size();
250 
251  const int N = NXZ > NYW ? NXZ : NYW;
252  if (N > MAX_MULTI_BLAS_N) errorQuda("Spinor vector length exceeds max size (%d > %d)", N, MAX_MULTI_BLAS_N);
253 
254  if (NXZ * NYW * sizeof(Complex) > MAX_MATRIX_SIZE)
255  errorQuda("A matrix exceeds max size (%lu > %d)", NXZ * NYW * sizeof(Complex), MAX_MATRIX_SIZE);
256 
257  typedef typename scalar<RegType>::type Float;
258  typedef typename vector<Float, 2>::type Float2;
259  typedef vector<Float, 2> vec2;
260 
265 
266  for (int i = 0; i < NXZ; i++) {
267  X[i].set(*dynamic_cast<cudaColorSpinorField *>(x[i]));
268  Z[i].set(*dynamic_cast<cudaColorSpinorField *>(z[i]));
269  }
270  for (int i = 0; i < NYW; i++) {
271  Y[i].set(*dynamic_cast<cudaColorSpinorField *>(y[i]));
272  W[i].set(*dynamic_cast<cudaColorSpinorField *>(w[i]));
273  }
274 
275  // if block caxpy is an 'outer product of caxpy' where 'x'
276 
277  Functor<NXZ, Float2, RegType> f(a, b, c, NYW);
278 
281  blas(X, Y, Z, W, f, a, b, c, x, y, z, w, NYW, length);
282  blas.apply(*getStream());
283 
284  blas::bytes += blas.bytes();
285  blas::flops += blas.flops();
286 
287  checkCudaError();
288  }
289 
293  template <int NXZ, template <int MXZ, typename Float, typename FloatN> class Functor, typename write, typename T>
294  void multiBlas(const coeff_array<T> &a, const coeff_array<T> &b, const coeff_array<T> &c,
297  {
298 
299  if (checkLocation(*x[0], *y[0], *z[0], *w[0]) == QUDA_CUDA_FIELD_LOCATION) {
300 
301  if (y[0]->Precision() == QUDA_DOUBLE_PRECISION && x[0]->Precision() == QUDA_DOUBLE_PRECISION) {
302 
303 #if QUDA_PRECISION & 8
304 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) || defined(GPU_STAGGERED_DIRAC)
305  const int M = 1;
306  multiBlas<NXZ, double2, double2, double2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Length() / (2 * M));
307 #else
308  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
309 #endif
310 #else
311  errorQuda("QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x[0]->Precision());
312 #endif
313 
314  } else if (y[0]->Precision() == QUDA_SINGLE_PRECISION && x[0]->Precision() == QUDA_SINGLE_PRECISION) {
315 
316 #if QUDA_PRECISION & 4
317  if (x[0]->Nspin() == 4) {
318 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
319  const int M = 1;
320  multiBlas<NXZ, float4, float4, float4, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Length() / (4 * M));
321 #else
322  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
323 #endif
324 
325  } else if (x[0]->Nspin() == 2 || x[0]->Nspin() == 1) {
326 
327 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) || defined(GPU_STAGGERED_DIRAC)
328  const int M = 1;
329  multiBlas<NXZ, float2, float2, float2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Length() / (2 * M));
330 #else
331  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
332 #endif
333  } else {
334  errorQuda("nSpin=%d is not supported\n", x[0]->Nspin());
335  }
336 #else
337  errorQuda("QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x[0]->Precision());
338 #endif
339 
340  } else if (y[0]->Precision() == QUDA_HALF_PRECISION && x[0]->Precision() == QUDA_HALF_PRECISION) {
341 
342 #if QUDA_PRECISION & 2
343  if (x[0]->Ncolor() != 3) { errorQuda("nColor = %d is not supported", x[0]->Ncolor()); }
344  if (x[0]->Nspin() == 4) { // wilson
345 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
346  const int M = 6;
347  multiBlas<NXZ, float4, short4, short4, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
348 #else
349  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
350 #endif
351  } else if (x[0]->Nspin() == 1) { // staggered
352 #ifdef GPU_STAGGERED_DIRAC
353  const int M = 3;
354  multiBlas<NXZ, float2, short2, short2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
355 #else
356  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
357 #endif
358  } else {
359  errorQuda("nSpin=%d is not supported\n", x[0]->Nspin());
360  }
361 #else
362  errorQuda("QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x[0]->Precision());
363 #endif
364 
365  } else if (y[0]->Precision() == QUDA_QUARTER_PRECISION && x[0]->Precision() == QUDA_QUARTER_PRECISION) {
366 
367 #if QUDA_PRECISION & 1
368  if (x[0]->Ncolor() != 3) { errorQuda("nColor = %d is not supported", x[0]->Ncolor()); }
369  if (x[0]->Nspin() == 4) { // wilson
370 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
371  const int M = 6;
372  multiBlas<NXZ, float4, char4, char4, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
373 #else
374  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
375 #endif
376  } else if (x[0]->Nspin() == 1) { // staggered
377 #ifdef GPU_STAGGERED_DIRAC
378  const int M = 3;
379  multiBlas<NXZ, float2, char2, char2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
380 #else
381  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
382 #endif
383  } else {
384  errorQuda("nSpin=%d is not supported\n", x[0]->Nspin());
385  }
386 #else
387  errorQuda("QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x[0]->Precision());
388 #endif
389 
390  } else {
391 
392  errorQuda("Precision combination x=%d not supported\n", x[0]->Precision());
393  }
394  } else { // fields on the cpu
395  errorQuda("Not implemented");
396  }
397  }
398 
402  template <int NXZ, template <int MXZ, typename Float, typename FloatN> class Functor, typename write, typename T>
403  void mixedMultiBlas(const coeff_array<T> &a, const coeff_array<T> &b, const coeff_array<T> &c,
406  {
407  if (checkLocation(*x[0], *y[0], *z[0], *w[0]) == QUDA_CUDA_FIELD_LOCATION) {
408 
409  if (y[0]->Precision() == QUDA_DOUBLE_PRECISION) {
410 
411 #if QUDA_PRECISION & 8
412  if (x[0]->Precision() == QUDA_SINGLE_PRECISION) {
413 
414 #if QUDA_PRECISION & 4
415  if (x[0]->Nspin() == 4) {
416 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
417  const int M = 12;
418  multiBlas<NXZ, double2, float4, double2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
419 #else
420  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
421 #endif
422  } else if (x[0]->Nspin() == 1) {
423 
424 #if defined(GPU_STAGGERED_DIRAC)
425  const int M = 3;
426  multiBlas<NXZ, double2, float2, double2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
427 #else
428  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
429 #endif
430  }
431 
432 #else
433  errorQuda("QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x[0]->Precision());
434 #endif
435 
436  } else if (x[0]->Precision() == QUDA_HALF_PRECISION) {
437 
438 #if QUDA_PRECISION & 2
439  if (x[0]->Nspin() == 4) {
440 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
441  const int M = 12;
442  multiBlas<NXZ, double2, short4, double2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
443 #else
444  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
445 #endif
446 
447  } else if (x[0]->Nspin() == 1) {
448 
449 #if defined(GPU_STAGGERED_DIRAC)
450  const int M = 3;
451  multiBlas<NXZ, double2, short2, double2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
452 #else
453  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
454 #endif
455  }
456 #else
457  errorQuda("QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x[0]->Precision());
458 #endif
459 
460  } else if (x[0]->Precision() == QUDA_QUARTER_PRECISION) {
461 
462 #if QUDA_PRECISION & 1
463  if (x[0]->Nspin() == 4) {
464 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
465  const int M = 12;
466  multiBlas<NXZ, double2, char4, double2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
467 #else
468  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
469 #endif
470 
471  } else if (x[0]->Nspin() == 1) {
472 
473 #if defined(GPU_STAGGERED_DIRAC)
474  const int M = 3;
475  multiBlas<NXZ, double2, char2, double2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
476 #else
477  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
478 #endif
479  }
480 #else
481  errorQuda("QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x[0]->Precision());
482 #endif
483 
484  } else {
485  errorQuda("Not implemented for this precision combination %d %d", x[0]->Precision(), y[0]->Precision());
486  }
487 #else
488  errorQuda("QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, y[0]->Precision());
489 #endif
490 
491  } else if (y[0]->Precision() == QUDA_SINGLE_PRECISION) {
492 
493 #if (QUDA_PRECISION & 4)
494  if (x[0]->Precision() == QUDA_HALF_PRECISION) {
495 
496 #if (QUDA_PRECISION & 2)
497  if (x[0]->Nspin() == 4) {
498 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
499  const int M = 6;
500  multiBlas<NXZ, float4, short4, float4, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
501 #else
502  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
503 #endif
504 
505  } else if (x[0]->Nspin() == 2 || x[0]->Nspin() == 1) {
506 
507 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) || defined(GPU_STAGGERED_DIRAC)
508  const int M = 3;
509  multiBlas<NXZ, float2, short2, float2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
510 #else
511  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
512 #endif
513  } else {
514  errorQuda("nSpin=%d is not supported\n", x[0]->Nspin());
515  }
516 
517 #else
518  errorQuda("QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, y[0]->Precision());
519 #endif
520 
521  } else if (x[0]->Precision() == QUDA_QUARTER_PRECISION) {
522 
523 #if (QUDA_PRECISION & 1)
524  if (x[0]->Nspin() == 4) {
525 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
526  const int M = 6;
527  multiBlas<NXZ, float4, char4, float4, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
528 #else
529  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
530 #endif
531 
532  } else if (x[0]->Nspin() == 2 || x[0]->Nspin() == 1) {
533 
534 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) || defined(GPU_STAGGERED_DIRAC)
535  const int M = 3;
536  multiBlas<NXZ, float2, char2, float2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
537 #else
538  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
539 #endif
540  } else {
541  errorQuda("nSpin=%d is not supported\n", x[0]->Nspin());
542  }
543 
544 #else
545  errorQuda("QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, y[0]->Precision());
546 #endif
547 
548  } else {
549  errorQuda("Precision combination x=%d y=%d not supported\n", x[0]->Precision(), y[0]->Precision());
550  }
551 #else
552  errorQuda("QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, y[0]->Precision());
553 #endif
554  } else {
555  errorQuda("Precision combination x=%d y=%d not supported\n", x[0]->Precision(), y[0]->Precision());
556  }
557  } else { // fields on the cpu
558  errorQuda("Not implemented");
559  }
560  }
561 
562  void caxpy_recurse(const Complex *a_, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y,
563  int i_idx ,int j_idx, int upper) {
564 
565  if (y.size() > MAX_MULTI_BLAS_N) // if greater than max single-kernel size, recurse.
566  {
567  // We need to split up 'a' carefully since it's row-major.
568  Complex* tmpmajor = new Complex[x.size()*y.size()];
569  Complex* tmpmajor0 = &tmpmajor[0];
570  Complex* tmpmajor1 = &tmpmajor[x.size()*(y.size()/2)];
571  std::vector<ColorSpinorField*> y0(y.begin(), y.begin() + y.size()/2);
572  std::vector<ColorSpinorField*> y1(y.begin() + y.size()/2, y.end());
573 
574  const unsigned int xlen = x.size();
575  const unsigned int ylen0 = y.size()/2;
576  const unsigned int ylen1 = y.size() - y.size()/2;
577 
578  int count = 0, count0 = 0, count1 = 0;
579  for (unsigned int i = 0; i < xlen; i++)
580  {
581  for (unsigned int j = 0; j < ylen0; j++)
582  tmpmajor0[count0++] = a_[count++];
583  for (unsigned int j = 0; j < ylen1; j++)
584  tmpmajor1[count1++] = a_[count++];
585  }
586 
587  caxpy_recurse(tmpmajor0, x, y0, i_idx, 2*j_idx+0, upper);
588  caxpy_recurse(tmpmajor1, x, y1, i_idx, 2*j_idx+1, upper);
589 
590  delete[] tmpmajor;
591  }
592  else
593  {
594  // if at the bottom of recursion,
595  // return if on lower left for upper triangular,
596  // return if on upper right for lower triangular.
597  if (x.size() <= MAX_MULTI_BLAS_N) {
598  if (upper == 1 && j_idx < i_idx) { return; }
599  if (upper == -1 && j_idx > i_idx) { return; }
600  }
601 
602  // mark true since we will copy the "a" matrix into constant memory
603  coeff_array<Complex> a(a_, true), b, c;
604 
605  if (x[0]->Precision() == y[0]->Precision())
606  {
607  switch (x.size()) {
608  case 1: multiBlas<1, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
609 #if MAX_MULTI_BLAS_N >= 2
610  case 2: multiBlas<2, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
611 #if MAX_MULTI_BLAS_N >= 3
612  case 3: multiBlas<3, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
613 #if MAX_MULTI_BLAS_N >= 4
614  case 4: multiBlas<4, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
615 #if MAX_MULTI_BLAS_N >= 5
616  case 5: multiBlas<5, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
617 #if MAX_MULTI_BLAS_N >= 6
618  case 6: multiBlas<6, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
619 #if MAX_MULTI_BLAS_N >= 7
620  case 7: multiBlas<7, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
621 #if MAX_MULTI_BLAS_N >= 8
622  case 8: multiBlas<8, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
623 #if MAX_MULTI_BLAS_N >= 9
624  case 9: multiBlas<9, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
625 #if MAX_MULTI_BLAS_N >= 10
626  case 10: multiBlas<10, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
627 #if MAX_MULTI_BLAS_N >= 11
628  case 11: multiBlas<11, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
629 #if MAX_MULTI_BLAS_N >= 12
630  case 12: multiBlas<12, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
631 #if MAX_MULTI_BLAS_N >= 13
632  case 13: multiBlas<13, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
633 #if MAX_MULTI_BLAS_N >= 14
634  case 14: multiBlas<14, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
635 #if MAX_MULTI_BLAS_N >= 15
636  case 15: multiBlas<15, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
637 #if MAX_MULTI_BLAS_N >= 16
638  case 16: multiBlas<16, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
639 #endif // 16
640  #endif // 15
641  #endif // 14
642  #endif // 13
643  #endif // 12
644  #endif // 11
645  #endif // 10
646  #endif // 9
647  #endif // 8
648  #endif // 7
649  #endif // 6
650  #endif // 5
651  #endif // 4
652  #endif // 3
653  #endif // 2
654  default:
655  // split the problem in half and recurse
656  const Complex *a0 = &a_[0];
657  const Complex *a1 = &a_[(x.size()/2)*y.size()];
658 
659  std::vector<ColorSpinorField*> x0(x.begin(), x.begin() + x.size()/2);
660  std::vector<ColorSpinorField*> x1(x.begin() + x.size()/2, x.end());
661 
662  caxpy_recurse(a0, x0, y, 2*i_idx+0, j_idx, upper);
663  caxpy_recurse(a1, x1, y, 2*i_idx+1, j_idx, upper);
664  break;
665  }
666  }
667  else // precisions don't agree.
668  {
669  switch (x.size()) {
670  case 1: mixedMultiBlas<1, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
671 #if MAX_MULTI_BLAS_N >= 2
672  case 2: mixedMultiBlas<2, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
673 #if MAX_MULTI_BLAS_N >= 3
674  case 3: mixedMultiBlas<3, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
675 #if MAX_MULTI_BLAS_N >= 4
676  case 4: mixedMultiBlas<4, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
677 #if MAX_MULTI_BLAS_N >= 5
678  case 5: mixedMultiBlas<5, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
679 #if MAX_MULTI_BLAS_N >= 6
680  case 6: mixedMultiBlas<6, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
681 #if MAX_MULTI_BLAS_N >= 7
682  case 7: mixedMultiBlas<7, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
683 #if MAX_MULTI_BLAS_N >= 8
684  case 8: mixedMultiBlas<8, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
685 #if MAX_MULTI_BLAS_N >= 9
686  case 9: mixedMultiBlas<9, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
687 #if MAX_MULTI_BLAS_N >= 10
688  case 10: mixedMultiBlas<10, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
689 #if MAX_MULTI_BLAS_N >= 11
690  case 11: mixedMultiBlas<11, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
691 #if MAX_MULTI_BLAS_N >= 12
692  case 12: mixedMultiBlas<12, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
693 #if MAX_MULTI_BLAS_N >= 13
694  case 13: mixedMultiBlas<13, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
695 #if MAX_MULTI_BLAS_N >= 14
696  case 14: mixedMultiBlas<14, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
697 #if MAX_MULTI_BLAS_N >= 15
698  case 15: mixedMultiBlas<15, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
699 #if MAX_MULTI_BLAS_N >= 16
700  case 16: mixedMultiBlas<16, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y); break;
701 #endif // 16
702  #endif // 15
703  #endif // 14
704  #endif // 13
705  #endif // 12
706  #endif // 11
707  #endif // 10
708  #endif // 9
709  #endif // 8
710  #endif // 7
711  #endif // 6
712  #endif // 5
713  #endif // 4
714  #endif // 3
715  #endif // 2
716  default:
717  // split the problem in half and recurse
718  const Complex *a0 = &a_[0];
719  const Complex *a1 = &a_[(x.size()/2)*y.size()];
720 
721  std::vector<ColorSpinorField*> x0(x.begin(), x.begin() + x.size()/2);
722  std::vector<ColorSpinorField*> x1(x.begin() + x.size()/2, x.end());
723 
724  caxpy_recurse(a0, x0, y, 2*i_idx+0, j_idx, upper);
725  caxpy_recurse(a1, x1, y, 2*i_idx+1, j_idx, upper);
726  break;
727  }
728  }
729  } // end if (y.size() > MAX_MULTI_BLAS_N)
730  }
731 
732  void caxpy(const Complex *a_, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y) {
733  // Enter a recursion.
734  // Pass a, x, y. (0,0) indexes the tiles. false specifies the matrix is unstructured.
735  caxpy_recurse(a_, x, y, 0, 0, 0);
736  }
737 
738  void caxpy_U(const Complex *a_, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y) {
739  // Enter a recursion.
740  // Pass a, x, y. (0,0) indexes the tiles. 1 indicates the matrix is upper-triangular,
741  // which lets us skip some tiles.
742  if (x.size() != y.size())
743  {
744  errorQuda("An optimal block caxpy_U with non-square 'a' has not yet been implemented. Use block caxpy instead.\n");
745  return;
746  }
747  caxpy_recurse(a_, x, y, 0, 0, 1);
748  }
749 
750  void caxpy_L(const Complex *a_, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y) {
751  // Enter a recursion.
752  // Pass a, x, y. (0,0) indexes the tiles. -1 indicates the matrix is lower-triangular
753  // which lets us skip some tiles.
754  if (x.size() != y.size())
755  {
756  errorQuda("An optimal block caxpy_L with non-square 'a' has not yet been implemented. Use block caxpy instead.\n");
757  return;
758  }
759  caxpy_recurse(a_, x, y, 0, 0, -1);
760  }
761 
762 
763  void caxpy(const Complex *a, ColorSpinorField &x, ColorSpinorField &y) { caxpy(a, x.Components(), y.Components()); }
764 
766 
768 
769 
770  void caxpyz_recurse(const Complex *a_, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y, std::vector<ColorSpinorField*> &z, int i, int j, int pass, int upper) {
771 
772  if (y.size() > MAX_MULTI_BLAS_N) // if greater than max single-kernel size, recurse.
773  {
774  // We need to split up 'a' carefully since it's row-major.
775  Complex* tmpmajor = new Complex[x.size()*y.size()];
776  Complex* tmpmajor0 = &tmpmajor[0];
777  Complex* tmpmajor1 = &tmpmajor[x.size()*(y.size()/2)];
778  std::vector<ColorSpinorField*> y0(y.begin(), y.begin() + y.size()/2);
779  std::vector<ColorSpinorField*> y1(y.begin() + y.size()/2, y.end());
780 
781  std::vector<ColorSpinorField*> z0(z.begin(), z.begin() + z.size()/2);
782  std::vector<ColorSpinorField*> z1(z.begin() + z.size()/2, z.end());
783 
784  const unsigned int xlen = x.size();
785  const unsigned int ylen0 = y.size()/2;
786  const unsigned int ylen1 = y.size() - y.size()/2;
787 
788  int count = 0, count0 = 0, count1 = 0;
789  for (unsigned int i_ = 0; i_ < xlen; i_++)
790  {
791  for (unsigned int j = 0; j < ylen0; j++)
792  tmpmajor0[count0++] = a_[count++];
793  for (unsigned int j = 0; j < ylen1; j++)
794  tmpmajor1[count1++] = a_[count++];
795  }
796 
797  caxpyz_recurse(tmpmajor0, x, y0, z0, i, 2*j+0, pass, upper);
798  caxpyz_recurse(tmpmajor1, x, y1, z1, i, 2*j+1, pass, upper);
799 
800  delete[] tmpmajor;
801  }
802  else
803  {
804  // if at bottom of recursion check where we are
805  if (x.size() <= MAX_MULTI_BLAS_N) {
806  if (pass==1) {
807  if (i!=j)
808  {
809  if (upper == 1 && j < i) { return; } // upper right, don't need to update lower left.
810  if (upper == -1 && i < j) { return; } // lower left, don't need to update upper right.
811  caxpy(a_, x, z); return; // off diagonal
812  }
813  return;
814  } else {
815  if (i!=j) return; // We're on the first pass, so we only want to update the diagonal.
816  }
817  }
818 
819  // mark true since we will copy the "a" matrix into constant memory
820  coeff_array<Complex> a(a_, true), b, c;
821 
822  if (x[0]->Precision() == y[0]->Precision())
823  {
824  switch (x.size()) {
825  case 1: multiBlas<1, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
826 #if MAX_MULTI_BLAS_N >= 2
827  case 2: multiBlas<2, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
828 #if MAX_MULTI_BLAS_N >= 3
829  case 3: multiBlas<3, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
830 #if MAX_MULTI_BLAS_N >= 4
831  case 4: multiBlas<4, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
832 #if MAX_MULTI_BLAS_N >= 5
833  case 5: multiBlas<5, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
834 #if MAX_MULTI_BLAS_N >= 6
835  case 6: multiBlas<6, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
836 #if MAX_MULTI_BLAS_N >= 7
837  case 7: multiBlas<7, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
838 #if MAX_MULTI_BLAS_N >= 8
839  case 8: multiBlas<8, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
840 #if MAX_MULTI_BLAS_N >= 9
841  case 9: multiBlas<9, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
842 #if MAX_MULTI_BLAS_N >= 10
843  case 10: multiBlas<10, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
844 #if MAX_MULTI_BLAS_N >= 11
845  case 11: multiBlas<11, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
846 #if MAX_MULTI_BLAS_N >= 12
847  case 12: multiBlas<12, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
848 #if MAX_MULTI_BLAS_N >= 13
849  case 13: multiBlas<13, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
850 #if MAX_MULTI_BLAS_N >= 14
851  case 14: multiBlas<14, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
852 #if MAX_MULTI_BLAS_N >= 15
853  case 15: multiBlas<15, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
854 #if MAX_MULTI_BLAS_N >= 16
855  case 16: multiBlas<16, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
856 #endif // 16
857  #endif // 15
858  #endif // 14
859  #endif // 13
860  #endif // 12
861  #endif // 11
862  #endif // 10
863  #endif // 9
864  #endif // 8
865  #endif // 7
866  #endif // 6
867  #endif // 5
868  #endif // 4
869  #endif // 3
870  #endif // 2
871  default:
872  // split the problem in half and recurse
873  const Complex *a0 = &a_[0];
874  const Complex *a1 = &a_[(x.size()/2)*y.size()];
875 
876  std::vector<ColorSpinorField*> x0(x.begin(), x.begin() + x.size()/2);
877  std::vector<ColorSpinorField*> x1(x.begin() + x.size()/2, x.end());
878 
879  caxpyz_recurse(a0, x0, y, z, 2*i+0, j, pass, upper);
880  caxpyz_recurse(a1, x1, y, z, 2*i+1, j, pass, upper); // b/c we don't want to re-zero z.
881  break;
882  }
883  }
884  else // precisions don't agree.
885  {
886  switch (x.size()) {
887  case 1: mixedMultiBlas<1, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
888 #if MAX_MULTI_BLAS_N >= 2
889  case 2: mixedMultiBlas<2, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
890 #if MAX_MULTI_BLAS_N >= 3
891  case 3: mixedMultiBlas<3, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
892 #if MAX_MULTI_BLAS_N >= 4
893  case 4: mixedMultiBlas<4, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
894 #if MAX_MULTI_BLAS_N >= 5
895  case 5: mixedMultiBlas<5, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
896 #if MAX_MULTI_BLAS_N >= 6
897  case 6: mixedMultiBlas<6, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
898 #if MAX_MULTI_BLAS_N >= 7
899  case 7: mixedMultiBlas<7, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
900 #if MAX_MULTI_BLAS_N >= 8
901  case 8: mixedMultiBlas<8, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
902 #if MAX_MULTI_BLAS_N >= 9
903  case 9: mixedMultiBlas<9, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
904 #if MAX_MULTI_BLAS_N >= 10
905  case 10: mixedMultiBlas<10, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
906 #if MAX_MULTI_BLAS_N >= 11
907  case 11: mixedMultiBlas<11, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
908 #if MAX_MULTI_BLAS_N >= 12
909  case 12: mixedMultiBlas<12, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
910 #if MAX_MULTI_BLAS_N >= 13
911  case 13: mixedMultiBlas<13, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
912 #if MAX_MULTI_BLAS_N >= 14
913  case 14: mixedMultiBlas<14, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
914 #if MAX_MULTI_BLAS_N >= 15
915  case 15: mixedMultiBlas<15, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
916 #if MAX_MULTI_BLAS_N >= 16
917  case 16: mixedMultiBlas<16, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z); break;
918 #endif // 16
919  #endif // 15
920  #endif // 14
921  #endif // 13
922  #endif // 12
923  #endif // 11
924  #endif // 10
925  #endif // 9
926  #endif // 8
927  #endif // 7
928  #endif // 6
929  #endif // 5
930  #endif // 4
931  #endif // 3
932  #endif // 2
933  default:
934  // split the problem in half and recurse
935  const Complex *a0 = &a_[0];
936  const Complex *a1 = &a_[(x.size()/2)*y.size()];
937 
938  std::vector<ColorSpinorField*> x0(x.begin(), x.begin() + x.size()/2);
939  std::vector<ColorSpinorField*> x1(x.begin() + x.size()/2, x.end());
940 
941  caxpyz_recurse(a0, x0, y, z, 2*i+0, j, pass, upper);
942  caxpyz_recurse(a1, x1, y, z, 2*i+1, j, pass, upper);
943  break;
944  }
945  }
946  } // end if (y.size() > MAX_MULTI_BLAS_N)
947  }
948 
949  void caxpyz(const Complex *a, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y, std::vector<ColorSpinorField*> &z) {
950  // first pass does the caxpyz on the diagonal
951  caxpyz_recurse(a, x, y, z, 0, 0, 0, 0);
952  // second pass does caxpy on the off diagonals
953  caxpyz_recurse(a, x, y, z, 0, 0, 1, 0);
954  }
955 
956  void caxpyz_U(const Complex *a, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y, std::vector<ColorSpinorField*> &z) {
957  // a is upper triangular.
958  // first pass does the caxpyz on the diagonal
959  caxpyz_recurse(a, x, y, z, 0, 0, 0, 1);
960  // second pass does caxpy on the off diagonals
961  caxpyz_recurse(a, x, y, z, 0, 0, 1, 1);
962  }
963 
964  void caxpyz_L(const Complex *a, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y, std::vector<ColorSpinorField*> &z) {
965  // a is upper triangular.
966  // first pass does the caxpyz on the diagonal
967  caxpyz_recurse(a, x, y, z, 0, 0, 0, -1);
968  // second pass does caxpy on the off diagonals
969  caxpyz_recurse(a, x, y, z, 0, 0, 1, -1);
970  }
971 
972 
974  caxpyz(a, x.Components(), y.Components(), z.Components());
975  }
976 
978  caxpyz_U(a, x.Components(), y.Components(), z.Components());
979  }
980 
982  caxpyz_L(a, x.Components(), y.Components(), z.Components());
983  }
984 
985  void axpyBzpcx(const double *a_, std::vector<ColorSpinorField*> &x_, std::vector<ColorSpinorField*> &y_,
986  const double *b_, ColorSpinorField &z_, const double *c_) {
987 
988  if (y_.size() <= MAX_MULTI_BLAS_N) {
989  // swizzle order since we are writing to x_ and y_, but the
990  // multi-blas only allow writing to y and w, and moreover the
991  // block width of y and w must match, and x and z must match.
992  std::vector<ColorSpinorField*> &y = y_;
993  std::vector<ColorSpinorField*> &w = x_;
994 
995  // wrap a container around the third solo vector
996  std::vector<ColorSpinorField*> x;
997  x.push_back(&z_);
998 
999  // we will curry the parameter arrays into the functor
1000  coeff_array<double> a(a_,false), b(b_,false), c(c_,false);
1001 
1002  if (x[0]->Precision() != y[0]->Precision() ) {
1003  mixedMultiBlas<1, multi_axpyBzpcx_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
1004  } else {
1005  multiBlas<1, multi_axpyBzpcx_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
1006  }
1007  } else {
1008  // split the problem in half and recurse
1009  const double *a0 = &a_[0];
1010  const double *b0 = &b_[0];
1011  const double *c0 = &c_[0];
1012 
1013  std::vector<ColorSpinorField*> x0(x_.begin(), x_.begin() + x_.size()/2);
1014  std::vector<ColorSpinorField*> y0(y_.begin(), y_.begin() + y_.size()/2);
1015 
1016  axpyBzpcx(a0, x0, y0, b0, z_, c0);
1017 
1018  const double *a1 = &a_[y_.size()/2];
1019  const double *b1 = &b_[y_.size()/2];
1020  const double *c1 = &c_[y_.size()/2];
1021 
1022  std::vector<ColorSpinorField*> x1(x_.begin() + x_.size()/2, x_.end());
1023  std::vector<ColorSpinorField*> y1(y_.begin() + y_.size()/2, y_.end());
1024 
1025  axpyBzpcx(a1, x1, y1, b1, z_, c1);
1026  }
1027  }
1028 
1029  void caxpyBxpz(const Complex *a_, std::vector<ColorSpinorField*> &x_, ColorSpinorField &y_,
1030  const Complex *b_, ColorSpinorField &z_)
1031  {
1032 
1033  const int xsize = x_.size();
1034  if (xsize <= MAX_MULTI_BLAS_N) // only swizzle if we have to.
1035  {
1036  // swizzle order since we are writing to y_ and z_, but the
1037  // multi-blas only allow writing to y and w, and moreover the
1038  // block width of y and w must match, and x and z must match.
1039  // Also, wrap a container around them.
1040  std::vector<ColorSpinorField*> y;
1041  y.push_back(&y_);
1042  std::vector<ColorSpinorField*> w;
1043  w.push_back(&z_);
1044 
1045  // we're reading from x
1046  std::vector<ColorSpinorField*> &x = x_;
1047 
1048  // put a and b into constant space
1049  coeff_array<Complex> a(a_,true), b(b_,true), c;
1050 
1051  if (x[0]->Precision() != y[0]->Precision() )
1052  {
1053  switch(xsize)
1054  {
1055  case 1: mixedMultiBlas<1, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1056 #if MAX_MULTI_BLAS_N >= 2
1057  case 2: mixedMultiBlas<2, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1058 #if MAX_MULTI_BLAS_N >= 3
1059  case 3: mixedMultiBlas<3, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1060 #if MAX_MULTI_BLAS_N >= 4
1061  case 4: mixedMultiBlas<4, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1062 #if MAX_MULTI_BLAS_N >= 5
1063  case 5: mixedMultiBlas<5, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1064 #if MAX_MULTI_BLAS_N >= 6
1065  case 6: mixedMultiBlas<6, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1066 #if MAX_MULTI_BLAS_N >= 7
1067  case 7: mixedMultiBlas<7, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1068 #if MAX_MULTI_BLAS_N >= 8
1069  case 8: mixedMultiBlas<8, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1070 #if MAX_MULTI_BLAS_N >= 9
1071  case 9: mixedMultiBlas<9, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1072 #if MAX_MULTI_BLAS_N >= 10
1073  case 10: mixedMultiBlas<10, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1074 #if MAX_MULTI_BLAS_N >= 11
1075  case 11: mixedMultiBlas<11, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1076 #if MAX_MULTI_BLAS_N >= 12
1077  case 12: mixedMultiBlas<12, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1078 #if MAX_MULTI_BLAS_N >= 13
1079  case 13: mixedMultiBlas<13, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1080 #if MAX_MULTI_BLAS_N >= 14
1081  case 14: mixedMultiBlas<14, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1082 #if MAX_MULTI_BLAS_N >= 15
1083  case 15: mixedMultiBlas<15, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1084 #if MAX_MULTI_BLAS_N >= 16
1085  case 16: mixedMultiBlas<16, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1086 #endif // 16
1087 #endif // 15
1088 #endif // 14
1089 #endif // 13
1090 #endif // 12
1091 #endif // 11
1092 #endif // 10
1093 #endif // 9
1094 #endif // 8
1095 #endif // 7
1096 #endif // 6
1097 #endif // 5
1098 #endif // 4
1099 #endif // 3
1100 #endif // 2
1101  default:
1102  // we can't hit the default, it ends up in the else below.
1103  break;
1104  }
1105  }
1106  else
1107  {
1108  switch(xsize)
1109  {
1110  case 1: multiBlas<1, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1111 #if MAX_MULTI_BLAS_N >= 2
1112  case 2: multiBlas<2, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1113 #if MAX_MULTI_BLAS_N >= 3
1114  case 3: multiBlas<3, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1115 #if MAX_MULTI_BLAS_N >= 4
1116  case 4: multiBlas<4, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1117 #if MAX_MULTI_BLAS_N >= 5
1118  case 5: multiBlas<5, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1119 #if MAX_MULTI_BLAS_N >= 6
1120  case 6: multiBlas<6, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1121 #if MAX_MULTI_BLAS_N >= 7
1122  case 7: multiBlas<7, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1123 #if MAX_MULTI_BLAS_N >= 8
1124  case 8: multiBlas<8, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1125 #if MAX_MULTI_BLAS_N >= 9
1126  case 9: multiBlas<9, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1127 #if MAX_MULTI_BLAS_N >= 10
1128  case 10: multiBlas<10, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1129 #if MAX_MULTI_BLAS_N >= 11
1130  case 11: multiBlas<11, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1131 #if MAX_MULTI_BLAS_N >= 12
1132  case 12: multiBlas<12, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1133 #if MAX_MULTI_BLAS_N >= 13
1134  case 13: multiBlas<13, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1135 #if MAX_MULTI_BLAS_N >= 14
1136  case 14: multiBlas<14, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1137 #if MAX_MULTI_BLAS_N >= 15
1138  case 15: multiBlas<15, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1139 #if MAX_MULTI_BLAS_N >= 16
1140  case 16: multiBlas<16, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w); break;
1141 #endif // 16
1142 #endif // 15
1143 #endif // 14
1144 #endif // 13
1145 #endif // 12
1146 #endif // 11
1147 #endif // 10
1148 #endif // 9
1149 #endif // 8
1150 #endif // 7
1151 #endif // 6
1152 #endif // 5
1153 #endif // 4
1154 #endif // 3
1155 #endif // 2
1156  default:
1157  // we can't hit the default, it ends up in the else below.
1158  break;
1159  }
1160  }
1161  } else {
1162  // split the problem in half and recurse
1163  const Complex *a0 = &a_[0];
1164  const Complex *b0 = &b_[0];
1165 
1166  std::vector<ColorSpinorField*> x0(x_.begin(), x_.begin() + x_.size()/2);
1167 
1168  caxpyBxpz(a0, x0, y_, b0, z_);
1169 
1170  const Complex *a1 = &a_[x_.size()/2];
1171  const Complex *b1 = &b_[x_.size()/2];
1172 
1173  std::vector<ColorSpinorField*> x1(x_.begin() + x_.size()/2, x_.end());
1174 
1175  caxpyBxpz(a1, x1, y_, b1, z_);
1176  }
1177  }
1178 
1179  } // namespace blas
1180 
1181 } // namespace quda
void caxpyz(const Complex *a, std::vector< ColorSpinorField *> &x, std::vector< ColorSpinorField *> &y, std::vector< ColorSpinorField *> &z)
Compute the block "caxpyz" with over the set of ColorSpinorFields. E.g., it computes.
void caxpyz_U(const Complex *a, std::vector< ColorSpinorField *> &x, std::vector< ColorSpinorField *> &y, std::vector< ColorSpinorField *> &z)
Compute the block "caxpyz" with over the set of ColorSpinorFields. E.g., it computes.
const coeff_array< T > & c
static __constant__ signed char Cmatrix_d[MAX_MATRIX_SIZE]
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
SpinorY Y[MAX_MULTI_BLAS_N]
#define errorQuda(...)
Definition: util_quda.h:121
Parameter struct for generic multi-blas kernel.
Helper file when using jitify run-time compilation. This file should be included in source code...
static __constant__ signed char Amatrix_d[MAX_MATRIX_SIZE]
cudaStream_t * stream
CompositeColorSpinorField & Components()
void set(const cudaColorSpinorField &x)
Definition: texture.h:321
void caxpy_U(const Complex *a, std::vector< ColorSpinorField *> &x, std::vector< ColorSpinorField *> &y)
Compute the block "caxpy_U" with over the set of ColorSpinorFields. E.g., it computes.
void initTuneParam(TuneParam &param) const
Definition: tune_quda.h:466
std::vector< ColorSpinorField * > & z
int length[]
static constexpr int Y
int Nspin
Definition: blas_test.cu:45
cpuGaugeField * Y_h
void caxpyBxpz(const Complex &, ColorSpinorField &, ColorSpinorField &, const Complex &, ColorSpinorField &)
Definition: blas_quda.cu:574
QudaGaugeParam param
Definition: pack_test.cpp:17
cudaStream_t * getStream()
Definition: blas_quda.cu:494
std::vector< ColorSpinorField * > CompositeColorSpinorField
void defaultTuneParam(TuneParam &param) const
void apply(const cudaStream_t &stream)
void mixedMultiBlas(const coeff_array< T > &a, const coeff_array< T > &b, const coeff_array< T > &c, CompositeColorSpinorField &x, CompositeColorSpinorField &y, CompositeColorSpinorField &z, CompositeColorSpinorField &w)
MultiBlasArg< NXZ, SpinorX, SpinorY, SpinorZ, SpinorW, Functor > arg
static constexpr int Z
void initTuneParam(TuneParam &param) const
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
void caxpy_recurse(const Complex *a_, std::vector< ColorSpinorField *> &x, std::vector< ColorSpinorField *> &y, int i_idx, int j_idx, int upper)
#define checkLocation(...)
void axpyBzpcx(double a, ColorSpinorField &x, ColorSpinorField &y, double b, ColorSpinorField &z, double c)
Definition: blas_quda.cu:541
static signed char * Bmatrix_h
void defaultTuneParam(TuneParam &param) const
Definition: tune_quda.h:474
static constexpr int W
static __constant__ signed char Bmatrix_d[MAX_MATRIX_SIZE]
std::complex< double > Complex
Definition: quda_internal.h:46
#define MAX_MATRIX_SIZE
long long bytes() const
static constexpr int X
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:512
bool tuneSharedBytes() const
void caxpy_L(const Complex *a, std::vector< ColorSpinorField *> &x, std::vector< ColorSpinorField *> &y)
Compute the block "caxpy_L" with over the set of ColorSpinorFields. E.g., it computes.
long long flops() const
void set(const cudaColorSpinorField &x, int nFace=1)
Definition: texture.h:196
void multiBlas(const coeff_array< T > &a, const coeff_array< T > &b, const coeff_array< T > &c, std::vector< ColorSpinorField *> &x, std::vector< ColorSpinorField *> &y, std::vector< ColorSpinorField *> &z, std::vector< ColorSpinorField *> &w, int length)
void caxpyz_L(const Complex *a, std::vector< ColorSpinorField *> &x, std::vector< ColorSpinorField *> &y, std::vector< ColorSpinorField *> &z)
Compute the block "caxpyz" with over the set of ColorSpinorFields. E.g., it computes.
int Ncolor
Definition: blas_test.cu:46
static __constant__ signed char arg_buffer[MAX_MATRIX_SIZE]
unsigned long long flops
Definition: blas_quda.cu:22
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
TuneKey tuneKey() const
SpinorW W[MAX_MULTI_BLAS_N]
#define checkCudaError()
Definition: util_quda.h:161
void caxpyz_recurse(const Complex *a_, std::vector< ColorSpinorField *> &x, std::vector< ColorSpinorField *> &y, std::vector< ColorSpinorField *> &z, int i, int j, int pass, int upper)
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
__device__ unsigned int count[QUDA_MAX_MULTI_REDUCE]
Definition: cub_helper.cuh:90
static signed char * Amatrix_h
static signed char * Cmatrix_h
static const int name_n
Definition: tune_key.h:11
unsigned long long bytes
Definition: blas_quda.cu:23
MultiBlas(SpinorX X[], SpinorY Y[], SpinorZ Z[], SpinorW W[], Functor &f, const coeff_array< T > &a, const coeff_array< T > &b, const coeff_array< T > &c, std::vector< ColorSpinorField *> &x, std::vector< ColorSpinorField *> &y, std::vector< ColorSpinorField *> &z, std::vector< ColorSpinorField *> &w, int NYW, int length)
#define MAX_MULTI_BLAS_N