QUDA  v1.1.0
A library for QCD on GPUs
prolongator.cu
Go to the documentation of this file.
1 #include <color_spinor_field.h>
2 #include <color_spinor_field_order.h>
3 #include <tune_quda.h>
4 #include <multigrid_helper.cuh>
5 
6 namespace quda {
7 
8 #ifdef GPU_MULTIGRID
9  using namespace quda::colorspinor;
10 
11  /**
12  Kernel argument struct
13  */
14  template <typename Float, typename vFloat, int fineSpin, int fineColor, int coarseSpin, int coarseColor, QudaFieldOrder order>
15  struct ProlongateArg {
16  FieldOrderCB<Float,fineSpin,fineColor,1,order> out;
17  const FieldOrderCB<Float,coarseSpin,coarseColor,1,order> in;
18  const FieldOrderCB<Float,fineSpin,fineColor,coarseColor,order,vFloat> V;
19  const int *geo_map; // need to make a device copy of this
20  const spin_mapper<fineSpin,coarseSpin> spin_map;
21  const int parity; // the parity of the output field (if single parity)
22  const int nParity; // number of parities of input fine field
23 
24  ProlongateArg(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &V,
25  const int *geo_map, const int parity)
26  : out(out), in(in), V(V), geo_map(geo_map), spin_map(), parity(parity), nParity(out.SiteSubset()) { }
27 
28  ProlongateArg(const ProlongateArg<Float,vFloat,fineSpin,fineColor,coarseSpin,coarseColor,order> &arg)
29  : out(arg.out), in(arg.in), V(arg.V), geo_map(arg.geo_map), spin_map(),
30  parity(arg.parity), nParity(arg.nParity) { }
31  };
32 
33  /**
34  Applies the grid prolongation operator (coarse to fine)
35  */
36  template <typename Float, int fineSpin, int coarseColor, class Coarse, typename S>
37  __device__ __host__ inline void prolongate(complex<Float> out[fineSpin*coarseColor], const Coarse &in,
38  int parity, int x_cb, const int *geo_map, const S& spin_map, int fineVolumeCB) {
39  int x = parity*fineVolumeCB + x_cb;
40  int x_coarse = geo_map[x];
41  int parity_coarse = (x_coarse >= in.VolumeCB()) ? 1 : 0;
42  int x_coarse_cb = x_coarse - parity_coarse*in.VolumeCB();
43 
44 #pragma unroll
45  for (int s=0; s<fineSpin; s++) {
46 #pragma unroll
47  for (int c=0; c<coarseColor; c++) {
48  out[s*coarseColor+c] = in(parity_coarse, x_coarse_cb, spin_map(s,parity), c);
49  }
50  }
51  }
52 
53  /**
54  Rotates from the coarse-color basis into the fine-color basis. This
55  is the second step of applying the prolongator.
56  */
57  template <typename Float, int fineSpin, int fineColor, int coarseColor, int fine_colors_per_thread,
58  class FineColor, class Rotator>
59  __device__ __host__ inline void rotateFineColor(FineColor &out, const complex<Float> in[fineSpin*coarseColor],
60  const Rotator &V, int parity, int nParity, int x_cb, int fine_color_block) {
61  const int spinor_parity = (nParity == 2) ? parity : 0;
62  const int v_parity = (V.Nparity() == 2) ? parity : 0;
63 
64  constexpr int color_unroll = 2;
65 
66 #pragma unroll
67  for (int s=0; s<fineSpin; s++)
68 #pragma unroll
69  for (int fine_color_local=0; fine_color_local<fine_colors_per_thread; fine_color_local++)
70  out(spinor_parity, x_cb, s, fine_color_block+fine_color_local) = 0.0; // global fine color index
71 
72 #pragma unroll
73  for (int s=0; s<fineSpin; s++) {
74 #pragma unroll
75  for (int fine_color_local=0; fine_color_local<fine_colors_per_thread; fine_color_local++) {
76  int i = fine_color_block + fine_color_local; // global fine color index
77 
78  complex<Float> partial[color_unroll];
79 #pragma unroll
80  for (int k=0; k<color_unroll; k++) partial[k] = 0.0;
81 
82 #pragma unroll
83  for (int j=0; j<coarseColor; j+=color_unroll) {
84  // V is a ColorMatrixField with internal dimensions Ns * Nc * Nvec
85 #pragma unroll
86  for (int k=0; k<color_unroll; k++)
87  partial[k] += V(v_parity, x_cb, s, i, j+k) * in[s*coarseColor + j + k];
88  }
89 
90 #pragma unroll
91  for (int k=0; k<color_unroll; k++) out(spinor_parity, x_cb, s, i) += partial[k];
92  }
93  }
94 
95  }
96 
97  template <typename Float, int fineSpin, int fineColor, int coarseSpin, int coarseColor, int fine_colors_per_thread, typename Arg>
98  void Prolongate(Arg &arg) {
99  for (int parity=0; parity<arg.nParity; parity++) {
100  parity = (arg.nParity == 2) ? parity : arg.parity;
101 
102  for (int x_cb=0; x_cb<arg.out.VolumeCB(); x_cb++) {
103  complex<Float> tmp[fineSpin*coarseColor];
104  prolongate<Float,fineSpin,coarseColor>(tmp, arg.in, parity, x_cb, arg.geo_map, arg.spin_map, arg.out.VolumeCB());
105  for (int fine_color_block=0; fine_color_block<fineColor; fine_color_block+=fine_colors_per_thread) {
106  rotateFineColor<Float,fineSpin,fineColor,coarseColor,fine_colors_per_thread>
107  (arg.out, tmp, arg.V, parity, arg.nParity, x_cb, fine_color_block);
108  }
109  }
110  }
111  }
112 
113  template <typename Float, int fineSpin, int fineColor, int coarseSpin, int coarseColor, int fine_colors_per_thread, typename Arg>
114  __global__ void ProlongateKernel(Arg arg) {
115  int x_cb = blockIdx.x*blockDim.x + threadIdx.x;
116  int parity = arg.nParity == 2 ? blockDim.y*blockIdx.y + threadIdx.y : arg.parity;
117  if (x_cb >= arg.out.VolumeCB()) return;
118 
119  int fine_color_block = (blockDim.z*blockIdx.z + threadIdx.z) * fine_colors_per_thread;
120  if (fine_color_block >= fineColor) return;
121 
122  complex<Float> tmp[fineSpin*coarseColor];
123  prolongate<Float,fineSpin,coarseColor>(tmp, arg.in, parity, x_cb, arg.geo_map, arg.spin_map, arg.out.VolumeCB());
124  rotateFineColor<Float,fineSpin,fineColor,coarseColor,fine_colors_per_thread>
125  (arg.out, tmp, arg.V, parity, arg.nParity, x_cb, fine_color_block);
126  }
127 
128  template <typename Float, typename vFloat, int fineSpin, int fineColor, int coarseSpin, int coarseColor, int fine_colors_per_thread>
129  class ProlongateLaunch : public TunableVectorYZ {
130 
131  ColorSpinorField &out;
132  const ColorSpinorField &in;
133  const ColorSpinorField &V;
134  const int *fine_to_coarse;
135  int parity;
136  QudaFieldLocation location;
137  char vol[TuneKey::volume_n];
138 
139  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
140  unsigned int minThreads() const { return out.VolumeCB(); } // fine parity is the block y dimension
141 
142  public:
143  ProlongateLaunch(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &V,
144  const int *fine_to_coarse, int parity)
145  : TunableVectorYZ(out.SiteSubset(), fineColor/fine_colors_per_thread), out(out), in(in), V(V),
146  fine_to_coarse(fine_to_coarse), parity(parity), location(checkLocation(out, in, V))
147  {
148  strcpy(vol, out.VolString());
149  strcat(vol, ",");
150  strcat(vol, in.VolString());
151 
152  strcpy(aux, out.AuxString());
153  strcat(aux, ",");
154  strcat(aux, in.AuxString());
155  }
156 
157  void apply(const qudaStream_t &stream) {
158  if (location == QUDA_CPU_FIELD_LOCATION) {
159  if (out.FieldOrder() == QUDA_SPACE_SPIN_COLOR_FIELD_ORDER) {
160  ProlongateArg<Float,vFloat,fineSpin,fineColor,coarseSpin,coarseColor,QUDA_SPACE_SPIN_COLOR_FIELD_ORDER>
161  arg(out, in, V, fine_to_coarse, parity);
162  Prolongate<Float,fineSpin,fineColor,coarseSpin,coarseColor,fine_colors_per_thread>(arg);
163  } else {
164  errorQuda("Unsupported field order %d", out.FieldOrder());
165  }
166  } else {
167  if (out.FieldOrder() == QUDA_FLOAT2_FIELD_ORDER) {
168  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
169  ProlongateArg<Float,vFloat,fineSpin,fineColor,coarseSpin,coarseColor,QUDA_FLOAT2_FIELD_ORDER>
170  arg(out, in, V, fine_to_coarse, parity);
171  qudaLaunchKernel(ProlongateKernel<Float,fineSpin,fineColor,coarseSpin,coarseColor,fine_colors_per_thread,decltype(arg)>,
172  tp, stream, arg);
173  } else {
174  errorQuda("Unsupported field order %d", out.FieldOrder());
175  }
176  }
177  }
178 
179  TuneKey tuneKey() const { return TuneKey(vol, typeid(*this).name(), aux); }
180 
181  long long flops() const { return 8 * fineSpin * fineColor * coarseColor * out.SiteSubset()*(long long)out.VolumeCB(); }
182 
183  long long bytes() const {
184  size_t v_bytes = V.Bytes() / (V.SiteSubset() == out.SiteSubset() ? 1 : 2);
185  return in.Bytes() + out.Bytes() + v_bytes + out.SiteSubset()*out.VolumeCB()*sizeof(int);
186  }
187 
188  };
189 
190  template <typename Float, int fineSpin, int fineColor, int coarseSpin, int coarseColor>
191  void Prolongate(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v,
192  const int *fine_to_coarse, int parity) {
193 
194  // for all grids use 1 color per thread
195  constexpr int fine_colors_per_thread = 1;
196 
197  if (v.Precision() == QUDA_HALF_PRECISION) {
198 #if QUDA_PRECISION & 2
199  ProlongateLaunch<Float, short, fineSpin, fineColor, coarseSpin, coarseColor, fine_colors_per_thread>
200  prolongator(out, in, v, fine_to_coarse, parity);
201  prolongator.apply(0);
202 #else
203  errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION);
204 #endif
205  } else if (v.Precision() == in.Precision()) {
206  ProlongateLaunch<Float, Float, fineSpin, fineColor, coarseSpin, coarseColor, fine_colors_per_thread>
207  prolongator(out, in, v, fine_to_coarse, parity);
208  prolongator.apply(0);
209  } else {
210  errorQuda("Unsupported V precision %d", v.Precision());
211  }
212  }
213 
214  template <typename Float, int fineSpin>
215  void Prolongate(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v,
216  int nVec, const int *fine_to_coarse, const int * const * spin_map, int parity) {
217 
218  if (in.Nspin() != 2) errorQuda("Coarse spin %d is not supported", in.Nspin());
219  const int coarseSpin = 2;
220 
221  // first check that the spin_map matches the spin_mapper
222  spin_mapper<fineSpin,coarseSpin> mapper;
223  for (int s=0; s<fineSpin; s++)
224  for (int p=0; p<2; p++)
225  if (mapper(s,p) != spin_map[s][p]) errorQuda("Spin map does not match spin_mapper");
226 
227  if (out.Ncolor() == 3) {
228  const int fineColor = 3;
229 #ifdef NSPIN4
230  if (nVec == 6) { // Free field Wilson
231  Prolongate<Float,fineSpin,fineColor,coarseSpin,6>(out, in, v, fine_to_coarse, parity);
232  } else
233 #endif // NSPIN4
234  if (nVec == 24) {
235  Prolongate<Float,fineSpin,fineColor,coarseSpin,24>(out, in, v, fine_to_coarse, parity);
236 #ifdef NSPIN4
237  } else if (nVec == 32) {
238  Prolongate<Float,fineSpin,fineColor,coarseSpin,32>(out, in, v, fine_to_coarse, parity);
239 #endif // NSPIN4
240 #ifdef NSPIN1
241  } else if (nVec == 64) {
242  Prolongate<Float,fineSpin,fineColor,coarseSpin,64>(out, in, v, fine_to_coarse, parity);
243  } else if (nVec == 96) {
244  Prolongate<Float,fineSpin,fineColor,coarseSpin,96>(out, in, v, fine_to_coarse, parity);
245 #endif // NSPIN1
246  } else {
247  errorQuda("Unsupported nVec %d", nVec);
248  }
249 #ifdef NSPIN4
250  } else if (out.Ncolor() == 6) { // for coarsening coarsened Wilson free field.
251  const int fineColor = 6;
252  if (nVec == 6) { // these are probably only for debugging only
253  Prolongate<Float,fineSpin,fineColor,coarseSpin,6>(out, in, v, fine_to_coarse, parity);
254  } else {
255  errorQuda("Unsupported nVec %d", nVec);
256  }
257 #endif // NSPIN4
258  } else if (out.Ncolor() == 24) {
259  const int fineColor = 24;
260  if (nVec == 24) { // to keep compilation under control coarse grids have same or more colors
261  Prolongate<Float,fineSpin,fineColor,coarseSpin,24>(out, in, v, fine_to_coarse, parity);
262 #ifdef NSPIN4
263  } else if (nVec == 32) {
264  Prolongate<Float,fineSpin,fineColor,coarseSpin,32>(out, in, v, fine_to_coarse, parity);
265 #endif // NSPIN4
266 #ifdef NSPIN1
267  } else if (nVec == 64) {
268  Prolongate<Float,fineSpin,fineColor,coarseSpin,64>(out, in, v, fine_to_coarse, parity);
269  } else if (nVec == 96) {
270  Prolongate<Float,fineSpin,fineColor,coarseSpin,96>(out, in, v, fine_to_coarse, parity);
271 #endif // NSPIN1
272  } else {
273  errorQuda("Unsupported nVec %d", nVec);
274  }
275 #ifdef NSPIN4
276  } else if (out.Ncolor() == 32) {
277  const int fineColor = 32;
278  if (nVec == 32) {
279  Prolongate<Float,fineSpin,fineColor,coarseSpin,32>(out, in, v, fine_to_coarse, parity);
280  } else {
281  errorQuda("Unsupported nVec %d", nVec);
282  }
283 #endif // NSPIN4
284 #ifdef NSPIN1
285  } else if (out.Ncolor() == 64) {
286  const int fineColor = 64;
287  if (nVec == 64) {
288  Prolongate<Float,fineSpin,fineColor,coarseSpin,64>(out, in, v, fine_to_coarse, parity);
289  } else if (nVec == 96) {
290  Prolongate<Float,fineSpin,fineColor,coarseSpin,96>(out, in, v, fine_to_coarse, parity);
291  } else {
292  errorQuda("Unsupported nVec %d", nVec);
293  }
294  } else if (out.Ncolor() == 96) {
295  const int fineColor = 96;
296  if (nVec == 96) {
297  Prolongate<Float,fineSpin,fineColor,coarseSpin,96>(out, in, v, fine_to_coarse, parity);
298  } else {
299  errorQuda("Unsupported nVec %d", nVec);
300  }
301 #endif // NSPIN1
302  } else {
303  errorQuda("Unsupported nColor %d", out.Ncolor());
304  }
305  }
306 
307  template <typename Float>
308  void Prolongate(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v,
309  int Nvec, const int *fine_to_coarse, const int * const * spin_map, int parity) {
310 
311  if (out.Nspin() == 2) {
312  Prolongate<Float,2>(out, in, v, Nvec, fine_to_coarse, spin_map, parity);
313 #ifdef NSPIN4
314  } else if (out.Nspin() == 4) {
315  Prolongate<Float,4>(out, in, v, Nvec, fine_to_coarse, spin_map, parity);
316 #endif
317 #ifdef NSPIN1
318  } else if (out.Nspin() == 1) {
319  Prolongate<Float,1>(out, in, v, Nvec, fine_to_coarse, spin_map, parity);
320 #endif
321  } else {
322  errorQuda("Unsupported nSpin %d", out.Nspin());
323  }
324  }
325 
326 #endif // GPU_MULTIGRID
327 
328  void Prolongate(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v,
329  int Nvec, const int *fine_to_coarse, const int * const * spin_map, int parity) {
330 #ifdef GPU_MULTIGRID
331  if (out.FieldOrder() != in.FieldOrder() || out.FieldOrder() != v.FieldOrder())
332  errorQuda("Field orders do not match (out=%d, in=%d, v=%d)",
333  out.FieldOrder(), in.FieldOrder(), v.FieldOrder());
334 
335  QudaPrecision precision = checkPrecision(out, in);
336 
337  if (precision == QUDA_DOUBLE_PRECISION) {
338 #ifdef GPU_MULTIGRID_DOUBLE
339  Prolongate<double>(out, in, v, Nvec, fine_to_coarse, spin_map, parity);
340 #else
341  errorQuda("Double precision multigrid has not been enabled");
342 #endif
343  } else if (precision == QUDA_SINGLE_PRECISION) {
344  Prolongate<float>(out, in, v, Nvec, fine_to_coarse, spin_map, parity);
345  } else {
346  errorQuda("Unsupported precision %d", out.Precision());
347  }
348 #else
349  errorQuda("Multigrid has not been built");
350 #endif
351  }
352 
353 } // end namespace quda