21 template <
typename Float,
int nDim,
int nColor,
int nParity,
bool dagger,
bool xpay, KernelType kernel_type,
typename Arg>
23 static constexpr
const char *
kernel =
"quda::staggeredGPU";
24 template <
typename Dslash>
27 dslash.
launch(staggeredGPU<Float, nDim, nColor, nParity, dagger, xpay, kernel_type, Arg>, tp, arg, stream);
31 template <
typename Float,
int nDim,
int nColor,
typename Arg>
class Staggered :
public Dslash<Float>
40 Dslash<Float>(arg, out, in,
"kernels/dslash_staggered.cuh"),
51 errorQuda(
"Staggered Dslash not implemented on CPU");
73 int ghost_flops = (3 + 1) * (mv_flops + 2 * in.
Ncolor() * in.
Nspin());
79 switch (arg.kernel_type) {
86 flops_ = ghost_flops * ghost_sites;
91 long long sites = in.
Volume();
92 flops_ = (2 * num_dir * mv_flops +
95 if (arg.xpay) flops_ += xpay_flops * sites;
99 long long ghost_sites = 0;
100 for (
int d = 0; d < 4; d++)
101 if (arg.commDim[d]) ghost_sites += 2 * in.
GhostFace()[d];
102 flops_ -= ghost_flops * ghost_sites;
113 int gauge_bytes_long = arg.reconstruct * in.
Precision();
115 int spinor_bytes = 2 * in.
Ncolor() * in.
Nspin() * in.
Precision() + (isFixed ?
sizeof(float) : 0);
116 int ghost_bytes = 3 * (spinor_bytes + gauge_bytes_long) + (spinor_bytes + gauge_bytes_fat)
117 + 3 * 2 * spinor_bytes;
120 long long bytes_ = 0;
122 switch (arg.kernel_type) {
129 bytes_ = ghost_bytes * ghost_sites;
134 long long sites = in.
Volume();
135 bytes_ = (num_dir * (gauge_bytes_fat + gauge_bytes_long) +
136 num_dir * 2 * spinor_bytes +
139 if (arg.xpay) bytes_ += spinor_bytes;
143 long long ghost_sites = 0;
144 for (
int d = 0; d < 4; d++)
145 if (arg.commDim[d]) ghost_sites += 2 * in.
GhostFace()[d];
146 bytes_ -= ghost_bytes * ghost_sites;
166 constexpr
int nDim = 4;
167 constexpr
bool improved =
true;
169 StaggeredArg<Float, nColor, recon_u, recon_l, improved> arg(out, in, U, L, a, x, parity, dagger, comm_override);
173 staggered, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.
VolumeCB(),
186 #ifdef GPU_STAGGERED_DIRAC 187 if (in.
V() == out.
V())
errorQuda(
"Aliasing pointers");
197 for (
int i = 0; i < 4; i++) {
200 "ERROR: partitioned dimension with local size less than 6 is not supported in improved staggered dslash\n");
205 instantiate<ImprovedStaggeredApply, StaggeredReconstruct>(
out,
in, L, U, a, x,
parity,
dagger, comm_override,
208 errorQuda(
"Staggered dslash has not been built");
void launch(T *f, const TuneParam &tp, Arg &arg, const cudaStream_t &stream)
void apply(const cudaStream_t &stream)
QudaVerbosity getVerbosity()
#define checkPrecision(...)
static void launch(Dslash &dslash, TuneParam &tp, Arg &arg, const cudaStream_t &stream)
const char * VolString() const
ImprovedStaggeredApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &L, const GaugeField &U, double a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
const ColorSpinorField & in
__device__ __host__ void staggered(Arg &arg, int idx, int parity)
const int * GhostFaceCB() const
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
#define checkLocation(...)
Main header file for host and device accessors to GaugeFields.
QudaFieldLocation Location() const
Parameter structure for driving the Staggered Dslash operator.
cpuColorSpinorField * out
enum QudaReconstructType_s QudaReconstructType
static constexpr const char * kernel
void apply(const cudaStream_t &stream)
void ApplyImprovedStaggered(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const GaugeField &L, double a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
Apply the improved staggered dslash operator to a color-spinor field.
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
const int * GhostFace() const
Staggered(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in)
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
QudaPrecision Precision() const
QudaFieldOrder FieldOrder() const
int comm_dim_partitioned(int dim)