3 #include <color_spinor_field_order.h>
4 #include <jitify_helper.cuh>
5 #include <kernels/reduce_core.cuh>
11 qudaStream_t* getStream();
13 template <int block_size, typename real, int len, typename Arg>
14 typename std::enable_if<block_size!=32, qudaError_t>::type launch(Arg &arg, const TuneParam &tp, const qudaStream_t &stream)
16 if (tp.block.x == block_size)
17 return qudaLaunchKernel(reduceKernel<block_size, real, len, Arg>, tp, stream, arg);
19 return launch<block_size - 32, real, len>(arg, tp, stream);
22 template <int block_size, typename real, int len, typename Arg>
23 typename std::enable_if<block_size==32, qudaError_t>::type launch(Arg &arg, const TuneParam &tp, const qudaStream_t &stream)
25 if (block_size != tp.block.x) errorQuda("Unexpected block size %d\n", tp.block.x);
26 return qudaLaunchKernel(reduceKernel<block_size, real, len, Arg>, tp, stream, arg);
29 #ifdef QUDA_FAST_COMPILE_REDUCE
30 constexpr static unsigned int max_block_size() { return 32; }
32 constexpr static unsigned int max_block_size() { return 1024; }
36 Generic reduction kernel launcher
38 template <typename host_reduce_t, typename real, int len, typename Arg>
39 auto reduceLaunch(Arg &arg, const TuneParam &tp, const qudaStream_t &stream, Tunable &tunable)
41 using device_reduce_t = typename Arg::Reducer::reduce_t;
42 if (tp.grid.x > (unsigned int)deviceProp.maxGridSize[0])
43 errorQuda("Grid size %d greater than maximum %d\n", tp.grid.x, deviceProp.maxGridSize[0]);
46 using namespace jitify::reflection;
47 tunable.jitifyError() = program->kernel("quda::blas::reduceKernel")
48 .instantiate((int)tp.block.x, Type<real>(), len, Type<Arg>())
49 .configure(tp.grid, tp.block, tp.shared_bytes, stream)
51 arg.launch_error = tunable.jitifyError() == CUDA_SUCCESS ? QUDA_SUCCESS : QUDA_ERROR;
53 arg.launch_error = launch<max_block_size(), real, len>(arg, tp, stream);
58 if (!commAsyncReduction()) arg.complete(result, stream);
62 template <template <typename ReducerType, typename real> class Reducer,
63 typename store_t, typename y_store_t, int nSpin, typename coeff_t>
64 class Reduce : public Tunable
66 using real = typename mapper<y_store_t>::type;
67 using host_reduce_t = typename Reducer<double, real>::reduce_t;
68 Reducer<device_reduce_t, real> r;
69 const int nParity; // for composite fields this includes the number of composites
70 host_reduce_t &result;
73 ColorSpinorField &x, &y, &z, &w, &v;
74 QudaFieldLocation location;
76 unsigned int sharedBytesPerThread() const { return 0; }
77 unsigned int sharedBytesPerBlock(const TuneParam ¶m) const { return 0; }
79 bool advanceSharedBytes(TuneParam ¶m) const
81 TuneParam next(param);
82 advanceBlockDim(next); // to get next blockDim
83 int nthreads = next.block.x * next.block.y * next.block.z;
84 param.shared_bytes = sharedBytesPerThread() * nthreads > sharedBytesPerBlock(param) ?
85 sharedBytesPerThread() * nthreads :
86 sharedBytesPerBlock(param);
90 unsigned int maxBlockSize(const TuneParam ¶m) const { return max_block_size(); }
93 Reduce(const coeff_t &a, const coeff_t &b, const coeff_t &c, ColorSpinorField &x, ColorSpinorField &y,
94 ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v, host_reduce_t &result) :
96 nParity((x.IsComposite() ? x.CompositeDim() : 1) * (x.SiteSubset())),
105 location(checkLocation(x, y, z, w, v))
107 checkLength(x, y, z, w, v);
108 auto x_prec = checkPrecision(x, z, w, v);
109 auto y_prec = y.Precision();
110 auto x_order = checkOrder(x, z, w, v);
111 auto y_order = y.FieldOrder();
112 if (sizeof(store_t) != x_prec) errorQuda("Expected precision %lu but received %d", sizeof(store_t), x_prec);
113 if (sizeof(y_store_t) != y_prec) errorQuda("Expected precision %lu but received %d", sizeof(y_store_t), y_prec);
114 if (x_prec == y_prec && x_order != y_order) errorQuda("Orders %d %d do not match", x_order, y_order);
116 strcpy(aux, x.AuxString());
117 if (x_prec != y_prec) {
119 strcat(aux, y.AuxString());
121 strcat(aux, nParity == 2 ? ",nParity=2" : ",nParity=1");
122 if (location == QUDA_CPU_FIELD_LOCATION) strcat(aux, ",CPU");
123 if (commAsyncReduction()) strcat(aux, ",async");
126 ::quda::create_jitify_program("kernels/reduce_core.cuh");
129 apply(*(blas::getStream()));
131 blas::bytes += bytes();
132 blas::flops += flops();
134 const int Nreduce = sizeof(host_reduce_t) / sizeof(double);
135 reduceDoubleArray((double *)&result, Nreduce);
138 TuneKey tuneKey() const { return TuneKey(x.VolString(), typeid(r).name(), aux); }
140 void apply(const qudaStream_t &stream)
142 constexpr bool site_unroll_check = !std::is_same<store_t, y_store_t>::value || isFixed<store_t>::value || decltype(r)::site_unroll;
143 if (site_unroll_check && (x.Ncolor() != 3 || x.Nspin() == 2))
144 errorQuda("site unroll not supported for nSpin = %d nColor = %d", x.Nspin(), x.Ncolor());
146 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
147 if (location == QUDA_CUDA_FIELD_LOCATION) {
148 if (site_unroll_check) checkNative(x, y, z, w, v); // require native order when using site_unroll
149 using device_store_t = typename device_type_mapper<store_t>::type;
150 using device_y_store_t = typename device_type_mapper<y_store_t>::type;
151 using device_real_t = typename mapper<device_y_store_t>::type;
152 Reducer<device_reduce_t, device_real_t> r_(a, b);
154 // redefine site_unroll with device_store types to ensure we have correct N/Ny/M values
155 constexpr bool site_unroll = !std::is_same<device_store_t, device_y_store_t>::value || isFixed<device_store_t>::value || decltype(r)::site_unroll;
156 constexpr int N = n_vector<device_store_t, true, nSpin, site_unroll>();
157 constexpr int Ny = n_vector<device_y_store_t, true, nSpin, site_unroll>();
158 constexpr int M = site_unroll ? (nSpin == 4 ? 24 : 6) : N; // real numbers per thread
159 const int length = x.Length() / (nParity * M);
161 ReductionArg<device_store_t, N, device_y_store_t, Ny, decltype(r_)> arg(x, y, z, w, v, r_, length, nParity, tp);
162 result = reduceLaunch<host_reduce_t, device_real_t, M>(arg, tp, stream, *this);
164 if (checkOrder(x, y, z, w, v) != QUDA_SPACE_SPIN_COLOR_FIELD_ORDER) {
165 warningQuda("CPU Blas functions expect AoS field order");
169 using host_store_t = typename host_type_mapper<store_t>::type;
170 using host_y_store_t = typename host_type_mapper<y_store_t>::type;
171 using host_real_t = typename mapper<host_y_store_t>::type;
172 Reducer<double, host_real_t> r_(a, b);
174 // redefine site_unroll with host_store types to ensure we have correct N/Ny/M values
175 constexpr bool site_unroll = !std::is_same<host_store_t, host_y_store_t>::value || isFixed<host_store_t>::value || decltype(r)::site_unroll;
176 constexpr int N = n_vector<host_store_t, false, nSpin, site_unroll>();
177 constexpr int Ny = n_vector<host_y_store_t, false, nSpin, site_unroll>();
178 constexpr int M = N; // if site unrolling then M=N will be 24/6, e.g., full AoS
179 const int length = x.Length() / (nParity * M);
181 ReductionArg<host_store_t, N, host_y_store_t, Ny, decltype(r_)> arg(x, y, z, w, v, r_, length, nParity, tp);
182 result = reduceCPU<host_real_t, M>(arg);
188 if (r.write.X) x.backup();
189 if (r.write.Y) y.backup();
190 if (r.write.Z) z.backup();
191 if (r.write.W) w.backup();
192 if (r.write.V) v.backup();
197 if (r.write.X) x.restore();
198 if (r.write.Y) y.restore();
199 if (r.write.Z) z.restore();
200 if (r.write.W) w.restore();
201 if (r.write.V) v.restore();
204 bool advanceTuneParam(TuneParam ¶m) const
206 return location == QUDA_CPU_FIELD_LOCATION ? false : Tunable::advanceTuneParam(param);
209 void initTuneParam(TuneParam ¶m) const
211 Tunable::initTuneParam(param);
214 void defaultTuneParam(TuneParam ¶m) const
216 Tunable::defaultTuneParam(param);
219 long long flops() const { return r.flops() * x.Length(); }
221 long long bytes() const
223 return (r.read.X + r.write.X) * x.Bytes() + (r.read.Y + r.write.Y) * y.Bytes() +
224 (r.read.Z + r.write.Z) * z.Bytes() + (r.read.W + r.write.W) * w.Bytes() + (r.read.V + r.write.V) * v.Bytes();
227 int tuningIter() const { return 3; }
230 template <template <typename reduce_t, typename real> class Functor, bool mixed, typename... Args>
231 auto instantiateReduce(Args &&... args)
233 using host_reduce_t = typename Functor<double, double>::reduce_t;
235 ::quda::zero(value); // no default constructor so we need to explicitly zero
236 instantiate<Functor, Reduce, mixed>(args..., value);
240 double norm1(const ColorSpinorField &x)
242 ColorSpinorField &y = const_cast<ColorSpinorField &>(x); // FIXME
243 return instantiateReduce<Norm1, false>(0.0, 0.0, 0.0, y, y, y, y, y);
246 double norm2(const ColorSpinorField &x)
248 ColorSpinorField &y = const_cast<ColorSpinorField &>(x);
249 return instantiateReduce<Norm2, false>(0.0, 0.0, 0.0, y, y, y, y, y);
252 double reDotProduct(ColorSpinorField &x, ColorSpinorField &y)
254 return instantiateReduce<Dot, false>(0.0, 0.0, 0.0, x, y, x, x, x);
257 double axpbyzNorm(double a, ColorSpinorField &x, double b, ColorSpinorField &y, ColorSpinorField &z)
259 return instantiateReduce<axpbyzNorm2, false>(a, b, 0.0, x, y, z, x, x);
262 double axpyReDot(double a, ColorSpinorField &x, ColorSpinorField &y)
264 return instantiateReduce<AxpyReDot, false>(a, 0.0, 0.0, x, y, x, x, x);
267 double caxpyNorm(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
269 return instantiateReduce<caxpyNorm2, false>(a, Complex(0.0), Complex(0.0), x, y, x, x, x);
272 double caxpyXmazNormX(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
274 return instantiateReduce<caxpyxmaznormx, false>(a, Complex(0.0), Complex(0.0), x, y, z, x, x);
277 double cabxpyzAxNorm(double a, const Complex &b, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
279 return instantiateReduce<cabxpyzaxnorm, false>(Complex(a), b, Complex(0.0), x, y, z, x, x);
282 Complex cDotProduct(ColorSpinorField &x, ColorSpinorField &y)
284 auto cdot = instantiateReduce<Cdot, false>(0.0, 0.0, 0.0, x, y, x, x, x);
285 return Complex(cdot.x, cdot.y);
288 Complex caxpyDotzy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
290 auto cdot = instantiateReduce<caxpydotzy, false>(a, Complex(0.0), Complex(0.0), x, y, z, x, x);
291 return Complex(cdot.x, cdot.y);
294 double3 cDotProductNormA(ColorSpinorField &x, ColorSpinorField &y)
296 return instantiateReduce<CdotNormA, false>(0.0, 0.0, 0.0, x, y, x, x, x);
299 double3 caxpbypzYmbwcDotProductUYNormY(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y,
300 ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &u)
302 return instantiateReduce<caxpbypzYmbwcDotProductUYNormY_, true>(a, b, Complex(0.0), x, z, y, w, u);
305 Complex axpyCGNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
307 double2 cg_norm = instantiateReduce<axpyCGNorm2, true>(a, 0.0, 0.0, x, y, x, x, x);
308 return Complex(cg_norm.x, cg_norm.y);
311 double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
313 // in case of x.Ncolor()!=3 (MG mainly) reduce_core do not support this function.
314 if (x.Ncolor() != 3) return make_double3(0.0, 0.0, 0.0);
315 double3 rtn = instantiateReduce<HeavyQuarkResidualNorm_, false>(0.0, 0.0, 0.0, x, r, r, r, r);
316 rtn.z /= (x.Volume()*comm_size());
320 double3 xpyHeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &r)
322 // in case of x.Ncolor()!=3 (MG mainly) reduce_core do not support this function.
323 if (x.Ncolor()!=3) return make_double3(0.0, 0.0, 0.0);
324 double3 rtn = instantiateReduce<xpyHeavyQuarkResidualNorm_, false>(0.0, 0.0, 0.0, x, y, r, r, r);
325 rtn.z /= (x.Volume()*comm_size());
329 double3 tripleCGReduction(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
331 return instantiateReduce<tripleCGReduction_, false>(0.0, 0.0, 0.0, x, y, z, x, x);
334 double4 quadrupleCGReduction(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
336 return instantiateReduce<quadrupleCGReduction_, false>(0.0, 0.0, 0.0, x, y, z, x, x);
339 double quadrupleCG3InitNorm(double a, ColorSpinorField &x, ColorSpinorField &y,
340 ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v)
342 return instantiateReduce<quadrupleCG3InitNorm_, false>(a, 0.0, 0.0, x, y, z, w, v);
345 double quadrupleCG3UpdateNorm(double a, double b, ColorSpinorField &x, ColorSpinorField &y,
346 ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v)
348 return instantiateReduce<quadrupleCG3UpdateNorm_, false>(a, b, 0.0, x, y, z, w, v);