1 #include <gauge_field.h>
2 #include <color_spinor_field.h>
6 #include <dslash_policy.cuh>
7 #include <kernels/dslash_ndeg_twisted_mass_preconditioned.cuh>
10 This is the preconditioned twisted-mass operator acting on a non-generate
17 // trait to ensure we don't instantiate asymmetric & xpay
18 template <bool symmetric> constexpr bool xpay_() { return true; }
19 template <> constexpr bool xpay_<true>() { return false; }
21 // trait to ensure we don't instantiate asymmetric & !dagger
22 template <bool symmetric> constexpr bool not_dagger_() { return false; }
23 template <> constexpr bool not_dagger_<true>() { return true; }
25 template <typename Arg> class NdegTwistedMassPreconditioned : public Dslash<nDegTwistedMassPreconditioned, Arg>
27 using Dslash = Dslash<nDegTwistedMassPreconditioned, Arg>;
33 unsigned int sharedBytesPerThread() const
35 return shared ? 2 * in.Ncolor() * 4 * sizeof(typename mapper<typename Arg::Float>::type) : 0;
39 NdegTwistedMassPreconditioned(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) :
41 shared(arg.asymmetric || !arg.dagger)
43 TunableVectorYZ::resizeVector(2, arg.nParity);
44 if (shared) TunableVectorY::resizeStep(2); // this will force flavor to be contained in the block
47 void apply(const qudaStream_t &stream)
49 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
51 if (arg.asymmetric && !arg.dagger) errorQuda("asymmetric operator only defined for dagger");
52 if (arg.asymmetric && arg.xpay) errorQuda("asymmetric operator not defined for xpay");
53 if (arg.nParity != 1) errorQuda("Preconditioned non-degenerate twisted-mass operator not defined nParity=%d", arg.nParity);
57 Dslash::template instantiate<packShmem, 1, true, xpay_<Arg::asymmetric>()>(tp, stream);
59 Dslash::template instantiate<packShmem, 1, true, false>(tp, stream);
62 Dslash::template instantiate<packShmem, 1, not_dagger_<Arg::asymmetric>(), xpay_<Arg::asymmetric>()>(tp, stream);
64 Dslash::template instantiate<packShmem, 1, not_dagger_<Arg::asymmetric>(), false>(tp, stream);
68 void initTuneParam(TuneParam ¶m) const
70 Dslash::initTuneParam(param);
71 if (shared) param.shared_bytes = sharedBytesPerThread() * param.block.x * param.block.y * param.block.z;
74 void defaultTuneParam(TuneParam ¶m) const
76 Dslash::defaultTuneParam(param);
77 if (shared) param.shared_bytes = sharedBytesPerThread() * param.block.x * param.block.y * param.block.z;
80 long long flops() const
82 long long flops = Dslash::flops();
83 switch (arg.kernel_type) {
87 flops += 2 * in.Ncolor() * 4 * 4 * in.Volume(); // complex * Nc * Ns * fma * vol
89 default: break; // twisted-mass flops are in the interior kernel
95 template <typename Float, int nColor, QudaReconstructType recon> struct NdegTwistedMassPreconditionedApply {
97 inline NdegTwistedMassPreconditionedApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U,
98 double a, double b, double c, bool xpay, const ColorSpinorField &x, int parity, bool dagger, bool asymmetric,
99 const int *comm_override, TimeProfile &profile)
101 constexpr int nDim = 4;
103 NdegTwistedMassArg<Float, nColor, nDim, recon, true> arg(out, in, U, a, b, c, xpay, x, parity, dagger, comm_override);
104 NdegTwistedMassPreconditioned<decltype(arg)> twisted(arg, out, in);
106 dslash::DslashPolicyTune<decltype(twisted)> policy(twisted,
107 const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)),
108 in.getDslashConstant().volume_4d_cb, in.getDslashConstant().ghostFaceCB, profile);
111 NdegTwistedMassArg<Float, nColor, nDim, recon, false> arg(out, in, U, a, b, c, xpay, x, parity, dagger, comm_override);
112 NdegTwistedMassPreconditioned<decltype(arg)> twisted(arg, out, in);
114 dslash::DslashPolicyTune<decltype(twisted)> policy(twisted,
115 const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)),
116 in.getDslashConstant().volume_4d_cb, in.getDslashConstant().ghostFaceCB, profile);
122 // Apply the non-degenerate twisted-mass Dslash operator
123 // out(x) = M*in = a*(1 + i*b*gamma_5*tau_3 + c*tau_1)*D + x
124 // Uses the kappa normalization for the Wilson operator, with a = -kappa.
125 void ApplyNdegTwistedMassPreconditioned(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U,
126 double a, double b, double c, bool xpay, const ColorSpinorField &x, int parity, bool dagger, bool asymmetric,
127 const int *comm_override, TimeProfile &profile)
129 #ifdef GPU_NDEG_TWISTED_MASS_DIRAC
130 // with symmetric dagger operator we must use kernel packing
131 if (dagger && !asymmetric) pushKernelPackT(true);
133 instantiate<NdegTwistedMassPreconditionedApply>(
134 out, in, U, a, b, c, xpay, x, parity, dagger, asymmetric, comm_override, profile);
136 if (dagger && !asymmetric) popKernelPackT();
138 errorQuda("Non-degenerate twisted-mass dslash has not been built");
139 #endif // GPU_NDEG_TWISTED_MASS_DIRAC