QUDA  v1.1.0
A library for QCD on GPUs
dslash_improved_staggered.cu
Go to the documentation of this file.
1 #include <dslash.h>
2 #include <worker.h>
3 #include <dslash_helper.cuh>
4 #include <color_spinor_field_order.h>
5 #include <gauge_field_order.h>
6 #include <color_spinor.h>
7 #include <dslash_helper.cuh>
8 #include <index_helper.cuh>
9 #include <gauge_field.h>
10 
11 #include <dslash_policy.cuh>
12 #include <kernels/dslash_staggered.cuh>
13 
14 /**
15  This is a staggered Dirac operator
16 */
17 
18 namespace quda
19 {
20 
21  template <typename Arg> class Staggered : public Dslash<staggered, Arg>
22  {
23  using Dslash = Dslash<staggered, Arg>;
24  using Dslash::arg;
25  using Dslash::in;
26 
27  public:
28  Staggered(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) : Dslash(arg, out, in) {}
29 
30  void apply(const qudaStream_t &stream)
31  {
32  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
33  Dslash::setParam(tp);
34  // operator is anti-Hermitian so do not instantiate dagger
35  if (arg.nParity == 1) {
36  if (arg.xpay)
37  Dslash::template instantiate<packStaggeredShmem, 1, false, true>(tp, stream);
38  else
39  Dslash::template instantiate<packStaggeredShmem, 1, false, false>(tp, stream);
40  } else if (arg.nParity == 2) {
41  if (arg.xpay)
42  Dslash::template instantiate<packStaggeredShmem, 2, false, true>(tp, stream);
43  else
44  Dslash::template instantiate<packStaggeredShmem, 2, false, false>(tp, stream);
45  }
46  }
47 
48  /*
49  per direction / dimension flops
50  SU(3) matrix-vector flops = (8 Nc - 2) * Nc
51  xpay = 2 * 2 * Nc * Ns
52 
53  So for the full dslash we have
54  flops = (2 * 2 * Nd * (8*Nc-2) * Nc) + ((2 * 2 * Nd - 1) * 2 * Nc * Ns)
55  flops_xpay = flops + 2 * 2 * Nc * Ns
56 
57  For Asqtad this should give 1146 for Nc=3,Ns=2 and 1158 for the axpy equivalent
58  */
59  long long flops() const
60  {
61  int mv_flops = (8 * in.Ncolor() - 2) * in.Ncolor(); // SU(3) matrix-vector flops
62  int ghost_flops = (3 + 1) * (mv_flops + 2 * in.Ncolor() * in.Nspin());
63  int xpay_flops = 2 * 2 * in.Ncolor() * in.Nspin(); // multiply and add per real component
64  int num_dir = 2 * 4; // hard code factor of 4 in direction since fields may be 5-d
65 
66  long long flops_ = 0;
67 
68  switch (arg.kernel_type) {
69  case EXTERIOR_KERNEL_X:
70  case EXTERIOR_KERNEL_Y:
71  case EXTERIOR_KERNEL_Z:
72  case EXTERIOR_KERNEL_T: flops_ = ghost_flops * 2 * in.GhostFace()[arg.kernel_type]; break;
73  case EXTERIOR_KERNEL_ALL: {
74  long long ghost_sites = 2 * (in.GhostFace()[0] + in.GhostFace()[1] + in.GhostFace()[2] + in.GhostFace()[3]);
75  flops_ = ghost_flops * ghost_sites;
76  break;
77  }
78  case INTERIOR_KERNEL:
79  case UBER_KERNEL:
80  case KERNEL_POLICY: {
81  long long sites = in.Volume();
82  flops_ = (2 * num_dir * mv_flops + // SU(3) matrix-vector multiplies
83  (2 * num_dir - 1) * 2 * in.Ncolor() * in.Nspin())
84  * sites; // accumulation
85  if (arg.xpay) flops_ += xpay_flops * sites; // axpy is always on interior
86 
87  if (arg.kernel_type == KERNEL_POLICY) break;
88  // now correct for flops done by exterior kernel
89  long long ghost_sites = 0;
90  for (int d = 0; d < 4; d++)
91  if (arg.commDim[d]) ghost_sites += 2 * in.GhostFace()[d];
92  flops_ -= ghost_flops * ghost_sites;
93 
94  break;
95  }
96  }
97  return flops_;
98  }
99 
100  long long bytes() const
101  {
102  int gauge_bytes_fat = QUDA_RECONSTRUCT_NO * in.Precision();
103  int gauge_bytes_long = arg.reconstruct * in.Precision();
104  int spinor_bytes = 2 * in.Ncolor() * in.Nspin() * in.Precision() + (isFixed<typename Arg::Float>::value ? sizeof(float) : 0);
105  int ghost_bytes = 3 * (spinor_bytes + gauge_bytes_long) + (spinor_bytes + gauge_bytes_fat)
106  + 3 * 2 * spinor_bytes; // last term is the accumulator load/store through the face
107  int num_dir = 2 * 4; // set to 4-d since we take care of 5-d fermions in derived classes where necessary
108 
109  long long bytes_ = 0;
110 
111  switch (arg.kernel_type) {
112  case EXTERIOR_KERNEL_X:
113  case EXTERIOR_KERNEL_Y:
114  case EXTERIOR_KERNEL_Z:
115  case EXTERIOR_KERNEL_T: bytes_ = ghost_bytes * 2 * in.GhostFace()[arg.kernel_type]; break;
116  case EXTERIOR_KERNEL_ALL: {
117  long long ghost_sites = 2 * (in.GhostFace()[0] + in.GhostFace()[1] + in.GhostFace()[2] + in.GhostFace()[3]);
118  bytes_ = ghost_bytes * ghost_sites;
119  break;
120  }
121  case INTERIOR_KERNEL:
122  case UBER_KERNEL:
123  case KERNEL_POLICY: {
124  long long sites = in.Volume();
125  bytes_ = (num_dir * (gauge_bytes_fat + gauge_bytes_long) + // gauge reads
126  num_dir * 2 * spinor_bytes + // spinor reads
127  spinor_bytes)
128  * sites; // spinor write
129  if (arg.xpay) bytes_ += spinor_bytes;
130 
131  if (arg.kernel_type == KERNEL_POLICY) break;
132  // now correct for bytes done by exterior kernel
133  long long ghost_sites = 0;
134  for (int d = 0; d < 4; d++)
135  if (arg.commDim[d]) ghost_sites += 2 * in.GhostFace()[d];
136  bytes_ -= ghost_bytes * ghost_sites;
137 
138  break;
139  }
140  }
141  return bytes_;
142  }
143 
144  };
145 
146  template <typename Float, int nColor, QudaReconstructType recon_l> struct ImprovedStaggeredApply {
147 
148  inline ImprovedStaggeredApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &L,
149  const GaugeField &U, double a, const ColorSpinorField &x, int parity, bool dagger,
150  const int *comm_override, TimeProfile &profile)
151  {
152  constexpr int nDim = 4; // MWTODO: this probably should be 5 for mrhs Dslash
153  constexpr bool improved = true;
154  constexpr QudaReconstructType recon_u = QUDA_RECONSTRUCT_NO;
155  StaggeredArg<Float, nColor, nDim, recon_u, recon_l, improved> arg(out, in, U, L, a, x, parity, dagger,
156  comm_override);
157  Staggered<decltype(arg)> staggered(arg, out, in);
158 
159  dslash::DslashPolicyTune<decltype(staggered)> policy(
160  staggered, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
161  in.GhostFaceCB(), profile);
162  policy.apply(0);
163  }
164  };
165 
166  void ApplyImprovedStaggered(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U,
167  const GaugeField &L, double a, const ColorSpinorField &x, int parity, bool dagger,
168  const int *comm_override, TimeProfile &profile)
169  {
170 
171 #ifdef GPU_STAGGERED_DIRAC
172  for (int i = 0; i < 4; i++) {
173  if (comm_dim_partitioned(i) && (U.X()[i] < 6)) {
174  errorQuda(
175  "ERROR: partitioned dimension with local size less than 6 is not supported in improved staggered dslash\n");
176  }
177  }
178 
179  // L must be first gauge field argument since we template on long reconstruct
180  instantiate<ImprovedStaggeredApply, StaggeredReconstruct>(out, in, L, U, a, x, parity, dagger, comm_override,
181  profile);
182 #else
183  errorQuda("Staggered dslash has not been built");
184 #endif
185  }
186 
187 } // namespace quda