QUDA  v0.5.0
A library for QCD on GPUs
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
blas_quda.cu
Go to the documentation of this file.
1 #include <stdlib.h>
2 #include <stdio.h>
3 #include <cstring> // needed for memset
4 
5 #include <float_vector.h>
6 
7 #include <tune_quda.h>
8 #include <typeinfo>
9 
10 #include <quda_internal.h>
11 #include <blas_quda.h>
12 #include <color_spinor_field.h>
13 #include <face_quda.h> // this is where the MPI / QMP depdendent code is
14 
15 #define checkSpinor(a, b) \
16  { \
17  if (a.Precision() != b.Precision()) \
18  errorQuda("precisions do not match: %d %d", a.Precision(), b.Precision()); \
19  if (a.Length() != b.Length()) \
20  errorQuda("lengths do not match: %d %d", a.Length(), b.Length()); \
21  if (a.Stride() != b.Stride()) \
22  errorQuda("strides do not match: %d %d", a.Stride(), b.Stride()); \
23  }
24 
25 namespace quda {
26 
27 #include <texture.h>
28 
29  unsigned long long blas_flops;
30  unsigned long long blas_bytes;
31 
33 
34  // blasTuning = 1 turns off error checking
35  static QudaTune blasTuning = QUDA_TUNE_NO;
36  static QudaVerbosity verbosity = QUDA_SILENT;
37  static cudaStream_t *blasStream;
38 
39  static struct {
41  int stride;
42  } blasConstants;
43 
44  void initReduce();
45  void endReduce();
46 
47  void initBlas()
48  {
49  blasStream = &streams[Nstream-1];
50  initReduce();
51  }
52 
53  void endBlas(void)
54  {
55  endReduce();
56  }
57 
59  {
60  blasTuning = tune;
61  verbosity = verbose;
62  }
63 
64  QudaTune getBlasTuning() { return blasTuning; }
65  QudaVerbosity getBlasVerbosity() { return verbosity; }
66  cudaStream_t* getBlasStream() { return blasStream; }
67 
68 #include <blas_core.h>
69 
73  template <typename Float2, typename FloatN>
74  struct axpby {
75  const Float2 a;
76  const Float2 b;
77  axpby(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
78  __device__ void operator()(const FloatN &x, FloatN &y, const FloatN &z, const FloatN &w) { y = a.x*x + b.x*y; }
79  static int streams() { return 3; }
80  static int flops() { return 3; }
81  };
82 
83  void axpbyCuda(const double &a, cudaColorSpinorField &x, const double &b, cudaColorSpinorField &y) {
84  blasCuda<axpby,0,1,0,0>(make_double2(a, 0.0), make_double2(b, 0.0), make_double2(0.0, 0.0),
85  x, y, x, x);
86  }
87 
91  template <typename Float2, typename FloatN>
92  struct xpy {
93  xpy(const Float2 &a, const Float2 &b, const Float2 &c) { ; }
94  __device__ void operator()(const FloatN &x, FloatN &y, const FloatN &z, const FloatN &w) { y += x ; }
95  static int streams() { return 3; }
96  static int flops() { return 1; }
97  };
98 
100  blasCuda<xpy,0,1,0,0>(make_double2(1.0, 0.0), make_double2(1.0, 0.0), make_double2(0.0, 0.0),
101  x, y, x, x);
102  }
103 
107  template <typename Float2, typename FloatN>
108  struct axpy {
109  const Float2 a;
110  axpy(const Float2 &a, const Float2 &b, const Float2 &c) : a(a) { ; }
111  __device__ void operator()(const FloatN &x, FloatN &y, const FloatN &z, const FloatN &w) { y = a.x*x + y; }
112  static int streams() { return 3; }
113  static int flops() { return 2; }
114  };
115 
116  void axpyCuda(const double &a, cudaColorSpinorField &x, cudaColorSpinorField &y) {
117  blasCuda<axpy,0,1,0,0>(make_double2(a, 0.0), make_double2(1.0, 0.0), make_double2(0.0, 0.0),
118  x, y, x, x);
119  }
120 
124  template <typename Float2, typename FloatN>
125  struct xpay {
126  const Float2 a;
127  xpay(const Float2 &a, const Float2 &b, const Float2 &c) : a(a) { ; }
128  __device__ void operator()(const FloatN &x, FloatN &y, const FloatN &z, const FloatN &w) { y = x + a.x*y; }
129  static int streams() { return 3; }
130  static int flops() { return 2; }
131  };
132 
133  void xpayCuda(cudaColorSpinorField &x, const double &a, cudaColorSpinorField &y) {
134  blasCuda<xpay,0,1,0,0>(make_double2(a,0.0), make_double2(0.0, 0.0), make_double2(0.0, 0.0),
135  x, y, x, x);
136  }
137 
141  template <typename Float2, typename FloatN>
142  struct mxpy {
143  mxpy(const Float2 &a, const Float2 &b, const Float2 &c) { ; }
144  __device__ void operator()(const FloatN &x, FloatN &y, const FloatN &z, const FloatN &w) { y -= x; }
145  static int streams() { return 3; }
146  static int flops() { return 1; }
147  };
148 
150  blasCuda<mxpy,0,1,0,0>(make_double2(1.0, 0.0), make_double2(1.0, 0.0),
151  make_double2(0.0, 0.0), x, y, x, x);
152  }
153 
157  template <typename Float2, typename FloatN>
158  struct ax {
159  const Float2 a;
160  ax(const Float2 &a, const Float2 &b, const Float2 &c) : a(a) { ; }
161  __device__ void operator()(FloatN &x, const FloatN &y, const FloatN &z, const FloatN &w) { x *= a.x; }
162  static int streams() { return 2; }
163  static int flops() { return 1; }
164  };
165 
166  void axCuda(const double &a, cudaColorSpinorField &x) {
167  blasCuda<ax,1,0,0,0>(make_double2(a, 0.0), make_double2(0.0, 0.0),
168  make_double2(0.0, 0.0), x, x, x, x);
169  }
170 
175  __device__ void caxpy_(const float2 &a, const float4 &x, float4 &y) {
176  y.x += a.x*x.x; y.x -= a.y*x.y;
177  y.y += a.y*x.x; y.y += a.x*x.y;
178  y.z += a.x*x.z; y.z -= a.y*x.w;
179  y.w += a.y*x.z; y.w += a.x*x.w;
180  }
181 
182  __device__ void caxpy_(const float2 &a, const float2 &x, float2 &y) {
183  y.x += a.x*x.x; y.x -= a.y*x.y;
184  y.y += a.y*x.x; y.y += a.x*x.y;
185  }
186 
187  __device__ void caxpy_(const double2 &a, const double2 &x, double2 &y) {
188  y.x += a.x*x.x; y.x -= a.y*x.y;
189  y.y += a.y*x.x; y.y += a.x*x.y;
190  }
191 
192  template <typename Float2, typename FloatN>
193  struct caxpy {
194  const Float2 a;
195  caxpy(const Float2 &a, const Float2 &b, const Float2 &c) : a(a) { ; }
196  __device__ void operator()(const FloatN &x, FloatN &y, const FloatN &z, const FloatN &w) { caxpy_(a, x, y); }
197  static int streams() { return 3; }
198  static int flops() { return 4; }
199  };
200 
202  blasCuda<caxpy,0,1,0,0>(make_double2(real(a),imag(a)), make_double2(0.0, 0.0),
203  make_double2(0.0, 0.0), x, y, x, x);
204  }
205 
210  __device__ void caxpby_(const float2 &a, const float4 &x, const float2 &b, float4 &y)
211  { float4 yy;
212  yy.x = a.x*x.x; yy.x -= a.y*x.y; yy.x += b.x*y.x; yy.x -= b.y*y.y;
213  yy.y = a.y*x.x; yy.y += a.x*x.y; yy.y += b.y*y.x; yy.y += b.x*y.y;
214  yy.z = a.x*x.z; yy.z -= a.y*x.w; yy.z += b.x*y.z; yy.z -= b.y*y.w;
215  yy.w = a.y*x.z; yy.w += a.x*x.w; yy.w += b.y*y.z; yy.w += b.x*y.w;
216  y = yy; }
217 
218  __device__ void caxpby_(const float2 &a, const float2 &x, const float2 &b, float2 &y)
219  { float2 yy;
220  yy.x = a.x*x.x; yy.x -= a.y*x.y; yy.x += b.x*y.x; yy.x -= b.y*y.y;
221  yy.y = a.y*x.x; yy.y += a.x*x.y; yy.y += b.y*y.x; yy.y += b.x*y.y;
222  y = yy; }
223 
224  __device__ void caxpby_(const double2 &a, const double2 &x, const double2 &b, double2 &y)
225  { double2 yy;
226  yy.x = a.x*x.x; yy.x -= a.y*x.y; yy.x += b.x*y.x; yy.x -= b.y*y.y;
227  yy.y = a.y*x.x; yy.y += a.x*x.y; yy.y += b.y*y.x; yy.y += b.x*y.y;
228  y = yy; }
229 
230  template <typename Float2, typename FloatN>
231  struct caxpby {
232  const Float2 a;
233  const Float2 b;
234  caxpby(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
235  __device__ void operator()(const FloatN &x, FloatN &y, const FloatN &z, const FloatN &w) { caxpby_(a, x, b, y); }
236  static int streams() { return 3; }
237  static int flops() { return 7; }
238  };
239 
241  blasCuda<caxpby,0,1,0,0>(make_double2(a.real(),a.imag()), make_double2(b.real(), b.imag()),
242  make_double2(0.0, 0.0), x, y, x, x);
243  }
244 
249  __device__ void cxpaypbz_(const float4 &x, const float2 &a, const float4 &y, const float2 &b, float4 &z) {
250  float4 zz;
251  zz.x = x.x + a.x*y.x; zz.x -= a.y*y.y; zz.x += b.x*z.x; zz.x -= b.y*z.y;
252  zz.y = x.y + a.y*y.x; zz.y += a.x*y.y; zz.y += b.y*z.x; zz.y += b.x*z.y;
253  zz.z = x.z + a.x*y.z; zz.z -= a.y*y.w; zz.z += b.x*z.z; zz.z -= b.y*z.w;
254  zz.w = x.w + a.y*y.z; zz.w += a.x*y.w; zz.w += b.y*z.z; zz.w += b.x*z.w;
255  z = zz;
256  }
257 
258  __device__ void cxpaypbz_(const float2 &x, const float2 &a, const float2 &y, const float2 &b, float2 &z) {
259  float2 zz;
260  zz.x = x.x + a.x*y.x; zz.x -= a.y*y.y; zz.x += b.x*z.x; zz.x -= b.y*z.y;
261  zz.y = x.y + a.y*y.x; zz.y += a.x*y.y; zz.y += b.y*z.x; zz.y += b.x*z.y;
262  z = zz;
263  }
264 
265  __device__ void cxpaypbz_(const double2 &x, const double2 &a, const double2 &y, const double2 &b, double2 &z) {
266  double2 zz;
267  zz.x = x.x + a.x*y.x; zz.x -= a.y*y.y; zz.x += b.x*z.x; zz.x -= b.y*z.y;
268  zz.y = x.y + a.y*y.x; zz.y += a.x*y.y; zz.y += b.y*z.x; zz.y += b.x*z.y;
269  z = zz;
270  }
271 
272  template <typename Float2, typename FloatN>
273  struct cxpaypbz {
274  const Float2 a;
275  const Float2 b;
276  cxpaypbz(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
277  __device__ void operator()(const FloatN &x, const FloatN &y, FloatN &z, FloatN &w)
278  { cxpaypbz_(x, a, y, b, z); }
279  static int streams() { return 4; }
280  static int flops() { return 8; }
281  };
282 
284  const Complex &b, cudaColorSpinorField &z) {
285  blasCuda<cxpaypbz,0,0,1,0>(make_double2(a.real(),a.imag()), make_double2(b.real(), b.imag()),
286  make_double2(0.0, 0.0), x, y, z, z);
287  }
288 
292  template <typename Float2, typename FloatN>
293  struct axpyBzpcx {
294  const Float2 a;
295  const Float2 b;
296  const Float2 c;
297  axpyBzpcx(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b), c(c) { ; }
298  __device__ void operator()(FloatN &x, FloatN &y, const FloatN &z, const FloatN &w)
299  { y += a.x*x; x = b.x*z + c.x*x; }
300  static int streams() { return 5; }
301  static int flops() { return 10; }
302  };
303 
304  void axpyBzpcxCuda(const double &a, cudaColorSpinorField& x, cudaColorSpinorField& y, const double &b,
305  cudaColorSpinorField& z, const double &c) {
306  blasCuda<axpyBzpcx,1,1,0,0>(make_double2(a,0.0), make_double2(b,0.0), make_double2(c,0.0),
307  x, y, z, x);
308  }
309 
313  template <typename Float2, typename FloatN>
314  struct axpyZpbx {
315  const Float2 a;
316  const Float2 b;
317  axpyZpbx(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
318  __device__ void operator()(FloatN &x, FloatN &y, const FloatN &z, const FloatN &w)
319  { y += a.x*x; x = z + b.x*x; }
320  static int streams() { return 5; }
321  static int flops() { return 8; }
322  };
323 
325  cudaColorSpinorField& z, const double &b) {
326  // swap arguments around
327  blasCuda<axpyZpbx,1,1,0,0>(make_double2(a,0.0), make_double2(b,0.0), make_double2(0.0,0.0),
328  x, y, z, x);
329  }
330 
334  template <typename Float2, typename FloatN>
335  struct caxpbypzYmbw {
336  const Float2 a;
337  const Float2 b;
338  caxpbypzYmbw(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
339  __device__ void operator()(const FloatN &x, FloatN &y, FloatN &z, const FloatN &w)
340  { caxpy_(a, x, z); caxpy_(b, y, z); caxpy_(-b, w, y); }
341 
342  static int streams() { return 6; }
343  static int flops() { return 12; }
344  };
345 
348  blasCuda<caxpbypzYmbw,0,1,1,0>(make_double2(a.real(),a.imag()), make_double2(b.real(), b.imag()),
349  make_double2(0.0,0.0), x, y, z, w);
350  }
351 
355  template <typename Float2, typename FloatN>
356  struct cabxpyAx {
357  const Float2 a;
358  const Float2 b;
359  cabxpyAx(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
360  __device__ void operator()(FloatN &x, FloatN &y, const FloatN &z, const FloatN &w)
361  { x *= a.x; caxpy_(b, x, y); }
362  static int streams() { return 4; }
363  static int flops() { return 5; }
364  };
365 
366  void cabxpyAxCuda(const double &a, const Complex &b,
368  // swap arguments around
369  blasCuda<cabxpyAx,1,1,0,0>(make_double2(a,0.0), make_double2(b.real(),b.imag()),
370  make_double2(0.0,0.0), x, y, x, x);
371  }
372 
376  template <typename Float2, typename FloatN>
377  struct caxpbypz {
378  const Float2 a;
379  const Float2 b;
380  caxpbypz(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
381  __device__ void operator()(const FloatN &x, const FloatN &y, FloatN &z, const FloatN &w)
382  { caxpy_(a, x, z); caxpy_(b, y, z); }
383  static int streams() { return 4; }
384  static int flops() { return 5; }
385  };
386 
387  void caxpbypzCuda(const Complex &a, cudaColorSpinorField &x, const Complex &b,
389  blasCuda<caxpbypz,0,0,1,0>(make_double2(a.real(),a.imag()), make_double2(b.real(),b.imag()),
390  make_double2(0.0,0.0), x, y, z, z);
391  }
392 
396  template <typename Float2, typename FloatN>
397  struct caxpbypczpw {
398  const Float2 a;
399  const Float2 b;
400  const Float2 c;
401  caxpbypczpw(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b), c(c) { ; }
402  __device__ void operator()(const FloatN &x, const FloatN &y, const FloatN &z, FloatN &w)
403  { caxpy_(a, x, w); caxpy_(b, y, w); caxpy_(c, z, w); }
404 
405  static int streams() { return 4; }
406  static int flops() { return 5; }
407  };
408 
412  blasCuda<caxpbypczpw,0,0,0,1>(make_double2(a.real(),a.imag()), make_double2(b.real(),b.imag()),
413  make_double2(c.real(), c.imag()), x, y, z, w);
414  }
415 
422  template <typename Float2, typename FloatN>
423  struct caxpyxmaz {
424  Float2 a;
425  caxpyxmaz(const Float2 &a, const Float2 &b, const Float2 &c) : a(a) { ; }
426  __device__ void operator()(FloatN &x, FloatN &y, const FloatN &z, const FloatN &w)
427  { caxpy_(a, x, y); x-= a.x*z; }
428  static int streams() { return 5; }
429  static int flops() { return 8; }
430  };
431 
434  blasCuda<caxpyxmaz,1,1,0,0>(make_double2(a.real(), a.imag()), make_double2(0.0, 0.0),
435  make_double2(0.0, 0.0), x, y, z, x);
436  }
437 
445  template <typename Float2, typename FloatN>
446  struct tripleCGUpdate {
447  Float2 a, b;
448  tripleCGUpdate(const Float2 &a, const Float2 &b, const Float2 &c) : a(a), b(b) { ; }
449  __device__ void operator()(const FloatN &x, FloatN &y, FloatN &z, FloatN &w)
450  { y -= a.x*x; z += a.x*w; w = y + b.x*w; }
451  static int streams() { return 7; }
452  static int flops() { return 6; }
453  };
454 
455  void tripleCGUpdateCuda(const double &a, const double &b, cudaColorSpinorField &x,
457  blasCuda<tripleCGUpdate,0,1,1,1>(make_double2(a, 0.0), make_double2(b, 0.0),
458  make_double2(0.0, 0.0), x, y, z, w);
459  }
460 
461 } // namespace quda