1 #ifndef USE_LEGACY_DSLASH
3 #include <gauge_field.h>
4 #include <color_spinor_field.h>
5 #include <clover_field.h>
9 #include <dslash_policy.cuh>
10 #include <kernels/dslash_wilson_clover_hasenbusch_twist_preconditioned.cuh>
15 /* ***************************
16 * No Clov Inv: 1 - k^2 D - i mu gamma_5 A
17 * **************************/
18 template <typename Arg>
19 class WilsonCloverHasenbuschTwistPCNoClovInv : public Dslash<cloverHasenbuschPreconditioned, Arg>
21 using Dslash = Dslash<cloverHasenbuschPreconditioned, Arg>;
26 WilsonCloverHasenbuschTwistPCNoClovInv(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) :
27 Dslash(arg, out, in) {}
29 void apply(const qudaStream_t &stream)
31 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
34 // specialize here to constrain the template instantiation
35 if (arg.nParity == 1) {
37 Dslash::template instantiate<packShmem, 1, true>(tp, stream);
39 errorQuda("Operator only defined for xpay=true");
41 errorQuda("Operator not defined nParity=%d", arg.nParity);
45 long long flops() const
47 int clover_flops = 504;
48 long long flops = Dslash::flops();
49 switch (arg.kernel_type) {
50 case EXTERIOR_KERNEL_X:
51 case EXTERIOR_KERNEL_Y:
52 case EXTERIOR_KERNEL_Z:
53 case EXTERIOR_KERNEL_T:
54 // 2 from fwd / back face * 1 clover terms:
55 // there is no A^{-1}D only D
56 // there is one clover_term and 48 is the - mu (igamma_5) A
57 flops += 2 * (clover_flops + 48) * in.GhostFace()[arg.kernel_type];
59 case EXTERIOR_KERNEL_ALL:
61 += 2 * (clover_flops + 48) * (in.GhostFace()[0] + in.GhostFace()[1] + in.GhostFace()[2] + in.GhostFace()[3]);
66 flops += (clover_flops + 48) * in.Volume();
68 if (arg.kernel_type == KERNEL_POLICY) break;
69 // now correct for flops done by exterior kernel
70 long long ghost_sites = 0;
71 for (int d = 0; d < 4; d++)
72 if (arg.commDim[d]) ghost_sites += 2 * in.GhostFace()[d];
73 flops -= (clover_flops + 48) * ghost_sites;
80 long long bytes() const
82 int clover_bytes = 72 * in.Precision() + (isFixed<typename Arg::Float>::value ? 2 * sizeof(float) : 0);
84 long long bytes = Dslash::bytes();
85 switch (arg.kernel_type) {
86 case EXTERIOR_KERNEL_X:
87 case EXTERIOR_KERNEL_Y:
88 case EXTERIOR_KERNEL_Z:
89 case EXTERIOR_KERNEL_T:
90 // Factor of 2 is from the fwd/back faces.
91 bytes += clover_bytes * 2 * in.GhostFace()[arg.kernel_type];
93 case EXTERIOR_KERNEL_ALL:
94 // Factor of 2 is from the fwd/back faces
95 bytes += clover_bytes * 2 * (in.GhostFace()[0] + in.GhostFace()[1] + in.GhostFace()[2] + in.GhostFace()[3]);
101 bytes += clover_bytes * in.Volume();
103 if (arg.kernel_type == KERNEL_POLICY) break;
104 // now correct for bytes done by exterior kernel
105 long long ghost_sites = 0;
106 for (int d = 0; d < 4; d++)
107 if (arg.commDim[d]) ghost_sites += 2 * in.GhostFace()[d];
108 bytes -= clover_bytes * ghost_sites;
117 template <typename Float, int nColor, QudaReconstructType recon> struct WilsonCloverHasenbuschTwistPCNoClovInvApply {
119 inline WilsonCloverHasenbuschTwistPCNoClovInvApply(ColorSpinorField &out, const ColorSpinorField &in,
120 const GaugeField &U, const CloverField &A, double a, double b,
121 const ColorSpinorField &x, int parity, bool dagger,
122 const int *comm_override, TimeProfile &profile)
124 constexpr int nDim = 4;
125 using ArgType = WilsonCloverHasenbuschTwistPCArg<Float, nColor, nDim, recon, false>;
127 ArgType arg(out, in, U, A, a, b, x, parity, dagger, comm_override);
128 WilsonCloverHasenbuschTwistPCNoClovInv<ArgType> wilson(arg, out, in);
130 dslash::DslashPolicyTune<decltype(wilson)> policy(
131 wilson, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
132 in.GhostFaceCB(), profile);
137 // Apply the Wilson-clover operator
138 // out(x) = M*in = (A(x) + kappa * \sum_mu U_{-\mu}(x)in(x+mu) + U^\dagger_mu(x-mu)in(x-mu))
139 // Uses the kappa normalization for the Wilson operator.
140 void ApplyWilsonCloverHasenbuschTwistPCNoClovInv(ColorSpinorField &out, const ColorSpinorField &in,
141 const GaugeField &U, const CloverField &A, double a, double b,
142 const ColorSpinorField &x, int parity, bool dagger,
143 const int *comm_override, TimeProfile &profile)
145 #ifdef GPU_CLOVER_HASENBUSCH_TWIST
146 if (in.V() == out.V()) errorQuda("Aliasing pointers");
147 if (in.FieldOrder() != out.FieldOrder())
148 errorQuda("Field order mismatch in = %d, out = %d", in.FieldOrder(), out.FieldOrder());
150 // check all precisions match
151 checkPrecision(out, in, U, A);
153 // check all locations match
154 checkLocation(out, in, U, A);
156 instantiate<WilsonCloverHasenbuschTwistPCNoClovInvApply>(out, in, U, A, a, b, x, parity, dagger, comm_override,
159 errorQuda("Clover Hasenbusch Twist dslash has not been built");
163 /* ***************************
166 * M = psi_p - k^2 A^{-1} D_p\not{p} - i mu gamma_5 A_{pp} psi_{p}
167 * **************************/
168 template <typename Arg>
169 class WilsonCloverHasenbuschTwistPCClovInv : public Dslash<cloverHasenbuschPreconditioned, Arg>
171 using Dslash = Dslash<cloverHasenbuschPreconditioned, Arg>;
176 WilsonCloverHasenbuschTwistPCClovInv(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) :
177 Dslash(arg, out, in) {}
179 void apply(const qudaStream_t &stream)
181 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
182 Dslash::setParam(tp);
184 // specialize here to constrain the template instantiation
185 if (arg.nParity == 1) {
187 Dslash::template instantiate<packShmem, 1, true>(tp, stream);
189 errorQuda("Operator only defined for xpay=true");
191 errorQuda("Operator not defined nParity=%d", arg.nParity);
195 long long flops() const
197 int clover_flops = 504;
198 long long flops = Dslash::flops();
199 switch (arg.kernel_type) {
200 case EXTERIOR_KERNEL_X:
201 case EXTERIOR_KERNEL_Y:
202 case EXTERIOR_KERNEL_Z:
203 case EXTERIOR_KERNEL_T:
204 // 2 from fwd / back face * 2 clover terms:
205 // one clover_term from the A^{-1}D
206 // second clover_term and 48 is the - mu (igamma_5) A
207 flops += 2 * (2 * clover_flops + 48) * in.GhostFace()[arg.kernel_type];
209 case EXTERIOR_KERNEL_ALL:
210 flops += 2 * (2 * clover_flops + 48)
211 * (in.GhostFace()[0] + in.GhostFace()[1] + in.GhostFace()[2] + in.GhostFace()[3]);
213 case INTERIOR_KERNEL:
216 flops += (2 * clover_flops + 48) * in.Volume();
218 if (arg.kernel_type == KERNEL_POLICY) break;
219 // now correct for flops done by exterior kernel
220 long long ghost_sites = 0;
221 for (int d = 0; d < 4; d++)
222 if (arg.commDim[d]) ghost_sites += 2 * in.GhostFace()[d];
223 flops -= (2 * clover_flops + 48) * ghost_sites;
230 long long bytes() const
232 int clover_bytes = 72 * in.Precision() + (isFixed<typename Arg::Float>::value ? 2 * sizeof(float) : 0);
234 // if we use dynamic clover we read only A (even for A^{-1}
235 // otherwise we read both A and A^{-1}
236 int dyn_factor = arg.dynamic_clover ? 1 : 2;
238 long long bytes = Dslash::bytes();
239 switch (arg.kernel_type) {
240 case EXTERIOR_KERNEL_X:
241 case EXTERIOR_KERNEL_Y:
242 case EXTERIOR_KERNEL_Z:
243 case EXTERIOR_KERNEL_T:
244 // Factor of 2 is from the fwd/back faces.
245 bytes += dyn_factor * clover_bytes * 2 * in.GhostFace()[arg.kernel_type];
247 case EXTERIOR_KERNEL_ALL:
248 // Factor of 2 is from the fwd/back faces
249 bytes += dyn_factor * clover_bytes * 2
250 * (in.GhostFace()[0] + in.GhostFace()[1] + in.GhostFace()[2] + in.GhostFace()[3]);
252 case INTERIOR_KERNEL:
256 bytes += dyn_factor * clover_bytes * in.Volume();
258 if (arg.kernel_type == KERNEL_POLICY) break;
259 // now correct for bytes done by exterior kernel
260 long long ghost_sites = 0;
261 for (int d = 0; d < 4; d++)
262 if (arg.commDim[d]) ghost_sites += 2 * in.GhostFace()[d];
263 bytes -= dyn_factor * clover_bytes * ghost_sites;
272 template <typename Float, int nColor, QudaReconstructType recon> struct WilsonCloverHasenbuschTwistPCClovInvApply {
274 inline WilsonCloverHasenbuschTwistPCClovInvApply(ColorSpinorField &out, const ColorSpinorField &in,
275 const GaugeField &U, const CloverField &A, double kappa, double mu,
276 const ColorSpinorField &x, int parity, bool dagger,
277 const int *comm_override, TimeProfile &profile)
279 constexpr int nDim = 4;
280 using ArgType = WilsonCloverHasenbuschTwistPCArg<Float, nColor, nDim, recon, true>;
281 ArgType arg(out, in, U, A, kappa, mu, x, parity, dagger, comm_override);
282 WilsonCloverHasenbuschTwistPCClovInv<ArgType> wilson(arg, out, in);
284 dslash::DslashPolicyTune<decltype(wilson)> policy(
285 wilson, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
286 in.GhostFaceCB(), profile);
291 // Apply the Wilson-clover operator
292 // out(x) = M*in = (A(x) + kappa * \sum_mu U_{-\mu}(x)in(x+mu) + U^\dagger_mu(x-mu)in(x-mu))
293 // Uses the kappa normalization for the Wilson operator.
294 void ApplyWilsonCloverHasenbuschTwistPCClovInv(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U,
295 const CloverField &A, double a, double b, const ColorSpinorField &x,
296 int parity, bool dagger, const int *comm_override, TimeProfile &profile)
298 #ifdef GPU_CLOVER_HASENBUSCH_TWIST
299 instantiate<WilsonCloverHasenbuschTwistPCClovInvApply>(out, in, U, A, a, b, x, parity, dagger, comm_override,
302 errorQuda("Clover Hasenbusch Twist dslash has not been built");