QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
prolongator.cu
Go to the documentation of this file.
1 #include <color_spinor_field.h>
3 #include <tune_quda.h>
4 #include <typeinfo>
5 #include <multigrid_helper.cuh>
6 
7 namespace quda {
8 
9 #ifdef GPU_MULTIGRID
10  using namespace quda::colorspinor;
11 
15  template <typename Float, typename vFloat, int fineSpin, int fineColor, int coarseSpin, int coarseColor, QudaFieldOrder order>
16  struct ProlongateArg {
20  const int *geo_map; // need to make a device copy of this
21  const spin_mapper<fineSpin,coarseSpin> spin_map;
22  const int parity; // the parity of the output field (if single parity)
23  const int nParity; // number of parities of input fine field
24 
25  ProlongateArg(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &V,
26  const int *geo_map, const int parity)
27  : out(out), in(in), V(V), geo_map(geo_map), spin_map(), parity(parity), nParity(out.SiteSubset()) { }
28 
29  ProlongateArg(const ProlongateArg<Float,vFloat,fineSpin,fineColor,coarseSpin,coarseColor,order> &arg)
30  : out(arg.out), in(arg.in), V(arg.V), geo_map(arg.geo_map), spin_map(),
31  parity(arg.parity), nParity(arg.nParity) { }
32  };
33 
37  template <typename Float, int fineSpin, int coarseColor, class Coarse, typename S>
38  __device__ __host__ inline void prolongate(complex<Float> out[fineSpin*coarseColor], const Coarse &in,
39  int parity, int x_cb, const int *geo_map, const S& spin_map, int fineVolumeCB) {
40  int x = parity*fineVolumeCB + x_cb;
41  int x_coarse = geo_map[x];
42  int parity_coarse = (x_coarse >= in.VolumeCB()) ? 1 : 0;
43  int x_coarse_cb = x_coarse - parity_coarse*in.VolumeCB();
44 
45 #pragma unroll
46  for (int s=0; s<fineSpin; s++) {
47 #pragma unroll
48  for (int c=0; c<coarseColor; c++) {
49  out[s*coarseColor+c] = in(parity_coarse, x_coarse_cb, spin_map(s,parity), c);
50  }
51  }
52  }
53 
58  template <typename Float, int fineSpin, int fineColor, int coarseColor, int fine_colors_per_thread,
59  class FineColor, class Rotator>
60  __device__ __host__ inline void rotateFineColor(FineColor &out, const complex<Float> in[fineSpin*coarseColor],
61  const Rotator &V, int parity, int nParity, int x_cb, int fine_color_block) {
62  const int spinor_parity = (nParity == 2) ? parity : 0;
63  const int v_parity = (V.Nparity() == 2) ? parity : 0;
64 
65  constexpr int color_unroll = 2;
66 
67 #pragma unroll
68  for (int s=0; s<fineSpin; s++)
69 #pragma unroll
70  for (int fine_color_local=0; fine_color_local<fine_colors_per_thread; fine_color_local++)
71  out(spinor_parity, x_cb, s, fine_color_block+fine_color_local) = 0.0; // global fine color index
72 
73 #pragma unroll
74  for (int s=0; s<fineSpin; s++) {
75 #pragma unroll
76  for (int fine_color_local=0; fine_color_local<fine_colors_per_thread; fine_color_local++) {
77  int i = fine_color_block + fine_color_local; // global fine color index
78 
79  complex<Float> partial[color_unroll];
80 #pragma unroll
81  for (int k=0; k<color_unroll; k++) partial[k] = 0.0;
82 
83 #pragma unroll
84  for (int j=0; j<coarseColor; j+=color_unroll) {
85  // V is a ColorMatrixField with internal dimensions Ns * Nc * Nvec
86 #pragma unroll
87  for (int k=0; k<color_unroll; k++)
88  partial[k] += V(v_parity, x_cb, s, i, j+k) * in[s*coarseColor + j + k];
89  }
90 
91 #pragma unroll
92  for (int k=0; k<color_unroll; k++) out(spinor_parity, x_cb, s, i) += partial[k];
93  }
94  }
95 
96  }
97 
98  template <typename Float, int fineSpin, int fineColor, int coarseSpin, int coarseColor, int fine_colors_per_thread, typename Arg>
99  void Prolongate(Arg &arg) {
100  for (int parity=0; parity<arg.nParity; parity++) {
101  parity = (arg.nParity == 2) ? parity : arg.parity;
102 
103  for (int x_cb=0; x_cb<arg.out.VolumeCB(); x_cb++) {
104  complex<Float> tmp[fineSpin*coarseColor];
105  prolongate<Float,fineSpin,coarseColor>(tmp, arg.in, parity, x_cb, arg.geo_map, arg.spin_map, arg.out.VolumeCB());
106  for (int fine_color_block=0; fine_color_block<fineColor; fine_color_block+=fine_colors_per_thread) {
107  rotateFineColor<Float,fineSpin,fineColor,coarseColor,fine_colors_per_thread>
108  (arg.out, tmp, arg.V, parity, arg.nParity, x_cb, fine_color_block);
109  }
110  }
111  }
112  }
113 
114  template <typename Float, int fineSpin, int fineColor, int coarseSpin, int coarseColor, int fine_colors_per_thread, typename Arg>
115  __global__ void ProlongateKernel(Arg arg) {
116  int x_cb = blockIdx.x*blockDim.x + threadIdx.x;
117  int parity = arg.nParity == 2 ? blockDim.y*blockIdx.y + threadIdx.y : arg.parity;
118  if (x_cb >= arg.out.VolumeCB()) return;
119 
120  int fine_color_block = (blockDim.z*blockIdx.z + threadIdx.z) * fine_colors_per_thread;
121  if (fine_color_block >= fineColor) return;
122 
123  complex<Float> tmp[fineSpin*coarseColor];
124  prolongate<Float,fineSpin,coarseColor>(tmp, arg.in, parity, x_cb, arg.geo_map, arg.spin_map, arg.out.VolumeCB());
125  rotateFineColor<Float,fineSpin,fineColor,coarseColor,fine_colors_per_thread>
126  (arg.out, tmp, arg.V, parity, arg.nParity, x_cb, fine_color_block);
127  }
128 
129  template <typename Float, typename vFloat, int fineSpin, int fineColor, int coarseSpin, int coarseColor, int fine_colors_per_thread>
130  class ProlongateLaunch : public TunableVectorYZ {
131 
132  protected:
133  ColorSpinorField &out;
134  const ColorSpinorField &in;
135  const ColorSpinorField &V;
136  const int *fine_to_coarse;
137  int parity;
138  QudaFieldLocation location;
139  char vol[TuneKey::volume_n];
140 
141  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
142  unsigned int minThreads() const { return out.VolumeCB(); } // fine parity is the block y dimension
143 
144  public:
145  ProlongateLaunch(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &V,
146  const int *fine_to_coarse, int parity)
147  : TunableVectorYZ(out.SiteSubset(), fineColor/fine_colors_per_thread), out(out), in(in), V(V),
148  fine_to_coarse(fine_to_coarse), parity(parity), location(checkLocation(out, in, V))
149  {
150  strcpy(vol, out.VolString());
151  strcat(vol, ",");
152  strcat(vol, in.VolString());
153 
154  strcpy(aux, out.AuxString());
155  strcat(aux, ",");
156  strcat(aux, in.AuxString());
157  }
158 
159  virtual ~ProlongateLaunch() { }
160 
161  void apply(const cudaStream_t &stream) {
162  if (location == QUDA_CPU_FIELD_LOCATION) {
163  if (out.FieldOrder() == QUDA_SPACE_SPIN_COLOR_FIELD_ORDER) {
164  ProlongateArg<Float,vFloat,fineSpin,fineColor,coarseSpin,coarseColor,QUDA_SPACE_SPIN_COLOR_FIELD_ORDER>
165  arg(out, in, V, fine_to_coarse, parity);
166  Prolongate<Float,fineSpin,fineColor,coarseSpin,coarseColor,fine_colors_per_thread>(arg);
167  } else {
168  errorQuda("Unsupported field order %d", out.FieldOrder());
169  }
170  } else {
171  if (out.FieldOrder() == QUDA_FLOAT2_FIELD_ORDER) {
172  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
173  ProlongateArg<Float,vFloat,fineSpin,fineColor,coarseSpin,coarseColor,QUDA_FLOAT2_FIELD_ORDER>
174  arg(out, in, V, fine_to_coarse, parity);
175  ProlongateKernel<Float,fineSpin,fineColor,coarseSpin,coarseColor,fine_colors_per_thread>
176  <<<tp.grid, tp.block, tp.shared_bytes, stream>>>(arg);
177  } else {
178  errorQuda("Unsupported field order %d", out.FieldOrder());
179  }
180  }
181  }
182 
183  TuneKey tuneKey() const { return TuneKey(vol, typeid(*this).name(), aux); }
184 
185  long long flops() const { return 8 * fineSpin * fineColor * coarseColor * out.SiteSubset()*(long long)out.VolumeCB(); }
186 
187  long long bytes() const {
188  size_t v_bytes = V.Bytes() / (V.SiteSubset() == out.SiteSubset() ? 1 : 2);
189  return in.Bytes() + out.Bytes() + v_bytes + out.SiteSubset()*out.VolumeCB()*sizeof(int);
190  }
191 
192  };
193 
194  template <typename Float, int fineSpin, int fineColor, int coarseSpin, int coarseColor>
195  void Prolongate(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v,
196  const int *fine_to_coarse, int parity) {
197 
198  // for all grids use 1 color per thread
199  constexpr int fine_colors_per_thread = 1;
200 
201  if (v.Precision() == QUDA_HALF_PRECISION) {
202 #if QUDA_PRECISION & 2
203  ProlongateLaunch<Float, short, fineSpin, fineColor, coarseSpin, coarseColor, fine_colors_per_thread>
204  prolongator(out, in, v, fine_to_coarse, parity);
205  prolongator.apply(0);
206 #else
207  errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION);
208 #endif
209  } else if (v.Precision() == in.Precision()) {
210  ProlongateLaunch<Float, Float, fineSpin, fineColor, coarseSpin, coarseColor, fine_colors_per_thread>
211  prolongator(out, in, v, fine_to_coarse, parity);
212  prolongator.apply(0);
213  } else {
214  errorQuda("Unsupported V precision %d", v.Precision());
215  }
216 
218  }
219 
220 
221  template <typename Float, int fineSpin>
222  void Prolongate(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v,
223  int nVec, const int *fine_to_coarse, const int * const * spin_map, int parity) {
224 
225  if (in.Nspin() != 2) errorQuda("Coarse spin %d is not supported", in.Nspin());
226  const int coarseSpin = 2;
227 
228  // first check that the spin_map matches the spin_mapper
229  spin_mapper<fineSpin,coarseSpin> mapper;
230  for (int s=0; s<fineSpin; s++)
231  for (int p=0; p<2; p++)
232  if (mapper(s,p) != spin_map[s][p]) errorQuda("Spin map does not match spin_mapper");
233 
234  if (out.Ncolor() == 3) {
235  const int fineColor = 3;
236  if (nVec == 4) {
237  Prolongate<Float,fineSpin,fineColor,coarseSpin,4>(out, in, v, fine_to_coarse, parity);
238  } else if (nVec == 6) { // Free field Wilson
239  Prolongate<Float,fineSpin,fineColor,coarseSpin,6>(out, in, v, fine_to_coarse, parity);
240  } else if (nVec == 24) {
241  Prolongate<Float,fineSpin,fineColor,coarseSpin,24>(out, in, v, fine_to_coarse, parity);
242  } else if (nVec == 32) {
243  Prolongate<Float,fineSpin,fineColor,coarseSpin,32>(out, in, v, fine_to_coarse, parity);
244  } else {
245  errorQuda("Unsupported nVec %d", nVec);
246  }
247  } else if (out.Ncolor() == 6) { // for coarsening coarsened Wilson free field.
248  const int fineColor = 6;
249  if (nVec == 6) { // these are probably only for debugging only
250  Prolongate<Float,fineSpin,fineColor,coarseSpin,6>(out, in, v, fine_to_coarse, parity);
251  } else {
252  errorQuda("Unsupported nVec %d", nVec);
253  }
254  } else if (out.Ncolor() == 24) {
255  const int fineColor = 24;
256  if (nVec == 24) { // to keep compilation under control coarse grids have same or more colors
257  Prolongate<Float,fineSpin,fineColor,coarseSpin,24>(out, in, v, fine_to_coarse, parity);
258  } else if (nVec == 32) {
259  Prolongate<Float,fineSpin,fineColor,coarseSpin,32>(out, in, v, fine_to_coarse, parity);
260  } else {
261  errorQuda("Unsupported nVec %d", nVec);
262  }
263  } else if (out.Ncolor() == 32) {
264  const int fineColor = 32;
265  if (nVec == 32) {
266  Prolongate<Float,fineSpin,fineColor,coarseSpin,32>(out, in, v, fine_to_coarse, parity);
267  } else {
268  errorQuda("Unsupported nVec %d", nVec);
269  }
270  } else {
271  errorQuda("Unsupported nColor %d", out.Ncolor());
272  }
273  }
274 
275  template <typename Float>
276  void Prolongate(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v,
277  int Nvec, const int *fine_to_coarse, const int * const * spin_map, int parity) {
278 
279  if (out.Nspin() == 2) {
280  Prolongate<Float,2>(out, in, v, Nvec, fine_to_coarse, spin_map, parity);
281 #ifdef GPU_WILSON_DIRAC
282  } else if (out.Nspin() == 4) {
283  Prolongate<Float,4>(out, in, v, Nvec, fine_to_coarse, spin_map, parity);
284 #endif
285 #ifdef GPU_STAGGERED_DIRAC
286  } else if (out.Nspin() == 1) {
287  Prolongate<Float,1>(out, in, v, Nvec, fine_to_coarse, spin_map, parity);
288 #endif
289  } else {
290  errorQuda("Unsupported nSpin %d", out.Nspin());
291  }
292  }
293 
294 #endif // GPU_MULTIGRID
295 
297  int Nvec, const int *fine_to_coarse, const int * const * spin_map, int parity) {
298 #ifdef GPU_MULTIGRID
299  if (out.FieldOrder() != in.FieldOrder() || out.FieldOrder() != v.FieldOrder())
300  errorQuda("Field orders do not match (out=%d, in=%d, v=%d)",
301  out.FieldOrder(), in.FieldOrder(), v.FieldOrder());
302 
303  QudaPrecision precision = checkPrecision(out, in);
304 
305  if (precision == QUDA_DOUBLE_PRECISION) {
306 #ifdef GPU_MULTIGRID_DOUBLE
307  Prolongate<double>(out, in, v, Nvec, fine_to_coarse, spin_map, parity);
308 #else
309  errorQuda("Double precision multigrid has not been enabled");
310 #endif
311  } else if (precision == QUDA_SINGLE_PRECISION) {
312  Prolongate<float>(out, in, v, Nvec, fine_to_coarse, spin_map, parity);
313  } else {
314  errorQuda("Unsupported precision %d", out.Precision());
315  }
316 
318 #else
319  errorQuda("Multigrid has not been built");
320 #endif
321  }
322 
323 } // end namespace quda
enum QudaPrecision_s QudaPrecision
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define checkPrecision(...)
#define errorQuda(...)
Definition: util_quda.h:121
cudaStream_t * stream
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:44
cpuColorSpinorField * in
This is just a dummy structure we use for trove to define the required structure size.
__device__ __host__ int VolumeCB() const
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
#define checkLocation(...)
int V
Definition: test_util.cpp:27
enum QudaFieldLocation_s QudaFieldLocation
cpuColorSpinorField * out
__shared__ float s[]
unsigned long long flops
Definition: blas_quda.cu:22
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
#define checkCudaError()
Definition: util_quda.h:161
static const int volume_n
Definition: tune_key.h:10
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
QudaPrecision Precision() const
void Prolongate(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v, int Nvec, const int *fine_to_coarse, const int *const *spin_map, int parity=QUDA_INVALID_PARITY)
Apply the prolongation operator.
Definition: prolongator.cu:296
QudaParity parity
Definition: covdev_test.cpp:54
QudaFieldOrder FieldOrder() const
unsigned long long bytes
Definition: blas_quda.cu:23