QUDA  v1.1.0
A library for QCD on GPUs
gauge_wilson_flow.cu
Go to the documentation of this file.
1 #include <quda_internal.h>
2 #include <tune_quda.h>
3 #include <gauge_field.h>
4 
5 #include <jitify_helper.cuh>
6 #include <kernels/gauge_wilson_flow.cuh>
7 #include <instantiate.h>
8 
9 namespace quda {
10 
11  template <typename Float, int nColor, QudaReconstructType recon>
12  class GaugeWFlowStep : TunableVectorYZ
13  {
14  static constexpr int wflow_dim = 4; // apply flow in all dims
15  GaugeWFlowArg<Float, nColor, recon, wflow_dim> arg;
16  const GaugeField &meta;
17 
18  bool tuneSharedBytes() const { return false; }
19  bool tuneGridDim() const { return false; }
20  unsigned int minThreads() const { return arg.threads; }
21  unsigned int maxBlockSize(const TuneParam &param) const { return 32; }
22  int blockStep() const { return 8; }
23  int blockMin() const { return 8; }
24 
25  public:
26  GaugeWFlowStep(GaugeField &out, GaugeField &temp, const GaugeField &in, const double epsilon, const QudaWFlowType wflow_type, const WFlowStepType step_type) :
27  TunableVectorYZ(2, wflow_dim),
28  arg(out, temp, in, epsilon, wflow_type, step_type),
29  meta(in)
30  {
31  strcpy(aux, meta.AuxString());
32  strcat(aux, comm_dim_partitioned_string());
33  switch (wflow_type) {
34  case QUDA_WFLOW_TYPE_WILSON: strcat(aux,",computeWFlowStepWilson"); break;
35  case QUDA_WFLOW_TYPE_SYMANZIK: strcat(aux,",computeWFlowStepSymanzik"); break;
36  default : errorQuda("Unknown Wilson Flow type %d", wflow_type);
37  }
38  switch (step_type) {
39  case WFLOW_STEP_W1: strcat(aux, "_W1"); break;
40  case WFLOW_STEP_W2: strcat(aux, "_W2"); break;
41  case WFLOW_STEP_VT: strcat(aux, "_VT"); break;
42  default : errorQuda("Unknown Wilson Flow step type %d", step_type);
43  }
44 
45 #ifdef JITIFY
46  create_jitify_program("kernels/gauge_wilson_flow.cuh");
47 #endif
48  apply(0);
49  qudaDeviceSynchronize();
50  }
51 
52  void apply(const qudaStream_t &stream)
53  {
54  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
55 #ifdef JITIFY
56  using namespace jitify::reflection;
57  jitify_error = program->kernel("quda::computeWFlowStep").instantiate(arg.wflow_type,arg.step_type,Type<decltype(arg)>())
58  .configure(tp.grid, tp.block, tp.shared_bytes, stream).launch(arg);
59 #else
60  switch (arg.wflow_type) {
61  case QUDA_WFLOW_TYPE_WILSON:
62  switch (arg.step_type) {
63  case WFLOW_STEP_W1: qudaLaunchKernel(computeWFlowStep<QUDA_WFLOW_TYPE_WILSON, WFLOW_STEP_W1, decltype(arg)>, tp, stream, arg); break;
64  case WFLOW_STEP_W2: qudaLaunchKernel(computeWFlowStep<QUDA_WFLOW_TYPE_WILSON, WFLOW_STEP_W2, decltype(arg)>, tp, stream, arg); break;
65  case WFLOW_STEP_VT: qudaLaunchKernel(computeWFlowStep<QUDA_WFLOW_TYPE_WILSON, WFLOW_STEP_VT, decltype(arg)>, tp, stream, arg); break;
66  }
67  break;
68  case QUDA_WFLOW_TYPE_SYMANZIK:
69  switch (arg.step_type) {
70  case WFLOW_STEP_W1: qudaLaunchKernel(computeWFlowStep<QUDA_WFLOW_TYPE_SYMANZIK, WFLOW_STEP_W1, decltype(arg)>, tp, stream, arg); break;
71  case WFLOW_STEP_W2: qudaLaunchKernel(computeWFlowStep<QUDA_WFLOW_TYPE_SYMANZIK, WFLOW_STEP_W2, decltype(arg)>, tp, stream, arg); break;
72  case WFLOW_STEP_VT: qudaLaunchKernel(computeWFlowStep<QUDA_WFLOW_TYPE_SYMANZIK, WFLOW_STEP_VT, decltype(arg)>, tp, stream, arg); break;
73  }
74  break;
75  default: errorQuda("Unknown Wilson Flow type %d", arg.wflow_type);
76  }
77 #endif
78  }
79 
80  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
81 
82  void preTune() {
83  arg.out.save(); // defensive measure in case out aliases in
84  arg.temp.save();
85  }
86  void postTune() {
87  arg.out.load();
88  arg.temp.load();
89  }
90 
91  long long flops() const
92  {
93  // only counts number of mat-muls per thread
94  long long threads = 2ll * arg.threads * wflow_dim;
95  long long mat_flops = arg.nColor * arg.nColor * (8 * arg.nColor - 2);
96  long long mat_muls = 1; // 1 comes from Z * conj(U) term
97  switch(arg.wflow_type) {
98  case QUDA_WFLOW_TYPE_WILSON: mat_muls += 4 * (wflow_dim - 1); break;
99  case QUDA_WFLOW_TYPE_SYMANZIK: mat_muls += 28 * (wflow_dim - 1); break;
100  default : errorQuda("Unknown Wilson Flow type");
101  }
102  return mat_muls * mat_flops * threads;
103  }
104 
105  long long bytes() const
106  {
107  int links = 0;
108  switch(arg.wflow_type) {
109  case QUDA_WFLOW_TYPE_WILSON: links = 6; break;
110  case QUDA_WFLOW_TYPE_SYMANZIK: links = 24; break;
111  default : errorQuda("Unknown Wilson Flow type");
112  }
113  auto temp_io = arg.step_type == WFLOW_STEP_W2 ? 2 : arg.step_type == WFLOW_STEP_VT ? 1 : 0;
114  return ((1 + (wflow_dim-1) * links) * arg.in.Bytes() + arg.out.Bytes() + temp_io*arg.temp.Bytes()) * 2ll * arg.threads * wflow_dim;
115  }
116  }; // GaugeWFlowStep
117 
118  void WFlowStep(GaugeField &out, GaugeField &temp, GaugeField &in, const double epsilon, const QudaWFlowType wflow_type)
119  {
120 #ifdef GPU_GAUGE_TOOLS
121  checkPrecision(out, temp, in);
122  checkReconstruct(out, in);
123  if (temp.Reconstruct() != QUDA_RECONSTRUCT_NO) errorQuda("Temporary vector must not use reconstruct");
124  if (!out.isNative()) errorQuda("Order %d with %d reconstruct not supported", in.Order(), in.Reconstruct());
125  if (!in.isNative()) errorQuda("Order %d with %d reconstruct not supported", out.Order(), out.Reconstruct());
126 
127  // Set each step type as an arg parameter, update halos if needed
128  // Step W1
129  instantiate<GaugeWFlowStep,WilsonReconstruct>(out, temp, in, epsilon, wflow_type, WFLOW_STEP_W1);
130  out.exchangeExtendedGhost(out.R(), false);
131 
132  // Step W2
133  instantiate<GaugeWFlowStep,WilsonReconstruct>(in, temp, out, epsilon, wflow_type, WFLOW_STEP_W2);
134  in.exchangeExtendedGhost(in.R(), false);
135 
136  // Step Vt
137  instantiate<GaugeWFlowStep,WilsonReconstruct>(out, temp, in, epsilon, wflow_type, WFLOW_STEP_VT);
138  out.exchangeExtendedGhost(out.R(), false);
139 #else
140  errorQuda("Gauge tools are not built");
141 #endif
142  }
143 }