QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
blas_core.cuh
Go to the documentation of this file.
1 #pragma once
2 
4 #include <blas_helper.cuh>
5 
6 namespace quda
7 {
8 
9  namespace blas
10  {
11 
12 #define BLAS_SPINOR // do not include ghost functions in Spinor class to reduce parameter space overhead
13 #include <texture.h>
14 
18  template <typename SpinorX, typename SpinorY, typename SpinorZ, typename SpinorW, typename SpinorV, typename Functor>
19  struct BlasArg {
20  SpinorX X;
21  SpinorY Y;
22  SpinorZ Z;
23  SpinorW W;
24  SpinorV V;
25  Functor f;
26  const int length;
27  BlasArg(SpinorX X, SpinorY Y, SpinorZ Z, SpinorW W, SpinorV V, Functor f, int length) :
28  X(X),
29  Y(Y),
30  Z(Z),
31  W(W),
32  V(V),
33  f(f),
34  length(length)
35  {
36  ;
37  }
38  };
39 
43  template <typename FloatN, int M, typename Arg> __global__ void blasKernel(Arg arg)
44  {
45  unsigned int i = blockIdx.x * (blockDim.x) + threadIdx.x;
46  unsigned int parity = blockIdx.y;
47  unsigned int gridSize = gridDim.x * blockDim.x;
48 
49  arg.f.init();
50 
51  while (i < arg.length) {
52  FloatN x[M], y[M], z[M], w[M], v[M];
53  arg.X.load(x, i, parity);
54  arg.Y.load(y, i, parity);
55  arg.Z.load(z, i, parity);
56  arg.W.load(w, i, parity);
57  arg.V.load(v, i, parity);
58 
59 #pragma unroll
60  for (int j = 0; j < M; j++) arg.f(x[j], y[j], z[j], w[j], v[j]);
61 
62  arg.X.save(x, i, parity);
63  arg.Y.save(y, i, parity);
64  arg.Z.save(z, i, parity);
65  arg.W.save(w, i, parity);
66  arg.V.save(v, i, parity);
67  i += gridSize;
68  }
69  }
70 
71  template <typename Float2, typename FloatN> struct BlasFunctor {
72 
74  virtual __device__ __host__ void init() { ; }
75 
77  virtual __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v) = 0;
78  };
79 
83  template <typename Float2, typename FloatN> struct axpbyz_ : public BlasFunctor<Float2, FloatN> {
84  const Float2 a;
85  const Float2 b;
86  axpbyz_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
87  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
88  {
89  v = a.x * x + b.x * y;
90  } // use v not z to ensure same precision as y
91  static int streams() { return 3; }
92  static int flops() { return 3; }
93  };
94 
98  template <typename Float2, typename FloatN> struct ax_ : public BlasFunctor<Float2, FloatN> {
99  const Float2 a;
100  ax_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a) { ; }
101  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v) { x *= a.x; }
102  static int streams() { return 2; }
103  static int flops() { return 1; }
104  };
105 
110  __device__ __host__ void _caxpy(const float2 &a, const float4 &x, float4 &y)
111  {
112  y.x += a.x * x.x;
113  y.x -= a.y * x.y;
114  y.y += a.y * x.x;
115  y.y += a.x * x.y;
116  y.z += a.x * x.z;
117  y.z -= a.y * x.w;
118  y.w += a.y * x.z;
119  y.w += a.x * x.w;
120  }
121 
122  __device__ __host__ void _caxpy(const float2 &a, const float2 &x, float2 &y)
123  {
124  y.x += a.x * x.x;
125  y.x -= a.y * x.y;
126  y.y += a.y * x.x;
127  y.y += a.x * x.y;
128  }
129 
130  __device__ __host__ void _caxpy(const double2 &a, const double2 &x, double2 &y)
131  {
132  y.x += a.x * x.x;
133  y.x -= a.y * x.y;
134  y.y += a.y * x.x;
135  y.y += a.x * x.y;
136  }
137 
138  template <typename Float2, typename FloatN> struct caxpy_ : public BlasFunctor<Float2, FloatN> {
139  const Float2 a;
140  caxpy_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a) { ; }
141  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v) { _caxpy(a, x, y); }
142  static int streams() { return 3; }
143  static int flops() { return 4; }
144  };
145 
150  __device__ __host__ void _caxpby(const float2 &a, const float4 &x, const float2 &b, float4 &y)
151  {
152  float4 yy;
153  yy.x = a.x * x.x;
154  yy.x -= a.y * x.y;
155  yy.x += b.x * y.x;
156  yy.x -= b.y * y.y;
157  yy.y = a.y * x.x;
158  yy.y += a.x * x.y;
159  yy.y += b.y * y.x;
160  yy.y += b.x * y.y;
161  yy.z = a.x * x.z;
162  yy.z -= a.y * x.w;
163  yy.z += b.x * y.z;
164  yy.z -= b.y * y.w;
165  yy.w = a.y * x.z;
166  yy.w += a.x * x.w;
167  yy.w += b.y * y.z;
168  yy.w += b.x * y.w;
169  y = yy;
170  }
171 
172  __device__ __host__ void _caxpby(const float2 &a, const float2 &x, const float2 &b, float2 &y)
173  {
174  float2 yy;
175  yy.x = a.x * x.x;
176  yy.x -= a.y * x.y;
177  yy.x += b.x * y.x;
178  yy.x -= b.y * y.y;
179  yy.y = a.y * x.x;
180  yy.y += a.x * x.y;
181  yy.y += b.y * y.x;
182  yy.y += b.x * y.y;
183  y = yy;
184  }
185 
186  __device__ __host__ void _caxpby(const double2 &a, const double2 &x, const double2 &b, double2 &y)
187  {
188  double2 yy;
189  yy.x = a.x * x.x;
190  yy.x -= a.y * x.y;
191  yy.x += b.x * y.x;
192  yy.x -= b.y * y.y;
193  yy.y = a.y * x.x;
194  yy.y += a.x * x.y;
195  yy.y += b.y * y.x;
196  yy.y += b.x * y.y;
197  y = yy;
198  }
199 
200  template <typename Float2, typename FloatN> struct caxpby_ : public BlasFunctor<Float2, FloatN> {
201  const Float2 a;
202  const Float2 b;
203  caxpby_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
204  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
205  {
206  _caxpby(a, x, b, y);
207  }
208  static int streams() { return 3; }
209  static int flops() { return 7; }
210  };
211 
212  template <typename Float2, typename FloatN> struct caxpbypczw_ : public BlasFunctor<Float2, FloatN> {
213  const Float2 a;
214  const Float2 b;
215  const Float2 c;
216  caxpbypczw_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b), c(c) { ; }
217  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
218  {
219  w = y;
220  _caxpby(a, x, b, w);
221  _caxpy(c, z, w);
222  }
223  static int streams() { return 4; }
224  static int flops() { return 8; }
225  };
226 
230  template <typename Float2, typename FloatN> struct axpyBzpcx_ : public BlasFunctor<Float2, FloatN> {
231  const Float2 a;
232  const Float2 b;
233  const Float2 c;
234  axpyBzpcx_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b), c(c) { ; }
235  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
236  {
237  y += a.x * x;
238  x = b.x * z + c.x * x;
239  }
240  static int streams() { return 5; }
241  static int flops() { return 5; }
242  };
243 
247  template <typename Float2, typename FloatN> struct axpyZpbx_ : public BlasFunctor<Float2, FloatN> {
248  const Float2 a;
249  const Float2 b;
250  axpyZpbx_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
251  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
252  {
253  y += a.x * x;
254  x = z + b.x * x;
255  }
256  static int streams() { return 5; }
257  static int flops() { return 4; }
258  };
259 
263  template <typename Float2, typename FloatN> struct caxpyBzpx_ : public BlasFunctor<Float2, FloatN> {
264  const Float2 a;
265  const Float2 b;
266  caxpyBzpx_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
267  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
268  {
269  _caxpy(a, x, y);
270  _caxpy(b, z, x);
271  }
272 
273  static int streams() { return 5; }
274  static int flops() { return 8; }
275  };
276 
280  template <typename Float2, typename FloatN> struct caxpyBxpz_ : public BlasFunctor<Float2, FloatN> {
281  const Float2 a;
282  const Float2 b;
283  caxpyBxpz_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
284  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
285  {
286  _caxpy(a, x, y);
287  _caxpy(b, x, z);
288  }
289 
290  static int streams() { return 5; }
291  static int flops() { return 8; }
292  };
293 
297  template <typename Float2, typename FloatN> struct caxpbypzYmbw_ : public BlasFunctor<Float2, FloatN> {
298  const Float2 a;
299  const Float2 b;
300  caxpbypzYmbw_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
301  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
302  {
303  _caxpy(a, x, z);
304  _caxpy(b, y, z);
305  _caxpy(-b, w, y);
306  }
307 
308  static int streams() { return 6; }
309  static int flops() { return 12; }
310  };
311 
315  template <typename Float2, typename FloatN> struct cabxpyAx_ : public BlasFunctor<Float2, FloatN> {
316  const Float2 a;
317  const Float2 b;
318  cabxpyAx_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
319  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
320  {
321  x *= a.x;
322  _caxpy(b, x, y);
323  }
324  static int streams() { return 4; }
325  static int flops() { return 5; }
326  };
327 
333  template <typename Float2, typename FloatN> struct caxpyxmaz_ : public BlasFunctor<Float2, FloatN> {
334  Float2 a;
335  caxpyxmaz_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a) { ; }
336  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
337  {
338  _caxpy(a, x, y);
339  _caxpy(-a, z, x);
340  }
341  static int streams() { return 5; }
342  static int flops() { return 8; }
343  };
344 
350  template <typename Float2, typename FloatN> struct caxpyxmazMR_ : public BlasFunctor<Float2, FloatN> {
351  Float2 a;
352  double3 *Ar3;
353  caxpyxmazMR_(const Float2 &a, const Float2 &b, const Float2 &c) :
354  a(a),
355  Ar3(static_cast<double3 *>(blas::getDeviceReduceBuffer()))
356  {
357  ;
358  }
359 
360  inline __device__ __host__ void init()
361  {
362 #ifdef __CUDA_ARCH__
363  typedef decltype(a.x) real;
364  double3 result = __ldg(Ar3);
365  a.y = a.x * (real)(result.y) * ((real)1.0 / (real)result.z);
366  a.x = a.x * (real)(result.x) * ((real)1.0 / (real)result.z);
367 #endif
368  }
369 
370  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
371  {
372  _caxpy(a, x, y);
373  _caxpy(-a, z, x);
374  }
375 
376  static int streams() { return 5; }
377  static int flops() { return 8; }
378  };
379 
386  template <typename Float2, typename FloatN> struct tripleCGUpdate_ : public BlasFunctor<Float2, FloatN> {
387  Float2 a, b;
388  tripleCGUpdate_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
389  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
390  {
391  y += a.x * w;
392  z -= a.x * x;
393  w = z + b.x * w;
394  }
395  static int streams() { return 7; }
396  static int flops() { return 6; }
397  };
398 
404  template <typename Float2, typename FloatN> struct doubleCG3Init_ : public BlasFunctor<Float2, FloatN> {
405  Float2 a;
406  doubleCG3Init_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a) { ; }
407  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
408  {
409  y = x;
410  x += a.x * z;
411  }
412  static int streams() { return 3; }
413  static int flops() { return 3; }
414  };
415 
422  template <typename Float2, typename FloatN> struct doubleCG3Update_ : public BlasFunctor<Float2, FloatN> {
423  Float2 a, b;
424  doubleCG3Update_(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
425  FloatN tmp {};
426  __device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
427  {
428  tmp = x;
429  x = b.x * (x + a.x * z) + b.y * y;
430  y = tmp;
431  }
432  static int streams() { return 4; }
433  static int flops() { return 7; }
434  };
435 
436  } // namespace blas
437 } // namespace quda
caxpbypzYmbw_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_core.cuh:300
tripleCGUpdate_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_core.cuh:388
static int streams()
Definition: blas_core.cuh:256
axpyBzpcx_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_core.cuh:234
__device__ __host__ void _caxpby(const float2 &a, const float4 &x, const float2 &b, float4 &y)
Definition: blas_core.cuh:150
axpbyz_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_core.cuh:86
static int flops()
total number of input and output streams
Definition: blas_core.cuh:224
doubleCG3Update_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_core.cuh:424
const Float2 b
Definition: blas_core.cuh:202
doubleCG3Init_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_core.cuh:406
caxpyxmaz_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_core.cuh:335
static int flops()
total number of input and output streams
Definition: blas_core.cuh:274
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
where the reduction is usually computed and any auxiliary operations
Definition: blas_core.cuh:101
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:44
caxpyBzpx_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_core.cuh:266
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
where the reduction is usually computed and any auxiliary operations
Definition: blas_core.cuh:284
static int flops()
total number of input and output streams
Definition: blas_core.cuh:209
caxpyBxpz_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_core.cuh:283
caxpbypczw_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_core.cuh:216
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
where the reduction is usually computed and any auxiliary operations
Definition: blas_core.cuh:251
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
where the reduction is usually computed and any auxiliary operations
Definition: blas_core.cuh:235
static int streams()
Definition: blas_core.cuh:240
__device__ __host__ void _caxpy(const float2 &a, const float4 &x, float4 &y)
Definition: blas_core.cuh:110
static int streams()
Definition: blas_core.cuh:208
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
where the reduction is usually computed and any auxiliary operations
Definition: blas_core.cuh:87
BlasArg(SpinorX X, SpinorY Y, SpinorZ Z, SpinorW W, SpinorV V, Functor f, int length)
Definition: blas_core.cuh:27
caxpy_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_core.cuh:140
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
where the reduction is usually computed and any auxiliary operations
Definition: blas_core.cuh:204
virtual __device__ __host__ void init()
pre-computation routine before the main loop
Definition: blas_core.cuh:74
static int streams()
Definition: blas_core.cuh:290
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
where the reduction is usually computed and any auxiliary operations
Definition: blas_core.cuh:336
static int flops()
total number of input and output streams
Definition: blas_core.cuh:325
__global__ void blasKernel(Arg arg)
Definition: blas_core.cuh:43
static int streams()
Definition: blas_core.cuh:273
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
where the reduction is usually computed and any auxiliary operations
Definition: blas_core.cuh:389
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
where the reduction is usually computed and any auxiliary operations
Definition: blas_core.cuh:426
static int flops()
total number of input and output streams
Definition: blas_core.cuh:241
const Float2 a
Definition: blas_core.cuh:139
static int flops()
total number of input and output streams
Definition: blas_core.cuh:291
const Float2 b
Definition: blas_core.cuh:85
static int flops()
total number of input and output streams
Definition: blas_core.cuh:413
axpyZpbx_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_core.cuh:250
static int streams()
Definition: blas_core.cuh:324
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
where the reduction is usually computed and any auxiliary operations
Definition: blas_core.cuh:370
static int flops()
total number of input and output streams
Definition: blas_core.cuh:257
const Float2 a
Definition: blas_core.cuh:201
static int flops()
total number of input and output streams
Definition: blas_core.cuh:309
ax_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_core.cuh:100
static int flops()
total number of input and output streams
Definition: blas_core.cuh:143
cabxpyAx_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_core.cuh:318
__device__ __host__ void init()
pre-computation routine before the main loop
Definition: blas_core.cuh:360
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
where the reduction is usually computed and any auxiliary operations
Definition: blas_core.cuh:141
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
where the reduction is usually computed and any auxiliary operations
Definition: blas_core.cuh:217
static int flops()
total number of input and output streams
Definition: blas_core.cuh:377
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
where the reduction is usually computed and any auxiliary operations
Definition: blas_core.cuh:267
static int streams()
Definition: blas_core.cuh:341
static int streams()
Definition: blas_core.cuh:142
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
where the reduction is usually computed and any auxiliary operations
Definition: blas_core.cuh:407
void * getDeviceReduceBuffer()
Definition: reduce_quda.cu:26
const Float2 a
Definition: blas_core.cuh:84
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
colorspinor::FieldOrderCB< real, Ns, Nc, 1, order > V
Definition: spinor_noise.cu:23
static int flops()
total number of input and output streams
Definition: blas_core.cuh:92
static int streams()
Definition: blas_core.cuh:91
static int flops()
total number of input and output streams
Definition: blas_core.cuh:342
caxpyxmazMR_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_core.cuh:353
static int streams()
Definition: blas_core.cuh:102
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
where the reduction is usually computed and any auxiliary operations
Definition: blas_core.cuh:301
static int flops()
total number of input and output streams
Definition: blas_core.cuh:433
QudaParity parity
Definition: covdev_test.cpp:54
__device__ __host__ void operator()(FloatN &x, FloatN &y, FloatN &z, FloatN &w, FloatN &v)
where the reduction is usually computed and any auxiliary operations
Definition: blas_core.cuh:319
const Float2 a
Definition: blas_core.cuh:99
static int flops()
total number of input and output streams
Definition: blas_core.cuh:103
caxpby_(const Float2 &a, const Float2 &b, const Float2 &c)
Definition: blas_core.cuh:203
static int flops()
total number of input and output streams
Definition: blas_core.cuh:396