QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
gauge_stout.cu
Go to the documentation of this file.
1 #include <quda_internal.h>
2 #include <tune_quda.h>
3 #include <gauge_field.h>
4 
5 #define DOUBLE_TOL 1e-15
6 #define SINGLE_TOL 2e-6
7 
8 #include <jitify_helper.cuh>
10 
11 namespace quda {
12 
13 #ifdef GPU_GAUGE_TOOLS
14 
15  template <typename Float, typename Arg> class GaugeSTOUT : TunableVectorYZ
16  {
17  Arg &arg;
18  const GaugeField &meta;
19 
20 private:
21  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
22  unsigned int minThreads() const { return arg.threads; }
23 
24 public:
25  // (2,3): 2 for parity in the y thread dim, 3 corresponds to mapping direction to the z thread dim
26  GaugeSTOUT(Arg &arg, const GaugeField &meta) : TunableVectorYZ(2, 3), arg(arg), meta(meta)
27  {
28 #ifdef JITIFY
29  create_jitify_program("kernels/gauge_stout.cuh");
30 #endif
31  }
32  virtual ~GaugeSTOUT() {}
33 
34  void apply(const cudaStream_t &stream)
35  {
36  if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
37  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
38 #ifdef JITIFY
39  using namespace jitify::reflection;
40  jitify_error = program->kernel("quda::computeSTOUTStep")
41  .instantiate(Type<Float>(), Type<Arg>())
42  .configure(tp.grid, tp.block, tp.shared_bytes, stream)
43  .launch(arg);
44 #else
45  computeSTOUTStep<Float><<<tp.grid, tp.block, tp.shared_bytes>>>(arg);
46 #endif
47  } else {
48  errorQuda("CPU not supported yet\n");
49  // computeSTOUTStepCPU(arg);
50  }
51  }
52 
53  TuneKey tuneKey() const
54  {
55  std::stringstream aux;
56  aux << "threads=" << arg.threads << ",prec=" << sizeof(Float);
57  return TuneKey(meta.VolString(), typeid(*this).name(), aux.str().c_str());
58  }
59 
60  void preTune() { arg.dest.save(); } // defensive measure in case they alias
61  void postTune() { arg.dest.load(); }
62 
63  long long flops() const { return 3 * (2 + 2 * 4) * 198ll * arg.threads; } // just counts matrix multiplication
64  long long bytes() const { return 3 * ((1 + 2 * 6) * arg.origin.Bytes() + arg.dest.Bytes()) * arg.threads; }
65  }; // GaugeSTOUT
66 
67  template<typename Float,typename GaugeOr, typename GaugeDs>
68  void STOUTStep(GaugeOr origin, GaugeDs dest, const GaugeField& dataOr, Float rho) {
69  GaugeSTOUTArg<Float,GaugeOr,GaugeDs> arg(origin, dest, dataOr, rho, dataOr.Precision() == QUDA_DOUBLE_PRECISION ? DOUBLE_TOL : SINGLE_TOL);
70  GaugeSTOUT<Float, GaugeSTOUTArg<Float, GaugeOr, GaugeDs>> gaugeSTOUT(arg, dataOr);
71  gaugeSTOUT.apply(0);
73  }
74 
75  template<typename Float>
76  void STOUTStep(GaugeField &dataDs, const GaugeField& dataOr, Float rho) {
77 
78  if(dataDs.Reconstruct() == QUDA_RECONSTRUCT_NO) {
79  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO>::type GDs;
80 
81  if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_NO) {
82  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO>::type GOr;
83  STOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho);
84  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_12){
85  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_12>::type GOr;
86  STOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho);
87  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_8){
88  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_8>::type GOr;
89  STOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho);
90  }else{
91  errorQuda("Reconstruction type %d of origin gauge field not supported", dataOr.Reconstruct());
92  }
93  } else if(dataDs.Reconstruct() == QUDA_RECONSTRUCT_12){
94  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_12>::type GDs;
95  if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_NO){
96  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO>::type GOr;
97  STOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho);
98  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_12){
99  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_12>::type GOr;
100  STOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho);
101  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_8){
102  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_8>::type GOr;
103  STOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho);
104  }else{
105  errorQuda("Reconstruction type %d of origin gauge field not supported", dataOr.Reconstruct());
106  }
107  } else if(dataDs.Reconstruct() == QUDA_RECONSTRUCT_8){
108  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_8>::type GDs;
109  if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_NO){
110  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO>::type GOr;
111  STOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho);
112  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_12){
113  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_12>::type GOr;
114  STOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho);
115  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_8){
116  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_8>::type GOr;
117  STOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho);
118  }else{
119  errorQuda("Reconstruction type %d of origin gauge field not supported", dataOr.Reconstruct());
120  }
121  } else {
122  errorQuda("Reconstruction type %d of destination gauge field not supported", dataDs.Reconstruct());
123  }
124 
125  }
126 
127 #endif
128 
129  void STOUTStep(GaugeField &dataDs, const GaugeField& dataOr, double rho) {
130 
131 #ifdef GPU_GAUGE_TOOLS
132 
133  if(dataOr.Precision() != dataDs.Precision()) {
134  errorQuda("Origin and destination fields must have the same precision\n");
135  }
136 
137  if(dataDs.Precision() == QUDA_HALF_PRECISION){
138  errorQuda("Half precision not supported\n");
139  }
140 
141  if (!dataOr.isNative())
142  errorQuda("Order %d with %d reconstruct not supported", dataOr.Order(), dataOr.Reconstruct());
143 
144  if (!dataDs.isNative())
145  errorQuda("Order %d with %d reconstruct not supported", dataDs.Order(), dataDs.Reconstruct());
146 
147  if (dataDs.Precision() == QUDA_SINGLE_PRECISION){
148  STOUTStep<float>(dataDs, dataOr, (float) rho);
149  } else if(dataDs.Precision() == QUDA_DOUBLE_PRECISION) {
150  STOUTStep<double>(dataDs, dataOr, rho);
151  } else {
152  errorQuda("Precision %d not supported", dataDs.Precision());
153  }
154  return;
155 #else
156  errorQuda("Gauge tools are not built");
157 #endif
158  }
159 
160  template <typename Float, typename Arg> class GaugeOvrImpSTOUT : TunableVectorYZ
161  {
163  const GaugeField &meta;
164 
165 private:
166  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
167  unsigned int minThreads() const { return arg.threads; }
168 
169 public:
170  // (2,3): 2 for parity in the y thread dim, 3 corresponds to mapping direction to the z thread dim
171  GaugeOvrImpSTOUT(Arg &arg, const GaugeField &meta) : TunableVectorYZ(2, 3), arg(arg), meta(meta) {}
172  virtual ~GaugeOvrImpSTOUT() {}
173 
174  void apply(const cudaStream_t &stream)
175  {
176  if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
177  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
178 #ifdef JITIFY
179  using namespace jitify::reflection;
180  jitify_error = program->kernel("quda::computeOvrImpSTOUTStep")
181  .instantiate(Type<Float>(), Type<Arg>())
182  .configure(tp.grid, tp.block, tp.shared_bytes, stream)
183  .launch(arg);
184 #else
185  computeOvrImpSTOUTStep<Float><<<tp.grid, tp.block, tp.shared_bytes>>>(arg);
186 #endif
187  } else {
188  errorQuda("CPU not supported yet\n");
189  // computeOvrImpSTOUTStepCPU(arg);
190  }
191  }
192 
193  TuneKey tuneKey() const
194  {
195  std::stringstream aux;
196  aux << "threads=" << arg.threads << ",prec=" << sizeof(Float);
197  return TuneKey(meta.VolString(), typeid(*this).name(), aux.str().c_str());
198  }
199 
200  void preTune() { arg.dest.save(); } // defensive measure in case they alias
201  void postTune() { arg.dest.load(); }
202 
203  long long flops() const { return 4*(18+2+2*4)*198ll*arg.threads; } // just counts matrix multiplication
204  long long bytes() const { return 4*((1+2*12)*arg.origin.Bytes()+arg.dest.Bytes())*arg.threads; }
205  }; // GaugeOvrImpSTOUT
206 
207  template<typename Float,typename GaugeOr, typename GaugeDs>
208  void OvrImpSTOUTStep(GaugeOr origin, GaugeDs dest, const GaugeField& dataOr, Float rho, Float epsilon) {
210  origin, dest, dataOr, rho, epsilon, dataOr.Precision() == QUDA_DOUBLE_PRECISION ? DOUBLE_TOL : SINGLE_TOL);
212  gaugeOvrImpSTOUT.apply(0);
214  }
215 
216  template<typename Float>
217  void OvrImpSTOUTStep(GaugeField &dataDs, const GaugeField& dataOr, Float rho, Float epsilon) {
218 
219  if(dataDs.Reconstruct() == QUDA_RECONSTRUCT_NO) {
221 
222  if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_NO) {
224  OvrImpSTOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho, epsilon);
225  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_12){
227  OvrImpSTOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho, epsilon);
228  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_8){
230  OvrImpSTOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho, epsilon);
231  }else{
232  errorQuda("Reconstruction type %d of origin gauge field not supported", dataOr.Reconstruct());
233  }
234  } else if(dataDs.Reconstruct() == QUDA_RECONSTRUCT_12){
236  if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_NO){
238  OvrImpSTOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho, epsilon);
239  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_12){
241  OvrImpSTOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho, epsilon);
242  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_8){
244  OvrImpSTOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho, epsilon);
245  }else{
246  errorQuda("Reconstruction type %d of origin gauge field not supported", dataOr.Reconstruct());
247  }
248  } else if(dataDs.Reconstruct() == QUDA_RECONSTRUCT_8){
250  if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_NO){
252  OvrImpSTOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho, epsilon);
253  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_12){
255  OvrImpSTOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho, epsilon);
256  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_8){
258  OvrImpSTOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho, epsilon);
259  }else{
260  errorQuda("Reconstruction type %d of origin gauge field not supported", dataOr.Reconstruct());
261  }
262  } else {
263  errorQuda("Reconstruction type %d of destination gauge field not supported", dataDs.Reconstruct());
264  }
265 
266  }
267 
268 
269  void OvrImpSTOUTStep(GaugeField &dataDs, const GaugeField& dataOr, double rho, double epsilon) {
270 
271 #ifdef GPU_GAUGE_TOOLS
272 
273  if(dataOr.Precision() != dataDs.Precision()) {
274  errorQuda("Origin and destination fields must have the same precision\n");
275  }
276 
277  if(dataDs.Precision() == QUDA_HALF_PRECISION){
278  errorQuda("Half precision not supported\n");
279  }
280 
281  if (!dataOr.isNative())
282  errorQuda("Order %d with %d reconstruct not supported", dataOr.Order(), dataOr.Reconstruct());
283 
284  if (!dataDs.isNative())
285  errorQuda("Order %d with %d reconstruct not supported", dataDs.Order(), dataDs.Reconstruct());
286 
287  if (dataDs.Precision() == QUDA_SINGLE_PRECISION){
288  OvrImpSTOUTStep<float>(dataDs, dataOr, (float) rho, epsilon);
289  } else if(dataDs.Precision() == QUDA_DOUBLE_PRECISION) {
290  OvrImpSTOUTStep<double>(dataDs, dataOr, rho, epsilon);
291  } else {
292  errorQuda("Precision %d not supported", dataDs.Precision());
293  }
294  return;
295 #else
296  errorQuda("Gauge tools are not built");
297 #endif
298  }
299 }
GaugeOvrImpSTOUT(Arg &arg, const GaugeField &meta)
Definition: gauge_stout.cu:171
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define errorQuda(...)
Definition: util_quda.h:121
Helper file when using jitify run-time compilation. This file should be included in source code...
TuneKey tuneKey() const
Definition: gauge_stout.cu:193
void STOUTStep(GaugeField &dataDs, const GaugeField &dataOr, double rho)
Apply STOUT smearing to the gauge field.
Definition: gauge_stout.cu:129
double epsilon
Definition: test_util.cpp:1649
cudaStream_t * stream
const GaugeField & meta
Definition: gauge_stout.cu:163
const char * VolString() const
long long bytes() const
Definition: gauge_stout.cu:204
#define qudaDeviceSynchronize()
#define SINGLE_TOL
Definition: gauge_stout.cu:6
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
void OvrImpSTOUTStep(GaugeField &dataDs, const GaugeField &dataOr, double rho, double epsilon)
Apply Over Improved STOUT smearing to the gauge field.
Definition: gauge_stout.cu:269
QudaFieldLocation Location() const
long long flops() const
Definition: gauge_stout.cu:203
#define DOUBLE_TOL
Definition: gauge_stout.cu:5
unsigned long long flops
Definition: blas_quda.cu:22
unsigned int minThreads() const
Definition: gauge_stout.cu:167
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
bool tuneGridDim() const
Definition: gauge_stout.cu:166
QudaReconstructType Reconstruct() const
Definition: gauge_field.h:250
QudaGaugeFieldOrder Order() const
Definition: gauge_field.h:251
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
QudaPrecision Precision() const
bool isNative() const
unsigned long long bytes
Definition: blas_quda.cu:23
void apply(const cudaStream_t &stream)
Definition: gauge_stout.cu:174