QUDA  v1.1.0
A library for QCD on GPUs
unitarize_force_quda.cu
Go to the documentation of this file.
1 #include <cstdlib>
2 #include <cstdio>
3 
4 #include <gauge_field.h>
5 #include <tune_quda.h>
6 #include <quda_matrix.h>
7 #include <gauge_field_order.h>
8 #include <instantiate.h>
9 #include <color_spinor.h>
10 
11 namespace quda {
12 
13  namespace { // anonymous
14 #include <svd_quda.h>
15  }
16 
17 #define HISQ_UNITARIZE_PI 3.14159265358979323846
18 #define HISQ_UNITARIZE_PI23 HISQ_UNITARIZE_PI*2.0/3.0
19 
20  static double unitarize_eps;
21  static double force_filter;
22  static double max_det_error;
23  static bool allow_svd;
24  static bool svd_only;
25  static double svd_rel_error;
26  static double svd_abs_error;
27 
28  namespace fermion_force {
29 
30  template <typename Float_, int nColor_, QudaReconstructType recon_, QudaGaugeFieldOrder order = QUDA_NATIVE_GAUGE_ORDER>
31  struct UnitarizeForceArg {
32  using Float = Float_;
33  static constexpr int nColor = nColor_;
34  static constexpr QudaReconstructType recon = recon_;
35  // use long form here to allow specification of order
36  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO,2*nColor*nColor,QUDA_STAGGERED_PHASE_NO,gauge::default_huge_alloc, QUDA_GHOST_EXCHANGE_INVALID,false,order>::type F;
37  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO,2*nColor*nColor,QUDA_STAGGERED_PHASE_NO,gauge::default_huge_alloc, QUDA_GHOST_EXCHANGE_INVALID,false,order>::type G;
38  F force;
39  const F force_old;
40  const G u;
41  int *fails;
42  int threads;
43  const double unitarize_eps;
44  const double force_filter;
45  const double max_det_error;
46  const int allow_svd;
47  const int svd_only;
48  const double svd_rel_error;
49  const double svd_abs_error;
50 
51  UnitarizeForceArg(GaugeField &force, const GaugeField &force_old, const GaugeField &u, int *fails,
52  double unitarize_eps, double force_filter, double max_det_error, int allow_svd,
53  int svd_only, double svd_rel_error, double svd_abs_error) :
54  force(force),
55  force_old(force_old),
56  u(u),
57  fails(fails),
58  unitarize_eps(unitarize_eps),
59  force_filter(force_filter),
60  max_det_error(max_det_error),
61  allow_svd(allow_svd),
62  svd_only(svd_only),
63  svd_rel_error(svd_rel_error),
64  svd_abs_error(svd_abs_error),
65  threads(u.VolumeCB()) { }
66  };
67 
68  void setUnitarizeForceConstants(double unitarize_eps_, double force_filter_,
69  double max_det_error_, bool allow_svd_, bool svd_only_,
70  double svd_rel_error_, double svd_abs_error_)
71  {
72  unitarize_eps = unitarize_eps_;
73  force_filter = force_filter_;
74  max_det_error = max_det_error_;
75  allow_svd = allow_svd_;
76  svd_only = svd_only_;
77  svd_rel_error = svd_rel_error_;
78  svd_abs_error = svd_abs_error_;
79  }
80 
81  template <class Real> class DerivativeCoefficients {
82  Real b[6];
83  constexpr Real computeC00(const Real &u, const Real &v, const Real &w)
84  {
85  return -pow(w,3) * pow(u,6) + 3*v*pow(w,3)*pow(u,4) + 3*pow(v,4)*w*pow(u,4)
86  - pow(v,6)*pow(u,3) - 4*pow(w,4)*pow(u,3) - 12*pow(v,3)*pow(w,2)*pow(u,3)
87  + 16*pow(v,2)*pow(w,3)*pow(u,2) + 3*pow(v,5)*w*pow(u,2) - 8*v*pow(w,4)*u
88  - 3*pow(v,4)*pow(w,2)*u + pow(w,5) + pow(v,3)*pow(w,3);
89  }
90 
91  constexpr Real computeC01(const Real & u, const Real & v, const Real & w)
92  {
93  return -pow(w,2)*pow(u,7) - pow(v,2)*w*pow(u,6) + pow(v,4)*pow(u,5) + 6*v*pow(w,2)*pow(u,5)
94  - 5*pow(w,3)*pow(u,4) - pow(v,3)*w*pow(u,4)- 2*pow(v,5)*pow(u,3) - 6*pow(v,2)*pow(w,2)*pow(u,3)
95  + 10*v*pow(w,3)*pow(u,2) + 6*pow(v,4)*w*pow(u,2) - 3*pow(w,4)*u - 6*pow(v,3)*pow(w,2)*u + 2*pow(v,2)*pow(w,3);
96  }
97 
98  constexpr Real computeC02(const Real & u, const Real & v, const Real & w)
99  {
100  return pow(w,2)*pow(u,5) + pow(v,2)*w*pow(u,4)- pow(v,4)*pow(u,3)- 4*v*pow(w,2)*pow(u,3)
101  + 4*pow(w,3)*pow(u,2) + 3*pow(v,3)*w*pow(u,2) - 3*pow(v,2)*pow(w,2)*u + v*pow(w,3);
102  }
103 
104  constexpr Real computeC11(const Real & u, const Real & v, const Real & w)
105  {
106  return -w*pow(u,8) - pow(v,2)*pow(u,7) + 7*v*w*pow(u,6) + 4*pow(v,3)*pow(u,5)
107  - 5*pow(w,2)*pow(u,5) - 16*pow(v,2)*w*pow(u,4) - 4*pow(v,4)*pow(u,3) + 16*v*pow(w,2)*pow(u,3)
108  - 3*pow(w,3)*pow(u,2) + 12*pow(v,3)*w*pow(u,2) - 12*pow(v,2)*pow(w,2)*u + 3*v*pow(w,3);
109  }
110 
111  constexpr Real computeC12(const Real &u, const Real &v, const Real &w)
112  {
113  return w*pow(u,6) + pow(v,2)*pow(u,5) - 5*v*w*pow(u,4) - 2*pow(v,3)*pow(u,3)
114  + 4*pow(w,2)*pow(u,3) + 6*pow(v,2)*w*pow(u,2) - 6*v*pow(w,2)*u + pow(w,3);
115  }
116 
117  constexpr Real computeC22(const Real &u, const Real &v, const Real &w)
118  {
119  return -w*pow(u,4) - pow(v,2)*pow(u,3) + 3*v*w*pow(u,2) - 3*pow(w,2)*u;
120  }
121 
122  public:
123  constexpr void set(const Real &u, const Real &v, const Real &w)
124  {
125  const Real denominator = 1.0 / (2.0*pow(w*(u*v-w),3));
126  b[0] = computeC00(u,v,w) * denominator;
127  b[1] = computeC01(u,v,w) * denominator;
128  b[2] = computeC02(u,v,w) * denominator;
129  b[3] = computeC11(u,v,w) * denominator;
130  b[4] = computeC12(u,v,w) * denominator;
131  b[5] = computeC22(u,v,w) * denominator;
132  }
133 
134  constexpr Real getB00() const { return b[0]; }
135  constexpr Real getB01() const { return b[1]; }
136  constexpr Real getB02() const { return b[2]; }
137  constexpr Real getB11() const { return b[3]; }
138  constexpr Real getB12() const { return b[4]; }
139  constexpr Real getB22() const { return b[5]; }
140  };
141 
142  template <typename mat>
143  __device__ __host__ void accumBothDerivatives(mat &result, const mat &left, const mat &right, const mat &outer_prod)
144  {
145  auto temp = (2.0*getTrace(left*outer_prod)).real();
146  for (int k=0; k<3; ++k) {
147  for (int l=0; l<3; ++l) {
148  result(k,l) += temp*right(k,l);
149  }
150  }
151  }
152 
153  template <class mat>
154  __device__ __host__ void accumDerivatives(mat &result, const mat &left, const mat &right, const mat &outer_prod)
155  {
156  auto temp = getTrace(left*outer_prod);
157  for(int k=0; k<3; ++k){
158  for(int l=0; l<3; ++l){
159  result(k,l) = temp*right(k,l);
160  }
161  }
162  }
163 
164  template<class T> constexpr T getAbsMin(const T* const array, int size)
165  {
166  T min = fabs(array[0]);
167  for (int i=1; i<size; ++i) {
168  T abs_val = fabs(array[i]);
169  if ((abs_val) < min){ min = abs_val; }
170  }
171  return min;
172  }
173 
174  template<class Real> constexpr bool checkAbsoluteError(Real a, Real b, Real epsilon) { return fabs(a-b) < epsilon; }
175  template<class Real> constexpr bool checkRelativeError(Real a, Real b, Real epsilon) { return fabs((a-b)/b) < epsilon; }
176 
177  // Compute the reciprocal square root of the matrix q
178  // Also modify q if the eigenvalues are dangerously small.
179  template<class Float, typename Arg>
180  __device__ __host__ void reciprocalRoot(Matrix<complex<Float>,3>* res, DerivativeCoefficients<Float>* deriv_coeffs,
181  Float f[3], Matrix<complex<Float>,3> & q, Arg &arg)
182  {
183  Matrix<complex<Float>,3> qsq, tempq;
184 
185  Float c[3];
186  Float g[3];
187 
188  if(!arg.svd_only){
189  qsq = q*q;
190  tempq = qsq*q;
191 
192  c[0] = getTrace(q).x;
193  c[1] = getTrace(qsq).x/2.0;
194  c[2] = getTrace(tempq).x/3.0;
195 
196  g[0] = g[1] = g[2] = c[0]/3.;
197  Float r,s,theta;
198  s = c[1]/3. - c[0]*c[0]/18;
199  r = c[2]/2. - (c[0]/3.)*(c[1] - c[0]*c[0]/9.);
200 
201  Float cosTheta = r*rsqrt(s*s*s);
202  if (fabs(s) < arg.unitarize_eps) {
203  cosTheta = 1.;
204  s = 0.0;
205  }
206  if(fabs(cosTheta)>1.0){ r>0 ? theta=0.0 : theta=HISQ_UNITARIZE_PI/3.0; }
207  else{ theta = acos(cosTheta)/3.0; }
208 
209  s = 2.0*sqrt(s);
210  for (int i=0; i<3; ++i) {
211  g[i] += s*cos(theta + (i-1)*HISQ_UNITARIZE_PI23);
212  }
213 
214  } // !REUNIT_SVD_ONLY?
215 
216  //
217  // Compare the product of the eigenvalues computed thus far to the
218  // absolute value of the determinant.
219  // If the determinant is very small or the relative error is greater than some predefined value
220  // then recompute the eigenvalues using a singular-value decomposition.
221  // Note that this particular calculation contains multiple branches,
222  // so it doesn't appear to be particularly well-suited to the GPU
223  // programming model. However, the analytic calculation of the
224  // unitarization is extremely fast, and if the SVD routine is not called
225  // too often, we expect pretty good performance.
226  //
227 
228  if (arg.allow_svd) {
229  bool perform_svd = true;
230  if (!arg.svd_only) {
231  const Float det = getDeterminant(q).x;
232  if( fabs(det) >= arg.svd_abs_error) {
233  if( checkRelativeError(g[0]*g[1]*g[2],det,arg.svd_rel_error) ) perform_svd = false;
234  }
235  }
236 
237  if(perform_svd){
238  Matrix<complex<Float>,3> tmp2;
239  // compute the eigenvalues using the singular value decomposition
240  computeSVD<Float>(q,tempq,tmp2,g);
241  // The array g contains the eigenvalues of the matrix q
242  // The determinant is the product of the eigenvalues, and I can use this
243  // to check the SVD
244  const Float determinant = getDeterminant(q).x;
245  const Float gprod = g[0]*g[1]*g[2];
246  // Check the svd result for errors
247  if (fabs(gprod - determinant) > arg.max_det_error) {
248  printf("Warning: Error in determinant computed by SVD : %g > %g\n", fabs(gprod-determinant), arg.max_det_error);
249  printLink(q);
250 
251 #ifdef __CUDA_ARCH__
252  atomicAdd(arg.fails, 1);
253 #else
254  (*arg.fails)++;
255 #endif
256  }
257  } // perform_svd?
258 
259  } // REUNIT_ALLOW_SVD?
260 
261  Float delta = getAbsMin(g,3);
262  if (delta < arg.force_filter) {
263  for (int i=0; i<3; ++i) {
264  g[i] += arg.force_filter;
265  q(i,i).x += arg.force_filter;
266  }
267  qsq = q*q; // recalculate Q^2
268  }
269 
270 
271  // At this point we have finished with the c's
272  // use these to store sqrt(g)
273  for (int i=0; i<3; ++i) c[i] = sqrt(g[i]);
274 
275  // done with the g's, use these to store u, v, w
276  g[0] = c[0]+c[1]+c[2];
277  g[1] = c[0]*c[1] + c[0]*c[2] + c[1]*c[2];
278  g[2] = c[0]*c[1]*c[2];
279 
280  // set the derivative coefficients!
281  deriv_coeffs->set(g[0], g[1], g[2]);
282 
283  const Float& denominator = g[2]*(g[0]*g[1]-g[2]);
284  c[0] = (g[0]*g[1]*g[1] - g[2]*(g[0]*g[0]+g[1]))/denominator;
285  c[1] = (-g[0]*g[0]*g[0] - g[2] + 2.*g[0]*g[1])/denominator;
286  c[2] = g[0]/denominator;
287 
288  tempq = c[1]*q + c[2]*qsq;
289  // Add a real scalar
290  tempq(0,0).x += c[0];
291  tempq(1,1).x += c[0];
292  tempq(2,2).x += c[0];
293 
294  f[0] = c[0];
295  f[1] = c[1];
296  f[2] = c[2];
297 
298  *res = tempq;
299  }
300 
301  // "v" denotes a "fattened" link variable
302  template <class Float, typename Arg>
303  __device__ __host__ void getUnitarizeForceSite(Matrix<complex<Float>,3>& result, const Matrix<complex<Float>,3> & v,
304  const Matrix<complex<Float>,3> & outer_prod, Arg &arg)
305  {
306  typedef Matrix<complex<Float>,3> Link;
307  Float f[3];
308  Float b[6];
309 
310  Link v_dagger = conj(v); // okay!
311  Link q = v_dagger*v; // okay!
312  Link rsqrt_q;
313 
314  DerivativeCoefficients<Float> deriv_coeffs;
315 
316  reciprocalRoot<Float>(&rsqrt_q, &deriv_coeffs, f, q, arg); // approx 529 flops (assumes no SVD)
317 
318  // Pure hack here
319  b[0] = deriv_coeffs.getB00();
320  b[1] = deriv_coeffs.getB01();
321  b[2] = deriv_coeffs.getB02();
322  b[3] = deriv_coeffs.getB11();
323  b[4] = deriv_coeffs.getB12();
324  b[5] = deriv_coeffs.getB22();
325 
326  result = rsqrt_q*outer_prod;
327 
328  // We are now finished with rsqrt_q
329  Link qv_dagger = q*v_dagger;
330  Link vv_dagger = v*v_dagger;
331  Link vqv_dagger = v*qv_dagger;
332  Link temp = f[1]*vv_dagger + f[2]*vqv_dagger;
333 
334  temp = f[1]*v_dagger + f[2]*qv_dagger;
335  Link conj_outer_prod = conj(outer_prod);
336 
337  temp = f[1]*v + f[2]*v*q;
338  result = result + outer_prod*temp*v_dagger + f[2]*q*outer_prod*vv_dagger;
339  result = result + v_dagger*conj_outer_prod*conj(temp) + f[2]*qv_dagger*conj_outer_prod*v_dagger;
340 
341  Link qsqv_dagger = q*qv_dagger;
342  Link pv_dagger = b[0]*v_dagger + b[1]*qv_dagger + b[2]*qsqv_dagger;
343  accumBothDerivatives(result, v, pv_dagger, outer_prod); // 41 flops
344 
345  Link rv_dagger = b[1]*v_dagger + b[3]*qv_dagger + b[4]*qsqv_dagger;
346  Link vq = v*q;
347  accumBothDerivatives(result, vq, rv_dagger, outer_prod); // 41 flops
348 
349  Link sv_dagger = b[2]*v_dagger + b[4]*qv_dagger + b[5]*qsqv_dagger;
350  Link vqsq = vq*q;
351  accumBothDerivatives(result, vqsq, sv_dagger, outer_prod); // 41 flops
352 
353  // 4528 flops - 17 matrix multiplies (198 flops each) + reciprocal root (approx 529 flops) + accumBothDerivatives (41 each) + miscellaneous
354  } // get unit force term
355 
356  template <typename Arg> __global__ void getUnitarizeForceField(Arg arg)
357  {
358  using real = typename Arg::Float;
359  int x_cb = blockIdx.x*blockDim.x + threadIdx.x;
360  if (x_cb >= arg.threads) return;
361  int parity = blockIdx.y*blockDim.y + threadIdx.y;
362 
363  // This part of the calculation is always done in double precision
364  Matrix<complex<double>,3> v, result, oprod;
365  Matrix<complex<real>,3> v_tmp, result_tmp, oprod_tmp;
366 
367  for (int dir=0; dir<4; ++dir) {
368  oprod_tmp = arg.force_old(dir, x_cb, parity);
369  v_tmp = arg.u(dir, x_cb, parity);
370  v = v_tmp;
371  oprod = oprod_tmp;
372 
373  getUnitarizeForceSite<double>(result, v, oprod, arg);
374  result_tmp = result;
375 
376  arg.force(dir, x_cb, parity) = result_tmp;
377  } // 4*4528 flops per site
378  } // getUnitarizeForceField
379 
380  template <typename Float, int nColor, QudaReconstructType recon> class UnitarizeForce : public TunableVectorY {
381  UnitarizeForceArg<Float, nColor, recon> arg;
382  const GaugeField &meta;
383 
384  // don't tune the grid dimension
385  bool tuneGridDim() const { return false; }
386  unsigned int minThreads() const { return arg.threads; }
387 
388  public:
389  UnitarizeForce(GaugeField &newForce, const GaugeField &oldForce, const GaugeField &u, int* fails) :
390  TunableVectorY(2),
391  arg(newForce, oldForce, u, fails, unitarize_eps, force_filter,
392  max_det_error, allow_svd, svd_only, svd_rel_error, svd_abs_error),
393  meta(u)
394  {
395  apply(0);
396  qudaDeviceSynchronize(); // need to synchronize to ensure failure write has completed
397  }
398 
399  void apply(const qudaStream_t &stream) {
400  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
401  qudaLaunchKernel(getUnitarizeForceField<decltype(arg)>, tp, stream, arg);
402  }
403 
404  void preTune() { ; }
405  void postTune() { qudaMemset(arg.fails, 0, sizeof(int)); } // reset fails counter
406 
407  long long flops() const { return 4ll*4528*meta.Volume(); }
408  long long bytes() const { return 4ll * meta.Volume() * (arg.force.Bytes() + arg.force_old.Bytes() + arg.u.Bytes()); }
409 
410  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
411  }; // UnitarizeForce
412 
413  void unitarizeForce(GaugeField &newForce, const GaugeField &oldForce, const GaugeField &u,
414  int* fails)
415  {
416 #ifdef GPU_HISQ_FORCE
417  checkReconstruct(u, oldForce, newForce);
418  checkPrecision(u, oldForce, newForce);
419 
420  if (!u.isNative() || !oldForce.isNative() || !newForce.isNative())
421  errorQuda("Only native order supported");
422 
423  instantiate<UnitarizeForce,ReconstructNone>(newForce, oldForce, u, fails);
424 #else
425  errorQuda("HISQ force has not been built");
426 #endif
427  }
428 
429  template <typename Float, typename Arg> void unitarizeForceCPU(Arg &arg)
430  {
431 #ifdef GPU_HISQ_FORCE
432  Matrix<complex<double>, 3> v, result, oprod;
433  Matrix<complex<Float>, 3> v_tmp, result_tmp, oprod_tmp;
434 
435  for (int parity = 0; parity < 2; parity++) {
436  for (int i = 0; i < arg.threads; i++) {
437  for (int dir = 0; dir < 4; dir++) {
438  oprod_tmp = arg.force_old(dir, i, parity);
439  v_tmp = arg.u(dir, i, parity);
440  v = v_tmp;
441  oprod = oprod_tmp;
442 
443  getUnitarizeForceSite<double>(result, v, oprod, arg);
444 
445  result_tmp = result;
446  arg.force(dir, i, parity) = result_tmp;
447  }
448  }
449  }
450 #else
451  errorQuda("HISQ force has not been built");
452 #endif
453  }
454 
455  void unitarizeForceCPU(GaugeField &newForce, const GaugeField &oldForce, const GaugeField &u)
456  {
457  if (checkLocation(newForce, oldForce, u) != QUDA_CPU_FIELD_LOCATION) errorQuda("Location must be CPU");
458  int num_failures = 0;
459  constexpr int nColor = 3;
460  Matrix<complex<double>, nColor> old_force, new_force, v;
461  if (u.Order() == QUDA_MILC_GAUGE_ORDER) {
462  if (u.Precision() == QUDA_DOUBLE_PRECISION) {
463  UnitarizeForceArg<double, nColor, QUDA_RECONSTRUCT_NO, QUDA_MILC_GAUGE_ORDER> arg(
464  newForce, oldForce, u, &num_failures, unitarize_eps, force_filter, max_det_error, allow_svd, svd_only,
465  svd_rel_error, svd_abs_error);
466  unitarizeForceCPU<double>(arg);
467  } else if (u.Precision() == QUDA_SINGLE_PRECISION) {
468  UnitarizeForceArg<float, nColor, QUDA_RECONSTRUCT_NO, QUDA_MILC_GAUGE_ORDER> arg(
469  newForce, oldForce, u, &num_failures, unitarize_eps, force_filter, max_det_error, allow_svd, svd_only,
470  svd_rel_error, svd_abs_error);
471  unitarizeForceCPU<float>(arg);
472  } else {
473  errorQuda("Precision = %d not supported", u.Precision());
474  }
475  } else if (u.Order() == QUDA_QDP_GAUGE_ORDER) {
476  if (u.Precision() == QUDA_DOUBLE_PRECISION) {
477  UnitarizeForceArg<double, nColor, QUDA_RECONSTRUCT_NO, QUDA_QDP_GAUGE_ORDER> arg(
478  newForce, oldForce, u, &num_failures, unitarize_eps, force_filter, max_det_error, allow_svd, svd_only,
479  svd_rel_error, svd_abs_error);
480  unitarizeForceCPU<double>(arg);
481  } else if (u.Precision() == QUDA_SINGLE_PRECISION) {
482  UnitarizeForceArg<float, nColor, QUDA_RECONSTRUCT_NO, QUDA_QDP_GAUGE_ORDER> arg(
483  newForce, oldForce, u, &num_failures, unitarize_eps, force_filter, max_det_error, allow_svd, svd_only,
484  svd_rel_error, svd_abs_error);
485  unitarizeForceCPU<float>(arg);
486  } else {
487  errorQuda("Precision = %d not supported", u.Precision());
488  }
489  } else {
490  errorQuda("Only MILC and QDP gauge orders supported\n");
491  }
492  if (num_failures) errorQuda("Unitarization failed, failures = %d", num_failures);
493  } // unitarize_force_cpu
494 
495  } // namespace fermion_force
496 
497 } // namespace quda