1 #include <gauge_field.h>
2 #include <color_spinor_field.h>
6 #include <dslash_policy.cuh>
7 #include <kernels/dslash_ndeg_twisted_mass.cuh>
10 This is the gauged twisted-mass operator acting on a non-generate
17 template <typename Arg> class NdegTwistedMass : public Dslash<nDegTwistedMass, Arg>
19 using Dslash = Dslash<nDegTwistedMass, Arg>;
24 NdegTwistedMass(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) : Dslash(arg, out, in)
26 TunableVectorYZ::resizeVector(2, arg.nParity);
29 void apply(const qudaStream_t &stream)
31 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
34 Dslash::template instantiate<packShmem, true>(tp, stream);
36 errorQuda("Non-degenerate twisted-mass operator only defined for xpay=true");
39 long long flops() const
41 long long flops = Dslash::flops();
42 switch (arg.kernel_type) {
46 flops += 2 * in.Ncolor() * 4 * 4 * in.Volume(); // complex * Nc * Ns * fma * vol
48 default: break; // twisted-mass flops are in the interior kernel
54 template <typename Float, int nColor, QudaReconstructType recon> struct NdegTwistedMassApply {
56 inline NdegTwistedMassApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a,
57 double b, double c, const ColorSpinorField &x, int parity, bool dagger,
58 const int *comm_override, TimeProfile &profile)
60 constexpr int nDim = 4;
61 NdegTwistedMassArg<Float, nColor, nDim, recon> arg(out, in, U, a, b, c, x, parity, dagger, comm_override);
62 NdegTwistedMass<decltype(arg)> twisted(arg, out, in);
64 dslash::DslashPolicyTune<decltype(twisted)> policy(
65 twisted, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)),
66 in.getDslashConstant().volume_4d_cb, in.getDslashConstant().ghostFaceCB, profile);
71 void ApplyNdegTwistedMass(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double b,
72 double c, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override,
75 #ifdef GPU_NDEG_TWISTED_MASS_DIRAC
76 instantiate<NdegTwistedMassApply>(out, in, U, a, b, c, x, parity, dagger, comm_override, profile);
78 errorQuda("Non-degenerate twisted-mass dslash has not been built");
79 #endif // GPU_NDEG_TWISTED_MASS_DIRAC