1 #include <color_spinor_field.h>
3 #include <launch_kernel.cuh>
5 #include <jitify_helper.cuh>
6 #include <kernels/restrictor.cuh>
10 template <typename Float, typename vFloat, int fineSpin, int fineColor, int coarseSpin, int coarseColor,
11 int coarse_colors_per_thread>
12 class RestrictLaunch : public Tunable {
15 ColorSpinorField &out;
16 const ColorSpinorField ∈
17 const ColorSpinorField &v;
18 const int *fine_to_coarse;
19 const int *coarse_to_fine;
21 const QudaFieldLocation location;
23 char vol[TuneKey::volume_n];
25 unsigned int sharedBytesPerThread() const { return 0; }
26 unsigned int sharedBytesPerBlock(const TuneParam ¶m) const { return 0; }
27 bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
28 bool tuneAuxDim() const { return true; } // Do tune the aux dimensions.
29 unsigned int minThreads() const { return in.VolumeCB(); } // fine parity is the block y dimension
32 RestrictLaunch(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v,
33 const int *fine_to_coarse, const int *coarse_to_fine, int parity)
34 : out(out), in(in), v(v), fine_to_coarse(fine_to_coarse), coarse_to_fine(coarse_to_fine),
35 parity(parity), location(checkLocation(out,in,v)), block_size(in.VolumeCB()/(2*out.VolumeCB()))
37 if (v.Location() == QUDA_CUDA_FIELD_LOCATION) {
39 create_jitify_program("kernels/restrictor.cuh");
42 strcpy(aux, compile_type_str(in));
43 strcat(aux, out.AuxString());
45 strcat(aux, in.AuxString());
47 strcpy(vol, out.VolString());
49 strcat(vol, in.VolString());
50 } // block size is checkerboard fine length / full coarse length
52 void apply(const qudaStream_t &stream) {
53 if (location == QUDA_CPU_FIELD_LOCATION) {
54 if (out.FieldOrder() == QUDA_SPACE_SPIN_COLOR_FIELD_ORDER) {
55 RestrictArg<Float,vFloat,fineSpin,fineColor,coarseSpin,coarseColor,QUDA_SPACE_SPIN_COLOR_FIELD_ORDER>
56 arg(out, in, v, fine_to_coarse, coarse_to_fine, parity);
57 Restrict<Float,fineSpin,fineColor,coarseSpin,coarseColor,coarse_colors_per_thread>(arg);
59 errorQuda("Unsupported field order %d", out.FieldOrder());
62 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
64 if (out.FieldOrder() == QUDA_FLOAT2_FIELD_ORDER) {
65 typedef RestrictArg<Float,vFloat,fineSpin,fineColor,coarseSpin,coarseColor,QUDA_FLOAT2_FIELD_ORDER> Arg;
66 Arg arg(out, in, v, fine_to_coarse, coarse_to_fine, parity);
67 arg.swizzle = tp.aux.x;
70 using namespace jitify::reflection;
71 jitify_error = program->kernel("quda::RestrictKernel")
72 .instantiate((int)tp.block.x,Type<Float>(),fineSpin,fineColor,coarseSpin,coarseColor,coarse_colors_per_thread,Type<Arg>())
73 .configure(tp.grid,tp.block,tp.shared_bytes,stream).launch(arg);
75 LAUNCH_KERNEL_MG_BLOCK_SIZE(RestrictKernel,tp,stream,arg,Float,fineSpin,fineColor,
76 coarseSpin,coarseColor,coarse_colors_per_thread,Arg);
79 errorQuda("Unsupported field order %d", out.FieldOrder());
84 // This block tuning tunes for the optimal amount of color
85 // splitting between blockDim.z and gridDim.z. However, enabling
86 // blockDim.z > 1 gives incorrect results due to cub reductions
87 // being unable to do independent sliced reductions along
88 // blockDim.z. So for now we only split between colors per thread
90 bool advanceBlockDim(TuneParam ¶m) const
92 // let's try to advance spin/block-color
93 while(param.block.z <= coarseColor/coarse_colors_per_thread) {
95 if ( (coarseColor/coarse_colors_per_thread) % param.block.z == 0) {
96 param.grid.z = (coarseColor/coarse_colors_per_thread) / param.block.z;
101 // we can advance spin/block-color since this is valid
102 if (param.block.z <= (coarseColor/coarse_colors_per_thread) ) { //
104 } else { // we have run off the end so let's reset
106 param.grid.z = coarseColor/coarse_colors_per_thread;
111 int tuningIter() const { return 3; }
113 bool advanceAux(TuneParam ¶m) const
116 if (param.aux.x < 2*deviceProp.multiProcessorCount) {
128 // only tune shared memory per thread (disable tuning for block.z for now)
129 bool advanceTuneParam(TuneParam ¶m) const { return advanceSharedBytes(param) || advanceAux(param); }
131 TuneKey tuneKey() const { return TuneKey(vol, typeid(*this).name(), aux); }
133 void initTuneParam(TuneParam ¶m) const { defaultTuneParam(param); }
135 /** sets default values for when tuning is disabled */
136 void defaultTuneParam(TuneParam ¶m) const {
137 param.block = dim3(block_size, in.SiteSubset(), 1);
138 param.grid = dim3( (minThreads()+param.block.x-1) / param.block.x, 1, 1);
139 param.shared_bytes = 0;
142 param.grid.z = coarseColor / coarse_colors_per_thread;
143 param.aux.x = 1; // swizzle factor
146 long long flops() const { return 8 * fineSpin * fineColor * coarseColor * in.SiteSubset()*(long long)in.VolumeCB(); }
148 long long bytes() const {
149 size_t v_bytes = v.Bytes() / (v.SiteSubset() == in.SiteSubset() ? 1 : 2);
150 return in.Bytes() + out.Bytes() + v_bytes + in.SiteSubset()*in.VolumeCB()*sizeof(int);
155 template <typename Float, int fineSpin, int fineColor, int coarseSpin, int coarseColor>
156 void Restrict(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v,
157 const int *fine_to_coarse, const int *coarse_to_fine, int parity) {
159 // for fine grids (Nc=3) have more parallelism so can use more coarse strategy
160 constexpr int coarse_colors_per_thread = fineColor != 3 ? 2 : coarseColor >= 4 && coarseColor % 4 == 0 ? 4 : 2;
161 //coarseColor >= 8 && coarseColor % 8 == 0 ? 8 : coarseColor >= 4 && coarseColor % 4 == 0 ? 4 : 2;
163 if (v.Precision() == QUDA_HALF_PRECISION) {
164 #if QUDA_PRECISION & 2
165 RestrictLaunch<Float, short, fineSpin, fineColor, coarseSpin, coarseColor, coarse_colors_per_thread>
166 restrictor(out, in, v, fine_to_coarse, coarse_to_fine, parity);
169 errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION);
171 } else if (v.Precision() == in.Precision()) {
172 RestrictLaunch<Float, Float, fineSpin, fineColor, coarseSpin, coarseColor, coarse_colors_per_thread>
173 restrictor(out, in, v, fine_to_coarse, coarse_to_fine, parity);
176 errorQuda("Unsupported V precision %d", v.Precision());
180 template <typename Float>
181 void Restrict(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v,
182 int nVec, const int *fine_to_coarse, const int *coarse_to_fine, const int * const * spin_map, int parity)
184 if (out.Nspin() != 2) errorQuda("Unsupported nSpin %d", out.Nspin());
185 constexpr int coarseSpin = 2;
187 // Template over fine color
188 if (in.Ncolor() == 3) { // standard QCD
189 constexpr int fineColor = 3;
191 if (in.Nspin() == 4) {
192 constexpr int fineSpin = 4;
194 // first check that the spin_map matches the spin_mapper
195 spin_mapper<fineSpin,coarseSpin> mapper;
196 for (int s=0; s<fineSpin; s++)
197 for (int p=0; p<2; p++)
198 if (mapper(s,p) != spin_map[s][p]) errorQuda("Spin map does not match spin_mapper");
200 if (nVec == 6) { // free field Wilson
201 Restrict<Float,fineSpin,fineColor,coarseSpin,6>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
202 } else if (nVec == 24) {
203 Restrict<Float,fineSpin,fineColor,coarseSpin,24>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
204 } else if (nVec == 32) {
205 Restrict<Float,fineSpin,fineColor,coarseSpin,32>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
207 errorQuda("Unsupported nVec %d", nVec);
212 if (in.Nspin() == 1) {
213 constexpr int fineSpin = 1;
215 // first check that the spin_map matches the spin_mapper
216 spin_mapper<fineSpin,coarseSpin> mapper;
217 for (int s=0; s<fineSpin; s++)
218 for (int p=0; p<2; p++)
219 if (mapper(s,p) != spin_map[s][p]) errorQuda("Spin map does not match spin_mapper");
221 if (nVec == 24) { // free field staggered
222 Restrict<Float,fineSpin,fineColor,coarseSpin,24>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
223 } else if (nVec == 64) {
224 Restrict<Float,fineSpin,fineColor,coarseSpin,64>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
225 } else if (nVec == 96) {
226 Restrict<Float,fineSpin,fineColor,coarseSpin,96>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
228 errorQuda("Unsupported nVec %d", nVec);
233 errorQuda("Unexpected nSpin = %d", in.Nspin());
238 if (in.Nspin() != 2) errorQuda("Unexpected nSpin = %d", in.Nspin());
239 constexpr int fineSpin = 2;
241 // first check that the spin_map matches the spin_mapper
242 spin_mapper<fineSpin,coarseSpin> mapper;
243 for (int s=0; s<fineSpin; s++)
244 for (int p=0; p<2; p++)
245 if (mapper(s,p) != spin_map[s][p]) errorQuda("Spin map does not match spin_mapper");
248 if (in.Ncolor() == 6) { // Coarsen coarsened Wilson free field
249 const int fineColor = 6;
251 Restrict<Float,fineSpin,fineColor,coarseSpin,6>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
253 errorQuda("Unsupported nVec %d", nVec);
257 if (in.Ncolor() == 24) { // to keep compilation under control coarse grids have same or more colors
258 const int fineColor = 24;
260 Restrict<Float,fineSpin,fineColor,coarseSpin,24>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
262 } else if (nVec == 32) {
263 Restrict<Float,fineSpin,fineColor,coarseSpin,32>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
266 } else if (nVec == 64) {
267 Restrict<Float,fineSpin,fineColor,coarseSpin,64>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
268 } else if (nVec == 96) {
269 Restrict<Float,fineSpin,fineColor,coarseSpin,96>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
272 errorQuda("Unsupported nVec %d", nVec);
275 } else if (in.Ncolor() == 32) {
276 const int fineColor = 32;
278 Restrict<Float,fineSpin,fineColor,coarseSpin,32>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
280 errorQuda("Unsupported nVec %d", nVec);
284 } else if (in.Ncolor() == 64) {
285 const int fineColor = 64;
287 Restrict<Float,fineSpin,fineColor,coarseSpin,64>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
288 } else if (nVec == 96) {
289 Restrict<Float,fineSpin,fineColor,coarseSpin,96>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
291 errorQuda("Unsupported nVec %d", nVec);
293 } else if (in.Ncolor() == 96) {
294 const int fineColor = 96;
296 Restrict<Float,fineSpin,fineColor,coarseSpin,96>(out, in, v, fine_to_coarse, coarse_to_fine, parity);
298 errorQuda("Unsupported nVec %d", nVec);
302 errorQuda("Unsupported nColor %d", in.Ncolor());
307 void Restrict(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v,
308 int Nvec, const int *fine_to_coarse, const int *coarse_to_fine, const int * const * spin_map, int parity)
311 if (out.FieldOrder() != in.FieldOrder() || out.FieldOrder() != v.FieldOrder())
312 errorQuda("Field orders do not match (out=%d, in=%d, v=%d)",
313 out.FieldOrder(), in.FieldOrder(), v.FieldOrder());
315 QudaPrecision precision = checkPrecision(out, in);
317 if (precision == QUDA_DOUBLE_PRECISION) {
318 #ifdef GPU_MULTIGRID_DOUBLE
319 Restrict<double>(out, in, v, Nvec, fine_to_coarse, coarse_to_fine, spin_map, parity);
321 errorQuda("Double precision multigrid has not been enabled");
323 } else if (precision == QUDA_SINGLE_PRECISION) {
324 Restrict<float>(out, in, v, Nvec, fine_to_coarse, coarse_to_fine, spin_map, parity);
326 errorQuda("Unsupported precision %d", out.Precision());
329 errorQuda("Multigrid has not been built");