4 template <
typename FloatN,
int M,
typename SpinorX,
typename SpinorY,
5 typename SpinorZ,
typename SpinorW,
typename Functor>
6 __global__
void blasKernel(SpinorX
X, SpinorY Y, SpinorZ
Z, SpinorW W, Functor f,
8 unsigned int i = blockIdx.x*(blockDim.x) + threadIdx.x;
9 unsigned int gridSize = gridDim.x*blockDim.x;
18 for (
int j=0; j<M; j++) f(x[j], y[j], z[j], w[j]);
28 template <
typename FloatN,
int M,
typename SpinorX,
typename SpinorY,
29 typename SpinorZ,
typename SpinorW,
typename Functor>
40 char *X_h, *Y_h, *Z_h, *W_h;
41 char *Xnorm_h, *Ynorm_h, *Znorm_h, *Wnorm_h;
46 int sharedBytesPerThread()
const {
return 0; }
47 int sharedBytesPerBlock(
const TuneParam &
param)
const {
return 0; }
49 virtual bool advanceSharedBytes(TuneParam ¶m)
const
51 TuneParam next(param);
52 advanceBlockDim(next);
53 int nthreads = next.block.x * next.block.y * next.block.z;
54 param.shared_bytes = sharedBytesPerThread()*nthreads > sharedBytesPerBlock(param) ?
55 sharedBytesPerThread()*nthreads : sharedBytesPerBlock(param);
60 BlasCuda(SpinorX &X, SpinorY &Y, SpinorZ &Z, SpinorW &W, Functor &f,
62 X(X), Y(Y), Z(Z), W(W), f(f), X_h(0), Y_h(0), Z_h(0), W_h(0),
63 Xnorm_h(0), Ynorm_h(0), Znorm_h(0), Wnorm_h(0), length(length)
68 std::stringstream vol, aux;
69 vol << blasConstants.x[0] <<
"x";
70 vol << blasConstants.x[1] <<
"x";
71 vol << blasConstants.x[2] <<
"x";
72 vol << blasConstants.x[3];
73 aux <<
"stride=" << blasConstants.stride <<
",prec=" << X.Precision();
74 return TuneKey(vol.str(),
typeid(f).name(), aux.str());
78 TuneParam tp =
tuneLaunch(*
this, blasTuning, verbosity);
79 blasKernel<FloatN,M> <<<tp.grid, tp.block, tp.shared_bytes, stream>>>
80 (X, Y, Z, W, f, length);
84 size_t bytes = X.Precision()*(
sizeof(
FloatN)/
sizeof(((
FloatN*)0)->x))*M*X.Stride();
86 X.save(&X_h, &Xnorm_h, bytes, norm_bytes);
87 Y.save(&Y_h, &Ynorm_h, bytes, norm_bytes);
88 Z.save(&Z_h, &Znorm_h, bytes, norm_bytes);
89 W.save(&W_h, &Wnorm_h, bytes, norm_bytes);
93 size_t bytes = X.Precision()*(
sizeof(
FloatN)/
sizeof(((
FloatN*)0)->x))*M*X.Stride();
95 X.load(&X_h, &Xnorm_h, bytes, norm_bytes);
96 Y.load(&Y_h, &Ynorm_h, bytes, norm_bytes);
97 Z.load(&Z_h, &Znorm_h, bytes, norm_bytes);
98 W.load(&W_h, &Wnorm_h, bytes, norm_bytes);
105 return f.streams()*bytes*length; }
111 template <
template <
typename Float,
typename FloatN>
class Functor,
112 int writeX,
int writeY,
int writeZ,
int writeW>
113 void blasCuda(
const double2 &a,
const double2 &b,
const double2 &c,
114 cudaColorSpinorField &
x, cudaColorSpinorField &y,
115 cudaColorSpinorField &z, cudaColorSpinorField &w) {
120 for (
int d=0; d<
QUDA_MAX_DIM; d++) blasConstants.x[d] = x.X()[d];
121 blasConstants.stride = x.Stride();
124 blasCuda<Functor,writeX,writeY,writeZ,writeW>
125 (a, b, c, x.Even(), y.Even(), z.Even(), w.Even());
126 blasCuda<Functor,writeX,writeY,writeZ,writeW>
127 (a, b, c, x.Odd(), y.Odd(), z.Odd(), w.Odd());
137 Functor<double2, double2> f(a,b,c);
139 Spinor<double2,double2,double2,M,writeX,0>,
Spinor<double2,double2,double2,M,writeY,1>,
140 Spinor<double2,double2,double2,M,writeZ,2>,
Spinor<double2,double2,double2,M,writeW,3>,
141 Functor<double2, double2> > blas(X, Y, Z, W, f, x.Length()/(2*M));
142 blas.apply(*blasStream);
145 if (x.Nspin() == 4) {
150 Functor<float2, float4> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
152 Spinor<float4,float4,float4,M,writeX,0>,
Spinor<float4,float4,float4,M,writeY,1>,
153 Spinor<float4,float4,float4,M,writeZ,2>,
Spinor<float4,float4,float4,M,writeW,3>,
154 Functor<float2, float4> > blas(X, Y, Z, W, f, x.Length()/(4*M));
155 blas.apply(*blasStream);
161 Functor<float2, float2> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
163 Spinor<float2,float2,float2,M,writeX,0>,
Spinor<float2,float2,float2,M,writeY,1>,
164 Spinor<float2,float2,float2,M,writeZ,2>,
Spinor<float2,float2,float2,M,writeW,3>,
165 Functor<float2, float2> > blas(X, Y, Z, W, f, x.Length()/(2*M));
166 blas.apply(*blasStream);
174 Functor<float2, float4> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
176 Spinor<float4,float4,short4,6,writeX,0>,
Spinor<float4,float4,short4,6,writeY,1>,
177 Spinor<float4,float4,short4,6,writeZ,2>,
Spinor<float4,float4,short4,6,writeW,3>,
178 Functor<float2, float4> > blas(X, Y, Z, W, f, y.Volume());
179 blas.apply(*blasStream);
180 }
else if (x.Nspin() == 1) {
185 Functor<float2, float2> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
187 Spinor<float2,float2,short2,3,writeX,0>,
Spinor<float2,float2,short2,3,writeY,1>,
188 Spinor<float2,float2,short2,3,writeZ,2>,
Spinor<float2,float2,short2,3,writeW,3>,
189 Functor<float2, float2> > blas(X, Y, Z, W, f, y.Volume());
190 blas.apply(*blasStream);
191 }
else {
errorQuda(
"ERROR: nSpin=%d is not supported\n", x.Nspin()); }
195 blas_flops += Functor<double2,double2>::flops()*(
unsigned long long)x.RealLength();