15 #define checkSpinor(a, b) \
17 if (a.Precision() != b.Precision()) \
18 errorQuda("precisions do not match: %d %d", a.Precision(), b.Precision()); \
19 if (a.Length() != b.Length()) \
20 errorQuda("lengths do not match: %d %d", a.Length(), b.Length()); \
21 if (a.Stride() != b.Stride()) \
22 errorQuda("strides do not match: %d %d", a.Stride(), b.Stride()); \
37 static cudaStream_t *blasStream;
73 template <
typename Float2,
typename FloatN>
77 axpby(
const Float2 &
a,
const Float2 &
b,
const Float2 &c) : a(a), b(b) { ; }
80 static int flops() {
return 3; }
84 blasCuda<axpby,0,1,0,0>(make_double2(a, 0.0), make_double2(b, 0.0), make_double2(0.0, 0.0),
91 template <
typename Float2,
typename FloatN>
93 xpy(
const Float2 &a,
const Float2 &b,
const Float2 &c) { ; }
96 static int flops() {
return 1; }
100 blasCuda<xpy,0,1,0,0>(make_double2(1.0, 0.0), make_double2(1.0, 0.0), make_double2(0.0, 0.0),
107 template <
typename Float2,
typename FloatN>
110 axpy(
const Float2 &
a,
const Float2 &b,
const Float2 &c) : a(a) { ; }
117 blasCuda<axpy,0,1,0,0>(make_double2(a, 0.0), make_double2(1.0, 0.0), make_double2(0.0, 0.0),
124 template <
typename Float2,
typename FloatN>
127 xpay(
const Float2 &
a,
const Float2 &b,
const Float2 &c) : a(a) { ; }
134 blasCuda<xpay,0,1,0,0>(make_double2(a,0.0), make_double2(0.0, 0.0), make_double2(0.0, 0.0),
141 template <
typename Float2,
typename FloatN>
143 mxpy(
const Float2 &a,
const Float2 &b,
const Float2 &c) { ; }
150 blasCuda<mxpy,0,1,0,0>(make_double2(1.0, 0.0), make_double2(1.0, 0.0),
151 make_double2(0.0, 0.0),
x, y,
x,
x);
157 template <
typename Float2,
typename FloatN>
160 ax(
const Float2 &
a,
const Float2 &b,
const Float2 &c) : a(a) { ; }
167 blasCuda<ax,1,0,0,0>(make_double2(a, 0.0), make_double2(0.0, 0.0),
168 make_double2(0.0, 0.0),
x,
x,
x,
x);
175 __device__
void caxpy_(
const float2 &a,
const float4 &
x, float4 &y) {
176 y.x += a.x*x.x; y.x -= a.y*x.y;
177 y.y += a.y*x.x; y.y += a.x*x.y;
178 y.z += a.x*x.z; y.z -= a.y*x.w;
179 y.w += a.y*x.z; y.w += a.x*x.w;
182 __device__
void caxpy_(
const float2 &a,
const float2 &
x, float2 &y) {
183 y.x += a.x*x.x; y.x -= a.y*x.y;
184 y.y += a.y*x.x; y.y += a.x*x.y;
187 __device__
void caxpy_(
const double2 &a,
const double2 &
x, double2 &y) {
188 y.x += a.x*x.x; y.x -= a.y*x.y;
189 y.y += a.y*x.x; y.y += a.x*x.y;
192 template <
typename Float2,
typename FloatN>
195 caxpy(
const Float2 &
a,
const Float2 &b,
const Float2 &c) : a(a) { ; }
202 blasCuda<caxpy,0,1,0,0>(make_double2(real(a),imag(a)), make_double2(0.0, 0.0),
203 make_double2(0.0, 0.0),
x, y,
x,
x);
210 __device__
void caxpby_(
const float2 &a,
const float4 &
x,
const float2 &b, float4 &y)
212 yy.x = a.x*x.x; yy.x -= a.y*x.y; yy.x += b.x*y.x; yy.x -= b.y*y.y;
213 yy.y = a.y*x.x; yy.y += a.x*x.y; yy.y += b.y*y.x; yy.y += b.x*y.y;
214 yy.z = a.x*x.z; yy.z -= a.y*x.w; yy.z += b.x*y.z; yy.z -= b.y*y.w;
215 yy.w = a.y*x.z; yy.w += a.x*x.w; yy.w += b.y*y.z; yy.w += b.x*y.w;
218 __device__
void caxpby_(
const float2 &a,
const float2 &
x,
const float2 &b, float2 &y)
220 yy.x = a.x*x.x; yy.x -= a.y*x.y; yy.x += b.x*y.x; yy.x -= b.y*y.y;
221 yy.y = a.y*x.x; yy.y += a.x*x.y; yy.y += b.y*y.x; yy.y += b.x*y.y;
224 __device__
void caxpby_(
const double2 &a,
const double2 &
x,
const double2 &b, double2 &y)
226 yy.x = a.x*x.x; yy.x -= a.y*x.y; yy.x += b.x*y.x; yy.x -= b.y*y.y;
227 yy.y = a.y*x.x; yy.y += a.x*x.y; yy.y += b.y*y.x; yy.y += b.x*y.y;
230 template <
typename Float2,
typename FloatN>
234 caxpby(
const Float2 &
a,
const Float2 &
b,
const Float2 &c) : a(a), b(b) { ; }
241 blasCuda<caxpby,0,1,0,0>(make_double2(a.real(),a.imag()), make_double2(b.real(), b.imag()),
242 make_double2(0.0, 0.0),
x, y,
x,
x);
249 __device__
void cxpaypbz_(
const float4 &
x,
const float2 &a,
const float4 &y,
const float2 &b, float4 &z) {
251 zz.x = x.x + a.x*y.x; zz.x -= a.y*y.y; zz.x += b.x*z.x; zz.x -= b.y*z.y;
252 zz.y = x.y + a.y*y.x; zz.y += a.x*y.y; zz.y += b.y*z.x; zz.y += b.x*z.y;
253 zz.z = x.z + a.x*y.z; zz.z -= a.y*y.w; zz.z += b.x*z.z; zz.z -= b.y*z.w;
254 zz.w = x.w + a.y*y.z; zz.w += a.x*y.w; zz.w += b.y*z.z; zz.w += b.x*z.w;
258 __device__
void cxpaypbz_(
const float2 &
x,
const float2 &a,
const float2 &y,
const float2 &b, float2 &z) {
260 zz.x = x.x + a.x*y.x; zz.x -= a.y*y.y; zz.x += b.x*z.x; zz.x -= b.y*z.y;
261 zz.y = x.y + a.y*y.x; zz.y += a.x*y.y; zz.y += b.y*z.x; zz.y += b.x*z.y;
265 __device__
void cxpaypbz_(
const double2 &
x,
const double2 &a,
const double2 &y,
const double2 &b, double2 &z) {
267 zz.x = x.x + a.x*y.x; zz.x -= a.y*y.y; zz.x += b.x*z.x; zz.x -= b.y*z.y;
268 zz.y = x.y + a.y*y.x; zz.y += a.x*y.y; zz.y += b.y*z.x; zz.y += b.x*z.y;
272 template <
typename Float2,
typename FloatN>
276 cxpaypbz(
const Float2 &
a,
const Float2 &
b,
const Float2 &c) : a(a), b(b) { ; }
285 blasCuda<cxpaypbz,0,0,1,0>(make_double2(a.real(),a.imag()), make_double2(b.real(), b.imag()),
286 make_double2(0.0, 0.0),
x, y, z, z);
292 template <
typename Float2,
typename FloatN>
297 axpyBzpcx(
const Float2 &
a,
const Float2 &
b,
const Float2 &
c) : a(a), b(b), c(c) { ; }
299 { y +=
a.x*
x; x =
b.x*z +
c.x*
x; }
306 blasCuda<axpyBzpcx,1,1,0,0>(make_double2(a,0.0), make_double2(b,0.0), make_double2(c,0.0),
313 template <
typename Float2,
typename FloatN>
317 axpyZpbx(
const Float2 &
a,
const Float2 &
b,
const Float2 &c) : a(a), b(b) { ; }
319 { y +=
a.x*
x; x = z +
b.x*
x; }
327 blasCuda<axpyZpbx,1,1,0,0>(make_double2(a,0.0), make_double2(b,0.0), make_double2(0.0,0.0),
334 template <
typename Float2,
typename FloatN>
338 caxpbypzYmbw(
const Float2 &
a,
const Float2 &
b,
const Float2 &c) : a(a), b(b) { ; }
348 blasCuda<caxpbypzYmbw,0,1,1,0>(make_double2(a.real(),a.imag()), make_double2(b.real(), b.imag()),
349 make_double2(0.0,0.0),
x, y, z, w);
355 template <
typename Float2,
typename FloatN>
359 cabxpyAx(
const Float2 &
a,
const Float2 &
b,
const Float2 &c) : a(a), b(b) { ; }
369 blasCuda<cabxpyAx,1,1,0,0>(make_double2(a,0.0), make_double2(b.real(),b.imag()),
370 make_double2(0.0,0.0),
x, y,
x,
x);
376 template <
typename Float2,
typename FloatN>
380 caxpbypz(
const Float2 &
a,
const Float2 &
b,
const Float2 &c) : a(a), b(b) { ; }
389 blasCuda<caxpbypz,0,0,1,0>(make_double2(a.real(),a.imag()), make_double2(b.real(),b.imag()),
390 make_double2(0.0,0.0),
x, y, z, z);
396 template <
typename Float2,
typename FloatN>
401 caxpbypczpw(
const Float2 &
a,
const Float2 &
b,
const Float2 &
c) : a(a), b(b), c(c) { ; }
412 blasCuda<caxpbypczpw,0,0,0,1>(make_double2(a.real(),a.imag()), make_double2(b.real(),b.imag()),
413 make_double2(c.real(), c.imag()), x, y, z, w);
422 template <
typename Float2,
typename FloatN>
425 caxpyxmaz(
const Float2 &
a,
const Float2 &b,
const Float2 &c) : a(a) { ; }
434 blasCuda<caxpyxmaz,1,1,0,0>(make_double2(a.real(), a.imag()), make_double2(0.0, 0.0),
435 make_double2(0.0, 0.0),
x, y, z,
x);
445 template <
typename Float2,
typename FloatN>
450 { y -=
a.x*
x; z +=
a.x*w; w = y +
b.x*w; }
457 blasCuda<tripleCGUpdate,0,1,1,1>(make_double2(a, 0.0), make_double2(b, 0.0),
458 make_double2(0.0, 0.0),
x, y, z, w);