11 template <
int NXZ,
typename SpinorX,
typename SpinorY,
typename SpinorZ,
12 typename SpinorW,
typename Functor>
25 for(
int i=0;
i<NXZ; ++
i){
38 #define MAX_MATRIX_SIZE 4096 47 template<
int k,
int NXZ,
typename FloatN,
int M,
typename Arg>
52 FloatN
x[M],
y[M],
z[M],
w[M];
57 for (
int l=0; l < NXZ; l++) {
62 for (
int j=0; j < M; j++)
arg.f(
x[j],
y[j],
z[j],
w[j], k, l);
76 template <
typename FloatN,
int M,
int NXZ,
typename SpinorX,
typename SpinorY,
77 typename SpinorZ,
typename SpinorW,
typename Functor>
81 unsigned int i = blockIdx.x *
blockDim.x + threadIdx.x;
82 unsigned int k = blockIdx.y *
blockDim.y + threadIdx.y;
83 unsigned int parity = blockIdx.z;
86 if (k >=
arg.NYW)
return;
89 case 0: compute< 0,NXZ,FloatN,M>(
arg,
i,
parity);
break;
90 #if MAX_MULTI_BLAS_N >= 2 91 case 1: compute< 1,NXZ,FloatN,M>(
arg,
i,
parity);
break;
92 #if MAX_MULTI_BLAS_N >= 3 93 case 2: compute< 2,NXZ,FloatN,M>(
arg,
i,
parity);
break;
94 #if MAX_MULTI_BLAS_N >= 4 95 case 3: compute< 3,NXZ,FloatN,M>(
arg,
i,
parity);
break;
96 #if MAX_MULTI_BLAS_N >= 5 97 case 4: compute< 4,NXZ,FloatN,M>(
arg,
i,
parity);
break;
98 #if MAX_MULTI_BLAS_N >= 6 99 case 5: compute< 5,NXZ,FloatN,M>(
arg,
i,
parity);
break;
100 #if MAX_MULTI_BLAS_N >= 7 101 case 6: compute< 6,NXZ,FloatN,M>(
arg,
i,
parity);
break;
102 #if MAX_MULTI_BLAS_N >= 8 103 case 7: compute< 7,NXZ,FloatN,M>(
arg,
i,
parity);
break;
104 #if MAX_MULTI_BLAS_N >= 9 105 case 8: compute< 8,NXZ,FloatN,M>(
arg,
i,
parity);
break;
106 #if MAX_MULTI_BLAS_N >= 10 107 case 9: compute< 9,NXZ,FloatN,M>(
arg,
i,
parity);
break;
108 #if MAX_MULTI_BLAS_N >= 11 109 case 10: compute<10,NXZ,FloatN,M>(
arg,
i,
parity);
break;
110 #if MAX_MULTI_BLAS_N >= 12 111 case 11: compute<11,NXZ,FloatN,M>(
arg,
i,
parity);
break;
112 #if MAX_MULTI_BLAS_N >= 13 113 case 12: compute<12,NXZ,FloatN,M>(
arg,
i,
parity);
break;
114 #if MAX_MULTI_BLAS_N >= 14 115 case 13: compute<13,NXZ,FloatN,M>(
arg,
i,
parity);
break;
116 #if MAX_MULTI_BLAS_N >= 15 117 case 14: compute<14,NXZ,FloatN,M>(
arg,
i,
parity);
break;
118 #if MAX_MULTI_BLAS_N >= 16 119 case 15: compute<15,NXZ,FloatN,M>(
arg,
i,
parity);
break;
141 template<
unsigned... digits>
144 template<
unsigned... digits>
147 template<
unsigned rem,
unsigned... digits>
150 template<
unsigned... digits>
154 template<
unsigned num>
158 template <
int NXZ,
typename FloatN,
int M,
typename SpinorX,
typename SpinorY,
159 typename SpinorZ,
typename SpinorW,
typename Functor>
170 std::vector<ColorSpinorField*> &
y, &
w;
177 std::vector<ColorSpinorField*> &
y, std::vector<ColorSpinorField*> &
w)
184 char name[TuneKey::name_n];
186 strcat(name, std::to_string(
NYW).c_str());
193 multiblasKernel<FloatN,M,NXZ> <<<tp.grid, tp.block, tp.shared_bytes,
stream>>>(
arg);
211 TunableVectorY::initTuneParam(
param);
216 TunableVectorY::defaultTuneParam(
param);
233 return ((
arg.f.streams()-2)*base_bytes + 2*extra_bytes)*
arg.length*
nParity;
239 template <
typename T>
247 template <
int NXZ,
typename RegType,
typename StoreType,
typename yType,
int M,
248 template <
int,
typename,
typename>
class Functor,
249 typename write,
typename T>
251 std::vector<ColorSpinorField*> &
x, std::vector<ColorSpinorField*> &
y,
252 std::vector<ColorSpinorField*> &
z, std::vector<ColorSpinorField*> &
w,
255 const int NYW =
y.size();
257 const int N = NXZ > NYW ? NXZ : NYW;
263 typedef typename scalar<RegType>::type Float;
269 if (
a.data &&
a.use_const) {
273 for (
int i=0;
i<NXZ;
i++)
for (
int j=0; j<NYW; j++)
277 Amatrix_h =
reinterpret_cast<signed char*
>(
const_cast<T*
>(
a.data));
280 if (
b.data &&
b.use_const) {
284 for (
int i=0;
i<NXZ;
i++)
for (
int j=0; j<NYW; j++)
288 Bmatrix_h =
reinterpret_cast<signed char*
>(
const_cast<T*
>(
b.data));
291 if (
c.data &&
c.use_const) {
295 for (
int i=0;
i<NXZ;
i++)
for (
int j=0; j<NYW; j++)
299 Cmatrix_h =
reinterpret_cast<signed char*
>(
const_cast<T*
>(
c.data));
308 if (
typeid(StoreType) !=
typeid(yType)) {
313 multi::SpinorTexture<RegType,StoreType,M,0>
X[NXZ];
315 multi::SpinorTexture<RegType,StoreType,M,2>
Z[NXZ];
319 for (
int i=0;
i<NXZ;
i++) {
X[
i].set(*dynamic_cast<cudaColorSpinorField *>(
x[
i]));
Z[
i].set(*dynamic_cast<cudaColorSpinorField *>(
z[
i]));}
320 for (
int i=0;
i<NYW;
i++) { Y[
i].set(*dynamic_cast<cudaColorSpinorField *>(
y[
i])); W[
i].set(*dynamic_cast<cudaColorSpinorField *>(
w[
i]));}
324 Functor<NXZ,Float2, RegType>
f(
a,
b,
c, NYW);
327 multi::SpinorTexture<RegType,StoreType,M,0>,
328 multi::Spinor<RegType, yType,M,write::Y,1>,
329 multi::SpinorTexture<RegType,StoreType,M,2>,
330 multi::Spinor<RegType,StoreType,M,write::W,3>,
332 blas(
X, Y,
Z, W,
f, NYW,
length,
x[0]->SiteSubset(),
y,
w);
348 template <
typename Float2,
typename write,
349 typename SpinorX,
typename SpinorY,
typename SpinorZ,
typename SpinorW,
354 for (
int x=0;
x<
X.VolumeCB();
x++) {
355 for (
int s=0;
s<
X.Nspin();
s++) {
356 for (
int c=0;
c<
X.Ncolor();
c++) {
357 Float2 X2 = make_Float2<Float2>(
X(
parity,
x,
s,
c) );
358 Float2 Y2 = make_Float2<Float2>( Y(
parity,
x,
s,
c) );
359 Float2 Z2 = make_Float2<Float2>(
Z(
parity,
x,
s,
c) );
360 Float2 W2 = make_Float2<Float2>( W(
parity,
x,
s,
c) );
361 f(X2, Y2, Z2, W2, 1 , 1);
374 typename write,
typename Functor>
376 ColorSpinorField &
w, Functor
f) {
377 colorspinor::FieldOrderCB<Float,nSpin,nColor,1,order>
X(
x),
Z(
z), W(
w);
378 colorspinor::FieldOrderCB<yFloat,nSpin,nColor,1,order> Y(
y);
380 genericMultiBlas<Float2,write>(
X, Y,
Z, W,
f);
383 template <
typename Float,
typename yFloat,
int nSpin,
QudaFieldOrder order,
384 typename write,
typename Functor>
385 void genericMultiBlas(ColorSpinorField &
x, ColorSpinorField &
y, ColorSpinorField &
z, ColorSpinorField &
w, Functor
f) {
386 if (
x.Ncolor() == 2) {
387 genericMultiBlas<Float,yFloat,nSpin,2,order,write,Functor>(
x,
y,
z,
w,
f);
388 }
else if (
x.Ncolor() == 3) {
389 genericMultiBlas<Float,yFloat,nSpin,3,order,write,Functor>(
x,
y,
z,
w,
f);
390 }
else if (
x.Ncolor() == 4) {
391 genericMultiBlas<Float,yFloat,nSpin,4,order,write,Functor>(
x,
y,
z,
w,
f);
392 }
else if (
x.Ncolor() == 8) {
393 genericMultiBlas<Float,yFloat,nSpin,8,order,write,Functor>(
x,
y,
z,
w,
f);
394 }
else if (
x.Ncolor() == 12) {
395 genericMultiBlas<Float,yFloat,nSpin,12,order,write,Functor>(
x,
y,
z,
w,
f);
396 }
else if (
x.Ncolor() == 16) {
397 genericMultiBlas<Float,yFloat,nSpin,16,order,write,Functor>(
x,
y,
z,
w,
f);
398 }
else if (
x.Ncolor() == 20) {
399 genericMultiBlas<Float,yFloat,nSpin,20,order,write,Functor>(
x,
y,
z,
w,
f);
400 }
else if (
x.Ncolor() == 24) {
401 genericMultiBlas<Float,yFloat,nSpin,24,order,write,Functor>(
x,
y,
z,
w,
f);
402 }
else if (
x.Ncolor() == 32) {
403 genericMultiBlas<Float,yFloat,nSpin,32,order,write,Functor>(
x,
y,
z,
w,
f);
405 errorQuda(
"nColor = %d not implemeneted",
x.Ncolor());
409 template <
typename Float,
typename yFloat, QudaFieldOrder order,
typename write,
typename Functor>
410 void genericMultiBlas(ColorSpinorField &
x, ColorSpinorField &
y, ColorSpinorField &
z, ColorSpinorField &
w, Functor
f) {
411 if (
x.Nspin() == 4) {
412 genericMultiBlas<Float,yFloat,4,order,write,Functor>(
x,
y,
z,
w,
f);
413 }
else if (
x.Nspin() == 2) {
414 genericMultiBlas<Float,yFloat,2,order,write,Functor>(
x,
y,
z,
w,
f);
415 #ifdef GPU_STAGGERED_DIRAC 416 }
else if (
x.Nspin() == 1) {
417 genericMultiBlas<Float,yFloat,1,order,write,Functor>(
x,
y,
z,
w,
f);
420 errorQuda(
"nSpin = %d not implemeneted",
x.Nspin());
424 template <
typename Float,
typename yFloat,
typename write,
typename Functor>
425 void genericMultiBlas(ColorSpinorField &
x, ColorSpinorField &
y, ColorSpinorField &
z, ColorSpinorField &
w, Functor
f) {
427 genericMultiBlas<Float,yFloat,QUDA_SPACE_SPIN_COLOR_FIELD_ORDER,write,Functor>
static __constant__ signed char Bmatrix_d[MAX_MATRIX_SIZE]
MultiBlasArg< NXZ, SpinorX, SpinorY, SpinorZ, SpinorW, Functor > arg
QudaVerbosity getVerbosity()
char * Wnorm_h[MAX_MULTI_BLAS_N]
static __constant__ signed char Amatrix_d[MAX_MATRIX_SIZE]
enum QudaFieldOrder_s QudaFieldOrder
std::complex< double > Complex
char * strcpy(char *__dst, const char *__src)
char * strcat(char *__s1, const char *__s2)
SpinorY Y[MAX_MULTI_BLAS_N]
coeff_array(const T *data, bool use_const)
complex< double > make_Complex(const double2 &a)
cudaStream_t * getStream()
__device__ void compute(Arg &arg, int idx, int parity)
static __constant__ signed char Cmatrix_d[MAX_MATRIX_SIZE]
void apply(const cudaStream_t &stream)
void initTuneParam(TuneParam ¶m) const
MultiBlasCuda(SpinorX X[], SpinorY Y[], SpinorZ Z[], SpinorW W[], Functor &f, int NYW, int length, int nParity, std::vector< ColorSpinorField *> &y, std::vector< ColorSpinorField *> &w)
static struct quda::blas::@4 blasStrings
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
void multiblasCuda(const coeff_array< T > &a, const coeff_array< T > &b, const coeff_array< T > &c, std::vector< ColorSpinorField *> &x, std::vector< ColorSpinorField *> &y, std::vector< ColorSpinorField *> &z, std::vector< ColorSpinorField *> &w, int length)
int int int enum cudaChannelFormatKind f
__global__ void multiblasKernel(MultiBlasArg< NXZ, SpinorX, SpinorY, SpinorZ, SpinorW, Functor > arg)
Generic multi-blas kernel with four loads and up to four stores.
static signed char * Cmatrix_h
std::vector< ColorSpinorField * > & w
static signed char * Amatrix_h
void genericMultiBlas(SpinorX &X, SpinorY &Y, SpinorZ &Z, SpinorW &W, Functor f)
std::vector< ColorSpinorField * > & y
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
Parameter struct for generic multi-blas kernel.
SpinorW W[MAX_MULTI_BLAS_N]
char * Y_h[MAX_MULTI_BLAS_N]
static const char value[]
char * W_h[MAX_MULTI_BLAS_N]
bool tuneSharedBytes() const
void defaultTuneParam(TuneParam ¶m) const
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
char * Ynorm_h[MAX_MULTI_BLAS_N]
MultiBlasArg(SpinorX X[NXZ], SpinorY Y[], SpinorZ Z[NXZ], SpinorW W[], Functor f, int NYW, int length)
static signed char * Bmatrix_h