QUDA  v0.5.0
A library for QCD on GPUs
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
blas_core.h
Go to the documentation of this file.
1 
4 template <typename FloatN, int M, typename SpinorX, typename SpinorY,
5  typename SpinorZ, typename SpinorW, typename Functor>
6 __global__ void blasKernel(SpinorX X, SpinorY Y, SpinorZ Z, SpinorW W, Functor f,
7  int length) {
8  unsigned int i = blockIdx.x*(blockDim.x) + threadIdx.x;
9  unsigned int gridSize = gridDim.x*blockDim.x;
10  while (i < length) {
11  FloatN x[M], y[M], z[M], w[M];
12  X.load(x, i);
13  Y.load(y, i);
14  Z.load(z, i);
15  W.load(w, i);
16 
17 #pragma unroll
18  for (int j=0; j<M; j++) f(x[j], y[j], z[j], w[j]);
19 
20  X.save(x, i);
21  Y.save(y, i);
22  Z.save(z, i);
23  W.save(w, i);
24  i += gridSize;
25  }
26 }
27 
28 template <typename FloatN, int M, typename SpinorX, typename SpinorY,
29  typename SpinorZ, typename SpinorW, typename Functor>
30 class BlasCuda : public Tunable {
31 
32 private:
33  SpinorX &X;
34  SpinorY &Y;
35  SpinorZ &Z;
36  SpinorW &W;
37 
38  // host pointers used for backing up fields when tuning
39  // these can't be curried into the Spinors because of Tesla argument length restriction
40  char *X_h, *Y_h, *Z_h, *W_h;
41  char *Xnorm_h, *Ynorm_h, *Znorm_h, *Wnorm_h;
42 
43  Functor &f;
44  const int length;
45 
46  int sharedBytesPerThread() const { return 0; }
47  int sharedBytesPerBlock(const TuneParam &param) const { return 0; }
48 
49  virtual bool advanceSharedBytes(TuneParam &param) const
50  {
51  TuneParam next(param);
52  advanceBlockDim(next); // to get next blockDim
53  int nthreads = next.block.x * next.block.y * next.block.z;
54  param.shared_bytes = sharedBytesPerThread()*nthreads > sharedBytesPerBlock(param) ?
55  sharedBytesPerThread()*nthreads : sharedBytesPerBlock(param);
56  return false;
57  }
58 
59 public:
60  BlasCuda(SpinorX &X, SpinorY &Y, SpinorZ &Z, SpinorW &W, Functor &f,
61  int length) :
62  X(X), Y(Y), Z(Z), W(W), f(f), X_h(0), Y_h(0), Z_h(0), W_h(0),
63  Xnorm_h(0), Ynorm_h(0), Znorm_h(0), Wnorm_h(0), length(length)
64  { ; }
65  virtual ~BlasCuda() { }
66 
67  TuneKey tuneKey() const {
68  std::stringstream vol, aux;
69  vol << blasConstants.x[0] << "x";
70  vol << blasConstants.x[1] << "x";
71  vol << blasConstants.x[2] << "x";
72  vol << blasConstants.x[3];
73  aux << "stride=" << blasConstants.stride << ",prec=" << X.Precision();
74  return TuneKey(vol.str(), typeid(f).name(), aux.str());
75  }
76 
77  void apply(const cudaStream_t &stream) {
78  TuneParam tp = tuneLaunch(*this, blasTuning, verbosity);
79  blasKernel<FloatN,M> <<<tp.grid, tp.block, tp.shared_bytes, stream>>>
80  (X, Y, Z, W, f, length);
81  }
82 
83  void preTune() {
84  size_t bytes = X.Precision()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*M*X.Stride();
85  size_t norm_bytes = (X.Precision() == QUDA_HALF_PRECISION) ? sizeof(float)*length : 0;
86  X.save(&X_h, &Xnorm_h, bytes, norm_bytes);
87  Y.save(&Y_h, &Ynorm_h, bytes, norm_bytes);
88  Z.save(&Z_h, &Znorm_h, bytes, norm_bytes);
89  W.save(&W_h, &Wnorm_h, bytes, norm_bytes);
90  }
91 
92  void postTune() {
93  size_t bytes = X.Precision()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*M*X.Stride();
94  size_t norm_bytes = (X.Precision() == QUDA_HALF_PRECISION) ? sizeof(float)*length : 0;
95  X.load(&X_h, &Xnorm_h, bytes, norm_bytes);
96  Y.load(&Y_h, &Ynorm_h, bytes, norm_bytes);
97  Z.load(&Z_h, &Znorm_h, bytes, norm_bytes);
98  W.load(&W_h, &Wnorm_h, bytes, norm_bytes);
99  }
100 
101  long long flops() const { return f.flops()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*length*M; }
102  long long bytes() const {
103  size_t bytes = X.Precision()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*M;
104  if (X.Precision() == QUDA_HALF_PRECISION) bytes += sizeof(float);
105  return f.streams()*bytes*length; }
106 };
107 
111 template <template <typename Float, typename FloatN> class Functor,
112  int writeX, int writeY, int writeZ, int writeW>
113 void blasCuda(const double2 &a, const double2 &b, const double2 &c,
114  cudaColorSpinorField &x, cudaColorSpinorField &y,
115  cudaColorSpinorField &z, cudaColorSpinorField &w) {
116  checkSpinor(x, y);
117  checkSpinor(x, z);
118  checkSpinor(x, w);
119 
120  for (int d=0; d<QUDA_MAX_DIM; d++) blasConstants.x[d] = x.X()[d];
121  blasConstants.stride = x.Stride();
122 
123  if (x.SiteSubset() == QUDA_FULL_SITE_SUBSET) {
124  blasCuda<Functor,writeX,writeY,writeZ,writeW>
125  (a, b, c, x.Even(), y.Even(), z.Even(), w.Even());
126  blasCuda<Functor,writeX,writeY,writeZ,writeW>
127  (a, b, c, x.Odd(), y.Odd(), z.Odd(), w.Odd());
128  return;
129  }
130 
131  if (x.Precision() == QUDA_DOUBLE_PRECISION) {
132  const int M = 1;
137  Functor<double2, double2> f(a,b,c);
138  BlasCuda<double2,M,
141  Functor<double2, double2> > blas(X, Y, Z, W, f, x.Length()/(2*M));
142  blas.apply(*blasStream);
143  } else if (x.Precision() == QUDA_SINGLE_PRECISION) {
144  const int M = 1;
145  if (x.Nspin() == 4) {
150  Functor<float2, float4> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
151  BlasCuda<float4,M,
154  Functor<float2, float4> > blas(X, Y, Z, W, f, x.Length()/(4*M));
155  blas.apply(*blasStream);
156  } else {
161  Functor<float2, float2> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
162  BlasCuda<float2,M,
165  Functor<float2, float2> > blas(X, Y, Z, W, f, x.Length()/(2*M));
166  blas.apply(*blasStream);
167  }
168  } else {
169  if (x.Nspin() == 4){ //wilson
174  Functor<float2, float4> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
175  BlasCuda<float4, 6,
178  Functor<float2, float4> > blas(X, Y, Z, W, f, y.Volume());
179  blas.apply(*blasStream);
180  } else if (x.Nspin() == 1) {//staggered
185  Functor<float2, float2> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
186  BlasCuda<float2, 3,
189  Functor<float2, float2> > blas(X, Y, Z, W, f, y.Volume());
190  blas.apply(*blasStream);
191  } else { errorQuda("ERROR: nSpin=%d is not supported\n", x.Nspin()); }
192  blas_bytes += Functor<double2,double2>::streams()*(unsigned long long)x.Volume()*sizeof(float);
193  }
194  blas_bytes += Functor<double2,double2>::streams()*(unsigned long long)x.RealLength()*x.Precision();
195  blas_flops += Functor<double2,double2>::flops()*(unsigned long long)x.RealLength();
196 
197  checkCudaError();
198 }
199