QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
dslash_staggered.cu
Go to the documentation of this file.
1 #include <dslash.h>
2 #include <worker.h>
3 #include <dslash_helper.cuh>
5 #include <gauge_field_order.h>
6 #include <color_spinor.h>
7 #include <dslash_helper.cuh>
8 #include <index_helper.cuh>
9 #include <gauge_field.h>
10 
11 #include <dslash_policy.cuh>
13 
18 namespace quda
19 {
20 
21  template <typename Float, int nDim, int nColor, int nParity, bool dagger, bool xpay, KernelType kernel_type, typename Arg>
22  struct StaggeredLaunch {
23  static constexpr const char *kernel = "quda::staggeredGPU"; // kernel name for jit compilation
24  template <typename Dslash>
25  inline static void launch(Dslash &dslash, TuneParam &tp, Arg &arg, const cudaStream_t &stream)
26  {
27  dslash.launch(staggeredGPU<Float, nDim, nColor, nParity, dagger, xpay, kernel_type, Arg>, tp, arg, stream);
28  }
29  };
30 
31  template <typename Float, int nDim, int nColor, typename Arg> class Staggered : public Dslash<Float>
32  {
33 
34 protected:
35  Arg &arg;
36  const ColorSpinorField &in;
37 
38 public:
39  Staggered(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) :
40  Dslash<Float>(arg, out, in, "kernels/dslash_staggered.cuh"),
41  arg(arg),
42  in(in)
43  {
44  }
45 
46  virtual ~Staggered() {}
47 
48  void apply(const cudaStream_t &stream)
49  {
50  if (in.Location() == QUDA_CPU_FIELD_LOCATION) {
51  errorQuda("Staggered Dslash not implemented on CPU");
52  } else {
53  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
55  Dslash<Float>::template instantiate<StaggeredLaunch, nDim, nColor>(tp, arg, stream);
56  }
57  }
58 
59  TuneKey tuneKey() const
60  {
61  return TuneKey(in.VolString(), typeid(*this).name(), Dslash<Float>::aux[arg.kernel_type]);
62  }
63  };
64 
65  template <typename Float, int nColor, QudaReconstructType recon_u> struct StaggeredApply {
66 
67  inline StaggeredApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a,
68  const ColorSpinorField &x, int parity, bool dagger, const int *comm_override,
69  TimeProfile &profile)
70  {
71 
73 #ifdef BUILD_MILC_INTERFACE
74  constexpr int nDim = 4; // MWTODO: this probably should be 5 for mrhs Dslash
75  constexpr bool improved = false;
76 
78  out, in, U, U, a, x, parity, dagger, comm_override);
80 
82  staggered, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
83  in.GhostFaceCB(), profile);
84  policy.apply(0);
85 #else
86  errorQuda("MILC interface has not been built so MILC phase staggered fermions not enabled");
87 #endif
88  } else if (U.StaggeredPhase() == QUDA_STAGGERED_PHASE_TIFR) {
89 #ifdef BUILD_TIFR_INTERFACE
90  constexpr int nDim = 4; // MWTODO: this probably should be 5 for mrhs Dslash
91  constexpr bool improved = false;
92 
94  out, in, U, U, a, x, parity, dagger, comm_override);
96 
98  staggered, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
99  in.GhostFaceCB(), profile);
100  policy.apply(0);
101 #else
102  errorQuda("TIFR interface has not been built so TIFR phase taggered fermions not enabled");
103 #endif
104  } else {
105  errorQuda("Unsupported staggered phase type %d", U.StaggeredPhase());
106  }
107 
108  checkCudaError();
109  }
110  };
111 
113  const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
114  {
115 
116 #ifdef GPU_STAGGERED_DIRAC
117  if (in.V() == out.V()) errorQuda("Aliasing pointers");
118  if (in.FieldOrder() != out.FieldOrder())
119  errorQuda("Field order mismatch in = %d, out = %d", in.FieldOrder(), out.FieldOrder());
120 
121  // check all precisions match
122  checkPrecision(out, in, U);
123 
124  // check all locations match
125  checkLocation(out, in, U);
126 
127  instantiate<StaggeredApply, StaggeredReconstruct>(out, in, U, a, x, parity, dagger, comm_override, profile);
128 #else
129  errorQuda("Staggered dslash has not been built");
130 #endif
131  }
132 
133 } // namespace quda
void launch(T *f, const TuneParam &tp, Arg &arg, const cudaStream_t &stream)
Definition: dslash.h:101
void setParam(Arg &arg)
Definition: dslash.h:66
void apply(const cudaStream_t &stream)
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define checkPrecision(...)
#define errorQuda(...)
Definition: util_quda.h:121
cudaStream_t * stream
static void launch(Dslash &dslash, TuneParam &tp, Arg &arg, const cudaStream_t &stream)
const char * VolString() const
StaggeredApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
__device__ __host__ void staggered(Arg &arg, int idx, int parity)
cpuColorSpinorField * in
const int * GhostFaceCB() const
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
#define checkLocation(...)
Main header file for host and device accessors to GaugeFields.
TuneKey tuneKey() const
QudaFieldLocation Location() const
Parameter structure for driving the Staggered Dslash operator.
cpuColorSpinorField * out
static constexpr const char * kernel
void ApplyStaggered(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
Apply the staggered dslash operator to a color-spinor field.
void apply(const cudaStream_t &stream)
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
Staggered(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in)
#define checkCudaError()
Definition: util_quda.h:161
QudaStaggeredPhase StaggeredPhase() const
Definition: gauge_field.h:259
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
QudaDagType dagger
Definition: test_util.cpp:1620
QudaParity parity
Definition: covdev_test.cpp:54
QudaFieldOrder FieldOrder() const