8 #include <dslash_helper.cuh>
9 #include <jitify_helper.cuh>
31 template <
template <
int,
bool,
bool, KernelType,
typename>
class D,
typename Arg>
58 comm[0] = (
arg.commDim[0] ?
'1' :
'0');
59 comm[1] = (
arg.commDim[1] ?
'1' :
'0');
60 comm[2] = (
arg.commDim[2] ?
'1' :
'0');
61 comm[3] = (
arg.commDim[3] ?
'1' :
'0');
75 inline void fillAux(KernelType kernel_type,
const char *kernel_str)
77 strcpy(
aux[kernel_type], kernel_str);
82 virtual bool tuneGridDim()
const {
return arg.kernel_type == EXTERIOR_KERNEL_ALL &&
arg.shmem > 0; }
91 if (
arg.kernel_type == EXTERIOR_KERNEL_ALL &&
arg.shmem > 0) {
103 if (
arg.kernel_type == EXTERIOR_KERNEL_ALL &&
arg.shmem > 0) {
105 for (
int d = 0; d < 4; d++)
nDimComms +=
arg.commDim[d];
121 static void *ghost[8] = {};
124 for (
int dir = 0; dir < 2; dir++) {
135 arg.in.resetGhost(
in, ghost);
137 if (
arg.pack_threads && (
arg.kernel_type == INTERIOR_KERNEL ||
arg.kernel_type == UBER_KERNEL)) {
138 arg.blocks_per_dir = tp.
aux.x;
139 arg.setPack(
true, this->packBuffer);
140 arg.in_pack.resetGhost(
in, this->packBuffer);
144 if (
arg.shmem > 0 &&
arg.kernel_type == EXTERIOR_KERNEL_ALL) {
148 if (
arg.shmem > 0 && (
arg.kernel_type == INTERIOR_KERNEL ||
arg.kernel_type == UBER_KERNEL)) {
152 arg.exterior_blocks = ((
arg.shmem & 64) &&
arg.exterior_dims > 0) ?
153 ((
deviceProp.multiProcessorCount) / (2 *
arg.exterior_dims)) * (2 *
arg.exterior_dims * tp.
aux.y) :
155 tp.
grid.x +=
arg.exterior_blocks;
168 if (
arg.pack_threads && (
arg.kernel_type == INTERIOR_KERNEL ||
arg.kernel_type == UBER_KERNEL)) {
170 int max_threads_per_dir = 0;
171 for (
int i = 0; i < 4; ++i) {
172 max_threads_per_dir = std::max(max_threads_per_dir, (
arg.threadDimMapUpper[i] -
arg.threadDimMapLower[i]) / 2);
175 for (
int d = 0; d < 4; d++)
nDimComms +=
arg.commDim[d];
179 const int max_blocks_per_dir = std::max((
deviceProp.multiProcessorCount) / (8 *
nDimComms), 4);
180 if (
param.aux.x + 1 <= max_blocks_per_dir
181 && (
param.aux.x + 1) *
param.block.x < (max_threads_per_dir +
param.block.x - 1)) {
186 if (
arg.exterior_dims > 0 &&
arg.shmem & 64) {
191 if (
param.aux.y < 4) {
216 if (
arg.shmem & 64) {
221 if (
arg.pack_threads && (
arg.kernel_type == INTERIOR_KERNEL ||
arg.kernel_type == UBER_KERNEL))
223 if (
arg.exterior_dims &&
arg.kernel_type == UBER_KERNEL)
param.aux.y = 1;
231 if (
arg.shmem & 64) {
236 if (
arg.pack_threads && (
arg.kernel_type == INTERIOR_KERNEL ||
arg.kernel_type == UBER_KERNEL))
238 if (
arg.exterior_dims &&
arg.kernel_type == UBER_KERNEL)
param.aux.y = 1;
247 template <
template <
bool, QudaPCType,
typename>
class P,
int nParity,
bool dagger,
bool xpay, KernelType kernel_type>
260 template <
template <
bool, QudaPCType,
typename>
class P>
auto kernel_instance()
262 if (!program)
errorQuda(
"Jitify program has not been created");
263 using namespace jitify::reflection;
264 const auto kernel =
"quda::dslashGPU";
267 auto D_instance = reflect<D<0, false, false, INTERIOR_KERNEL, Arg>>();
268 auto D_naked = D_instance.substr(0, D_instance.find(
"<"));
269 auto P_instance = reflect<P<false, QUDA_4D_PC, Arg>>();
270 auto P_naked = P_instance.substr(0, P_instance.find(
"<"));
276 auto instance = program->kernel(kernel).instantiate({D_naked, P_naked, reflect(
arg.nParity), reflect(
arg.dagger),
277 reflect(
arg.xpay), reflect(
arg.kernel_type), reflect<Arg>()});
290 template <
template <
bool, QudaPCType,
typename>
class P,
int nParity,
bool dagger,
bool xpay>
299 switch (
arg.kernel_type) {
300 case INTERIOR_KERNEL: launch<P, nParity, dagger, xpay, INTERIOR_KERNEL>(tp,
stream);
break;
303 case UBER_KERNEL: launch<P, nParity, dagger, xpay, UBER_KERNEL>(tp,
stream);
break;
305 case EXTERIOR_KERNEL_X: launch<P, nParity, dagger, xpay, EXTERIOR_KERNEL_X>(tp,
stream);
break;
306 case EXTERIOR_KERNEL_Y: launch<P, nParity, dagger, xpay, EXTERIOR_KERNEL_Y>(tp,
stream);
break;
307 case EXTERIOR_KERNEL_Z: launch<P, nParity, dagger, xpay, EXTERIOR_KERNEL_Z>(tp,
stream);
break;
308 case EXTERIOR_KERNEL_T: launch<P, nParity, dagger, xpay, EXTERIOR_KERNEL_T>(tp,
stream);
break;
309 case EXTERIOR_KERNEL_ALL: launch<P, nParity, dagger, xpay, EXTERIOR_KERNEL_ALL>(tp,
stream);
break;
310 default:
errorQuda(
"Unexpected kernel type %d",
arg.kernel_type);
312 default:
errorQuda(
"Unexpected kernel type %d for single-GPU build",
arg.kernel_type);
325 template <
template <
bool, QudaPCType,
typename>
class P,
int nParity,
bool xpay>
332 instantiate<P, nParity, true, xpay>(tp,
stream);
334 instantiate<P, nParity, false, xpay>(tp,
stream);
344 template <
template <
bool, QudaPCType,
typename>
class P,
bool xpay>
350 switch (
arg.nParity) {
351 case 1: instantiate<P, 1, xpay>(tp,
stream);
break;
352 case 2: instantiate<P, 2, xpay>(tp,
stream);
break;
353 default:
errorQuda(
"nParity = %d undefined\n",
arg.nParity);
364 template <
template <
bool, QudaPCType,
typename>
class P>
371 instantiate<P, true>(tp,
stream);
373 instantiate<P, false>(tp,
stream);
388 errorQuda(
"CPU Fields not supported in Dslash framework yet");
395 fillAux(INTERIOR_KERNEL,
"policy_kernel=interior");
396 fillAux(UBER_KERNEL,
"policy_kernel=uber");
397 fillAux(EXTERIOR_KERNEL_ALL,
"policy_kernel=exterior_all");
398 fillAux(EXTERIOR_KERNEL_X,
"policy_kernel=exterior_x");
399 fillAux(EXTERIOR_KERNEL_Y,
"policy_kernel=exterior_y");
400 fillAux(EXTERIOR_KERNEL_Z,
"policy_kernel=exterior_z");
401 fillAux(EXTERIOR_KERNEL_T,
"policy_kernel=exterior_t");
403 fillAux(INTERIOR_KERNEL,
"policy_kernel=single-GPU");
405 fillAux(KERNEL_POLICY,
"policy");
415 using D_ = D<0, false, false, INTERIOR_KERNEL, Arg>;
438 for (
int dir = 0; dir < 2; dir++) {
443 }
else if (location &
Shmem) {
466 switch ((
int)location) {
472 strcat(
aux_pack,
arg.exterior_dims > 0 ?
",shmemuber" :
",shmem");
476 default:
errorQuda(
"Unknown pack target location %d\n", location);
482 return 2 *
arg.nFace;
486 const char *
getAux(KernelType type)
const {
return aux[type]; }
488 void setAux(KernelType type,
const char *aux_) { strcpy(
aux[type], aux_); }
490 void augmentAux(KernelType type,
const char *extra) { strcat(
aux[type], extra); }
494 auto aux_ = (
arg.pack_blocks > 0 && (
arg.kernel_type == INTERIOR_KERNEL ||
arg.kernel_type == UBER_KERNEL)) ?
506 if (
arg.kernel_type != INTERIOR_KERNEL &&
arg.kernel_type != UBER_KERNEL &&
arg.kernel_type != KERNEL_POLICY)
515 if (
arg.kernel_type != INTERIOR_KERNEL &&
arg.kernel_type != UBER_KERNEL &&
arg.kernel_type != KERNEL_POLICY)
538 int num_mv_multiply =
in.
Nspin() == 4 ? 2 : 1;
539 int ghost_flops = (num_mv_multiply * mv_flops + 2 *
in.
Ncolor() *
in.
Nspin());
544 long long flops_ = 0;
549 switch (
arg.kernel_type) {
550 case EXTERIOR_KERNEL_X:
551 case EXTERIOR_KERNEL_Y:
552 case EXTERIOR_KERNEL_Z:
553 case EXTERIOR_KERNEL_T:
554 flops_ = (ghost_flops + (
arg.xpay ? xpay_flops : xpay_flops / 2)) * 2 *
in.
GhostFace()[
arg.kernel_type];
556 case EXTERIOR_KERNEL_ALL: {
558 flops_ = (ghost_flops + (
arg.xpay ? xpay_flops : xpay_flops / 2)) * ghost_sites;
561 case INTERIOR_KERNEL:
564 case KERNEL_POLICY: {
567 num_dir * num_mv_multiply * mv_flops +
570 if (
arg.xpay) flops_ += xpay_flops * sites;
572 if (
arg.kernel_type == KERNEL_POLICY)
break;
574 long long ghost_sites = 0;
575 for (
int d = 0; d < 4; d++)
577 flops_ -= ghost_flops * ghost_sites;
591 int proj_spinor_bytes =
in.
Nspin() == 4 ? spinor_bytes / 2 : spinor_bytes;
592 int ghost_bytes = (proj_spinor_bytes + gauge_bytes) + 2 * spinor_bytes;
596 long long bytes_ = 0;
598 switch (
arg.kernel_type) {
599 case EXTERIOR_KERNEL_X:
600 case EXTERIOR_KERNEL_Y:
601 case EXTERIOR_KERNEL_Z:
602 case EXTERIOR_KERNEL_T: bytes_ = ghost_bytes * 2 *
in.
GhostFace()[
arg.kernel_type];
break;
603 case EXTERIOR_KERNEL_ALL: {
605 bytes_ = ghost_bytes * ghost_sites;
608 case INTERIOR_KERNEL:
611 case KERNEL_POLICY: {
613 bytes_ = (num_dir * gauge_bytes + ((num_dir - 2) * spinor_bytes + 2 * proj_spinor_bytes) + spinor_bytes) * sites;
614 if (
arg.xpay) bytes_ += spinor_bytes;
616 if (
arg.kernel_type == KERNEL_POLICY)
break;
618 long long ghost_sites = 0;
619 for (
int d = 0; d < 4; d++)
621 bytes_ -= ghost_bytes * ghost_sites;
const DslashConstant & getDslashConstant() const
Get the dslash_constant structure from this field.
virtual const void * Ghost2() const
size_t GhostOffset(const int dim, const int dir) const
const int * GhostFace() const
This is the generic driver for launching Dslash kernels (the base kernel of which is defined in dslas...
const ColorSpinorField & out
char aux_barrier[TuneKey::aux_n]
void fillAux(KernelType kernel_type, const char *kernel_str)
Specialize the auxiliary strings for each kernel type.
void setPack(bool pack, MemoryLocation location)
virtual int blockMin() const
char aux_base[TuneKey::aux_n - 32]
virtual unsigned int minGridSize() const
const char * getAux(KernelType type) const
void fillAuxBase()
Set the base strings used by the different dslash kernel types for autotuning.
const ColorSpinorField & in
virtual void initTuneParam(TuneParam ¶m) const
Dslash(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in)
void augmentAux(KernelType type, const char *extra)
void setParam(TuneParam &tp)
virtual bool advanceAux(TuneParam ¶m) const
void instantiate(TuneParam &tp, const qudaStream_t &stream)
This instantiate function is used to instantiate the the xpay template.
void instantiate(TuneParam &tp, const qudaStream_t &stream)
This instantiate function is used to instantiate the the dagger template.
virtual long long bytes() const
void launch(TuneParam &tp, const qudaStream_t &stream)
This is a helper class that is used to instantiate the correct templated kernel for the dslash....
virtual long long flops() const
virtual void defaultTuneParam(TuneParam ¶m) const
virtual int blockStep() const
virtual int gridStep() const
gridStep sets the step size when iterating the grid size in advanceGridDim.
virtual unsigned int minThreads() const
virtual TuneKey tuneKey() const
char aux_pack[TuneKey::aux_n]
void instantiate(TuneParam &tp, const qudaStream_t &stream)
This instantiate function is used to instantiate the the nParity template.
void instantiate(TuneParam &tp, const qudaStream_t &stream)
This instantiate function is used to instantiate the the KernelType template required for the multi-G...
void * packBuffer[4 *QUDA_MAX_DIM]
virtual bool tuneGridDim() const
virtual int tuningIter() const
virtual void preTune()
Save the output field since the output field is both read from and written to in the exterior kernels...
void setAux(KernelType type, const char *aux_)
unsigned int maxSharedBytesPerBlock() const
The maximum shared memory that a CUDA thread block can use in the autotuner. This isn't necessarily t...
virtual bool advanceTuneParam(TuneParam ¶m) const
char aux[8][TuneKey::aux_n]
virtual void postTune()
Restore the output field if doing exterior kernel.
void * remoteFace_r() const
Return base pointer to the ghost recv buffer. Since this is a base pointer, one still needs to take c...
virtual void backup() const
Backs up the LatticeField.
const char * VolString() const
void * remoteFace_d(int dir, int dim) const
Return base pointer to a remote device buffer for direct sending in a given direction and dimension....
void * myFace_d(int dir, int dim) const
Return pointer to the device send buffer in a given direction and dimension.
QudaPrecision Precision() const
QudaFieldLocation Location() const
void * myFace_hd(int dir, int dim) const
Return pointer to the local mapped my_face buffer in a given direction and dimension.
virtual void restore() const
Restores the LatticeField.
unsigned int maxDynamicSharedBytesPerBlock() const
Returns the maximum dynamic shared memory per block.
virtual int gridStep() const
gridStep sets the step size when iterating the grid size in advanceGridDim.
virtual bool advanceGridDim(TuneParam ¶m) const
virtual unsigned int minGridSize() const
virtual bool advanceSharedBytes(TuneParam ¶m) const
virtual unsigned int maxGridSize() const
unsigned int vector_length_y
void initTuneParam(TuneParam ¶m) const
bool advanceBlockDim(TuneParam ¶m) const
void defaultTuneParam(TuneParam ¶m) const
bool set_max_shared_bytes
const char * comm_dim_partitioned_string(const int *comm_dim_override=0)
Return a string that defines the comm partitioning (used as a tuneKey)
bool comm_peer2peer_enabled(int dir, int dim)
int comm_peer2peer_enabled_global()
@ QUDA_CPU_FIELD_LOCATION
#define checkLocation(...)
void xpay(ColorSpinorField &x, double a, ColorSpinorField &y)
shmem_sync_t inc_shmem_sync_counter()
increase the shmem sync counter for the next dslash application
shmem_sync_t get_shmem_sync_counter()
Get the shmem sync counter.
bool policyTuning()
Query whether we are currently tuning a policy.
void setPackComms(const int *dim_pack)
Helper function that sets which dimensions the packing kernel should be packing for.
void setUberTuning(bool)
Enable / disable whether we are tuning an uber kernel.
bool activeTuning()
query if tuning is in progress
qudaError_t qudaLaunchKernel(const void *func, const TuneParam &tp, void **args, qudaStream_t stream)
Wrapper around cudaLaunchKernel.
bool uberTuning()
Query whether we are tuning an uber kernel.
FloatingPoint< float > Float
cudaDeviceProp deviceProp
cudaStream_t qudaStream_t
#define QUDA_MAX_DIM
Maximum number of dimensions supported by QUDA. In practice, no routines make use of more than 5.