1 #include <quda_internal.h>
3 #include <gauge_field.h>
5 #include <jitify_helper.cuh>
6 #include <kernels/gauge_ape.cuh>
7 #include <instantiate.h>
11 template <typename Float, int nColor, QudaReconstructType recon> class GaugeAPE : TunableVectorYZ
13 static constexpr int apeDim = 3; // apply APE in space only
14 GaugeAPEArg<Float,nColor,recon, apeDim> arg;
15 const GaugeField &meta;
17 bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
18 unsigned int minThreads() const { return arg.threads; }
21 // (2,3): 2 for parity in the y thread dim, 3 corresponds to mapping direction to the z thread dim
22 GaugeAPE(GaugeField &out, const GaugeField &in, double alpha) :
23 TunableVectorYZ(2, apeDim),
27 strcpy(aux, meta.AuxString());
28 strcat(aux, comm_dim_partitioned_string());
30 create_jitify_program("kernels/gauge_ape.cuh");
33 qudaDeviceSynchronize();
36 void apply(const qudaStream_t &stream)
38 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
40 using namespace jitify::reflection;
41 jitify_error = program->kernel("quda::computeAPEStep").instantiate(Type<decltype(arg)>())
42 .configure(tp.grid, tp.block, tp.shared_bytes, stream).launch(arg);
44 qudaLaunchKernel(computeAPEStep<decltype(arg)>, tp, stream, arg);
48 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
50 void preTune() { arg.out.save(); } // defensive measure in case they alias
51 void postTune() { arg.out.load(); }
53 long long flops() const { return apeDim * (2 + 2 * 4) * 198ll * arg.threads; } // just counts matrix multiplication
54 long long bytes() const { return ((1 + 6 * apeDim) * arg.in.Bytes() + arg.out.Bytes()) * arg.threads; } // 6 links per dim, 1 in, 1 out.
57 void APEStep(GaugeField &out, GaugeField& in, double alpha) {
58 #ifdef GPU_GAUGE_TOOLS
59 checkPrecision(out, in);
60 checkReconstruct(out, in);
62 if (!out.isNative()) errorQuda("Order %d with %d reconstruct not supported", in.Order(), in.Reconstruct());
63 if (!in.isNative()) errorQuda("Order %d with %d reconstruct not supported", out.Order(), out.Reconstruct());
65 copyExtendedGauge(in, out, QUDA_CUDA_FIELD_LOCATION);
66 in.exchangeExtendedGhost(in.R(), false);
67 instantiate<GaugeAPE>(out, in, alpha);
68 out.exchangeExtendedGhost(out.R(), false);
71 errorQuda("Gauge tools are not built");