QUDA  1.0.0
dslash_domain_wall_4d.cu
Go to the documentation of this file.
1 #include <gauge_field.h>
2 #include <color_spinor_field.h>
3 #include <dslash.h>
4 #include <worker.h>
5 
6 #include <dslash_policy.cuh>
8 
16 namespace quda
17 {
18 
23  template <typename Float, int nDim, int nColor, int nParity, bool dagger, bool xpay, KernelType kernel_type, typename Arg>
25  static constexpr const char *kernel = "quda::domainWall4DGPU"; // kernel name for jit compilation
26  template <typename Dslash>
27  inline static void launch(Dslash &dslash, TuneParam &tp, Arg &arg, const cudaStream_t &stream)
28  {
29  dslash.launch(domainWall4DGPU<Float, nDim, nColor, nParity, dagger, xpay, kernel_type, Arg>, tp, arg, stream);
30  }
31  };
32 
33  template <typename Float, int nDim, int nColor, typename Arg> class DomainWall4D : public Dslash<Float>
34  {
35 
36 protected:
37  Arg &arg;
39 
40 public:
42  Dslash<Float>(arg, out, in, "kernels/dslash_domain_wall_4d.cuh"),
43  arg(arg),
44  in(in)
45  {
47  }
48 
49  virtual ~DomainWall4D() {}
50 
51  void apply(const cudaStream_t &stream)
52  {
53  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
55  typedef typename mapper<Float>::type real;
56 #ifdef JITIFY
57  // we need to break the dslash launch abstraction here to get a handle on the constant memory pointer in the kernel module
58  using namespace jitify::reflection;
60  auto instance = Dslash<Float>::program_->kernel(kernel).instantiate(
61  Type<Float>(), nDim, nColor, arg.nParity, arg.dagger, arg.xpay, arg.kernel_type, Type<Arg>());
62  cuMemcpyHtoDAsync(instance.get_constant_ptr("quda::mobius_d"), arg.a_5, QUDA_MAX_DWF_LS * sizeof(complex<real>),
63  stream);
64  Tunable::jitify_error = instance.configure(tp.grid, tp.block, tp.shared_bytes, stream).launch(arg);
65 #else
66  cudaMemcpyToSymbolAsync(mobius_d, arg.a_5, QUDA_MAX_DWF_LS * sizeof(complex<real>), 0, cudaMemcpyHostToDevice,
67  streams[Nstream - 1]);
68  Dslash<Float>::template instantiate<DomainWall4DLaunch, nDim, nColor>(tp, arg, stream);
69 #endif
70  }
71 
72  TuneKey tuneKey() const
73  {
74  return TuneKey(in.VolString(), typeid(*this).name(), Dslash<Float>::aux[arg.kernel_type]);
75  }
76  };
77 
78  template <typename Float, int nColor, QudaReconstructType recon> struct DomainWall4DApply {
79 
80  inline DomainWall4DApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a,
81  double m_5, const Complex *b_5, const Complex *c_5, const ColorSpinorField &x, int parity,
82  bool dagger, const int *comm_override, TimeProfile &profile)
83  {
84  constexpr int nDim = 4;
85  DomainWall4DArg<Float, nColor, recon> arg(out, in, U, a, m_5, b_5, c_5, a != 0.0, x, parity, dagger, comm_override);
87 
89  twisted, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)),
91  policy.apply(0);
92 
94  }
95  };
96 
97  // Apply the 4-d preconditioned domain-wall Dslash operator
98  // out(x) = M*in = in(x) + a*\sum_mu U_{-\mu}(x)in(x+mu) + U^\dagger_mu(x-mu)in(x-mu)
99  void ApplyDomainWall4D(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double m_5,
100  const Complex *b_5, const Complex *c_5, const ColorSpinorField &x, int parity, bool dagger,
101  const int *comm_override, TimeProfile &profile)
102  {
103 #ifdef GPU_DOMAIN_WALL_DIRAC
104  if (in.V() == out.V()) errorQuda("Aliasing pointers");
105  if (in.FieldOrder() != out.FieldOrder())
106  errorQuda("Field order mismatch in = %d, out = %d", in.FieldOrder(), out.FieldOrder());
107 
108  // check all precisions match
109  checkPrecision(out, in, x, U);
110 
111  // check all locations match
112  checkLocation(out, in, x, U);
113 
114  instantiate<DomainWall4DApply>(out, in, U, a, m_5, b_5, c_5, x, parity, dagger, comm_override, profile);
115 #else
116  errorQuda("Domain-wall dslash has not been built");
117 #endif // GPU_DOMAIN_WALL_DIRAC
118  }
119 
120 } // namespace quda
static __constant__ char mobius_d[size]
void launch(T *f, const TuneParam &tp, Arg &arg, const cudaStream_t &stream)
Definition: dslash.h:101
void setParam(Arg &arg)
Definition: dslash.h:66
void apply(const cudaStream_t &stream)
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define checkPrecision(...)
#define errorQuda(...)
Definition: util_quda.h:121
cudaStream_t * streams
cudaStream_t * stream
DomainWall4D(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in)
const int Nstream
Definition: quda_internal.h:83
const char * VolString() const
void apply(const cudaStream_t &stream)
static constexpr const char * kernel
const int nColor
Definition: covdev_test.cpp:75
cpuColorSpinorField * in
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
CUresult jitify_error
Definition: tune_quda.h:276
#define checkLocation(...)
This is a helper class that is used to instantiate the correct templated kernel for the dslash...
std::complex< double > Complex
Definition: quda_internal.h:46
void ApplyDomainWall4D(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double m_5, const Complex *b_5, const Complex *c_5, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
Driver for applying the batched Wilson 4-d stencil to a 5-d vector with 4-d preconditioned data order...
const ColorSpinorField & in
int ghostFaceCB[QUDA_MAX_DIM+1]
DomainWall4DApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double m_5, const Complex *b_5, const Complex *c_5, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
const DslashConstant & getDslashConstant() const
Get the dslash_constant structure from this field.
cpuColorSpinorField * out
const int nParity
Definition: spinor_noise.cu:25
#define QUDA_MAX_DWF_LS
Maximum length of the Ls dimension for domain-wall fermions.
void resizeVector(int y, int z) const
Definition: tune_quda.h:538
void instantiate(TuneParam &tp, Arg &arg, const cudaStream_t &stream)
This instantiate function is used to instantiate the the KernelType template required for the multi-G...
Definition: dslash.h:119
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
static void launch(Dslash &dslash, TuneParam &tp, Arg &arg, const cudaStream_t &stream)
const int * X() const
#define checkCudaError()
Definition: util_quda.h:161
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
QudaDagType dagger
Definition: test_util.cpp:1620
QudaParity parity
Definition: covdev_test.cpp:54
QudaFieldOrder FieldOrder() const