4 template <
typename SpinorX,
typename SpinorY,
typename SpinorZ,
5 typename SpinorW,
typename Functor>
20 template <
typename FloatN,
int M,
typename SpinorX,
typename SpinorY,
21 typename SpinorZ,
typename SpinorW,
typename Functor>
23 unsigned int i = blockIdx.x*(
blockDim.x) + threadIdx.x;
24 unsigned int parity = blockIdx.y;
29 while (
i <
arg.length) {
30 FloatN
x[M],
y[M],
z[M],
w[M];
37 for (
int j=0; j<M; j++)
arg.f(
x[j],
y[j],
z[j],
w[j]);
47 template <
typename FloatN,
int M,
typename SpinorX,
typename SpinorY,
48 typename SpinorZ,
typename SpinorW,
typename Functor>
68 TuneParam next(
param);
69 advanceBlockDim(next);
70 int nthreads = next.block.x * next.block.y * next.block.z;
77 BlasCuda(SpinorX &
X, SpinorY &Y, SpinorZ &
Z, SpinorW &W, Functor &
f,
79 arg(
X, Y,
Z, W,
f,
length/
nParity),
nParity(
nParity),
X_h(0),
Y_h(0),
Z_h(0),
W_h(0),
90 blasKernel<FloatN,M> <<<tp.grid, tp.block, tp.shared_bytes,
stream>>>(
arg);
108 Tunable::initTuneParam(
param);
113 Tunable::initTuneParam(
param);
129 return ((
arg.f.streams()-2)*base_bytes + 2*extra_bytes)*
arg.length*
nParity;
134 template <
typename RegType,
typename StoreType,
typename yType,
int M,
135 template <
typename,
typename>
class Functor,
136 int writeX,
int writeY,
int writeZ,
int writeW>
137 void blasCuda(
const double2 &
a,
const double2 &
b,
const double2 &
c,
138 ColorSpinorField &
x, ColorSpinorField &
y,
139 ColorSpinorField &
z, ColorSpinorField &
w,
int length) {
144 warningQuda(
"Device blas on non-native fields is not supported\n");
150 if (
typeid(StoreType) !=
typeid(yType)) {
155 size_t bytes[] = {
x.Bytes(),
y.Bytes(),
z.Bytes(),
w.Bytes()};
156 size_t norm_bytes[] = {
x.NormBytes(),
y.NormBytes(),
z.NormBytes(),
w.NormBytes()};
163 typedef typename scalar<RegType>::type Float;
166 Functor<Float2, RegType>
f( (Float2)vec2(
a), (Float2)vec2(
b), (Float2)vec2(
c));
168 int partitions = (
x.IsComposite() ?
x.CompositeDim() : 1) * (
x.SiteSubset());
170 decltype(
X), decltype(Y), decltype(
Z), decltype(W),
171 Functor<Float2, RegType> >
189 template <
typename Float2,
int writeX,
int writeY,
int writeZ,
int writeW,
190 typename SpinorX,
typename SpinorY,
typename SpinorZ,
typename SpinorW,
195 for (
int x=0;
x<
X.VolumeCB();
x++) {
196 for (
int s=0;
s<
X.Nspin();
s++) {
197 for (
int c=0;
c<
X.Ncolor();
c++) {
198 Float2 X2 = make_Float2<Float2>(
X(
parity,
x,
s,
c) );
199 Float2 Y2 = make_Float2<Float2>( Y(
parity,
x,
s,
c) );
200 Float2 Z2 = make_Float2<Float2>(
Z(
parity,
x,
s,
c) );
201 Float2 W2 = make_Float2<Float2>( W(
parity,
x,
s,
c) );
214 int writeX,
int writeY,
int writeZ,
int writeW,
typename Functor>
216 ColorSpinorField &
w, Functor
f) {
217 colorspinor::FieldOrderCB<Float,nSpin,nColor,1,order>
X(
x),
Z(
z), W(
w);
218 colorspinor::FieldOrderCB<yFloat,nSpin,nColor,1,order> Y(
y);
220 genericBlas<Float2,writeX,writeY,writeZ,writeW>(
X, Y,
Z, W,
f);
223 template <
typename Float,
typename yFloat,
int nSpin,
QudaFieldOrder order,
224 int writeX,
int writeY,
int writeZ,
int writeW,
typename Functor>
225 void genericBlas(ColorSpinorField &
x, ColorSpinorField &
y, ColorSpinorField &
z, ColorSpinorField &
w, Functor
f) {
226 if (
x.Ncolor() == 2) {
227 genericBlas<Float,yFloat,nSpin,2,order,writeX,writeY,writeZ,writeW,Functor>(
x,
y,
z,
w,
f);
228 }
else if (
x.Ncolor() == 3) {
229 genericBlas<Float,yFloat,nSpin,3,order,writeX,writeY,writeZ,writeW,Functor>(
x,
y,
z,
w,
f);
230 }
else if (
x.Ncolor() == 4) {
231 genericBlas<Float,yFloat,nSpin,4,order,writeX,writeY,writeZ,writeW,Functor>(
x,
y,
z,
w,
f);
232 }
else if (
x.Ncolor() == 8) {
233 genericBlas<Float,yFloat,nSpin,8,order,writeX,writeY,writeZ,writeW,Functor>(
x,
y,
z,
w,
f);
234 }
else if (
x.Ncolor() == 12) {
235 genericBlas<Float,yFloat,nSpin,12,order,writeX,writeY,writeZ,writeW,Functor>(
x,
y,
z,
w,
f);
236 }
else if (
x.Ncolor() == 16) {
237 genericBlas<Float,yFloat,nSpin,16,order,writeX,writeY,writeZ,writeW,Functor>(
x,
y,
z,
w,
f);
238 }
else if (
x.Ncolor() == 20) {
239 genericBlas<Float,yFloat,nSpin,20,order,writeX,writeY,writeZ,writeW,Functor>(
x,
y,
z,
w,
f);
240 }
else if (
x.Ncolor() == 24) {
241 genericBlas<Float,yFloat,nSpin,24,order,writeX,writeY,writeZ,writeW,Functor>(
x,
y,
z,
w,
f);
242 }
else if (
x.Ncolor() == 32) {
243 genericBlas<Float,yFloat,nSpin,32,order,writeX,writeY,writeZ,writeW,Functor>(
x,
y,
z,
w,
f);
245 errorQuda(
"nColor = %d not implemeneted",
x.Ncolor());
249 template <
typename Float,
typename yFloat, QudaFieldOrder order,
int writeX,
int writeY,
int writeZ,
int writeW,
typename Functor>
250 void genericBlas(ColorSpinorField &
x, ColorSpinorField &
y, ColorSpinorField &
z, ColorSpinorField &
w, Functor
f) {
251 if (
x.Nspin() == 4) {
252 genericBlas<Float,yFloat,4,order,writeX,writeY,writeZ,writeW,Functor>(
x,
y,
z,
w,
f);
253 }
else if (
x.Nspin() == 2) {
254 genericBlas<Float,yFloat,2,order,writeX,writeY,writeZ,writeW,Functor>(
x,
y,
z,
w,
f);
255 #ifdef GPU_STAGGERED_DIRAC 256 }
else if (
x.Nspin() == 1) {
257 genericBlas<Float,yFloat,1,order,writeX,writeY,writeZ,writeW,Functor>(
x,
y,
z,
w,
f);
260 errorQuda(
"nSpin = %d not implemeneted",
x.Nspin());
264 template <
typename Float,
typename yFloat,
int writeX,
int writeY,
int writeZ,
int writeW,
typename Functor>
265 void genericBlas(ColorSpinorField &
x, ColorSpinorField &
y, ColorSpinorField &
z, ColorSpinorField &
w, Functor
f) {
267 genericBlas<Float,yFloat,QUDA_SPACE_SPIN_COLOR_FIELD_ORDER,writeX,writeY,writeZ,writeW,Functor>
void blasCuda(const double2 &a, const double2 &b, const double2 &c, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, int length)
BlasArg(SpinorX X, SpinorY Y, SpinorZ Z, SpinorW W, Functor f, int length)
void initTuneParam(TuneParam ¶m) const
QudaVerbosity getVerbosity()
unsigned int sharedBytesPerThread() const
enum QudaFieldOrder_s QudaFieldOrder
void checkLength(const ColorSpinorField &a, ColorSpinorField &b)
char * strcpy(char *__dst, const char *__src)
__global__ void blasKernel(BlasArg< SpinorX, SpinorY, SpinorZ, SpinorW, Functor > arg)
BlasCuda(SpinorX &X, SpinorY &Y, SpinorZ &Z, SpinorW &W, Functor &f, int length, int nParity, const size_t *bytes, const size_t *norm_bytes)
char * strcat(char *__s1, const char *__s2)
virtual bool advanceSharedBytes(TuneParam ¶m) const
complex< double > make_Complex(const double2 &a)
static cudaStream_t * blasStream
static struct quda::blas::@4 blasStrings
unsigned int sharedBytesPerBlock(const TuneParam ¶m) const
BlasArg< SpinorX, SpinorY, SpinorZ, SpinorW, Functor > arg
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
int int int enum cudaChannelFormatKind f
void genericBlas(SpinorX &X, SpinorY &Y, SpinorZ &Z, SpinorW &W, Functor f)
void defaultTuneParam(TuneParam ¶m) const
const size_t * norm_bytes_
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
void apply(const cudaStream_t &stream)