3 #include <dslash_helper.cuh>
4 #include <color_spinor_field_order.h>
5 #include <gauge_field_order.h>
6 #include <color_spinor.h>
7 #include <dslash_helper.cuh>
8 #include <index_helper.cuh>
9 #include <gauge_field.h>
11 #include <dslash_policy.cuh>
12 #include <kernels/dslash_staggered.cuh>
15 This is a staggered Dirac operator
21 template <typename Arg> class Staggered : public Dslash<staggered, Arg>
23 using Dslash = Dslash<staggered, Arg>;
28 Staggered(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) : Dslash(arg, out, in) {}
30 void apply(const qudaStream_t &stream)
32 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
34 // operator is anti-Hermitian so do not instantiate dagger
35 if (arg.nParity == 1) {
37 Dslash::template instantiate<packStaggeredShmem, 1, false, true>(tp, stream);
39 Dslash::template instantiate<packStaggeredShmem, 1, false, false>(tp, stream);
40 } else if (arg.nParity == 2) {
42 Dslash::template instantiate<packStaggeredShmem, 2, false, true>(tp, stream);
44 Dslash::template instantiate<packStaggeredShmem, 2, false, false>(tp, stream);
49 per direction / dimension flops
50 SU(3) matrix-vector flops = (8 Nc - 2) * Nc
51 xpay = 2 * 2 * Nc * Ns
53 So for the full dslash we have
54 flops = (2 * 2 * Nd * (8*Nc-2) * Nc) + ((2 * 2 * Nd - 1) * 2 * Nc * Ns)
55 flops_xpay = flops + 2 * 2 * Nc * Ns
57 For Asqtad this should give 1146 for Nc=3,Ns=2 and 1158 for the axpy equivalent
59 long long flops() const
61 int mv_flops = (8 * in.Ncolor() - 2) * in.Ncolor(); // SU(3) matrix-vector flops
62 int ghost_flops = (3 + 1) * (mv_flops + 2 * in.Ncolor() * in.Nspin());
63 int xpay_flops = 2 * 2 * in.Ncolor() * in.Nspin(); // multiply and add per real component
64 int num_dir = 2 * 4; // hard code factor of 4 in direction since fields may be 5-d
68 switch (arg.kernel_type) {
69 case EXTERIOR_KERNEL_X:
70 case EXTERIOR_KERNEL_Y:
71 case EXTERIOR_KERNEL_Z:
72 case EXTERIOR_KERNEL_T: flops_ = ghost_flops * 2 * in.GhostFace()[arg.kernel_type]; break;
73 case EXTERIOR_KERNEL_ALL: {
74 long long ghost_sites = 2 * (in.GhostFace()[0] + in.GhostFace()[1] + in.GhostFace()[2] + in.GhostFace()[3]);
75 flops_ = ghost_flops * ghost_sites;
81 long long sites = in.Volume();
82 flops_ = (2 * num_dir * mv_flops + // SU(3) matrix-vector multiplies
83 (2 * num_dir - 1) * 2 * in.Ncolor() * in.Nspin())
84 * sites; // accumulation
85 if (arg.xpay) flops_ += xpay_flops * sites; // axpy is always on interior
87 if (arg.kernel_type == KERNEL_POLICY) break;
88 // now correct for flops done by exterior kernel
89 long long ghost_sites = 0;
90 for (int d = 0; d < 4; d++)
91 if (arg.commDim[d]) ghost_sites += 2 * in.GhostFace()[d];
92 flops_ -= ghost_flops * ghost_sites;
100 long long bytes() const
102 int gauge_bytes_fat = QUDA_RECONSTRUCT_NO * in.Precision();
103 int gauge_bytes_long = arg.reconstruct * in.Precision();
104 int spinor_bytes = 2 * in.Ncolor() * in.Nspin() * in.Precision() + (isFixed<typename Arg::Float>::value ? sizeof(float) : 0);
105 int ghost_bytes = 3 * (spinor_bytes + gauge_bytes_long) + (spinor_bytes + gauge_bytes_fat)
106 + 3 * 2 * spinor_bytes; // last term is the accumulator load/store through the face
107 int num_dir = 2 * 4; // set to 4-d since we take care of 5-d fermions in derived classes where necessary
109 long long bytes_ = 0;
111 switch (arg.kernel_type) {
112 case EXTERIOR_KERNEL_X:
113 case EXTERIOR_KERNEL_Y:
114 case EXTERIOR_KERNEL_Z:
115 case EXTERIOR_KERNEL_T: bytes_ = ghost_bytes * 2 * in.GhostFace()[arg.kernel_type]; break;
116 case EXTERIOR_KERNEL_ALL: {
117 long long ghost_sites = 2 * (in.GhostFace()[0] + in.GhostFace()[1] + in.GhostFace()[2] + in.GhostFace()[3]);
118 bytes_ = ghost_bytes * ghost_sites;
121 case INTERIOR_KERNEL:
123 case KERNEL_POLICY: {
124 long long sites = in.Volume();
125 bytes_ = (num_dir * (gauge_bytes_fat + gauge_bytes_long) + // gauge reads
126 num_dir * 2 * spinor_bytes + // spinor reads
128 * sites; // spinor write
129 if (arg.xpay) bytes_ += spinor_bytes;
131 if (arg.kernel_type == KERNEL_POLICY) break;
132 // now correct for bytes done by exterior kernel
133 long long ghost_sites = 0;
134 for (int d = 0; d < 4; d++)
135 if (arg.commDim[d]) ghost_sites += 2 * in.GhostFace()[d];
136 bytes_ -= ghost_bytes * ghost_sites;
146 template <typename Float, int nColor, QudaReconstructType recon_l> struct ImprovedStaggeredApply {
148 inline ImprovedStaggeredApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &L,
149 const GaugeField &U, double a, const ColorSpinorField &x, int parity, bool dagger,
150 const int *comm_override, TimeProfile &profile)
152 constexpr int nDim = 4; // MWTODO: this probably should be 5 for mrhs Dslash
153 constexpr bool improved = true;
154 constexpr QudaReconstructType recon_u = QUDA_RECONSTRUCT_NO;
155 StaggeredArg<Float, nColor, nDim, recon_u, recon_l, improved> arg(out, in, U, L, a, x, parity, dagger,
157 Staggered<decltype(arg)> staggered(arg, out, in);
159 dslash::DslashPolicyTune<decltype(staggered)> policy(
160 staggered, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
161 in.GhostFaceCB(), profile);
166 void ApplyImprovedStaggered(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U,
167 const GaugeField &L, double a, const ColorSpinorField &x, int parity, bool dagger,
168 const int *comm_override, TimeProfile &profile)
171 #ifdef GPU_STAGGERED_DIRAC
172 for (int i = 0; i < 4; i++) {
173 if (comm_dim_partitioned(i) && (U.X()[i] < 6)) {
175 "ERROR: partitioned dimension with local size less than 6 is not supported in improved staggered dslash\n");
179 // L must be first gauge field argument since we template on long reconstruct
180 instantiate<ImprovedStaggeredApply, StaggeredReconstruct>(out, in, L, U, a, x, parity, dagger, comm_override,
183 errorQuda("Staggered dslash has not been built");