2 #include <gauge_field.h>
3 #include <blas_lapack.h>
7 #include <jitify_helper.cuh>
8 #include <kernels/coarse_op_preconditioned.cuh>
10 #include <coarse_op_preconditioned_mma_launch.h>
16 @brief Launcher for CPU instantiations of preconditioned coarse-link construction
18 template <QudaFieldLocation location, typename Arg>
20 Launch(Arg &arg, CUresult &error, bool compute_max_only, TuneParam &tp, bool use_mma, const qudaStream_t &stream)
23 CalculateYhatCPU<true, Arg>(arg);
25 CalculateYhatCPU<false, Arg>(arg);
30 @brief Launcher for GPU instantiations of preconditioned coarse-link construction
32 template <typename Arg>
33 struct Launch<QUDA_CUDA_FIELD_LOCATION, Arg> {
34 Launch(Arg &arg, CUresult &error, bool compute_max_only, TuneParam &tp, bool use_mma, const qudaStream_t &stream)
36 if (compute_max_only) {
37 if (!activeTuning()) {
38 qudaMemsetAsync(arg.max_d, 0, sizeof(typename Arg::Float), stream);
43 errorQuda("MMA kernels haven't been jitify'ed.");
45 using namespace jitify::reflection;
46 error = program->kernel("quda::CalculateYhatGPU")
47 .instantiate(compute_max_only, Type<Arg>())
48 .configure(tp.grid, tp.block, tp.shared_bytes, stream)
53 if (compute_max_only) {
54 mma::launch_yhat_kernel<true>(arg, arg.Y.VolumeCB(), tp, stream);
56 mma::launch_yhat_kernel<false>(arg, arg.Y.VolumeCB(), tp, stream);
59 if (compute_max_only) {
60 qudaLaunchKernel(CalculateYhatGPU<true, Arg>, tp, stream, arg);
62 qudaLaunchKernel(CalculateYhatGPU<false, Arg>, tp, stream, arg);
66 if (compute_max_only) {
67 if (!activeTuning()) { // only do copy once tuning is done
68 qudaMemcpyAsync(arg.max_h, arg.max_d, sizeof(typename Arg::Float), cudaMemcpyDeviceToHost, stream);
69 qudaStreamSynchronize(const_cast<qudaStream_t&>(stream));
75 template <QudaFieldLocation location, typename Arg>
76 class CalculateYhat : public TunableVectorYZ {
78 using Float = typename Arg::Float;
80 const LatticeField &meta;
83 bool compute_max_only;
87 long long flops() const { return 2l * arg.Y.VolumeCB() * 8 * n * n * (8*n-2); } // 8 from dir, 8 from complexity,
88 long long bytes() const { return 2l * (arg.Xinv.Bytes() + 8*arg.Y.Bytes() + !compute_max_only * 8*arg.Yhat.Bytes()) * n; }
90 unsigned int minThreads() const { return arg.Y.VolumeCB(); }
91 bool tuneGridDim() const { return false; } // don't tune the grid dimension
93 // all the tuning done is only in matrix tile size (Y/Z block.grid)
94 int blockMin() const { return 8; }
95 int blockStep() const { return 8; }
96 unsigned int maxBlockSize(const TuneParam ¶m) const { return 8u; }
97 bool tuneAuxDim() const { return use_mma; } // tune aux if doing mma
100 CalculateYhat(Arg &arg, const LatticeField &meta, bool use_mma) :
101 TunableVectorYZ(2 * arg.tile.M_tiles, 4 * arg.tile.N_tiles),
105 compute_max_only(false),
108 if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
110 create_jitify_program("kernels/coarse_op_preconditioned.cuh");
112 arg.max_d = static_cast<Float*>(pool_device_malloc(sizeof(Float)));
114 arg.max_h = static_cast<Float*>(pool_pinned_malloc(sizeof(Float)));
115 strcpy(aux, compile_type_str(meta));
116 strcat(aux, comm_dim_partitioned_string());
119 virtual ~CalculateYhat() {
120 if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
121 pool_device_free(arg.max_d);
123 pool_pinned_free(arg.max_h);
126 void apply(const qudaStream_t &stream)
128 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
129 Launch<location, Arg>(arg, jitify_error, compute_max_only, tp, use_mma, stream);
133 Set if we're doing a max-only compute (fixed point only)
135 void setComputeMaxOnly(bool compute_max_only_) { compute_max_only = compute_max_only_; }
137 bool advanceSharedBytes(TuneParam ¶m) const { return false; }
139 bool advanceAux(TuneParam ¶m) const
142 constexpr bool compute_max_only_dummy = true;
143 constexpr bool query_max = true;
144 int max = mma::template launch_yhat_kernel<compute_max_only_dummy, query_max>(arg, 1, param, 0);
145 if (param.aux.x < max) {
155 bool advanceTuneParam(TuneParam ¶m) const
158 if (meta.Location() == QUDA_CUDA_FIELD_LOCATION && meta.MemType() == QUDA_MEMORY_DEVICE)
159 return Tunable::advanceTuneParam(param);
167 void initTuneParam(TuneParam ¶m) const
169 TunableVectorYZ::initTuneParam(param);
170 param.aux = make_int4(0, 0, 0, 0);
173 void defaultTuneParam(TuneParam ¶m) const
175 TunableVectorYZ::defaultTuneParam(param);
176 param.aux = make_int4(0, 0, 0, 0);
179 TuneKey tuneKey() const {
180 char Aux[TuneKey::aux_n];
182 if (compute_max_only) strcat(Aux, ",compute_max_only");
183 if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
184 strcat(Aux, meta.MemType() == QUDA_MEMORY_MAPPED ? ",GPU-mapped" : ",GPU-device");
185 } else if (meta.Location() == QUDA_CPU_FIELD_LOCATION) {
187 strcat(Aux, getOmpThreadStr());
189 if (use_mma) { strcat(Aux, ",MMA"); }
190 return TuneKey(meta.VolString(), typeid(*this).name(), Aux);
195 @brief Calculate the preconditioned coarse-link field and the clover inverse.
196 @param Yhat[out] Preconditioned coarse link field
197 @param Xinv[out] Coarse clover inverse field
198 @param Y[out] Coarse link field
199 @param X[out] Coarse clover field
200 @param use_mma[in] Whether or not use MMA (tensor core) to do the calculation
202 template <QudaFieldLocation location, typename storeFloat, typename Float, int N, QudaGaugeFieldOrder gOrder>
203 void calculateYhat(GaugeField &Yhat, GaugeField &Xinv, const GaugeField &Y, const GaugeField &X, bool use_mma)
205 using namespace blas_lapack;
206 auto invert = use_native() ? native::BatchInvertMatrix : generic::BatchInvertMatrix;
208 constexpr QudaGaugeFieldOrder gOrder_milc = QUDA_MILC_GAUGE_ORDER;
209 GaugeField *Xinv_aos = nullptr;
211 // invert the clover matrix field
212 const int n = X.Ncolor();
214 if (X.Location() == QUDA_CUDA_FIELD_LOCATION) {
216 auto create_gauge_copy = [](const GaugeField &X, bool copy_content) -> auto
218 GaugeField *output = nullptr;
219 if (X.Order() == gOrder_milc && X.Precision() >= QUDA_SINGLE_PRECISION) {
220 output = const_cast<GaugeField *>(&X);
222 GaugeFieldParam param(X);
223 param.order = gOrder_milc;
224 param.setPrecision(X.Precision() < QUDA_SINGLE_PRECISION ? QUDA_SINGLE_PRECISION : X.Precision());
225 output = cudaGaugeField::Create(param);
226 if (copy_content) output->copy(X);
231 GaugeField *X_aos = create_gauge_copy(X, true);
232 Xinv_aos = create_gauge_copy(Xinv, false);
234 blas::flops += invert((void *)Xinv_aos->Gauge_p(), (void *)X_aos->Gauge_p(), n, X_aos->Volume(),
235 X_aos->Precision(), X.Location());
237 if (&Xinv != Xinv_aos) {
238 if (Xinv.Precision() < QUDA_SINGLE_PRECISION) Xinv.Scale(Xinv_aos->abs_max());
239 Xinv.copy(*Xinv_aos);
241 if (&X != X_aos) { delete X_aos; }
243 if (!use_mma) { delete Xinv_aos; }
245 } else if (X.Location() == QUDA_CPU_FIELD_LOCATION && X.Order() == QUDA_QDP_GAUGE_ORDER) {
246 const cpuGaugeField *X_h = static_cast<const cpuGaugeField*>(&X);
247 cpuGaugeField *Xinv_h = static_cast<cpuGaugeField*>(&Xinv);
248 blas::flops += invert(*(void**)Xinv_h->Gauge_p(), *(void**)X_h->Gauge_p(), n, X_h->Volume(), X.Precision(), X.Location());
250 errorQuda("Unsupported location=%d and order=%d", X.Location(), X.Order());
253 // now exchange Y halos of both forwards and backwards links for multi-process dslash
254 const_cast<GaugeField&>(Y).exchangeGhost(QUDA_LINK_BIDIRECTIONAL);
256 // compute the preconditioned links
257 // Yhat_back(x-\mu) = Y_back(x-\mu) * Xinv^dagger(x) (positive projector)
258 // Yhat_fwd(x) = Xinv(x) * Y_fwd(x) (negative projector)
261 for (int i=0; i<4; i++) xc_size[i] = X.X()[i];
266 auto create_gauge_copy = [](const GaugeField &X, QudaGaugeFieldOrder order, bool copy_content) -> auto
268 GaugeField *output = nullptr;
269 if (X.Order() == order) {
270 output = const_cast<GaugeField *>(&X);
272 GaugeFieldParam param(X);
274 output = cudaGaugeField::Create(param);
275 if (copy_content) output->copy(X);
280 GaugeField *Y_aos = create_gauge_copy(Y, gOrder_milc, true);
281 GaugeField *Yhat_aos = create_gauge_copy(Yhat, gOrder_milc, false);
283 constexpr bool use_native_ghosts = true;
284 // use spin-ignorant accessor to make multiplication simpler
285 typedef typename gauge::FieldOrder<Float, N, 1, gOrder_milc, use_native_ghosts, storeFloat> gCoarse;
286 typedef typename gauge::FieldOrder<Float, N, 1, gOrder_milc, use_native_ghosts, storeFloat> gPreconditionedCoarse;
287 gCoarse yAccessor(*Y_aos);
288 gPreconditionedCoarse yHatAccessor(*Yhat_aos);
290 // XXX: This doesn't work for double precision.
291 using gCoarseInv = gauge::FieldOrder<float, N, 1, gOrder_milc, use_native_ghosts, float>;
292 gCoarseInv xInvAccessor(*Xinv_aos);
293 if (getVerbosity() >= QUDA_VERBOSE) printfQuda("Xinv = %e\n", Xinv_aos->norm2(0));
296 for (int i = 0; i < 4; i++) comm_dim[i] = comm_dim_partitioned(i);
298 using yHatArg = CalculateYhatArg<Float, gPreconditionedCoarse, gCoarse, gCoarseInv, N, 4, 2>;
299 yHatArg arg(yHatAccessor, yAccessor, xInvAccessor, xc_size, comm_dim, 1);
301 CalculateYhat<location, yHatArg> yHat(arg, Y, use_mma);
302 if (Yhat.Precision() == QUDA_HALF_PRECISION || Yhat.Precision() == QUDA_QUARTER_PRECISION) {
303 yHat.setComputeMaxOnly(true);
306 double max_h_double = *arg.max_h;
307 comm_allreduce_max(&max_h_double);
308 *arg.max_h = static_cast<Float>(max_h_double);
310 if (getVerbosity() >= QUDA_VERBOSE) printfQuda("Yhat Max = %e\n", *arg.max_h);
312 Yhat_aos->Scale(*arg.max_h);
313 arg.Yhat.resetScale(*arg.max_h);
315 yHat.setComputeMaxOnly(false);
318 if (&Y != Y_aos) { delete Y_aos; }
320 if (&Yhat != Yhat_aos) {
321 Yhat.copy(*Yhat_aos);
325 if (Xinv_aos != &Xinv) { delete Xinv_aos; }
329 // use spin-ignorant accessor to make multiplication simpler
330 typedef typename gauge::FieldOrder<Float, N, 1, gOrder, true, storeFloat> gCoarse;
331 typedef typename gauge::FieldOrder<Float, N, 1, gOrder, true, storeFloat> gPreconditionedCoarse;
332 gCoarse yAccessor(const_cast<GaugeField &>(Y));
333 gPreconditionedCoarse yHatAccessor(const_cast<GaugeField &>(Yhat));
334 gCoarse xInvAccessor(const_cast<GaugeField &>(Xinv));
335 if (getVerbosity() >= QUDA_VERBOSE) printfQuda("Xinv = %e\n", Xinv.norm2(0));
338 for (int i = 0; i < 4; i++) comm_dim[i] = comm_dim_partitioned(i);
339 typedef CalculateYhatArg<Float, gPreconditionedCoarse, gCoarse, gCoarse, N, 4, 2> yHatArg;
340 yHatArg arg(yHatAccessor, yAccessor, xInvAccessor, xc_size, comm_dim, 1);
342 CalculateYhat<location, yHatArg> yHat(arg, Y, use_mma);
343 if (Yhat.Precision() == QUDA_HALF_PRECISION || Yhat.Precision() == QUDA_QUARTER_PRECISION) {
344 yHat.setComputeMaxOnly(true);
347 double max_h_double = *arg.max_h;
348 comm_allreduce_max(&max_h_double);
349 *arg.max_h = static_cast<Float>(max_h_double);
351 if (getVerbosity() >= QUDA_VERBOSE) printfQuda("Yhat Max = %e\n", *arg.max_h);
353 Yhat.Scale(*arg.max_h);
354 arg.Yhat.resetScale(*arg.max_h);
356 yHat.setComputeMaxOnly(false);
360 if (getVerbosity() >= QUDA_VERBOSE)
362 if (use_mma && X.Location() == QUDA_CUDA_FIELD_LOCATION)
363 warningQuda("There is a known issue with Yhat norms 0 through 3 for CUDA+MMA builds. These are harmless and will be addressed in the future.\n");
364 for (int d = 0; d < 8; d++)
365 printfQuda("Yhat[%d] = %e (%e %e = %e x %e)\n", d, Yhat.norm2(d), Yhat.abs_max(d),
366 Y.abs_max(d) * Xinv.abs_max(0), Y.abs_max(d), Xinv.abs_max(0));
370 // fill back in the bulk of Yhat so that the backward link is updated on the previous node
371 // need to put this in the bulk of the previous node - but only send backwards the backwards
372 // links to and not overwrite the forwards bulk
373 Yhat.injectGhost(QUDA_LINK_BACKWARDS);
375 // exchange forwards links for multi-process dslash dagger
376 // need to put this in the ghost zone of the next node - but only send forwards the forwards
377 // links and not overwrite the backwards ghost
378 Yhat.exchangeGhost(QUDA_LINK_FORWARDS);
381 template <typename storeFloat, typename Float, int N>
382 void calculateYhat(GaugeField &Yhat, GaugeField &Xinv, const GaugeField &Y, const GaugeField &X, bool use_mma)
384 if (Y.Location() == QUDA_CPU_FIELD_LOCATION) {
385 constexpr QudaGaugeFieldOrder gOrder = QUDA_QDP_GAUGE_ORDER;
386 if (Y.FieldOrder() != gOrder) errorQuda("Unsupported field order %d\n", Y.FieldOrder());
387 calculateYhat<QUDA_CPU_FIELD_LOCATION, storeFloat, Float, N, gOrder>(Yhat, Xinv, Y, X, use_mma);
389 constexpr QudaGaugeFieldOrder gOrder = QUDA_FLOAT2_GAUGE_ORDER;
390 // if (Y.FieldOrder() != gOrder) errorQuda("Unsupported field order %d\n", Y.FieldOrder());
391 calculateYhat<QUDA_CUDA_FIELD_LOCATION, storeFloat, Float, N, gOrder>(Yhat, Xinv, Y, X, use_mma);
395 // template on the number of coarse degrees of freedom
396 template <typename storeFloat, typename Float>
397 void calculateYhat(GaugeField &Yhat, GaugeField &Xinv, const GaugeField &Y, const GaugeField &X, bool use_mma)
399 switch (Y.Ncolor()) {
400 case 48: calculateYhat<storeFloat, Float, 48>(Yhat, Xinv, Y, X, use_mma); break;
402 case 12: calculateYhat<storeFloat, Float, 12>(Yhat, Xinv, Y, X, use_mma); break;
403 case 64: calculateYhat<storeFloat, Float, 64>(Yhat, Xinv, Y, X, use_mma); break;
406 case 128: calculateYhat<storeFloat, Float, 128>(Yhat, Xinv, Y, X, use_mma); break;
407 case 192: calculateYhat<storeFloat, Float, 192>(Yhat, Xinv, Y, X, use_mma); break;
409 default: errorQuda("Unsupported number of coarse dof %d\n", Y.Ncolor()); break;
413 //Does the heavy lifting of creating the coarse color matrices Y
414 void calculateYhat(GaugeField &Yhat, GaugeField &Xinv, const GaugeField &Y, const GaugeField &X, bool use_mma)
418 QudaPrecision precision = checkPrecision(Xinv, Y, X);
419 if (getVerbosity() >= QUDA_SUMMARIZE) printfQuda("Computing Yhat field......\n");
421 if (precision == QUDA_DOUBLE_PRECISION) {
422 #ifdef GPU_MULTIGRID_DOUBLE
423 if (Yhat.Precision() != QUDA_DOUBLE_PRECISION) errorQuda("Unsupported precision %d\n", Yhat.Precision());
424 if (use_mma) errorQuda("MG-MMA does not support double precision, yet.");
425 calculateYhat<double, double>(Yhat, Xinv, Y, X, use_mma);
427 errorQuda("Double precision multigrid has not been enabled");
429 } else if (precision == QUDA_SINGLE_PRECISION) {
430 if (Yhat.Precision() == QUDA_SINGLE_PRECISION) {
431 calculateYhat<float, float>(Yhat, Xinv, Y, X, use_mma);
433 errorQuda("Unsupported precision %d\n", precision);
435 } else if (precision == QUDA_HALF_PRECISION) {
436 if (Yhat.Precision() == QUDA_HALF_PRECISION) {
437 calculateYhat<short, float>(Yhat, Xinv, Y, X, use_mma);
439 errorQuda("Unsupported precision %d\n", precision);
442 errorQuda("Unsupported precision %d\n", precision);
445 if (getVerbosity() >= QUDA_SUMMARIZE) printfQuda("....done computing Yhat field\n");
447 errorQuda("Multigrid has not been built");