QUDA  0.9.0
gauge_update_quda.cu
Go to the documentation of this file.
1 #include <cstdio>
2 #include <cstdlib>
3 #include <cuda.h>
4 #include <quda_internal.h>
5 #include <tune_quda.h>
6 #include <gauge_field.h>
7 #include <gauge_field_order.h>
8 #include <quda_matrix.h>
9 #include <float_vector.h>
10 #include <complex_quda.h>
11 
12 namespace quda {
13 
14 #ifdef GPU_GAUGE_TOOLS
15 
16  template <typename Float, typename Gauge, typename Mom>
17  struct UpdateGaugeArg {
18  Gauge out;
19  Gauge in;
20  Mom momentum;
21  Float dt;
22  int nDim;
23  UpdateGaugeArg(const Gauge &out, const Gauge &in,
24  const Mom &momentum, Float dt, int nDim)
25  : out(out), in(in), momentum(momentum), dt(dt), nDim(nDim) { }
26  };
27 
31  template <typename Float>
32  __device__ __host__ void expsu3(Matrix<complex<Float>,3> &q, int x) {
33  typedef complex<Float> Complex;
34 
35  Complex a2 = (q(3)*q(1)+q(7)*q(5)+q(6)*q(2) -
36  (q(0)*q(4)+(q(0)+q(4))*q(8))) / (Float)3.0 ;
37  Complex a3 = q(0)*q(4)*q(8) + q(1)*q(5)*q(6) + q(2)*q(3)*q(7) -
38  q(6)*q(4)*q(2) - q(3)*q(1)*q(8) - q(0)*q(7)*q(5);
39 
40  Complex sg2h3 = sqrt(a3*a3-(Float)4.*a2*a2*a2);
41  Complex cp = exp( log((Float)0.5*(a3+sg2h3)) / (Float)3.0);
42  Complex cm = a2/cp;
43 
44  Complex r1 = exp( Complex(0.0,1.0)*(Float)(2.0*M_PI/3.0));
45  Complex r2 = exp(-Complex(0.0,1.0)*(Float)(2.0*M_PI/3.0));
46 
47  Complex w1[3];
48 
49  w1[0]=cm+cp;
50  w1[1]=r1*cp+r2*cm;
51  w1[2]=r2*cp+r1*cm;
52  Complex z1=q(1)*q(6)-q(0)*q(7);
53  Complex z2=q(3)*q(7)-q(4)*q(6);
54 
55  Complex al = w1[0];
56  Complex wr21 = (z1+al*q(7)) / (z2+al*q(6));
57  Complex wr31 = (al-q(0)-wr21*q(3))/q(6);
58 
59  al=w1[1];
60  Complex wr22 = (z1+al*q(7))/(z2+al*q(6));
61  Complex wr32 = (al-q(0)-wr22*q(3))/q(6);
62 
63  al=w1[2];
64  Complex wr23 = (z1+al*q(7))/(z2+al*q(6));
65  Complex wr33 = (al-q(0)-wr23*q(3))/q(6);
66 
67  z1=q(3)*q(2) - q(0)*q(5);
68  z2=q(1)*q(5) - q(4)*q(2);
69 
70  al=w1[0];
71  Complex wl21 = conj((z1+al*q(5))/(z2+al*q(2)));
72  Complex wl31 = conj((al-q(0)-conj(wl21)*q(1))/q(2));
73 
74  al=w1[1];
75  Complex wl22 = conj((z1+al*q(5))/(z2+al*q(2)));
76  Complex wl32 = conj((al-q(0)-conj(wl22)*q(1))/q(2));
77 
78  al=w1[2];
79  Complex wl23 = conj((z1+al*q(5))/(z2+al*q(2)));
80  Complex wl33 = conj((al-q(0)-conj(wl23)*q(1))/q(2));
81 
82  Complex xn1 = (Float)1. + wr21*conj(wl21) + wr31*conj(wl31);
83  Complex xn2 = (Float)1. + wr22*conj(wl22) + wr32*conj(wl32);
84  Complex xn3 = (Float)1. + wr23*conj(wl23) + wr33*conj(wl33);
85 
86  Complex d1 = exp(w1[0]);
87  Complex d2 = exp(w1[1]);
88  Complex d3 = exp(w1[2]);
89  Complex y11 = d1/xn1;
90  Complex y12 = d2/xn2;
91  Complex y13 = d3/xn3;
92  Complex y21 = wr21*d1/xn1;
93  Complex y22 = wr22*d2/xn2;
94  Complex y23 = wr23*d3/xn3;
95  Complex y31 = wr31*d1/xn1;
96  Complex y32 = wr32*d2/xn2;
97  Complex y33 = wr33*d3/xn3;
98  q(0) = y11 + y12 + y13;
99  q(1) = y21 + y22 + y23;
100  q(2) = y31 + y32 + y33;
101  q(3) = y11*conj(wl21) + y12*conj(wl22) + y13*conj(wl23);
102  q(4) = y21*conj(wl21) + y22*conj(wl22) + y23*conj(wl23);
103  q(5) = y31*conj(wl21) + y32*conj(wl22) + y33*conj(wl23);
104  q(6) = y11*conj(wl31) + y12*conj(wl32) + y13*conj(wl33);
105  q(7) = y21*conj(wl31) + y22*conj(wl32) + y23*conj(wl33);
106  q(8) = y31*conj(wl31) + y32*conj(wl32) + y33*conj(wl33);
107  }
108 
109  template<typename Float, typename Gauge, typename Mom, int N,
110  bool conj_mom, bool exact>
111  __device__ __host__ void updateGaugeFieldCompute
112  (UpdateGaugeArg<Float,Gauge,Mom> &arg, int x, int parity) {
113  typedef complex<Float> Complex;
114 
115  Matrix<Complex,3> link, result, mom;
116  for(int dir=0; dir<arg.nDim; ++dir){
117  arg.in.load((Float*)(link.data), x, dir, parity);
118  arg.momentum.load((Float*)(mom.data), x, dir, parity);
119 
120  Complex trace = getTrace(mom);
121  mom(0,0) -= trace/static_cast<Float>(3.0);
122  mom(1,1) -= trace/static_cast<Float>(3.0);
123  mom(2,2) -= trace/static_cast<Float>(3.0);
124 
125  if (!exact) {
126  result = link;
127 
128  // Nth order expansion of exponential
129  if (!conj_mom) {
130  for(int r=N; r>0; r--)
131  result = (arg.dt/r)*mom*result + link;
132  } else {
133  for(int r=N; r>0; r--)
134  result = (arg.dt/r)*conj(mom)*result + link;
135  }
136  } else {
137  mom = arg.dt * mom;
138  expsu3<Float>(mom, x+dir+parity);
139 
140  if (!conj_mom) {
141  link = mom * link;
142  } else {
143  link = conj(mom) * link;
144  }
145 
146  result = link;
147  }
148 
149  arg.out.save((Float*)(result.data), x, dir, parity);
150  } // dir
151 
152  }
153 
154  template<typename Float, typename Gauge, typename Mom, int N,
155  bool conj_mom, bool exact>
156  void updateGaugeField(UpdateGaugeArg<Float,Gauge,Mom> arg) {
157 
158  for (unsigned int parity=0; parity<2; parity++) {
159  for (int x=0; x<arg.out.volumeCB; x++) {
160  updateGaugeFieldCompute<Float,Gauge,Mom,N,conj_mom,exact>
161  (arg, x, parity);
162  }
163  }
164  }
165 
166  template<typename Float, typename Gauge, typename Mom, int N,
167  bool conj_mom, bool exact>
168  __global__ void updateGaugeFieldKernel(UpdateGaugeArg<Float,Gauge,Mom> arg) {
169  int idx = blockIdx.x*blockDim.x + threadIdx.x;
170  if (idx >= 2*arg.out.volumeCB) return;
171  int parity = (idx >= arg.out.volumeCB) ? 1 : 0;
172  idx -= parity*arg.out.volumeCB;
173 
174  updateGaugeFieldCompute<Float,Gauge,Mom,N,conj_mom,exact>(arg, idx, parity);
175  }
176 
177  template <typename Float, typename Gauge, typename Mom, int N,
178  bool conj_mom, bool exact>
179  class UpdateGaugeField : public Tunable {
180  private:
181  UpdateGaugeArg<Float,Gauge,Mom> arg;
182  const GaugeField &meta; // meta data
183  const QudaFieldLocation location; // location of the lattice fields
184 
185  unsigned int sharedBytesPerThread() const { return 0; }
186  unsigned int sharedBytesPerBlock(const TuneParam &) const { return 0; }
187 
188  unsigned int minThreads() const { return 2*arg.in.volumeCB; }
189  bool tuneGridDim() const { return false; }
190 
191  public:
192  UpdateGaugeField(const UpdateGaugeArg<Float,Gauge,Mom> &arg,
193  const GaugeField &meta, QudaFieldLocation location)
194  : arg(arg), meta(meta), location(location) {
195  writeAuxString("threads=%d,prec=%lu,stride=%d",
196  2*arg.in.volumeCB, sizeof(Float), arg.in.stride);
197  }
198  virtual ~UpdateGaugeField() { }
199 
200  void apply(const cudaStream_t &stream){
201  if (location == QUDA_CUDA_FIELD_LOCATION) {
202  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
203  updateGaugeFieldKernel<Float,Gauge,Mom,N,conj_mom,exact>
204  <<<tp.grid,tp.block,tp.shared_bytes>>>(arg);
205  } else { // run the CPU code
206  updateGaugeField<Float,Gauge,Mom,N,conj_mom,exact>(arg);
207  }
208  } // apply
209 
210  long long flops() const {
211  const int Nc = 3;
212  return arg.nDim*2*arg.in.volumeCB*N*(Nc*Nc*2 + // scalar-matrix multiply
213  (8*Nc*Nc*Nc - 2*Nc*Nc) + // matrix-matrix multiply
214  Nc*Nc*2); // matrix-matrix addition
215  }
216  long long bytes() const { return arg.nDim*2*arg.in.volumeCB*
217  (arg.in.Bytes() + arg.out.Bytes() + arg.momentum.Bytes()); }
218 
219  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
220  };
221 
222  template <typename Float, typename Gauge, typename Mom>
223  void updateGaugeField(Gauge &out, const Gauge &in, const Mom &mom,
224  double dt, const GaugeField &meta, bool conj_mom, bool exact,
225  QudaFieldLocation location) {
226  // degree of exponential expansion
227  const int N = 8;
228 
229  if (conj_mom) {
230  if (exact) {
231  UpdateGaugeArg<Float, Gauge, Mom> arg(out, in, mom, dt, 4);
232  UpdateGaugeField<Float,Gauge,Mom,N,true,true> updateGauge(arg, meta, location);
233  updateGauge.apply(0);
234  } else {
235  UpdateGaugeArg<Float, Gauge, Mom> arg(out, in, mom, dt, 4);
236  UpdateGaugeField<Float,Gauge,Mom,N,true,false> updateGauge(arg, meta, location);
237  updateGauge.apply(0);
238  }
239  } else {
240  if (exact) {
241  UpdateGaugeArg<Float, Gauge, Mom> arg(out, in, mom, dt, 4);
242  UpdateGaugeField<Float,Gauge,Mom,N,false,true> updateGauge(arg, meta, location);
243  updateGauge.apply(0);
244  } else {
245  UpdateGaugeArg<Float, Gauge, Mom> arg(out, in, mom, dt, 4);
246  UpdateGaugeField<Float,Gauge,Mom,N,false,false> updateGauge(arg, meta, location);
247  updateGauge.apply(0);
248  }
249  }
250 
251  if (location == QUDA_CUDA_FIELD_LOCATION) checkCudaError();
252 
253  }
254 
255  template <typename Float, typename Gauge>
256  void updateGaugeField(Gauge out, const Gauge &in, const GaugeField &mom,
257  double dt, bool conj_mom, bool exact,
258  QudaFieldLocation location) {
259  if (mom.Order() == QUDA_FLOAT2_GAUGE_ORDER) {
260  if (mom.Reconstruct() == QUDA_RECONSTRUCT_10) {
261  // FIX ME - 11 is a misnomer to avoid confusion in template instantiation
262  updateGaugeField<Float>(out, in, gauge::FloatNOrder<Float,18,2,11>(mom), dt, mom, conj_mom, exact, location);
263  } else {
264  errorQuda("Reconstruction type not supported");
265  }
266  } else if (mom.Order() == QUDA_MILC_GAUGE_ORDER) {
267  updateGaugeField<Float>(out, in, gauge::MILCOrder<Float,10>(mom), dt, mom, conj_mom, exact, location);
268  } else {
269  errorQuda("Gauge Field order %d not supported", mom.Order());
270  }
271 
272  }
273 
274  template <typename Float>
275  void updateGaugeField(GaugeField &out, const GaugeField &in, const GaugeField &mom,
276  double dt, bool conj_mom, bool exact,
277  QudaFieldLocation location) {
278 
279  const int Nc = 3;
280  if (out.Ncolor() != Nc)
281  errorQuda("Ncolor=%d not supported at this time", out.Ncolor());
282 
283  if (out.Order() != in.Order() || out.Reconstruct() != in.Reconstruct()) {
284  errorQuda("Input and output gauge field ordering and reconstruction must match");
285  }
286 
287  if (out.isNative()) {
288  if (out.Reconstruct() == QUDA_RECONSTRUCT_NO) {
289  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO>::type G;
290  updateGaugeField<Float>(G(out),G(in), mom, dt, conj_mom, exact, location);
291  } else if (out.Reconstruct() == QUDA_RECONSTRUCT_12) {
292  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_12>::type G;
293  updateGaugeField<Float>(G(out), G(in), mom, dt, conj_mom, exact, location);
294  } else {
295  errorQuda("Reconstruction type not supported");
296  }
297  } else if (out.Order() == QUDA_MILC_GAUGE_ORDER) {
298  updateGaugeField<Float>(gauge::MILCOrder<Float, Nc*Nc*2>(out),
299  gauge::MILCOrder<Float, Nc*Nc*2>(in),
300  mom, dt, conj_mom, exact, location);
301  } else {
302  errorQuda("Gauge Field order %d not supported", out.Order());
303  }
304 
305  }
306 #endif
307 
308  void updateGaugeField(GaugeField &out, double dt, const GaugeField& in,
309  const GaugeField& mom, bool conj_mom, bool exact)
310  {
311 #ifdef GPU_GAUGE_TOOLS
312  if (out.Precision() != in.Precision() || out.Precision() != mom.Precision())
313  errorQuda("Gauge and momentum fields must have matching precision");
314 
315  if (out.Location() != in.Location() || out.Location() != mom.Location())
316  errorQuda("Gauge and momentum fields must have matching location");
317 
319  updateGaugeField<double>(out, in, mom, dt, conj_mom, exact, out.Location());
320  } else if (out.Precision() == QUDA_SINGLE_PRECISION) {
321  updateGaugeField<float>(out, in, mom, dt, conj_mom, exact, out.Location());
322  } else {
323  errorQuda("Precision %d not supported", out.Precision());
324  }
325 #else
326  errorQuda("Gauge tools are not build");
327 #endif
328 
329  }
330 
331 } // namespace quda
dim3 dim3 blockDim
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:20
__host__ __device__ ValueType exp(ValueType x)
Definition: complex_quda.h:85
#define errorQuda(...)
Definition: util_quda.h:90
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:105
std::complex< double > Complex
Definition: eig_variables.h:13
cudaStream_t * stream
cpuColorSpinorField * in
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:603
Main header file for host and device accessors to GaugeFields.
QudaFieldLocation Location() const
__host__ __device__ ValueType log(ValueType x)
Definition: complex_quda.h:90
__device__ __host__ T getTrace(const Matrix< T, 3 > &a)
Definition: quda_matrix.h:305
enum QudaFieldLocation_s QudaFieldLocation
cpuColorSpinorField * out
unsigned long long flops
Definition: blas_quda.cu:42
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
Definition: complex_quda.h:880
void updateGaugeField(GaugeField &out, double dt, const GaugeField &in, const GaugeField &mom, bool conj_mom, bool exact)
#define checkCudaError()
Definition: util_quda.h:129
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:115
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:51
QudaPrecision Precision() const
QudaParity parity
Definition: covdev_test.cpp:53
unsigned long long bytes
Definition: blas_quda.cu:43