QUDA  v1.1.0
A library for QCD on GPUs
gauge_update_quda.cu
Go to the documentation of this file.
1 #include <cstdio>
2 #include <cstdlib>
3 #include <quda_internal.h>
4 #include <tune_quda.h>
5 #include <gauge_field.h>
6 #include <gauge_field_order.h>
7 #include <quda_matrix.h>
8 #include <float_vector.h>
9 #include <complex_quda.h>
10 #include <instantiate.h>
11 
12 namespace quda {
13 
14  template <typename Float_, int nColor_, QudaReconstructType recon_u, QudaReconstructType recon_m, int N_>
15  struct UpdateGaugeArg {
16  using Float = Float_;
17  static constexpr int nColor = nColor_;
18  static constexpr int N = N_;
19  static_assert(nColor == 3, "Only nColor=3 enabled at this time");
20  typedef typename gauge_mapper<Float,recon_u>::type Gauge;
21  typedef typename gauge_mapper<Float,recon_m>::type Mom;
22  Gauge out;
23  Gauge in;
24  Mom mom;
25  Float dt;
26  int nDim;
27  UpdateGaugeArg(GaugeField &out, const GaugeField &in, const GaugeField &mom, Float dt, int nDim)
28  : out(out), in(in), mom(mom), dt(dt), nDim(nDim) { }
29  };
30 
31  template <bool conj_mom, bool exact, typename Arg>
32  __device__ __host__ void compute(Arg &arg, int x, int parity)
33  {
34  using Float = typename Arg::Float;
35  typedef complex<Float> Complex;
36  Matrix<Complex, Arg::nColor> link, result, mom;
37 
38  for (int dir=0; dir<arg.nDim; ++dir) {
39  link = arg.in(dir, x, parity);
40  mom = arg.mom(dir, x, parity);
41 
42  Complex trace = getTrace(mom);
43  for (int c=0; c<Arg::nColor; c++) mom(c,c) -= trace/static_cast<Float>(Arg::nColor);
44 
45  if (!exact) {
46  result = link;
47 
48  // Nth order expansion of exponential
49  if (!conj_mom) {
50  for (int r= Arg::N; r>0; r--)
51  result = (arg.dt/r)*mom*result + link;
52  } else {
53  for (int r= Arg::N; r>0; r--)
54  result = (arg.dt/r)*conj(mom)*result + link;
55  }
56  } else {
57  mom = arg.dt * mom;
58  expsu3<Float>(mom);
59 
60  if (!conj_mom) {
61  link = mom * link;
62  } else {
63  link = conj(mom) * link;
64  }
65 
66  result = link;
67  }
68 
69  arg.out(dir, x, parity) = result;
70  } // dir
71  }
72 
73  template <bool conj_mom, bool exact, typename Arg>
74  __global__ void updateGaugeFieldKernel(Arg arg)
75  {
76  int x_cb = blockIdx.x*blockDim.x + threadIdx.x;
77  if (x_cb >= arg.out.volumeCB) return;
78  int parity = blockIdx.y*blockDim.y + threadIdx.y;
79  compute<conj_mom,exact>(arg, x_cb, parity);
80  }
81 
82  template <typename Arg, bool conj_mom, bool exact>
83  class UpdateGaugeField : public TunableVectorY {
84  Arg &arg;
85  const GaugeField &meta; // meta data
86 
87  bool tuneGridDim() const { return false; }
88  unsigned int minThreads() const { return arg.in.volumeCB; }
89 
90  public:
91  UpdateGaugeField(Arg &arg, const GaugeField &meta) :
92  TunableVectorY(2),
93  arg(arg),
94  meta(meta) {}
95 
96  void apply(const qudaStream_t &stream){
97  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
98  qudaLaunchKernel(updateGaugeFieldKernel<conj_mom,exact,Arg>, tp, stream, arg);
99  } // apply
100 
101  long long flops() const {
102  const int Nc = Arg::nColor;
103  return arg.nDim*2*arg.in.volumeCB*Arg::N*(Nc*Nc*2 + // scalar-matrix multiply
104  (8*Nc*Nc*Nc - 2*Nc*Nc) + // matrix-matrix multiply
105  Nc*Nc*2); // matrix-matrix addition
106  }
107 
108  long long bytes() const { return arg.nDim*2*arg.in.volumeCB*(arg.in.Bytes() + arg.out.Bytes() + arg.mom.Bytes()); }
109 
110  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
111  };
112 
113  template <typename Float, int nColor, QudaReconstructType recon_u> struct UpdateGauge
114  {
115  UpdateGauge(GaugeField &out, const GaugeField &in, const GaugeField &mom, double dt, bool conj_mom, bool exact)
116  {
117  if (mom.Reconstruct() != QUDA_RECONSTRUCT_10) errorQuda("Reconstruction type %d not supported", mom.Reconstruct());
118  constexpr QudaReconstructType recon_m = QUDA_RECONSTRUCT_10;
119  constexpr int N = 8; // degree of exponential expansion
120  UpdateGaugeArg<Float, nColor, recon_u, recon_m, N> arg(out, in, mom, dt, 4);
121  if (conj_mom) {
122  if (exact) {
123  UpdateGaugeField<decltype(arg),true,true> updateGauge(arg, in);
124  updateGauge.apply(0);
125  } else {
126  UpdateGaugeField<decltype(arg),true,false> updateGauge(arg, in);
127  updateGauge.apply(0);
128  }
129  } else {
130  if (exact) {
131  UpdateGaugeField<decltype(arg),false,true> updateGauge(arg, in);
132  updateGauge.apply(0);
133  } else {
134  UpdateGaugeField<decltype(arg),false,false> updateGauge(arg, in);
135  updateGauge.apply(0);
136  }
137  }
138  }
139  };
140 
141  void updateGaugeField(GaugeField &out, double dt, const GaugeField& in, const GaugeField& mom, bool conj_mom, bool exact)
142  {
143 #ifdef GPU_GAUGE_TOOLS
144  checkPrecision(out, in, mom);
145  checkLocation(out, in, mom);
146  checkReconstruct(out, in);
147  instantiate<UpdateGauge,ReconstructNo12>(out, in, mom, dt, conj_mom, exact);
148 #else
149  errorQuda("Gauge tools are not build");
150 #endif
151  }
152 
153 } // namespace quda