QUDA  v1.1.0
A library for QCD on GPUs
dslash_wilson_clover_hasenbusch_twist_preconditioned.cu
Go to the documentation of this file.
1 #ifndef USE_LEGACY_DSLASH
2 
3 #include <gauge_field.h>
4 #include <color_spinor_field.h>
5 #include <clover_field.h>
6 #include <dslash.h>
7 #include <worker.h>
8 
9 #include <dslash_policy.cuh>
10 #include <kernels/dslash_wilson_clover_hasenbusch_twist_preconditioned.cuh>
11 
12 namespace quda
13 {
14 
15  /* ***************************
16  * No Clov Inv: 1 - k^2 D - i mu gamma_5 A
17  * **************************/
18  template <typename Arg>
19  class WilsonCloverHasenbuschTwistPCNoClovInv : public Dslash<cloverHasenbuschPreconditioned, Arg>
20  {
21  using Dslash = Dslash<cloverHasenbuschPreconditioned, Arg>;
22  using Dslash::arg;
23  using Dslash::in;
24 
25  public:
26  WilsonCloverHasenbuschTwistPCNoClovInv(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) :
27  Dslash(arg, out, in) {}
28 
29  void apply(const qudaStream_t &stream)
30  {
31  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
32  Dslash::setParam(tp);
33 
34  // specialize here to constrain the template instantiation
35  if (arg.nParity == 1) {
36  if (arg.xpay)
37  Dslash::template instantiate<packShmem, 1, true>(tp, stream);
38  else
39  errorQuda("Operator only defined for xpay=true");
40  } else {
41  errorQuda("Operator not defined nParity=%d", arg.nParity);
42  }
43  }
44 
45  long long flops() const
46  {
47  int clover_flops = 504;
48  long long flops = Dslash::flops();
49  switch (arg.kernel_type) {
50  case EXTERIOR_KERNEL_X:
51  case EXTERIOR_KERNEL_Y:
52  case EXTERIOR_KERNEL_Z:
53  case EXTERIOR_KERNEL_T:
54  // 2 from fwd / back face * 1 clover terms:
55  // there is no A^{-1}D only D
56  // there is one clover_term and 48 is the - mu (igamma_5) A
57  flops += 2 * (clover_flops + 48) * in.GhostFace()[arg.kernel_type];
58  break;
59  case EXTERIOR_KERNEL_ALL:
60  flops
61  += 2 * (clover_flops + 48) * (in.GhostFace()[0] + in.GhostFace()[1] + in.GhostFace()[2] + in.GhostFace()[3]);
62  break;
63  case INTERIOR_KERNEL:
64  case UBER_KERNEL:
65  case KERNEL_POLICY:
66  flops += (clover_flops + 48) * in.Volume();
67 
68  if (arg.kernel_type == KERNEL_POLICY) break;
69  // now correct for flops done by exterior kernel
70  long long ghost_sites = 0;
71  for (int d = 0; d < 4; d++)
72  if (arg.commDim[d]) ghost_sites += 2 * in.GhostFace()[d];
73  flops -= (clover_flops + 48) * ghost_sites;
74 
75  break;
76  }
77  return flops;
78  }
79 
80  long long bytes() const
81  {
82  int clover_bytes = 72 * in.Precision() + (isFixed<typename Arg::Float>::value ? 2 * sizeof(float) : 0);
83 
84  long long bytes = Dslash::bytes();
85  switch (arg.kernel_type) {
86  case EXTERIOR_KERNEL_X:
87  case EXTERIOR_KERNEL_Y:
88  case EXTERIOR_KERNEL_Z:
89  case EXTERIOR_KERNEL_T:
90  // Factor of 2 is from the fwd/back faces.
91  bytes += clover_bytes * 2 * in.GhostFace()[arg.kernel_type];
92  break;
93  case EXTERIOR_KERNEL_ALL:
94  // Factor of 2 is from the fwd/back faces
95  bytes += clover_bytes * 2 * (in.GhostFace()[0] + in.GhostFace()[1] + in.GhostFace()[2] + in.GhostFace()[3]);
96  break;
97  case INTERIOR_KERNEL:
98  case UBER_KERNEL:
99  case KERNEL_POLICY:
100 
101  bytes += clover_bytes * in.Volume();
102 
103  if (arg.kernel_type == KERNEL_POLICY) break;
104  // now correct for bytes done by exterior kernel
105  long long ghost_sites = 0;
106  for (int d = 0; d < 4; d++)
107  if (arg.commDim[d]) ghost_sites += 2 * in.GhostFace()[d];
108  bytes -= clover_bytes * ghost_sites;
109 
110  break;
111  }
112 
113  return bytes;
114  }
115  };
116 
117  template <typename Float, int nColor, QudaReconstructType recon> struct WilsonCloverHasenbuschTwistPCNoClovInvApply {
118 
119  inline WilsonCloverHasenbuschTwistPCNoClovInvApply(ColorSpinorField &out, const ColorSpinorField &in,
120  const GaugeField &U, const CloverField &A, double a, double b,
121  const ColorSpinorField &x, int parity, bool dagger,
122  const int *comm_override, TimeProfile &profile)
123  {
124  constexpr int nDim = 4;
125  using ArgType = WilsonCloverHasenbuschTwistPCArg<Float, nColor, nDim, recon, false>;
126 
127  ArgType arg(out, in, U, A, a, b, x, parity, dagger, comm_override);
128  WilsonCloverHasenbuschTwistPCNoClovInv<ArgType> wilson(arg, out, in);
129 
130  dslash::DslashPolicyTune<decltype(wilson)> policy(
131  wilson, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
132  in.GhostFaceCB(), profile);
133  policy.apply(0);
134  }
135  };
136 
137  // Apply the Wilson-clover operator
138  // out(x) = M*in = (A(x) + kappa * \sum_mu U_{-\mu}(x)in(x+mu) + U^\dagger_mu(x-mu)in(x-mu))
139  // Uses the kappa normalization for the Wilson operator.
140  void ApplyWilsonCloverHasenbuschTwistPCNoClovInv(ColorSpinorField &out, const ColorSpinorField &in,
141  const GaugeField &U, const CloverField &A, double a, double b,
142  const ColorSpinorField &x, int parity, bool dagger,
143  const int *comm_override, TimeProfile &profile)
144  {
145 #ifdef GPU_CLOVER_HASENBUSCH_TWIST
146  if (in.V() == out.V()) errorQuda("Aliasing pointers");
147  if (in.FieldOrder() != out.FieldOrder())
148  errorQuda("Field order mismatch in = %d, out = %d", in.FieldOrder(), out.FieldOrder());
149 
150  // check all precisions match
151  checkPrecision(out, in, U, A);
152 
153  // check all locations match
154  checkLocation(out, in, U, A);
155 
156  instantiate<WilsonCloverHasenbuschTwistPCNoClovInvApply>(out, in, U, A, a, b, x, parity, dagger, comm_override,
157  profile);
158 #else
159  errorQuda("Clover Hasenbusch Twist dslash has not been built");
160 #endif
161  }
162 
163  /* ***************************
164  * Clov Inv
165  *
166  * M = psi_p - k^2 A^{-1} D_p\not{p} - i mu gamma_5 A_{pp} psi_{p}
167  * **************************/
168  template <typename Arg>
169  class WilsonCloverHasenbuschTwistPCClovInv : public Dslash<cloverHasenbuschPreconditioned, Arg>
170  {
171  using Dslash = Dslash<cloverHasenbuschPreconditioned, Arg>;
172  using Dslash::arg;
173  using Dslash::in;
174 
175  public:
176  WilsonCloverHasenbuschTwistPCClovInv(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) :
177  Dslash(arg, out, in) {}
178 
179  void apply(const qudaStream_t &stream)
180  {
181  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
182  Dslash::setParam(tp);
183 
184  // specialize here to constrain the template instantiation
185  if (arg.nParity == 1) {
186  if (arg.xpay)
187  Dslash::template instantiate<packShmem, 1, true>(tp, stream);
188  else
189  errorQuda("Operator only defined for xpay=true");
190  } else {
191  errorQuda("Operator not defined nParity=%d", arg.nParity);
192  }
193  }
194 
195  long long flops() const
196  {
197  int clover_flops = 504;
198  long long flops = Dslash::flops();
199  switch (arg.kernel_type) {
200  case EXTERIOR_KERNEL_X:
201  case EXTERIOR_KERNEL_Y:
202  case EXTERIOR_KERNEL_Z:
203  case EXTERIOR_KERNEL_T:
204  // 2 from fwd / back face * 2 clover terms:
205  // one clover_term from the A^{-1}D
206  // second clover_term and 48 is the - mu (igamma_5) A
207  flops += 2 * (2 * clover_flops + 48) * in.GhostFace()[arg.kernel_type];
208  break;
209  case EXTERIOR_KERNEL_ALL:
210  flops += 2 * (2 * clover_flops + 48)
211  * (in.GhostFace()[0] + in.GhostFace()[1] + in.GhostFace()[2] + in.GhostFace()[3]);
212  break;
213  case INTERIOR_KERNEL:
214  case UBER_KERNEL:
215  case KERNEL_POLICY:
216  flops += (2 * clover_flops + 48) * in.Volume();
217 
218  if (arg.kernel_type == KERNEL_POLICY) break;
219  // now correct for flops done by exterior kernel
220  long long ghost_sites = 0;
221  for (int d = 0; d < 4; d++)
222  if (arg.commDim[d]) ghost_sites += 2 * in.GhostFace()[d];
223  flops -= (2 * clover_flops + 48) * ghost_sites;
224 
225  break;
226  }
227  return flops;
228  }
229 
230  long long bytes() const
231  {
232  int clover_bytes = 72 * in.Precision() + (isFixed<typename Arg::Float>::value ? 2 * sizeof(float) : 0);
233 
234  // if we use dynamic clover we read only A (even for A^{-1}
235  // otherwise we read both A and A^{-1}
236  int dyn_factor = arg.dynamic_clover ? 1 : 2;
237 
238  long long bytes = Dslash::bytes();
239  switch (arg.kernel_type) {
240  case EXTERIOR_KERNEL_X:
241  case EXTERIOR_KERNEL_Y:
242  case EXTERIOR_KERNEL_Z:
243  case EXTERIOR_KERNEL_T:
244  // Factor of 2 is from the fwd/back faces.
245  bytes += dyn_factor * clover_bytes * 2 * in.GhostFace()[arg.kernel_type];
246  break;
247  case EXTERIOR_KERNEL_ALL:
248  // Factor of 2 is from the fwd/back faces
249  bytes += dyn_factor * clover_bytes * 2
250  * (in.GhostFace()[0] + in.GhostFace()[1] + in.GhostFace()[2] + in.GhostFace()[3]);
251  break;
252  case INTERIOR_KERNEL:
253  case UBER_KERNEL:
254  case KERNEL_POLICY:
255 
256  bytes += dyn_factor * clover_bytes * in.Volume();
257 
258  if (arg.kernel_type == KERNEL_POLICY) break;
259  // now correct for bytes done by exterior kernel
260  long long ghost_sites = 0;
261  for (int d = 0; d < 4; d++)
262  if (arg.commDim[d]) ghost_sites += 2 * in.GhostFace()[d];
263  bytes -= dyn_factor * clover_bytes * ghost_sites;
264 
265  break;
266  }
267 
268  return bytes;
269  }
270  };
271 
272  template <typename Float, int nColor, QudaReconstructType recon> struct WilsonCloverHasenbuschTwistPCClovInvApply {
273 
274  inline WilsonCloverHasenbuschTwistPCClovInvApply(ColorSpinorField &out, const ColorSpinorField &in,
275  const GaugeField &U, const CloverField &A, double kappa, double mu,
276  const ColorSpinorField &x, int parity, bool dagger,
277  const int *comm_override, TimeProfile &profile)
278  {
279  constexpr int nDim = 4;
280  using ArgType = WilsonCloverHasenbuschTwistPCArg<Float, nColor, nDim, recon, true>;
281  ArgType arg(out, in, U, A, kappa, mu, x, parity, dagger, comm_override);
282  WilsonCloverHasenbuschTwistPCClovInv<ArgType> wilson(arg, out, in);
283 
284  dslash::DslashPolicyTune<decltype(wilson)> policy(
285  wilson, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
286  in.GhostFaceCB(), profile);
287  policy.apply(0);
288  }
289  };
290 
291  // Apply the Wilson-clover operator
292  // out(x) = M*in = (A(x) + kappa * \sum_mu U_{-\mu}(x)in(x+mu) + U^\dagger_mu(x-mu)in(x-mu))
293  // Uses the kappa normalization for the Wilson operator.
294  void ApplyWilsonCloverHasenbuschTwistPCClovInv(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U,
295  const CloverField &A, double a, double b, const ColorSpinorField &x,
296  int parity, bool dagger, const int *comm_override, TimeProfile &profile)
297  {
298 #ifdef GPU_CLOVER_HASENBUSCH_TWIST
299  instantiate<WilsonCloverHasenbuschTwistPCClovInvApply>(out, in, U, A, a, b, x, parity, dagger, comm_override,
300  profile);
301 #else
302  errorQuda("Clover Hasenbusch Twist dslash has not been built");
303 #endif
304  }
305 
306 } // namespace quda
307 
308 #endif