QUDA  0.9.0
multi_blas_core.cuh
Go to the documentation of this file.
1 
11 template <int NXZ, typename SpinorX, typename SpinorY, typename SpinorZ,
12  typename SpinorW, typename Functor>
13 struct MultiBlasArg {
14  const int NYW;
15  SpinorX X[NXZ];
16  SpinorY Y[MAX_MULTI_BLAS_N];
17  SpinorZ Z[NXZ];
18  SpinorW W[MAX_MULTI_BLAS_N];
19  Functor f;
20  const int length;
21 
22  MultiBlasArg(SpinorX X[NXZ], SpinorY Y[], SpinorZ Z[NXZ], SpinorW W[], Functor f, int NYW, int length)
23  : NYW(NYW), f(f), length(length) {
24 
25  for(int i=0; i<NXZ; ++i){
26  this->X[i] = X[i];
27  this->Z[i] = Z[i];
28  }
29  for(int i=0; i<NYW; ++i){
30  this->Y[i] = Y[i];
31  this->W[i] = W[i];
32  }
33  }
34 };
35 
36 
37 // storage for matrix coefficients
38 #define MAX_MATRIX_SIZE 4096
39 static __constant__ signed char Amatrix_d[MAX_MATRIX_SIZE];
40 static __constant__ signed char Bmatrix_d[MAX_MATRIX_SIZE];
41 static __constant__ signed char Cmatrix_d[MAX_MATRIX_SIZE];
42 
43 static signed char *Amatrix_h;
44 static signed char *Bmatrix_h;
45 static signed char *Cmatrix_h;
46 
47 template<int k, int NXZ, typename FloatN, int M, typename Arg>
48 __device__ inline void compute(Arg &arg, int idx, int parity) {
49 
50  while (idx < arg.length) {
51 
52  FloatN x[M], y[M], z[M], w[M];
53  arg.Y[k].load(y, idx, parity);
54  arg.W[k].load(w, idx, parity);
55 
56 #pragma unroll
57  for (int l=0; l < NXZ; l++) {
58  arg.X[l].load(x, idx, parity);
59  arg.Z[l].load(z, idx, parity);
60 
61 #pragma unroll
62  for (int j=0; j < M; j++) arg.f(x[j], y[j], z[j], w[j], k, l);
63  }
64  arg.Y[k].save(y, idx, parity);
65  arg.W[k].save(w, idx, parity);
66 
67  idx += gridDim.x*blockDim.x;
68  }
69 }
70 
76 template <typename FloatN, int M, int NXZ, typename SpinorX, typename SpinorY,
77 typename SpinorZ, typename SpinorW, typename Functor>
79 
80  // use i to loop over elements in kernel
81  unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
82  unsigned int k = blockIdx.y * blockDim.y + threadIdx.y;
83  unsigned int parity = blockIdx.z;
84 
85  arg.f.init();
86  if (k >= arg.NYW) return;
87 
88  switch(k) {
89  case 0: compute< 0,NXZ,FloatN,M>(arg,i,parity); break;
90 #if MAX_MULTI_BLAS_N >= 2
91  case 1: compute< 1,NXZ,FloatN,M>(arg,i,parity); break;
92 #if MAX_MULTI_BLAS_N >= 3
93  case 2: compute< 2,NXZ,FloatN,M>(arg,i,parity); break;
94 #if MAX_MULTI_BLAS_N >= 4
95  case 3: compute< 3,NXZ,FloatN,M>(arg,i,parity); break;
96 #if MAX_MULTI_BLAS_N >= 5
97  case 4: compute< 4,NXZ,FloatN,M>(arg,i,parity); break;
98 #if MAX_MULTI_BLAS_N >= 6
99  case 5: compute< 5,NXZ,FloatN,M>(arg,i,parity); break;
100 #if MAX_MULTI_BLAS_N >= 7
101  case 6: compute< 6,NXZ,FloatN,M>(arg,i,parity); break;
102 #if MAX_MULTI_BLAS_N >= 8
103  case 7: compute< 7,NXZ,FloatN,M>(arg,i,parity); break;
104 #if MAX_MULTI_BLAS_N >= 9
105  case 8: compute< 8,NXZ,FloatN,M>(arg,i,parity); break;
106 #if MAX_MULTI_BLAS_N >= 10
107  case 9: compute< 9,NXZ,FloatN,M>(arg,i,parity); break;
108 #if MAX_MULTI_BLAS_N >= 11
109  case 10: compute<10,NXZ,FloatN,M>(arg,i,parity); break;
110 #if MAX_MULTI_BLAS_N >= 12
111  case 11: compute<11,NXZ,FloatN,M>(arg,i,parity); break;
112 #if MAX_MULTI_BLAS_N >= 13
113  case 12: compute<12,NXZ,FloatN,M>(arg,i,parity); break;
114 #if MAX_MULTI_BLAS_N >= 14
115  case 13: compute<13,NXZ,FloatN,M>(arg,i,parity); break;
116 #if MAX_MULTI_BLAS_N >= 15
117  case 14: compute<14,NXZ,FloatN,M>(arg,i,parity); break;
118 #if MAX_MULTI_BLAS_N >= 16
119  case 15: compute<15,NXZ,FloatN,M>(arg,i,parity); break;
120 #endif //16
121 #endif //15
122 #endif //14
123 #endif //13
124 #endif //12
125 #endif //11
126 #endif //10
127 #endif //9
128 #endif //8
129 #endif //7
130 #endif //6
131 #endif //5
132 #endif //4
133 #endif //3
134 #endif //2
135  }
136 
137 }
138 
139 namespace detail
140 {
141  template<unsigned... digits>
142  struct to_chars { static const char value[]; };
143 
144  template<unsigned... digits>
145  const char to_chars<digits...>::value[] = {('0' + digits)..., 0};
146 
147  template<unsigned rem, unsigned... digits>
148  struct explode : explode<rem / 10, rem % 10, digits...> {};
149 
150  template<unsigned... digits>
151  struct explode<0, digits...> : to_chars<digits...> {};
152 }
153 
154 template<unsigned num>
155 struct num_to_string : detail::explode<num / 10, num % 10> {};
156 
157 
158 template <int NXZ, typename FloatN, int M, typename SpinorX, typename SpinorY,
159  typename SpinorZ, typename SpinorW, typename Functor>
160 class MultiBlasCuda : public TunableVectorY {
161 
162 private:
163  const int NYW;
165  const int nParity;
166 
167  // host pointers used for backing up fields when tuning
168  // don't curry into the Spinors to minimize parameter size
170  std::vector<ColorSpinorField*> &y, &w;
171 
172  bool tuneSharedBytes() const { return false; }
173 
174 public:
175  MultiBlasCuda(SpinorX X[], SpinorY Y[], SpinorZ Z[], SpinorW W[], Functor &f,
176  int NYW, int length, int nParity,
177  std::vector<ColorSpinorField*> &y, std::vector<ColorSpinorField*> &w)
178  : TunableVectorY(NYW), NYW(NYW), arg(X, Y, Z, W, f, NYW, length/nParity),
179  nParity(nParity), Y_h(), W_h(), Ynorm_h(), Wnorm_h(), y(y), w(w) { }
180 
181  virtual ~MultiBlasCuda() { }
182 
183  inline TuneKey tuneKey() const {
184  char name[TuneKey::name_n];
186  strcat(name, std::to_string(NYW).c_str());
187  strcat(name, typeid(arg.f).name());
188  return TuneKey(blasStrings.vol_str, name, blasStrings.aux_tmp);
189  }
190 
191  inline void apply(const cudaStream_t &stream) {
192  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
193  multiblasKernel<FloatN,M,NXZ> <<<tp.grid, tp.block, tp.shared_bytes, stream>>>(arg);
194  }
195 
196  void preTune() {
197  for(int i=0; i<NYW; ++i){
198  arg.Y[i].backup(&Y_h[i], &Ynorm_h[i], y[i]->Bytes(), y[i]->NormBytes());
199  arg.W[i].backup(&W_h[i], &Wnorm_h[i], w[i]->Bytes(), w[i]->NormBytes());
200  }
201  }
202 
203  void postTune() {
204  for(int i=0; i<NYW; ++i){
205  arg.Y[i].restore(&Y_h[i], &Ynorm_h[i], y[i]->Bytes(), y[i]->NormBytes());
206  arg.W[i].restore(&W_h[i], &Wnorm_h[i], w[i]->Bytes(), w[i]->NormBytes());
207  }
208  }
209 
210  void initTuneParam(TuneParam &param) const {
211  TunableVectorY::initTuneParam(param);
212  param.grid.z = nParity;
213  }
214 
215  void defaultTuneParam(TuneParam &param) const {
216  TunableVectorY::defaultTuneParam(param);
217  param.grid.z = nParity;
218  }
219 
220  long long flops() const { return arg.f.flops()*vec_length<FloatN>::value*(long)arg.length*nParity*M; }
221 
222  long long bytes() const
223  {
224  // bytes for low-precision vector
225  size_t base_bytes = arg.X[0].Precision()*vec_length<FloatN>::value*M;
226  if (arg.X[0].Precision() == QUDA_HALF_PRECISION) base_bytes += sizeof(float);
227 
228  // bytes for high precision vector
229  size_t extra_bytes = arg.Y[0].Precision()*vec_length<FloatN>::value*M;
230  if (arg.Y[0].Precision() == QUDA_HALF_PRECISION) extra_bytes += sizeof(float);
231 
232  // the factor two here assumes we are reading and writing to the high precision vector
233  return ((arg.f.streams()-2)*base_bytes + 2*extra_bytes)*arg.length*nParity;
234  }
235 
236  int tuningIter() const { return 3; }
237 };
238 
239 template <typename T>
240 struct coeff_array {
241  const T *data;
242  const bool use_const;
243  coeff_array() : data(nullptr), use_const(false) { }
245 };
246 
247 template <int NXZ, typename RegType, typename StoreType, typename yType, int M,
248  template <int,typename,typename> class Functor,
249  typename write, typename T>
251  std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &y,
252  std::vector<ColorSpinorField*> &z, std::vector<ColorSpinorField*> &w,
253  int length) {
254 
255  const int NYW = y.size();
256 
257  const int N = NXZ > NYW ? NXZ : NYW;
258  if (N > MAX_MULTI_BLAS_N) errorQuda("Spinor vector length exceeds max size (%d > %d)", N, MAX_MULTI_BLAS_N);
259 
260  if (NXZ*NYW*sizeof(Complex) > MAX_MATRIX_SIZE)
261  errorQuda("A matrix exceeds max size (%lu > %d)", NXZ*NYW*sizeof(Complex), MAX_MATRIX_SIZE);
262 
263  typedef typename scalar<RegType>::type Float;
264  typedef typename vector<Float,2>::type Float2;
265  typedef vector<Float,2> vec2;
266 
267  // FIXME - if NXZ=1 no need to copy entire array
268  // FIXME - do we really need strided access here?
269  if (a.data && a.use_const) {
270  Float2 A[MAX_MATRIX_SIZE/sizeof(Float2)];
271  // since the kernel doesn't know the width of them matrix at compile
272  // time we stride it and copy the padded matrix to GPU
273  for (int i=0; i<NXZ; i++) for (int j=0; j<NYW; j++)
274  A[MAX_MULTI_BLAS_N * i + j] = make_Float2<Float2>(Complex(a.data[NYW * i + j]));
275 
276  cudaMemcpyToSymbolAsync(Amatrix_d, A, MAX_MATRIX_SIZE, 0, cudaMemcpyHostToDevice, *getStream());
277  Amatrix_h = reinterpret_cast<signed char*>(const_cast<T*>(a.data));
278  }
279 
280  if (b.data && b.use_const) {
281  Float2 B[MAX_MATRIX_SIZE/sizeof(Float2)];
282  // since the kernel doesn't know the width of them matrix at compile
283  // time we stride it and copy the padded matrix to GPU
284  for (int i=0; i<NXZ; i++) for (int j=0; j<NYW; j++)
285  B[MAX_MULTI_BLAS_N * i + j] = make_Float2<Float2>(Complex(b.data[NYW * i + j]));
286 
287  cudaMemcpyToSymbolAsync(Bmatrix_d, B, MAX_MATRIX_SIZE, 0, cudaMemcpyHostToDevice, *getStream());
288  Bmatrix_h = reinterpret_cast<signed char*>(const_cast<T*>(b.data));
289  }
290 
291  if (c.data && c.use_const) {
292  Float2 C[MAX_MATRIX_SIZE/sizeof(Float2)];
293  // since the kernel doesn't know the width of them matrix at compile
294  // time we stride it and copy the padded matrix to GPU
295  for (int i=0; i<NXZ; i++) for (int j=0; j<NYW; j++)
296  C[MAX_MULTI_BLAS_N * i + j] = make_Float2<Float2>(Complex(c.data[NYW * i + j]));
297 
298  cudaMemcpyToSymbolAsync(Cmatrix_d, C, MAX_MATRIX_SIZE, 0, cudaMemcpyHostToDevice, *getStream());
299  Cmatrix_h = reinterpret_cast<signed char*>(const_cast<T*>(c.data));
300  }
301 
302  // for (int i=0; i<N; i++) {
303  // checkLength(*x[i],*y[i]); checkLength(*x[i],*z[i]); checkLength(*x[i],*w[i]);
304  // }
305 
306  blasStrings.vol_str = x[0]->VolString();
307  strcpy(blasStrings.aux_tmp, x[0]->AuxString());
308  if (typeid(StoreType) != typeid(yType)) {
309  strcat(blasStrings.aux_tmp, ",");
310  strcat(blasStrings.aux_tmp, y[0]->AuxString());
311  }
312 
313  multi::SpinorTexture<RegType,StoreType,M,0> X[NXZ];
314  multi::Spinor<RegType, yType,M,write::Y,1> Y[MAX_MULTI_BLAS_N];
315  multi::SpinorTexture<RegType,StoreType,M,2> Z[NXZ];
316  multi::Spinor<RegType,StoreType,M,write::W,3> W[MAX_MULTI_BLAS_N];
317 
318  //MWFIXME
319  for (int i=0; i<NXZ; i++) { X[i].set(*dynamic_cast<cudaColorSpinorField *>(x[i])); Z[i].set(*dynamic_cast<cudaColorSpinorField *>(z[i]));}
320  for (int i=0; i<NYW; i++) { Y[i].set(*dynamic_cast<cudaColorSpinorField *>(y[i])); W[i].set(*dynamic_cast<cudaColorSpinorField *>(w[i]));}
321 
322  // if block caxpy is an 'outer product of caxpy' where 'x'
323 
324  Functor<NXZ,Float2, RegType> f(a, b, c, NYW);
325 
326  MultiBlasCuda<NXZ,RegType,M,
327  multi::SpinorTexture<RegType,StoreType,M,0>,
328  multi::Spinor<RegType, yType,M,write::Y,1>,
329  multi::SpinorTexture<RegType,StoreType,M,2>,
330  multi::Spinor<RegType,StoreType,M,write::W,3>,
331  decltype(f) >
332  blas(X, Y, Z, W, f, NYW, length, x[0]->SiteSubset(), y, w);
333  blas.apply(*getStream());
334 
335  blas::bytes += blas.bytes();
336  blas::flops += blas.flops();
337 
338  checkCudaError();
339 }
340 
341 
348 template <typename Float2, typename write,
349  typename SpinorX, typename SpinorY, typename SpinorZ, typename SpinorW,
350  typename Functor>
351 void genericMultiBlas(SpinorX &X, SpinorY &Y, SpinorZ &Z, SpinorW &W, Functor f) {
352 
353  for (int parity=0; parity<X.Nparity(); parity++) {
354  for (int x=0; x<X.VolumeCB(); x++) {
355  for (int s=0; s<X.Nspin(); s++) {
356  for (int c=0; c<X.Ncolor(); c++) {
357  Float2 X2 = make_Float2<Float2>( X(parity, x, s, c) );
358  Float2 Y2 = make_Float2<Float2>( Y(parity, x, s, c) );
359  Float2 Z2 = make_Float2<Float2>( Z(parity, x, s, c) );
360  Float2 W2 = make_Float2<Float2>( W(parity, x, s, c) );
361  f(X2, Y2, Z2, W2, 1 , 1);
362  // if (writeX) X(parity, x, s, c) = make_Complex(X2);
363  if (write::X) errorQuda("writeX not supported in multiblas.");
364  if (write::Y) Y(parity, x, s, c) = make_Complex(Y2);
365  if (write::Z) errorQuda("writeZ not supported in multiblas.");
366  if (write::W) W(parity, x, s, c) = make_Complex(W2);
367  }
368  }
369  }
370  }
371 }
372 
373 template <typename Float, typename yFloat, int nSpin, int nColor, QudaFieldOrder order,
374  typename write, typename Functor>
375  void genericMultiBlas(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z,
376  ColorSpinorField &w, Functor f) {
377  colorspinor::FieldOrderCB<Float,nSpin,nColor,1,order> X(x), Z(z), W(w);
378  colorspinor::FieldOrderCB<yFloat,nSpin,nColor,1,order> Y(y);
379  typedef typename vector<yFloat,2>::type Float2;
380  genericMultiBlas<Float2,write>(X, Y, Z, W, f);
381 }
382 
383 template <typename Float, typename yFloat, int nSpin, QudaFieldOrder order,
384  typename write, typename Functor>
385  void genericMultiBlas(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, Functor f) {
386  if (x.Ncolor() == 2) {
387  genericMultiBlas<Float,yFloat,nSpin,2,order,write,Functor>(x, y, z, w, f);
388  } else if (x.Ncolor() == 3) {
389  genericMultiBlas<Float,yFloat,nSpin,3,order,write,Functor>(x, y, z, w, f);
390  } else if (x.Ncolor() == 4) {
391  genericMultiBlas<Float,yFloat,nSpin,4,order,write,Functor>(x, y, z, w, f);
392  } else if (x.Ncolor() == 8) {
393  genericMultiBlas<Float,yFloat,nSpin,8,order,write,Functor>(x, y, z, w, f);
394  } else if (x.Ncolor() == 12) {
395  genericMultiBlas<Float,yFloat,nSpin,12,order,write,Functor>(x, y, z, w, f);
396  } else if (x.Ncolor() == 16) {
397  genericMultiBlas<Float,yFloat,nSpin,16,order,write,Functor>(x, y, z, w, f);
398  } else if (x.Ncolor() == 20) {
399  genericMultiBlas<Float,yFloat,nSpin,20,order,write,Functor>(x, y, z, w, f);
400  } else if (x.Ncolor() == 24) {
401  genericMultiBlas<Float,yFloat,nSpin,24,order,write,Functor>(x, y, z, w, f);
402  } else if (x.Ncolor() == 32) {
403  genericMultiBlas<Float,yFloat,nSpin,32,order,write,Functor>(x, y, z, w, f);
404  } else {
405  errorQuda("nColor = %d not implemeneted",x.Ncolor());
406  }
407 }
408 
409 template <typename Float, typename yFloat, QudaFieldOrder order, typename write, typename Functor>
410  void genericMultiBlas(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, Functor f) {
411  if (x.Nspin() == 4) {
412  genericMultiBlas<Float,yFloat,4,order,write,Functor>(x, y, z, w, f);
413  } else if (x.Nspin() == 2) {
414  genericMultiBlas<Float,yFloat,2,order,write,Functor>(x, y, z, w, f);
415 #ifdef GPU_STAGGERED_DIRAC
416  } else if (x.Nspin() == 1) {
417  genericMultiBlas<Float,yFloat,1,order,write,Functor>(x, y, z, w, f);
418 #endif
419  } else {
420  errorQuda("nSpin = %d not implemeneted",x.Nspin());
421  }
422 }
423 
424 template <typename Float, typename yFloat, typename write, typename Functor>
425  void genericMultiBlas(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, Functor f) {
426  if (x.FieldOrder() == QUDA_SPACE_SPIN_COLOR_FIELD_ORDER) {
427  genericMultiBlas<Float,yFloat,QUDA_SPACE_SPIN_COLOR_FIELD_ORDER,write,Functor>
428  (x, y, z, w, f);
429  } else {
430  errorQuda("Not implemeneted");
431  }
432 }
dim3 dim3 blockDim
cudaStream_t stream
static __constant__ signed char Bmatrix_d[MAX_MATRIX_SIZE]
MultiBlasArg< NXZ, SpinorX, SpinorY, SpinorZ, SpinorW, Functor > arg
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:20
char * Wnorm_h[MAX_MULTI_BLAS_N]
#define errorQuda(...)
Definition: util_quda.h:90
static __constant__ signed char Amatrix_d[MAX_MATRIX_SIZE]
enum QudaFieldOrder_s QudaFieldOrder
std::complex< double > Complex
Definition: eig_variables.h:13
long long flops() const
char * strcpy(char *__dst, const char *__src)
char * strcat(char *__s1, const char *__s2)
SpinorY Y[MAX_MULTI_BLAS_N]
coeff_array(const T *data, bool use_const)
virtual ~MultiBlasCuda()
QudaGaugeParam param
Definition: pack_test.cpp:17
#define b
complex< double > make_Complex(const double2 &a)
Definition: float_vector.h:278
cudaStream_t * getStream()
Definition: blas_quda.cu:75
__device__ void compute(Arg &arg, int idx, int parity)
static __constant__ signed char Cmatrix_d[MAX_MATRIX_SIZE]
void apply(const cudaStream_t &stream)
void initTuneParam(TuneParam &param) const
MultiBlasCuda(SpinorX X[], SpinorY Y[], SpinorZ Z[], SpinorW W[], Functor &f, int NYW, int length, int nParity, std::vector< ColorSpinorField *> &y, std::vector< ColorSpinorField *> &w)
const int nColor
Definition: covdev_test.cpp:77
static struct quda::blas::@4 blasStrings
int int int w
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:603
void multiblasCuda(const coeff_array< T > &a, const coeff_array< T > &b, const coeff_array< T > &c, std::vector< ColorSpinorField *> &x, std::vector< ColorSpinorField *> &y, std::vector< ColorSpinorField *> &z, std::vector< ColorSpinorField *> &w, int length)
int int int enum cudaChannelFormatKind f
int Z[4]
Definition: test_util.cpp:27
__global__ void multiblasKernel(MultiBlasArg< NXZ, SpinorX, SpinorY, SpinorZ, SpinorW, Functor > arg)
Generic multi-blas kernel with four loads and up to four stores.
static signed char * Cmatrix_h
std::vector< ColorSpinorField * > & w
long long bytes() const
SpinorZ Z[NXZ]
static signed char * Amatrix_h
#define MAX_MULTI_BLAS_N
Definition: quda_internal.h:49
void genericMultiBlas(SpinorX &X, SpinorY &Y, SpinorZ &Z, SpinorW &W, Functor f)
unsigned long long flops
Definition: blas_quda.cu:42
std::vector< ColorSpinorField * > & y
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
Definition: complex_quda.h:880
void size_t length
Parameter struct for generic multi-blas kernel.
SpinorW W[MAX_MULTI_BLAS_N]
char * Y_h[MAX_MULTI_BLAS_N]
static const char value[]
const void * c
char * W_h[MAX_MULTI_BLAS_N]
#define MAX_MATRIX_SIZE
#define checkCudaError()
Definition: util_quda.h:129
int tuningIter() const
bool tuneSharedBytes() const
TuneKey tuneKey() const
void defaultTuneParam(TuneParam &param) const
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:51
char * Ynorm_h[MAX_MULTI_BLAS_N]
QudaParity parity
Definition: covdev_test.cpp:53
const bool use_const
MultiBlasArg(SpinorX X[NXZ], SpinorY Y[], SpinorZ Z[NXZ], SpinorW W[], Functor f, int NYW, int length)
const int length
#define a
unsigned long long bytes
Definition: blas_quda.cu:43
SpinorX X[NXZ]
static signed char * Bmatrix_h