QUDA  v0.7.0
A library for QCD on GPUs
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
blas_mixed_core.h
Go to the documentation of this file.
1 namespace mixed {
2 
6 template <typename SpinorX, typename SpinorY, typename SpinorZ,
7  typename SpinorW, typename Functor>
8 struct BlasArg {
9  SpinorX X;
10  SpinorY Y;
11  SpinorZ Z;
12  SpinorW W;
13  Functor f;
14  const int length;
15  BlasArg(SpinorX X, SpinorY Y, SpinorZ Z, SpinorW W, Functor f, int length)
16  : X(X), Y(Y), Z(Z), W(W), f(f), length(length) { ; }
17 };
18 
22 template <typename FloatN, int M, typename SpinorX, typename SpinorY,
23  typename SpinorZ, typename SpinorW, typename Functor>
25  unsigned int i = blockIdx.x*(blockDim.x) + threadIdx.x;
26  unsigned int gridSize = gridDim.x*blockDim.x;
27  while (i < arg.length) {
28  FloatN x[M], y[M], z[M], w[M];
29  arg.X.load(x, i);
30  arg.Y.load(y, i);
31  arg.Z.load(z, i);
32  arg.W.load(w, i);
33 
34 #pragma unroll
35  for (int j=0; j<M; j++) arg.f(x[j], y[j], z[j], w[j]);
36 
37  arg.X.save(x, i);
38  arg.Y.save(y, i);
39  arg.Z.save(z, i);
40  arg.W.save(w, i);
41  i += gridSize;
42  }
43 }
44 
45 template <typename FloatN, int M, typename SpinorX, typename SpinorY,
46  typename SpinorZ, typename SpinorW, typename Functor>
47 class BlasCuda : public Tunable {
48 
49 private:
51 
52  // host pointers used for backing up fields when tuning
53  // these can't be curried into the Spinors because of Tesla argument length restriction
54  char *X_h, *Y_h, *Z_h, *W_h;
55  char *Xnorm_h, *Ynorm_h, *Znorm_h, *Wnorm_h;
56  const size_t *bytes_;
57  const size_t *norm_bytes_;
58 
59  unsigned int sharedBytesPerThread() const { return 0; }
60  unsigned int sharedBytesPerBlock(const TuneParam &param) const { return 0; }
61 
62  virtual bool advanceSharedBytes(TuneParam &param) const
63  {
64  TuneParam next(param);
65  advanceBlockDim(next); // to get next blockDim
66  int nthreads = next.block.x * next.block.y * next.block.z;
67  param.shared_bytes = sharedBytesPerThread()*nthreads > sharedBytesPerBlock(param) ?
68  sharedBytesPerThread()*nthreads : sharedBytesPerBlock(param);
69  return false;
70  }
71 
72 public:
73  BlasCuda(SpinorX &X, SpinorY &Y, SpinorZ &Z, SpinorW &W, Functor &f,
74  int length, const size_t *bytes, const size_t *norm_bytes) :
75  arg(X, Y, Z, W, f, length), X_h(0), Y_h(0), Z_h(0), W_h(0),
76  Xnorm_h(0), Ynorm_h(0), Znorm_h(0), Wnorm_h(0),
77  bytes_(bytes), norm_bytes_(norm_bytes)
78  { ; }
79  virtual ~BlasCuda() { }
80 
81  inline TuneKey tuneKey() const {
82  return TuneKey(blasStrings.vol_str, typeid(arg.f).name(), blasStrings.aux_tmp);
83  }
84 
85  void apply(const cudaStream_t &stream) {
86  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
87  blasKernel<FloatN,M> <<<tp.grid, tp.block, tp.shared_bytes, stream>>>(arg);
88  }
89 
90  void preTune() {
91  arg.X.save(&X_h, &Xnorm_h, bytes_[0], norm_bytes_[0]);
92  arg.Y.save(&Y_h, &Ynorm_h, bytes_[1], norm_bytes_[1]);
93  arg.Z.save(&Z_h, &Znorm_h, bytes_[2], norm_bytes_[2]);
94  arg.W.save(&W_h, &Wnorm_h, bytes_[3], norm_bytes_[3]);
95  }
96 
97  void postTune() {
98  arg.X.load(&X_h, &Xnorm_h, bytes_[0], norm_bytes_[0]);
99  arg.Y.load(&Y_h, &Ynorm_h, bytes_[1], norm_bytes_[1]);
100  arg.Z.load(&Z_h, &Znorm_h, bytes_[2], norm_bytes_[2]);
101  arg.W.load(&W_h, &Wnorm_h, bytes_[3], norm_bytes_[3]);
102  }
103 
104  long long flops() const { return arg.f.flops()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*arg.length*M; }
105  long long bytes() const {
106  size_t bytes = arg.X.Precision()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*M;
107  if (arg.X.Precision() == QUDA_HALF_PRECISION) bytes += sizeof(float);
108  return arg.f.streams()*bytes*arg.length; }
109  int tuningIter() const { return 3; }
110 };
111 
115 template <template <typename Float, typename FloatN> class Functor,
116  int writeX, int writeY, int writeZ, int writeW>
117 void blasCuda(const double2 &a, const double2 &b, const double2 &c,
118  cudaColorSpinorField &x, cudaColorSpinorField &y,
119  cudaColorSpinorField &z, cudaColorSpinorField &w) {
120  checkLength(x, y);
121  checkSpinor(x, z);
122  checkSpinor(x, w);
123 
124  if (!x.isNative()) {
125  warningQuda("Blas on non-native fields is not supported\n");
126  return;
127  }
128 
129  blasStrings.vol_str = x.VolString();
130  strcpy(blasStrings.aux_tmp, x.AuxString());
131  strcat(blasStrings.aux_tmp, ",");
132  strcat(blasStrings.aux_tmp, y.AuxString());
133 
134  if (x.SiteSubset() == QUDA_FULL_SITE_SUBSET) {
135  mixed::blasCuda<Functor,writeX,writeY,writeZ,writeW>
136  (a, b, c, x.Even(), y.Even(), z.Even(), w.Even());
137  mixed::blasCuda<Functor,writeX,writeY,writeZ,writeW>
138  (a, b, c, x.Odd(), y.Odd(), z.Odd(), w.Odd());
139  return;
140  }
141 
142  // FIXME: use traits to encapsulate register type for shorts -
143  // will reduce template type parameters from 3 to 2
144 
145  size_t bytes[] = {x.Bytes(), y.Bytes(), z.Bytes(), w.Bytes()};
146  size_t norm_bytes[] = {x.NormBytes(), y.NormBytes(), z.NormBytes(), w.NormBytes()};
147 
148  if (x.Precision() == QUDA_SINGLE_PRECISION && y.Precision() == QUDA_DOUBLE_PRECISION) {
149  if (x.Nspin() == 4) {
150  const int M = 12;
155  Functor<double2, double2> f(a, b, c);
158  Spinor<double2,double4,float4,M,writeW,3>, Functor<double2, double2> >
159  blas(X, Y, Z, W, f, y.Volume(), bytes, norm_bytes);
160  blas.apply(*blasStream);
161  } else if (x.Nspin() == 1) {
162  const int M = 3;
167  Functor<double2, double2> f(a, b, c);
168  BlasCuda<double2,M,
171  Functor<double2, double2> > blas(X, Y, Z, W, f, y.Volume(), bytes, norm_bytes);
172  blas.apply(*blasStream);
173  }
174  } else if (x.Precision() == QUDA_HALF_PRECISION && y.Precision() == QUDA_DOUBLE_PRECISION) {
175  if (x.Nspin() == 4) {
176  const int M = 12;
181  Functor<double2, double2> f(a, b, c);
182  BlasCuda<double2,M,
185  Functor<double2, double2> > blas(X, Y, Z, W, f, y.Volume(), bytes, norm_bytes);
186  blas.apply(*blasStream);
187  } else if (x.Nspin() == 1) {
188  const int M = 3;
193  Functor<double2, double2> f(a, b, c);
194  BlasCuda<double2,M,
197  Functor<double2, double2> > blas(X, Y, Z, W, f, y.Volume(), bytes, norm_bytes);
198  blas.apply(*blasStream);
199  }
200  } else if (y.Precision() == QUDA_SINGLE_PRECISION) {
201  if (x.Nspin() == 4) {
202  const int M = 6;
207  Functor<float2, float4> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
208  BlasCuda<float4,M,
211  Functor<float2, float4> > blas(X, Y, Z, W, f, y.Volume(), bytes, norm_bytes);
212  blas.apply(*blasStream);
213  } else if (x.Nspin() == 1) {
214  const int M = 3;
219  Functor<float2, float2> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
220  BlasCuda<float2, M,
223  Functor<float2, float2> > blas(X, Y, Z, W, f, y.Volume(), bytes, norm_bytes);
224  blas.apply(*blasStream);
225  }
226  } else {
227  errorQuda("Not implemented for this precision combination");
228  }
229 
230  blas_bytes += Functor<double2,double2>::streams()*(unsigned long long)x.RealLength()*x.Precision();
231  blas_flops += Functor<double2,double2>::flops()*(unsigned long long)x.RealLength();
232 
233  checkCudaError();
234 }
235 
236 }
long long bytes() const
long long flops() const
int y[4]
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:20
#define errorQuda(...)
Definition: util_quda.h:73
unsigned long long blas_bytes
Definition: blas_quda.cu:38
cudaStream_t * streams
cudaStream_t * stream
BlasArg(SpinorX X, SpinorY Y, SpinorZ Z, SpinorW W, Functor f, int length)
__global__ void blasKernel(BlasArg< SpinorX, SpinorY, SpinorZ, SpinorW, Functor > arg)
const int length
int length[]
QudaGaugeParam param
Definition: pack_test.cpp:17
void apply(const cudaStream_t &stream)
TuneKey tuneKey() const
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:271
#define warningQuda(...)
Definition: util_quda.h:84
BlasCuda(SpinorX &X, SpinorY &Y, SpinorZ &Z, SpinorW &W, Functor &f, int length, const size_t *bytes, const size_t *norm_bytes)
#define checkLength(a, b)
Definition: blas_quda.cu:25
int x[4]
unsigned long long blas_flops
Definition: blas_quda.cu:37
int tuningIter() const
int Z[4]
Definition: test_util.cpp:28
virtual ~BlasCuda()
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
Definition: complex_quda.h:843
void blasCuda(const double2 &a, const double2 &b, const double2 &c, cudaColorSpinorField &x, cudaColorSpinorField &y, cudaColorSpinorField &z, cudaColorSpinorField &w)
#define checkSpinor(a, b)
Definition: blas_quda.cu:15
#define checkCudaError()
Definition: util_quda.h:110
QudaTune getTuning()
Definition: util_quda.cpp:32