27 template <
typename FloatN,
int M,
typename SpinorX,
typename SpinorY,
typename SpinorZ,
typename SpinorW,
28 typename SpinorV,
typename Functor>
60 nParity((x.IsComposite() ? x.CompositeDim() : 1) * x.SiteSubset()),
61 arg(X, Y, Z, W, V, f, length / nParity),
85 ::quda::create_jitify_program(
"kernels/blas_core.cuh");
97 using namespace jitify::reflection;
99 .instantiate(Type<FloatN>(), M, Type<decltype(arg)>())
142 return (arg.
f.streams() - 2) * x.
Bytes() + 2 * y.
Bytes();
147 template <
typename RegType,
typename StoreType,
typename yType,
int M,
template <
typename,
typename>
class Functor,
148 int writeX,
int writeY,
int writeZ,
int writeW,
int writeV>
165 typedef typename vector<Float, 2>::type Float2;
166 typedef vector<Float, 2> vec2;
167 Functor<Float2, RegType> f((Float2)vec2(a), (Float2)vec2(b), (Float2)vec2(c));
170 X, Y, Z, W, V, f, x, y, z, w, v, length);
171 blas.
apply(*blasStream);
183 template <
template <
typename Float,
typename FloatN>
class Functor,
int writeX = 0,
int writeY = 0,
int writeZ = 0,
184 int writeW = 0,
int writeV = 0>
196 warningQuda(
"Device blas on non-native fields is not supported\n");
202 #if QUDA_PRECISION & 8 203 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) || defined(GPU_STAGGERED_DIRAC) 205 nativeBlas<double2, double2, double2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
208 errorQuda(
"blas has not been built for Nspin=%d fields", x.Nspin());
211 errorQuda(
"QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x.Precision());
216 #if QUDA_PRECISION & 4 218 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) 220 nativeBlas<float4, float4, float4, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
223 errorQuda(
"blas has not been built for Nspin=%d fields", x.Nspin());
226 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) || defined(GPU_STAGGERED_DIRAC) 228 nativeBlas<float2, float2, float2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
231 errorQuda(
"blas has not been built for Nspin=%d fields", x.Nspin());
237 errorQuda(
"QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x.
Precision());
242 #if QUDA_PRECISION & 2 245 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) 247 nativeBlas<float4, short4, short4, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
250 errorQuda(
"blas has not been built for Nspin=%d fields", x.Nspin());
253 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) 255 nativeBlas<float2, short2, short2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
258 errorQuda(
"blas has not been built for Nspin=%d fields", x.Nspin());
260 }
else if (x.
Nspin() == 1) {
261 #ifdef GPU_STAGGERED_DIRAC 263 nativeBlas<float2, short2, short2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
266 errorQuda(
"blas has not been built for Nspin=%d fields", x.Nspin());
272 errorQuda(
"QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x.
Precision());
277 #if QUDA_PRECISION & 1 279 if (x.
Nspin() == 4) {
280 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) 282 nativeBlas<float4, char4, char4, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
285 errorQuda(
"blas has not been built for Nspin=%d fields", x.Nspin());
287 }
else if (x.
Nspin() == 1) {
288 #ifdef GPU_STAGGERED_DIRAC 290 nativeBlas<float2, char2, char2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
293 errorQuda(
"blas has not been built for Nspin=%d fields", x.Nspin());
299 errorQuda(
"QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x.
Precision());
307 Functor<double2, double2> f(a, b, c);
308 genericBlas<double, double, writeX, writeY, writeZ, writeW, writeV>(
x,
y,
z,
w,
v, f);
310 Functor<float2, float2> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
311 genericBlas<float, float, writeX, writeY, writeZ, writeW, writeV>(
x,
y,
z,
w,
v, f);
324 template <
template <
typename Float,
typename FloatN>
class Functor,
int writeX = 0,
int writeY = 0,
int writeZ = 0,
325 int writeW = 0,
int writeV = 0>
336 warningQuda(
"Device blas on non-native fields is not supported\n");
342 #if QUDA_PRECISION & 4 343 if (x.
Nspin() == 4) {
345 nativeBlas<double2, float4, double2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
347 }
else if (x.
Nspin() == 1) {
349 nativeBlas<double2, float2, double2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
353 errorQuda(
"QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x.
Precision());
358 #if QUDA_PRECISION & 2 361 #if QUDA_PRECISION & 8 362 if (x.
Nspin() == 4) {
364 nativeBlas<double2, short4, double2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
366 }
else if (x.
Nspin() == 1) {
368 nativeBlas<double2, short2, double2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
372 errorQuda(
"QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, y.
Precision());
377 #if QUDA_PRECISION & 4 378 if (x.
Nspin() == 4) {
380 nativeBlas<float4, short4, float4, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
382 }
else if (x.
Nspin() == 1) {
384 nativeBlas<float2, short2, float2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
388 errorQuda(
"QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, y.
Precision());
395 errorQuda(
"QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x.
Precision());
400 #if QUDA_PRECISION & 1 404 #if QUDA_PRECISION & 8 405 if (x.
Nspin() == 4) {
407 nativeBlas<double2, char4, double2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
409 }
else if (x.
Nspin() == 1) {
411 nativeBlas<double2, char2, double2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
415 errorQuda(
"QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, y.
Precision());
420 #if QUDA_PRECISION & 4 421 if (x.
Nspin() == 4) {
423 nativeBlas<float4, char4, float4, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
425 }
else if (x.
Nspin() == 1) {
427 nativeBlas<float2, char2, float2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
431 errorQuda(
"QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, y.
Precision());
436 #if QUDA_PRECISION & 2 437 if (x.
Nspin() == 4) {
439 nativeBlas<float4, char4, short4, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
441 }
else if (x.
Nspin() == 1) {
443 nativeBlas<float2, char2, short2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
447 errorQuda(
"QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, y.
Precision());
454 errorQuda(
"QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x.
Precision());
464 Functor<double2, double2> f(a, b, c);
465 genericBlas<float, double, writeX, writeY, writeZ, writeW, writeV>(
x,
y,
z,
w,
v, f);
500 mixed_blas<axpbyz_, 0, 0, 0, 0, 1>(
501 make_double2(a, 0.0), make_double2(b, 0.0), make_double2(0.0, 0.0),
x,
y,
x,
x,
z);
503 uni_blas<axpbyz_, 0, 0, 0, 0, 1>(
504 make_double2(a, 0.0), make_double2(b, 0.0), make_double2(0.0, 0.0),
x,
y,
x,
x,
z);
509 uni_blas<ax_, 1>(make_double2(a, 0.0), make_double2(0.0, 0.0), make_double2(0.0, 0.0),
x,
x,
x,
x,
x);
514 mixed_blas<caxpy_, 0, 1>(
515 make_double2(real(a), imag(a)), make_double2(0.0, 0.0), make_double2(0.0, 0.0),
x,
y,
x,
x,
y);
517 uni_blas<caxpy_, 0, 1>(
518 make_double2(real(a), imag(a)), make_double2(0.0, 0.0), make_double2(0.0, 0.0),
x,
y,
x,
x,
y);
524 uni_blas<caxpby_, 0, 1>(
525 make_double2(
REAL(a),
IMAG(a)), make_double2(
REAL(b),
IMAG(b)), make_double2(0.0, 0.0),
x,
y,
x,
x,
y);
531 uni_blas<caxpbypczw_, 0, 0, 0, 1>(make_double2(
REAL(a),
IMAG(a)), make_double2(
REAL(b),
IMAG(b)),
537 uni_blas<caxpbypczw_, 0, 0, 0, 1>(make_double2(1.0, 0.0), make_double2(
REAL(a),
IMAG(a)),
545 mixed_blas<axpyBzpcx_, 1, 1>(make_double2(a, 0.0), make_double2(b, 0.0), make_double2(c, 0.0),
x,
y,
z,
x,
y);
548 uni_blas<axpyBzpcx_, 1, 1>(make_double2(a, 0.0), make_double2(b, 0.0), make_double2(c, 0.0),
x,
y,
z,
x,
y);
556 mixed_blas<axpyZpbx_, 1, 1>(make_double2(a, 0.0), make_double2(b, 0.0), make_double2(0.0, 0.0),
x,
y,
z,
x,
y);
559 uni_blas<axpyZpbx_, 1, 1>(make_double2(a, 0.0), make_double2(b, 0.0), make_double2(0.0, 0.0),
x,
y,
z,
x,
y);
566 mixed_blas<caxpyBzpx_, 1, 1>(
567 make_double2(
REAL(a),
IMAG(a)), make_double2(
REAL(b),
IMAG(b)), make_double2(0.0, 0.0),
x,
y,
z,
x,
y);
569 uni_blas<caxpyBzpx_, 1, 1>(
570 make_double2(
REAL(a),
IMAG(a)), make_double2(
REAL(b),
IMAG(b)), make_double2(0.0, 0.0),
x,
y,
z,
x,
y);
577 mixed_blas<caxpyBxpz_, 0, 1, 1>(
578 make_double2(
REAL(a),
IMAG(a)), make_double2(
REAL(b),
IMAG(b)), make_double2(0.0, 0.0),
x,
y,
z,
x,
y);
580 uni_blas<caxpyBxpz_, 0, 1, 1>(
581 make_double2(
REAL(a),
IMAG(a)), make_double2(
REAL(b),
IMAG(b)), make_double2(0.0, 0.0),
x,
y,
z,
x,
y);
587 uni_blas<caxpbypzYmbw_, 0, 1, 1>(
588 make_double2(
REAL(a),
IMAG(a)), make_double2(
REAL(b),
IMAG(b)), make_double2(0.0, 0.0),
x,
y,
z,
w,
y);
593 uni_blas<cabxpyAx_, 1, 1>(
594 make_double2(a, 0.0), make_double2(
REAL(b),
IMAG(b)), make_double2(0.0, 0.0),
x,
y,
x,
x,
y);
599 uni_blas<caxpyxmaz_, 1, 1>(
600 make_double2(
REAL(a),
IMAG(a)), make_double2(0.0, 0.0), make_double2(0.0, 0.0),
x,
y,
z,
x,
y);
606 errorQuda(
"This kernel requires asynchronous reductions to be set");
608 errorQuda(
"This kernel cannot be run on CPU fields");
610 uni_blas<caxpyxmazMR_, 1, 1>(
611 make_double2(
REAL(a),
IMAG(a)), make_double2(0.0, 0.0), make_double2(0.0, 0.0),
x,
y,
z,
x,
y);
618 mixed_blas<tripleCGUpdate_, 0, 1, 1, 1>(
619 make_double2(a, 0.0), make_double2(b, 0.0), make_double2(0.0, 0.0),
x,
y,
z,
w,
y);
621 uni_blas<tripleCGUpdate_, 0, 1, 1, 1>(
622 make_double2(a, 0.0), make_double2(b, 0.0), make_double2(0.0, 0.0),
x,
y,
z,
w,
y);
627 uni_blas<doubleCG3Init_, 1, 1, 0, 0>(
628 make_double2(a, 0.0), make_double2(0.0, 0.0), make_double2(0.0, 0.0),
x,
y,
z,
z,
y);
632 uni_blas<doubleCG3Update_, 1, 1, 0, 0>(
633 make_double2(a, 0.0), make_double2(b, 1.0 - b), make_double2(0.0, 0.0),
x,
y,
z,
z,
y);
void ax(double a, ColorSpinorField &x)
void caxpyXmazMR(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
bool commAsyncReduction()
void axpyZpbx(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, double b)
const char * AuxString() const
QudaVerbosity getVerbosity()
const ColorSpinorField & x
#define checkPrecision(...)
Helper file when using jitify run-time compilation. This file should be included in source code...
const ColorSpinorField & y
void cabxpyAx(double a, const Complex &b, ColorSpinorField &x, ColorSpinorField &y)
void caxpbypczw(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y, const Complex &c, ColorSpinorField &z, ColorSpinorField &w)
virtual bool advanceSharedBytes(TuneParam ¶m) const
const char * VolString() const
const ColorSpinorField & w
void caxpyBzpx(const Complex &, ColorSpinorField &, ColorSpinorField &, const Complex &, ColorSpinorField &)
void caxpyBxpz(const Complex &, ColorSpinorField &, ColorSpinorField &, const Complex &, ColorSpinorField &)
void doubleCG3Update(double a, double b, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
cudaStream_t * getStream()
void apply(const cudaStream_t &stream)
static cudaStream_t * blasStream
const ColorSpinorField & v
void mixed_blas(const double2 &a, const double2 &b, const double2 &c, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v)
void caxpbypzYmbw(const Complex &, ColorSpinorField &, const Complex &, ColorSpinorField &, ColorSpinorField &, ColorSpinorField &)
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
#define checkLocation(...)
void axpyBzpcx(double a, ColorSpinorField &x, ColorSpinorField &y, double b, ColorSpinorField &z, double c)
unsigned int sharedBytesPerBlock(const TuneParam ¶m) const
std::complex< double > Complex
void initTuneParam(TuneParam ¶m) const
void tripleCGUpdate(double alpha, double beta, ColorSpinorField &q, ColorSpinorField &r, ColorSpinorField &x, ColorSpinorField &p)
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
void axpbyz(double a, ColorSpinorField &x, double b, ColorSpinorField &y, ColorSpinorField &z)
void zero(ColorSpinorField &a)
void doubleCG3Init(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
void checkLength(const ColorSpinorField &a, const ColorSpinorField &b)
QudaFieldLocation Location() const
BlasArg< SpinorX, SpinorY, SpinorZ, SpinorW, SpinorV, Functor > arg
BlasCuda(SpinorX &X, SpinorY &Y, SpinorZ &Z, SpinorW &W, SpinorV &V, Functor &f, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v, int length)
void caxpyXmaz(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
void uni_blas(const double2 &a, const double2 &b, const double2 &c, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v)
void caxpby(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y)
virtual void initTuneParam(TuneParam ¶m) const
void cxpaypbz(ColorSpinorField &, const Complex &b, ColorSpinorField &y, const Complex &c, ColorSpinorField &z)
void nativeBlas(const double2 &a, const double2 &b, const double2 &c, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v, int length)
virtual bool advanceBlockDim(TuneParam ¶m) const
const ColorSpinorField & z
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
QudaPrecision Precision() const
void defaultTuneParam(TuneParam ¶m) const
QudaFieldOrder FieldOrder() const
unsigned int sharedBytesPerThread() const