QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
gauge_update_quda.cu
Go to the documentation of this file.
1 #include <cstdio>
2 #include <cstdlib>
3 #include <cuda.h>
4 #include <quda_internal.h>
5 #include <tune_quda.h>
6 #include <gauge_field.h>
7 #include <gauge_field_order.h>
8 #include <quda_matrix.h>
9 #include <float_vector.h>
10 #include <complex_quda.h>
11 
12 namespace quda {
13 
14 #ifdef GPU_GAUGE_TOOLS
15 
16  template <typename Float, typename Gauge, typename Mom>
17  struct UpdateGaugeArg {
18  Gauge out;
19  Gauge in;
20  Mom momentum;
21  Float dt;
22  int nDim;
23  UpdateGaugeArg(const Gauge &out, const Gauge &in,
24  const Mom &momentum, Float dt, int nDim)
25  : out(out), in(in), momentum(momentum), dt(dt), nDim(nDim) { }
26  };
27 
28  template<typename Float, typename Gauge, typename Mom, int N,
29  bool conj_mom, bool exact>
30  __device__ __host__ void updateGaugeFieldCompute
31  (UpdateGaugeArg<Float,Gauge,Mom> &arg, int x, int parity) {
32  typedef complex<Float> Complex;
33 
34  Matrix<Complex,3> link, result, mom;
35  for(int dir=0; dir<arg.nDim; ++dir){
36  link = arg.in(dir, x, parity);
37  mom = arg.momentum(dir, x, parity);
38 
39  Complex trace = getTrace(mom);
40  mom(0,0) -= trace/static_cast<Float>(3.0);
41  mom(1,1) -= trace/static_cast<Float>(3.0);
42  mom(2,2) -= trace/static_cast<Float>(3.0);
43 
44  if (!exact) {
45  result = link;
46 
47  // Nth order expansion of exponential
48  if (!conj_mom) {
49  for(int r=N; r>0; r--)
50  result = (arg.dt/r)*mom*result + link;
51  } else {
52  for(int r=N; r>0; r--)
53  result = (arg.dt/r)*conj(mom)*result + link;
54  }
55  } else {
56  mom = arg.dt * mom;
57  expsu3<Float>(mom);
58 
59  if (!conj_mom) {
60  link = mom * link;
61  } else {
62  link = conj(mom) * link;
63  }
64 
65  result = link;
66  }
67 
68  arg.out(dir, x, parity) = result;
69  } // dir
70 
71  }
72 
73  template<typename Float, typename Gauge, typename Mom, int N,
74  bool conj_mom, bool exact>
75  void updateGaugeField(UpdateGaugeArg<Float,Gauge,Mom> arg) {
76 
77  for (unsigned int parity=0; parity<2; parity++) {
78  for (int x=0; x<arg.out.volumeCB; x++) {
79  updateGaugeFieldCompute<Float,Gauge,Mom,N,conj_mom,exact>
80  (arg, x, parity);
81  }
82  }
83  }
84 
85  template<typename Float, typename Gauge, typename Mom, int N,
86  bool conj_mom, bool exact>
87  __global__ void updateGaugeFieldKernel(UpdateGaugeArg<Float,Gauge,Mom> arg) {
88  int idx = blockIdx.x*blockDim.x + threadIdx.x;
89  if (idx >= 2*arg.out.volumeCB) return;
90  int parity = (idx >= arg.out.volumeCB) ? 1 : 0;
91  idx -= parity*arg.out.volumeCB;
92 
93  updateGaugeFieldCompute<Float,Gauge,Mom,N,conj_mom,exact>(arg, idx, parity);
94  }
95 
96  template <typename Float, typename Gauge, typename Mom, int N,
97  bool conj_mom, bool exact>
98  class UpdateGaugeField : public Tunable {
99  private:
100  UpdateGaugeArg<Float,Gauge,Mom> arg;
101  const GaugeField &meta; // meta data
102  const QudaFieldLocation location; // location of the lattice fields
103 
104  unsigned int sharedBytesPerThread() const { return 0; }
105  unsigned int sharedBytesPerBlock(const TuneParam &) const { return 0; }
106 
107  unsigned int minThreads() const { return 2*arg.in.volumeCB; }
108  bool tuneGridDim() const { return false; }
109 
110  public:
111  UpdateGaugeField(const UpdateGaugeArg<Float,Gauge,Mom> &arg,
112  const GaugeField &meta, QudaFieldLocation location)
113  : arg(arg), meta(meta), location(location) {
114  writeAuxString("threads=%d,prec=%lu,stride=%d",
115  2*arg.in.volumeCB, sizeof(Float), arg.in.stride);
116  }
117  virtual ~UpdateGaugeField() { }
118 
119  void apply(const cudaStream_t &stream){
120  if (location == QUDA_CUDA_FIELD_LOCATION) {
121  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
122  updateGaugeFieldKernel<Float,Gauge,Mom,N,conj_mom,exact>
123  <<<tp.grid,tp.block,tp.shared_bytes>>>(arg);
124  } else { // run the CPU code
125  updateGaugeField<Float,Gauge,Mom,N,conj_mom,exact>(arg);
126  }
127  } // apply
128 
129  long long flops() const {
130  const int Nc = 3;
131  return arg.nDim*2*arg.in.volumeCB*N*(Nc*Nc*2 + // scalar-matrix multiply
132  (8*Nc*Nc*Nc - 2*Nc*Nc) + // matrix-matrix multiply
133  Nc*Nc*2); // matrix-matrix addition
134  }
135  long long bytes() const { return arg.nDim*2*arg.in.volumeCB*
136  (arg.in.Bytes() + arg.out.Bytes() + arg.momentum.Bytes()); }
137 
138  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
139  };
140 
141  template <typename Float, typename Gauge, typename Mom>
142  void updateGaugeField(Gauge &out, const Gauge &in, const Mom &mom,
143  double dt, const GaugeField &meta, bool conj_mom, bool exact,
144  QudaFieldLocation location) {
145  // degree of exponential expansion
146  const int N = 8;
147 
148  if (conj_mom) {
149  if (exact) {
150  UpdateGaugeArg<Float, Gauge, Mom> arg(out, in, mom, dt, 4);
151  UpdateGaugeField<Float,Gauge,Mom,N,true,true> updateGauge(arg, meta, location);
152  updateGauge.apply(0);
153  } else {
154  UpdateGaugeArg<Float, Gauge, Mom> arg(out, in, mom, dt, 4);
155  UpdateGaugeField<Float,Gauge,Mom,N,true,false> updateGauge(arg, meta, location);
156  updateGauge.apply(0);
157  }
158  } else {
159  if (exact) {
160  UpdateGaugeArg<Float, Gauge, Mom> arg(out, in, mom, dt, 4);
161  UpdateGaugeField<Float,Gauge,Mom,N,false,true> updateGauge(arg, meta, location);
162  updateGauge.apply(0);
163  } else {
164  UpdateGaugeArg<Float, Gauge, Mom> arg(out, in, mom, dt, 4);
165  UpdateGaugeField<Float,Gauge,Mom,N,false,false> updateGauge(arg, meta, location);
166  updateGauge.apply(0);
167  }
168  }
169 
170  if (location == QUDA_CUDA_FIELD_LOCATION) checkCudaError();
171 
172  }
173 
174  template <typename Float, typename Gauge>
175  void updateGaugeField(Gauge out, const Gauge &in, const GaugeField &mom,
176  double dt, bool conj_mom, bool exact,
177  QudaFieldLocation location) {
178  if (mom.Order() == QUDA_FLOAT2_GAUGE_ORDER) {
179  if (mom.Reconstruct() == QUDA_RECONSTRUCT_10) {
180  // FIX ME - 11 is a misnomer to avoid confusion in template instantiation
181  updateGaugeField<Float>(out, in, gauge::FloatNOrder<Float,18,2,11>(mom), dt, mom, conj_mom, exact, location);
182  } else {
183  errorQuda("Reconstruction type not supported");
184  }
185  } else if (mom.Order() == QUDA_MILC_GAUGE_ORDER) {
186  updateGaugeField<Float>(out, in, gauge::MILCOrder<Float,10>(mom), dt, mom, conj_mom, exact, location);
187  } else {
188  errorQuda("Gauge Field order %d not supported", mom.Order());
189  }
190 
191  }
192 
193  template <typename Float>
194  void updateGaugeField(GaugeField &out, const GaugeField &in, const GaugeField &mom,
195  double dt, bool conj_mom, bool exact,
196  QudaFieldLocation location) {
197 
198  const int Nc = 3;
199  if (out.Ncolor() != Nc)
200  errorQuda("Ncolor=%d not supported at this time", out.Ncolor());
201 
202  if (out.Order() != in.Order() || out.Reconstruct() != in.Reconstruct()) {
203  errorQuda("Input and output gauge field ordering and reconstruction must match");
204  }
205 
206  if (out.isNative()) {
207  if (out.Reconstruct() == QUDA_RECONSTRUCT_NO) {
208  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO>::type G;
209  updateGaugeField<Float>(G(out),G(in), mom, dt, conj_mom, exact, location);
210  } else if (out.Reconstruct() == QUDA_RECONSTRUCT_12) {
211  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_12>::type G;
212  updateGaugeField<Float>(G(out), G(in), mom, dt, conj_mom, exact, location);
213  } else {
214  errorQuda("Reconstruction type not supported");
215  }
216  } else if (out.Order() == QUDA_MILC_GAUGE_ORDER) {
217  updateGaugeField<Float>(gauge::MILCOrder<Float, Nc*Nc*2>(out),
218  gauge::MILCOrder<Float, Nc*Nc*2>(in),
219  mom, dt, conj_mom, exact, location);
220  } else {
221  errorQuda("Gauge Field order %d not supported", out.Order());
222  }
223 
224  }
225 #endif
226 
227  void updateGaugeField(GaugeField &out, double dt, const GaugeField& in,
228  const GaugeField& mom, bool conj_mom, bool exact)
229  {
230 #ifdef GPU_GAUGE_TOOLS
231  if (out.Precision() != in.Precision() || out.Precision() != mom.Precision())
232  errorQuda("Gauge and momentum fields must have matching precision");
233 
234  if (out.Location() != in.Location() || out.Location() != mom.Location())
235  errorQuda("Gauge and momentum fields must have matching location");
236 
237  if (out.Precision() == QUDA_DOUBLE_PRECISION) {
238  updateGaugeField<double>(out, in, mom, dt, conj_mom, exact, out.Location());
239  } else if (out.Precision() == QUDA_SINGLE_PRECISION) {
240  updateGaugeField<float>(out, in, mom, dt, conj_mom, exact, out.Location());
241  } else {
242  errorQuda("Precision %d not supported", out.Precision());
243  }
244 #else
245  errorQuda("Gauge tools are not build");
246 #endif
247 
248  }
249 
250 } // namespace quda
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define errorQuda(...)
Definition: util_quda.h:121
cudaStream_t * stream
cpuColorSpinorField * in
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
Main header file for host and device accessors to GaugeFields.
std::complex< double > Complex
Definition: quda_internal.h:46
QudaFieldLocation Location() const
__device__ __host__ T getTrace(const Matrix< T, 3 > &a)
Definition: quda_matrix.h:415
enum QudaFieldLocation_s QudaFieldLocation
cpuColorSpinorField * out
unsigned long long flops
Definition: blas_quda.cu:22
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
void updateGaugeField(GaugeField &out, double dt, const GaugeField &in, const GaugeField &mom, bool conj_mom, bool exact)
#define checkCudaError()
Definition: util_quda.h:161
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:130
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
QudaPrecision Precision() const
QudaParity parity
Definition: covdev_test.cpp:54
unsigned long long bytes
Definition: blas_quda.cu:23