1 #include <gauge_field.h>
2 #include <color_spinor_field.h>
6 #include <dslash_policy.cuh>
7 #include <kernels/dslash_domain_wall_4d.cuh>
10 This is the gauged domain-wall 4-d preconditioned operator.
12 Note, for now, this just applies a batched 4-d dslash across the fifth
19 template <typename Arg> class DomainWall4D : public Dslash<domainWall4D, Arg>
21 using Dslash = Dslash<domainWall4D, Arg>;
26 DomainWall4D(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) : Dslash(arg, out, in)
28 TunableVectorYZ::resizeVector(in.X(4), arg.nParity);
31 void apply(const qudaStream_t &stream)
33 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
35 typedef typename mapper<typename Arg::Float>::type real;
37 // we need to break the dslash launch abstraction here to get a handle on the constant memory pointer in the kernel module
38 auto instance = Dslash::template kernel_instance<packShmem>();
39 cuMemcpyHtoDAsync(instance.get_constant_ptr("quda::mobius_d"), arg.a_5, QUDA_MAX_DWF_LS * sizeof(complex<real>),
41 Tunable::jitify_error = instance.configure(tp.grid, tp.block, tp.shared_bytes, stream).launch(arg);
43 cudaMemcpyToSymbolAsync(mobius_d, arg.a_5, QUDA_MAX_DWF_LS * sizeof(complex<real>), 0, cudaMemcpyHostToDevice,
44 streams[Nstream - 1]);
45 Dslash::template instantiate<packShmem>(tp, stream);
50 template <typename Float, int nColor, QudaReconstructType recon> struct DomainWall4DApply {
52 inline DomainWall4DApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a,
53 double m_5, const Complex *b_5, const Complex *c_5, const ColorSpinorField &x, int parity,
54 bool dagger, const int *comm_override, TimeProfile &profile)
56 constexpr int nDim = 4;
57 DomainWall4DArg<Float, nColor, nDim, recon> arg(out, in, U, a, m_5, b_5, c_5, a != 0.0, x, parity, dagger,
59 DomainWall4D<decltype(arg)> dwf(arg, out, in);
61 dslash::DslashPolicyTune<decltype(dwf)> policy(
62 dwf, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)),
63 in.getDslashConstant().volume_4d_cb, in.getDslashConstant().ghostFaceCB, profile);
68 // Apply the 4-d preconditioned domain-wall Dslash operator
69 // out(x) = M*in = in(x) + a*\sum_mu U_{-\mu}(x)in(x+mu) + U^\dagger_mu(x-mu)in(x-mu)
70 void ApplyDomainWall4D(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double m_5,
71 const Complex *b_5, const Complex *c_5, const ColorSpinorField &x, int parity, bool dagger,
72 const int *comm_override, TimeProfile &profile)
74 #ifdef GPU_DOMAIN_WALL_DIRAC
75 instantiate<DomainWall4DApply>(out, in, U, a, m_5, b_5, c_5, x, parity, dagger, comm_override, profile);
77 errorQuda("Domain-wall dslash has not been built");
78 #endif // GPU_DOMAIN_WALL_DIRAC