QUDA  v0.7.0
A library for QCD on GPUs
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
blas_core.h
Go to the documentation of this file.
1 
4 template <typename SpinorX, typename SpinorY, typename SpinorZ,
5  typename SpinorW, typename Functor>
6 struct BlasArg {
7  SpinorX X;
8  SpinorY Y;
9  SpinorZ Z;
10  SpinorW W;
11  Functor f;
12  const int length;
13  BlasArg(SpinorX X, SpinorY Y, SpinorZ Z, SpinorW W, Functor f, int length)
14  : X(X), Y(Y), Z(Z), W(W), f(f), length(length) { ; }
15 };
16 
20 template <typename FloatN, int M, typename SpinorX, typename SpinorY,
21  typename SpinorZ, typename SpinorW, typename Functor>
23  unsigned int i = blockIdx.x*(blockDim.x) + threadIdx.x;
24  unsigned int gridSize = gridDim.x*blockDim.x;
25  while (i < arg.length) {
26  FloatN x[M], y[M], z[M], w[M];
27  arg.X.load(x, i);
28  arg.Y.load(y, i);
29  arg.Z.load(z, i);
30  arg.W.load(w, i);
31 
32 #pragma unroll
33  for (int j=0; j<M; j++) arg.f(x[j], y[j], z[j], w[j]);
34 
35  arg.X.save(x, i);
36  arg.Y.save(y, i);
37  arg.Z.save(z, i);
38  arg.W.save(w, i);
39  i += gridSize;
40  }
41 }
42 
43 template <typename FloatN, int M, typename SpinorX, typename SpinorY,
44  typename SpinorZ, typename SpinorW, typename Functor>
45 class BlasCuda : public Tunable {
46 
47 private:
49 
50  // host pointers used for backing up fields when tuning
51  // these can't be curried into the Spinors because of Tesla argument length restriction
52  char *X_h, *Y_h, *Z_h, *W_h;
53  char *Xnorm_h, *Ynorm_h, *Znorm_h, *Wnorm_h;
54  const size_t *bytes_;
55  const size_t *norm_bytes_;
56 
57  unsigned int sharedBytesPerThread() const { return 0; }
58  unsigned int sharedBytesPerBlock(const TuneParam &param) const { return 0; }
59 
60  virtual bool advanceSharedBytes(TuneParam &param) const
61  {
62  TuneParam next(param);
63  advanceBlockDim(next); // to get next blockDim
64  int nthreads = next.block.x * next.block.y * next.block.z;
65  param.shared_bytes = sharedBytesPerThread()*nthreads > sharedBytesPerBlock(param) ?
66  sharedBytesPerThread()*nthreads : sharedBytesPerBlock(param);
67  return false;
68  }
69 
70 public:
71  BlasCuda(SpinorX &X, SpinorY &Y, SpinorZ &Z, SpinorW &W, Functor &f,
72  int length, const size_t *bytes, const size_t *norm_bytes) :
73  arg(X, Y, Z, W, f, length), X_h(0), Y_h(0), Z_h(0), W_h(0),
74  Xnorm_h(0), Ynorm_h(0), Znorm_h(0), Wnorm_h(0), bytes_(bytes), norm_bytes_(norm_bytes) { }
75 
76  virtual ~BlasCuda() { }
77 
78  inline TuneKey tuneKey() const {
79  return TuneKey(blasStrings.vol_str, typeid(arg.f).name(), blasStrings.aux_str);
80  }
81 
82  inline void apply(const cudaStream_t &stream) {
83  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
84  blasKernel<FloatN,M> <<<tp.grid, tp.block, tp.shared_bytes, stream>>>(arg);
85  }
86 
87  void preTune() {
88  arg.X.save(&X_h, &Xnorm_h, bytes_[0], norm_bytes_[0]);
89  arg.Y.save(&Y_h, &Ynorm_h, bytes_[1], norm_bytes_[1]);
90  arg.Z.save(&Z_h, &Znorm_h, bytes_[2], norm_bytes_[2]);
91  arg.W.save(&W_h, &Wnorm_h, bytes_[3], norm_bytes_[3]);
92  }
93 
94  void postTune() {
95  arg.X.load(&X_h, &Xnorm_h, bytes_[0], norm_bytes_[0]);
96  arg.Y.load(&Y_h, &Ynorm_h, bytes_[1], norm_bytes_[1]);
97  arg.Z.load(&Z_h, &Znorm_h, bytes_[2], norm_bytes_[2]);
98  arg.W.load(&W_h, &Wnorm_h, bytes_[3], norm_bytes_[3]);
99  }
100 
101  long long flops() const { return arg.f.flops()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*arg.length*M; }
102  long long bytes() const {
103  size_t bytes = arg.X.Precision()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*M;
104  if (arg.X.Precision() == QUDA_HALF_PRECISION) bytes += sizeof(float);
105  return arg.f.streams()*bytes*arg.length; }
106  int tuningIter() const { return 3; }
107 };
108 
112 template <template <typename Float, typename FloatN> class Functor,
113  int writeX, int writeY, int writeZ, int writeW>
114 inline void blasCuda(const double2 &a, const double2 &b, const double2 &c,
115  cudaColorSpinorField &x, cudaColorSpinorField &y,
116  cudaColorSpinorField &z, cudaColorSpinorField &w) {
117 
118  static TimeProfile head("head");
119 
120  checkSpinor(x, y);
121  checkSpinor(x, z);
122  checkSpinor(x, w);
123 
124  if (!x.isNative()) {
125  warningQuda("Blas on non-native fields is not supported\n");
126  return;
127  }
128 
129  blasStrings.vol_str = x.VolString();
130  blasStrings.aux_str = x.AuxString();
131 
132  if (x.SiteSubset() == QUDA_FULL_SITE_SUBSET) {
133  blasCuda<Functor,writeX,writeY,writeZ,writeW>
134  (a, b, c, x.Even(), y.Even(), z.Even(), w.Even());
135  blasCuda<Functor,writeX,writeY,writeZ,writeW>
136  (a, b, c, x.Odd(), y.Odd(), z.Odd(), w.Odd());
137  return;
138  }
139 
140  // FIXME: use traits to encapsulate register type for shorts -
141  // will reduce template type parameters from 3 to 2
142 
143  size_t bytes[] = {x.Bytes(), y.Bytes(), z.Bytes(), w.Bytes()};
144  size_t norm_bytes[] = {x.NormBytes(), y.NormBytes(), z.NormBytes(), w.NormBytes()};
145 
146  if (x.Precision() == QUDA_DOUBLE_PRECISION) {
147  const int M = 1;
152  Functor<double2, double2> f(a,b,c);
153  BlasCuda<double2,M,
156  Functor<double2, double2> > blas(X, Y, Z, W, f, x.Length()/(2*M), bytes, norm_bytes);
157  blas.apply(*blasStream);
158  } else if (x.Precision() == QUDA_SINGLE_PRECISION) {
159  const int M = 1;
160  if (x.Nspin() == 4) {
161 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
166  Functor<float2, float4> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
167  BlasCuda<float4,M,
170  Functor<float2, float4> > blas(X, Y, Z, W, f, x.Length()/(4*M), bytes, norm_bytes);
171  blas.apply(*blasStream);
172 #else
173  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
174 #endif
175  } else {
176 #ifdef GPU_STAGGERED_DIRAC
181  Functor<float2, float2> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
182  BlasCuda<float2,M,
185  Functor<float2, float2> > blas(X, Y, Z, W, f, x.Length()/(2*M), bytes, norm_bytes);
186  blas.apply(*blasStream);
187 #else
188  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
189 #endif
190  }
191  } else {
192  if (x.Nspin() == 4){ //wilson
193 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
198  Functor<float2, float4> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
199  BlasCuda<float4, 6,
202  Functor<float2, float4> > blas(X, Y, Z, W, f, y.Volume(), bytes, norm_bytes);
203  blas.apply(*blasStream);
204 #else
205  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
206 #endif
207  } else if (x.Nspin() == 1) {//staggered
208 #ifdef GPU_STAGGERED_DIRAC
213  Functor<float2, float2> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
214  BlasCuda<float2, 3,
217  Functor<float2, float2> > blas(X, Y, Z, W, f, y.Volume(), bytes, norm_bytes);
218  blas.apply(*blasStream);
219 #else
220  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
221 #endif
222  } else { errorQuda("ERROR: nSpin=%d is not supported\n", x.Nspin()); }
223  blas_bytes += Functor<double2,double2>::streams()*(unsigned long long)x.Volume()*sizeof(float);
224  }
225 
226  blas_bytes += Functor<double2,double2>::streams()*(unsigned long long)x.RealLength()*x.Precision();
227  blas_flops += Functor<double2,double2>::flops()*(unsigned long long)x.RealLength();
228 
229  checkCudaError();
230 }
231 
void blasCuda(const double2 &a, const double2 &b, const double2 &c, cudaColorSpinorField &x, cudaColorSpinorField &y, cudaColorSpinorField &z, cudaColorSpinorField &w)
Definition: blas_core.h:114
BlasArg(SpinorX X, SpinorY Y, SpinorZ Z, SpinorW W, Functor f, int length)
Definition: blas_core.h:13
int y[4]
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:20
#define errorQuda(...)
Definition: util_quda.h:73
unsigned long long blas_bytes
Definition: blas_quda.cu:38
cudaStream_t * streams
cudaStream_t * stream
__global__ void blasKernel(BlasArg< SpinorX, SpinorY, SpinorZ, SpinorW, Functor > arg)
Definition: blas_core.h:22
int length[]
QudaGaugeParam param
Definition: pack_test.cpp:17
SpinorX X
Definition: blas_core.h:7
SpinorZ Z
Definition: blas_core.h:9
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:271
#define warningQuda(...)
Definition: util_quda.h:84
void preTune()
Definition: blas_core.h:87
long long bytes() const
Definition: blas_core.h:102
int x[4]
virtual ~BlasCuda()
Definition: blas_core.h:76
const int length
Definition: blas_core.h:12
void postTune()
Definition: blas_core.h:94
unsigned long long blas_flops
Definition: blas_quda.cu:37
SpinorW W
Definition: blas_core.h:10
int Z[4]
Definition: test_util.cpp:28
int tuningIter() const
Definition: blas_core.h:106
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
Definition: complex_quda.h:843
#define checkSpinor(a, b)
Definition: blas_quda.cu:15
#define checkCudaError()
Definition: util_quda.h:110
long long flops() const
Definition: blas_core.h:101
QudaTune getTuning()
Definition: util_quda.cpp:32
BlasCuda(SpinorX &X, SpinorY &Y, SpinorZ &Z, SpinorW &W, Functor &f, int length, const size_t *bytes, const size_t *norm_bytes)
Definition: blas_core.h:71
Functor f
Definition: blas_core.h:11
void apply(const cudaStream_t &stream)
Definition: blas_core.h:82
TuneKey tuneKey() const
Definition: blas_core.h:78
SpinorY Y
Definition: blas_core.h:8