1 #include <quda_internal.h>
3 #include <gauge_field.h>
5 #include <jitify_helper.cuh>
6 #include <kernels/gauge_wilson_flow.cuh>
7 #include <instantiate.h>
11 template <typename Float, int nColor, QudaReconstructType recon>
12 class GaugeWFlowStep : TunableVectorYZ
14 static constexpr int wflow_dim = 4; // apply flow in all dims
15 GaugeWFlowArg<Float, nColor, recon, wflow_dim> arg;
16 const GaugeField &meta;
18 bool tuneSharedBytes() const { return false; }
19 bool tuneGridDim() const { return false; }
20 unsigned int minThreads() const { return arg.threads; }
21 unsigned int maxBlockSize(const TuneParam ¶m) const { return 32; }
22 int blockStep() const { return 8; }
23 int blockMin() const { return 8; }
26 GaugeWFlowStep(GaugeField &out, GaugeField &temp, const GaugeField &in, const double epsilon, const QudaWFlowType wflow_type, const WFlowStepType step_type) :
27 TunableVectorYZ(2, wflow_dim),
28 arg(out, temp, in, epsilon, wflow_type, step_type),
31 strcpy(aux, meta.AuxString());
32 strcat(aux, comm_dim_partitioned_string());
34 case QUDA_WFLOW_TYPE_WILSON: strcat(aux,",computeWFlowStepWilson"); break;
35 case QUDA_WFLOW_TYPE_SYMANZIK: strcat(aux,",computeWFlowStepSymanzik"); break;
36 default : errorQuda("Unknown Wilson Flow type %d", wflow_type);
39 case WFLOW_STEP_W1: strcat(aux, "_W1"); break;
40 case WFLOW_STEP_W2: strcat(aux, "_W2"); break;
41 case WFLOW_STEP_VT: strcat(aux, "_VT"); break;
42 default : errorQuda("Unknown Wilson Flow step type %d", step_type);
46 create_jitify_program("kernels/gauge_wilson_flow.cuh");
49 qudaDeviceSynchronize();
52 void apply(const qudaStream_t &stream)
54 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
56 using namespace jitify::reflection;
57 jitify_error = program->kernel("quda::computeWFlowStep").instantiate(arg.wflow_type,arg.step_type,Type<decltype(arg)>())
58 .configure(tp.grid, tp.block, tp.shared_bytes, stream).launch(arg);
60 switch (arg.wflow_type) {
61 case QUDA_WFLOW_TYPE_WILSON:
62 switch (arg.step_type) {
63 case WFLOW_STEP_W1: qudaLaunchKernel(computeWFlowStep<QUDA_WFLOW_TYPE_WILSON, WFLOW_STEP_W1, decltype(arg)>, tp, stream, arg); break;
64 case WFLOW_STEP_W2: qudaLaunchKernel(computeWFlowStep<QUDA_WFLOW_TYPE_WILSON, WFLOW_STEP_W2, decltype(arg)>, tp, stream, arg); break;
65 case WFLOW_STEP_VT: qudaLaunchKernel(computeWFlowStep<QUDA_WFLOW_TYPE_WILSON, WFLOW_STEP_VT, decltype(arg)>, tp, stream, arg); break;
68 case QUDA_WFLOW_TYPE_SYMANZIK:
69 switch (arg.step_type) {
70 case WFLOW_STEP_W1: qudaLaunchKernel(computeWFlowStep<QUDA_WFLOW_TYPE_SYMANZIK, WFLOW_STEP_W1, decltype(arg)>, tp, stream, arg); break;
71 case WFLOW_STEP_W2: qudaLaunchKernel(computeWFlowStep<QUDA_WFLOW_TYPE_SYMANZIK, WFLOW_STEP_W2, decltype(arg)>, tp, stream, arg); break;
72 case WFLOW_STEP_VT: qudaLaunchKernel(computeWFlowStep<QUDA_WFLOW_TYPE_SYMANZIK, WFLOW_STEP_VT, decltype(arg)>, tp, stream, arg); break;
75 default: errorQuda("Unknown Wilson Flow type %d", arg.wflow_type);
80 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
83 arg.out.save(); // defensive measure in case out aliases in
91 long long flops() const
93 // only counts number of mat-muls per thread
94 long long threads = 2ll * arg.threads * wflow_dim;
95 long long mat_flops = arg.nColor * arg.nColor * (8 * arg.nColor - 2);
96 long long mat_muls = 1; // 1 comes from Z * conj(U) term
97 switch(arg.wflow_type) {
98 case QUDA_WFLOW_TYPE_WILSON: mat_muls += 4 * (wflow_dim - 1); break;
99 case QUDA_WFLOW_TYPE_SYMANZIK: mat_muls += 28 * (wflow_dim - 1); break;
100 default : errorQuda("Unknown Wilson Flow type");
102 return mat_muls * mat_flops * threads;
105 long long bytes() const
108 switch(arg.wflow_type) {
109 case QUDA_WFLOW_TYPE_WILSON: links = 6; break;
110 case QUDA_WFLOW_TYPE_SYMANZIK: links = 24; break;
111 default : errorQuda("Unknown Wilson Flow type");
113 auto temp_io = arg.step_type == WFLOW_STEP_W2 ? 2 : arg.step_type == WFLOW_STEP_VT ? 1 : 0;
114 return ((1 + (wflow_dim-1) * links) * arg.in.Bytes() + arg.out.Bytes() + temp_io*arg.temp.Bytes()) * 2ll * arg.threads * wflow_dim;
118 void WFlowStep(GaugeField &out, GaugeField &temp, GaugeField &in, const double epsilon, const QudaWFlowType wflow_type)
120 #ifdef GPU_GAUGE_TOOLS
121 checkPrecision(out, temp, in);
122 checkReconstruct(out, in);
123 if (temp.Reconstruct() != QUDA_RECONSTRUCT_NO) errorQuda("Temporary vector must not use reconstruct");
124 if (!out.isNative()) errorQuda("Order %d with %d reconstruct not supported", in.Order(), in.Reconstruct());
125 if (!in.isNative()) errorQuda("Order %d with %d reconstruct not supported", out.Order(), out.Reconstruct());
127 // Set each step type as an arg parameter, update halos if needed
129 instantiate<GaugeWFlowStep,WilsonReconstruct>(out, temp, in, epsilon, wflow_type, WFLOW_STEP_W1);
130 out.exchangeExtendedGhost(out.R(), false);
133 instantiate<GaugeWFlowStep,WilsonReconstruct>(in, temp, out, epsilon, wflow_type, WFLOW_STEP_W2);
134 in.exchangeExtendedGhost(in.R(), false);
137 instantiate<GaugeWFlowStep,WilsonReconstruct>(out, temp, in, epsilon, wflow_type, WFLOW_STEP_VT);
138 out.exchangeExtendedGhost(out.R(), false);
140 errorQuda("Gauge tools are not built");