QUDA  v1.1.0
A library for QCD on GPUs
blas_quda.cu
Go to the documentation of this file.
1 #include <stdlib.h>
2 #include <stdio.h>
3 #include <cstring> // needed for memset
4 
5 #include <tune_quda.h>
6 #include <quda_internal.h>
7 #include <blas_quda.h>
8 #include <color_spinor_field.h>
9 
10 #include <jitify_helper.cuh>
11 #include <kernels/blas_core.cuh>
12 
13 namespace quda {
14 
15  namespace reducer {
16  void init();
17  void destroy();
18  }
19 
20  namespace blas {
21 
22  unsigned long long flops;
23  unsigned long long bytes;
24 
25  static qudaStream_t *blasStream;
26 
27  template <template <typename real> class Functor, typename store_t, typename y_store_t,
28  int nSpin, typename coeff_t>
29  class Blas : public Tunable
30  {
31  using real = typename mapper<y_store_t>::type;
32  Functor<real> f;
33  const int nParity; // for composite fields this includes the number of composites
34 
35  const coeff_t &a, &b, &c;
36  ColorSpinorField &x, &y, &z, &w, &v;
37  const QudaFieldLocation location;
38 
39  unsigned int sharedBytesPerThread() const { return 0; }
40  unsigned int sharedBytesPerBlock(const TuneParam &param) const { return 0; }
41 
42  bool tuneSharedBytes() const { return false; }
43 
44  // for these streaming kernels, there is no need to tune the grid size, just use max
45  unsigned int minGridSize() const { return maxGridSize(); }
46 
47  public:
48  Blas(const coeff_t &a, const coeff_t &b, const coeff_t &c, ColorSpinorField &x,
49  ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v) :
50  f(a, b, c),
51  nParity((x.IsComposite() ? x.CompositeDim() : 1) * x.SiteSubset()),
52  a(a),
53  b(b),
54  c(c),
55  x(x),
56  y(y),
57  z(z),
58  w(w),
59  v(v),
60  location(checkLocation(x, y, z, w, v))
61  {
62  checkLength(x, y, z, w, v);
63  auto x_prec = checkPrecision(x, z, w);
64  auto y_prec = checkPrecision(y, v);
65  auto x_order = checkOrder(x, z, w);
66  auto y_order = checkOrder(y, v);
67  if (sizeof(store_t) != x_prec) errorQuda("Expected precision %lu but received %d", sizeof(store_t), x_prec);
68  if (sizeof(y_store_t) != y_prec) errorQuda("Expected precision %lu but received %d", sizeof(y_store_t), y_prec);
69  if (x_prec == y_prec && x_order != y_order) errorQuda("Orders %d %d do not match", x_order, y_order);
70 
71  strcpy(aux, x.AuxString());
72  if (x_prec != y_prec) {
73  strcat(aux, ",");
74  strcat(aux, y.AuxString());
75  }
76  if (location == QUDA_CPU_FIELD_LOCATION) strcat(aux, ",CPU");
77 
78 #ifdef JITIFY
79  ::quda::create_jitify_program("kernels/blas_core.cuh");
80 #endif
81 
82  apply(*blasStream);
83 
84  blas::bytes += bytes();
85  blas::flops += flops();
86  }
87 
88  TuneKey tuneKey() const { return TuneKey(x.VolString(), typeid(f).name(), aux); }
89 
90  void apply(const qudaStream_t &stream)
91  {
92  constexpr bool site_unroll_check = !std::is_same<store_t, y_store_t>::value || isFixed<store_t>::value;
93  if (site_unroll_check && (x.Ncolor() != 3 || x.Nspin() == 2))
94  errorQuda("site unroll not supported for nSpin = %d nColor = %d", x.Nspin(), x.Ncolor());
95 
96  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
97  if (location == QUDA_CUDA_FIELD_LOCATION) {
98  if (site_unroll_check) checkNative(x, y, z, w, v); // require native order when using site_unroll
99  using device_store_t = typename device_type_mapper<store_t>::type;
100  using device_y_store_t = typename device_type_mapper<y_store_t>::type;
101  using device_real_t = typename mapper<device_y_store_t>::type;
102  Functor<device_real_t> f_(a, b, c);
103 
104  // redefine site_unroll with device_store types to ensure we have correct N/Ny/M values
105  constexpr bool site_unroll = !std::is_same<device_store_t, device_y_store_t>::value || isFixed<device_store_t>::value;
106  constexpr int N = n_vector<device_store_t, true, nSpin, site_unroll>();
107  constexpr int Ny = n_vector<device_y_store_t, true, nSpin, site_unroll>();
108  constexpr int M = site_unroll ? (nSpin == 4 ? 24 : 6) : N; // real numbers per thread
109  const int length = x.Length() / (nParity * M);
110 
111  BlasArg<device_store_t, N, device_y_store_t, Ny, decltype(f_)> arg(x, y, z, w, v, f_, length, nParity);
112 #ifdef JITIFY
113  using namespace jitify::reflection;
114  jitify_error = program->kernel("quda::blas::blasKernel")
115  .instantiate(Type<device_real_t>(), M, Type<decltype(arg)>())
116  .configure(tp.grid, tp.block, tp.shared_bytes, stream)
117  .launch(arg);
118 #else
119  qudaLaunchKernel(blasKernel<device_real_t, M, decltype(arg)>, tp, stream, arg);
120 #endif
121  } else {
122  if (checkOrder(x, y, z, w, v) != QUDA_SPACE_SPIN_COLOR_FIELD_ORDER)
123  errorQuda("CPU Blas functions expect AoS field order");
124 
125  using host_store_t = typename host_type_mapper<store_t>::type;
126  using host_y_store_t = typename host_type_mapper<y_store_t>::type;
127  using host_real_t = typename mapper<host_y_store_t>::type;
128  Functor<host_real_t> f_(a, b, c);
129 
130  // redefine site_unroll with host_store types to ensure we have correct N/Ny/M values
131  constexpr bool site_unroll = !std::is_same<host_store_t, host_y_store_t>::value || isFixed<host_store_t>::value;
132  constexpr int N = n_vector<host_store_t, false, nSpin, site_unroll>();
133  constexpr int Ny = n_vector<host_y_store_t, false, nSpin, site_unroll>();
134  constexpr int M = N; // if site unrolling then M=N will be 24/6, e.g., full AoS
135  const int length = x.Length() / (nParity * M);
136 
137  BlasArg<host_store_t, N, host_y_store_t, Ny, decltype(f_)> arg(x, y, z, w, v, f_, length, nParity);
138  blasCPU<host_real_t, M>(arg);
139  }
140  }
141 
142  void preTune()
143  {
144  if (f.write.X) x.backup();
145  if (f.write.Y) y.backup();
146  if (f.write.Z) z.backup();
147  if (f.write.W) w.backup();
148  if (f.write.V) v.backup();
149  }
150 
151  void postTune()
152  {
153  if (f.write.X) x.restore();
154  if (f.write.Y) y.restore();
155  if (f.write.Z) z.restore();
156  if (f.write.W) w.restore();
157  if (f.write.V) v.restore();
158  }
159 
160  bool advanceTuneParam(TuneParam &param) const
161  {
162  return location == QUDA_CPU_FIELD_LOCATION ? false : Tunable::advanceTuneParam(param);
163  }
164 
165  void initTuneParam(TuneParam &param) const
166  {
167  Tunable::initTuneParam(param);
168  param.grid.y = nParity;
169  }
170 
171  void defaultTuneParam(TuneParam &param) const
172  {
173  Tunable::initTuneParam(param);
174  param.grid.y = nParity;
175  }
176 
177  long long flops() const { return f.flops() * x.Length(); }
178  long long bytes() const
179  {
180  return (f.read.X + f.write.X) * x.Bytes() + (f.read.Y + f.write.Y) * y.Bytes() +
181  (f.read.Z + f.write.Z) * z.Bytes() + (f.read.W + f.write.W) * w.Bytes() + (f.read.V + f.write.V) * v.Bytes();
182  }
183  int tuningIter() const { return 3; }
184  };
185 
186  void zero(ColorSpinorField &a) {
187  if (typeid(a) == typeid(cudaColorSpinorField)) {
188  static_cast<cudaColorSpinorField&>(a).zero();
189  } else {
190  static_cast<cpuColorSpinorField&>(a).zero();
191  }
192  }
193 
194  void init()
195  {
196  blasStream = &streams[Nstream-1];
197  reducer::init();
198  }
199 
200  void destroy(void)
201  {
202  reducer::destroy();
203  }
204 
205  qudaStream_t* getStream() { return blasStream; }
206 
207  void axpbyz(double a, ColorSpinorField &x, double b, ColorSpinorField &y, ColorSpinorField &z)
208  {
209  instantiate<axpbyz_, Blas, true>(a, b, 0.0, x, y, x, x, z);
210  }
211 
212  void ax(double a, ColorSpinorField &x)
213  {
214  instantiate<ax_, Blas, false>(a, 0.0, 0.0, x, x, x, x, x);
215  }
216 
217  void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
218  {
219  instantiate<caxpy_, Blas, true>(a, Complex(0.0), Complex(0.0), x, y, x, x, y);
220  }
221 
222  void caxpby(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y)
223  {
224  instantiate<caxpby_, Blas, false>(a, b, Complex(0.0), x, y, x, x, y);
225  }
226 
227  void caxpbypczw(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y, const Complex &c,
228  ColorSpinorField &z, ColorSpinorField &w)
229  {
230  instantiate<caxpbypczw_, Blas, false>(a, b, c, x, y, z, w, y);
231  }
232 
233  void cxpaypbz(ColorSpinorField &x, const Complex &a, ColorSpinorField &y, const Complex &b, ColorSpinorField &z)
234  {
235  instantiate<caxpbypczw_, Blas, false>(Complex(1.0), a, b, x, y, z, z, y);
236  }
237 
238  void axpyBzpcx(double a, ColorSpinorField& x, ColorSpinorField& y, double b, ColorSpinorField& z, double c)
239  {
240  instantiate<axpyBzpcx_, Blas, true>(a, b, c, x, y, z, x, y);
241  }
242 
243  void axpyZpbx(double a, ColorSpinorField& x, ColorSpinorField& y, ColorSpinorField& z, double b)
244  {
245  instantiate<axpyZpbx_, Blas, true>(a, b, 0.0, x, y, z, x, y);
246  }
247 
248  void caxpyBzpx(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, const Complex &b, ColorSpinorField &z)
249  {
250  instantiate<caxpyBzpx_, Blas, true>(a, b, Complex(0.0), x, y, z, x, y);
251  }
252 
253  void caxpyBxpz(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, const Complex &b, ColorSpinorField &z)
254  {
255  instantiate<caxpyBxpz_, Blas, true>(a, b, Complex(0.0), x, y, z, x, y);
256  }
257 
258  void caxpbypzYmbw(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w)
259  {
260  instantiate<caxpbypzYmbw_, Blas, false>(a, b, Complex(0.0), x, y, z, w, y);
261  }
262 
263  void cabxpyAx(double a, const Complex &b, ColorSpinorField &x, ColorSpinorField &y)
264  {
265  instantiate<cabxpyAx_, Blas, false>(Complex(a), b, Complex(0.0), x, y, x, x, y);
266  }
267 
268  void caxpyXmaz(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
269  {
270  instantiate<caxpyxmaz_, Blas, false>(a, Complex(0.0), Complex(0.0), x, y, z, x, y);
271  }
272 
273  void caxpyXmazMR(const double &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
274  {
275  if (!commAsyncReduction())
276  errorQuda("This kernel requires asynchronous reductions to be set");
277  if (x.Location() == QUDA_CPU_FIELD_LOCATION)
278  errorQuda("This kernel cannot be run on CPU fields");
279  instantiate<caxpyxmazMR_, Blas, false>(a, 0.0, 0.0, x, y, z, y, y);
280  }
281 
282  void tripleCGUpdate(double a, double b, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w)
283  {
284  instantiate<tripleCGUpdate_, Blas, true>(a, b, 0.0, x, y, z, w, y);
285  }
286 
287  } // namespace blas
288 
289 } // namespace quda