QUDA  v1.1.0
A library for QCD on GPUs
dslash_staggered.cu
Go to the documentation of this file.
1 #include <dslash.h>
2 #include <worker.h>
3 #include <dslash_helper.cuh>
4 #include <color_spinor_field_order.h>
5 #include <gauge_field_order.h>
6 #include <color_spinor.h>
7 #include <index_helper.cuh>
8 #include <gauge_field.h>
9 
10 #include <dslash_policy.cuh>
11 #include <kernels/dslash_staggered.cuh>
12 
13 /**
14  This is a staggered Dirac operator
15 */
16 
17 namespace quda
18 {
19 
20  template <typename Arg> class Staggered : public Dslash<staggered, Arg>
21  {
22  using Dslash = Dslash<staggered, Arg>;
23  using Dslash::arg;
24 
25  public:
26  Staggered(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) : Dslash(arg, out, in) {}
27 
28  void apply(const qudaStream_t &stream)
29  {
30  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
31  Dslash::setParam(tp);
32  // operator is anti-Hermitian so do not instantiate dagger
33  if (arg.nParity == 1) {
34  if (arg.xpay)
35  Dslash::template instantiate<packStaggeredShmem, 1, false, true>(tp, stream);
36  else
37  Dslash::template instantiate<packStaggeredShmem, 1, false, false>(tp, stream);
38  } else if (arg.nParity == 2) {
39  if (arg.xpay)
40  Dslash::template instantiate<packStaggeredShmem, 2, false, true>(tp, stream);
41  else
42  Dslash::template instantiate<packStaggeredShmem, 2, false, false>(tp, stream);
43  }
44  }
45  };
46 
47  template <typename Float, int nColor, QudaReconstructType recon_u> struct StaggeredApply {
48 
49  inline StaggeredApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a,
50  const ColorSpinorField &x, int parity, bool dagger, const int *comm_override,
51  TimeProfile &profile)
52  {
53  if (U.StaggeredPhase() == QUDA_STAGGERED_PHASE_MILC) {
54 #ifdef BUILD_MILC_INTERFACE
55  constexpr int nDim = 4; // MWTODO: this probably should be 5 for mrhs Dslash
56  constexpr bool improved = false;
57 
58  StaggeredArg<Float, nColor, nDim, recon_u, QUDA_RECONSTRUCT_NO, improved, QUDA_STAGGERED_PHASE_MILC> arg(
59  out, in, U, U, a, x, parity, dagger, comm_override);
60  Staggered<decltype(arg)> staggered(arg, out, in);
61 
62  dslash::DslashPolicyTune<decltype(staggered)> policy(
63  staggered, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
64  in.GhostFaceCB(), profile);
65  policy.apply(0);
66 #else
67  errorQuda("MILC interface has not been built so MILC phase staggered fermions not enabled");
68 #endif
69  } else if (U.StaggeredPhase() == QUDA_STAGGERED_PHASE_TIFR) {
70 #ifdef BUILD_TIFR_INTERFACE
71  constexpr int nDim = 4; // MWTODO: this probably should be 5 for mrhs Dslash
72  constexpr bool improved = false;
73 
74  StaggeredArg<Float, nColor, nDim, recon_u, QUDA_RECONSTRUCT_NO, improved, QUDA_STAGGERED_PHASE_TIFR> arg(
75  out, in, U, U, a, x, parity, dagger, comm_override);
76  Staggered<decltype(arg)> staggered(arg, out, in);
77 
78  dslash::DslashPolicyTune<decltype(staggered)> policy(
79  staggered, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
80  in.GhostFaceCB(), profile);
81  policy.apply(0);
82 #else
83  errorQuda("TIFR interface has not been built so TIFR phase taggered fermions not enabled");
84 #endif
85  } else {
86  errorQuda("Unsupported staggered phase type %d", U.StaggeredPhase());
87  }
88  }
89  };
90 
91  void ApplyStaggered(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a,
92  const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
93  {
94 #ifdef GPU_STAGGERED_DIRAC
95  instantiate<StaggeredApply, StaggeredReconstruct>(out, in, U, a, x, parity, dagger, comm_override, profile);
96 #else
97  errorQuda("Staggered dslash has not been built");
98 #endif
99  }
100 
101 } // namespace quda