QUDA  0.9.0
blas_core.cuh
Go to the documentation of this file.
1 
4 template <typename SpinorX, typename SpinorY, typename SpinorZ,
5  typename SpinorW, typename Functor>
6 struct BlasArg {
7  SpinorX X;
8  SpinorY Y;
9  SpinorZ Z;
10  SpinorW W;
11  Functor f;
12  const int length;
13  BlasArg(SpinorX X, SpinorY Y, SpinorZ Z, SpinorW W, Functor f, int length)
14  : X(X), Y(Y), Z(Z), W(W), f(f), length(length) { ; }
15 };
16 
20 template <typename FloatN, int M, typename SpinorX, typename SpinorY,
21  typename SpinorZ, typename SpinorW, typename Functor>
23  unsigned int i = blockIdx.x*(blockDim.x) + threadIdx.x;
24  unsigned int parity = blockIdx.y;
25  unsigned int gridSize = gridDim.x*blockDim.x;
26 
27  arg.f.init();
28 
29  while (i < arg.length) {
30  FloatN x[M], y[M], z[M], w[M];
31  arg.X.load(x, i, parity);
32  arg.Y.load(y, i, parity);
33  arg.Z.load(z, i, parity);
34  arg.W.load(w, i, parity);
35 
36 #pragma unroll
37  for (int j=0; j<M; j++) arg.f(x[j], y[j], z[j], w[j]);
38 
39  arg.X.save(x, i, parity);
40  arg.Y.save(y, i, parity);
41  arg.Z.save(z, i, parity);
42  arg.W.save(w, i, parity);
43  i += gridSize;
44  }
45 }
46 
47 template <typename FloatN, int M, typename SpinorX, typename SpinorY,
48  typename SpinorZ, typename SpinorW, typename Functor>
49 class BlasCuda : public Tunable {
50 
51 private:
53 
54  const int nParity;
55 
56  // host pointers used for backing up fields when tuning
57  // dont't these curry these in to minimize Arg size
58  char *X_h, *Y_h, *Z_h, *W_h;
60  const size_t *bytes_;
61  const size_t *norm_bytes_;
62 
63  unsigned int sharedBytesPerThread() const { return 0; }
64  unsigned int sharedBytesPerBlock(const TuneParam &param) const { return 0; }
65 
66  virtual bool advanceSharedBytes(TuneParam &param) const
67  {
68  TuneParam next(param);
69  advanceBlockDim(next); // to get next blockDim
70  int nthreads = next.block.x * next.block.y * next.block.z;
71  param.shared_bytes = sharedBytesPerThread()*nthreads > sharedBytesPerBlock(param) ?
73  return false;
74  }
75 
76 public:
77  BlasCuda(SpinorX &X, SpinorY &Y, SpinorZ &Z, SpinorW &W, Functor &f,
78  int length, int nParity, const size_t *bytes, const size_t *norm_bytes) :
79  arg(X, Y, Z, W, f, length/nParity), nParity(nParity), X_h(0), Y_h(0), Z_h(0), W_h(0),
80  Xnorm_h(0), Ynorm_h(0), Znorm_h(0), Wnorm_h(0), bytes_(bytes), norm_bytes_(norm_bytes) { }
81 
82  virtual ~BlasCuda() { }
83 
84  inline TuneKey tuneKey() const {
85  return TuneKey(blasStrings.vol_str, typeid(arg.f).name(), blasStrings.aux_tmp);
86  }
87 
88  inline void apply(const cudaStream_t &stream) {
89  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
90  blasKernel<FloatN,M> <<<tp.grid, tp.block, tp.shared_bytes, stream>>>(arg);
91  }
92 
93  void preTune() {
94  arg.X.backup(&X_h, &Xnorm_h, bytes_[0], norm_bytes_[0]);
95  arg.Y.backup(&Y_h, &Ynorm_h, bytes_[1], norm_bytes_[1]);
96  arg.Z.backup(&Z_h, &Znorm_h, bytes_[2], norm_bytes_[2]);
97  arg.W.backup(&W_h, &Wnorm_h, bytes_[3], norm_bytes_[3]);
98  }
99 
100  void postTune() {
101  arg.X.restore(&X_h, &Xnorm_h, bytes_[0], norm_bytes_[0]);
102  arg.Y.restore(&Y_h, &Ynorm_h, bytes_[1], norm_bytes_[1]);
103  arg.Z.restore(&Z_h, &Znorm_h, bytes_[2], norm_bytes_[2]);
104  arg.W.restore(&W_h, &Wnorm_h, bytes_[3], norm_bytes_[3]);
105  }
106 
107  void initTuneParam(TuneParam &param) const {
108  Tunable::initTuneParam(param);
109  param.grid.y = nParity;
110  }
111 
112  void defaultTuneParam(TuneParam &param) const {
113  Tunable::initTuneParam(param);
114  param.grid.y = nParity;
115  }
116 
117  long long flops() const { return arg.f.flops()*vec_length<FloatN>::value*arg.length*nParity*M; }
118  long long bytes() const
119  {
120  // bytes for low-precision vector
121  size_t base_bytes = arg.X.Precision()*vec_length<FloatN>::value*M;
122  if (arg.X.Precision() == QUDA_HALF_PRECISION) base_bytes += sizeof(float);
123 
124  // bytes for high precision vector
125  size_t extra_bytes = arg.Y.Precision()*vec_length<FloatN>::value*M;
126  if (arg.Y.Precision() == QUDA_HALF_PRECISION) extra_bytes += sizeof(float);
127 
128  // the factor two here assumes we are reading and writing to the high precision vector
129  return ((arg.f.streams()-2)*base_bytes + 2*extra_bytes)*arg.length*nParity;
130  }
131  int tuningIter() const { return 3; }
132 };
133 
134 template <typename RegType, typename StoreType, typename yType, int M,
135  template <typename,typename> class Functor,
136  int writeX, int writeY, int writeZ, int writeW>
137 void blasCuda(const double2 &a, const double2 &b, const double2 &c,
138  ColorSpinorField &x, ColorSpinorField &y,
139  ColorSpinorField &z, ColorSpinorField &w, int length) {
140 
142 
143  if (!x.isNative()) {
144  warningQuda("Device blas on non-native fields is not supported\n");
145  return;
146  }
147 
148  blasStrings.vol_str = x.VolString();
149  strcpy(blasStrings.aux_tmp, x.AuxString());
150  if (typeid(StoreType) != typeid(yType)) {
151  strcat(blasStrings.aux_tmp, ",");
152  strcat(blasStrings.aux_tmp, y.AuxString());
153  }
154 
155  size_t bytes[] = {x.Bytes(), y.Bytes(), z.Bytes(), w.Bytes()};
156  size_t norm_bytes[] = {x.NormBytes(), y.NormBytes(), z.NormBytes(), w.NormBytes()};
157 
162 
163  typedef typename scalar<RegType>::type Float;
164  typedef typename vector<Float,2>::type Float2;
165  typedef vector<Float,2> vec2;
166  Functor<Float2, RegType> f( (Float2)vec2(a), (Float2)vec2(b), (Float2)vec2(c));
167 
168  int partitions = (x.IsComposite() ? x.CompositeDim() : 1) * (x.SiteSubset());
169  BlasCuda<RegType,M,
170  decltype(X), decltype(Y), decltype(Z), decltype(W),
171  Functor<Float2, RegType> >
172  blas(X, Y, Z, W, f, length, partitions, bytes, norm_bytes);
173  blas.apply(*blasStream);
174 
175  blas::bytes += blas.bytes();
176  blas::flops += blas.flops();
177 
178  checkCudaError();
179 }
180 
181 
189 template <typename Float2, int writeX, int writeY, int writeZ, int writeW,
190  typename SpinorX, typename SpinorY, typename SpinorZ, typename SpinorW,
191  typename Functor>
192 void genericBlas(SpinorX &X, SpinorY &Y, SpinorZ &Z, SpinorW &W, Functor f) {
193 
194  for (int parity=0; parity<X.Nparity(); parity++) {
195  for (int x=0; x<X.VolumeCB(); x++) {
196  for (int s=0; s<X.Nspin(); s++) {
197  for (int c=0; c<X.Ncolor(); c++) {
198  Float2 X2 = make_Float2<Float2>( X(parity, x, s, c) );
199  Float2 Y2 = make_Float2<Float2>( Y(parity, x, s, c) );
200  Float2 Z2 = make_Float2<Float2>( Z(parity, x, s, c) );
201  Float2 W2 = make_Float2<Float2>( W(parity, x, s, c) );
202  f(X2, Y2, Z2, W2);
203  if (writeX) X(parity, x, s, c) = make_Complex(X2);
204  if (writeY) Y(parity, x, s, c) = make_Complex(Y2);
205  if (writeZ) Z(parity, x, s, c) = make_Complex(Z2);
206  if (writeW) W(parity, x, s, c) = make_Complex(W2);
207  }
208  }
209  }
210  }
211 }
212 
213 template <typename Float, typename yFloat, int nSpin, int nColor, QudaFieldOrder order,
214  int writeX, int writeY, int writeZ, int writeW, typename Functor>
215  void genericBlas(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z,
216  ColorSpinorField &w, Functor f) {
217  colorspinor::FieldOrderCB<Float,nSpin,nColor,1,order> X(x), Z(z), W(w);
218  colorspinor::FieldOrderCB<yFloat,nSpin,nColor,1,order> Y(y);
219  typedef typename vector<yFloat,2>::type Float2;
220  genericBlas<Float2,writeX,writeY,writeZ,writeW>(X, Y, Z, W, f);
221 }
222 
223 template <typename Float, typename yFloat, int nSpin, QudaFieldOrder order,
224  int writeX, int writeY, int writeZ, int writeW, typename Functor>
225  void genericBlas(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, Functor f) {
226  if (x.Ncolor() == 2) {
227  genericBlas<Float,yFloat,nSpin,2,order,writeX,writeY,writeZ,writeW,Functor>(x, y, z, w, f);
228  } else if (x.Ncolor() == 3) {
229  genericBlas<Float,yFloat,nSpin,3,order,writeX,writeY,writeZ,writeW,Functor>(x, y, z, w, f);
230  } else if (x.Ncolor() == 4) {
231  genericBlas<Float,yFloat,nSpin,4,order,writeX,writeY,writeZ,writeW,Functor>(x, y, z, w, f);
232  } else if (x.Ncolor() == 8) {
233  genericBlas<Float,yFloat,nSpin,8,order,writeX,writeY,writeZ,writeW,Functor>(x, y, z, w, f);
234  } else if (x.Ncolor() == 12) {
235  genericBlas<Float,yFloat,nSpin,12,order,writeX,writeY,writeZ,writeW,Functor>(x, y, z, w, f);
236  } else if (x.Ncolor() == 16) {
237  genericBlas<Float,yFloat,nSpin,16,order,writeX,writeY,writeZ,writeW,Functor>(x, y, z, w, f);
238  } else if (x.Ncolor() == 20) {
239  genericBlas<Float,yFloat,nSpin,20,order,writeX,writeY,writeZ,writeW,Functor>(x, y, z, w, f);
240  } else if (x.Ncolor() == 24) {
241  genericBlas<Float,yFloat,nSpin,24,order,writeX,writeY,writeZ,writeW,Functor>(x, y, z, w, f);
242  } else if (x.Ncolor() == 32) {
243  genericBlas<Float,yFloat,nSpin,32,order,writeX,writeY,writeZ,writeW,Functor>(x, y, z, w, f);
244  } else {
245  errorQuda("nColor = %d not implemeneted",x.Ncolor());
246  }
247 }
248 
249 template <typename Float, typename yFloat, QudaFieldOrder order, int writeX, int writeY, int writeZ, int writeW, typename Functor>
250  void genericBlas(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, Functor f) {
251  if (x.Nspin() == 4) {
252  genericBlas<Float,yFloat,4,order,writeX,writeY,writeZ,writeW,Functor>(x, y, z, w, f);
253  } else if (x.Nspin() == 2) {
254  genericBlas<Float,yFloat,2,order,writeX,writeY,writeZ,writeW,Functor>(x, y, z, w, f);
255 #ifdef GPU_STAGGERED_DIRAC
256  } else if (x.Nspin() == 1) {
257  genericBlas<Float,yFloat,1,order,writeX,writeY,writeZ,writeW,Functor>(x, y, z, w, f);
258 #endif
259  } else {
260  errorQuda("nSpin = %d not implemeneted",x.Nspin());
261  }
262 }
263 
264 template <typename Float, typename yFloat, int writeX, int writeY, int writeZ, int writeW, typename Functor>
265  void genericBlas(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, Functor f) {
266  if (x.FieldOrder() == QUDA_SPACE_SPIN_COLOR_FIELD_ORDER) {
267  genericBlas<Float,yFloat,QUDA_SPACE_SPIN_COLOR_FIELD_ORDER,writeX,writeY,writeZ,writeW,Functor>
268  (x, y, z, w, f);
269  } else {
270  errorQuda("Not implemeneted");
271  }
272 }
void blasCuda(const double2 &a, const double2 &b, const double2 &c, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, int length)
Definition: blas_core.cuh:137
dim3 dim3 blockDim
int tuningIter() const
Definition: blas_core.cuh:131
BlasArg(SpinorX X, SpinorY Y, SpinorZ Z, SpinorW W, Functor f, int length)
Definition: blas_core.cuh:13
void initTuneParam(TuneParam &param) const
Definition: blas_core.cuh:107
cudaStream_t stream
char * Xnorm_h
Definition: blas_core.cuh:59
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:20
#define errorQuda(...)
Definition: util_quda.h:90
unsigned int sharedBytesPerThread() const
Definition: blas_core.cuh:63
enum QudaFieldOrder_s QudaFieldOrder
void checkLength(const ColorSpinorField &a, ColorSpinorField &b)
Definition: reduce_quda.cu:49
char * Znorm_h
Definition: blas_core.cuh:59
char * Wnorm_h
Definition: blas_core.cuh:59
char * strcpy(char *__dst, const char *__src)
__global__ void blasKernel(BlasArg< SpinorX, SpinorY, SpinorZ, SpinorW, Functor > arg)
Definition: blas_core.cuh:22
char * Ynorm_h
Definition: blas_core.cuh:59
BlasCuda(SpinorX &X, SpinorY &Y, SpinorZ &Z, SpinorW &W, Functor &f, int length, int nParity, const size_t *bytes, const size_t *norm_bytes)
Definition: blas_core.cuh:77
char * strcat(char *__s1, const char *__s2)
long long flops() const
Definition: blas_core.cuh:117
const size_t * bytes_
Definition: blas_core.cuh:60
char * Y_h
Definition: blas_core.cuh:58
virtual bool advanceSharedBytes(TuneParam &param) const
Definition: blas_core.cuh:66
QudaGaugeParam param
Definition: pack_test.cpp:17
#define b
complex< double > make_Complex(const double2 &a)
Definition: float_vector.h:278
long long bytes() const
Definition: blas_core.cuh:118
SpinorX X
Definition: blas_core.cuh:7
SpinorZ Z
Definition: blas_core.cuh:9
static cudaStream_t * blasStream
Definition: blas_quda.cu:53
const int nColor
Definition: covdev_test.cpp:77
static struct quda::blas::@4 blasStrings
unsigned int sharedBytesPerBlock(const TuneParam &param) const
Definition: blas_core.cuh:64
BlasArg< SpinorX, SpinorY, SpinorZ, SpinorW, Functor > arg
Definition: blas_core.cuh:52
int int int w
TuneKey tuneKey() const
Definition: blas_core.cuh:84
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:603
#define warningQuda(...)
Definition: util_quda.h:101
int int int enum cudaChannelFormatKind f
int Z[4]
Definition: test_util.cpp:27
void preTune()
Definition: blas_core.cuh:93
char * W_h
Definition: blas_core.cuh:58
virtual ~BlasCuda()
Definition: blas_core.cuh:82
const int length
Definition: blas_core.cuh:12
void postTune()
Definition: blas_core.cuh:100
SpinorW W
Definition: blas_core.cuh:10
char * X_h
Definition: blas_core.cuh:58
void genericBlas(SpinorX &X, SpinorY &Y, SpinorZ &Z, SpinorW &W, Functor f)
Definition: blas_core.cuh:192
void defaultTuneParam(TuneParam &param) const
Definition: blas_core.cuh:112
const size_t * norm_bytes_
Definition: blas_core.cuh:61
const int nParity
Definition: blas_core.cuh:54
char * Z_h
Definition: blas_core.cuh:58
unsigned long long flops
Definition: blas_quda.cu:42
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
Definition: complex_quda.h:880
void size_t length
const void * c
#define checkCudaError()
Definition: util_quda.h:129
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:51
Functor f
Definition: blas_core.cuh:11
void apply(const cudaStream_t &stream)
Definition: blas_core.cuh:88
QudaParity parity
Definition: covdev_test.cpp:53
#define a
unsigned long long bytes
Definition: blas_quda.cu:43
SpinorY Y
Definition: blas_core.cuh:8