17 #define checkSpinor(a, b) \ 19 if (a.Precision() != b.Precision()) \ 20 errorQuda("precisions do not match: %d %d", a.Precision(), b.Precision()); \ 21 if (a.Length() != b.Length()) \ 22 errorQuda("lengths do not match: %lu %lu", a.Length(), b.Length()); \ 23 if (a.Stride() != b.Stride()) \ 24 errorQuda("strides do not match: %d %d", a.Stride(), b.Stride()); \ 27 #define checkLength(a, b) \ 29 if (a.Length() != b.Length()) \ 30 errorQuda("lengths do not match: %lu %lu", a.Length(), b.Length()); \ 31 if (a.Stride() != b.Stride()) \ 32 errorQuda("strides do not match: %d %d", a.Stride(), b.Stride()); \ 39 #define BLAS_SPINOR // do not include ghost functions in Spinor class to reduce parameter space overhead 82 template <
typename Float2,
typename FloatN>
86 virtual __device__ __host__
void init() { ; }
89 virtual __device__ __host__
void operator()(FloatN &
x, FloatN &
y, FloatN &
z, FloatN &
w) = 0;
95 template <
typename Float2,
typename FloatN>
99 axpby_(
const Float2 &
a,
const Float2 &
b,
const Float2 &
c) :
a(
a),
b(
b) { ; }
101 {
y =
a.x*
x +
b.x*
y; }
107 if (
x.Precision() !=
y.Precision()) {
109 mixed::blasCuda<axpby_,0,1,0,0>(make_double2(
a,0.0), make_double2(
b,0.0), make_double2(0.0,0.0),
112 blasCuda<axpby_,0,1,0,0>(make_double2(
a, 0.0), make_double2(
b, 0.0), make_double2(0.0, 0.0),
120 template <
typename Float2,
typename FloatN>
122 xpy_(
const Float2 &
a,
const Float2 &
b,
const Float2 &
c) { ; }
123 __device__ __host__
void operator()(FloatN &
x, FloatN &
y, FloatN &
z, FloatN &
w) {
y +=
x ; }
129 if (
x.Precision() !=
y.Precision()) {
130 mixed::blasCuda<xpy_,0,1,0,0>(make_double2(1.0, 0.0), make_double2(1.0, 0.0),
131 make_double2(0.0, 0.0),
x,
y,
x,
x);
133 blasCuda<xpy_,0,1,0,0>(make_double2(1.0, 0.0), make_double2(1.0, 0.0),
134 make_double2(0.0, 0.0),
x,
y,
x,
x);
141 template <
typename Float2,
typename FloatN>
144 axpy_(
const Float2 &
a,
const Float2 &
b,
const Float2 &
c) :
a(
a) { ; }
151 if (
x.Precision() !=
y.Precision()) {
153 mixed::blasCuda<axpy_,0,1,0,0>(make_double2(
a,0.0), make_double2(1.0,0.0), make_double2(0.0,0.0),
156 blasCuda<axpy_,0,1,0,0>(make_double2(
a, 0.0), make_double2(1.0, 0.0), make_double2(0.0, 0.0),
164 template <
typename Float2,
typename FloatN>
167 xpayz_(
const Float2 &
a,
const Float2 &
b,
const Float2 &
c) :
a(
a) { ; }
174 blasCuda<xpayz_,0,0,1,0>(make_double2(
a,0.0), make_double2(0.0, 0.0), make_double2(0.0, 0.0),
x,
y,
y,
x);
178 blasCuda<xpayz_,0,0,1,0>(make_double2(
a,0.0), make_double2(0.0, 0.0), make_double2(0.0, 0.0),
x,
y,
z,
x);
184 template <
typename Float2,
typename FloatN>
186 mxpy_(
const Float2 &
a,
const Float2 &
b,
const Float2 &
c) { ; }
187 __device__ __host__
void operator()(FloatN &
x, FloatN &
y, FloatN &
z, FloatN &
w) {
y -=
x; }
193 blasCuda<mxpy_,0,1,0,0>(make_double2(1.0, 0.0), make_double2(1.0, 0.0),
194 make_double2(0.0, 0.0),
x,
y,
x,
x);
200 template <
typename Float2,
typename FloatN>
203 ax_(
const Float2 &
a,
const Float2 &
b,
const Float2 &
c) :
a(
a) { ; }
204 __device__ __host__
void operator()(FloatN &
x, FloatN &
y, FloatN &
z, FloatN &
w) {
x *=
a.x; }
210 blasCuda<ax_,1,0,0,0>(make_double2(
a, 0.0), make_double2(0.0, 0.0),
211 make_double2(0.0, 0.0),
x,
x,
x,
x);
219 __device__ __host__
void _caxpy(
const float2 &
a,
const float4 &
x, float4 &
y) {
220 y.x +=
a.x*
x.x;
y.x -=
a.y*
x.y;
221 y.y +=
a.y*
x.x;
y.y +=
a.x*
x.y;
222 y.z +=
a.x*
x.z;
y.z -=
a.y*
x.w;
223 y.w +=
a.y*
x.z;
y.w +=
a.x*
x.w;
226 __device__ __host__
void _caxpy(
const float2 &
a,
const float2 &
x, float2 &
y) {
227 y.x +=
a.x*
x.x;
y.x -=
a.y*
x.y;
228 y.y +=
a.y*
x.x;
y.y +=
a.x*
x.y;
231 __device__ __host__
void _caxpy(
const double2 &
a,
const double2 &
x, double2 &
y) {
232 y.x +=
a.x*
x.x;
y.x -=
a.y*
x.y;
233 y.y +=
a.y*
x.x;
y.y +=
a.x*
x.y;
236 template <
typename Float2,
typename FloatN>
239 caxpy_(
const Float2 &
a,
const Float2 &
b,
const Float2 &
c) :
a(
a) { ; }
247 if (
x.Precision() !=
y.Precision()) {
248 mixed::blasCuda<caxpy_,0,1,0,0>(make_double2(real(
a),imag(
a)), make_double2(0.0, 0.0),
249 make_double2(0.0, 0.0),
x,
y,
x,
x);
251 blasCuda<caxpy_,0,1,0,0>(make_double2(real(
a),imag(
a)), make_double2(0.0, 0.0),
252 make_double2(0.0, 0.0),
x,
y,
x,
x);
261 __device__ __host__
void _caxpby(
const float2 &
a,
const float4 &
x,
const float2 &
b, float4 &
y)
263 yy.x =
a.x*
x.x; yy.x -=
a.y*
x.y; yy.x +=
b.x*
y.x; yy.x -=
b.y*
y.y;
264 yy.y =
a.y*
x.x; yy.y +=
a.x*
x.y; yy.y +=
b.y*
y.x; yy.y +=
b.x*
y.y;
265 yy.z =
a.x*
x.z; yy.z -=
a.y*
x.w; yy.z +=
b.x*
y.z; yy.z -=
b.y*
y.w;
266 yy.w =
a.y*
x.z; yy.w +=
a.x*
x.w; yy.w +=
b.y*
y.z; yy.w +=
b.x*
y.w;
269 __device__ __host__
void _caxpby(
const float2 &
a,
const float2 &
x,
const float2 &
b, float2 &
y)
271 yy.x =
a.x*
x.x; yy.x -=
a.y*
x.y; yy.x +=
b.x*
y.x; yy.x -=
b.y*
y.y;
272 yy.y =
a.y*
x.x; yy.y +=
a.x*
x.y; yy.y +=
b.y*
y.x; yy.y +=
b.x*
y.y;
275 __device__ __host__
void _caxpby(
const double2 &
a,
const double2 &
x,
const double2 &
b, double2 &
y)
277 yy.x =
a.x*
x.x; yy.x -=
a.y*
x.y; yy.x +=
b.x*
y.x; yy.x -=
b.y*
y.y;
278 yy.y =
a.y*
x.x; yy.y +=
a.x*
x.y; yy.y +=
b.y*
y.x; yy.y +=
b.x*
y.y;
281 template <
typename Float2,
typename FloatN>
294 make_double2(0.0, 0.0),
x,
y,
x,
x);
301 __device__ __host__
void _cxpaypbz(
const float4 &
x,
const float2 &
a,
const float4 &
y,
const float2 &
b, float4 &
z) {
303 zz.x =
x.x +
a.x*
y.x; zz.x -=
a.y*
y.y; zz.x +=
b.x*
z.x; zz.x -=
b.y*
z.y;
304 zz.y =
x.y +
a.y*
y.x; zz.y +=
a.x*
y.y; zz.y +=
b.y*
z.x; zz.y +=
b.x*
z.y;
305 zz.z =
x.z +
a.x*
y.z; zz.z -=
a.y*
y.w; zz.z +=
b.x*
z.z; zz.z -=
b.y*
z.w;
306 zz.w =
x.w +
a.y*
y.z; zz.w +=
a.x*
y.w; zz.w +=
b.y*
z.z; zz.w +=
b.x*
z.w;
310 __device__ __host__
void _cxpaypbz(
const float2 &
x,
const float2 &
a,
const float2 &
y,
const float2 &
b, float2 &
z) {
312 zz.x =
x.x +
a.x*
y.x; zz.x -=
a.y*
y.y; zz.x +=
b.x*
z.x; zz.x -=
b.y*
z.y;
313 zz.y =
x.y +
a.y*
y.x; zz.y +=
a.x*
y.y; zz.y +=
b.y*
z.x; zz.y +=
b.x*
z.y;
317 __device__ __host__
void _cxpaypbz(
const double2 &
x,
const double2 &
a,
const double2 &
y,
const double2 &
b, double2 &
z) {
319 zz.x =
x.x +
a.x*
y.x; zz.x -=
a.y*
y.y; zz.x +=
b.x*
z.x; zz.x -=
b.y*
z.y;
320 zz.y =
x.y +
a.y*
y.x; zz.y +=
a.x*
y.y; zz.y +=
b.y*
z.x; zz.y +=
b.x*
z.y;
324 template <
typename Float2,
typename FloatN>
338 make_double2(0.0, 0.0),
x,
y,
z,
z);
344 template <
typename Float2,
typename FloatN>
358 if (
x.Precision() !=
y.Precision()) {
360 mixed::blasCuda<axpyBzpcx_,1,1,0,0>(make_double2(
a,0.0), make_double2(
b,0.0),
361 make_double2(
c,0.0),
x,
y,
z,
x);
364 blasCuda<axpyBzpcx_,1,1,0,0>(make_double2(
a,0.0), make_double2(
b,0.0),
365 make_double2(
c,0.0),
x,
y,
z,
x);
373 template <
typename Float2,
typename FloatN>
386 if (
x.Precision() !=
y.Precision()) {
388 mixed::blasCuda<axpyZpbx_,1,1,0,0>(make_double2(
a,0.0), make_double2(
b,0.0), make_double2(0.0,0.0),
392 blasCuda<axpyZpbx_,1,1,0,0>(make_double2(
a,0.0), make_double2(
b,0.0), make_double2(0.0,0.0),
400 template <
typename Float2,
typename FloatN>
414 if (
x.Precision() !=
y.Precision()) {
415 mixed::blasCuda<caxpyBzpx_,1,1,0,0>(make_double2(
REAL(
a),
IMAG(
a)), make_double2(
REAL(
b),
IMAG(
b)),
416 make_double2(0.0,0.0),
x,
y,
z,
x);
419 make_double2(0.0,0.0),
x,
y,
z,
x);
426 template <
typename Float2,
typename FloatN>
440 if (
x.Precision() !=
y.Precision()) {
441 mixed::blasCuda<caxpyBxpz_,0,1,1,0>(make_double2(
REAL(
a),
IMAG(
a)), make_double2(
REAL(
b),
IMAG(
b)),
442 make_double2(0.0,0.0),
x,
y,
z,
x);
445 make_double2(0.0,0.0),
x,
y,
z,
x);
452 template <
typename Float2,
typename FloatN>
467 make_double2(0.0,0.0),
x,
y,
z,
w);
473 template <
typename Float2,
typename FloatN>
487 blasCuda<cabxpyAx_,1,1,0,0>(make_double2(
a,0.0), make_double2(
REAL(
b),
IMAG(
b)),
488 make_double2(0.0,0.0),
x,
y,
x,
x);
494 template <
typename Float2,
typename FloatN>
508 make_double2(0.0,0.0),
x,
y,
z,
z);
514 template <
typename Float2,
typename FloatN>
539 template <
typename Float2,
typename FloatN>
551 blasCuda<caxpyxmaz_,1,1,0,0>(make_double2(
REAL(
a),
IMAG(
a)), make_double2(0.0, 0.0),
552 make_double2(0.0, 0.0),
x,
y,
z,
x);
560 template <
typename Float2,
typename FloatN>
567 inline __device__ __host__
void init() {
569 typedef decltype(
a.x) real;
570 double3 result = __ldg(
Ar3);
571 a.y =
a.x * (real)(result.y) * ((real)1.0 / (real)result.z);
572 a.x =
a.x * (real)(result.x) * ((real)1.0 / (real)result.z);
586 errorQuda(
"This kernel requires asynchronous reductions to be set");
588 errorQuda(
"This kernel cannot be run on CPU fields");
590 blasCuda<caxpyxmazMR_,1,1,0,0>(make_double2(
REAL(
a),
IMAG(
a)), make_double2(0.0, 0.0),
591 make_double2(0.0, 0.0),
x,
y,
z,
x);
600 template <
typename Float2,
typename FloatN>
612 if (
x.Precision() !=
y.Precision()) {
614 mixed::blasCuda<tripleCGUpdate_,0,1,1,1>(make_double2(
a,0.0), make_double2(
b,0.0),
615 make_double2(0.0,0.0),
x,
y,
z,
w);
617 blasCuda<tripleCGUpdate_,0,1,1,1>(make_double2(
a, 0.0), make_double2(
b, 0.0),
618 make_double2(0.0, 0.0),
x,
y,
z,
w);
caxpbypzYmbw_(const Float2 &a, const Float2 &b, const Float2 &c)
tripleCGUpdate_(const Float2 &a, const Float2 &b, const Float2 &c)
static int flops()
total number of input and output streams
axpyBzpcx_(const Float2 &a, const Float2 &b, const Float2 &c)
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
void xpay(ColorSpinorField &x, const double &a, ColorSpinorField &y)
__device__ __host__ void _caxpby(const float2 &a, const float4 &x, const float2 &b, float4 &y)
void caxpyXmazMR(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
bool commAsyncReduction()
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
static int flops()
total number of input and output streams
char aux_tmp[TuneKey::aux_n]
caxpyxmaz_(const Float2 &a, const Float2 &b, const Float2 &c)
std::complex< double > Complex
static int flops()
total number of input and output streams
void xpayz(ColorSpinorField &x, const double &a, ColorSpinorField &y, ColorSpinorField &z)
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
caxpyBzpx_(const Float2 &a, const Float2 &b, const Float2 &c)
mxpy_(const Float2 &a, const Float2 &b, const Float2 &c)
void ax(const double &a, ColorSpinorField &x)
static int flops()
total number of input and output streams
caxpyBxpz_(const Float2 &a, const Float2 &b, const Float2 &c)
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
caxpbypczpw_(const Float2 &a, const Float2 &b, const Float2 &c)
virtual __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)=0
where the reduction is usually computed and any auxiliary operations
static int flops()
total number of input and output streams
void caxpyBzpx(const Complex &, ColorSpinorField &, ColorSpinorField &, const Complex &, ColorSpinorField &)
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
static int flops()
total number of input and output streams
void caxpyBxpz(const Complex &, ColorSpinorField &, ColorSpinorField &, const Complex &, ColorSpinorField &)
xpayz_(const Float2 &a, const Float2 &b, const Float2 &c)
__device__ __host__ void _caxpy(const float2 &a, const float4 &x, float4 &y)
cudaStream_t * getStream()
caxpy_(const Float2 &a, const Float2 &b, const Float2 &c)
void cabxpyAx(const double &a, const Complex &b, ColorSpinorField &x, ColorSpinorField &y)
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
virtual __device__ __host__ void init()
pre-computation routine before the main loop
static cudaStream_t * blasStream
void axpyZpbx(const double &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, const double &b)
static int flops()
total number of input and output streams
caxpbypz_(const Float2 &a, const Float2 &b, const Float2 &c)
static struct quda::blas::@4 blasStrings
void caxpbypzYmbw(const Complex &, ColorSpinorField &, const Complex &, ColorSpinorField &, ColorSpinorField &, ColorSpinorField &)
static int flops()
total number of input and output streams
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
static int flops()
total number of input and output streams
void tripleCGUpdate(const double &alpha, const double &beta, ColorSpinorField &q, ColorSpinorField &r, ColorSpinorField &x, ColorSpinorField &p)
cxpaypbz_(const Float2 &a, const Float2 &b, const Float2 &c)
static int flops()
total number of input and output streams
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
axpyZpbx_(const Float2 &a, const Float2 &b, const Float2 &c)
static int flops()
total number of input and output streams
static int flops()
total number of input and output streams
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
void zero(ColorSpinorField &a)
ax_(const Float2 &a, const Float2 &b, const Float2 &c)
void caxpbypczpw(const Complex &, ColorSpinorField &, const Complex &, ColorSpinorField &, const Complex &, ColorSpinorField &, ColorSpinorField &)
void axpy(const double &a, ColorSpinorField &x, ColorSpinorField &y)
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
static int flops()
total number of input and output streams
cabxpyAx_(const Float2 &a, const Float2 &b, const Float2 &c)
static int flops()
total number of input and output streams
__device__ __host__ void init()
pre-computation routine before the main loop
void axpby(const double &a, ColorSpinorField &x, const double &b, ColorSpinorField &y)
axpby_(const Float2 &a, const Float2 &b, const Float2 &c)
void caxpbypz(const Complex &, ColorSpinorField &, const Complex &, ColorSpinorField &, ColorSpinorField &)
static int flops()
total number of input and output streams
xpy_(const Float2 &a, const Float2 &b, const Float2 &c)
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
void axpyBzpcx(const double &a, ColorSpinorField &x, ColorSpinorField &y, const double &b, ColorSpinorField &z, const double &c)
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
void caxpyXmaz(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
void * getDeviceReduceBuffer()
__device__ __host__ void _cxpaypbz(const float4 &x, const float2 &a, const float4 &y, const float2 &b, float4 &z)
static int flops()
total number of input and output streams
axpy_(const Float2 &a, const Float2 &b, const Float2 &c)
void xpy(ColorSpinorField &x, ColorSpinorField &y)
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
static int flops()
total number of input and output streams
void caxpby(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y)
caxpyxmazMR_(const Float2 &a, const Float2 &b, const Float2 &c)
void mxpy(ColorSpinorField &x, ColorSpinorField &y)
void cxpaypbz(ColorSpinorField &, const Complex &b, ColorSpinorField &y, const Complex &c, ColorSpinorField &z)
static int flops()
total number of input and output streams
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
static int flops()
total number of input and output streams
caxpby_(const Float2 &a, const Float2 &b, const Float2 &c)
static int flops()
total number of input and output streams
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations