3 #include <cstring> // needed for memset
7 #include <color_spinor_field.h>
9 #include <jitify_helper.cuh>
10 #include <kernels/multi_blas_core.cuh>
16 qudaStream_t* getStream();
18 template <template <typename ...> class Functor, typename store_t, typename y_store_t, int nSpin, typename T>
19 class MultiBlas : public TunableVectorY
21 using real = typename mapper<y_store_t>::type;
26 mutable int warp_split; // helper used to keep track of current warp splitting
29 std::vector<ColorSpinorField *> &x, &y, &z, &w;
30 const QudaFieldLocation location;
32 bool tuneSharedBytes() const { return false; }
34 // for these streaming kernels, there is no need to tune the grid size, just use max
35 unsigned int minGridSize() const { return maxGridSize(); }
38 MultiBlas(const T &a, const T &b, const T &c, const ColorSpinorField &x_meta, const ColorSpinorField &y_meta,
39 std::vector<ColorSpinorField *> &x, std::vector<ColorSpinorField *> &y,
40 std::vector<ColorSpinorField *> &z, std::vector<ColorSpinorField *> &w) :
41 TunableVectorY(y.size()),
46 nParity(x[0]->SiteSubset()),
54 location(checkLocation(*x[0], *y[0], *z[0], *w[0]))
56 checkLength(*x[0], *y[0], *z[0], *w[0]);
57 auto x_prec = checkPrecision(*x[0], *z[0], *w[0]);
58 auto y_prec = y[0]->Precision();
59 auto x_order = checkOrder(*x[0], *z[0], *w[0]);
60 auto y_order = y[0]->FieldOrder();
61 if (sizeof(store_t) != x_prec) errorQuda("Expected precision %lu but received %d", sizeof(store_t), x_prec);
62 if (sizeof(y_store_t) != y_prec) errorQuda("Expected precision %lu but received %d", sizeof(y_store_t), y_prec);
63 if (x_prec == y_prec && x_order != y_order) errorQuda("Orders %d %d do not match", x_order, y_order);
65 // heuristic for enabling if we need the warp-splitting optimization
66 const int gpu_size = 2 * deviceProp.maxThreadsPerBlock * deviceProp.multiProcessorCount;
67 switch (gpu_size / (x[0]->Length() * NYW)) {
68 case 0: max_warp_split = 1; break; // we have plenty of work, no need to split
69 case 1: max_warp_split = 2; break; // double the thread count
70 case 2: // quadruple the thread count
71 default: max_warp_split = 4;
73 max_warp_split = std::min(NXZ, max_warp_split); // ensure we only split if valid
75 Amatrix_h = reinterpret_cast<signed char *>(const_cast<typename T::type *>(a.data));
76 Bmatrix_h = reinterpret_cast<signed char *>(const_cast<typename T::type *>(b.data));
77 Cmatrix_h = reinterpret_cast<signed char *>(const_cast<typename T::type *>(c.data));
79 strcpy(aux, x[0]->AuxString());
80 if (x_prec != y_prec) {
82 strcat(aux, y[0]->AuxString());
86 ::quda::create_jitify_program("kernels/multi_blas_core.cuh");
91 blas::bytes += bytes();
92 blas::flops += flops();
95 TuneKey tuneKey() const
97 char name[TuneKey::name_n];
100 u32toa(NXZ_str, NXZ);
101 u32toa(NYW_str, NYW);
103 strcat(name, NXZ_str);
105 strcat(name, NYW_str);
106 strcat(name, typeid(f).name());
107 return TuneKey(x[0]->VolString(), name, aux);
110 template <bool multi_1d, typename device_buffer_t, typename Arg> typename std::enable_if<multi_1d, void>::type
111 set_param(device_buffer_t &&buf_d, Arg &arg, char select, const T &h, const qudaStream_t &stream)
113 using coeff_t = typename decltype(arg.f)::coeff_t;
114 coeff_t *buf_arg = nullptr;
116 case 'a': buf_arg = arg.f.a; break;
117 case 'b': buf_arg = arg.f.b; break;
118 case 'c': buf_arg = arg.f.c; break;
119 default: errorQuda("Unknown buffer %c", select);
121 const auto N = std::max(NXZ,NYW);
122 for (int i = 0; i < N; i++) buf_arg[i] = coeff_t(h.data[i]);
125 template <bool multi_1d, typename device_buffer_t, typename Arg> typename std::enable_if<!multi_1d, void>::type
126 set_param(device_buffer_t &&buf_d, Arg &arg, char dummy, const T &h, const qudaStream_t &stream)
128 using coeff_t = typename decltype(arg.f)::coeff_t;
129 constexpr size_t n_coeff = MAX_MATRIX_SIZE / sizeof(coeff_t);
131 coeff_t tmp[n_coeff];
132 for (int i = 0; i < NXZ; i++)
133 for (int j = 0; j < NYW; j++) tmp[NYW * i + j] = coeff_t(h.data[NYW * i + j]);
136 cuMemcpyHtoDAsync(buf_d, tmp, NXZ * NYW * sizeof(coeff_t), stream);
138 cudaMemcpyToSymbolAsync(buf_d, tmp, NXZ * NYW * sizeof(coeff_t), 0, cudaMemcpyHostToDevice, stream);
142 template <int NXZ> void compute(const qudaStream_t &stream)
144 staticCheck<NXZ, store_t, y_store_t, decltype(f)>(f, x, y);
146 constexpr bool site_unroll_check = !std::is_same<store_t, y_store_t>::value || isFixed<store_t>::value;
147 if (site_unroll_check && (x[0]->Ncolor() != 3 || x[0]->Nspin() == 2))
148 errorQuda("site unroll not supported for nSpin = %d nColor = %d", x[0]->Nspin(), x[0]->Ncolor());
150 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
152 if (location == QUDA_CUDA_FIELD_LOCATION) {
153 if (site_unroll_check) checkNative(*x[0], *y[0], *z[0], *w[0]); // require native order when using site_unroll
154 using device_store_t = typename device_type_mapper<store_t>::type;
155 using device_y_store_t = typename device_type_mapper<y_store_t>::type;
156 using device_real_t = typename mapper<device_y_store_t>::type;
157 Functor<device_real_t> f_(NXZ, NYW);
159 // redefine site_unroll with device_store types to ensure we have correct N/Ny/M values
160 constexpr bool site_unroll = !std::is_same<device_store_t, device_y_store_t>::value || isFixed<device_store_t>::value;
161 constexpr int N = n_vector<device_store_t, true, nSpin, site_unroll>();
162 constexpr int Ny = n_vector<device_y_store_t, true, nSpin, site_unroll>();
163 constexpr int M = site_unroll ? (nSpin == 4 ? 24 : 6) : N; // real numbers per thread
164 const int length = x[0]->Length() / (nParity * M);
166 tp.block.x *= tp.aux.x; // include warp-split factor
168 MultiBlasArg<NXZ, device_store_t, N, device_y_store_t, Ny, decltype(f_)> arg(x, y, z, w, f_, NYW, length);
170 using namespace jitify::reflection;
171 auto instance = program->kernel("quda::blas::multiBlasKernel")
172 .instantiate(Type<device_real_t>(), M, NXZ, tp.aux.x, Type<decltype(arg)>());
174 if (a.data) set_param<decltype(f_)::multi_1d>(instance.get_constant_ptr("quda::blas::Amatrix_d"), arg, 'a', a, stream);
175 if (b.data) set_param<decltype(f_)::multi_1d>(instance.get_constant_ptr("quda::blas::Bmatrix_d"), arg, 'b', b, stream);
176 if (c.data) set_param<decltype(f_)::multi_1d>(instance.get_constant_ptr("quda::blas::Cmatrix_d"), arg, 'c', c, stream);
178 jitify_error = instance.configure(tp.grid, tp.block, tp.shared_bytes, stream).launch(arg);
180 if (a.data) { set_param<decltype(f_)::multi_1d>(Amatrix_d, arg, 'a', a, stream); }
181 if (b.data) { set_param<decltype(f_)::multi_1d>(Bmatrix_d, arg, 'b', b, stream); }
182 if (c.data) { set_param<decltype(f_)::multi_1d>(Cmatrix_d, arg, 'c', c, stream); }
184 case 1: qudaLaunchKernel(multiBlasKernel<device_real_t, M, NXZ, 1, decltype(arg)>, tp, stream, arg); break;
186 case 2: qudaLaunchKernel(multiBlasKernel<device_real_t, M, NXZ, 2, decltype(arg)>, tp, stream, arg); break;
187 case 4: qudaLaunchKernel(multiBlasKernel<device_real_t, M, NXZ, 4, decltype(arg)>, tp, stream, arg); break;
189 default: errorQuda("warp-split factor %d not instantiated", tp.aux.x);
193 tp.block.x /= tp.aux.x; // restore block size
195 errorQuda("Only implemented for GPU fields");
199 template <int n> typename std::enable_if<n!=1, void>::type instantiateLinear(const qudaStream_t &stream)
201 if (NXZ == n) compute<n>(stream);
202 else instantiateLinear<n-1>(stream);
205 template <int n> typename std::enable_if<n==1, void>::type instantiateLinear(const qudaStream_t &stream)
210 template <int n> typename std::enable_if<n!=1, void>::type instantiatePow2(const qudaStream_t &stream)
212 if (NXZ == n) compute<n>(stream);
213 else instantiatePow2<n/2>(stream);
216 template <int n> typename std::enable_if<n==1, void>::type instantiatePow2(const qudaStream_t &stream)
221 // instantiate the loop unrolling template
222 template <int NXZ_max> typename std::enable_if<NXZ_max!=1, void>::type instantiate(const qudaStream_t &stream)
224 // if multi-1d then constrain the templates to no larger than max-1d size
225 constexpr int pow2_max = !decltype(f)::multi_1d ? max_NXZ_power2<false, isFixed<store_t>::value>() :
226 std::min(max_N_multi_1d(), max_NXZ_power2<false, isFixed<store_t>::value>());
227 constexpr int linear_max = !decltype(f)::multi_1d ? MAX_MULTI_BLAS_N : std::min(max_N_multi_1d(), MAX_MULTI_BLAS_N);
229 if (NXZ <= pow2_max && is_power2(NXZ)) instantiatePow2<pow2_max>(stream);
230 else if (NXZ <= linear_max) instantiateLinear<linear_max>(stream);
231 else errorQuda("x.size %lu greater than maximum supported size (pow2 = %d, linear = %d)", x.size(), pow2_max, linear_max);
234 template <int NXZ_max> typename std::enable_if<NXZ_max==1, void>::type instantiate(const qudaStream_t &stream)
239 void apply(const qudaStream_t &stream) { instantiate<decltype(f)::NXZ_max>(stream); }
243 for (int i = 0; i < NYW; ++i) {
244 if (f.write.X) x[i]->backup();
245 if (f.write.Y) y[i]->backup();
246 if (f.write.Z) z[i]->backup();
247 if (f.write.W) w[i]->backup();
253 for (int i = 0; i < NYW; ++i) {
254 if (f.write.X) x[i]->restore();
255 if (f.write.Y) y[i]->restore();
256 if (f.write.Z) z[i]->restore();
257 if (f.write.W) w[i]->restore();
261 bool advanceAux(TuneParam ¶m) const
264 if (2 * param.aux.x <= max_warp_split) {
266 warp_split = param.aux.x;
270 warp_split = param.aux.x;
271 // reset the block dimension manually here to pick up the warp_split parameter
272 resetBlockDim(param);
281 int blockStep() const { return deviceProp.warpSize / warp_split; }
282 int blockMin() const { return deviceProp.warpSize / warp_split; }
284 void initTuneParam(TuneParam ¶m) const
286 TunableVectorY::initTuneParam(param);
287 param.grid.z = nParity;
288 param.aux = make_int4(1, 0, 0, 0); // warp-split parameter
291 void defaultTuneParam(TuneParam ¶m) const
293 TunableVectorY::defaultTuneParam(param);
294 param.grid.z = nParity;
295 param.aux = make_int4(1, 0, 0, 0); // warp-split parameter
298 long long flops() const
300 return NYW * NXZ * f.flops() * x[0]->Length();
303 long long bytes() const
305 // X and Z reads are repeated (and hopefully cached) across NYW
306 // each Y and W read/write is done once
307 return NYW * NXZ * (f.read.X + f.write.X) * x[0]->Bytes() +
308 NYW * (f.read.Y + f.write.Y) * y[0]->Bytes() +
309 NYW * NXZ * (f.read.Z + f.write.Z) * z[0]->Bytes() +
310 NYW * (f.read.W + f.write.W) * w[0]->Bytes();
313 int tuningIter() const { return 3; }
316 using range = std::pair<size_t,size_t>;
318 template <template <typename...> class Functor, typename T>
319 void axpy_recurse(const T *a_, std::vector<ColorSpinorField *> &x, std::vector<ColorSpinorField *> &y,
320 const range &range_x, const range &range_y, int upper, int coeff_width)
322 // if greater than max single-kernel size, recurse
323 if (y.size() > (size_t)max_YW_size(x.size(), x[0]->Precision(), y[0]->Precision(), false, false, coeff_width, false)) {
324 // We need to split up 'a' carefully since it's row-major.
325 T *tmpmajor = new T[x.size() * y.size()];
326 T *tmpmajor0 = &tmpmajor[0];
327 T *tmpmajor1 = &tmpmajor[x.size() * (y.size() / 2)];
328 std::vector<ColorSpinorField*> y0(y.begin(), y.begin() + y.size()/2);
329 std::vector<ColorSpinorField*> y1(y.begin() + y.size()/2, y.end());
331 const unsigned int xlen = x.size();
332 const unsigned int ylen0 = y.size()/2;
333 const unsigned int ylen1 = y.size() - y.size()/2;
335 int count = 0, count0 = 0, count1 = 0;
336 for (unsigned int i = 0; i < xlen; i++)
338 for (unsigned int j = 0; j < ylen0; j++)
339 tmpmajor0[count0++] = a_[count++];
340 for (unsigned int j = 0; j < ylen1; j++)
341 tmpmajor1[count1++] = a_[count++];
344 axpy_recurse<Functor>(tmpmajor0, x, y0, range_x, range(range_y.first, range_y.first + y0.size()), upper, coeff_width);
345 axpy_recurse<Functor>(tmpmajor1, x, y1, range_x, range(range_y.first + y0.size(), range_y.second), upper, coeff_width);
349 // if at the bottom of recursion,
350 if (is_valid_NXZ(x.size(), false, x[0]->Precision() < QUDA_SINGLE_PRECISION)) {
351 // since tile range is [first,second), e.g., [first,second-1], we need >= here
352 // if upper triangular and upper-right tile corner is below diagonal return
353 if (upper == 1 && range_y.first >= range_x.second) { return; }
354 // if lower triangular and lower-left tile corner is above diagonal return
355 if (upper == -1 && range_x.first >= range_y.second) { return; }
357 // mark true since we will copy the "a" matrix into constant memory
358 coeff_array<T> a(a_), b, c;
359 constexpr bool mixed = true;
360 instantiate<Functor, MultiBlas, mixed>(a, b, c, *x[0], *y[0], x, y, x, x);
362 // split the problem in half and recurse
363 const T *a0 = &a_[0];
364 const T *a1 = &a_[(x.size() / 2) * y.size()];
366 std::vector<ColorSpinorField *> x0(x.begin(), x.begin() + x.size() / 2);
367 std::vector<ColorSpinorField *> x1(x.begin() + x.size() / 2, x.end());
369 axpy_recurse<Functor>(a0, x0, y, range(range_x.first, range_x.first + x0.size()), range_y, upper, coeff_width);
370 axpy_recurse<Functor>(a1, x1, y, range(range_x.first + x0.size(), range_x.second), range_y, upper, coeff_width);
372 } // end if (y.size() > max_YW_size())
375 void caxpy(const Complex *a_, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y) {
376 // Enter a recursion.
377 // Pass a, x, y. (0,0) indexes the tiles. false specifies the matrix is unstructured.
378 axpy_recurse<multicaxpy_>(a_, x, y, range(0,x.size()), range(0,y.size()), 0, 2);
381 void caxpy_U(const Complex *a_, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y) {
382 // Enter a recursion.
383 // Pass a, x, y. (0,0) indexes the tiles. 1 indicates the matrix is upper-triangular,
384 // which lets us skip some tiles.
385 if (x.size() != y.size())
387 errorQuda("An optimal block caxpy_U with non-square 'a' has not yet been implemented. Use block caxpy instead");
389 axpy_recurse<multicaxpy_>(a_, x, y, range(0,x.size()), range(0,y.size()), 1, 2);
392 void caxpy_L(const Complex *a_, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y) {
393 // Enter a recursion.
394 // Pass a, x, y. (0,0) indexes the tiles. -1 indicates the matrix is lower-triangular
395 // which lets us skip some tiles.
396 if (x.size() != y.size())
398 errorQuda("An optimal block caxpy_L with non-square 'a' has not yet been implemented. Use block caxpy instead");
400 axpy_recurse<multicaxpy_>(a_, x, y, range(0,x.size()), range(0,y.size()), -1, 2);
403 void caxpy(const Complex *a, ColorSpinorField &x, ColorSpinorField &y) { caxpy(a, x.Components(), y.Components()); }
405 void caxpy_U(const Complex *a, ColorSpinorField &x, ColorSpinorField &y) { caxpy_U(a, x.Components(), y.Components()); }
407 void caxpy_L(const Complex *a, ColorSpinorField &x, ColorSpinorField &y) { caxpy_L(a, x.Components(), y.Components()); }
409 void caxpyz_recurse(const Complex *a_, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y,
410 std::vector<ColorSpinorField*> &z, const range &range_x, const range &range_y,
413 // if greater than max single-kernel size, recurse
414 if (y.size() > (size_t)max_YW_size(x.size(), x[0]->Precision(), y[0]->Precision(), false, true, 2, false)) {
415 // We need to split up 'a' carefully since it's row-major.
416 Complex* tmpmajor = new Complex[x.size()*y.size()];
417 Complex* tmpmajor0 = &tmpmajor[0];
418 Complex* tmpmajor1 = &tmpmajor[x.size()*(y.size()/2)];
419 std::vector<ColorSpinorField*> y0(y.begin(), y.begin() + y.size()/2);
420 std::vector<ColorSpinorField*> y1(y.begin() + y.size()/2, y.end());
422 std::vector<ColorSpinorField*> z0(z.begin(), z.begin() + z.size()/2);
423 std::vector<ColorSpinorField*> z1(z.begin() + z.size()/2, z.end());
425 const unsigned int xlen = x.size();
426 const unsigned int ylen0 = y.size()/2;
427 const unsigned int ylen1 = y.size() - y.size()/2;
429 int count = 0, count0 = 0, count1 = 0;
430 for (unsigned int i_ = 0; i_ < xlen; i_++)
432 for (unsigned int j = 0; j < ylen0; j++)
433 tmpmajor0[count0++] = a_[count++];
434 for (unsigned int j = 0; j < ylen1; j++)
435 tmpmajor1[count1++] = a_[count++];
438 caxpyz_recurse(tmpmajor0, x, y0, z0, range_x, range(range_y.first, range_y.first + y0.size()), pass, upper);
439 caxpyz_recurse(tmpmajor1, x, y1, z1, range_x, range(range_y.first + y0.size(), range_y.second), pass, upper);
443 // if at bottom of recursion check where we are
444 if (is_valid_NXZ(x.size(), false, x[0]->Precision() < QUDA_SINGLE_PRECISION)) {
445 // check if tile straddles diagonal
446 bool is_diagonal = (range_x.first < range_y.second) && (range_y.first < range_x.second);
449 // if upper triangular and upper-right tile corner is below diagonal return
450 if (upper == 1 && range_y.first >= range_x.second) { return; }
451 // if lower triangular and lower-left tile corner is above diagonal return
452 if (upper == -1 && range_x.first >= range_y.second) { return; }
453 caxpy(a_, x, z); return; // off diagonal
457 if (!is_diagonal) return; // We're on the first pass, so we only want to update the diagonal.
460 coeff_array<Complex> a(a_), b, c;
461 constexpr bool mixed = false;
462 instantiate<multicaxpyz_, MultiBlas, mixed>(a, b, c, *x[0], *y[0], x, y, x, z);
464 // split the problem in half and recurse
465 const Complex *a0 = &a_[0];
466 const Complex *a1 = &a_[(x.size() / 2) * y.size()];
468 std::vector<ColorSpinorField *> x0(x.begin(), x.begin() + x.size() / 2);
469 std::vector<ColorSpinorField *> x1(x.begin() + x.size() / 2, x.end());
471 caxpyz_recurse(a0, x0, y, z, range(range_x.first, range_x.first + x0.size()), range_y, pass, upper);
472 caxpyz_recurse(a1, x1, y, z, range(range_x.first + x0.size(), range_x.second), range_y, pass, upper);
474 } // end if (y.size() > max_YW_size())
477 void caxpyz(const Complex *a, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y, std::vector<ColorSpinorField*> &z)
479 // first pass does the caxpyz on the diagonal
480 caxpyz_recurse(a, x, y, z, range(0, x.size()), range(0, y.size()), 0, 0);
481 // second pass does caxpy on the off diagonals
482 caxpyz_recurse(a, x, y, z, range(0, x.size()), range(0, y.size()), 1, 0);
485 void caxpyz_U(const Complex *a, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y, std::vector<ColorSpinorField*> &z)
487 // a is upper triangular.
488 // first pass does the caxpyz on the diagonal
489 caxpyz_recurse(a, x, y, z, range(0, x.size()), range(0, y.size()), 0, 1);
490 // second pass does caxpy on the off diagonals
491 caxpyz_recurse(a, x, y, z, range(0, x.size()), range(0, y.size()), 1, 1);
494 void caxpyz_L(const Complex *a, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y, std::vector<ColorSpinorField*> &z)
496 // a is upper triangular.
497 // first pass does the caxpyz on the diagonal
498 caxpyz_recurse(a, x, y, z, range(0, x.size()), range(0, y.size()), 0, -1);
499 // second pass does caxpy on the off diagonals
500 caxpyz_recurse(a, x, y, z, range(0, x.size()), range(0, y.size()), 1, -1);
504 void caxpyz(const Complex *a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
506 caxpyz(a, x.Components(), y.Components(), z.Components());
509 void caxpyz_U(const Complex *a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
511 caxpyz_U(a, x.Components(), y.Components(), z.Components());
514 void caxpyz_L(const Complex *a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
516 caxpyz_L(a, x.Components(), y.Components(), z.Components());
519 void axpyBzpcx(const double *a_, std::vector<ColorSpinorField *> &x_, std::vector<ColorSpinorField *> &y_,
520 const double *b_, ColorSpinorField &z_, const double *c_)
522 if (y_.size() <= (size_t)max_N_multi_1d()) {
523 // swizzle order since we are writing to x_ and y_, but the
524 // multi-blas only allow writing to y and w, and moreover the
525 // block width of y and w must match, and x and z must match.
526 std::vector<ColorSpinorField*> &y = y_;
527 std::vector<ColorSpinorField*> &w = x_;
529 // wrap a container around the third solo vector
530 std::vector<ColorSpinorField*> x;
533 coeff_array<double> a(a_), b(b_), c(c_);
534 constexpr bool mixed = true;
535 instantiate<multi_axpyBzpcx_, MultiBlas, mixed>(a, b, c, *x[0], *y[0], x, y, x, w);
537 // split the problem in half and recurse
538 const double *a0 = &a_[0];
539 const double *b0 = &b_[0];
540 const double *c0 = &c_[0];
542 std::vector<ColorSpinorField*> x0(x_.begin(), x_.begin() + x_.size()/2);
543 std::vector<ColorSpinorField*> y0(y_.begin(), y_.begin() + y_.size()/2);
545 axpyBzpcx(a0, x0, y0, b0, z_, c0);
547 const double *a1 = &a_[y_.size()/2];
548 const double *b1 = &b_[y_.size()/2];
549 const double *c1 = &c_[y_.size()/2];
551 std::vector<ColorSpinorField*> x1(x_.begin() + x_.size()/2, x_.end());
552 std::vector<ColorSpinorField*> y1(y_.begin() + y_.size()/2, y_.end());
554 axpyBzpcx(a1, x1, y1, b1, z_, c1);
558 void caxpyBxpz(const Complex *a_, std::vector<ColorSpinorField*> &x_, ColorSpinorField &y_,
559 const Complex *b_, ColorSpinorField &z_)
561 if (x_.size() <= (size_t)max_N_multi_1d() &&
562 is_valid_NXZ(x_.size(), false, x_[0]->Precision() < QUDA_SINGLE_PRECISION)) // only split if we have to.
564 // swizzle order since we are writing to y_ and z_, but the
565 // multi-blas only allow writing to y and w, and moreover the
566 // block width of y and w must match, and x and z must match.
567 // Also, wrap a container around them.
568 std::vector<ColorSpinorField*> y;
570 std::vector<ColorSpinorField*> w;
573 // we're reading from x
574 std::vector<ColorSpinorField*> &x = x_;
576 coeff_array<Complex> a(a_), b(b_), c;
577 constexpr bool mixed = true;
578 instantiate<multi_caxpyBxpz_, MultiBlas, mixed>(a, b, c, *x[0], *y[0], x, y, x, w);
580 // split the problem in half and recurse
581 const Complex *a0 = &a_[0];
582 const Complex *b0 = &b_[0];
584 std::vector<ColorSpinorField*> x0(x_.begin(), x_.begin() + x_.size()/2);
586 caxpyBxpz(a0, x0, y_, b0, z_);
588 const Complex *a1 = &a_[x_.size()/2];
589 const Complex *b1 = &b_[x_.size()/2];
591 std::vector<ColorSpinorField*> x1(x_.begin() + x_.size()/2, x_.end());
593 caxpyBxpz(a1, x1, y_, b1, z_);
597 void axpy(const double *a_, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y)
599 // Enter a recursion.
600 // Pass a, x, y. (0,0) indexes the tiles. false specifies the matrix is unstructured.
601 axpy_recurse<multiaxpy_>(a_, x, y, range(0, x.size()), range(0, y.size()), 0, 1);
604 void axpy_U(const double *a_, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y)
606 // Enter a recursion.
607 // Pass a, x, y. (0,0) indexes the tiles. 1 indicates the matrix is upper-triangular,
608 // which lets us skip some tiles.
609 if (x.size() != y.size())
611 errorQuda("An optimal block axpy_U with non-square 'a' has not yet been implemented. Use block axpy instead");
613 axpy_recurse<multiaxpy_>(a_, x, y, range(0, x.size()), range(0, y.size()), 1, 1);
616 void axpy_L(const double *a_, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y)
618 // Enter a recursion.
619 // Pass a, x, y. (0,0) indexes the tiles. -1 indicates the matrix is lower-triangular
620 // which lets us skip some tiles.
621 if (x.size() != y.size())
623 errorQuda("An optimal block axpy_L with non-square 'a' has not yet been implemented. Use block axpy instead");
625 axpy_recurse<multiaxpy_>(a_, x, y, range(0, x.size()), range(0, y.size()), -1, 1);
628 // Composite field version
629 void axpy(const double *a, ColorSpinorField &x, ColorSpinorField &y) { axpy(a, x.Components(), y.Components()); }
631 void axpy_U(const double *a, ColorSpinorField &x, ColorSpinorField &y) { axpy_U(a, x.Components(), y.Components()); }
633 void axpy_L(const double *a, ColorSpinorField &x, ColorSpinorField &y) { axpy_L(a, x.Components(), y.Components()); }