19 template <
int writeX,
int writeY,
int writeZ,
int writeW>
21 static constexpr
int X = writeX;
22 static constexpr
int Y = writeY;
23 static constexpr
int Z = writeZ;
24 static constexpr
int W = writeW;
29 template <
unsigned... digits>
struct to_chars {
30 static const char value[];
33 template <
unsigned... digits>
const char to_chars<digits...>::value[] = {(
'0' + digits)..., 0};
35 template <
unsigned rem,
unsigned... digits>
struct explode :
explode<rem / 10, rem % 10, digits...> {
38 template <
unsigned... digits>
struct explode<0, digits...> :
to_chars<digits...> {
45 template <
int NXZ,
typename FloatN,
int M,
typename SpinorX,
typename SpinorY,
typename SpinorZ,
typename SpinorW,
46 typename Functor,
typename T>
56 std::vector<ColorSpinorField *> &x, &y, &
z, &w;
67 std::vector<ColorSpinorField *> &y, std::vector<ColorSpinorField *> &z, std::vector<ColorSpinorField *> &w,
71 nParity(x[0]->SiteSubset()),
72 arg(X, Y, Z, W, f, NYW, length / nParity),
85 Amatrix_h =
reinterpret_cast<signed char *
>(
const_cast<T *
>(a.
data));
86 Bmatrix_h =
reinterpret_cast<signed char *
>(
const_cast<T *
>(b.
data));
87 Cmatrix_h =
reinterpret_cast<signed char *
>(
const_cast<T *
>(c.
data));
89 strcpy(aux, x[0]->AuxString());
90 if (x[0]->Precision() != y[0]->Precision()) {
92 strcat(aux, y[0]->AuxString());
96 ::quda::create_jitify_program(
"kernels/multi_blas_core.cuh");
106 strcat(name, std::to_string(NYW).c_str());
107 strcat(name,
typeid(arg.
f).name());
108 return TuneKey(x[0]->VolString(), name, aux);
116 typedef typename vector<Float, 2>::type Float2;
118 using namespace jitify::reflection;
120 = program->kernel(
"quda::blas::multiBlasKernel").instantiate(Type<FloatN>(), M, NXZ, Type<decltype(arg)>());
128 for (
int i = 0; i < NXZ; i++)
129 for (
int j = 0; j < NYW; j++)
132 auto Amatrix_d = instance.get_constant_ptr(
"quda::blas::Amatrix_d");
140 for (
int i = 0; i < NXZ; i++)
141 for (
int j = 0; j < NYW; j++)
144 auto Bmatrix_d = instance.get_constant_ptr(
"quda::blas::Bmatrix_d");
152 for (
int i = 0; i < NXZ; i++)
153 for (
int j = 0; j < NYW; j++)
156 auto Cmatrix_d = instance.get_constant_ptr(
"quda::blas::Cmatrix_d");
168 for (
int i = 0; i < NXZ; i++)
169 for (
int j = 0; j < NYW; j++)
179 for (
int i = 0; i < NXZ; i++)
180 for (
int j = 0; j < NYW; j++)
190 for (
int i = 0; i < NXZ; i++)
191 for (
int j = 0; j < NYW; j++)
196 #if CUDA_VERSION < 9000 197 cudaMemcpyToSymbolAsync(
arg_buffer, reinterpret_cast<char *>(&arg),
sizeof(arg), 0, cudaMemcpyHostToDevice,
206 for (
int i = 0; i < NYW; ++i) {
207 arg.
Y[i].backup(&Y_h[i], &Ynorm_h[i], y[i]->Bytes(), y[i]->NormBytes());
208 arg.
W[i].backup(&W_h[i], &Wnorm_h[i], w[i]->Bytes(), w[i]->NormBytes());
214 for (
int i = 0; i < NYW; ++i) {
215 arg.
Y[i].restore(&Y_h[i], &Ynorm_h[i], y[i]->Bytes(), y[i]->NormBytes());
216 arg.
W[i].restore(&W_h[i], &Wnorm_h[i], w[i]->Bytes(), w[i]->NormBytes());
223 param.
grid.z = nParity;
229 param.
grid.z = nParity;
237 return ((arg.
f.streams() - 2) * x[0]->Bytes() + 2 * y[0]->Bytes());
243 template <
int NXZ,
typename RegType,
typename StoreType,
typename yType,
int M,
244 template <
int,
typename,
typename>
class Functor,
typename write,
typename T>
246 std::vector<ColorSpinorField *> &x, std::vector<ColorSpinorField *> &y, std::vector<ColorSpinorField *> &z,
247 std::vector<ColorSpinorField *> &w,
int length)
249 const int NYW = y.size();
251 const int N = NXZ > NYW ? NXZ : NYW;
258 typedef typename vector<Float, 2>::type Float2;
259 typedef vector<Float, 2> vec2;
266 for (
int i = 0; i < NXZ; i++) {
267 X[i].
set(*dynamic_cast<cudaColorSpinorField *>(x[i]));
268 Z[i].
set(*dynamic_cast<cudaColorSpinorField *>(z[i]));
270 for (
int i = 0; i < NYW; i++) {
271 Y[i].
set(*dynamic_cast<cudaColorSpinorField *>(y[i]));
272 W[i].
set(*dynamic_cast<cudaColorSpinorField *>(w[i]));
277 Functor<NXZ, Float2, RegType> f(a, b, c, NYW);
279 MultiBlas<NXZ, RegType, M, SpinorTexture<RegType, StoreType, M>,
Spinor<RegType, yType, M, write::Y>,
281 blas(X, Y, Z, W, f, a, b, c, x, y, z, w, NYW, length);
293 template <
int NXZ,
template <
int MXZ,
typename Float,
typename FloatN>
class Functor,
typename write,
typename T>
303 #if QUDA_PRECISION & 8 304 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) || defined(GPU_STAGGERED_DIRAC) 306 multiBlas<NXZ, double2, double2, double2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Length() / (2 * M));
308 errorQuda(
"blas has not been built for Nspin=%d fields", x[0]->
Nspin());
311 errorQuda(
"QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x[0]->Precision());
316 #if QUDA_PRECISION & 4 317 if (x[0]->
Nspin() == 4) {
318 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) 320 multiBlas<NXZ, float4, float4, float4, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Length() / (4 * M));
322 errorQuda(
"blas has not been built for Nspin=%d fields", x[0]->
Nspin());
325 }
else if (x[0]->
Nspin() == 2 || x[0]->
Nspin() == 1) {
327 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) || defined(GPU_STAGGERED_DIRAC) 329 multiBlas<NXZ, float2, float2, float2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Length() / (2 * M));
331 errorQuda(
"blas has not been built for Nspin=%d fields", x[0]->
Nspin());
337 errorQuda(
"QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x[0]->Precision());
342 #if QUDA_PRECISION & 2 344 if (x[0]->
Nspin() == 4) {
345 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) 347 multiBlas<NXZ, float4, short4, short4, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
349 errorQuda(
"blas has not been built for Nspin=%d fields", x[0]->
Nspin());
351 }
else if (x[0]->
Nspin() == 1) {
352 #ifdef GPU_STAGGERED_DIRAC 354 multiBlas<NXZ, float2, short2, short2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
356 errorQuda(
"blas has not been built for Nspin=%d fields", x[0]->
Nspin());
362 errorQuda(
"QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x[0]->Precision());
367 #if QUDA_PRECISION & 1 369 if (x[0]->
Nspin() == 4) {
370 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) 372 multiBlas<NXZ, float4, char4, char4, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
374 errorQuda(
"blas has not been built for Nspin=%d fields", x[0]->
Nspin());
376 }
else if (x[0]->
Nspin() == 1) {
377 #ifdef GPU_STAGGERED_DIRAC 379 multiBlas<NXZ, float2, char2, char2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
381 errorQuda(
"blas has not been built for Nspin=%d fields", x[0]->
Nspin());
387 errorQuda(
"QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x[0]->Precision());
392 errorQuda(
"Precision combination x=%d not supported\n", x[0]->Precision());
402 template <
int NXZ,
template <
int MXZ,
typename Float,
typename FloatN>
class Functor,
typename write,
typename T>
411 #if QUDA_PRECISION & 8 414 #if QUDA_PRECISION & 4 415 if (x[0]->
Nspin() == 4) {
416 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) 418 multiBlas<NXZ, double2, float4, double2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
420 errorQuda(
"blas has not been built for Nspin=%d fields", x[0]->
Nspin());
422 }
else if (x[0]->
Nspin() == 1) {
424 #if defined(GPU_STAGGERED_DIRAC) 426 multiBlas<NXZ, double2, float2, double2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
428 errorQuda(
"blas has not been built for Nspin=%d fields", x[0]->
Nspin());
433 errorQuda(
"QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x[0]->Precision());
438 #if QUDA_PRECISION & 2 439 if (x[0]->
Nspin() == 4) {
440 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) 442 multiBlas<NXZ, double2, short4, double2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
444 errorQuda(
"blas has not been built for Nspin=%d fields", x[0]->
Nspin());
447 }
else if (x[0]->
Nspin() == 1) {
449 #if defined(GPU_STAGGERED_DIRAC) 451 multiBlas<NXZ, double2, short2, double2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
453 errorQuda(
"blas has not been built for Nspin=%d fields", x[0]->
Nspin());
457 errorQuda(
"QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x[0]->Precision());
462 #if QUDA_PRECISION & 1 463 if (x[0]->
Nspin() == 4) {
464 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) 466 multiBlas<NXZ, double2, char4, double2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
468 errorQuda(
"blas has not been built for Nspin=%d fields", x[0]->
Nspin());
471 }
else if (x[0]->
Nspin() == 1) {
473 #if defined(GPU_STAGGERED_DIRAC) 475 multiBlas<NXZ, double2, char2, double2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
477 errorQuda(
"blas has not been built for Nspin=%d fields", x[0]->
Nspin());
481 errorQuda(
"QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x[0]->Precision());
485 errorQuda(
"Not implemented for this precision combination %d %d", x[0]->Precision(), y[0]->Precision());
488 errorQuda(
"QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, y[0]->Precision());
493 #if (QUDA_PRECISION & 4) 496 #if (QUDA_PRECISION & 2) 497 if (x[0]->
Nspin() == 4) {
498 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) 500 multiBlas<NXZ, float4, short4, float4, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
502 errorQuda(
"blas has not been built for Nspin=%d fields", x[0]->
Nspin());
505 }
else if (x[0]->
Nspin() == 2 || x[0]->
Nspin() == 1) {
507 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) || defined(GPU_STAGGERED_DIRAC) 509 multiBlas<NXZ, float2, short2, float2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
511 errorQuda(
"blas has not been built for Nspin=%d fields", x[0]->
Nspin());
518 errorQuda(
"QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, y[0]->Precision());
523 #if (QUDA_PRECISION & 1) 524 if (x[0]->
Nspin() == 4) {
525 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) 527 multiBlas<NXZ, float4, char4, float4, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
529 errorQuda(
"blas has not been built for Nspin=%d fields", x[0]->
Nspin());
532 }
else if (x[0]->
Nspin() == 2 || x[0]->
Nspin() == 1) {
534 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) || defined(GPU_STAGGERED_DIRAC) 536 multiBlas<NXZ, float2, char2, float2, M, Functor, write>(a, b, c, x, y, z, w, x[0]->Volume());
538 errorQuda(
"blas has not been built for Nspin=%d fields", x[0]->
Nspin());
545 errorQuda(
"QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, y[0]->Precision());
549 errorQuda(
"Precision combination x=%d y=%d not supported\n", x[0]->Precision(), y[0]->Precision());
552 errorQuda(
"QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, y[0]->Precision());
555 errorQuda(
"Precision combination x=%d y=%d not supported\n", x[0]->Precision(), y[0]->Precision());
563 int i_idx ,
int j_idx,
int upper) {
569 Complex* tmpmajor0 = &tmpmajor[0];
570 Complex* tmpmajor1 = &tmpmajor[x.size()*(y.size()/2)];
571 std::vector<ColorSpinorField*> y0(y.begin(), y.begin() + y.size()/2);
572 std::vector<ColorSpinorField*> y1(y.begin() + y.size()/2, y.end());
574 const unsigned int xlen = x.size();
575 const unsigned int ylen0 = y.size()/2;
576 const unsigned int ylen1 = y.size() - y.size()/2;
578 int count = 0, count0 = 0, count1 = 0;
579 for (
unsigned int i = 0; i < xlen; i++)
581 for (
unsigned int j = 0; j < ylen0; j++)
582 tmpmajor0[count0++] = a_[count++];
583 for (
unsigned int j = 0; j < ylen1; j++)
584 tmpmajor1[count1++] = a_[count++];
598 if (upper == 1 && j_idx < i_idx) {
return; }
599 if (upper == -1 && j_idx > i_idx) {
return; }
605 if (x[0]->Precision() == y[0]->Precision())
608 case 1: multiBlas<1, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
609 #if MAX_MULTI_BLAS_N >= 2 610 case 2: multiBlas<2, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
611 #if MAX_MULTI_BLAS_N >= 3 612 case 3: multiBlas<3, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
613 #if MAX_MULTI_BLAS_N >= 4 614 case 4: multiBlas<4, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
615 #if MAX_MULTI_BLAS_N >= 5 616 case 5: multiBlas<5, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
617 #if MAX_MULTI_BLAS_N >= 6 618 case 6: multiBlas<6, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
619 #if MAX_MULTI_BLAS_N >= 7 620 case 7: multiBlas<7, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
621 #if MAX_MULTI_BLAS_N >= 8 622 case 8: multiBlas<8, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
623 #if MAX_MULTI_BLAS_N >= 9 624 case 9: multiBlas<9, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
625 #if MAX_MULTI_BLAS_N >= 10 626 case 10: multiBlas<10, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
627 #if MAX_MULTI_BLAS_N >= 11 628 case 11: multiBlas<11, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
629 #if MAX_MULTI_BLAS_N >= 12 630 case 12: multiBlas<12, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
631 #if MAX_MULTI_BLAS_N >= 13 632 case 13: multiBlas<13, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
633 #if MAX_MULTI_BLAS_N >= 14 634 case 14: multiBlas<14, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
635 #if MAX_MULTI_BLAS_N >= 15 636 case 15: multiBlas<15, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
637 #if MAX_MULTI_BLAS_N >= 16 638 case 16: multiBlas<16, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
657 const Complex *a1 = &a_[(x.size()/2)*y.size()];
659 std::vector<ColorSpinorField*> x0(x.begin(), x.begin() + x.size()/2);
660 std::vector<ColorSpinorField*> x1(x.begin() + x.size()/2, x.end());
670 case 1: mixedMultiBlas<1, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
671 #if MAX_MULTI_BLAS_N >= 2 672 case 2: mixedMultiBlas<2, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
673 #if MAX_MULTI_BLAS_N >= 3 674 case 3: mixedMultiBlas<3, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
675 #if MAX_MULTI_BLAS_N >= 4 676 case 4: mixedMultiBlas<4, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
677 #if MAX_MULTI_BLAS_N >= 5 678 case 5: mixedMultiBlas<5, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
679 #if MAX_MULTI_BLAS_N >= 6 680 case 6: mixedMultiBlas<6, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
681 #if MAX_MULTI_BLAS_N >= 7 682 case 7: mixedMultiBlas<7, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
683 #if MAX_MULTI_BLAS_N >= 8 684 case 8: mixedMultiBlas<8, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
685 #if MAX_MULTI_BLAS_N >= 9 686 case 9: mixedMultiBlas<9, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
687 #if MAX_MULTI_BLAS_N >= 10 688 case 10: mixedMultiBlas<10, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
689 #if MAX_MULTI_BLAS_N >= 11 690 case 11: mixedMultiBlas<11, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
691 #if MAX_MULTI_BLAS_N >= 12 692 case 12: mixedMultiBlas<12, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
693 #if MAX_MULTI_BLAS_N >= 13 694 case 13: mixedMultiBlas<13, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
695 #if MAX_MULTI_BLAS_N >= 14 696 case 14: mixedMultiBlas<14, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
697 #if MAX_MULTI_BLAS_N >= 15 698 case 15: mixedMultiBlas<15, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
699 #if MAX_MULTI_BLAS_N >= 16 700 case 16: mixedMultiBlas<16, multicaxpy_, write<0, 1, 0, 0>>(a, b, c, x, y, x, y);
break;
719 const Complex *a1 = &a_[(x.size()/2)*y.size()];
721 std::vector<ColorSpinorField*> x0(x.begin(), x.begin() + x.size()/2);
722 std::vector<ColorSpinorField*> x1(x.begin() + x.size()/2, x.end());
732 void caxpy(
const Complex *a_, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y) {
738 void caxpy_U(
const Complex *a_, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y) {
742 if (x.size() != y.size())
744 errorQuda(
"An optimal block caxpy_U with non-square 'a' has not yet been implemented. Use block caxpy instead.\n");
750 void caxpy_L(
const Complex *a_, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y) {
754 if (x.size() != y.size())
756 errorQuda(
"An optimal block caxpy_L with non-square 'a' has not yet been implemented. Use block caxpy instead.\n");
770 void caxpyz_recurse(
const Complex *a_, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y, std::vector<ColorSpinorField*> &z,
int i,
int j,
int pass,
int upper) {
776 Complex* tmpmajor0 = &tmpmajor[0];
777 Complex* tmpmajor1 = &tmpmajor[x.size()*(y.size()/2)];
778 std::vector<ColorSpinorField*> y0(y.begin(), y.begin() + y.size()/2);
779 std::vector<ColorSpinorField*> y1(y.begin() + y.size()/2, y.end());
781 std::vector<ColorSpinorField*> z0(z.begin(), z.begin() + z.size()/2);
782 std::vector<ColorSpinorField*> z1(z.begin() + z.size()/2, z.end());
784 const unsigned int xlen = x.size();
785 const unsigned int ylen0 = y.size()/2;
786 const unsigned int ylen1 = y.size() - y.size()/2;
788 int count = 0, count0 = 0, count1 = 0;
789 for (
unsigned int i_ = 0; i_ < xlen; i_++)
791 for (
unsigned int j = 0; j < ylen0; j++)
792 tmpmajor0[count0++] = a_[count++];
793 for (
unsigned int j = 0; j < ylen1; j++)
794 tmpmajor1[count1++] = a_[count++];
809 if (upper == 1 && j < i) {
return; }
810 if (upper == -1 && i < j) {
return; }
811 caxpy(a_, x, z);
return;
822 if (x[0]->Precision() == y[0]->Precision())
825 case 1: multiBlas<1, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
826 #if MAX_MULTI_BLAS_N >= 2 827 case 2: multiBlas<2, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
828 #if MAX_MULTI_BLAS_N >= 3 829 case 3: multiBlas<3, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
830 #if MAX_MULTI_BLAS_N >= 4 831 case 4: multiBlas<4, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
832 #if MAX_MULTI_BLAS_N >= 5 833 case 5: multiBlas<5, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
834 #if MAX_MULTI_BLAS_N >= 6 835 case 6: multiBlas<6, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
836 #if MAX_MULTI_BLAS_N >= 7 837 case 7: multiBlas<7, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
838 #if MAX_MULTI_BLAS_N >= 8 839 case 8: multiBlas<8, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
840 #if MAX_MULTI_BLAS_N >= 9 841 case 9: multiBlas<9, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
842 #if MAX_MULTI_BLAS_N >= 10 843 case 10: multiBlas<10, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
844 #if MAX_MULTI_BLAS_N >= 11 845 case 11: multiBlas<11, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
846 #if MAX_MULTI_BLAS_N >= 12 847 case 12: multiBlas<12, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
848 #if MAX_MULTI_BLAS_N >= 13 849 case 13: multiBlas<13, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
850 #if MAX_MULTI_BLAS_N >= 14 851 case 14: multiBlas<14, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
852 #if MAX_MULTI_BLAS_N >= 15 853 case 15: multiBlas<15, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
854 #if MAX_MULTI_BLAS_N >= 16 855 case 16: multiBlas<16, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
874 const Complex *a1 = &a_[(x.size()/2)*y.size()];
876 std::vector<ColorSpinorField*> x0(x.begin(), x.begin() + x.size()/2);
877 std::vector<ColorSpinorField*> x1(x.begin() + x.size()/2, x.end());
887 case 1: mixedMultiBlas<1, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
888 #if MAX_MULTI_BLAS_N >= 2 889 case 2: mixedMultiBlas<2, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
890 #if MAX_MULTI_BLAS_N >= 3 891 case 3: mixedMultiBlas<3, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
892 #if MAX_MULTI_BLAS_N >= 4 893 case 4: mixedMultiBlas<4, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
894 #if MAX_MULTI_BLAS_N >= 5 895 case 5: mixedMultiBlas<5, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
896 #if MAX_MULTI_BLAS_N >= 6 897 case 6: mixedMultiBlas<6, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
898 #if MAX_MULTI_BLAS_N >= 7 899 case 7: mixedMultiBlas<7, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
900 #if MAX_MULTI_BLAS_N >= 8 901 case 8: mixedMultiBlas<8, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
902 #if MAX_MULTI_BLAS_N >= 9 903 case 9: mixedMultiBlas<9, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
904 #if MAX_MULTI_BLAS_N >= 10 905 case 10: mixedMultiBlas<10, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
906 #if MAX_MULTI_BLAS_N >= 11 907 case 11: mixedMultiBlas<11, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
908 #if MAX_MULTI_BLAS_N >= 12 909 case 12: mixedMultiBlas<12, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
910 #if MAX_MULTI_BLAS_N >= 13 911 case 13: mixedMultiBlas<13, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
912 #if MAX_MULTI_BLAS_N >= 14 913 case 14: mixedMultiBlas<14, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
914 #if MAX_MULTI_BLAS_N >= 15 915 case 15: mixedMultiBlas<15, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
916 #if MAX_MULTI_BLAS_N >= 16 917 case 16: mixedMultiBlas<16, multicaxpyz_, write<0, 0, 0, 1>>(a, b, c, x, y, x, z);
break;
936 const Complex *a1 = &a_[(x.size()/2)*y.size()];
938 std::vector<ColorSpinorField*> x0(x.begin(), x.begin() + x.size()/2);
939 std::vector<ColorSpinorField*> x1(x.begin() + x.size()/2, x.end());
949 void caxpyz(
const Complex *a, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y, std::vector<ColorSpinorField*> &z) {
956 void caxpyz_U(
const Complex *a, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y, std::vector<ColorSpinorField*> &z) {
964 void caxpyz_L(
const Complex *a, std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y, std::vector<ColorSpinorField*> &z) {
985 void axpyBzpcx(
const double *a_, std::vector<ColorSpinorField*> &x_, std::vector<ColorSpinorField*> &y_,
992 std::vector<ColorSpinorField*> &y = y_;
993 std::vector<ColorSpinorField*> &w = x_;
996 std::vector<ColorSpinorField*> x;
1002 if (x[0]->Precision() != y[0]->Precision() ) {
1003 mixedMultiBlas<1, multi_axpyBzpcx_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
1005 multiBlas<1, multi_axpyBzpcx_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
1009 const double *a0 = &a_[0];
1010 const double *b0 = &b_[0];
1011 const double *c0 = &c_[0];
1013 std::vector<ColorSpinorField*> x0(x_.begin(), x_.begin() + x_.size()/2);
1014 std::vector<ColorSpinorField*> y0(y_.begin(), y_.begin() + y_.size()/2);
1018 const double *a1 = &a_[y_.size()/2];
1019 const double *b1 = &b_[y_.size()/2];
1020 const double *c1 = &c_[y_.size()/2];
1022 std::vector<ColorSpinorField*> x1(x_.begin() + x_.size()/2, x_.end());
1023 std::vector<ColorSpinorField*> y1(y_.begin() + y_.size()/2, y_.end());
1033 const int xsize = x_.size();
1040 std::vector<ColorSpinorField*> y;
1042 std::vector<ColorSpinorField*> w;
1046 std::vector<ColorSpinorField*> &x = x_;
1051 if (x[0]->Precision() != y[0]->Precision() )
1055 case 1: mixedMultiBlas<1, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1056 #if MAX_MULTI_BLAS_N >= 2 1057 case 2: mixedMultiBlas<2, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1058 #if MAX_MULTI_BLAS_N >= 3 1059 case 3: mixedMultiBlas<3, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1060 #if MAX_MULTI_BLAS_N >= 4 1061 case 4: mixedMultiBlas<4, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1062 #if MAX_MULTI_BLAS_N >= 5 1063 case 5: mixedMultiBlas<5, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1064 #if MAX_MULTI_BLAS_N >= 6 1065 case 6: mixedMultiBlas<6, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1066 #if MAX_MULTI_BLAS_N >= 7 1067 case 7: mixedMultiBlas<7, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1068 #if MAX_MULTI_BLAS_N >= 8 1069 case 8: mixedMultiBlas<8, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1070 #if MAX_MULTI_BLAS_N >= 9 1071 case 9: mixedMultiBlas<9, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1072 #if MAX_MULTI_BLAS_N >= 10 1073 case 10: mixedMultiBlas<10, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1074 #if MAX_MULTI_BLAS_N >= 11 1075 case 11: mixedMultiBlas<11, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1076 #if MAX_MULTI_BLAS_N >= 12 1077 case 12: mixedMultiBlas<12, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1078 #if MAX_MULTI_BLAS_N >= 13 1079 case 13: mixedMultiBlas<13, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1080 #if MAX_MULTI_BLAS_N >= 14 1081 case 14: mixedMultiBlas<14, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1082 #if MAX_MULTI_BLAS_N >= 15 1083 case 15: mixedMultiBlas<15, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1084 #if MAX_MULTI_BLAS_N >= 16 1085 case 16: mixedMultiBlas<16, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1110 case 1: multiBlas<1, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1111 #if MAX_MULTI_BLAS_N >= 2 1112 case 2: multiBlas<2, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1113 #if MAX_MULTI_BLAS_N >= 3 1114 case 3: multiBlas<3, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1115 #if MAX_MULTI_BLAS_N >= 4 1116 case 4: multiBlas<4, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1117 #if MAX_MULTI_BLAS_N >= 5 1118 case 5: multiBlas<5, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1119 #if MAX_MULTI_BLAS_N >= 6 1120 case 6: multiBlas<6, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1121 #if MAX_MULTI_BLAS_N >= 7 1122 case 7: multiBlas<7, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1123 #if MAX_MULTI_BLAS_N >= 8 1124 case 8: multiBlas<8, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1125 #if MAX_MULTI_BLAS_N >= 9 1126 case 9: multiBlas<9, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1127 #if MAX_MULTI_BLAS_N >= 10 1128 case 10: multiBlas<10, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1129 #if MAX_MULTI_BLAS_N >= 11 1130 case 11: multiBlas<11, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1131 #if MAX_MULTI_BLAS_N >= 12 1132 case 12: multiBlas<12, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1133 #if MAX_MULTI_BLAS_N >= 13 1134 case 13: multiBlas<13, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1135 #if MAX_MULTI_BLAS_N >= 14 1136 case 14: multiBlas<14, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1137 #if MAX_MULTI_BLAS_N >= 15 1138 case 15: multiBlas<15, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1139 #if MAX_MULTI_BLAS_N >= 16 1140 case 16: multiBlas<16, multi_caxpyBxpz_, write<0, 1, 0, 1>>(a, b, c, x, y, x, w);
break;
1166 std::vector<ColorSpinorField*> x0(x_.begin(), x_.begin() + x_.size()/2);
1170 const Complex *a1 = &a_[x_.size()/2];
1171 const Complex *b1 = &b_[x_.size()/2];
1173 std::vector<ColorSpinorField*> x1(x_.begin() + x_.size()/2, x_.end());
void caxpyz(const Complex *a, std::vector< ColorSpinorField *> &x, std::vector< ColorSpinorField *> &y, std::vector< ColorSpinorField *> &z)
Compute the block "caxpyz" with over the set of ColorSpinorFields. E.g., it computes.
void caxpyz_U(const Complex *a, std::vector< ColorSpinorField *> &x, std::vector< ColorSpinorField *> &y, std::vector< ColorSpinorField *> &z)
Compute the block "caxpyz" with over the set of ColorSpinorFields. E.g., it computes.
const coeff_array< T > & c
static __constant__ signed char Cmatrix_d[MAX_MATRIX_SIZE]
QudaVerbosity getVerbosity()
SpinorY Y[MAX_MULTI_BLAS_N]
Parameter struct for generic multi-blas kernel.
Helper file when using jitify run-time compilation. This file should be included in source code...
static __constant__ signed char Amatrix_d[MAX_MATRIX_SIZE]
CompositeColorSpinorField & Components()
void set(const cudaColorSpinorField &x)
void caxpy_U(const Complex *a, std::vector< ColorSpinorField *> &x, std::vector< ColorSpinorField *> &y)
Compute the block "caxpy_U" with over the set of ColorSpinorFields. E.g., it computes.
void initTuneParam(TuneParam ¶m) const
std::vector< ColorSpinorField * > & z
void caxpyBxpz(const Complex &, ColorSpinorField &, ColorSpinorField &, const Complex &, ColorSpinorField &)
cudaStream_t * getStream()
std::vector< ColorSpinorField * > CompositeColorSpinorField
void defaultTuneParam(TuneParam ¶m) const
void apply(const cudaStream_t &stream)
void mixedMultiBlas(const coeff_array< T > &a, const coeff_array< T > &b, const coeff_array< T > &c, CompositeColorSpinorField &x, CompositeColorSpinorField &y, CompositeColorSpinorField &z, CompositeColorSpinorField &w)
MultiBlasArg< NXZ, SpinorX, SpinorY, SpinorZ, SpinorW, Functor > arg
void initTuneParam(TuneParam ¶m) const
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
void caxpy_recurse(const Complex *a_, std::vector< ColorSpinorField *> &x, std::vector< ColorSpinorField *> &y, int i_idx, int j_idx, int upper)
#define checkLocation(...)
void axpyBzpcx(double a, ColorSpinorField &x, ColorSpinorField &y, double b, ColorSpinorField &z, double c)
static signed char * Bmatrix_h
void defaultTuneParam(TuneParam ¶m) const
static __constant__ signed char Bmatrix_d[MAX_MATRIX_SIZE]
std::complex< double > Complex
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
bool tuneSharedBytes() const
void caxpy_L(const Complex *a, std::vector< ColorSpinorField *> &x, std::vector< ColorSpinorField *> &y)
Compute the block "caxpy_L" with over the set of ColorSpinorFields. E.g., it computes.
void set(const cudaColorSpinorField &x, int nFace=1)
void multiBlas(const coeff_array< T > &a, const coeff_array< T > &b, const coeff_array< T > &c, std::vector< ColorSpinorField *> &x, std::vector< ColorSpinorField *> &y, std::vector< ColorSpinorField *> &z, std::vector< ColorSpinorField *> &w, int length)
void caxpyz_L(const Complex *a, std::vector< ColorSpinorField *> &x, std::vector< ColorSpinorField *> &y, std::vector< ColorSpinorField *> &z)
Compute the block "caxpyz" with over the set of ColorSpinorFields. E.g., it computes.
static __constant__ signed char arg_buffer[MAX_MATRIX_SIZE]
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
SpinorW W[MAX_MULTI_BLAS_N]
void caxpyz_recurse(const Complex *a_, std::vector< ColorSpinorField *> &x, std::vector< ColorSpinorField *> &y, std::vector< ColorSpinorField *> &z, int i, int j, int pass, int upper)
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
__device__ unsigned int count[QUDA_MAX_MULTI_REDUCE]
static signed char * Amatrix_h
static signed char * Cmatrix_h
MultiBlas(SpinorX X[], SpinorY Y[], SpinorZ Z[], SpinorW W[], Functor &f, const coeff_array< T > &a, const coeff_array< T > &b, const coeff_array< T > &c, std::vector< ColorSpinorField *> &x, std::vector< ColorSpinorField *> &y, std::vector< ColorSpinorField *> &z, std::vector< ColorSpinorField *> &w, int NYW, int length)