14 template <
typename Float,
typename yFloat,
typename ghostFloat,
int nDim,
int Ns,
int Nc,
int Mc,
bool dslash,
bool clover,
bool dagger, DslashType type>
28 const int max_color_col_stride = 8;
29 mutable int color_col_stride;
30 mutable int dim_threads;
33 long long flops()
const 35 return ((dslash*2*nDim+clover*1)*(8*Ns*Nc*Ns*Nc)-2*Ns*Nc)*nParity*(
long long)out.
VolumeCB();
37 long long bytes()
const 39 return (dslash||clover) * out.
Bytes() + dslash*8*inA.
Bytes() + clover*inB.
Bytes() +
42 unsigned int sharedBytesPerThread()
const {
return (
sizeof(complex<Float>) * Mc); }
43 unsigned int sharedBytesPerBlock(
const TuneParam &
param)
const {
return 0; }
44 bool tuneGridDim()
const {
return false; }
45 bool tuneAuxDim()
const {
return true; }
46 unsigned int minThreads()
const {
return color_col_stride * X.
VolumeCB(); }
50 dim3 grid = param.
grid;
52 param.
grid.z = grid.z;
59 while(param.
block.z <= (
unsigned int)(dim_threads * 2 * 2 * (Nc/Mc))) {
60 param.
block.z+=dim_threads * 2;
61 if ( (dim_threads*2*2*(Nc/Mc)) % param.
block.z == 0) {
62 param.
grid.z = (dim_threads * 2 * 2 * (Nc/Mc)) / param.
block.z;
68 if (param.
block.z <= (
unsigned int)(dim_threads * 2 * 2 * (Nc/Mc)) &&
73 param.
block.z = dim_threads * 2;
74 param.
grid.z = 2 * (Nc/Mc);
88 #ifdef DOT_PRODUCT_SPLIT 90 if (2*param.
aux.x <= max_color_col_stride && Nc % (2*param.
aux.x) == 0 &&
97 color_col_stride = param.
aux.x;
103 if (param.
grid.x < (
unsigned int)
deviceProp.maxGridSize[0])
return true;
109 color_col_stride = param.
aux.x;
114 if (2*param.
aux.y <= nDim &&
117 dim_threads = param.
aux.y;
120 param.
block.z = dim_threads * 2;
121 param.
grid.z = 2* (Nc / Mc);
124 sharedBytesPerThread()*param.
block.x*param.
block.y*param.
block.z : sharedBytesPerBlock(param);
129 dim_threads = param.
aux.y;
135 param.
block.z = dim_threads * 2;
136 param.
grid.z = 2* (Nc / Mc);
139 sharedBytesPerThread()*param.
block.x*param.
block.y*param.
block.z : sharedBytesPerBlock(param);
145 virtual void initTuneParam(
TuneParam ¶m)
const 147 param.
aux = make_int4(1,1,1,1);
148 color_col_stride = param.
aux.x;
149 dim_threads = param.
aux.y;
152 param.
block.z = dim_threads * 2;
153 param.
grid.z = 2*(Nc/Mc);
155 sharedBytesPerThread()*param.
block.x*param.
block.y*param.
block.z : sharedBytesPerBlock(param);
159 virtual void defaultTuneParam(
TuneParam ¶m)
const 161 param.
aux = make_int4(1,1,1,1);
162 color_col_stride = param.
aux.x;
163 dim_threads = param.
aux.y;
169 param.
block.z = dim_threads * 2;
170 param.
grid.z = 2*(Nc/Mc);
172 sharedBytesPerThread()*param.
block.x*param.
block.y*param.
block.z : sharedBytesPerBlock(param);
179 :
TunableVectorY(out.SiteSubset() * (out.Ndim()==5 ? out.X(4) : 1)),
180 out(out), inA(inA), inB(inB), Y(Y), X(X), kappa(kappa), parity(parity),
181 nParity(out.SiteSubset()), nSrc(out.Ndim()==5 ? out.X(4) : 1)
183 strcpy(aux,
"policy_kernel,");
186 create_jitify_program(
"kernels/dslash_coarse.cuh");
204 if (doHalo<type>()) {
205 char label[15] =
",halo=";
206 for (
int dim=0; dim<4; dim++) {
207 for (
int dir=0; dir<2; dir++) {
215 virtual ~DslashCoarse() { }
217 inline void apply(
const cudaStream_t &
stream) {
224 DslashCoarseArg<Float,yFloat,ghostFloat,Ns,Nc,QUDA_SPACE_SPIN_COLOR_FIELD_ORDER,QUDA_QDP_GAUGE_ORDER> arg(out, inA, inB, Y, X, (Float)kappa, parity);
225 coarseDslash<Float,nDim,Ns,Nc,Mc,dslash,clover,dagger,type>(
arg);
234 Arg
arg(out, inA, inB, Y, X, (Float)kappa, parity);
237 using namespace jitify::reflection;
238 jitify_error = program->kernel(
"quda::coarseDslashKernel")
239 .instantiate(Type<Float>(),nDim,Ns,Nc,Mc,(
int)tp.
aux.x,(
int)tp.
aux.y,dslash,clover,
dagger,type,Type<Arg>())
246 coarseDslashKernel<Float,nDim,Ns,Nc,Mc,1,1,dslash,clover,dagger,type> <<<tp.
grid,tp.
block,tp.
shared_bytes,stream>>>(
arg);
248 #ifdef DOT_PRODUCT_SPLIT 250 coarseDslashKernel<Float,nDim,Ns,Nc,Mc,2,1,dslash,clover,dagger,type> <<<tp.
grid,tp.
block,tp.
shared_bytes,stream>>>(
arg);
253 coarseDslashKernel<Float,nDim,Ns,Nc,Mc,4,1,dslash,clover,dagger,type> <<<tp.
grid,tp.
block,tp.
shared_bytes,stream>>>(
arg);
256 coarseDslashKernel<Float,nDim,Ns,Nc,Mc,8,1,dslash,clover,dagger,type> <<<tp.
grid,tp.
block,tp.
shared_bytes,stream>>>(
arg);
258 #endif // DOT_PRODUCT_SPLIT 260 errorQuda(
"Color column stride %d not valid", tp.
aux.x);
266 coarseDslashKernel<Float,nDim,Ns,Nc,Mc,1,2,dslash,clover,dagger,type> <<<tp.
grid,tp.
block,tp.
shared_bytes,stream>>>(
arg);
268 #ifdef DOT_PRODUCT_SPLIT 270 coarseDslashKernel<Float,nDim,Ns,Nc,Mc,2,2,dslash,clover,dagger,type> <<<tp.
grid,tp.
block,tp.
shared_bytes,stream>>>(
arg);
273 coarseDslashKernel<Float,nDim,Ns,Nc,Mc,4,2,dslash,clover,dagger,type> <<<tp.
grid,tp.
block,tp.
shared_bytes,stream>>>(
arg);
276 coarseDslashKernel<Float,nDim,Ns,Nc,Mc,8,2,dslash,clover,dagger,type> <<<tp.
grid,tp.
block,tp.
shared_bytes,stream>>>(
arg);
278 #endif // DOT_PRODUCT_SPLIT 280 errorQuda(
"Color column stride %d not valid", tp.
aux.x);
286 coarseDslashKernel<Float,nDim,Ns,Nc,Mc,1,4,dslash,clover,dagger,type> <<<tp.
grid,tp.
block,tp.
shared_bytes,stream>>>(
arg);
288 #ifdef DOT_PRODUCT_SPLIT 290 coarseDslashKernel<Float,nDim,Ns,Nc,Mc,2,4,dslash,clover,dagger,type> <<<tp.
grid,tp.
block,tp.
shared_bytes,stream>>>(
arg);
293 coarseDslashKernel<Float,nDim,Ns,Nc,Mc,4,4,dslash,clover,dagger,type> <<<tp.
grid,tp.
block,tp.
shared_bytes,stream>>>(
arg);
296 coarseDslashKernel<Float,nDim,Ns,Nc,Mc,8,4,dslash,clover,dagger,type> <<<tp.
grid,tp.
block,tp.
shared_bytes,stream>>>(
arg);
298 #endif // DOT_PRODUCT_SPLIT 300 errorQuda(
"Color column stride %d not valid", tp.
aux.x);
304 errorQuda(
"Invalid dimension thread splitting %d", tp.
aux.y);
315 saveOut =
new char[out.
Bytes()];
316 cudaMemcpy(saveOut, out.
V(), out.
Bytes(), cudaMemcpyDeviceToHost);
321 cudaMemcpy(out.
V(), saveOut, out.
Bytes(), cudaMemcpyHostToDevice);
328 template <
typename Float,
typename yFloat,
typename ghostFloat,
int coarseColor,
int coarseSpin>
333 const int colors_per_thread = 1;
341 DslashCoarse<Float,yFloat,ghostFloat,nDim,coarseSpin,coarseColor,colors_per_thread,true,true,true,DSLASH_FULL> dslash(out, inA, inB, Y, X, kappa, parity, halo_location);
344 DslashCoarse<Float,yFloat,ghostFloat,nDim,coarseSpin,coarseColor,colors_per_thread,true,true,true,DSLASH_INTERIOR> dslash(out, inA, inB, Y, X, kappa, parity, halo_location);
346 }
else {
errorQuda(
"Dslash type %d not instantiated", type); }
351 DslashCoarse<Float,yFloat,ghostFloat,nDim,coarseSpin,coarseColor,colors_per_thread,true,false,true,DSLASH_FULL> dslash(out, inA, inB, Y, X, kappa, parity, halo_location);
354 DslashCoarse<Float,yFloat,ghostFloat,nDim,coarseSpin,coarseColor,colors_per_thread,true,false,true,DSLASH_INTERIOR> dslash(out, inA, inB, Y, X, kappa, parity, halo_location);
356 }
else {
errorQuda(
"Dslash type %d not instantiated", type); }
363 DslashCoarse<Float,yFloat,ghostFloat,nDim,coarseSpin,coarseColor,colors_per_thread,false,true,true,DSLASH_FULL> dslash(out, inA, inB, Y, X, kappa, parity, halo_location);
366 errorQuda(
"Unsupported dslash=false clover=false");
376 DslashCoarse<Float,yFloat,ghostFloat,nDim,coarseSpin,coarseColor,colors_per_thread,true,true,false,DSLASH_FULL> dslash(out, inA, inB, Y, X, kappa, parity, halo_location);
379 DslashCoarse<Float,yFloat,ghostFloat,nDim,coarseSpin,coarseColor,colors_per_thread,true,true,false,DSLASH_INTERIOR> dslash(out, inA, inB, Y, X, kappa, parity, halo_location);
381 }
else {
errorQuda(
"Dslash type %d not instantiated", type); }
386 DslashCoarse<Float,yFloat,ghostFloat,nDim,coarseSpin,coarseColor,colors_per_thread,true,false,false,DSLASH_FULL> dslash(out, inA, inB, Y, X, kappa, parity, halo_location);
389 DslashCoarse<Float,yFloat,ghostFloat,nDim,coarseSpin,coarseColor,colors_per_thread,true,false,false,DSLASH_INTERIOR> dslash(out, inA, inB, Y, X, kappa, parity, halo_location);
391 }
else {
errorQuda(
"Dslash type %d not instantiated", type); }
397 DslashCoarse<Float,yFloat,ghostFloat,nDim,coarseSpin,coarseColor,colors_per_thread,false,true,false,DSLASH_FULL> dslash(out, inA, inB, Y, X, kappa, parity, halo_location);
400 errorQuda(
"Unsupported dslash=false clover=false");
407 template <
typename Float,
typename yFloat,
typename ghostFloat>
418 if (inA.
Nspin() != 2)
422 }
else if (inA.
Ncolor() == 4) {
423 ApplyCoarse<Float,yFloat,ghostFloat,4,2>(
out, inA, inB, Y,
X,
kappa,
parity, dslash, clover,
dagger, type, halo_location);
425 if (inA.Ncolor() == 6) {
426 ApplyCoarse<Float,yFloat,ghostFloat,6,2>(
out, inA, inB, Y,
X,
kappa,
parity, dslash, clover,
dagger, type, halo_location);
428 }
else if (inA.Ncolor() == 8) {
429 ApplyCoarse<Float,yFloat,ghostFloat,8,2>(
out, inA, inB, Y,
X,
kappa,
parity, dslash, clover,
dagger, type, halo_location);
430 }
else if (inA.Ncolor() == 12) {
431 ApplyCoarse<Float,yFloat,ghostFloat,12,2>(
out, inA, inB, Y,
X,
kappa,
parity, dslash, clover,
dagger, type, halo_location);
432 }
else if (inA.Ncolor() == 16) {
433 ApplyCoarse<Float,yFloat,ghostFloat,16,2>(
out, inA, inB, Y,
X,
kappa,
parity, dslash, clover,
dagger, type, halo_location);
434 }
else if (inA.Ncolor() == 20) {
435 ApplyCoarse<Float,yFloat,ghostFloat,20,2>(
out, inA, inB, Y,
X,
kappa,
parity, dslash, clover,
dagger, type, halo_location);
437 }
else if (inA.Ncolor() == 24) {
438 ApplyCoarse<Float,yFloat,ghostFloat,24,2>(
out, inA, inB, Y,
X,
kappa,
parity, dslash, clover,
dagger, type, halo_location);
440 }
else if (inA.Ncolor() == 28) {
441 ApplyCoarse<Float,yFloat,ghostFloat,28,2>(
out, inA, inB, Y,
X,
kappa,
parity, dslash, clover,
dagger, type, halo_location);
443 }
else if (inA.Ncolor() == 32) {
444 ApplyCoarse<Float,yFloat,ghostFloat,32,2>(
out, inA, inB, Y,
X,
kappa,
parity, dslash, clover,
dagger, type, halo_location);
446 errorQuda(
"Unsupported number of coarse dof %d\n", Y.Ncolor());
456 #endif // GPU_MULTIGRID 488 bool dslash,
bool clover,
bool dagger,
const int *commDim,
QudaPrecision halo_precision)
489 : out(out), inA(inA), inB(inB), Y(Y), X(X), kappa(kappa), parity(parity),
490 dslash(dslash), clover(clover), dagger(dagger), commDim(commDim),
498 if (inA.
V() == out.
V())
errorQuda(
"Aliasing pointers");
508 if (commDim)
for (
int i=0; i<4; i++) comm_sum -= (1-commDim[i]);
509 if (comm_sum != 4 && comm_sum != 0)
errorQuda(
"Unsupported comms %d", comm_sum);
510 bool comms = comm_sum;
535 pack_destination, halo_location, gdr_send, gdr_recv, halo_precision);
541 #ifdef GPU_MULTIGRID_DOUBLE 545 errorQuda(
"Halo precision %d not supported with field precision %d and link precision %d", halo_precision, precision, Y.
Precision());
546 ApplyCoarse<double,double,double>(
out, inA, inB, Y,
X,
kappa,
parity, dslash, clover,
550 errorQuda(
"Double precision multigrid has not been enabled");
555 ApplyCoarse<float,float,float>(
out, inA, inB, Y,
X,
kappa,
parity, dslash, clover,
558 errorQuda(
"Halo precision %d not supported with field precision %d and link precision %d", halo_precision, precision, Y.
Precision());
561 #if QUDA_PRECISION & 2 563 ApplyCoarse<float,short,short>(
out, inA, inB, Y,
X,
kappa,
parity, dslash, clover,
566 #if QUDA_PRECISION & 1 567 ApplyCoarse<float,short,char>(
out, inA, inB, Y,
X,
kappa,
parity, dslash, clover,
570 errorQuda(
"QUDA_PRECISION=%d does not enable quarter precision", QUDA_PRECISION);
573 errorQuda(
"Halo precision %d not supported with field precision %d and link precision %d", halo_precision, precision, Y.
Precision());
576 errorQuda(
"QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION);
589 errorQuda(
"Multigrid has not been built");
603 policies[
static_cast<std::size_t
>(p)] = p;
624 static char *dslash_policy_env = getenv(
"QUDA_ENABLE_DSLASH_COARSE_POLICY");
626 if (dslash_policy_env) {
627 std::stringstream policy_list(dslash_policy_env);
630 while (policy_list >> policy_) {
639 errorQuda(
"Cannot select a GDR policy %d unless QUDA_ENABLE_GDR is set", static_cast<int>(dslash_policy));
644 if (policy_list.peek() ==
',') policy_list.ignore();
648 first_active_policy = 0;
663 strcat(policy_string,
",pol=");
665 strcat(policy_string, (
int)
policies[i] == i ?
"1" :
"0");
671 strcpy(aux,
"policy,");
672 if (dslash.
dslash) strcat(aux,
"dslash");
673 strcat(aux, dslash.
clover ?
"clover," :
",");
675 strcat(aux,
",gauge_prec=");
679 strcat(aux, prec_str);
680 strcat(aux,
",halo_prec=");
682 strcat(aux, prec_str);
686 strcat(aux, policy_string);
690 for (
int i = 0; i < 4; i++) comm_sum -= (1 - dslash.
commDim[i]);
691 strcat(aux, comm_sum ?
",full" :
",interior");
705 inline void apply(
const cudaStream_t &stream) {
717 while ((
unsigned)param.
aux.x <
policies.size()-1) {
776 DslashCoarseLaunch Dslash(out, inA, inB, Y, X, kappa, parity, dslash, clover, dagger, commDim, halo_precision);
virtual void apply(const cudaStream_t &stream)=0
void operator()(DslashCoarsePolicy policy)
Execute the coarse dslash using the given policy.
void ApplyCoarse(ColorSpinorField &out, const ColorSpinorField &inA, const ColorSpinorField &inB, const GaugeField &Y, const GaugeField &X, double kappa, int parity=QUDA_INVALID_PARITY, bool dslash=true, bool clover=true, bool dagger=false, const int *commDim=0, QudaPrecision halo_precision=QUDA_INVALID_PRECISION)
Apply the coarse dslash stencil. This single driver accounts for all variations with and without the ...
enum QudaPrecision_s QudaPrecision
QudaGaugeFieldOrder FieldOrder() const
const char * AuxString() const
cudaDeviceProp deviceProp
void disableProfileCount()
Disable the profile kernel counting.
QudaVerbosity getVerbosity()
#define checkPrecision(...)
const ColorSpinorField & inB
Helper file when using jitify run-time compilation. This file should be included in source code...
static char policy_string[TuneKey::aux_n]
int comm_partitioned()
Loop over comm_dim_partitioned(dim) for all comms dimensions.
DslashCoarseLaunch(ColorSpinorField &out, const ColorSpinorField &inA, const ColorSpinorField &inB, const GaugeField &Y, const GaugeField &X, double kappa, int parity, bool dslash, bool clover, bool dagger, const int *commDim, QudaPrecision halo_precision)
unsigned int sharedBytesPerThread() const
const char * VolString() const
unsigned int sharedBytesPerBlock(const TuneParam ¶m) const
void initTuneParam(TuneParam ¶m) const
bool advanceBlockDim(TuneParam ¶m) const
static int first_active_policy
const char * comm_dim_partitioned_string(const int *comm_dim_override=0)
Return a string that defines the comm partitioning (used as a tuneKey)
const char * compile_type_str(const LatticeField &meta, QudaFieldLocation location_=QUDA_INVALID_FIELD_LOCATION)
Helper function for setting auxilary string.
void enableProfileCount()
Enable the profile kernel counting.
DslashCoarseLaunch & dslash
void comm_enable_peer2peer(bool enable)
Enable / disable peer-to-peer communication: used for dslash policies that do not presently support p...
void i32toa(char *buffer, int32_t value)
const char * comm_dim_topology_string()
Return a string that defines the comm topology (for use as a tuneKey)
QudaSiteSubset SiteSubset() const
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
#define checkLocation(...)
void defaultTuneParam(TuneParam ¶m) const
void disable_policy(DslashCoarsePolicy p)
bool advanceAux(TuneParam ¶m) const
enum QudaParity_s QudaParity
void setPolicyTuning(bool)
Enable / disable whether are tuning a policy.
QudaFieldLocation Location() const
void apply(const cudaStream_t &stream)
static int commDim[QUDA_MAX_DIM]
cpuColorSpinorField * out
virtual ~DslashCoarsePolicyTune()
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
const QudaPrecision halo_precision
bool comm_gdr_enabled()
Query if GPU Direct RDMA communication is enabled (global setting)
const ColorSpinorField & inA
virtual void initTuneParam(TuneParam ¶m) const
#define QUDA_MAX_DIM
Maximum number of dimensions supported by QUDA. In practice, no routines make use of more than 5...
void enable_policy(DslashCoarsePolicy p)
void defaultTuneParam(TuneParam ¶m) const
virtual void exchangeGhost(QudaParity parity, int nFace, int dagger, const MemoryLocation *pack_destination=nullptr, const MemoryLocation *halo_location=nullptr, bool gdr_send=false, bool gdr_recv=false, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION) const =0
const std::map< TuneKey, TuneParam > & getTuneCache()
Returns a reference to the tunecache map.
static std::vector< DslashCoarsePolicy > policies(static_cast< int >(DslashCoarsePolicy::DSLASH_COARSE_POLICY_DISABLED), DslashCoarsePolicy::DSLASH_COARSE_POLICY_DISABLED)
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
bool advanceTuneParam(TuneParam ¶m) const
QudaPrecision Precision() const
void initTuneParam(TuneParam ¶m) const
DslashCoarsePolicyTune(DslashCoarseLaunch &dslash)
QudaFieldOrder FieldOrder() const
int comm_dim_partitioned(int dim)
const char * comm_config_string()
Return a string that defines the P2P/GDR environment variable configuration (for use as a tuneKey to ...
virtual void defaultTuneParam(TuneParam ¶m) const