QUDA  v1.1.0
A library for QCD on GPUs
dslash5_domain_wall.cu
Go to the documentation of this file.
1 #include <color_spinor_field.h>
2 #include <color_spinor_field_order.h>
3 #include <dslash_quda.h>
4 #include <index_helper.cuh>
5 #include <instantiate.h>
6 
7 #include <kernels/dslash_domain_wall_m5.cuh>
8 
9 namespace quda
10 {
11 
12  template <typename Float, int nColor> class Dslash5 : public TunableVectorYZ
13  {
14  Dslash5Arg<Float, nColor> arg;
15  const ColorSpinorField &meta;
16  static constexpr bool shared = true; // whether to use shared memory cache blocking for M5inv
17 
18  /** Whether to use variable or fixed coefficient algorithm. Must be true if using ZMOBIUS */
19  static constexpr bool var_inverse = true;
20 
21  long long flops() const
22  {
23  long long Ls = meta.X(4);
24  long long bulk = (Ls - 2) * (meta.Volume() / Ls);
25  long long wall = 2 * meta.Volume() / Ls;
26  long long n = meta.Ncolor() * meta.Nspin();
27 
28  long long flops_ = 0;
29  switch (arg.type) {
30  case DSLASH5_DWF: flops_ = n * (8ll * bulk + 10ll * wall + (arg.xpay ? 4ll * meta.Volume() : 0)); break;
31  case DSLASH5_MOBIUS_PRE:
32  flops_ = n * (8ll * bulk + 10ll * wall + 14ll * meta.Volume() + (arg.xpay ? 8ll * meta.Volume() : 0));
33  break;
34  case DSLASH5_MOBIUS:
35  flops_ = n * (8ll * bulk + 10ll * wall + 8ll * meta.Volume() + (arg.xpay ? 8ll * meta.Volume() : 0));
36  break;
37  case M5_INV_DWF:
38  case M5_INV_MOBIUS: // FIXME flops
39  flops_ = ((2 + 8 * n) * Ls + (arg.xpay ? 4ll : 0)) * meta.Volume();
40  break;
41  case M5_INV_ZMOBIUS: flops_ = ((12 + 16 * n) * Ls + (arg.xpay ? 8ll : 0)) * meta.Volume(); break;
42  default: errorQuda("Unknown Dslash5Type %d", arg.type);
43  }
44 
45  return flops_;
46  }
47 
48  long long bytes() const
49  {
50  long long Ls = meta.X(4);
51  switch (arg.type) {
52  case DSLASH5_DWF: return arg.out.Bytes() + 2 * arg.in.Bytes() + (arg.xpay ? arg.x.Bytes() : 0);
53  case DSLASH5_MOBIUS_PRE: return arg.out.Bytes() + 3 * arg.in.Bytes() + (arg.xpay ? arg.x.Bytes() : 0);
54  case DSLASH5_MOBIUS: return arg.out.Bytes() + 3 * arg.in.Bytes() + (arg.xpay ? arg.x.Bytes() : 0);
55  case M5_INV_DWF: return arg.out.Bytes() + Ls * arg.in.Bytes() + (arg.xpay ? arg.x.Bytes() : 0);
56  case M5_INV_MOBIUS: return arg.out.Bytes() + Ls * arg.in.Bytes() + (arg.xpay ? arg.x.Bytes() : 0);
57  case M5_INV_ZMOBIUS: return arg.out.Bytes() + Ls * arg.in.Bytes() + (arg.xpay ? arg.x.Bytes() : 0);
58  default: errorQuda("Unknown Dslash5Type %d", arg.type);
59  }
60  return 0ll;
61  }
62 
63  bool tuneGridDim() const { return false; }
64  unsigned int minThreads() const { return arg.volume_4d_cb; }
65  int blockStep() const { return 4; }
66  int blockMin() const { return 4; }
67  unsigned int sharedBytesPerThread() const
68  {
69  if (shared && (arg.type == M5_INV_DWF || arg.type == M5_INV_MOBIUS || arg.type == M5_INV_ZMOBIUS)) {
70  // spin components in shared depend on inversion algorithm
71  int nSpin = var_inverse ? meta.Nspin() / 2 : meta.Nspin();
72  return 2 * nSpin * nColor * sizeof(typename mapper<Float>::type);
73  } else {
74  return 0;
75  }
76  }
77 
78  // overloaded to return max dynamic shared memory if doing shared-memory inverse
79  unsigned int maxSharedBytesPerBlock() const
80  {
81  if (shared && (arg.type == M5_INV_DWF || arg.type == M5_INV_MOBIUS || arg.type == M5_INV_ZMOBIUS)) {
82  return maxDynamicSharedBytesPerBlock();
83  } else {
84  return TunableVectorYZ::maxSharedBytesPerBlock();
85  }
86  }
87 
88 public:
89  Dslash5(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, double m_f,
90  double m_5, const Complex *b_5, const Complex *c_5, double a, bool dagger, Dslash5Type type) :
91  TunableVectorYZ(in.X(4), in.SiteSubset()),
92  arg(out, in, x, m_f, m_5, b_5, c_5, a, dagger, type),
93  meta(in)
94  {
95  strcpy(aux, meta.AuxString());
96  if (arg.dagger) strcat(aux, ",Dagger");
97  if (arg.xpay) strcat(aux, ",xpay");
98  switch (arg.type) {
99  case DSLASH5_DWF: strcat(aux, ",DSLASH5_DWF"); break;
100  case DSLASH5_MOBIUS_PRE: strcat(aux, ",DSLASH5_MOBIUS_PRE"); break;
101  case DSLASH5_MOBIUS: strcat(aux, ",DSLASH5_MOBIUS"); break;
102  case M5_INV_DWF: strcat(aux, ",M5_INV_DWF"); break;
103  case M5_INV_MOBIUS: strcat(aux, ",M5_INV_MOBIUS"); break;
104  case M5_INV_ZMOBIUS: strcat(aux, ",M5_INV_ZMOBIUS"); break;
105  default: errorQuda("Unknown Dslash5Type %d", arg.type);
106  }
107 
108  apply(streams[Nstream - 1]);
109  }
110 
111  template <typename T, typename Arg> inline void launch(T *f, TuneParam &tp, Arg &arg, const qudaStream_t &stream)
112  {
113  if (shared && (arg.type == M5_INV_DWF || arg.type == M5_INV_MOBIUS || arg.type == M5_INV_ZMOBIUS)) {
114  // if inverse kernel uses shared memory then maximize total shared memory pool
115  tp.set_max_shared_bytes = true;
116  }
117  qudaLaunchKernel(f, tp, stream, arg);
118  }
119 
120  void apply(const qudaStream_t &stream)
121  {
122  using Arg = decltype(arg);
123  if (meta.Location() == QUDA_CPU_FIELD_LOCATION) {
124  errorQuda("CPU variant not instantiated");
125  } else {
126  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
127  if (arg.type == DSLASH5_DWF) {
128  if (arg.xpay)
129  arg.dagger ? launch(dslash5GPU<Float, nColor, true, true, DSLASH5_DWF, Arg>, tp, arg, stream) :
130  launch(dslash5GPU<Float, nColor, false, true, DSLASH5_DWF, Arg>, tp, arg, stream);
131  else
132  arg.dagger ? launch(dslash5GPU<Float, nColor, true, false, DSLASH5_DWF, Arg>, tp, arg, stream) :
133  launch(dslash5GPU<Float, nColor, false, false, DSLASH5_DWF, Arg>, tp, arg, stream);
134  } else if (arg.type == DSLASH5_MOBIUS_PRE) {
135  if (arg.xpay)
136  arg.dagger ? launch(dslash5GPU<Float, nColor, true, true, DSLASH5_MOBIUS_PRE, Arg>, tp, arg, stream) :
137  launch(dslash5GPU<Float, nColor, false, true, DSLASH5_MOBIUS_PRE, Arg>, tp, arg, stream);
138  else
139  arg.dagger ? launch(dslash5GPU<Float, nColor, true, false, DSLASH5_MOBIUS_PRE, Arg>, tp, arg, stream) :
140  launch(dslash5GPU<Float, nColor, false, false, DSLASH5_MOBIUS_PRE, Arg>, tp, arg, stream);
141  } else if (arg.type == DSLASH5_MOBIUS) {
142  if (arg.xpay)
143  arg.dagger ? launch(dslash5GPU<Float, nColor, true, true, DSLASH5_MOBIUS, Arg>, tp, arg, stream) :
144  launch(dslash5GPU<Float, nColor, false, true, DSLASH5_MOBIUS, Arg>, tp, arg, stream);
145  else
146  arg.dagger ? launch(dslash5GPU<Float, nColor, true, false, DSLASH5_MOBIUS, Arg>, tp, arg, stream) :
147  launch(dslash5GPU<Float, nColor, false, false, DSLASH5_MOBIUS, Arg>, tp, arg, stream);
148  } else if (arg.type == M5_INV_DWF) {
149  if (arg.xpay)
150  arg.dagger ?
151  launch(dslash5invGPU<Float, nColor, true, true, M5_INV_DWF, shared, var_inverse, Arg>, tp, arg, stream) :
152  launch(dslash5invGPU<Float, nColor, false, true, M5_INV_DWF, shared, var_inverse, Arg>, tp, arg, stream);
153  else
154  arg.dagger ?
155  launch(dslash5invGPU<Float, nColor, true, false, M5_INV_DWF, shared, var_inverse, Arg>, tp, arg, stream) :
156  launch(dslash5invGPU<Float, nColor, false, false, M5_INV_DWF, shared, var_inverse, Arg>, tp, arg, stream);
157  } else if (arg.type == M5_INV_MOBIUS) {
158  if (arg.xpay)
159  arg.dagger ? launch(
160  dslash5invGPU<Float, nColor, true, true, M5_INV_MOBIUS, shared, var_inverse, Arg>, tp, arg, stream) :
161  launch(dslash5invGPU<Float, nColor, false, true, M5_INV_MOBIUS, shared, var_inverse, Arg>, tp,
162  arg, stream);
163  else
164  arg.dagger ? launch(
165  dslash5invGPU<Float, nColor, true, false, M5_INV_MOBIUS, shared, var_inverse, Arg>, tp, arg, stream) :
166  launch(dslash5invGPU<Float, nColor, false, false, M5_INV_MOBIUS, shared, var_inverse, Arg>, tp,
167  arg, stream);
168  } else if (arg.type == M5_INV_ZMOBIUS) {
169  if (arg.xpay)
170  arg.dagger ? launch(
171  dslash5invGPU<Float, nColor, true, true, M5_INV_ZMOBIUS, shared, var_inverse, Arg>, tp, arg, stream) :
172  launch(dslash5invGPU<Float, nColor, false, true, M5_INV_ZMOBIUS, shared, var_inverse, Arg>, tp,
173  arg, stream);
174  else
175  arg.dagger ? launch(
176  dslash5invGPU<Float, nColor, true, false, M5_INV_ZMOBIUS, shared, var_inverse, Arg>, tp, arg, stream) :
177  launch(dslash5invGPU<Float, nColor, false, false, M5_INV_ZMOBIUS, shared, var_inverse, Arg>,
178  tp, arg, stream);
179  }
180  }
181  }
182 
183  void initTuneParam(TuneParam &param) const
184  {
185  TunableVectorYZ::initTuneParam(param);
186  if (shared && (arg.type == M5_INV_DWF || arg.type == M5_INV_MOBIUS || arg.type == M5_INV_ZMOBIUS)) {
187  param.block.y = arg.Ls; // Ls must be contained in the block
188  param.grid.y = 1;
189  param.shared_bytes = sharedBytesPerThread() * param.block.x * param.block.y * param.block.z;
190  }
191  }
192 
193  void defaultTuneParam(TuneParam &param) const
194  {
195  TunableVectorYZ::defaultTuneParam(param);
196  if (shared && (arg.type == M5_INV_DWF || arg.type == M5_INV_MOBIUS || arg.type == M5_INV_ZMOBIUS)) {
197  param.block.y = arg.Ls; // Ls must be contained in the block
198  param.grid.y = 1;
199  param.shared_bytes = sharedBytesPerThread() * param.block.x * param.block.y * param.block.z;
200  }
201  }
202 
203  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
204  };
205 
206  // Apply the 5th dimension dslash operator to a colorspinor field
207  // out = Dslash5*in
208  void ApplyDslash5(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, double m_f,
209  double m_5, const Complex *b_5, const Complex *c_5, double a, bool dagger, Dslash5Type type)
210  {
211 #ifdef GPU_DOMAIN_WALL_DIRAC
212  if (in.PCType() != QUDA_4D_PC) errorQuda("Only 4-d preconditioned fields are supported");
213  checkLocation(out, in, x); // check all locations match
214  instantiate<Dslash5>(out, in, x, m_f, m_5, b_5, c_5, a, dagger, type);
215 #else
216  errorQuda("Domain wall dslash has not been built");
217 #endif
218  }
219 
220 } // namespace quda