QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
blas_quda.cu
Go to the documentation of this file.
1 #include <stdlib.h>
2 #include <stdio.h>
3 #include <cstring> // needed for memset
4 #include <typeinfo>
5 
6 #include <tune_quda.h>
7 
8 #include <quda_internal.h>
9 #include <float_vector.h>
10 #include <blas_quda.h>
11 #include <color_spinor_field.h>
12 
13 #include <jitify_helper.cuh>
14 #include <kernels/blas_core.cuh>
15 
16 namespace quda {
17 
18  namespace blas {
19 
20 #include <generic_blas.cuh>
21 
22  unsigned long long flops;
23  unsigned long long bytes;
24 
25  static cudaStream_t *blasStream;
26 
27  template <typename FloatN, int M, typename SpinorX, typename SpinorY, typename SpinorZ, typename SpinorW,
28  typename SpinorV, typename Functor>
29  class BlasCuda : public Tunable
30  {
31 
32  private:
33  const int nParity; // for composite fields this includes the number of composites
35 
36  const ColorSpinorField &x, &y, &z, &w, &v;
37 
38  // host pointers used for backing up fields when tuning
39  // dont't these curry these in to minimize Arg size
40  char *X_h, *Y_h, *Z_h, *W_h, *V_h;
42 
43  unsigned int sharedBytesPerThread() const { return 0; }
44  unsigned int sharedBytesPerBlock(const TuneParam &param) const { return 0; }
45 
46  virtual bool advanceSharedBytes(TuneParam &param) const
47  {
48  TuneParam next(param);
49  advanceBlockDim(next); // to get next blockDim
50  int nthreads = next.block.x * next.block.y * next.block.z;
51  param.shared_bytes = sharedBytesPerThread() * nthreads > sharedBytesPerBlock(param) ?
52  sharedBytesPerThread() * nthreads :
53  sharedBytesPerBlock(param);
54  return false;
55  }
56 
57  public:
58  BlasCuda(SpinorX &X, SpinorY &Y, SpinorZ &Z, SpinorW &W, SpinorV &V, Functor &f, ColorSpinorField &x,
60  nParity((x.IsComposite() ? x.CompositeDim() : 1) * x.SiteSubset()), // must be first
61  arg(X, Y, Z, W, V, f, length / nParity),
62  x(x),
63  y(y),
64  z(z),
65  w(w),
66  v(v),
67  X_h(0),
68  Y_h(0),
69  Z_h(0),
70  W_h(0),
71  V_h(0),
72  Xnorm_h(0),
73  Ynorm_h(0),
74  Znorm_h(0),
75  Wnorm_h(0),
76  Vnorm_h(0)
77  {
78  strcpy(aux, x.AuxString());
79  if (x.Precision() != y.Precision()) {
80  strcat(aux, ",");
81  strcat(aux, y.AuxString());
82  }
83 
84 #ifdef JITIFY
85  ::quda::create_jitify_program("kernels/blas_core.cuh");
86 #endif
87  }
88 
89  virtual ~BlasCuda() {}
90 
91  inline TuneKey tuneKey() const { return TuneKey(x.VolString(), typeid(arg.f).name(), aux); }
92 
93  inline void apply(const cudaStream_t &stream)
94  {
95  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
96 #ifdef JITIFY
97  using namespace jitify::reflection;
98  jitify_error = program->kernel("quda::blas::blasKernel")
99  .instantiate(Type<FloatN>(), M, Type<decltype(arg)>())
100  .configure(tp.grid, tp.block, tp.shared_bytes, stream)
101  .launch(arg);
102 #else
103  blasKernel<FloatN, M><<<tp.grid, tp.block, tp.shared_bytes, stream>>>(arg);
104 #endif
105  }
106 
107  void preTune()
108  {
109  arg.X.backup(&X_h, &Xnorm_h, x.Bytes(), x.NormBytes());
110  arg.Y.backup(&Y_h, &Ynorm_h, y.Bytes(), y.NormBytes());
111  arg.Z.backup(&Z_h, &Znorm_h, z.Bytes(), z.NormBytes());
112  arg.W.backup(&W_h, &Wnorm_h, w.Bytes(), w.NormBytes());
113  arg.V.backup(&V_h, &Vnorm_h, v.Bytes(), v.NormBytes());
114  }
115 
116  void postTune()
117  {
118  arg.X.restore(&X_h, &Xnorm_h, x.Bytes(), x.NormBytes());
119  arg.Y.restore(&Y_h, &Ynorm_h, y.Bytes(), y.NormBytes());
120  arg.Z.restore(&Z_h, &Znorm_h, z.Bytes(), z.NormBytes());
121  arg.W.restore(&W_h, &Wnorm_h, w.Bytes(), w.NormBytes());
122  arg.V.restore(&V_h, &Vnorm_h, v.Bytes(), v.NormBytes());
123  }
124 
126  {
127  Tunable::initTuneParam(param);
128  param.grid.y = nParity;
129  }
130 
132  {
133  Tunable::initTuneParam(param);
134  param.grid.y = nParity;
135  }
136 
137  long long flops() const { return arg.f.flops() * vec_length<FloatN>::value * arg.length * nParity * M; }
138  long long bytes() const
139  {
140  // the factor two here assumes we are reading and writing to the high precision vector
141  // this will evaluate correctly for non-mixed kernels since the +2/-2 will cancel out
142  return (arg.f.streams() - 2) * x.Bytes() + 2 * y.Bytes();
143  }
144  int tuningIter() const { return 3; }
145  };
146 
147  template <typename RegType, typename StoreType, typename yType, int M, template <typename, typename> class Functor,
148  int writeX, int writeY, int writeZ, int writeW, int writeV>
149  void nativeBlas(const double2 &a, const double2 &b, const double2 &c, ColorSpinorField &x, ColorSpinorField &y,
151  {
152 
153  checkLength(x, y);
154  checkLength(x, z);
155  checkLength(x, w);
156  checkLength(x, v);
157 
163 
164  typedef typename scalar<RegType>::type Float;
165  typedef typename vector<Float, 2>::type Float2;
166  typedef vector<Float, 2> vec2;
167  Functor<Float2, RegType> f((Float2)vec2(a), (Float2)vec2(b), (Float2)vec2(c));
168 
170  X, Y, Z, W, V, f, x, y, z, w, v, length);
171  blas.apply(*blasStream);
172 
173  blas::bytes += blas.bytes();
174  blas::flops += blas.flops();
175 
176  checkCudaError();
177  }
178 
183  template <template <typename Float, typename FloatN> class Functor, int writeX = 0, int writeY = 0, int writeZ = 0,
184  int writeW = 0, int writeV = 0>
185  void uni_blas(const double2 &a, const double2 &b, const double2 &c, ColorSpinorField &x, ColorSpinorField &y,
187  {
188 
189  checkPrecision(x, y, z, w, v);
190 
191  if (checkLocation(x, y, z, w, v) == QUDA_CUDA_FIELD_LOCATION) {
192 
193  if (!x.isNative()
195  || x.Nspin() == 4 && x.FieldOrder() == QUDA_FLOAT2_FIELD_ORDER && x.Precision() == QUDA_HALF_PRECISION)) {
196  warningQuda("Device blas on non-native fields is not supported\n");
197  return;
198  }
199 
200  if (x.Precision() == QUDA_DOUBLE_PRECISION) {
201 
202 #if QUDA_PRECISION & 8
203 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) || defined(GPU_STAGGERED_DIRAC)
204  const int M = 1;
205  nativeBlas<double2, double2, double2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
206  a, b, c, x, y, z, w, v, x.Length() / (2 * M));
207 #else
208  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
209 #endif
210 #else
211  errorQuda("QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x.Precision());
212 #endif
213 
214  } else if (x.Precision() == QUDA_SINGLE_PRECISION) {
215 
216 #if QUDA_PRECISION & 4
217  if (x.Nspin() == 4 && x.FieldOrder() == QUDA_FLOAT4_FIELD_ORDER) {
218 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
219  const int M = 1;
220  nativeBlas<float4, float4, float4, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
221  a, b, c, x, y, z, w, v, x.Length() / (4 * M));
222 #else
223  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
224 #endif
225  } else if (x.Nspin() == 2 || x.Nspin() == 1 || (x.Nspin() == 4 && x.FieldOrder() == QUDA_FLOAT2_FIELD_ORDER)) {
226 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) || defined(GPU_STAGGERED_DIRAC)
227  const int M = 1;
228  nativeBlas<float2, float2, float2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
229  a, b, c, x, y, z, w, v, x.Length() / (2 * M));
230 #else
231  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
232 #endif
233  } else {
234  errorQuda("nSpin=%d is not supported\n", x.Nspin());
235  }
236 #else
237  errorQuda("QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x.Precision());
238 #endif
239 
240  } else if (x.Precision() == QUDA_HALF_PRECISION) {
241 
242 #if QUDA_PRECISION & 2
243  if (x.Ncolor() != 3) { errorQuda("nColor = %d is not supported", x.Ncolor()); }
244  if (x.Nspin() == 4 && x.FieldOrder() == QUDA_FLOAT4_FIELD_ORDER) { // wilson
245 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
246  const int M = 6;
247  nativeBlas<float4, short4, short4, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
248  a, b, c, x, y, z, w, v, x.Volume());
249 #else
250  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
251 #endif
252  } else if (x.Nspin() == 4 && x.FieldOrder() == QUDA_FLOAT2_FIELD_ORDER) { // wilson
253 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
254  const int M = 12;
255  nativeBlas<float2, short2, short2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
256  a, b, c, x, y, z, w, v, x.Volume());
257 #else
258  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
259 #endif
260  } else if (x.Nspin() == 1) { // staggered
261 #ifdef GPU_STAGGERED_DIRAC
262  const int M = 3;
263  nativeBlas<float2, short2, short2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
264  a, b, c, x, y, z, w, v, x.Volume());
265 #else
266  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
267 #endif
268  } else {
269  errorQuda("nSpin=%d is not supported\n", x.Nspin());
270  }
271 #else
272  errorQuda("QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x.Precision());
273 #endif
274 
275  } else if (x.Precision() == QUDA_QUARTER_PRECISION) {
276 
277 #if QUDA_PRECISION & 1
278  if (x.Ncolor() != 3) { errorQuda("nColor = %d is not supported", x.Ncolor()); }
279  if (x.Nspin() == 4) { // wilson
280 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
281  const int M = 6;
282  nativeBlas<float4, char4, char4, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
283  a, b, c, x, y, z, w, v, x.Volume());
284 #else
285  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
286 #endif
287  } else if (x.Nspin() == 1) { // staggered
288 #ifdef GPU_STAGGERED_DIRAC
289  const int M = 3;
290  nativeBlas<float2, char2, char2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
291  a, b, c, x, y, z, w, v, x.Volume());
292 #else
293  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
294 #endif
295  } else {
296  errorQuda("nSpin=%d is not supported\n", x.Nspin());
297  }
298 #else
299  errorQuda("QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x.Precision());
300 #endif
301 
302  } else {
303  errorQuda("precision=%d is not supported\n", x.Precision());
304  }
305  } else { // fields on the cpu
306  if (x.Precision() == QUDA_DOUBLE_PRECISION) {
307  Functor<double2, double2> f(a, b, c);
308  genericBlas<double, double, writeX, writeY, writeZ, writeW, writeV>(x, y, z, w, v, f);
309  } else if (x.Precision() == QUDA_SINGLE_PRECISION) {
310  Functor<float2, float2> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
311  genericBlas<float, float, writeX, writeY, writeZ, writeW, writeV>(x, y, z, w, v, f);
312  } else {
313  errorQuda("Not implemented");
314  }
315  }
316  }
317 
324  template <template <typename Float, typename FloatN> class Functor, int writeX = 0, int writeY = 0, int writeZ = 0,
325  int writeW = 0, int writeV = 0>
326  void mixed_blas(const double2 &a, const double2 &b, const double2 &c, ColorSpinorField &x, ColorSpinorField &y,
328  {
329 
330  checkPrecision(x, z, w);
331  checkPrecision(y, v);
332 
333  if (checkLocation(x, y, z, w, v) == QUDA_CUDA_FIELD_LOCATION) {
334 
335  if (!x.isNative()) {
336  warningQuda("Device blas on non-native fields is not supported\n");
337  return;
338  }
339 
341 
342 #if QUDA_PRECISION & 4
343  if (x.Nspin() == 4) {
344  const int M = 12;
345  nativeBlas<double2, float4, double2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
346  a, b, c, x, y, z, w, v, x.Volume());
347  } else if (x.Nspin() == 1) {
348  const int M = 3;
349  nativeBlas<double2, float2, double2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
350  a, b, c, x, y, z, w, v, x.Volume());
351  }
352 #else
353  errorQuda("QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x.Precision());
354 #endif
355 
356  } else if (x.Precision() == QUDA_HALF_PRECISION) {
357 
358 #if QUDA_PRECISION & 2
359  if (y.Precision() == QUDA_DOUBLE_PRECISION) {
360 
361 #if QUDA_PRECISION & 8
362  if (x.Nspin() == 4) {
363  const int M = 12;
364  nativeBlas<double2, short4, double2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
365  a, b, c, x, y, z, w, v, x.Volume());
366  } else if (x.Nspin() == 1) {
367  const int M = 3;
368  nativeBlas<double2, short2, double2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
369  a, b, c, x, y, z, w, v, x.Volume());
370  }
371 #else
372  errorQuda("QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, y.Precision());
373 #endif
374 
375  } else if (y.Precision() == QUDA_SINGLE_PRECISION) {
376 
377 #if QUDA_PRECISION & 4
378  if (x.Nspin() == 4) {
379  const int M = 6;
380  nativeBlas<float4, short4, float4, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
381  a, b, c, x, y, z, w, v, x.Volume());
382  } else if (x.Nspin() == 1) {
383  const int M = 3;
384  nativeBlas<float2, short2, float2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
385  a, b, c, x, y, z, w, v, x.Volume());
386  }
387 #else
388  errorQuda("QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, y.Precision());
389 #endif
390 
391  } else {
392  errorQuda("Not implemented for this precision combination %d %d", x.Precision(), y.Precision());
393  }
394 #else
395  errorQuda("QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x.Precision());
396 #endif
397 
398  } else if (x.Precision() == QUDA_QUARTER_PRECISION) {
399 
400 #if QUDA_PRECISION & 1
401 
402  if (y.Precision() == QUDA_DOUBLE_PRECISION) {
403 
404 #if QUDA_PRECISION & 8
405  if (x.Nspin() == 4) {
406  const int M = 12;
407  nativeBlas<double2, char4, double2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
408  a, b, c, x, y, z, w, v, x.Volume());
409  } else if (x.Nspin() == 1) {
410  const int M = 3;
411  nativeBlas<double2, char2, double2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
412  a, b, c, x, y, z, w, v, x.Volume());
413  }
414 #else
415  errorQuda("QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, y.Precision());
416 #endif
417 
418  } else if (y.Precision() == QUDA_SINGLE_PRECISION) {
419 
420 #if QUDA_PRECISION & 4
421  if (x.Nspin() == 4) {
422  const int M = 6;
423  nativeBlas<float4, char4, float4, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
424  a, b, c, x, y, z, w, v, x.Volume());
425  } else if (x.Nspin() == 1) {
426  const int M = 3;
427  nativeBlas<float2, char2, float2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
428  a, b, c, x, y, z, w, v, x.Volume());
429  }
430 #else
431  errorQuda("QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, y.Precision());
432 #endif
433 
434  } else if (y.Precision() == QUDA_HALF_PRECISION) {
435 
436 #if QUDA_PRECISION & 2
437  if (x.Nspin() == 4) {
438  const int M = 6;
439  nativeBlas<float4, char4, short4, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
440  a, b, c, x, y, z, w, v, x.Volume());
441  } else if (x.Nspin() == 1) {
442  const int M = 3;
443  nativeBlas<float2, char2, short2, M, Functor, writeX, writeY, writeZ, writeW, writeV>(
444  a, b, c, x, y, z, w, v, x.Volume());
445  }
446 #else
447  errorQuda("QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, y.Precision());
448 #endif
449 
450  } else {
451  errorQuda("Not implemented for this precision combination %d %d", x.Precision(), y.Precision());
452  }
453 #else
454  errorQuda("QUDA_PRECISION=%d does not enable precision %d", QUDA_PRECISION, x.Precision());
455 #endif
456 
457  } else {
458  errorQuda("Not implemented for this precision combination %d %d", x.Precision(), y.Precision());
459  }
460 
461  } else { // fields on the cpu
462  using namespace quda::colorspinor;
464  Functor<double2, double2> f(a, b, c);
465  genericBlas<float, double, writeX, writeY, writeZ, writeW, writeV>(x, y, z, w, v, f);
466  } else {
467  errorQuda("Not implemented");
468  }
469  }
470  }
471 
473  if (typeid(a) == typeid(cudaColorSpinorField)) {
474  static_cast<cudaColorSpinorField&>(a).zero();
475  } else {
476  static_cast<cpuColorSpinorField&>(a).zero();
477  }
478  }
479 
480  void initReduce();
481  void endReduce();
482 
483  void init()
484  {
485  blasStream = &streams[Nstream-1];
486  initReduce();
487  }
488 
489  void end(void)
490  {
491  endReduce();
492  }
493 
494  cudaStream_t* getStream() { return blasStream; }
495 
496  void axpbyz(double a, ColorSpinorField &x, double b,
498  if (x.Precision() != y.Precision()) {
499  // call hacked mixed precision kernel
500  mixed_blas<axpbyz_, 0, 0, 0, 0, 1>(
501  make_double2(a, 0.0), make_double2(b, 0.0), make_double2(0.0, 0.0), x, y, x, x, z);
502  } else {
503  uni_blas<axpbyz_, 0, 0, 0, 0, 1>(
504  make_double2(a, 0.0), make_double2(b, 0.0), make_double2(0.0, 0.0), x, y, x, x, z);
505  }
506  }
507 
508  void ax(double a, ColorSpinorField &x) {
509  uni_blas<ax_, 1>(make_double2(a, 0.0), make_double2(0.0, 0.0), make_double2(0.0, 0.0), x, x, x, x, x);
510  }
511 
513  if (x.Precision() != y.Precision()) {
514  mixed_blas<caxpy_, 0, 1>(
515  make_double2(real(a), imag(a)), make_double2(0.0, 0.0), make_double2(0.0, 0.0), x, y, x, x, y);
516  } else {
517  uni_blas<caxpy_, 0, 1>(
518  make_double2(real(a), imag(a)), make_double2(0.0, 0.0), make_double2(0.0, 0.0), x, y, x, x, y);
519  }
520  }
521 
522 
523  void caxpby(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y) {
524  uni_blas<caxpby_, 0, 1>(
525  make_double2(REAL(a), IMAG(a)), make_double2(REAL(b), IMAG(b)), make_double2(0.0, 0.0), x, y, x, x, y);
526  }
527 
528  void caxpbypczw(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y, const Complex &c,
530  {
531  uni_blas<caxpbypczw_, 0, 0, 0, 1>(make_double2(REAL(a), IMAG(a)), make_double2(REAL(b), IMAG(b)),
532  make_double2(REAL(c), IMAG(c)), x, y, z, w, y);
533  }
534 
536  const Complex &b, ColorSpinorField &z) {
537  uni_blas<caxpbypczw_, 0, 0, 0, 1>(make_double2(1.0, 0.0), make_double2(REAL(a), IMAG(a)),
538  make_double2(REAL(b), IMAG(b)), x, y, z, z, y);
539  }
540 
541  void axpyBzpcx(double a, ColorSpinorField& x, ColorSpinorField& y, double b,
542  ColorSpinorField& z, double c) {
543  if (x.Precision() != y.Precision()) {
544  // call hacked mixed precision kernel
545  mixed_blas<axpyBzpcx_, 1, 1>(make_double2(a, 0.0), make_double2(b, 0.0), make_double2(c, 0.0), x, y, z, x, y);
546  } else {
547  // swap arguments around
548  uni_blas<axpyBzpcx_, 1, 1>(make_double2(a, 0.0), make_double2(b, 0.0), make_double2(c, 0.0), x, y, z, x, y);
549  }
550  }
551 
553  ColorSpinorField& z, double b) {
554  if (x.Precision() != y.Precision()) {
555  // call hacked mixed precision kernel
556  mixed_blas<axpyZpbx_, 1, 1>(make_double2(a, 0.0), make_double2(b, 0.0), make_double2(0.0, 0.0), x, y, z, x, y);
557  } else {
558  // swap arguments around
559  uni_blas<axpyZpbx_, 1, 1>(make_double2(a, 0.0), make_double2(b, 0.0), make_double2(0.0, 0.0), x, y, z, x, y);
560  }
561  }
562 
565  if (x.Precision() != y.Precision()) {
566  mixed_blas<caxpyBzpx_, 1, 1>(
567  make_double2(REAL(a), IMAG(a)), make_double2(REAL(b), IMAG(b)), make_double2(0.0, 0.0), x, y, z, x, y);
568  } else {
569  uni_blas<caxpyBzpx_, 1, 1>(
570  make_double2(REAL(a), IMAG(a)), make_double2(REAL(b), IMAG(b)), make_double2(0.0, 0.0), x, y, z, x, y);
571  }
572  }
573 
576  if (x.Precision() != y.Precision()) {
577  mixed_blas<caxpyBxpz_, 0, 1, 1>(
578  make_double2(REAL(a), IMAG(a)), make_double2(REAL(b), IMAG(b)), make_double2(0.0, 0.0), x, y, z, x, y);
579  } else {
580  uni_blas<caxpyBxpz_, 0, 1, 1>(
581  make_double2(REAL(a), IMAG(a)), make_double2(REAL(b), IMAG(b)), make_double2(0.0, 0.0), x, y, z, x, y);
582  }
583  }
584 
585  void caxpbypzYmbw(const Complex &a, ColorSpinorField &x, const Complex &b,
587  uni_blas<caxpbypzYmbw_, 0, 1, 1>(
588  make_double2(REAL(a), IMAG(a)), make_double2(REAL(b), IMAG(b)), make_double2(0.0, 0.0), x, y, z, w, y);
589  }
590 
591  void cabxpyAx(double a, const Complex &b, ColorSpinorField &x, ColorSpinorField &y) {
592  // swap arguments around
593  uni_blas<cabxpyAx_, 1, 1>(
594  make_double2(a, 0.0), make_double2(REAL(b), IMAG(b)), make_double2(0.0, 0.0), x, y, x, x, y);
595  }
596 
599  uni_blas<caxpyxmaz_, 1, 1>(
600  make_double2(REAL(a), IMAG(a)), make_double2(0.0, 0.0), make_double2(0.0, 0.0), x, y, z, x, y);
601  }
602 
605  if (!commAsyncReduction())
606  errorQuda("This kernel requires asynchronous reductions to be set");
608  errorQuda("This kernel cannot be run on CPU fields");
609 
610  uni_blas<caxpyxmazMR_, 1, 1>(
611  make_double2(REAL(a), IMAG(a)), make_double2(0.0, 0.0), make_double2(0.0, 0.0), x, y, z, x, y);
612  }
613 
614  void tripleCGUpdate(double a, double b, ColorSpinorField &x,
616  if (x.Precision() != y.Precision()) {
617  // call hacked mixed precision kernel
618  mixed_blas<tripleCGUpdate_, 0, 1, 1, 1>(
619  make_double2(a, 0.0), make_double2(b, 0.0), make_double2(0.0, 0.0), x, y, z, w, y);
620  } else {
621  uni_blas<tripleCGUpdate_, 0, 1, 1, 1>(
622  make_double2(a, 0.0), make_double2(b, 0.0), make_double2(0.0, 0.0), x, y, z, w, y);
623  }
624  }
625 
627  uni_blas<doubleCG3Init_, 1, 1, 0, 0>(
628  make_double2(a, 0.0), make_double2(0.0, 0.0), make_double2(0.0, 0.0), x, y, z, z, y);
629  }
630 
632  uni_blas<doubleCG3Update_, 1, 1, 0, 0>(
633  make_double2(a, 0.0), make_double2(b, 1.0 - b), make_double2(0.0, 0.0), x, y, z, z, y);
634  }
635 
636  } // namespace blas
637 
638 } // namespace quda
void ax(double a, ColorSpinorField &x)
Definition: blas_quda.cu:508
int Z[4]
Definition: test_util.cpp:26
void caxpyXmazMR(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Definition: blas_quda.cu:603
bool commAsyncReduction()
void axpyZpbx(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, double b)
Definition: blas_quda.cu:552
const char * AuxString() const
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
void end(void)
Definition: blas_quda.cu:489
const ColorSpinorField & x
Definition: blas_quda.cu:36
#define checkPrecision(...)
#define errorQuda(...)
Definition: util_quda.h:121
void init()
Definition: blas_quda.cu:483
Helper file when using jitify run-time compilation. This file should be included in source code...
const ColorSpinorField & y
Definition: blas_quda.cu:36
cudaStream_t * streams
cudaStream_t * stream
void cabxpyAx(double a, const Complex &b, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:591
const int Nstream
Definition: quda_internal.h:83
void caxpbypczw(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y, const Complex &c, ColorSpinorField &z, ColorSpinorField &w)
Definition: blas_quda.cu:528
virtual bool advanceSharedBytes(TuneParam &param) const
Definition: blas_quda.cu:46
const char * VolString() const
const ColorSpinorField & w
Definition: blas_quda.cu:36
int length[]
void caxpyBzpx(const Complex &, ColorSpinorField &, ColorSpinorField &, const Complex &, ColorSpinorField &)
Definition: blas_quda.cu:563
void caxpyBxpz(const Complex &, ColorSpinorField &, ColorSpinorField &, const Complex &, ColorSpinorField &)
Definition: blas_quda.cu:574
void doubleCG3Update(double a, double b, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Definition: blas_quda.cu:631
QudaGaugeParam param
Definition: pack_test.cpp:17
cudaStream_t * getStream()
Definition: blas_quda.cu:494
void apply(const cudaStream_t &stream)
Definition: blas_quda.cu:93
static cudaStream_t * blasStream
Definition: blas_quda.cu:25
const ColorSpinorField & v
Definition: blas_quda.cu:36
void mixed_blas(const double2 &a, const double2 &b, const double2 &c, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v)
Definition: blas_quda.cu:326
void caxpbypzYmbw(const Complex &, ColorSpinorField &, const Complex &, ColorSpinorField &, ColorSpinorField &, ColorSpinorField &)
Definition: blas_quda.cu:585
TuneKey tuneKey() const
Definition: blas_quda.cu:91
void initReduce()
Definition: reduce_quda.cu:64
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
CUresult jitify_error
Definition: tune_quda.h:276
#define warningQuda(...)
Definition: util_quda.h:133
#define checkLocation(...)
void axpyBzpcx(double a, ColorSpinorField &x, ColorSpinorField &y, double b, ColorSpinorField &z, double c)
Definition: blas_quda.cu:541
long long bytes() const
Definition: blas_quda.cu:138
unsigned int sharedBytesPerBlock(const TuneParam &param) const
Definition: blas_quda.cu:44
#define REAL(a)
Definition: blas_helper.cuh:14
int X[4]
Definition: covdev_test.cpp:70
std::complex< double > Complex
Definition: quda_internal.h:46
void initTuneParam(TuneParam &param) const
Definition: blas_quda.cu:125
void tripleCGUpdate(double alpha, double beta, ColorSpinorField &q, ColorSpinorField &r, ColorSpinorField &x, ColorSpinorField &p)
Definition: blas_quda.cu:614
int tuningIter() const
Definition: blas_quda.cu:144
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:512
void axpbyz(double a, ColorSpinorField &x, double b, ColorSpinorField &y, ColorSpinorField &z)
Definition: blas_quda.cu:496
virtual ~BlasCuda()
Definition: blas_quda.cu:89
void zero(ColorSpinorField &a)
Definition: blas_quda.cu:472
void doubleCG3Init(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Definition: blas_quda.cu:626
int V
Definition: test_util.cpp:27
void checkLength(const ColorSpinorField &a, const ColorSpinorField &b)
Definition: blas_helper.cuh:26
QudaFieldLocation Location() const
BlasArg< SpinorX, SpinorY, SpinorZ, SpinorW, SpinorV, Functor > arg
Definition: blas_quda.cu:34
BlasCuda(SpinorX &X, SpinorY &Y, SpinorZ &Z, SpinorW &W, SpinorV &V, Functor &f, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v, int length)
Definition: blas_quda.cu:58
void caxpyXmaz(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Definition: blas_quda.cu:597
long long flops() const
Definition: blas_quda.cu:137
void uni_blas(const double2 &a, const double2 &b, const double2 &c, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v)
Definition: blas_quda.cu:185
unsigned long long flops
Definition: blas_quda.cu:22
void caxpby(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y)
Definition: blas_quda.cu:523
virtual void initTuneParam(TuneParam &param) const
Definition: tune_quda.h:304
void cxpaypbz(ColorSpinorField &, const Complex &b, ColorSpinorField &y, const Complex &c, ColorSpinorField &z)
Definition: blas_quda.cu:535
void nativeBlas(const double2 &a, const double2 &b, const double2 &c, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v, int length)
Definition: blas_quda.cu:149
#define checkCudaError()
Definition: util_quda.h:161
virtual bool advanceBlockDim(TuneParam &param) const
Definition: tune_quda.h:124
const ColorSpinorField & z
Definition: blas_quda.cu:36
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
QudaPrecision Precision() const
void defaultTuneParam(TuneParam &param) const
Definition: blas_quda.cu:131
QudaFieldOrder FieldOrder() const
char aux[TuneKey::aux_n]
Definition: tune_quda.h:265
unsigned int sharedBytesPerThread() const
Definition: blas_quda.cu:43
void endReduce()
Definition: reduce_quda.cu:120
unsigned long long bytes
Definition: blas_quda.cu:23
#define IMAG(a)
Definition: blas_helper.cuh:15