QUDA  v1.1.0
A library for QCD on GPUs
unitarize_links_quda.cu
Go to the documentation of this file.
1 #include <cstdlib>
2 #include <cstdio>
3 
4 #include <gauge_field.h>
5 #include <gauge_field_order.h>
6 #include <tune_quda.h>
7 #include <quda_matrix.h>
8 #include <unitarization_links.h>
9 #include <su3_project.cuh>
10 #include <index_helper.cuh>
11 #include <instantiate.h>
12 #include <color_spinor.h>
13 
14 namespace quda {
15 
16 namespace {
17 #include <svd_quda.h>
18 }
19 
20 #ifndef FL_UNITARIZE_PI
21 #define FL_UNITARIZE_PI 3.14159265358979323846
22 #endif
23 #ifndef FL_UNITARIZE_PI23
24 #define FL_UNITARIZE_PI23 FL_UNITARIZE_PI*0.66666666666666666666
25 #endif
26 
27 
28  // supress compiler warnings about unused variables when GPU_UNITARIZE is not set
29  // when we switch to C++17 consider [[maybe_unused]]
30  __attribute__((unused)) static const int max_iter_newton = 20;
31  __attribute__((unused))static const int max_iter = 20;
32 
33  __attribute__((unused)) static double unitarize_eps = 1e-14;
34  __attribute__((unused)) static double max_error = 1e-10;
35  __attribute__((unused)) static int reunit_allow_svd = 1;
36  __attribute__((unused)) static int reunit_svd_only = 0;
37  __attribute__((unused)) static double svd_rel_error = 1e-6;
38  __attribute__((unused)) static double svd_abs_error = 1e-6;
39 
40  template <typename Float_, int nColor_, QudaReconstructType recon_>
41  struct UnitarizeLinksArg {
42  using Float = Float_;
43  static constexpr int nColor = nColor_;
44  static constexpr QudaReconstructType recon = recon_;
45  typedef typename gauge_mapper<Float,recon>::type Gauge;
46  Gauge out;
47  const Gauge in;
48 
49  int threads; // number of active threads required
50  int X[4]; // grid dimensions
51  int *fails;
52  const int max_iter;
53  const double unitarize_eps;
54  const double max_error;
55  const int reunit_allow_svd;
56  const int reunit_svd_only;
57  const double svd_rel_error;
58  const double svd_abs_error;
59  const static bool check_unitarization = true;
60 
61  UnitarizeLinksArg(GaugeField &out, const GaugeField &in, int* fails, int max_iter,
62  double unitarize_eps, double max_error, int reunit_allow_svd,
63  int reunit_svd_only, double svd_rel_error, double svd_abs_error) :
64  out(out),
65  in(in),
66  threads(in.VolumeCB()),
67  fails(fails),
68  unitarize_eps(unitarize_eps),
69  max_iter(max_iter),
70  max_error(max_error),
71  reunit_allow_svd(reunit_allow_svd),
72  reunit_svd_only(reunit_svd_only),
73  svd_rel_error(svd_rel_error),
74  svd_abs_error(svd_abs_error)
75  {
76  for (int dir=0; dir<4; ++dir) X[dir] = in.X()[dir];
77  }
78  };
79 
80  void setUnitarizeLinksConstants(double unitarize_eps_, double max_error_,
81  bool reunit_allow_svd_, bool reunit_svd_only_,
82  double svd_rel_error_, double svd_abs_error_) {
83  unitarize_eps = unitarize_eps_;
84  max_error = max_error_;
85  reunit_allow_svd = reunit_allow_svd_;
86  reunit_svd_only = reunit_svd_only_;
87  svd_rel_error = svd_rel_error_;
88  svd_abs_error = svd_abs_error_;
89  }
90 
91  template <typename mat>
92  __device__ __host__ bool isUnitarizedLinkConsistent(const mat &initial_matrix,
93  const mat &unitary_matrix, double max_error)
94  {
95  auto n = initial_matrix.size();
96  mat temporary = conj(initial_matrix)*unitary_matrix;
97  temporary = temporary*temporary - conj(initial_matrix)*initial_matrix;
98 
99  for (int i=0; i<n; ++i) {
100  for (int j=0; j<n; ++j) {
101  if (fabs(temporary(i,j).x) > max_error || fabs(temporary(i,j).y) > max_error) {
102  return false;
103  }
104  }
105  }
106  return true;
107  }
108 
109 
110  template <class T> constexpr T getAbsMin(const T* const array, int size)
111  {
112  T min = fabs(array[0]);
113  for(int i=1; i<size; ++i){
114  T abs_val = fabs(array[i]);
115  if((abs_val) < min){ min = abs_val; }
116  }
117  return min;
118  }
119 
120  template <class Real> constexpr bool checkAbsoluteError(Real a, Real b, Real epsilon) { return fabs(a-b) < epsilon; }
121 
122  template <class Real> constexpr bool checkRelativeError(Real a, Real b, Real epsilon) { return fabs((a-b)/b) < epsilon; }
123 
124  // Compute the reciprocal square root of the matrix q
125  // Also modify q if the eigenvalues are dangerously small.
126  template <typename real, typename mat, typename Arg>
127  __device__ __host__ bool reciprocalRoot(mat &res, const mat& q, Arg &arg)
128  {
129  mat qsq, tempq;
130 
131  real c[3];
132  real g[3];
133 
134  const real one_third = 0.333333333333333333333;
135  const real one_ninth = 0.111111111111111111111;
136  const real one_eighteenth = 0.055555555555555555555;
137 
138  qsq = q*q;
139  tempq = qsq*q;
140 
141  c[0] = getTrace(q).x;
142  c[1] = getTrace(qsq).x * 0.5;
143  c[2] = getTrace(tempq).x * one_third;;
144 
145  g[0] = g[1] = g[2] = c[0] * one_third;
146  real r,s,theta;
147  s = c[1]*one_third - c[0]*c[0]*one_eighteenth;
148 
149  real cosTheta;
150  if (fabs(s) >= arg.unitarize_eps) { // faster when this conditional is removed?
151  const real rsqrt_s = rsqrt(s);
152  r = c[2]*0.5 - (c[0]*one_third)*(c[1] - c[0]*c[0]*one_ninth);
153  cosTheta = r*rsqrt_s*rsqrt_s*rsqrt_s;
154 
155  if(fabs(cosTheta) >= 1.0){
156  theta = (r > 0) ? 0.0 : FL_UNITARIZE_PI;
157  }else{
158  theta = acos(cosTheta); // this is the primary performance limiter
159  }
160 
161  const real sqrt_s = s*rsqrt_s;
162 
163 #if 0 // experimental version
164  real as, ac;
165  sincos( theta*one_third, &as, &ac );
166  g[0] = c[0]*one_third + 2*sqrt_s*ac;
167  //g[1] = c[0]*one_third + 2*sqrt_s*(ac*cos(1*FL_UNITARIZE_PI23) - as*sin(1*FL_UNITARIZE_PI23));
168  g[1] = c[0]*one_third - 2*sqrt_s*(0.5*ac + as*0.8660254037844386467637);
169  //g[2] = c[0]*one_third + 2*sqrt_s*(ac*cos(2*FL_UNITARIZE_PI23) - as*sin(2*FL_UNITARIZE_PI23));
170  g[2] = c[0]*one_third + 2*sqrt_s*(-0.5*ac + as*0.8660254037844386467637);
171 #else
172  g[0] = c[0]*one_third + 2*sqrt_s*cos( theta*one_third );
173  g[1] = c[0]*one_third + 2*sqrt_s*cos( theta*one_third + FL_UNITARIZE_PI23 );
174  g[2] = c[0]*one_third + 2*sqrt_s*cos( theta*one_third + 2*FL_UNITARIZE_PI23 );
175 #endif
176  }
177 
178  // Check the eigenvalues, if the determinant does not match the product of the eigenvalues
179  // return false. Then call SVD instead.
180  real det = getDeterminant(q).x;
181  if (fabs(det) < arg.svd_abs_error) return false;
182  if (!checkRelativeError<double>(g[0]*g[1]*g[2], det, arg.svd_rel_error)) return false;
183 
184  // At this point we have finished with the c's
185  // use these to store sqrt(g)
186  for(int i=0; i<3; ++i) c[i] = sqrt(g[i]);
187 
188  // done with the g's, use these to store u, v, w
189  g[0] = c[0]+c[1]+c[2];
190  g[1] = c[0]*c[1] + c[0]*c[2] + c[1]*c[2];
191  g[2] = c[0]*c[1]*c[2];
192 
193  const real denominator = 1.0 / ( g[2]*(g[0]*g[1]-g[2]) );
194  c[0] = (g[0]*g[1]*g[1] - g[2]*(g[0]*g[0]+g[1])) * denominator;
195  c[1] = (-g[0]*g[0]*g[0] - g[2] + 2.*g[0]*g[1]) * denominator;
196  c[2] = g[0] * denominator;
197 
198  tempq = c[1]*q + c[2]*qsq;
199  // Add a real scalar
200  tempq(0,0).x += c[0];
201  tempq(1,1).x += c[0];
202  tempq(2,2).x += c[0];
203 
204  res = tempq;
205 
206  return true;
207  }
208 
209  template <typename real, typename mat, typename Arg>
210  __host__ __device__ bool unitarizeLinkMILC(mat &out, const mat &in, Arg &arg)
211  {
212  mat u;
213  if (!arg.reunit_svd_only) {
214  if (reciprocalRoot<real>(u, conj(in)*in, arg) ) {
215  out = in * u;
216  return true;
217  }
218  }
219 
220  // If we've got this far, then the Caley-Hamilton unitarization
221  // has failed. If SVD is not allowed, the unitarization has failed.
222  if (!arg.reunit_allow_svd) return false;
223 
224  mat v;
225  real singular_values[3];
226  computeSVD<real>(in, u, v, singular_values);
227  out = u * conj(v);
228  return true;
229  } // unitarizeMILC
230 
231  template <typename mat>
232  __host__ __device__ bool unitarizeLinkNewton(mat &out, const mat& in, int max_iter)
233  {
234  mat u = in;
235 
236  for (int i=0; i<max_iter; ++i) {
237  mat uinv = inverse(u);
238  u = 0.5*(u + conj(uinv));
239  }
240 
241  if (isUnitarizedLinkConsistent(in,u,0.0000001)==false) {
242  printf("ERROR: Unitarized link is not consistent with incoming link\n");
243  return false;
244  }
245  out = u;
246 
247  return true;
248  }
249 
250  void unitarizeLinksCPU(GaugeField &outfield, const GaugeField& infield)
251  {
252 #ifdef GPU_UNITARIZE
253  if (checkLocation(outfield, infield) != QUDA_CPU_FIELD_LOCATION) errorQuda("Location must be CPU");
254  checkPrecision(outfield, infield);
255 
256  int num_failures = 0;
257  Matrix<complex<double>,3> inlink, outlink;
258 
259  for (unsigned int i = 0; i < infield.Volume(); ++i) {
260  for (int dir=0; dir<4; ++dir){
261  if (infield.Precision() == QUDA_SINGLE_PRECISION) {
262  copyArrayToLink(&inlink, ((float*)(infield.Gauge_p()) + (i*4 + dir)*18)); // order of arguments?
263  if (unitarizeLinkNewton(outlink, inlink, max_iter_newton) == false ) num_failures++;
264  copyLinkToArray(((float*)(outfield.Gauge_p()) + (i*4 + dir)*18), outlink);
265  } else if (infield.Precision() == QUDA_DOUBLE_PRECISION) {
266  copyArrayToLink(&inlink, ((double*)(infield.Gauge_p()) + (i*4 + dir)*18)); // order of arguments?
267  if (unitarizeLinkNewton(outlink, inlink, max_iter_newton) == false ) num_failures++;
268  copyLinkToArray(((double*)(outfield.Gauge_p()) + (i*4 + dir)*18), outlink);
269  } // precision?
270  } // dir
271  } // loop over volume
272 #else
273  errorQuda("Unitarization has not been built");
274 #endif
275  }
276 
277  // CPU function which checks that the gauge field is unitary
278  bool isUnitary(const GaugeField& field, double max_error)
279  {
280 #ifdef GPU_UNITARIZE
281  if (field.Location() != QUDA_CPU_FIELD_LOCATION) errorQuda("Location must be CPU");
282  Matrix<complex<double>,3> link, identity;
283 
284  for (unsigned int i = 0; i < field.Volume(); ++i) {
285  for (int dir=0; dir<4; ++dir) {
286  if (field.Precision() == QUDA_SINGLE_PRECISION) {
287  copyArrayToLink(&link, ((float*)(field.Gauge_p()) + (i*4 + dir)*18)); // order of arguments?
288  } else if (field.Precision() == QUDA_DOUBLE_PRECISION) {
289  copyArrayToLink(&link, ((double*)(field.Gauge_p()) + (i*4 + dir)*18)); // order of arguments?
290  } else {
291  errorQuda("Unsupported precision\n");
292  }
293  if (link.isUnitary(max_error) == false) {
294  printf("Unitarity failure\n");
295  printf("site index = %u,\t direction = %d\n", i, dir);
296  printLink(link);
297  identity = conj(link)*link;
298  printLink(identity);
299  return false;
300  }
301  } // dir
302  } // i
303  return true;
304 #else
305  errorQuda("Unitarization has not been built");
306  return false;
307 #endif
308  } // is unitary
309 
310 
311  template <typename Arg> __global__ void DoUnitarizedLink(Arg arg)
312  {
313  int idx = threadIdx.x + blockIdx.x*blockDim.x;
314  int parity = threadIdx.y + blockIdx.y*blockDim.y;
315  int mu = threadIdx.z + blockIdx.z*blockDim.z;
316  if (idx >= arg.threads) return;
317  if (mu >= 4) return;
318 
319  // result is always in double precision
320  Matrix<complex<double>,Arg::nColor> v, result;
321  Matrix<complex<typename Arg::Float>,Arg::nColor> tmp = arg.in(mu, idx, parity);
322 
323  v = tmp;
324  unitarizeLinkMILC<double>(result, v, arg);
325  if (arg.check_unitarization) {
326  if (result.isUnitary(arg.max_error) == false) atomicAdd(arg.fails, 1);
327  }
328  tmp = result;
329 
330  arg.out(mu, idx, parity) = tmp;
331  }
332 
333  template <typename Float, int nColor, QudaReconstructType recon>
334  class UnitarizeLinks : TunableVectorYZ {
335  UnitarizeLinksArg<Float, nColor, recon> arg;
336  const GaugeField &meta;
337 
338  bool tuneGridDim() const { return false; }
339  unsigned int minThreads() const { return arg.threads; }
340 
341  public:
342  UnitarizeLinks(GaugeField &out, const GaugeField &in, int* fails) :
343  TunableVectorYZ(2,4),
344  arg(out, in, fails, max_iter, unitarize_eps, max_error, reunit_allow_svd,
345  reunit_svd_only, svd_rel_error, svd_abs_error),
346  meta(in)
347  {
348  apply(0);
349  qudaDeviceSynchronize(); // need to synchronize to ensure failure write has completed
350  }
351 
352  void apply(const qudaStream_t &stream) {
353  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
354  qudaLaunchKernel(DoUnitarizedLink<decltype(arg)>, tp, stream, arg);
355  }
356 
357  void preTune() { if (arg.in.gauge == arg.out.gauge) arg.out.save(); }
358  void postTune() {
359  if (arg.in.gauge == arg.out.gauge) arg.out.load();
360  qudaMemset(arg.fails, 0, sizeof(int)); // reset fails counter
361  }
362 
363  long long flops() const {
364  // Accounted only the minimum flops for the case reunitarize_svd_only=0
365  return 4ll * 2 * arg.threads * 1147;
366  }
367  long long bytes() const { return 4ll * 2 * arg.threads * (arg.in.Bytes() + arg.out.Bytes()); }
368 
369  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
370  };
371 
372  void unitarizeLinks(GaugeField& out, const GaugeField &in, int* fails)
373  {
374 #ifdef GPU_UNITARIZE
375  checkPrecision(out, in);
376  instantiate<UnitarizeLinks, ReconstructWilson>(out, in, fails);
377 #else
378  errorQuda("Unitarization has not been built");
379 #endif
380  }
381 
382  void unitarizeLinks(GaugeField &links, int* fails) { unitarizeLinks(links, links, fails); }
383 
384  template <typename Float_, int nColor_, QudaReconstructType recon_>
385  struct ProjectSU3Arg {
386  using Float = Float_;
387  static constexpr int nColor = nColor_;
388  static constexpr QudaReconstructType recon = recon_;
389  typedef typename gauge_mapper<Float,recon>::type Gauge;
390  Gauge u;
391 
392  int threads; // number of active threads required
393  Float tol;
394  int *fails;
395  ProjectSU3Arg(GaugeField &u, Float tol, int *fails) :
396  threads(u.VolumeCB()),
397  u(u),
398  tol(tol),
399  fails(fails) { }
400  };
401 
402  template<typename Arg>
403  __global__ void ProjectSU3kernel(Arg arg){
404  using real = typename Arg::Float;
405  int idx = threadIdx.x + blockIdx.x*blockDim.x;
406  int parity = threadIdx.y + blockIdx.y*blockDim.y;
407  int mu = threadIdx.z + blockIdx.z*blockDim.z;
408  if (idx >= arg.threads) return;
409  if (mu >= 4) return;
410 
411  Matrix<complex<real>, Arg::nColor> u = arg.u(mu, idx, parity);
412 
413  polarSu3<real>(u, arg.tol);
414 
415  // count number of failures
416  if (u.isUnitary(arg.tol) == false) {
417  atomicAdd(arg.fails, 1);
418  }
419 
420  arg.u(mu, idx, parity) = u;
421  }
422 
423  template <typename Float, int nColor, QudaReconstructType recon>
424  class ProjectSU3 : TunableVectorYZ {
425  ProjectSU3Arg<Float, nColor, recon> arg;
426  const GaugeField &meta;
427 
428  bool tuneGridDim() const { return false; }
429  unsigned int minThreads() const { return arg.threads; }
430 
431  public:
432  ProjectSU3(GaugeField &u, double tol, int *fails) :
433  arg(u, static_cast<Float>(tol), fails),
434  TunableVectorYZ(2, 4),
435  meta(u)
436  {
437  apply(0);
438  qudaDeviceSynchronize();
439  }
440 
441  void apply(const qudaStream_t &stream) {
442  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
443  qudaLaunchKernel(ProjectSU3kernel<decltype(arg)>, tp, stream, arg);
444  }
445 
446  void preTune() { arg.u.save(); }
447  void postTune() {
448  arg.u.load();
449  qudaMemset(arg.fails, 0, sizeof(int)); // reset fails counter
450  }
451 
452  long long flops() const { return 0; } // depends on number of iterations
453  long long bytes() const { return 4ll * 2 * arg.threads * 2 * arg.u.Bytes(); }
454  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
455  };
456 
457  void projectSU3(GaugeField &u, double tol, int *fails) {
458 #ifdef GPU_GAUGE_TOOLS
459  // check the the field doesn't have staggered phases applied
460  if (u.StaggeredPhaseApplied())
461  errorQuda("Cannot project gauge field with staggered phases applied");
462 
463  instantiate<ProjectSU3, ReconstructWilson>(u, tol, fails);
464 #else
465  errorQuda("Gauge tools have not been built");
466 #endif
467  }
468 
469 } // namespace quda