QUDA  v1.1.0
A library for QCD on GPUs
dslash_domain_wall_4d.cu
Go to the documentation of this file.
1 #include <gauge_field.h>
2 #include <color_spinor_field.h>
3 #include <dslash.h>
4 #include <worker.h>
5 
6 #include <dslash_policy.cuh>
7 #include <kernels/dslash_domain_wall_4d.cuh>
8 
9 /**
10  This is the gauged domain-wall 4-d preconditioned operator.
11 
12  Note, for now, this just applies a batched 4-d dslash across the fifth
13  dimension.
14 */
15 
16 namespace quda
17 {
18 
19  template <typename Arg> class DomainWall4D : public Dslash<domainWall4D, Arg>
20  {
21  using Dslash = Dslash<domainWall4D, Arg>;
22  using Dslash::arg;
23  using Dslash::in;
24 
25  public:
26  DomainWall4D(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) : Dslash(arg, out, in)
27  {
28  TunableVectorYZ::resizeVector(in.X(4), arg.nParity);
29  }
30 
31  void apply(const qudaStream_t &stream)
32  {
33  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
34  Dslash::setParam(tp);
35  typedef typename mapper<typename Arg::Float>::type real;
36 #ifdef JITIFY
37  // we need to break the dslash launch abstraction here to get a handle on the constant memory pointer in the kernel module
38  auto instance = Dslash::template kernel_instance<packShmem>();
39  cuMemcpyHtoDAsync(instance.get_constant_ptr("quda::mobius_d"), arg.a_5, QUDA_MAX_DWF_LS * sizeof(complex<real>),
40  stream);
41  Tunable::jitify_error = instance.configure(tp.grid, tp.block, tp.shared_bytes, stream).launch(arg);
42 #else
43  cudaMemcpyToSymbolAsync(mobius_d, arg.a_5, QUDA_MAX_DWF_LS * sizeof(complex<real>), 0, cudaMemcpyHostToDevice,
44  streams[Nstream - 1]);
45  Dslash::template instantiate<packShmem>(tp, stream);
46 #endif
47  }
48  };
49 
50  template <typename Float, int nColor, QudaReconstructType recon> struct DomainWall4DApply {
51 
52  inline DomainWall4DApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a,
53  double m_5, const Complex *b_5, const Complex *c_5, const ColorSpinorField &x, int parity,
54  bool dagger, const int *comm_override, TimeProfile &profile)
55  {
56  constexpr int nDim = 4;
57  DomainWall4DArg<Float, nColor, nDim, recon> arg(out, in, U, a, m_5, b_5, c_5, a != 0.0, x, parity, dagger,
58  comm_override);
59  DomainWall4D<decltype(arg)> dwf(arg, out, in);
60 
61  dslash::DslashPolicyTune<decltype(dwf)> policy(
62  dwf, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)),
63  in.getDslashConstant().volume_4d_cb, in.getDslashConstant().ghostFaceCB, profile);
64  policy.apply(0);
65  }
66  };
67 
68  // Apply the 4-d preconditioned domain-wall Dslash operator
69  // out(x) = M*in = in(x) + a*\sum_mu U_{-\mu}(x)in(x+mu) + U^\dagger_mu(x-mu)in(x-mu)
70  void ApplyDomainWall4D(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double m_5,
71  const Complex *b_5, const Complex *c_5, const ColorSpinorField &x, int parity, bool dagger,
72  const int *comm_override, TimeProfile &profile)
73  {
74 #ifdef GPU_DOMAIN_WALL_DIRAC
75  instantiate<DomainWall4DApply>(out, in, U, a, m_5, b_5, c_5, x, parity, dagger, comm_override, profile);
76 #else
77  errorQuda("Domain-wall dslash has not been built");
78 #endif // GPU_DOMAIN_WALL_DIRAC
79  }
80 
81 } // namespace quda