1 #include <gauge_field.h>
2 #include <color_spinor_field.h>
3 #include <clover_field.h>
7 #include <dslash_policy.cuh>
8 #include <kernels/dslash_wilson_clover.cuh>
11 This is the Wilson-clover linear operator
17 template <typename Arg> class WilsonClover : public Dslash<wilsonClover, Arg>
19 using Dslash = Dslash<wilsonClover, Arg>;
24 WilsonClover(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) : Dslash(arg, out, in) {}
26 void apply(const qudaStream_t &stream)
28 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
31 Dslash::template instantiate<packShmem, true>(tp, stream);
33 errorQuda("Wilson-clover operator only defined for xpay=true");
36 long long flops() const
38 int clover_flops = 504;
39 long long flops = Dslash::flops();
41 switch (arg.kernel_type) {
44 case KERNEL_POLICY: flops += clover_flops * in.Volume(); break;
45 default: break; // all clover flops are in the interior kernel
50 long long bytes() const
52 int clover_bytes = 72 * in.Precision() + (isFixed<typename Arg::Float>::value ? 2 * sizeof(float) : 0);
53 long long bytes = Dslash::bytes();
55 switch (arg.kernel_type) {
57 case KERNEL_POLICY: bytes += clover_bytes * in.Volume(); break;
65 template <typename Float, int nColor, QudaReconstructType recon> struct WilsonCloverApply {
67 inline WilsonCloverApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const CloverField &A,
68 double a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
70 constexpr int nDim = 4;
71 WilsonCloverArg<Float, nColor, nDim, recon> arg(out, in, U, A, a, 0.0, x, parity, dagger, comm_override);
72 WilsonClover<decltype(arg)> wilson(arg, out, in);
74 dslash::DslashPolicyTune<decltype(wilson)> policy(wilson,
75 const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
76 in.GhostFaceCB(), profile);
81 template <typename Float, int nColor, QudaReconstructType recon> struct WilsonCloverWithTwistApply {
83 inline WilsonCloverWithTwistApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U,
84 const CloverField &A, double a, double b, const ColorSpinorField &x, int parity,
85 bool dagger, const int *comm_override, TimeProfile &profile)
87 constexpr int nDim = 4;
88 WilsonCloverArg<Float, nColor, nDim, recon, true> arg(out, in, U, A, a, b, x, parity, dagger, comm_override);
89 WilsonClover<decltype(arg)> wilson(arg, out, in);
91 dslash::DslashPolicyTune<decltype(wilson)> policy(
92 wilson, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
93 in.GhostFaceCB(), profile);
98 // Apply the Wilson-clover operator
99 // out(x) = M*in = (A(x) + a * \sum_mu U_{-\mu}(x)in(x+mu) + U^\dagger_mu(x-mu)in(x-mu))
100 // Uses the kappa normalization for the Wilson operator.
101 void ApplyWilsonClover(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const CloverField &A,
102 double a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
104 #ifdef GPU_CLOVER_DIRAC
105 instantiate<WilsonCloverApply>(out, in, U, A, a, x, parity, dagger, comm_override, profile);
107 errorQuda("Clover dslash has not been built");