1 #include <gauge_field.h>
2 #include <color_spinor_field.h>
6 #include <dslash_policy.cuh>
7 #include <kernels/dslash_twisted_mass.cuh>
10 This is the basic gauged twisted-mass operator
16 template <typename Arg> class TwistedMass : public Dslash<twistedMass, Arg>
18 using Dslash = Dslash<twistedMass, Arg>;
23 TwistedMass(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) : Dslash(arg, out, in) {}
25 void apply(const qudaStream_t &stream)
27 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
30 Dslash::template instantiate<packShmem, true>(tp, stream);
32 errorQuda("Twisted-mass operator only defined for xpay=true");
35 long long flops() const
37 long long flops = Dslash::flops();
38 switch (arg.kernel_type) {
42 flops += 2 * in.Ncolor() * 4 * 2 * in.Volume(); // complex * Nc * Ns * fma * vol
44 default: break; // twisted-mass flops are in the interior kernel
50 template <typename Float, int nColor, QudaReconstructType recon> struct TwistedMassApply {
52 inline TwistedMassApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double b,
53 const ColorSpinorField &x, int parity, bool dagger, const int *comm_override,
56 constexpr int nDim = 4;
57 TwistedMassArg<Float, nColor, nDim, recon> arg(out, in, U, a, b, x, parity, dagger, comm_override);
58 TwistedMass<decltype(arg)> twisted(arg, out, in);
60 dslash::DslashPolicyTune<decltype(twisted)> policy(
61 twisted, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
62 in.GhostFaceCB(), profile);
67 // Apply the twisted-mass Dslash operator
68 // out(x) = M*in = (1 + i*b*gamma_5)*in(x) + a*\sum_mu U_{-\mu}(x)in(x+mu) + U^\dagger_mu(x-mu)in(x-mu)
69 // Uses the kappa normalization for the Wilson operator, with a = -kappa.
70 void ApplyTwistedMass(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double b,
71 const ColorSpinorField &x, int parity, bool dagger, const int *comm_override,
74 #ifdef GPU_TWISTED_MASS_DIRAC
75 instantiate<TwistedMassApply>(out, in, U, a, b, x, parity, dagger, comm_override, profile);
77 errorQuda("Twisted-mass dslash has not been built");
78 #endif // GPU_TWISTED_MASS_DIRAC