QUDA  v0.7.0
A library for QCD on GPUs
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
gauge_update_quda.cu
Go to the documentation of this file.
1 #include <cstdio>
2 #include <cstdlib>
3 #include <cuda.h>
4 #include <quda_internal.h>
5 #include <tune_quda.h>
6 #include <gauge_field.h>
7 #include <gauge_field_order.h>
8 #include <quda_matrix.h>
9 #include <float_vector.h>
10 #include <complex_quda.h>
11 
12 namespace quda {
13 
14 #ifdef GPU_GAUGE_TOOLS
15 
16  template <typename Complex, typename Gauge, typename Mom>
17  struct UpdateGaugeArg {
18  typedef typename RealTypeId<Complex>::Type real;
19  Gauge out;
20  Gauge in;
21  Mom momentum;
22  real dt;
23  int nDim;
24  UpdateGaugeArg(const Gauge &out, const Gauge &in,
25  const Mom &momentum, real dt, int nDim)
26  : out(out), in(in), momentum(momentum), dt(dt), nDim(nDim) { }
27  };
28 
32  template <typename complex, typename Cmplx>
33  __device__ __host__ void expsu3(Matrix<Cmplx,3> &q, int x) {
34  typedef typename RealTypeId<Cmplx>::Type real;
35 
36  complex a2 = (q(3)*q(1)+q(7)*q(5)+q(6)*q(2) -
37  (q(0)*q(4)+(q(0)+q(4))*q(8))) / (real)3.0 ;
38  complex a3 = q(0)*q(4)*q(8) + q(1)*q(5)*q(6) + q(2)*q(3)*q(7) -
39  q(6)*q(4)*q(2) - q(3)*q(1)*q(8) - q(0)*q(7)*q(5);
40 
41  complex sg2h3 = sqrt(a3*a3-(real)4.*a2*a2*a2);
42  complex cp = exp( log((real)0.5*(a3+sg2h3)) / (real)3.0);
43  complex cm = a2/cp;
44 
45  complex r1 = exp( complex(0.0,1.0)*(real)(2.0*M_PI/3.0));
46  complex r2 = exp(-complex(0.0,1.0)*(real)(2.0*M_PI/3.0));
47 
48  complex w1[3];
49 
50  w1[0]=cm+cp;
51  w1[1]=r1*cp+r2*cm;
52  w1[2]=r2*cp+r1*cm;
53  complex z1=q(1)*q(6)-q(0)*q(7);
54  complex z2=q(3)*q(7)-q(4)*q(6);
55 
56  complex al = w1[0];
57  complex wr21 = (z1+al*q(7)) / (z2+al*q(6));
58  complex wr31 = (al-q(0)-wr21*q(3))/q(6);
59 
60  al=w1[1];
61  complex wr22 = (z1+al*q(7))/(z2+al*q(6));
62  complex wr32 = (al-q(0)-wr22*q(3))/q(6);
63 
64  al=w1[2];
65  complex wr23 = (z1+al*q(7))/(z2+al*q(6));
66  complex wr33 = (al-q(0)-wr23*q(3))/q(6);
67 
68  z1=q(3)*q(2) - q(0)*q(5);
69  z2=q(1)*q(5) - q(4)*q(2);
70 
71  al=w1[0];
72  complex wl21 = conj((z1+al*q(5))/(z2+al*q(2)));
73  complex wl31 = conj((al-q(0)-conj(wl21)*q(1))/q(2));
74 
75  al=w1[1];
76  complex wl22 = conj((z1+al*q(5))/(z2+al*q(2)));
77  complex wl32 = conj((al-q(0)-conj(wl22)*q(1))/q(2));
78 
79  al=w1[2];
80  complex wl23 = conj((z1+al*q(5))/(z2+al*q(2)));
81  complex wl33 = conj((al-q(0)-conj(wl23)*q(1))/q(2));
82 
83  complex xn1 = (real)1. + wr21*conj(wl21) + wr31*conj(wl31);
84  complex xn2 = (real)1. + wr22*conj(wl22) + wr32*conj(wl32);
85  complex xn3 = (real)1. + wr23*conj(wl23) + wr33*conj(wl33);
86 
87  complex d1 = exp(w1[0]);
88  complex d2 = exp(w1[1]);
89  complex d3 = exp(w1[2]);
90  complex y11 = d1/xn1;
91  complex y12 = d2/xn2;
92  complex y13 = d3/xn3;
93  complex y21 = wr21*d1/xn1;
94  complex y22 = wr22*d2/xn2;
95  complex y23 = wr23*d3/xn3;
96  complex y31 = wr31*d1/xn1;
97  complex y32 = wr32*d2/xn2;
98  complex y33 = wr33*d3/xn3;
99  q(0) = y11 + y12 + y13;
100  q(1) = y21 + y22 + y23;
101  q(2) = y31 + y32 + y33;
102  q(3) = y11*conj(wl21) + y12*conj(wl22) + y13*conj(wl23);
103  q(4) = y21*conj(wl21) + y22*conj(wl22) + y23*conj(wl23);
104  q(5) = y31*conj(wl21) + y32*conj(wl22) + y33*conj(wl23);
105  q(6) = y11*conj(wl31) + y12*conj(wl32) + y13*conj(wl33);
106  q(7) = y21*conj(wl31) + y22*conj(wl32) + y23*conj(wl33);
107  q(8) = y31*conj(wl31) + y32*conj(wl32) + y33*conj(wl33);
108  }
109 
114  template <typename complex, typename Cmplx>
115  __device__ __host__ void vmcm2(Matrix<Cmplx,3> a, Matrix<Cmplx,3> &b) {
116  complex c[9];
117  c[0] = conj(a(0))*b(0)+conj(a(1))*b(1) + conj(a(2))*b(2);
118  c[3] = conj(a(0))*b(3)+conj(a(1))*b(4) + conj(a(2))*b(5);
119  c[6] = conj(a(0))*b(6)+conj(a(1))*b(7) + conj(a(2))*b(8);
120  c[1] = conj(a(3))*b(0)+conj(a(4))*b(1) + conj(a(5))*b(2);
121  c[4] = conj(a(3))*b(3)+conj(a(4))*b(4) + conj(a(5))*b(5);
122  c[7] = conj(a(3))*b(6)+conj(a(4))*b(7) + conj(a(5))*b(8);
123  c[2] = conj(a(6))*b(0)+conj(a(7))*b(1) + conj(a(8))*b(2);
124  c[5] = conj(a(6))*b(3)+conj(a(7))*b(4) + conj(a(8))*b(5);
125  c[8] = conj(a(6))*b(6)+conj(a(7))*b(7) + conj(a(8))*b(8);
126  for (int i=0; i<9; i++) b(i) = c[i];
127  }
128 
129  template<typename Cmplx, typename Gauge, typename Mom, int N,
130  bool conj_mom, bool exact>
131  __device__ __host__ void updateGaugeFieldCompute
132  (UpdateGaugeArg<Cmplx,Gauge,Mom> &arg, int x, int parity) {
133 
134  typedef typename RealTypeId<Cmplx>::Type real;
135  Matrix<Cmplx,3> link, result, mom;
136  for(int dir=0; dir<arg.nDim; ++dir){
137  arg.in.load((real*)(link.data), x, dir, parity);
138  arg.momentum.load((real*)(mom.data), x, dir, parity);
139 
140  Cmplx trace = getTrace(mom);
141  mom(0,0) -= trace/3.0;
142  mom(1,1) -= trace/3.0;
143  mom(2,2) -= trace/3.0;
144 
145  if (!exact) {
146  result = link;
147 
148  // Nth order expansion of exponential
149  if (!conj_mom) {
150  for(int r=N; r>0; r--)
151  result = (arg.dt/r)*mom*result + link;
152  } else {
153  for(int r=N; r>0; r--)
154  result = (arg.dt/r)*conj(mom)*result + link;
155  }
156  } else {
157  mom = arg.dt * mom;
158  expsu3<complex<real> >(mom, x+dir+parity);
159 
160  if (!conj_mom) {
161  link = mom * link;
162  } else {
163  link = conj(mom) * link;
164  //vmcm2<complex<real> >(mom, link);
165  }
166 
167  result = link;
168  }
169 
170  arg.out.save((real*)(result.data), x, dir, parity);
171  } // dir
172 
173  }
174 
175  template<typename Cmplx, typename Gauge, typename Mom, int N,
176  bool conj_mom, bool exact>
177  void updateGaugeField(UpdateGaugeArg<Cmplx,Gauge,Mom> arg) {
178 
179  for (unsigned int parity=0; parity<2; parity++) {
180  for (unsigned int x=0; x<arg.out.volumeCB; x++) {
181  updateGaugeFieldCompute<Cmplx,Gauge,Mom,N,conj_mom,exact>
182  (arg, x, parity);
183  }
184  }
185  }
186 
187  template<typename Cmplx, typename Gauge, typename Mom, int N,
188  bool conj_mom, bool exact>
189  __global__ void updateGaugeFieldKernel(UpdateGaugeArg<Cmplx,Gauge,Mom> arg) {
190  int idx = blockIdx.x*blockDim.x + threadIdx.x;
191  if (idx >= 2*arg.out.volumeCB) return;
192  int parity = (idx >= arg.out.volumeCB) ? 1 : 0;
193  idx -= parity*arg.out.volumeCB;
194 
195  updateGaugeFieldCompute<Cmplx,Gauge,Mom,N,conj_mom,exact>
196  (arg, idx, parity);
197  }
198 
199  template <typename Complex, typename Gauge, typename Mom, int N,
200  bool conj_mom, bool exact>
201  class UpdateGaugeField : public Tunable {
202  private:
203  UpdateGaugeArg<Complex,Gauge,Mom> arg;
204  const GaugeField &meta; // meta data
205  const QudaFieldLocation location; // location of the lattice fields
206 
207  unsigned int sharedBytesPerThread() const { return 0; }
208  unsigned int sharedBytesPerBlock(const TuneParam &) const { return 0; }
209 
210  unsigned int minThreads() const { return 2*arg.in.volumeCB; }
211  bool tuneGridDim() const { return false; }
212 
213  public:
214  UpdateGaugeField(const UpdateGaugeArg<Complex,Gauge,Mom> &arg,
215  const GaugeField &meta, QudaFieldLocation location)
216  : arg(arg), meta(meta), location(location) {
217  writeAuxString("threads=%d,prec=%lu,stride=%d",
218  2*arg.in.volumeCB, sizeof(Complex)/2, arg.in.stride);
219  }
220  virtual ~UpdateGaugeField() { }
221 
222  void apply(const cudaStream_t &stream){
224 #if __COMPUTE_CAPABILITY__ >= 200
225  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
226  updateGaugeFieldKernel<Complex,Gauge,Mom,N,conj_mom,exact>
227  <<<tp.grid,tp.block,tp.shared_bytes>>>(arg);
228 #else
229  errorQuda("Not supported on pre-Fermi architecture");
230 #endif
231  } else { // run the CPU code
232  updateGaugeField<Complex,Gauge,Mom,N,conj_mom,exact>(arg);
233  }
234  } // apply
235 
236  void preTune(){}
237  void postTune(){}
238 
239  long long flops() const {
240  const int Nc = 3;
241  return arg.nDim*2*arg.in.volumeCB*N*(Nc*Nc*2 + // scalar-matrix multiply
242  (8*Nc*Nc*Nc - 2*Nc*Nc) + // matrix-matrix multiply
243  Nc*Nc*2); // matrix-matrix addition
244  }
245  long long bytes() const { return arg.nDim*2*arg.in.volumeCB*
246  (arg.in.Bytes() + arg.out.Bytes() + arg.momentum.Bytes()); }
247 
248  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
249  };
250 
251  template <typename Float, typename Gauge, typename Mom>
252  void updateGaugeField(Gauge &out, const Gauge &in, const Mom &mom,
253  double dt, const GaugeField &meta, bool conj_mom, bool exact,
255  // degree of exponential expansion
256  const int N = 8;
257 
258  typedef typename ComplexTypeId<Float>::Type Complex;
259  if (conj_mom) {
260  if (exact) {
261  UpdateGaugeArg<Complex, Gauge, Mom> arg(out, in, mom, dt, 4);
262  UpdateGaugeField<Complex,Gauge,Mom,N,true,true> updateGauge(arg, meta, location);
263  updateGauge.apply(0);
264  } else {
265  UpdateGaugeArg<Complex, Gauge, Mom> arg(out, in, mom, dt, 4);
266  UpdateGaugeField<Complex,Gauge,Mom,N,true,false> updateGauge(arg, meta, location);
267  updateGauge.apply(0);
268  }
269  } else {
270  if (exact) {
271  UpdateGaugeArg<Complex, Gauge, Mom> arg(out, in, mom, dt, 4);
272  UpdateGaugeField<Complex,Gauge,Mom,N,false,true> updateGauge(arg, meta, location);
273  updateGauge.apply(0);
274  } else {
275  UpdateGaugeArg<Complex, Gauge, Mom> arg(out, in, mom, dt, 4);
276  UpdateGaugeField<Complex,Gauge,Mom,N,false,false> updateGauge(arg, meta, location);
277  updateGauge.apply(0);
278  }
279  }
280 
281  if (location == QUDA_CUDA_FIELD_LOCATION) checkCudaError();
282 
283  }
284 
285  template <typename Float, typename Gauge>
286  void updateGaugeField(Gauge out, const Gauge &in, const GaugeField &mom,
287  double dt, bool conj_mom, bool exact,
288  QudaFieldLocation location) {
289  if (mom.Order() == QUDA_FLOAT2_GAUGE_ORDER) {
290  if (mom.Reconstruct() == QUDA_RECONSTRUCT_10) {
291  // FIX ME - 11 is a misnomer to avoid confusion in template instantiation
292  updateGaugeField<Float>(out, in, FloatNOrder<Float,18,2,11>(mom), dt, mom, conj_mom, exact, location);
293  } else {
294  errorQuda("Reconstruction type not supported");
295  }
296  } else if (mom.Order() == QUDA_MILC_GAUGE_ORDER) {
297  updateGaugeField<Float>(out, in, MILCOrder<Float,10>(mom), dt, mom, conj_mom, exact, location);
298  } else {
299  errorQuda("Gauge Field order %d not supported", mom.Order());
300  }
301 
302  }
303 
304  template <typename Float>
305  void updateGaugeField(GaugeField &out, const GaugeField &in, const GaugeField &mom,
306  double dt, bool conj_mom, bool exact,
307  QudaFieldLocation location) {
308 
309  const int Nc = 3;
310  if (out.Ncolor() != Nc)
311  errorQuda("Ncolor=%d not supported at this time", out.Ncolor());
312 
313  if (out.Order() != in.Order() || out.Reconstruct() != in.Reconstruct()) {
314  errorQuda("Input and output gauge field ordering and reconstruction must match");
315  }
316 
317  if (out.Order() == QUDA_FLOAT2_GAUGE_ORDER) {
318  if (out.Reconstruct() == QUDA_RECONSTRUCT_NO) {
319  updateGaugeField<Float>(FloatNOrder<Float, Nc*Nc*2, 2, 18>(out),
320  FloatNOrder<Float, Nc*Nc*2, 2, 18>(in),
321  mom, dt, conj_mom, exact, location);
322  } else if (out.Reconstruct() == QUDA_RECONSTRUCT_12) {
323  updateGaugeField<Float>(FloatNOrder<Float, Nc*Nc*2, 2, 12>(out),
324  FloatNOrder<Float, Nc*Nc*2, 2, 12>(in),
325  mom, dt, conj_mom, exact, location);
326  } else {
327  errorQuda("Reconstruction type not supported");
328  }
329  } else if (out.Order() == QUDA_FLOAT4_GAUGE_ORDER) {
330  if (out.Reconstruct() == QUDA_RECONSTRUCT_12) {
331  updateGaugeField<Float>(FloatNOrder<Float, Nc*Nc*2, 4, 12>(out),
332  FloatNOrder<Float, Nc*Nc*2, 4, 12>(in),
333  mom, dt, conj_mom, exact, location);
334  } else {
335  errorQuda("Reconstruction type %d not supported", out.Order());
336  }
337  } else if (out.Order() == QUDA_MILC_GAUGE_ORDER) {
338  updateGaugeField<Float>(MILCOrder<Float, Nc*Nc*2>(out),
339  MILCOrder<Float, Nc*Nc*2>(in),
340  mom, dt, conj_mom, exact, location);
341  } else {
342  errorQuda("Gauge Field order %d not supported", out.Order());
343  }
344 
345  }
346 #endif
347 
348  void updateGaugeField(GaugeField &out, double dt, const GaugeField& in,
349  const GaugeField& mom, bool conj_mom, bool exact)
350  {
351 #ifdef GPU_GAUGE_TOOLS
352  if (out.Precision() != in.Precision() || out.Precision() != mom.Precision())
353  errorQuda("Gauge and momentum fields must have matching precision");
354 
355  if (out.Location() != in.Location() || out.Location() != mom.Location())
356  errorQuda("Gauge and momentum fields must have matching location");
357 
358  if (out.Precision() == QUDA_DOUBLE_PRECISION) {
359  updateGaugeField<double>(out, in, mom, dt, conj_mom, exact, out.Location());
360  } else if (out.Precision() == QUDA_SINGLE_PRECISION) {
361  updateGaugeField<float>(out, in, mom, dt, conj_mom, exact, out.Location());
362  } else {
363  errorQuda("Precision %d not supported", out.Precision());
364  }
365 #else
366  errorQuda("Gauge tools are not build");
367 #endif
368 
369  }
370 
371 } // namespace quda
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:20
__host__ __device__ ValueType exp(ValueType x)
Definition: complex_quda.h:85
#define errorQuda(...)
Definition: util_quda.h:73
QudaFieldLocation Location() const
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:105
std::complex< double > Complex
Definition: eig_variables.h:13
cudaStream_t * stream
QudaPrecision Precision() const
const QudaFieldLocation location
Definition: pack_test.cpp:46
int z1
Definition: llfat_core.h:814
cpuColorSpinorField * in
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:271
int x[4]
__host__ __device__ ValueType log(ValueType x)
Definition: complex_quda.h:90
__device__ __host__ T getTrace(const Matrix< T, 3 > &a)
Definition: quda_matrix.h:378
enum QudaFieldLocation_s QudaFieldLocation
cpuColorSpinorField * out
int z2
Definition: llfat_core.h:816
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
Definition: complex_quda.h:843
void updateGaugeField(GaugeField &out, double dt, const GaugeField &in, const GaugeField &mom, bool conj_mom, bool exact)
#define checkCudaError()
Definition: util_quda.h:110
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:115
QudaTune getTuning()
Definition: util_quda.cpp:32
const QudaParity parity
Definition: dslash_test.cpp:29