QUDA  0.9.0
blas_quda.cu
Go to the documentation of this file.
1 
2 #include <stdlib.h>
3 #include <stdio.h>
4 #include <cstring> // needed for memset
5 
6 
7 
8 #include <tune_quda.h>
9 #include <typeinfo>
10 
11 #include <quda_internal.h>
12 #include <float_vector.h>
13 #include <blas_quda.h>
14 #include <color_spinor_field.h>
16 
17 #define checkSpinor(a, b) \
18  { \
19  if (a.Precision() != b.Precision()) \
20  errorQuda("precisions do not match: %d %d", a.Precision(), b.Precision()); \
21  if (a.Length() != b.Length()) \
22  errorQuda("lengths do not match: %lu %lu", a.Length(), b.Length()); \
23  if (a.Stride() != b.Stride()) \
24  errorQuda("strides do not match: %d %d", a.Stride(), b.Stride()); \
25  }
26 
27 #define checkLength(a, b) \
28  { \
29  if (a.Length() != b.Length()) \
30  errorQuda("lengths do not match: %lu %lu", a.Length(), b.Length()); \
31  if (a.Stride() != b.Stride()) \
32  errorQuda("strides do not match: %d %d", a.Stride(), b.Stride()); \
33  }
34 
35 namespace quda {
36 
37  namespace blas {
38 
39 #define BLAS_SPINOR // do not include ghost functions in Spinor class to reduce parameter space overhead
40 #include <texture.h>
41 
42  unsigned long long flops;
43  unsigned long long bytes;
44 
46  if (typeid(a) == typeid(cudaColorSpinorField)) {
47  static_cast<cudaColorSpinorField&>(a).zero();
48  } else {
49  static_cast<cpuColorSpinorField&>(a).zero();
50  }
51  }
52 
53  static cudaStream_t *blasStream;
54 
55  static struct {
56  const char *vol_str;
57  const char *aux_str;
59  } blasStrings;
60 
61  void initReduce();
62  void endReduce();
63 
64  void init()
65  {
67  initReduce();
68  }
69 
70  void end(void)
71  {
72  endReduce();
73  }
74 
75  cudaStream_t* getStream() { return blasStream; }
76 
77 #include <blas_core.cuh>
78 
79 #include <blas_core.h>
80 #include <blas_mixed_core.h>
81 
82  template <typename Float2, typename FloatN>
83  struct BlasFunctor {
84 
86  virtual __device__ __host__ void init() { ; }
87 
89  virtual __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w) = 0;
90  };
91 
95  template <typename Float2, typename FloatN>
96  struct axpby_ : public BlasFunctor<Float2,FloatN> {
97  const Float2 a;
98  const Float2 b;
99  axpby_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
100  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
101  { y = a.x*x + b.x*y; }
102  static int streams() { return 3; }
103  static int flops() { return 3; }
104  };
105 
106  void axpby(const double &a, ColorSpinorField &x, const double &b, ColorSpinorField &y) {
107  if (x.Precision() != y.Precision()) {
108  // call hacked mixed precision kernel
109  mixed::blasCuda<axpby_,0,1,0,0>(make_double2(a,0.0), make_double2(b,0.0), make_double2(0.0,0.0),
110  x, y, x, x);
111  } else {
112  blasCuda<axpby_,0,1,0,0>(make_double2(a, 0.0), make_double2(b, 0.0), make_double2(0.0, 0.0),
113  x, y, x, x);
114  }
115  }
116 
120  template <typename Float2, typename FloatN>
121  struct xpy_ : public BlasFunctor<Float2,FloatN> {
122  xpy_(const Float2 &a, const Float2 &b, const Float2 &c) { ; }
123  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w) { y += x ; }
124  static int streams() { return 3; }
125  static int flops() { return 1; }
126  };
127 
129  if (x.Precision() != y.Precision()) {
130  mixed::blasCuda<xpy_,0,1,0,0>(make_double2(1.0, 0.0), make_double2(1.0, 0.0),
131  make_double2(0.0, 0.0), x, y, x, x);
132  } else {
133  blasCuda<xpy_,0,1,0,0>(make_double2(1.0, 0.0), make_double2(1.0, 0.0),
134  make_double2(0.0, 0.0), x, y, x, x);
135  }
136  }
137 
141  template <typename Float2, typename FloatN>
142  struct axpy_ : public BlasFunctor<Float2,FloatN> {
143  const Float2 a;
144  axpy_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a) { ; }
145  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w) { y = a.x*x + y; }
146  static int streams() { return 3; }
147  static int flops() { return 2; }
148  };
149 
150  void axpy(const double &a, ColorSpinorField &x, ColorSpinorField &y) {
151  if (x.Precision() != y.Precision()) {
152  // call hacked mixed precision kernel
153  mixed::blasCuda<axpy_,0,1,0,0>(make_double2(a,0.0), make_double2(1.0,0.0), make_double2(0.0,0.0),
154  x, y, x, x);
155  } else {
156  blasCuda<axpy_,0,1,0,0>(make_double2(a, 0.0), make_double2(1.0, 0.0), make_double2(0.0, 0.0),
157  x, y, x, x);
158  }
159  }
160 
164  template <typename Float2, typename FloatN>
165  struct xpayz_ : public BlasFunctor<Float2,FloatN> {
166  const Float2 a;
167  xpayz_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a) { ; }
168  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w) { z = x + a.x*y; }
169  static int streams() { return 3; }
170  static int flops() { return 2; }
171  };
172 
173  void xpay(ColorSpinorField &x, const double &a, ColorSpinorField &y) {
174  blasCuda<xpayz_,0,0,1,0>(make_double2(a,0.0), make_double2(0.0, 0.0), make_double2(0.0, 0.0), x, y, y, x);
175  }
176 
178  blasCuda<xpayz_,0,0,1,0>(make_double2(a,0.0), make_double2(0.0, 0.0), make_double2(0.0, 0.0), x, y, z, x);
179  }
180 
184  template <typename Float2, typename FloatN>
185  struct mxpy_ : public BlasFunctor<Float2,FloatN> {
186  mxpy_(const Float2 &a, const Float2 &b, const Float2 &c) { ; }
187  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w) { y -= x; }
188  static int streams() { return 3; }
189  static int flops() { return 1; }
190  };
191 
193  blasCuda<mxpy_,0,1,0,0>(make_double2(1.0, 0.0), make_double2(1.0, 0.0),
194  make_double2(0.0, 0.0), x, y, x, x);
195  }
196 
200  template <typename Float2, typename FloatN>
201  struct ax_ : public BlasFunctor<Float2,FloatN> {
202  const Float2 a;
203  ax_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a) { ; }
204  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w) { x *= a.x; }
205  static int streams() { return 2; }
206  static int flops() { return 1; }
207  };
208 
209  void ax(const double &a, ColorSpinorField &x) {
210  blasCuda<ax_,1,0,0,0>(make_double2(a, 0.0), make_double2(0.0, 0.0),
211  make_double2(0.0, 0.0), x, x, x, x);
212  }
213 
214 
219  __device__ __host__ void _caxpy(const float2 &a, const float4 &x, float4 &y) {
220  y.x += a.x*x.x; y.x -= a.y*x.y;
221  y.y += a.y*x.x; y.y += a.x*x.y;
222  y.z += a.x*x.z; y.z -= a.y*x.w;
223  y.w += a.y*x.z; y.w += a.x*x.w;
224  }
225 
226  __device__ __host__ void _caxpy(const float2 &a, const float2 &x, float2 &y) {
227  y.x += a.x*x.x; y.x -= a.y*x.y;
228  y.y += a.y*x.x; y.y += a.x*x.y;
229  }
230 
231  __device__ __host__ void _caxpy(const double2 &a, const double2 &x, double2 &y) {
232  y.x += a.x*x.x; y.x -= a.y*x.y;
233  y.y += a.y*x.x; y.y += a.x*x.y;
234  }
235 
236  template <typename Float2, typename FloatN>
237  struct caxpy_ : public BlasFunctor<Float2,FloatN> {
238  const Float2 a;
239  caxpy_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a) { ; }
240  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
241  { _caxpy(a, x, y); }
242  static int streams() { return 3; }
243  static int flops() { return 4; }
244  };
245 
247  if (x.Precision() != y.Precision()) {
248  mixed::blasCuda<caxpy_,0,1,0,0>(make_double2(real(a),imag(a)), make_double2(0.0, 0.0),
249  make_double2(0.0, 0.0), x, y, x, x);
250  } else {
251  blasCuda<caxpy_,0,1,0,0>(make_double2(real(a),imag(a)), make_double2(0.0, 0.0),
252  make_double2(0.0, 0.0), x, y, x, x);
253  }
254  }
255 
256 
261  __device__ __host__ void _caxpby(const float2 &a, const float4 &x, const float2 &b, float4 &y)
262  { float4 yy;
263  yy.x = a.x*x.x; yy.x -= a.y*x.y; yy.x += b.x*y.x; yy.x -= b.y*y.y;
264  yy.y = a.y*x.x; yy.y += a.x*x.y; yy.y += b.y*y.x; yy.y += b.x*y.y;
265  yy.z = a.x*x.z; yy.z -= a.y*x.w; yy.z += b.x*y.z; yy.z -= b.y*y.w;
266  yy.w = a.y*x.z; yy.w += a.x*x.w; yy.w += b.y*y.z; yy.w += b.x*y.w;
267  y = yy; }
268 
269  __device__ __host__ void _caxpby(const float2 &a, const float2 &x, const float2 &b, float2 &y)
270  { float2 yy;
271  yy.x = a.x*x.x; yy.x -= a.y*x.y; yy.x += b.x*y.x; yy.x -= b.y*y.y;
272  yy.y = a.y*x.x; yy.y += a.x*x.y; yy.y += b.y*y.x; yy.y += b.x*y.y;
273  y = yy; }
274 
275  __device__ __host__ void _caxpby(const double2 &a, const double2 &x, const double2 &b, double2 &y)
276  { double2 yy;
277  yy.x = a.x*x.x; yy.x -= a.y*x.y; yy.x += b.x*y.x; yy.x -= b.y*y.y;
278  yy.y = a.y*x.x; yy.y += a.x*x.y; yy.y += b.y*y.x; yy.y += b.x*y.y;
279  y = yy; }
280 
281  template <typename Float2, typename FloatN>
282  struct caxpby_ : public BlasFunctor<Float2,FloatN> {
283  const Float2 a;
284  const Float2 b;
285  caxpby_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
286  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
287  { _caxpby(a, x, b, y); }
288  static int streams() { return 3; }
289  static int flops() { return 7; }
290  };
291 
293  blasCuda<caxpby_,0,1,0,0>(make_double2(REAL(a),IMAG(a)), make_double2(REAL(b), IMAG(b)),
294  make_double2(0.0, 0.0), x, y, x, x);
295  }
296 
301  __device__ __host__ void _cxpaypbz(const float4 &x, const float2 &a, const float4 &y, const float2 &b, float4 &z) {
302  float4 zz;
303  zz.x = x.x + a.x*y.x; zz.x -= a.y*y.y; zz.x += b.x*z.x; zz.x -= b.y*z.y;
304  zz.y = x.y + a.y*y.x; zz.y += a.x*y.y; zz.y += b.y*z.x; zz.y += b.x*z.y;
305  zz.z = x.z + a.x*y.z; zz.z -= a.y*y.w; zz.z += b.x*z.z; zz.z -= b.y*z.w;
306  zz.w = x.w + a.y*y.z; zz.w += a.x*y.w; zz.w += b.y*z.z; zz.w += b.x*z.w;
307  z = zz;
308  }
309 
310  __device__ __host__ void _cxpaypbz(const float2 &x, const float2 &a, const float2 &y, const float2 &b, float2 &z) {
311  float2 zz;
312  zz.x = x.x + a.x*y.x; zz.x -= a.y*y.y; zz.x += b.x*z.x; zz.x -= b.y*z.y;
313  zz.y = x.y + a.y*y.x; zz.y += a.x*y.y; zz.y += b.y*z.x; zz.y += b.x*z.y;
314  z = zz;
315  }
316 
317  __device__ __host__ void _cxpaypbz(const double2 &x, const double2 &a, const double2 &y, const double2 &b, double2 &z) {
318  double2 zz;
319  zz.x = x.x + a.x*y.x; zz.x -= a.y*y.y; zz.x += b.x*z.x; zz.x -= b.y*z.y;
320  zz.y = x.y + a.y*y.x; zz.y += a.x*y.y; zz.y += b.y*z.x; zz.y += b.x*z.y;
321  z = zz;
322  }
323 
324  template <typename Float2, typename FloatN>
325  struct cxpaypbz_ : public BlasFunctor<Float2,FloatN> {
326  const Float2 a;
327  const Float2 b;
328  cxpaypbz_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
329  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
330  { _cxpaypbz(x, a, y, b, z); }
331  static int streams() { return 4; }
332  static int flops() { return 8; }
333  };
334 
336  const Complex &b, ColorSpinorField &z) {
337  blasCuda<cxpaypbz_,0,0,1,0>(make_double2(REAL(a),IMAG(a)), make_double2(REAL(b), IMAG(b)),
338  make_double2(0.0, 0.0), x, y, z, z);
339  }
340 
344  template <typename Float2, typename FloatN>
345  struct axpyBzpcx_ : public BlasFunctor<Float2,FloatN> {
346  const Float2 a;
347  const Float2 b;
348  const Float2 c;
349  axpyBzpcx_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b), c(c) { ; }
350  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
351  { y += a.x*x; x = b.x*z + c.x*x; }
352  static int streams() { return 5; }
353  static int flops() { return 5; }
354  };
355 
356  void axpyBzpcx(const double &a, ColorSpinorField& x, ColorSpinorField& y, const double &b,
357  ColorSpinorField& z, const double &c) {
358  if (x.Precision() != y.Precision()) {
359  // call hacked mixed precision kernel
360  mixed::blasCuda<axpyBzpcx_,1,1,0,0>(make_double2(a,0.0), make_double2(b,0.0),
361  make_double2(c,0.0), x, y, z, x);
362  } else {
363  // swap arguments around
364  blasCuda<axpyBzpcx_,1,1,0,0>(make_double2(a,0.0), make_double2(b,0.0),
365  make_double2(c,0.0), x, y, z, x);
366  }
367  }
368 
369 
373  template <typename Float2, typename FloatN>
374  struct axpyZpbx_ : public BlasFunctor<Float2,FloatN> {
375  const Float2 a;
376  const Float2 b;
377  axpyZpbx_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
378  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
379  { y += a.x*x; x = z + b.x*x; }
380  static int streams() { return 5; }
381  static int flops() { return 4; }
382  };
383 
384  void axpyZpbx(const double &a, ColorSpinorField& x, ColorSpinorField& y,
385  ColorSpinorField& z, const double &b) {
386  if (x.Precision() != y.Precision()) {
387  // call hacked mixed precision kernel
388  mixed::blasCuda<axpyZpbx_,1,1,0,0>(make_double2(a,0.0), make_double2(b,0.0), make_double2(0.0,0.0),
389  x, y, z, x);
390  } else {
391  // swap arguments around
392  blasCuda<axpyZpbx_,1,1,0,0>(make_double2(a,0.0), make_double2(b,0.0), make_double2(0.0,0.0),
393  x, y, z, x);
394  }
395  }
396 
400  template <typename Float2, typename FloatN>
401  struct caxpyBzpx_ : public BlasFunctor<Float2,FloatN> {
402  const Float2 a;
403  const Float2 b;
404  caxpyBzpx_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
405  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
406  { _caxpy(a, x, y); _caxpy(b, z, x); }
407 
408  static int streams() { return 5; }
409  static int flops() { return 8; }
410  };
411 
414  if (x.Precision() != y.Precision()) {
415  mixed::blasCuda<caxpyBzpx_,1,1,0,0>(make_double2(REAL(a),IMAG(a)), make_double2(REAL(b), IMAG(b)),
416  make_double2(0.0,0.0), x, y, z, x);
417  } else {
418  blasCuda<caxpyBzpx_,1,1,0,0>(make_double2(REAL(a),IMAG(a)), make_double2(REAL(b), IMAG(b)),
419  make_double2(0.0,0.0), x, y, z, x);
420  }
421  }
422 
426  template <typename Float2, typename FloatN>
427  struct caxpyBxpz_ : public BlasFunctor<Float2,FloatN> {
428  const Float2 a;
429  const Float2 b;
430  caxpyBxpz_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
431  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
432  { _caxpy(a, x, y); _caxpy(b, x, z); }
433 
434  static int streams() { return 5; }
435  static int flops() { return 8; }
436  };
437 
440  if (x.Precision() != y.Precision()) {
441  mixed::blasCuda<caxpyBxpz_,0,1,1,0>(make_double2(REAL(a),IMAG(a)), make_double2(REAL(b), IMAG(b)),
442  make_double2(0.0,0.0), x, y, z, x);
443  } else {
444  blasCuda<caxpyBxpz_,0,1,1,0>(make_double2(REAL(a),IMAG(a)), make_double2(REAL(b), IMAG(b)),
445  make_double2(0.0,0.0), x, y, z, x);
446  }
447  }
448 
452  template <typename Float2, typename FloatN>
453  struct caxpbypzYmbw_ : public BlasFunctor<Float2,FloatN> {
454  const Float2 a;
455  const Float2 b;
456  caxpbypzYmbw_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
457  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
458  { _caxpy(a, x, z); _caxpy(b, y, z); _caxpy(-b, w, y); }
459 
460  static int streams() { return 6; }
461  static int flops() { return 12; }
462  };
463 
466  blasCuda<caxpbypzYmbw_,0,1,1,0>(make_double2(REAL(a),IMAG(a)), make_double2(REAL(b), IMAG(b)),
467  make_double2(0.0,0.0), x, y, z, w);
468  }
469 
473  template <typename Float2, typename FloatN>
474  struct cabxpyAx_ : public BlasFunctor<Float2,FloatN> {
475  const Float2 a;
476  const Float2 b;
477  cabxpyAx_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
478  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
479  { x *= a.x; _caxpy(b, x, y); }
480  static int streams() { return 4; }
481  static int flops() { return 5; }
482  };
483 
484  void cabxpyAx(const double &a, const Complex &b,
486  // swap arguments around
487  blasCuda<cabxpyAx_,1,1,0,0>(make_double2(a,0.0), make_double2(REAL(b),IMAG(b)),
488  make_double2(0.0,0.0), x, y, x, x);
489  }
490 
494  template <typename Float2, typename FloatN>
495  struct caxpbypz_ : public BlasFunctor<Float2,FloatN> {
496  const Float2 a;
497  const Float2 b;
498  caxpbypz_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
499  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
500  { _caxpy(a, x, z); _caxpy(b, y, z); }
501  static int streams() { return 4; }
502  static int flops() { return 8; }
503  };
504 
505  void caxpbypz(const Complex &a, ColorSpinorField &x, const Complex &b,
507  blasCuda<caxpbypz_,0,0,1,0>(make_double2(REAL(a),IMAG(a)), make_double2(REAL(b),IMAG(b)),
508  make_double2(0.0,0.0), x, y, z, z);
509  }
510 
514  template <typename Float2, typename FloatN>
515  struct caxpbypczpw_ : public BlasFunctor<Float2,FloatN> {
516  const Float2 a;
517  const Float2 b;
518  const Float2 c;
519  caxpbypczpw_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b), c(c) { ; }
520  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
521  { _caxpy(a, x, w); _caxpy(b, y, w); _caxpy(c, z, w); }
522 
523  static int streams() { return 4; }
524  static int flops() { return 12; }
525  };
526 
527  void caxpbypczpw(const Complex &a, ColorSpinorField &x, const Complex &b,
529  ColorSpinorField &w) {
530  blasCuda<caxpbypczpw_,0,0,0,1>(make_double2(REAL(a),IMAG(a)), make_double2(REAL(b),IMAG(b)),
531  make_double2(REAL(c),IMAG(c)), x, y, z, w);
532  }
533 
539  template <typename Float2, typename FloatN>
540  struct caxpyxmaz_ : public BlasFunctor<Float2,FloatN> {
541  Float2 a;
542  caxpyxmaz_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a) { ; }
543  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
544  { _caxpy(a, x, y); _caxpy(-a, z, x); }
545  static int streams() { return 5; }
546  static int flops() { return 8; }
547  };
548 
551  blasCuda<caxpyxmaz_,1,1,0,0>(make_double2(REAL(a), IMAG(a)), make_double2(0.0, 0.0),
552  make_double2(0.0, 0.0), x, y, z, x);
553  }
554 
560  template <typename Float2, typename FloatN>
561  struct caxpyxmazMR_ : public BlasFunctor<Float2,FloatN> {
562  Float2 a;
563  double3 *Ar3;
564  caxpyxmazMR_(const Float2 &a, const Float2 &b, const Float2 &c)
565  : a(a), Ar3(static_cast<double3*>(blas::getDeviceReduceBuffer())) { ; }
566 
567  inline __device__ __host__ void init() {
568 #ifdef __CUDA_ARCH__
569  typedef decltype(a.x) real;
570  double3 result = __ldg(Ar3);
571  a.y = a.x * (real)(result.y) * ((real)1.0 / (real)result.z);
572  a.x = a.x * (real)(result.x) * ((real)1.0 / (real)result.z);
573 #endif
574  }
575 
576  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
577  { _caxpy(a, x, y); _caxpy(-a, z, x); }
578 
579  static int streams() { return 5; }
580  static int flops() { return 8; }
581  };
582 
585  if (!commAsyncReduction())
586  errorQuda("This kernel requires asynchronous reductions to be set");
587  if (x.Location() == QUDA_CPU_FIELD_LOCATION)
588  errorQuda("This kernel cannot be run on CPU fields");
589 
590  blasCuda<caxpyxmazMR_,1,1,0,0>(make_double2(REAL(a), IMAG(a)), make_double2(0.0, 0.0),
591  make_double2(0.0, 0.0), x, y, z, x);
592  }
593 
600  template <typename Float2, typename FloatN>
601  struct tripleCGUpdate_ : public BlasFunctor<Float2,FloatN> {
602  Float2 a, b;
603  tripleCGUpdate_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
604  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
605  { y += a.x*w; z -= a.x*x; w = z + b.x*w; }
606  static int streams() { return 7; }
607  static int flops() { return 6; }
608  };
609 
610  void tripleCGUpdate(const double &a, const double &b, ColorSpinorField &x,
612  if (x.Precision() != y.Precision()) {
613  // call hacked mixed precision kernel
614  mixed::blasCuda<tripleCGUpdate_,0,1,1,1>(make_double2(a,0.0), make_double2(b,0.0),
615  make_double2(0.0,0.0), x, y, z, w);
616  } else {
617  blasCuda<tripleCGUpdate_,0,1,1,1>(make_double2(a, 0.0), make_double2(b, 0.0),
618  make_double2(0.0, 0.0), x, y, z, w);
619  }
620  }
621 
622  } // namespace blas
623 
624 } // namespace quda
caxpbypzYmbw_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_quda.cu:456
tripleCGUpdate_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_quda.cu:603
static int flops()
total number of input and output streams
Definition: blas_quda.cu:125
static int streams()
Definition: blas_quda.cu:380
const char * aux_str
Definition: blas_quda.cu:57
axpyBzpcx_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_quda.cu:349
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
Definition: blas_quda.cu:576
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
Definition: blas_quda.cu:100
void xpay(ColorSpinorField &x, const double &a, ColorSpinorField &y)
Definition: blas_quda.cu:173
__device__ __host__ void _caxpby(const float2 &a, const float4 &x, const float2 &b, float4 &y)
Definition: blas_quda.cu:261
void caxpyXmazMR(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Definition: blas_quda.cu:583
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
Definition: blas_quda.cu:378
const Float2 a
Definition: blas_quda.cu:97
bool commAsyncReduction()
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
Definition: blas_quda.cu:350
static int flops()
total number of input and output streams
Definition: blas_quda.cu:524
char aux_tmp[TuneKey::aux_n]
Definition: blas_quda.cu:58
static int streams()
Definition: blas_quda.cu:501
const Float2 a
Definition: blas_quda.cu:143
void end(void)
Definition: blas_quda.cu:70
const Float2 b
Definition: blas_quda.cu:284
#define errorQuda(...)
Definition: util_quda.h:90
void init()
Definition: blas_quda.cu:64
static int streams()
Definition: blas_quda.cu:124
caxpyxmaz_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_quda.cu:542
std::complex< double > Complex
Definition: eig_variables.h:13
cudaStream_t * streams
static int flops()
total number of input and output streams
Definition: blas_quda.cu:409
void xpayz(ColorSpinorField &x, const double &a, ColorSpinorField &y, ColorSpinorField &z)
Definition: blas_quda.cu:177
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
Definition: blas_quda.cu:405
const int Nstream
caxpyBzpx_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_quda.cu:404
mxpy_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_quda.cu:186
void ax(const double &a, ColorSpinorField &x)
Definition: blas_quda.cu:209
static int streams()
Definition: blas_quda.cu:102
static int flops()
total number of input and output streams
Definition: blas_quda.cu:289
caxpyBxpz_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_quda.cu:430
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
Definition: blas_quda.cu:457
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
Definition: blas_quda.cu:240
caxpbypczpw_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_quda.cu:519
virtual __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)=0
where the reduction is usually computed and any auxiliary operations
static int flops()
total number of input and output streams
Definition: blas_quda.cu:170
void caxpyBzpx(const Complex &, ColorSpinorField &, ColorSpinorField &, const Complex &, ColorSpinorField &)
Definition: blas_quda.cu:412
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
Definition: blas_quda.cu:123
static int flops()
total number of input and output streams
Definition: blas_quda.cu:147
void caxpyBxpz(const Complex &, ColorSpinorField &, ColorSpinorField &, const Complex &, ColorSpinorField &)
Definition: blas_quda.cu:438
static int streams()
Definition: blas_quda.cu:352
xpayz_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_quda.cu:167
__device__ __host__ void _caxpy(const float2 &a, const float4 &x, float4 &y)
Definition: blas_quda.cu:219
static int streams()
Definition: blas_quda.cu:288
#define b
cudaStream_t * getStream()
Definition: blas_quda.cu:75
caxpy_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_quda.cu:239
void cabxpyAx(const double &a, const Complex &b, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:484
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
Definition: blas_quda.cu:604
#define IMAG(a)
Definition: blas_quda.h:14
virtual __device__ __host__ void init()
pre-computation routine before the main loop
Definition: blas_quda.cu:86
static cudaStream_t * blasStream
Definition: blas_quda.cu:53
void axpyZpbx(const double &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, const double &b)
Definition: blas_quda.cu:384
static int streams()
Definition: blas_quda.cu:434
static int flops()
total number of input and output streams
Definition: blas_quda.cu:481
caxpbypz_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_quda.cu:498
static struct quda::blas::@4 blasStrings
static int streams()
Definition: blas_quda.cu:408
static int streams()
Definition: blas_quda.cu:188
int int int w
void caxpbypzYmbw(const Complex &, ColorSpinorField &, const Complex &, ColorSpinorField &, ColorSpinorField &, ColorSpinorField &)
Definition: blas_quda.cu:464
static int flops()
total number of input and output streams
Definition: blas_quda.cu:353
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
Definition: blas_quda.cu:187
static int flops()
total number of input and output streams
Definition: blas_quda.cu:502
void initReduce()
Definition: reduce_quda.cu:78
static int streams()
Definition: blas_quda.cu:331
void tripleCGUpdate(const double &alpha, const double &beta, ColorSpinorField &q, ColorSpinorField &r, ColorSpinorField &x, ColorSpinorField &p)
Definition: blas_quda.cu:610
cxpaypbz_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_quda.cu:328
const Float2 a
Definition: blas_quda.cu:238
static int flops()
total number of input and output streams
Definition: blas_quda.cu:435
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
Definition: blas_quda.cu:431
axpyZpbx_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_quda.cu:377
static int streams()
Definition: blas_quda.cu:480
static int flops()
total number of input and output streams
Definition: blas_quda.cu:381
const Float2 a
Definition: blas_quda.cu:283
static int flops()
total number of input and output streams
Definition: blas_quda.cu:461
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:246
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
Definition: blas_quda.cu:478
void zero(ColorSpinorField &a)
Definition: blas_quda.cu:45
const char * vol_str
Definition: blas_quda.cu:56
ax_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_quda.cu:203
void caxpbypczpw(const Complex &, ColorSpinorField &, const Complex &, ColorSpinorField &, const Complex &, ColorSpinorField &, ColorSpinorField &)
Definition: blas_quda.cu:527
void axpy(const double &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:150
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
Definition: blas_quda.cu:168
static int flops()
total number of input and output streams
Definition: blas_quda.cu:243
#define REAL(a)
Definition: blas_quda.h:13
cabxpyAx_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_quda.cu:477
static int flops()
total number of input and output streams
Definition: blas_quda.cu:103
__device__ __host__ void init()
pre-computation routine before the main loop
Definition: blas_quda.cu:567
void axpby(const double &a, ColorSpinorField &x, const double &b, ColorSpinorField &y)
Definition: blas_quda.cu:106
axpby_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_quda.cu:99
void caxpbypz(const Complex &, ColorSpinorField &, const Complex &, ColorSpinorField &, ColorSpinorField &)
Definition: blas_quda.cu:505
static int flops()
total number of input and output streams
Definition: blas_quda.cu:580
xpy_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_quda.cu:122
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
Definition: blas_quda.cu:286
const Float2 a
Definition: blas_quda.cu:166
static int streams()
Definition: blas_quda.cu:545
static int streams()
Definition: blas_quda.cu:242
void axpyBzpcx(const double &a, ColorSpinorField &x, ColorSpinorField &y, const double &b, ColorSpinorField &z, const double &c)
Definition: blas_quda.cu:356
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
Definition: blas_quda.cu:499
void caxpyXmaz(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Definition: blas_quda.cu:549
static const int aux_n
Definition: tune_key.h:12
void * getDeviceReduceBuffer()
Definition: reduce_quda.cu:73
__device__ __host__ void _cxpaypbz(const float4 &x, const float2 &a, const float4 &y, const float2 &b, float4 &z)
Definition: blas_quda.cu:301
static int flops()
total number of input and output streams
Definition: blas_quda.cu:189
axpy_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_quda.cu:144
unsigned long long flops
Definition: blas_quda.cu:42
void xpy(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:128
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
Definition: blas_quda.cu:204
static int flops()
total number of input and output streams
Definition: blas_quda.cu:546
void caxpby(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y)
Definition: blas_quda.cu:292
caxpyxmazMR_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_quda.cu:564
const void * c
void mxpy(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:192
void cxpaypbz(ColorSpinorField &, const Complex &b, ColorSpinorField &y, const Complex &c, ColorSpinorField &z)
Definition: blas_quda.cu:335
static int streams()
Definition: blas_quda.cu:205
static int streams()
Definition: blas_quda.cu:146
static int flops()
total number of input and output streams
Definition: blas_quda.cu:332
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
Definition: blas_quda.cu:145
const Float2 b
Definition: blas_quda.cu:98
static int streams()
Definition: blas_quda.cu:169
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
Definition: blas_quda.cu:520
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
Definition: blas_quda.cu:329
const Float2 a
Definition: blas_quda.cu:202
#define a
static int flops()
total number of input and output streams
Definition: blas_quda.cu:206
void endReduce()
Definition: reduce_quda.cu:134
unsigned long long bytes
Definition: blas_quda.cu:43
caxpby_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_quda.cu:285
static int flops()
total number of input and output streams
Definition: blas_quda.cu:607
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w)
where the reduction is usually computed and any auxiliary operations
Definition: blas_quda.cu:543