3 #include <dslash_helper.cuh>
4 #include <color_spinor_field_order.h>
5 #include <gauge_field_order.h>
6 #include <color_spinor.h>
7 #include <index_helper.cuh>
8 #include <gauge_field.h>
10 #include <dslash_policy.cuh>
11 #include <kernels/dslash_staggered.cuh>
14 This is a staggered Dirac operator
20 template <typename Arg> class Staggered : public Dslash<staggered, Arg>
22 using Dslash = Dslash<staggered, Arg>;
26 Staggered(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) : Dslash(arg, out, in) {}
28 void apply(const qudaStream_t &stream)
30 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
32 // operator is anti-Hermitian so do not instantiate dagger
33 if (arg.nParity == 1) {
35 Dslash::template instantiate<packStaggeredShmem, 1, false, true>(tp, stream);
37 Dslash::template instantiate<packStaggeredShmem, 1, false, false>(tp, stream);
38 } else if (arg.nParity == 2) {
40 Dslash::template instantiate<packStaggeredShmem, 2, false, true>(tp, stream);
42 Dslash::template instantiate<packStaggeredShmem, 2, false, false>(tp, stream);
47 template <typename Float, int nColor, QudaReconstructType recon_u> struct StaggeredApply {
49 inline StaggeredApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a,
50 const ColorSpinorField &x, int parity, bool dagger, const int *comm_override,
53 if (U.StaggeredPhase() == QUDA_STAGGERED_PHASE_MILC) {
54 #ifdef BUILD_MILC_INTERFACE
55 constexpr int nDim = 4; // MWTODO: this probably should be 5 for mrhs Dslash
56 constexpr bool improved = false;
58 StaggeredArg<Float, nColor, nDim, recon_u, QUDA_RECONSTRUCT_NO, improved, QUDA_STAGGERED_PHASE_MILC> arg(
59 out, in, U, U, a, x, parity, dagger, comm_override);
60 Staggered<decltype(arg)> staggered(arg, out, in);
62 dslash::DslashPolicyTune<decltype(staggered)> policy(
63 staggered, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
64 in.GhostFaceCB(), profile);
67 errorQuda("MILC interface has not been built so MILC phase staggered fermions not enabled");
69 } else if (U.StaggeredPhase() == QUDA_STAGGERED_PHASE_TIFR) {
70 #ifdef BUILD_TIFR_INTERFACE
71 constexpr int nDim = 4; // MWTODO: this probably should be 5 for mrhs Dslash
72 constexpr bool improved = false;
74 StaggeredArg<Float, nColor, nDim, recon_u, QUDA_RECONSTRUCT_NO, improved, QUDA_STAGGERED_PHASE_TIFR> arg(
75 out, in, U, U, a, x, parity, dagger, comm_override);
76 Staggered<decltype(arg)> staggered(arg, out, in);
78 dslash::DslashPolicyTune<decltype(staggered)> policy(
79 staggered, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
80 in.GhostFaceCB(), profile);
83 errorQuda("TIFR interface has not been built so TIFR phase taggered fermions not enabled");
86 errorQuda("Unsupported staggered phase type %d", U.StaggeredPhase());
91 void ApplyStaggered(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a,
92 const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
94 #ifdef GPU_STAGGERED_DIRAC
95 instantiate<StaggeredApply, StaggeredReconstruct>(out, in, U, a, x, parity, dagger, comm_override, profile);
97 errorQuda("Staggered dslash has not been built");