QUDA  v1.1.0
A library for QCD on GPUs
gauge_stout.cu
Go to the documentation of this file.
1 #include <quda_internal.h>
2 #include <tune_quda.h>
3 #include <gauge_field.h>
4 
5 #include <jitify_helper.cuh>
6 #include <kernels/gauge_stout.cuh>
7 #include <instantiate.h>
8 
9 namespace quda {
10 
11  template <typename Float, int nColor, QudaReconstructType recon> class GaugeSTOUT : TunableVectorYZ
12  {
13  static constexpr int stoutDim = 3; // apply stouting in space only
14  GaugeSTOUTArg<Float, nColor, recon, stoutDim> arg;
15  const GaugeField &meta;
16 
17  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
18  unsigned int minThreads() const { return arg.threads; }
19 
20 public:
21  // (2,3): 2 for parity in the y thread dim, 3 corresponds to mapping direction to the z thread dim
22  GaugeSTOUT(GaugeField &out, const GaugeField &in, double rho) :
23  TunableVectorYZ(2, stoutDim),
24  arg(out, in, rho),
25  meta(in)
26  {
27  strcpy(aux, meta.AuxString());
28  strcat(aux, comm_dim_partitioned_string());
29 #ifdef JITIFY
30  create_jitify_program("kernels/gauge_stout.cuh");
31 #endif
32  apply(0);
33  qudaDeviceSynchronize();
34  }
35 
36  void apply(const qudaStream_t &stream)
37  {
38  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
39 #ifdef JITIFY
40  using namespace jitify::reflection;
41  jitify_error = program->kernel("quda::computeSTOUTStep").instantiate(Type<decltype(arg)>())
42  .configure(tp.grid, tp.block, tp.shared_bytes, stream).launch(arg);
43 #else
44  qudaLaunchKernel(computeSTOUTStep<decltype(arg)>, tp, stream, arg);
45 #endif
46  }
47 
48  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
49 
50  void preTune() { arg.out.save(); } // defensive measure in case they alias
51  void postTune() { arg.out.load(); }
52 
53  long long flops() const { return 3 * (2 + 2 * 4) * 198ll * arg.threads; } // just counts matrix multiplication
54  long long bytes() const { return ((1 + stoutDim * 6) * arg.in.Bytes() + arg.out.Bytes()) * arg.threads; } // 6 links per dim, 1 in, 1 out.
55  }; // GaugeSTOUT
56 
57  void STOUTStep(GaugeField &out, GaugeField &in, double rho)
58  {
59 #ifdef GPU_GAUGE_TOOLS
60  checkPrecision(out, in);
61  checkReconstruct(out, in);
62 
63  if (!out.isNative()) errorQuda("Order %d with %d reconstruct not supported", in.Order(), in.Reconstruct());
64  if (!in.isNative()) errorQuda("Order %d with %d reconstruct not supported", out.Order(), out.Reconstruct());
65 
66  copyExtendedGauge(in, out, QUDA_CUDA_FIELD_LOCATION);
67  in.exchangeExtendedGhost(in.R(), false);
68  instantiate<GaugeSTOUT>(out, in, rho);
69  out.exchangeExtendedGhost(out.R(), false);
70 #else
71  errorQuda("Gauge tools are not built");
72 #endif
73  }
74 
75  template <typename Float, int nColor, QudaReconstructType recon> class GaugeOvrImpSTOUT : TunableVectorYZ
76  {
77  static constexpr int stoutDim = 4; // apply stouting in all dims
78  GaugeSTOUTArg<Float, nColor, recon, stoutDim> arg;
79  const GaugeField &meta;
80 
81  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
82  unsigned int minThreads() const { return arg.threads; }
83 
84 public:
85  GaugeOvrImpSTOUT(GaugeField &out, const GaugeField &in, double rho, double epsilon) :
86  TunableVectorYZ(2, stoutDim),
87  arg(out, in, rho, epsilon),
88  meta(in)
89  {
90  strcpy(aux, meta.AuxString());
91  strcat(aux, comm_dim_partitioned_string());
92 #ifdef JITIFY
93  create_jitify_program("kernels/gauge_stout.cuh");
94 #endif
95  apply(0);
96  qudaDeviceSynchronize();
97  }
98 
99  void apply(const qudaStream_t &stream)
100  {
101  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
102 #ifdef JITIFY
103  using namespace jitify::reflection;
104  jitify_error = program->kernel("quda::computeOvrImpSTOUTStep").instantiate(Type<decltype(arg)>())
105  .configure(tp.grid, tp.block, tp.shared_bytes, stream).launch(arg);
106 #else
107  qudaLaunchKernel(computeOvrImpSTOUTStep<decltype(arg)>, tp, stream, arg);
108 #endif
109  }
110 
111  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
112 
113  void preTune() { arg.out.save(); } // defensive measure in case they alias
114  void postTune() { arg.out.load(); }
115 
116  long long flops() const { return 4*(18+2+2*4)*198ll*arg.threads; } // just counts matrix multiplication
117  long long bytes() const { return ((1 + stoutDim * 24) * arg.in.Bytes() + arg.out.Bytes()) * arg.threads; } //24 links per dim, 1 in, 1 out
118  }; // GaugeOvrImpSTOUT
119 
120  void OvrImpSTOUTStep(GaugeField &out, GaugeField& in, double rho, double epsilon)
121  {
122 #ifdef GPU_GAUGE_TOOLS
123  checkPrecision(out, in);
124  checkReconstruct(out, in);
125 
126  if (!out.isNative()) errorQuda("Order %d with %d reconstruct not supported", in.Order(), in.Reconstruct());
127  if (!in.isNative()) errorQuda("Order %d with %d reconstruct not supported", out.Order(), out.Reconstruct());
128 
129  copyExtendedGauge(in, out, QUDA_CUDA_FIELD_LOCATION);
130  in.exchangeExtendedGhost(in.R(), false);
131  instantiate<GaugeOvrImpSTOUT>(out, in, rho, epsilon);
132  out.exchangeExtendedGhost(out.R(), false);
133 
134 #else
135  errorQuda("Gauge tools are not built");
136 #endif
137  }
138 }