1 #include <quda_internal.h>
3 #include <gauge_field.h>
5 #include <jitify_helper.cuh>
6 #include <kernels/gauge_stout.cuh>
7 #include <instantiate.h>
11 template <typename Float, int nColor, QudaReconstructType recon> class GaugeSTOUT : TunableVectorYZ
13 static constexpr int stoutDim = 3; // apply stouting in space only
14 GaugeSTOUTArg<Float, nColor, recon, stoutDim> arg;
15 const GaugeField &meta;
17 bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
18 unsigned int minThreads() const { return arg.threads; }
21 // (2,3): 2 for parity in the y thread dim, 3 corresponds to mapping direction to the z thread dim
22 GaugeSTOUT(GaugeField &out, const GaugeField &in, double rho) :
23 TunableVectorYZ(2, stoutDim),
27 strcpy(aux, meta.AuxString());
28 strcat(aux, comm_dim_partitioned_string());
30 create_jitify_program("kernels/gauge_stout.cuh");
33 qudaDeviceSynchronize();
36 void apply(const qudaStream_t &stream)
38 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
40 using namespace jitify::reflection;
41 jitify_error = program->kernel("quda::computeSTOUTStep").instantiate(Type<decltype(arg)>())
42 .configure(tp.grid, tp.block, tp.shared_bytes, stream).launch(arg);
44 qudaLaunchKernel(computeSTOUTStep<decltype(arg)>, tp, stream, arg);
48 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
50 void preTune() { arg.out.save(); } // defensive measure in case they alias
51 void postTune() { arg.out.load(); }
53 long long flops() const { return 3 * (2 + 2 * 4) * 198ll * arg.threads; } // just counts matrix multiplication
54 long long bytes() const { return ((1 + stoutDim * 6) * arg.in.Bytes() + arg.out.Bytes()) * arg.threads; } // 6 links per dim, 1 in, 1 out.
57 void STOUTStep(GaugeField &out, GaugeField &in, double rho)
59 #ifdef GPU_GAUGE_TOOLS
60 checkPrecision(out, in);
61 checkReconstruct(out, in);
63 if (!out.isNative()) errorQuda("Order %d with %d reconstruct not supported", in.Order(), in.Reconstruct());
64 if (!in.isNative()) errorQuda("Order %d with %d reconstruct not supported", out.Order(), out.Reconstruct());
66 copyExtendedGauge(in, out, QUDA_CUDA_FIELD_LOCATION);
67 in.exchangeExtendedGhost(in.R(), false);
68 instantiate<GaugeSTOUT>(out, in, rho);
69 out.exchangeExtendedGhost(out.R(), false);
71 errorQuda("Gauge tools are not built");
75 template <typename Float, int nColor, QudaReconstructType recon> class GaugeOvrImpSTOUT : TunableVectorYZ
77 static constexpr int stoutDim = 4; // apply stouting in all dims
78 GaugeSTOUTArg<Float, nColor, recon, stoutDim> arg;
79 const GaugeField &meta;
81 bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
82 unsigned int minThreads() const { return arg.threads; }
85 GaugeOvrImpSTOUT(GaugeField &out, const GaugeField &in, double rho, double epsilon) :
86 TunableVectorYZ(2, stoutDim),
87 arg(out, in, rho, epsilon),
90 strcpy(aux, meta.AuxString());
91 strcat(aux, comm_dim_partitioned_string());
93 create_jitify_program("kernels/gauge_stout.cuh");
96 qudaDeviceSynchronize();
99 void apply(const qudaStream_t &stream)
101 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
103 using namespace jitify::reflection;
104 jitify_error = program->kernel("quda::computeOvrImpSTOUTStep").instantiate(Type<decltype(arg)>())
105 .configure(tp.grid, tp.block, tp.shared_bytes, stream).launch(arg);
107 qudaLaunchKernel(computeOvrImpSTOUTStep<decltype(arg)>, tp, stream, arg);
111 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
113 void preTune() { arg.out.save(); } // defensive measure in case they alias
114 void postTune() { arg.out.load(); }
116 long long flops() const { return 4*(18+2+2*4)*198ll*arg.threads; } // just counts matrix multiplication
117 long long bytes() const { return ((1 + stoutDim * 24) * arg.in.Bytes() + arg.out.Bytes()) * arg.threads; } //24 links per dim, 1 in, 1 out
118 }; // GaugeOvrImpSTOUT
120 void OvrImpSTOUTStep(GaugeField &out, GaugeField& in, double rho, double epsilon)
122 #ifdef GPU_GAUGE_TOOLS
123 checkPrecision(out, in);
124 checkReconstruct(out, in);
126 if (!out.isNative()) errorQuda("Order %d with %d reconstruct not supported", in.Order(), in.Reconstruct());
127 if (!in.isNative()) errorQuda("Order %d with %d reconstruct not supported", out.Order(), out.Reconstruct());
129 copyExtendedGauge(in, out, QUDA_CUDA_FIELD_LOCATION);
130 in.exchangeExtendedGhost(in.R(), false);
131 instantiate<GaugeOvrImpSTOUT>(out, in, rho, epsilon);
132 out.exchangeExtendedGhost(out.R(), false);
135 errorQuda("Gauge tools are not built");