15 template <
typename Float,
typename vFloat,
int fineSpin,
int fineColor,
int coarseSpin,
int coarseColor, QudaFieldOrder order>
16 struct ProlongateArg {
21 const spin_mapper<fineSpin,coarseSpin> spin_map;
25 ProlongateArg(ColorSpinorField &out,
const ColorSpinorField &in,
const ColorSpinorField &V,
26 const int *geo_map,
const int parity)
27 : out(out), in(in), V(V), geo_map(geo_map), spin_map(), parity(parity), nParity(out.SiteSubset()) { }
29 ProlongateArg(
const ProlongateArg<Float,vFloat,fineSpin,fineColor,coarseSpin,coarseColor,order> &
arg)
30 : out(arg.out), in(arg.in), V(arg.V), geo_map(arg.geo_map), spin_map(),
31 parity(arg.parity), nParity(arg.nParity) { }
37 template <
typename Float,
int fineSpin,
int coarseColor,
class Coarse,
typename S>
38 __device__ __host__
inline void prolongate(complex<Float>
out[fineSpin*coarseColor],
const Coarse &
in,
39 int parity,
int x_cb,
const int *geo_map,
const S& spin_map,
int fineVolumeCB) {
40 int x = parity*fineVolumeCB + x_cb;
41 int x_coarse = geo_map[x];
42 int parity_coarse = (x_coarse >= in.VolumeCB()) ? 1 : 0;
43 int x_coarse_cb = x_coarse - parity_coarse*in.
VolumeCB();
46 for (
int s=0;
s<fineSpin;
s++) {
48 for (
int c=0; c<coarseColor; c++) {
49 out[
s*coarseColor+c] =
in(parity_coarse, x_coarse_cb, spin_map(
s,parity), c);
58 template <
typename Float,
int fineSpin,
int fineColor,
int coarseColor,
int fine_colors_per_thread,
59 class FineColor,
class Rotator>
60 __device__ __host__
inline void rotateFineColor(FineColor &
out,
const complex<Float> in[fineSpin*coarseColor],
61 const Rotator &
V,
int parity,
int nParity,
int x_cb,
int fine_color_block) {
62 const int spinor_parity = (nParity == 2) ? parity : 0;
63 const int v_parity = (V.Nparity() == 2) ? parity : 0;
65 constexpr
int color_unroll = 2;
68 for (
int s=0;
s<fineSpin;
s++)
70 for (
int fine_color_local=0; fine_color_local<fine_colors_per_thread; fine_color_local++)
71 out(spinor_parity, x_cb,
s, fine_color_block+fine_color_local) = 0.0;
74 for (
int s=0;
s<fineSpin;
s++) {
76 for (
int fine_color_local=0; fine_color_local<fine_colors_per_thread; fine_color_local++) {
77 int i = fine_color_block + fine_color_local;
79 complex<Float> partial[color_unroll];
81 for (
int k=0; k<color_unroll; k++) partial[k] = 0.0;
84 for (
int j=0; j<coarseColor; j+=color_unroll) {
87 for (
int k=0; k<color_unroll; k++)
88 partial[k] +=
V(v_parity, x_cb,
s, i, j+k) * in[
s*coarseColor + j + k];
92 for (
int k=0; k<color_unroll; k++)
out(spinor_parity, x_cb,
s, i) += partial[k];
98 template <
typename Float,
int fineSpin,
int fineColor,
int coarseSpin,
int coarseColor,
int fine_colors_per_thread,
typename Arg>
100 for (
int parity=0; parity<arg.nParity; parity++) {
101 parity = (arg.nParity == 2) ? parity : arg.parity;
103 for (
int x_cb=0; x_cb<arg.out.VolumeCB(); x_cb++) {
104 complex<Float>
tmp[fineSpin*coarseColor];
105 prolongate<Float,fineSpin,coarseColor>(
tmp, arg.in,
parity, x_cb, arg.geo_map, arg.spin_map, arg.out.VolumeCB());
106 for (
int fine_color_block=0; fine_color_block<fineColor; fine_color_block+=fine_colors_per_thread) {
107 rotateFineColor<Float,fineSpin,fineColor,coarseColor,fine_colors_per_thread>
108 (arg.out,
tmp, arg.V,
parity, arg.nParity, x_cb, fine_color_block);
114 template <
typename Float,
int fineSpin,
int fineColor,
int coarseSpin,
int coarseColor,
int fine_colors_per_thread,
typename Arg>
115 __global__
void ProlongateKernel(Arg arg) {
116 int x_cb = blockIdx.x*blockDim.x + threadIdx.x;
117 int parity = arg.nParity == 2 ? blockDim.y*blockIdx.y + threadIdx.y : arg.parity;
118 if (x_cb >= arg.out.VolumeCB())
return;
120 int fine_color_block = (blockDim.z*blockIdx.z + threadIdx.z) * fine_colors_per_thread;
121 if (fine_color_block >= fineColor)
return;
123 complex<Float>
tmp[fineSpin*coarseColor];
124 prolongate<Float,fineSpin,coarseColor>(
tmp, arg.in,
parity, x_cb, arg.geo_map, arg.spin_map, arg.out.VolumeCB());
125 rotateFineColor<Float,fineSpin,fineColor,coarseColor,fine_colors_per_thread>
126 (arg.out,
tmp, arg.V,
parity, arg.nParity, x_cb, fine_color_block);
129 template <
typename Float,
typename vFloat,
int fineSpin,
int fineColor,
int coarseSpin,
int coarseColor,
int fine_colors_per_thread>
130 class ProlongateLaunch :
public TunableVectorYZ {
133 ColorSpinorField &
out;
134 const ColorSpinorField &
in;
135 const ColorSpinorField &
V;
136 const int *fine_to_coarse;
141 bool tuneGridDim()
const {
return false; }
142 unsigned int minThreads()
const {
return out.VolumeCB(); }
145 ProlongateLaunch(ColorSpinorField &out,
const ColorSpinorField &in,
const ColorSpinorField &V,
146 const int *fine_to_coarse,
int parity)
147 : TunableVectorYZ(out.SiteSubset(), fineColor/fine_colors_per_thread), out(out), in(in), V(V),
148 fine_to_coarse(fine_to_coarse), parity(parity), location(
checkLocation(out, in, V))
150 strcpy(vol, out.VolString());
152 strcat(vol, in.VolString());
154 strcpy(aux, out.AuxString());
156 strcat(aux, in.AuxString());
159 virtual ~ProlongateLaunch() { }
161 void apply(
const cudaStream_t &
stream) {
164 ProlongateArg<Float,vFloat,fineSpin,fineColor,coarseSpin,coarseColor,QUDA_SPACE_SPIN_COLOR_FIELD_ORDER>
165 arg(out, in, V, fine_to_coarse, parity);
166 Prolongate<Float,fineSpin,fineColor,coarseSpin,coarseColor,fine_colors_per_thread>(
arg);
168 errorQuda(
"Unsupported field order %d", out.FieldOrder());
173 ProlongateArg<Float,vFloat,fineSpin,fineColor,coarseSpin,coarseColor,QUDA_FLOAT2_FIELD_ORDER>
174 arg(out, in, V, fine_to_coarse, parity);
175 ProlongateKernel<Float,fineSpin,fineColor,coarseSpin,coarseColor,fine_colors_per_thread>
176 <<<tp.grid, tp.block, tp.shared_bytes, stream>>>(
arg);
178 errorQuda(
"Unsupported field order %d", out.FieldOrder());
183 TuneKey tuneKey()
const {
return TuneKey(vol,
typeid(*this).name(), aux); }
185 long long flops()
const {
return 8 * fineSpin * fineColor * coarseColor * out.SiteSubset()*(
long long)out.VolumeCB(); }
187 long long bytes()
const {
188 size_t v_bytes = V.Bytes() / (V.SiteSubset() == out.SiteSubset() ? 1 : 2);
189 return in.Bytes() + out.Bytes() + v_bytes + out.SiteSubset()*out.VolumeCB()*
sizeof(int);
194 template <
typename Float,
int fineSpin,
int fineColor,
int coarseSpin,
int coarseColor>
195 void Prolongate(ColorSpinorField &out,
const ColorSpinorField &in,
const ColorSpinorField &v,
196 const int *fine_to_coarse,
int parity) {
199 constexpr
int fine_colors_per_thread = 1;
202 #if QUDA_PRECISION & 2 203 ProlongateLaunch<Float, short, fineSpin, fineColor, coarseSpin, coarseColor, fine_colors_per_thread>
204 prolongator(out, in, v, fine_to_coarse, parity);
205 prolongator.apply(0);
207 errorQuda(
"QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION);
209 }
else if (v.Precision() == in.Precision()) {
210 ProlongateLaunch<Float, Float, fineSpin, fineColor, coarseSpin, coarseColor, fine_colors_per_thread>
211 prolongator(out, in, v, fine_to_coarse, parity);
212 prolongator.apply(0);
214 errorQuda(
"Unsupported V precision %d", v.Precision());
221 template <
typename Float,
int fineSpin>
222 void Prolongate(ColorSpinorField &out,
const ColorSpinorField &in,
const ColorSpinorField &v,
223 int nVec,
const int *fine_to_coarse,
const int *
const * spin_map,
int parity) {
225 if (in.Nspin() != 2)
errorQuda(
"Coarse spin %d is not supported", in.Nspin());
226 const int coarseSpin = 2;
229 spin_mapper<fineSpin,coarseSpin> mapper;
230 for (
int s=0;
s<fineSpin;
s++)
231 for (
int p=0; p<2; p++)
232 if (mapper(
s,p) != spin_map[
s][p])
errorQuda(
"Spin map does not match spin_mapper");
234 if (out.Ncolor() == 3) {
235 const int fineColor = 3;
237 Prolongate<Float,fineSpin,fineColor,coarseSpin,4>(
out,
in, v, fine_to_coarse,
parity);
238 }
else if (nVec == 6) {
239 Prolongate<Float,fineSpin,fineColor,coarseSpin,6>(
out,
in, v, fine_to_coarse,
parity);
240 }
else if (nVec == 24) {
241 Prolongate<Float,fineSpin,fineColor,coarseSpin,24>(
out,
in, v, fine_to_coarse,
parity);
242 }
else if (nVec == 32) {
243 Prolongate<Float,fineSpin,fineColor,coarseSpin,32>(
out,
in, v, fine_to_coarse,
parity);
247 }
else if (out.Ncolor() == 6) {
248 const int fineColor = 6;
250 Prolongate<Float,fineSpin,fineColor,coarseSpin,6>(
out,
in, v, fine_to_coarse,
parity);
254 }
else if (out.Ncolor() == 24) {
255 const int fineColor = 24;
257 Prolongate<Float,fineSpin,fineColor,coarseSpin,24>(
out,
in, v, fine_to_coarse,
parity);
258 }
else if (nVec == 32) {
259 Prolongate<Float,fineSpin,fineColor,coarseSpin,32>(
out,
in, v, fine_to_coarse,
parity);
263 }
else if (out.Ncolor() == 32) {
264 const int fineColor = 32;
266 Prolongate<Float,fineSpin,fineColor,coarseSpin,32>(
out,
in, v, fine_to_coarse,
parity);
271 errorQuda(
"Unsupported nColor %d", out.Ncolor());
275 template <
typename Float>
276 void Prolongate(ColorSpinorField &out,
const ColorSpinorField &in,
const ColorSpinorField &v,
277 int Nvec,
const int *fine_to_coarse,
const int *
const * spin_map,
int parity) {
279 if (out.Nspin() == 2) {
280 Prolongate<Float,2>(
out,
in, v, Nvec, fine_to_coarse, spin_map,
parity);
281 #ifdef GPU_WILSON_DIRAC 282 }
else if (out.Nspin() == 4) {
283 Prolongate<Float,4>(
out,
in, v, Nvec, fine_to_coarse, spin_map,
parity);
285 #ifdef GPU_STAGGERED_DIRAC 286 }
else if (out.Nspin() == 1) {
287 Prolongate<Float,1>(
out,
in, v, Nvec, fine_to_coarse, spin_map,
parity);
290 errorQuda(
"Unsupported nSpin %d", out.Nspin());
294 #endif // GPU_MULTIGRID 297 int Nvec,
const int *fine_to_coarse,
const int *
const * spin_map,
int parity) {
300 errorQuda(
"Field orders do not match (out=%d, in=%d, v=%d)",
306 #ifdef GPU_MULTIGRID_DOUBLE 307 Prolongate<double>(
out,
in, v, Nvec, fine_to_coarse, spin_map,
parity);
309 errorQuda(
"Double precision multigrid has not been enabled");
312 Prolongate<float>(
out,
in, v, Nvec, fine_to_coarse, spin_map,
parity);
319 errorQuda(
"Multigrid has not been built");
enum QudaPrecision_s QudaPrecision
QudaVerbosity getVerbosity()
#define checkPrecision(...)
cudaColorSpinorField * tmp
This is just a dummy structure we use for trove to define the required structure size.
__device__ __host__ int VolumeCB() const
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
#define checkLocation(...)
enum QudaFieldLocation_s QudaFieldLocation
cpuColorSpinorField * out
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
static const int volume_n
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
QudaPrecision Precision() const
void Prolongate(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v, int Nvec, const int *fine_to_coarse, const int *const *spin_map, int parity=QUDA_INVALID_PARITY)
Apply the prolongation operator.
QudaFieldOrder FieldOrder() const