1 #include <color_spinor_field.h>
2 #include <color_spinor_field_order.h>
4 #include <multigrid_helper.cuh>
9 using namespace quda::colorspinor;
12 Kernel argument struct
14 template <typename Float, typename vFloat, int fineSpin, int fineColor, int coarseSpin, int coarseColor, QudaFieldOrder order>
15 struct ProlongateArg {
16 FieldOrderCB<Float,fineSpin,fineColor,1,order> out;
17 const FieldOrderCB<Float,coarseSpin,coarseColor,1,order> in;
18 const FieldOrderCB<Float,fineSpin,fineColor,coarseColor,order,vFloat> V;
19 const int *geo_map; // need to make a device copy of this
20 const spin_mapper<fineSpin,coarseSpin> spin_map;
21 const int parity; // the parity of the output field (if single parity)
22 const int nParity; // number of parities of input fine field
24 ProlongateArg(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &V,
25 const int *geo_map, const int parity)
26 : out(out), in(in), V(V), geo_map(geo_map), spin_map(), parity(parity), nParity(out.SiteSubset()) { }
28 ProlongateArg(const ProlongateArg<Float,vFloat,fineSpin,fineColor,coarseSpin,coarseColor,order> &arg)
29 : out(arg.out), in(arg.in), V(arg.V), geo_map(arg.geo_map), spin_map(),
30 parity(arg.parity), nParity(arg.nParity) { }
34 Applies the grid prolongation operator (coarse to fine)
36 template <typename Float, int fineSpin, int coarseColor, class Coarse, typename S>
37 __device__ __host__ inline void prolongate(complex<Float> out[fineSpin*coarseColor], const Coarse &in,
38 int parity, int x_cb, const int *geo_map, const S& spin_map, int fineVolumeCB) {
39 int x = parity*fineVolumeCB + x_cb;
40 int x_coarse = geo_map[x];
41 int parity_coarse = (x_coarse >= in.VolumeCB()) ? 1 : 0;
42 int x_coarse_cb = x_coarse - parity_coarse*in.VolumeCB();
45 for (int s=0; s<fineSpin; s++) {
47 for (int c=0; c<coarseColor; c++) {
48 out[s*coarseColor+c] = in(parity_coarse, x_coarse_cb, spin_map(s,parity), c);
54 Rotates from the coarse-color basis into the fine-color basis. This
55 is the second step of applying the prolongator.
57 template <typename Float, int fineSpin, int fineColor, int coarseColor, int fine_colors_per_thread,
58 class FineColor, class Rotator>
59 __device__ __host__ inline void rotateFineColor(FineColor &out, const complex<Float> in[fineSpin*coarseColor],
60 const Rotator &V, int parity, int nParity, int x_cb, int fine_color_block) {
61 const int spinor_parity = (nParity == 2) ? parity : 0;
62 const int v_parity = (V.Nparity() == 2) ? parity : 0;
64 constexpr int color_unroll = 2;
67 for (int s=0; s<fineSpin; s++)
69 for (int fine_color_local=0; fine_color_local<fine_colors_per_thread; fine_color_local++)
70 out(spinor_parity, x_cb, s, fine_color_block+fine_color_local) = 0.0; // global fine color index
73 for (int s=0; s<fineSpin; s++) {
75 for (int fine_color_local=0; fine_color_local<fine_colors_per_thread; fine_color_local++) {
76 int i = fine_color_block + fine_color_local; // global fine color index
78 complex<Float> partial[color_unroll];
80 for (int k=0; k<color_unroll; k++) partial[k] = 0.0;
83 for (int j=0; j<coarseColor; j+=color_unroll) {
84 // V is a ColorMatrixField with internal dimensions Ns * Nc * Nvec
86 for (int k=0; k<color_unroll; k++)
87 partial[k] += V(v_parity, x_cb, s, i, j+k) * in[s*coarseColor + j + k];
91 for (int k=0; k<color_unroll; k++) out(spinor_parity, x_cb, s, i) += partial[k];
97 template <typename Float, int fineSpin, int fineColor, int coarseSpin, int coarseColor, int fine_colors_per_thread, typename Arg>
98 void Prolongate(Arg &arg) {
99 for (int parity=0; parity<arg.nParity; parity++) {
100 parity = (arg.nParity == 2) ? parity : arg.parity;
102 for (int x_cb=0; x_cb<arg.out.VolumeCB(); x_cb++) {
103 complex<Float> tmp[fineSpin*coarseColor];
104 prolongate<Float,fineSpin,coarseColor>(tmp, arg.in, parity, x_cb, arg.geo_map, arg.spin_map, arg.out.VolumeCB());
105 for (int fine_color_block=0; fine_color_block<fineColor; fine_color_block+=fine_colors_per_thread) {
106 rotateFineColor<Float,fineSpin,fineColor,coarseColor,fine_colors_per_thread>
107 (arg.out, tmp, arg.V, parity, arg.nParity, x_cb, fine_color_block);
113 template <typename Float, int fineSpin, int fineColor, int coarseSpin, int coarseColor, int fine_colors_per_thread, typename Arg>
114 __global__ void ProlongateKernel(Arg arg) {
115 int x_cb = blockIdx.x*blockDim.x + threadIdx.x;
116 int parity = arg.nParity == 2 ? blockDim.y*blockIdx.y + threadIdx.y : arg.parity;
117 if (x_cb >= arg.out.VolumeCB()) return;
119 int fine_color_block = (blockDim.z*blockIdx.z + threadIdx.z) * fine_colors_per_thread;
120 if (fine_color_block >= fineColor) return;
122 complex<Float> tmp[fineSpin*coarseColor];
123 prolongate<Float,fineSpin,coarseColor>(tmp, arg.in, parity, x_cb, arg.geo_map, arg.spin_map, arg.out.VolumeCB());
124 rotateFineColor<Float,fineSpin,fineColor,coarseColor,fine_colors_per_thread>
125 (arg.out, tmp, arg.V, parity, arg.nParity, x_cb, fine_color_block);
128 template <typename Float, typename vFloat, int fineSpin, int fineColor, int coarseSpin, int coarseColor, int fine_colors_per_thread>
129 class ProlongateLaunch : public TunableVectorYZ {
131 ColorSpinorField &out;
132 const ColorSpinorField ∈
133 const ColorSpinorField &V;
134 const int *fine_to_coarse;
136 QudaFieldLocation location;
137 char vol[TuneKey::volume_n];
139 bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
140 unsigned int minThreads() const { return out.VolumeCB(); } // fine parity is the block y dimension
143 ProlongateLaunch(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &V,
144 const int *fine_to_coarse, int parity)
145 : TunableVectorYZ(out.SiteSubset(), fineColor/fine_colors_per_thread), out(out), in(in), V(V),
146 fine_to_coarse(fine_to_coarse), parity(parity), location(checkLocation(out, in, V))
148 strcpy(vol, out.VolString());
150 strcat(vol, in.VolString());
152 strcpy(aux, out.AuxString());
154 strcat(aux, in.AuxString());
157 void apply(const qudaStream_t &stream) {
158 if (location == QUDA_CPU_FIELD_LOCATION) {
159 if (out.FieldOrder() == QUDA_SPACE_SPIN_COLOR_FIELD_ORDER) {
160 ProlongateArg<Float,vFloat,fineSpin,fineColor,coarseSpin,coarseColor,QUDA_SPACE_SPIN_COLOR_FIELD_ORDER>
161 arg(out, in, V, fine_to_coarse, parity);
162 Prolongate<Float,fineSpin,fineColor,coarseSpin,coarseColor,fine_colors_per_thread>(arg);
164 errorQuda("Unsupported field order %d", out.FieldOrder());
167 if (out.FieldOrder() == QUDA_FLOAT2_FIELD_ORDER) {
168 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
169 ProlongateArg<Float,vFloat,fineSpin,fineColor,coarseSpin,coarseColor,QUDA_FLOAT2_FIELD_ORDER>
170 arg(out, in, V, fine_to_coarse, parity);
171 qudaLaunchKernel(ProlongateKernel<Float,fineSpin,fineColor,coarseSpin,coarseColor,fine_colors_per_thread,decltype(arg)>,
174 errorQuda("Unsupported field order %d", out.FieldOrder());
179 TuneKey tuneKey() const { return TuneKey(vol, typeid(*this).name(), aux); }
181 long long flops() const { return 8 * fineSpin * fineColor * coarseColor * out.SiteSubset()*(long long)out.VolumeCB(); }
183 long long bytes() const {
184 size_t v_bytes = V.Bytes() / (V.SiteSubset() == out.SiteSubset() ? 1 : 2);
185 return in.Bytes() + out.Bytes() + v_bytes + out.SiteSubset()*out.VolumeCB()*sizeof(int);
190 template <typename Float, int fineSpin, int fineColor, int coarseSpin, int coarseColor>
191 void Prolongate(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v,
192 const int *fine_to_coarse, int parity) {
194 // for all grids use 1 color per thread
195 constexpr int fine_colors_per_thread = 1;
197 if (v.Precision() == QUDA_HALF_PRECISION) {
198 #if QUDA_PRECISION & 2
199 ProlongateLaunch<Float, short, fineSpin, fineColor, coarseSpin, coarseColor, fine_colors_per_thread>
200 prolongator(out, in, v, fine_to_coarse, parity);
201 prolongator.apply(0);
203 errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION);
205 } else if (v.Precision() == in.Precision()) {
206 ProlongateLaunch<Float, Float, fineSpin, fineColor, coarseSpin, coarseColor, fine_colors_per_thread>
207 prolongator(out, in, v, fine_to_coarse, parity);
208 prolongator.apply(0);
210 errorQuda("Unsupported V precision %d", v.Precision());
214 template <typename Float, int fineSpin>
215 void Prolongate(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v,
216 int nVec, const int *fine_to_coarse, const int * const * spin_map, int parity) {
218 if (in.Nspin() != 2) errorQuda("Coarse spin %d is not supported", in.Nspin());
219 const int coarseSpin = 2;
221 // first check that the spin_map matches the spin_mapper
222 spin_mapper<fineSpin,coarseSpin> mapper;
223 for (int s=0; s<fineSpin; s++)
224 for (int p=0; p<2; p++)
225 if (mapper(s,p) != spin_map[s][p]) errorQuda("Spin map does not match spin_mapper");
227 if (out.Ncolor() == 3) {
228 const int fineColor = 3;
230 if (nVec == 6) { // Free field Wilson
231 Prolongate<Float,fineSpin,fineColor,coarseSpin,6>(out, in, v, fine_to_coarse, parity);
235 Prolongate<Float,fineSpin,fineColor,coarseSpin,24>(out, in, v, fine_to_coarse, parity);
237 } else if (nVec == 32) {
238 Prolongate<Float,fineSpin,fineColor,coarseSpin,32>(out, in, v, fine_to_coarse, parity);
241 } else if (nVec == 64) {
242 Prolongate<Float,fineSpin,fineColor,coarseSpin,64>(out, in, v, fine_to_coarse, parity);
243 } else if (nVec == 96) {
244 Prolongate<Float,fineSpin,fineColor,coarseSpin,96>(out, in, v, fine_to_coarse, parity);
247 errorQuda("Unsupported nVec %d", nVec);
250 } else if (out.Ncolor() == 6) { // for coarsening coarsened Wilson free field.
251 const int fineColor = 6;
252 if (nVec == 6) { // these are probably only for debugging only
253 Prolongate<Float,fineSpin,fineColor,coarseSpin,6>(out, in, v, fine_to_coarse, parity);
255 errorQuda("Unsupported nVec %d", nVec);
258 } else if (out.Ncolor() == 24) {
259 const int fineColor = 24;
260 if (nVec == 24) { // to keep compilation under control coarse grids have same or more colors
261 Prolongate<Float,fineSpin,fineColor,coarseSpin,24>(out, in, v, fine_to_coarse, parity);
263 } else if (nVec == 32) {
264 Prolongate<Float,fineSpin,fineColor,coarseSpin,32>(out, in, v, fine_to_coarse, parity);
267 } else if (nVec == 64) {
268 Prolongate<Float,fineSpin,fineColor,coarseSpin,64>(out, in, v, fine_to_coarse, parity);
269 } else if (nVec == 96) {
270 Prolongate<Float,fineSpin,fineColor,coarseSpin,96>(out, in, v, fine_to_coarse, parity);
273 errorQuda("Unsupported nVec %d", nVec);
276 } else if (out.Ncolor() == 32) {
277 const int fineColor = 32;
279 Prolongate<Float,fineSpin,fineColor,coarseSpin,32>(out, in, v, fine_to_coarse, parity);
281 errorQuda("Unsupported nVec %d", nVec);
285 } else if (out.Ncolor() == 64) {
286 const int fineColor = 64;
288 Prolongate<Float,fineSpin,fineColor,coarseSpin,64>(out, in, v, fine_to_coarse, parity);
289 } else if (nVec == 96) {
290 Prolongate<Float,fineSpin,fineColor,coarseSpin,96>(out, in, v, fine_to_coarse, parity);
292 errorQuda("Unsupported nVec %d", nVec);
294 } else if (out.Ncolor() == 96) {
295 const int fineColor = 96;
297 Prolongate<Float,fineSpin,fineColor,coarseSpin,96>(out, in, v, fine_to_coarse, parity);
299 errorQuda("Unsupported nVec %d", nVec);
303 errorQuda("Unsupported nColor %d", out.Ncolor());
307 template <typename Float>
308 void Prolongate(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v,
309 int Nvec, const int *fine_to_coarse, const int * const * spin_map, int parity) {
311 if (out.Nspin() == 2) {
312 Prolongate<Float,2>(out, in, v, Nvec, fine_to_coarse, spin_map, parity);
314 } else if (out.Nspin() == 4) {
315 Prolongate<Float,4>(out, in, v, Nvec, fine_to_coarse, spin_map, parity);
318 } else if (out.Nspin() == 1) {
319 Prolongate<Float,1>(out, in, v, Nvec, fine_to_coarse, spin_map, parity);
322 errorQuda("Unsupported nSpin %d", out.Nspin());
326 #endif // GPU_MULTIGRID
328 void Prolongate(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v,
329 int Nvec, const int *fine_to_coarse, const int * const * spin_map, int parity) {
331 if (out.FieldOrder() != in.FieldOrder() || out.FieldOrder() != v.FieldOrder())
332 errorQuda("Field orders do not match (out=%d, in=%d, v=%d)",
333 out.FieldOrder(), in.FieldOrder(), v.FieldOrder());
335 QudaPrecision precision = checkPrecision(out, in);
337 if (precision == QUDA_DOUBLE_PRECISION) {
338 #ifdef GPU_MULTIGRID_DOUBLE
339 Prolongate<double>(out, in, v, Nvec, fine_to_coarse, spin_map, parity);
341 errorQuda("Double precision multigrid has not been enabled");
343 } else if (precision == QUDA_SINGLE_PRECISION) {
344 Prolongate<float>(out, in, v, Nvec, fine_to_coarse, spin_map, parity);
346 errorQuda("Unsupported precision %d", out.Precision());
349 errorQuda("Multigrid has not been built");
353 } // end namespace quda