4 template <
typename SpinorX,
typename SpinorY,
typename SpinorZ,
5 typename SpinorW,
typename Functor>
14 : X(X), Y(Y), Z(Z), W(W), f(f), length(length) { ; }
20 template <
typename FloatN,
int M,
typename SpinorX,
typename SpinorY,
21 typename SpinorZ,
typename SpinorW,
typename Functor>
23 unsigned int i = blockIdx.x*(blockDim.x) + threadIdx.x;
24 unsigned int gridSize = gridDim.x*blockDim.x;
26 FloatN
x[M],
y[M], z[M], w[M];
33 for (
int j=0; j<M; j++) arg.
f(x[j], y[j], z[j], w[j]);
43 template <
typename FloatN,
int M,
typename SpinorX,
typename SpinorY,
44 typename SpinorZ,
typename SpinorW,
typename Functor>
52 char *X_h, *Y_h, *Z_h, *W_h;
53 char *Xnorm_h, *Ynorm_h, *Znorm_h, *Wnorm_h;
55 const size_t *norm_bytes_;
57 unsigned int sharedBytesPerThread()
const {
return 0; }
58 unsigned int sharedBytesPerBlock(
const TuneParam &
param)
const {
return 0; }
60 virtual bool advanceSharedBytes(TuneParam ¶m)
const
62 TuneParam next(param);
63 advanceBlockDim(next);
64 int nthreads = next.block.x * next.block.y * next.block.z;
65 param.shared_bytes = sharedBytesPerThread()*nthreads > sharedBytesPerBlock(param) ?
66 sharedBytesPerThread()*nthreads : sharedBytesPerBlock(param);
71 BlasCuda(SpinorX &
X, SpinorY &
Y, SpinorZ &
Z, SpinorW &W, Functor &f,
72 int length,
const size_t *
bytes,
const size_t *norm_bytes) :
73 arg(X, Y, Z, W, f, length), X_h(0), Y_h(0), Z_h(0), W_h(0),
74 Xnorm_h(0), Ynorm_h(0), Znorm_h(0), Wnorm_h(0), bytes_(bytes), norm_bytes_(norm_bytes) { }
79 return TuneKey(blasStrings.vol_str,
typeid(arg.f).name(), blasStrings.aux_str);
84 blasKernel<FloatN,M> <<<tp.grid, tp.block, tp.shared_bytes, stream>>>(arg);
88 arg.X.save(&X_h, &Xnorm_h, bytes_[0], norm_bytes_[0]);
89 arg.Y.save(&Y_h, &Ynorm_h, bytes_[1], norm_bytes_[1]);
90 arg.Z.save(&Z_h, &Znorm_h, bytes_[2], norm_bytes_[2]);
91 arg.W.save(&W_h, &Wnorm_h, bytes_[3], norm_bytes_[3]);
95 arg.X.load(&X_h, &Xnorm_h, bytes_[0], norm_bytes_[0]);
96 arg.Y.load(&Y_h, &Ynorm_h, bytes_[1], norm_bytes_[1]);
97 arg.Z.load(&Z_h, &Znorm_h, bytes_[2], norm_bytes_[2]);
98 arg.W.load(&W_h, &Wnorm_h, bytes_[3], norm_bytes_[3]);
101 long long flops()
const {
return arg.f.flops()*(
sizeof(FloatN)/
sizeof(((FloatN*)0)->x))*arg.length*M; }
103 size_t bytes = arg.X.Precision()*(
sizeof(FloatN)/
sizeof(((FloatN*)0)->x))*M;
105 return arg.f.streams()*bytes*arg.length; }
112 template <
template <
typename Float,
typename FloatN>
class Functor,
113 int writeX,
int writeY,
int writeZ,
int writeW>
114 inline void blasCuda(
const double2 &a,
const double2 &b,
const double2 &c,
115 cudaColorSpinorField &
x, cudaColorSpinorField &
y,
116 cudaColorSpinorField &z, cudaColorSpinorField &w) {
118 static TimeProfile head(
"head");
125 warningQuda(
"Blas on non-native fields is not supported\n");
129 blasStrings.vol_str = x.VolString();
130 blasStrings.aux_str = x.AuxString();
133 blasCuda<Functor,writeX,writeY,writeZ,writeW>
134 (a, b, c, x.Even(), y.Even(), z.Even(), w.Even());
135 blasCuda<Functor,writeX,writeY,writeZ,writeW>
136 (a, b, c, x.Odd(), y.Odd(), z.Odd(), w.Odd());
143 size_t bytes[] = {x.Bytes(), y.Bytes(), z.Bytes(), w.Bytes()};
144 size_t norm_bytes[] = {x.NormBytes(), y.NormBytes(), z.NormBytes(), w.NormBytes()};
152 Functor<double2, double2> f(a,b,c);
154 Spinor<double2,double2,double2,M,writeX,0>,
Spinor<double2,double2,double2,M,writeY,1>,
155 Spinor<double2,double2,double2,M,writeZ,2>,
Spinor<double2,double2,double2,M,writeW,3>,
156 Functor<double2, double2> > blas(X, Y, Z, W, f, x.Length()/(2*M), bytes, norm_bytes);
157 blas.apply(*blasStream);
160 if (x.Nspin() == 4) {
161 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
166 Functor<float2, float4> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
168 Spinor<float4,float4,float4,M,writeX,0>,
Spinor<float4,float4,float4,M,writeY,1>,
169 Spinor<float4,float4,float4,M,writeZ,2>,
Spinor<float4,float4,float4,M,writeW,3>,
170 Functor<float2, float4> > blas(X, Y, Z, W, f, x.Length()/(4*M), bytes, norm_bytes);
171 blas.apply(*blasStream);
173 errorQuda(
"blas has not been built for Nspin=%d fields", x.Nspin());
176 #ifdef GPU_STAGGERED_DIRAC
181 Functor<float2, float2> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
183 Spinor<float2,float2,float2,M,writeX,0>,
Spinor<float2,float2,float2,M,writeY,1>,
184 Spinor<float2,float2,float2,M,writeZ,2>,
Spinor<float2,float2,float2,M,writeW,3>,
185 Functor<float2, float2> > blas(X, Y, Z, W, f, x.Length()/(2*M), bytes, norm_bytes);
186 blas.apply(*blasStream);
188 errorQuda(
"blas has not been built for Nspin=%d fields", x.Nspin());
193 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
198 Functor<float2, float4> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
200 Spinor<float4,float4,short4,6,writeX,0>,
Spinor<float4,float4,short4,6,writeY,1>,
201 Spinor<float4,float4,short4,6,writeZ,2>,
Spinor<float4,float4,short4,6,writeW,3>,
202 Functor<float2, float4> > blas(X, Y, Z, W, f, y.Volume(), bytes, norm_bytes);
203 blas.apply(*blasStream);
205 errorQuda(
"blas has not been built for Nspin=%d fields", x.Nspin());
207 }
else if (x.Nspin() == 1) {
208 #ifdef GPU_STAGGERED_DIRAC
213 Functor<float2, float2> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
215 Spinor<float2,float2,short2,3,writeX,0>,
Spinor<float2,float2,short2,3,writeY,1>,
216 Spinor<float2,float2,short2,3,writeZ,2>,
Spinor<float2,float2,short2,3,writeW,3>,
217 Functor<float2, float2> > blas(X, Y, Z, W, f, y.Volume(), bytes, norm_bytes);
218 blas.apply(*blasStream);
220 errorQuda(
"blas has not been built for Nspin=%d fields", x.Nspin());
222 }
else {
errorQuda(
"ERROR: nSpin=%d is not supported\n", x.Nspin()); }
227 blas_flops += Functor<double2,double2>::flops()*(
unsigned long long)x.RealLength();
void blasCuda(const double2 &a, const double2 &b, const double2 &c, cudaColorSpinorField &x, cudaColorSpinorField &y, cudaColorSpinorField &z, cudaColorSpinorField &w)
BlasArg(SpinorX X, SpinorY Y, SpinorZ Z, SpinorW W, Functor f, int length)
QudaVerbosity getVerbosity()
unsigned long long blas_bytes
__global__ void blasKernel(BlasArg< SpinorX, SpinorY, SpinorZ, SpinorW, Functor > arg)
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
unsigned long long blas_flops
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
#define checkSpinor(a, b)
BlasCuda(SpinorX &X, SpinorY &Y, SpinorZ &Z, SpinorW &W, Functor &f, int length, const size_t *bytes, const size_t *norm_bytes)
void apply(const cudaStream_t &stream)