QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
dslash_domain_wall_5d.cu
Go to the documentation of this file.
1 #include <gauge_field.h>
2 #include <color_spinor_field.h>
3 #include <dslash.h>
4 #include <worker.h>
5 
6 #include <dslash_policy.cuh>
8 
13 namespace quda
14 {
15 
20  template <typename Float, int nDim, int nColor, int nParity, bool dagger, bool xpay, KernelType kernel_type, typename Arg>
22  static constexpr const char *kernel = "quda::domainWall5DGPU"; // kernel name for jit compilation
23  template <typename Dslash>
24  inline static void launch(Dslash &dslash, TuneParam &tp, Arg &arg, const cudaStream_t &stream)
25  {
26  dslash.launch(domainWall5DGPU<Float, nDim, nColor, nParity, dagger, xpay, kernel_type, Arg>, tp, arg, stream);
27  }
28  };
29 
30  template <typename Float, int nDim, int nColor, typename Arg> class DomainWall5D : public Dslash<Float>
31  {
32 
33 protected:
34  Arg &arg;
36 
37 public:
39  Dslash<Float>(arg, out, in, "kernels/dslash_domain_wall_5d.cuh"),
40  arg(arg),
41  in(in)
42  {
44  }
45 
46  virtual ~DomainWall5D() {}
47 
48  void apply(const cudaStream_t &stream)
49  {
50  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
52  Dslash<Float>::template instantiate<DomainWall5DLaunch, nDim, nColor>(tp, arg, stream);
53  }
54 
55  long long flops() const
56  {
57  long long flops = Dslash<Float>::flops();
58  switch (arg.kernel_type) {
59  case EXTERIOR_KERNEL_X:
60  case EXTERIOR_KERNEL_Y:
61  case EXTERIOR_KERNEL_Z:
62  case EXTERIOR_KERNEL_T:
63  case EXTERIOR_KERNEL_ALL: break; // 5-d flops are in the interior kernel
64  case INTERIOR_KERNEL:
65  case KERNEL_POLICY:
66  int Ls = in.X(4);
67  long long bulk = (Ls - 2) * (in.Volume() / Ls);
68  long long wall = 2 * (in.Volume() / Ls);
69  flops += 96ll * bulk + 120ll * wall;
70  break;
71  }
72  return flops;
73  }
74 
75  long long bytes() const
76  {
77  bool isFixed = (in.Precision() == sizeof(short) || in.Precision() == sizeof(char)) ? true : false;
78  int spinor_bytes = 2 * in.Ncolor() * in.Nspin() * in.Precision() + (isFixed ? sizeof(float) : 0);
79  long long bytes = Dslash<Float>::bytes();
80  switch (arg.kernel_type) {
81  case EXTERIOR_KERNEL_X:
82  case EXTERIOR_KERNEL_Y:
83  case EXTERIOR_KERNEL_Z:
84  case EXTERIOR_KERNEL_T:
85  case EXTERIOR_KERNEL_ALL: break;
86  case INTERIOR_KERNEL:
87  case KERNEL_POLICY: bytes += 2 * spinor_bytes * in.VolumeCB(); break;
88  }
89  return bytes;
90  }
91 
92  TuneKey tuneKey() const
93  {
94  return TuneKey(in.VolString(), typeid(*this).name(), Dslash<Float>::aux[arg.kernel_type]);
95  }
96  };
97 
98  template <typename Float, int nColor, QudaReconstructType recon> struct DomainWall5DApply {
99 
101  double m_f, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
102  {
103  constexpr int nDim = 5;
104  DomainWall5DArg<Float, nColor, recon> arg(out, in, U, a, m_f, a != 0.0, x, parity, dagger, comm_override);
106 
108  const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)),
110  policy.apply(0);
111 
112  checkCudaError();
113  }
114  };
115 
116  // Apply the 4-d preconditioned domain-wall Dslash operator
117  // out(x) = M*in = in(x) + a*\sum_mu U_{-\mu}(x)in(x+mu) + U^\dagger_mu(x-mu)in(x-mu)
118  void ApplyDomainWall5D(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double m_f,
119  const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
120  {
121 #ifdef GPU_DOMAIN_WALL_DIRAC
122  if (in.V() == out.V()) errorQuda("Aliasing pointers");
123  if (in.FieldOrder() != out.FieldOrder())
124  errorQuda("Field order mismatch in = %d, out = %d", in.FieldOrder(), out.FieldOrder());
125 
126  // check all precisions match
127  checkPrecision(out, in, x, U);
128 
129  // check all locations match
130  checkLocation(out, in, x, U);
131 
132  // with 5-d checkerboarding we must use kernel packing
133  pushKernelPackT(true);
134 
135  instantiate<DomainWall5DApply>(out, in, U, a, m_f, x, parity, dagger, comm_override, profile);
136 
137  popKernelPackT();
138 #else
139  errorQuda("Domain-wall dslash has not been built");
140 #endif // GPU_DOMAIN_WALL_DIRAC
141  }
142 
143 } // namespace quda
void launch(T *f, const TuneParam &tp, Arg &arg, const cudaStream_t &stream)
Definition: dslash.h:101
void setParam(Arg &arg)
Definition: dslash.h:66
void apply(const cudaStream_t &stream)
This is a helper class that is used to instantiate the correct templated kernel for the dslash...
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define checkPrecision(...)
#define errorQuda(...)
Definition: util_quda.h:121
DomainWall5DApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double m_f, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
cudaStream_t * stream
const char * VolString() const
void ApplyDomainWall5D(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double m_f, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
Driver for applying the Domain-wall 5-d stencil to a 5-d vector with 5-d preconditioned data order...
void apply(const cudaStream_t &stream)
int Ls
Definition: test_util.cpp:38
void popKernelPackT()
Definition: dslash_quda.cu:42
virtual long long bytes() const
Definition: dslash.h:364
const ColorSpinorField & in
cpuColorSpinorField * in
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
#define checkLocation(...)
static void launch(Dslash &dslash, TuneParam &tp, Arg &arg, const cudaStream_t &stream)
long long flops() const
static constexpr const char * kernel
int ghostFaceCB[QUDA_MAX_DIM+1]
const DslashConstant & getDslashConstant() const
Get the dslash_constant structure from this field.
DomainWall5D(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in)
cpuColorSpinorField * out
const int nParity
Definition: spinor_noise.cu:25
void resizeVector(int y, int z) const
Definition: tune_quda.h:538
unsigned long long flops
Definition: blas_quda.cu:22
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
const int * X() const
void pushKernelPackT(bool pack)
Definition: dslash_quda.cu:30
#define checkCudaError()
Definition: util_quda.h:161
virtual long long flops() const
Definition: dslash.h:316
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
QudaPrecision Precision() const
QudaDagType dagger
Definition: test_util.cpp:1620
QudaParity parity
Definition: covdev_test.cpp:54
long long bytes() const
QudaFieldOrder FieldOrder() const
unsigned long long bytes
Definition: blas_quda.cu:23