3 #include <cstring> // needed for memset
6 #include <quda_internal.h>
8 #include <color_spinor_field.h>
10 #include <jitify_helper.cuh>
11 #include <kernels/blas_core.cuh>
22 unsigned long long flops;
23 unsigned long long bytes;
25 static qudaStream_t *blasStream;
27 template <template <typename real> class Functor, typename store_t, typename y_store_t,
28 int nSpin, typename coeff_t>
29 class Blas : public Tunable
31 using real = typename mapper<y_store_t>::type;
33 const int nParity; // for composite fields this includes the number of composites
35 const coeff_t &a, &b, &c;
36 ColorSpinorField &x, &y, &z, &w, &v;
37 const QudaFieldLocation location;
39 unsigned int sharedBytesPerThread() const { return 0; }
40 unsigned int sharedBytesPerBlock(const TuneParam ¶m) const { return 0; }
42 bool tuneSharedBytes() const { return false; }
44 // for these streaming kernels, there is no need to tune the grid size, just use max
45 unsigned int minGridSize() const { return maxGridSize(); }
48 Blas(const coeff_t &a, const coeff_t &b, const coeff_t &c, ColorSpinorField &x,
49 ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v) :
51 nParity((x.IsComposite() ? x.CompositeDim() : 1) * x.SiteSubset()),
60 location(checkLocation(x, y, z, w, v))
62 checkLength(x, y, z, w, v);
63 auto x_prec = checkPrecision(x, z, w);
64 auto y_prec = checkPrecision(y, v);
65 auto x_order = checkOrder(x, z, w);
66 auto y_order = checkOrder(y, v);
67 if (sizeof(store_t) != x_prec) errorQuda("Expected precision %lu but received %d", sizeof(store_t), x_prec);
68 if (sizeof(y_store_t) != y_prec) errorQuda("Expected precision %lu but received %d", sizeof(y_store_t), y_prec);
69 if (x_prec == y_prec && x_order != y_order) errorQuda("Orders %d %d do not match", x_order, y_order);
71 strcpy(aux, x.AuxString());
72 if (x_prec != y_prec) {
74 strcat(aux, y.AuxString());
76 if (location == QUDA_CPU_FIELD_LOCATION) strcat(aux, ",CPU");
79 ::quda::create_jitify_program("kernels/blas_core.cuh");
84 blas::bytes += bytes();
85 blas::flops += flops();
88 TuneKey tuneKey() const { return TuneKey(x.VolString(), typeid(f).name(), aux); }
90 void apply(const qudaStream_t &stream)
92 constexpr bool site_unroll_check = !std::is_same<store_t, y_store_t>::value || isFixed<store_t>::value;
93 if (site_unroll_check && (x.Ncolor() != 3 || x.Nspin() == 2))
94 errorQuda("site unroll not supported for nSpin = %d nColor = %d", x.Nspin(), x.Ncolor());
96 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
97 if (location == QUDA_CUDA_FIELD_LOCATION) {
98 if (site_unroll_check) checkNative(x, y, z, w, v); // require native order when using site_unroll
99 using device_store_t = typename device_type_mapper<store_t>::type;
100 using device_y_store_t = typename device_type_mapper<y_store_t>::type;
101 using device_real_t = typename mapper<device_y_store_t>::type;
102 Functor<device_real_t> f_(a, b, c);
104 // redefine site_unroll with device_store types to ensure we have correct N/Ny/M values
105 constexpr bool site_unroll = !std::is_same<device_store_t, device_y_store_t>::value || isFixed<device_store_t>::value;
106 constexpr int N = n_vector<device_store_t, true, nSpin, site_unroll>();
107 constexpr int Ny = n_vector<device_y_store_t, true, nSpin, site_unroll>();
108 constexpr int M = site_unroll ? (nSpin == 4 ? 24 : 6) : N; // real numbers per thread
109 const int length = x.Length() / (nParity * M);
111 BlasArg<device_store_t, N, device_y_store_t, Ny, decltype(f_)> arg(x, y, z, w, v, f_, length, nParity);
113 using namespace jitify::reflection;
114 jitify_error = program->kernel("quda::blas::blasKernel")
115 .instantiate(Type<device_real_t>(), M, Type<decltype(arg)>())
116 .configure(tp.grid, tp.block, tp.shared_bytes, stream)
119 qudaLaunchKernel(blasKernel<device_real_t, M, decltype(arg)>, tp, stream, arg);
122 if (checkOrder(x, y, z, w, v) != QUDA_SPACE_SPIN_COLOR_FIELD_ORDER)
123 errorQuda("CPU Blas functions expect AoS field order");
125 using host_store_t = typename host_type_mapper<store_t>::type;
126 using host_y_store_t = typename host_type_mapper<y_store_t>::type;
127 using host_real_t = typename mapper<host_y_store_t>::type;
128 Functor<host_real_t> f_(a, b, c);
130 // redefine site_unroll with host_store types to ensure we have correct N/Ny/M values
131 constexpr bool site_unroll = !std::is_same<host_store_t, host_y_store_t>::value || isFixed<host_store_t>::value;
132 constexpr int N = n_vector<host_store_t, false, nSpin, site_unroll>();
133 constexpr int Ny = n_vector<host_y_store_t, false, nSpin, site_unroll>();
134 constexpr int M = N; // if site unrolling then M=N will be 24/6, e.g., full AoS
135 const int length = x.Length() / (nParity * M);
137 BlasArg<host_store_t, N, host_y_store_t, Ny, decltype(f_)> arg(x, y, z, w, v, f_, length, nParity);
138 blasCPU<host_real_t, M>(arg);
144 if (f.write.X) x.backup();
145 if (f.write.Y) y.backup();
146 if (f.write.Z) z.backup();
147 if (f.write.W) w.backup();
148 if (f.write.V) v.backup();
153 if (f.write.X) x.restore();
154 if (f.write.Y) y.restore();
155 if (f.write.Z) z.restore();
156 if (f.write.W) w.restore();
157 if (f.write.V) v.restore();
160 bool advanceTuneParam(TuneParam ¶m) const
162 return location == QUDA_CPU_FIELD_LOCATION ? false : Tunable::advanceTuneParam(param);
165 void initTuneParam(TuneParam ¶m) const
167 Tunable::initTuneParam(param);
168 param.grid.y = nParity;
171 void defaultTuneParam(TuneParam ¶m) const
173 Tunable::initTuneParam(param);
174 param.grid.y = nParity;
177 long long flops() const { return f.flops() * x.Length(); }
178 long long bytes() const
180 return (f.read.X + f.write.X) * x.Bytes() + (f.read.Y + f.write.Y) * y.Bytes() +
181 (f.read.Z + f.write.Z) * z.Bytes() + (f.read.W + f.write.W) * w.Bytes() + (f.read.V + f.write.V) * v.Bytes();
183 int tuningIter() const { return 3; }
186 void zero(ColorSpinorField &a) {
187 if (typeid(a) == typeid(cudaColorSpinorField)) {
188 static_cast<cudaColorSpinorField&>(a).zero();
190 static_cast<cpuColorSpinorField&>(a).zero();
196 blasStream = &streams[Nstream-1];
205 qudaStream_t* getStream() { return blasStream; }
207 void axpbyz(double a, ColorSpinorField &x, double b, ColorSpinorField &y, ColorSpinorField &z)
209 instantiate<axpbyz_, Blas, true>(a, b, 0.0, x, y, x, x, z);
212 void ax(double a, ColorSpinorField &x)
214 instantiate<ax_, Blas, false>(a, 0.0, 0.0, x, x, x, x, x);
217 void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
219 instantiate<caxpy_, Blas, true>(a, Complex(0.0), Complex(0.0), x, y, x, x, y);
222 void caxpby(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y)
224 instantiate<caxpby_, Blas, false>(a, b, Complex(0.0), x, y, x, x, y);
227 void caxpbypczw(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y, const Complex &c,
228 ColorSpinorField &z, ColorSpinorField &w)
230 instantiate<caxpbypczw_, Blas, false>(a, b, c, x, y, z, w, y);
233 void cxpaypbz(ColorSpinorField &x, const Complex &a, ColorSpinorField &y, const Complex &b, ColorSpinorField &z)
235 instantiate<caxpbypczw_, Blas, false>(Complex(1.0), a, b, x, y, z, z, y);
238 void axpyBzpcx(double a, ColorSpinorField& x, ColorSpinorField& y, double b, ColorSpinorField& z, double c)
240 instantiate<axpyBzpcx_, Blas, true>(a, b, c, x, y, z, x, y);
243 void axpyZpbx(double a, ColorSpinorField& x, ColorSpinorField& y, ColorSpinorField& z, double b)
245 instantiate<axpyZpbx_, Blas, true>(a, b, 0.0, x, y, z, x, y);
248 void caxpyBzpx(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, const Complex &b, ColorSpinorField &z)
250 instantiate<caxpyBzpx_, Blas, true>(a, b, Complex(0.0), x, y, z, x, y);
253 void caxpyBxpz(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, const Complex &b, ColorSpinorField &z)
255 instantiate<caxpyBxpz_, Blas, true>(a, b, Complex(0.0), x, y, z, x, y);
258 void caxpbypzYmbw(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w)
260 instantiate<caxpbypzYmbw_, Blas, false>(a, b, Complex(0.0), x, y, z, w, y);
263 void cabxpyAx(double a, const Complex &b, ColorSpinorField &x, ColorSpinorField &y)
265 instantiate<cabxpyAx_, Blas, false>(Complex(a), b, Complex(0.0), x, y, x, x, y);
268 void caxpyXmaz(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
270 instantiate<caxpyxmaz_, Blas, false>(a, Complex(0.0), Complex(0.0), x, y, z, x, y);
273 void caxpyXmazMR(const double &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
275 if (!commAsyncReduction())
276 errorQuda("This kernel requires asynchronous reductions to be set");
277 if (x.Location() == QUDA_CPU_FIELD_LOCATION)
278 errorQuda("This kernel cannot be run on CPU fields");
279 instantiate<caxpyxmazMR_, Blas, false>(a, 0.0, 0.0, x, y, z, y, y);
282 void tripleCGUpdate(double a, double b, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w)
284 instantiate<tripleCGUpdate_, Blas, true>(a, b, 0.0, x, y, z, w, y);