QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
unitarize_force_quda.cu
Go to the documentation of this file.
1 #include <cstdlib>
2 #include <cstdio>
3 #include <iostream>
4 #include <iomanip>
5 #include <cuda.h>
6 #include <gauge_field.h>
7 #include <tune_quda.h>
8 
9 #include <tune_quda.h>
10 #include <quda_matrix.h>
11 #include <gauge_field_order.h>
12 
13 #ifdef GPU_HISQ_FORCE
14 
15 // work around for CUDA 7.0 bug on OSX
16 #if defined(__APPLE__) && CUDA_VERSION >= 7000 && CUDA_VERSION < 7050
17 #define EXPONENT_TYPE Real
18 #else
19 #define EXPONENT_TYPE int
20 #endif
21 
22 namespace quda{
23 
24  namespace { // anonymous
25 #include <svd_quda.h>
26  }
27 
28 #define HISQ_UNITARIZE_PI 3.14159265358979323846
29 #define HISQ_UNITARIZE_PI23 HISQ_UNITARIZE_PI*2.0/3.0
30 
31  static double unitarize_eps;
32  static double force_filter;
33  static double max_det_error;
34  static bool allow_svd;
35  static bool svd_only;
36  static double svd_rel_error;
37  static double svd_abs_error;
38 
39 
40  namespace fermion_force {
41 
42  template <typename F, typename G>
43  struct UnitarizeForceArg {
44  int threads;
45  F force;
46  F force_old;
47  G gauge;
48  int *fails;
49  const double unitarize_eps;
50  const double force_filter;
51  const double max_det_error;
52  const int allow_svd;
53  const int svd_only;
54  const double svd_rel_error;
55  const double svd_abs_error;
56 
57  UnitarizeForceArg(const F &force, const F &force_old, const G &gauge, const GaugeField &meta, int *fails,
58  double unitarize_eps, double force_filter, double max_det_error, int allow_svd,
59  int svd_only, double svd_rel_error, double svd_abs_error)
60  : threads(1), force(force), force_old(force_old), gauge(gauge), fails(fails), unitarize_eps(unitarize_eps),
61  force_filter(force_filter), max_det_error(max_det_error), allow_svd(allow_svd),
62  svd_only(svd_only), svd_rel_error(svd_rel_error), svd_abs_error(svd_abs_error)
63  {
64  for(int dir=0; dir<4; ++dir) threads *= meta.X()[dir];
65  }
66  };
67 
68 
69  void setUnitarizeForceConstants(double unitarize_eps_, double force_filter_,
70  double max_det_error_, bool allow_svd_, bool svd_only_,
71  double svd_rel_error_, double svd_abs_error_)
72  {
73  unitarize_eps = unitarize_eps_;
74  force_filter = force_filter_;
75  max_det_error = max_det_error_;
76  allow_svd = allow_svd_;
77  svd_only = svd_only_;
78  svd_rel_error = svd_rel_error_;
79  svd_abs_error = svd_abs_error_;
80  }
81 
82 
83  template<class Real>
84  class DerivativeCoefficients{
85  private:
86  Real b[6];
87  __device__ __host__
88  Real computeC00(const Real &, const Real &, const Real &);
89  __device__ __host__
90  Real computeC01(const Real &, const Real &, const Real &);
91  __device__ __host__
92  Real computeC02(const Real &, const Real &, const Real &);
93  __device__ __host__
94  Real computeC11(const Real &, const Real &, const Real &);
95  __device__ __host__
96  Real computeC12(const Real &, const Real &, const Real &);
97  __device__ __host__
98  Real computeC22(const Real &, const Real &, const Real &);
99 
100  public:
101  __device__ __host__ void set(const Real & u, const Real & v, const Real & w);
102  __device__ __host__
103  Real getB00() const { return b[0]; }
104  __device__ __host__
105  Real getB01() const { return b[1]; }
106  __device__ __host__
107  Real getB02() const { return b[2]; }
108  __device__ __host__
109  Real getB11() const { return b[3]; }
110  __device__ __host__
111  Real getB12() const { return b[4]; }
112  __device__ __host__
113  Real getB22() const { return b[5]; }
114  };
115 
116  template<class Real>
117  __device__ __host__
118  Real DerivativeCoefficients<Real>::computeC00(const Real & u, const Real & v, const Real & w){
119  Real result = -pow(w,static_cast<EXPONENT_TYPE>(3)) * pow(u,static_cast<EXPONENT_TYPE>(6))
120  + 3*v*pow(w,static_cast<EXPONENT_TYPE>(3))*pow(u,static_cast<EXPONENT_TYPE>(4))
121  + 3*pow(v,static_cast<EXPONENT_TYPE>(4))*w*pow(u,static_cast<EXPONENT_TYPE>(4))
122  - pow(v,static_cast<EXPONENT_TYPE>(6))*pow(u,static_cast<EXPONENT_TYPE>(3))
123  - 4*pow(w,static_cast<EXPONENT_TYPE>(4))*pow(u,static_cast<EXPONENT_TYPE>(3))
124  - 12*pow(v,static_cast<EXPONENT_TYPE>(3))*pow(w,static_cast<EXPONENT_TYPE>(2))*pow(u,static_cast<EXPONENT_TYPE>(3))
125  + 16*pow(v,static_cast<EXPONENT_TYPE>(2))*pow(w,static_cast<EXPONENT_TYPE>(3))*pow(u,static_cast<EXPONENT_TYPE>(2))
126  + 3*pow(v,static_cast<EXPONENT_TYPE>(5))*w*pow(u,static_cast<EXPONENT_TYPE>(2))
127  - 8*v*pow(w,static_cast<EXPONENT_TYPE>(4))*u
128  - 3*pow(v,static_cast<EXPONENT_TYPE>(4))*pow(w,static_cast<EXPONENT_TYPE>(2))*u
129  + pow(w,static_cast<EXPONENT_TYPE>(5))
130  + pow(v,static_cast<EXPONENT_TYPE>(3))*pow(w,static_cast<EXPONENT_TYPE>(3));
131 
132  return result;
133  }
134 
135  template<class Real>
136  __device__ __host__
137  Real DerivativeCoefficients<Real>::computeC01(const Real & u, const Real & v, const Real & w){
138  Real result = - pow(w,static_cast<EXPONENT_TYPE>(2))*pow(u,static_cast<EXPONENT_TYPE>(7))
139  - pow(v,static_cast<EXPONENT_TYPE>(2))*w*pow(u,static_cast<EXPONENT_TYPE>(6))
140  + pow(v,static_cast<EXPONENT_TYPE>(4))*pow(u,static_cast<EXPONENT_TYPE>(5)) // This was corrected!
141  + 6*v*pow(w,static_cast<EXPONENT_TYPE>(2))*pow(u,static_cast<EXPONENT_TYPE>(5))
142  - 5*pow(w,static_cast<EXPONENT_TYPE>(3))*pow(u,static_cast<EXPONENT_TYPE>(4)) // This was corrected!
143  - pow(v,static_cast<EXPONENT_TYPE>(3))*w*pow(u,static_cast<EXPONENT_TYPE>(4))
144  - 2*pow(v,static_cast<EXPONENT_TYPE>(5))*pow(u,static_cast<EXPONENT_TYPE>(3))
145  - 6*pow(v,static_cast<EXPONENT_TYPE>(2))*pow(w,static_cast<EXPONENT_TYPE>(2))*pow(u,static_cast<EXPONENT_TYPE>(3))
146  + 10*v*pow(w,static_cast<EXPONENT_TYPE>(3))*pow(u,static_cast<EXPONENT_TYPE>(2))
147  + 6*pow(v,static_cast<EXPONENT_TYPE>(4))*w*pow(u,static_cast<EXPONENT_TYPE>(2))
148  - 3*pow(w,static_cast<EXPONENT_TYPE>(4))*u
149  - 6*pow(v,static_cast<EXPONENT_TYPE>(3))*pow(w,static_cast<EXPONENT_TYPE>(2))*u
150  + 2*pow(v,static_cast<EXPONENT_TYPE>(2))*pow(w,static_cast<EXPONENT_TYPE>(3));
151  return result;
152  }
153 
154  template<class Real>
155  __device__ __host__
156  Real DerivativeCoefficients<Real>::computeC02(const Real & u, const Real & v, const Real & w){
157  Real result = pow(w,static_cast<EXPONENT_TYPE>(2))*pow(u,static_cast<EXPONENT_TYPE>(5))
158  + pow(v,static_cast<EXPONENT_TYPE>(2))*w*pow(u,static_cast<EXPONENT_TYPE>(4))
159  - pow(v,static_cast<EXPONENT_TYPE>(4))*pow(u,static_cast<EXPONENT_TYPE>(3))
160  - 4*v*pow(w,static_cast<EXPONENT_TYPE>(2))*pow(u,static_cast<EXPONENT_TYPE>(3))
161  + 4*pow(w,static_cast<EXPONENT_TYPE>(3))*pow(u,static_cast<EXPONENT_TYPE>(2))
162  + 3*pow(v,static_cast<EXPONENT_TYPE>(3))*w*pow(u,static_cast<EXPONENT_TYPE>(2))
163  - 3*pow(v,static_cast<EXPONENT_TYPE>(2))*pow(w,static_cast<EXPONENT_TYPE>(2))*u
164  + v*pow(w,static_cast<EXPONENT_TYPE>(3));
165  return result;
166  }
167 
168  template<class Real>
169  __device__ __host__
170  Real DerivativeCoefficients<Real>::computeC11(const Real & u, const Real & v, const Real & w){
171  Real result = - w*pow(u,static_cast<EXPONENT_TYPE>(8))
172  - pow(v,static_cast<EXPONENT_TYPE>(2))*pow(u,static_cast<EXPONENT_TYPE>(7))
173  + 7*v*w*pow(u,static_cast<EXPONENT_TYPE>(6))
174  + 4*pow(v,static_cast<EXPONENT_TYPE>(3))*pow(u,static_cast<EXPONENT_TYPE>(5))
175  - 5*pow(w,static_cast<EXPONENT_TYPE>(2))*pow(u,static_cast<EXPONENT_TYPE>(5))
176  - 16*pow(v,static_cast<EXPONENT_TYPE>(2))*w*pow(u,static_cast<EXPONENT_TYPE>(4))
177  - 4*pow(v,static_cast<EXPONENT_TYPE>(4))*pow(u,static_cast<EXPONENT_TYPE>(3))
178  + 16*v*pow(w,static_cast<EXPONENT_TYPE>(2))*pow(u,static_cast<EXPONENT_TYPE>(3))
179  - 3*pow(w,static_cast<EXPONENT_TYPE>(3))*pow(u,static_cast<EXPONENT_TYPE>(2))
180  + 12*pow(v,static_cast<EXPONENT_TYPE>(3))*w*pow(u,static_cast<EXPONENT_TYPE>(2))
181  - 12*pow(v,static_cast<EXPONENT_TYPE>(2))*pow(w,static_cast<EXPONENT_TYPE>(2))*u
182  + 3*v*pow(w,static_cast<EXPONENT_TYPE>(3));
183  return result;
184  }
185 
186  template<class Real>
187  __device__ __host__
188  Real DerivativeCoefficients<Real>::computeC12(const Real & u, const Real & v, const Real & w){
189  Real result = w*pow(u,static_cast<EXPONENT_TYPE>(6))
190  + pow(v,static_cast<EXPONENT_TYPE>(2))*pow(u,static_cast<EXPONENT_TYPE>(5)) // Fixed this!
191  - 5*v*w*pow(u,static_cast<EXPONENT_TYPE>(4)) // Fixed this!
192  - 2*pow(v,static_cast<EXPONENT_TYPE>(3))*pow(u,static_cast<EXPONENT_TYPE>(3))
193  + 4*pow(w,static_cast<EXPONENT_TYPE>(2))*pow(u,static_cast<EXPONENT_TYPE>(3))
194  + 6*pow(v,static_cast<EXPONENT_TYPE>(2))*w*pow(u,static_cast<EXPONENT_TYPE>(2))
195  - 6*v*pow(w,static_cast<EXPONENT_TYPE>(2))*u
196  + pow(w,static_cast<EXPONENT_TYPE>(3));
197  return result;
198  }
199 
200  template<class Real>
201  __device__ __host__
202  Real DerivativeCoefficients<Real>::computeC22(const Real & u, const Real & v, const Real & w){
203  Real result = - w*pow(u,static_cast<EXPONENT_TYPE>(4))
204  - pow(v,static_cast<EXPONENT_TYPE>(2))*pow(u,static_cast<EXPONENT_TYPE>(3))
205  + 3*v*w*pow(u,static_cast<EXPONENT_TYPE>(2))
206  - 3*pow(w,static_cast<EXPONENT_TYPE>(2))*u;
207  return result;
208  }
209 
210  template <class Real>
211  __device__ __host__
212  void DerivativeCoefficients<Real>::set(const Real & u, const Real & v, const Real & w){
213  const Real & denominator = 2.0*pow(w*(u*v-w),static_cast<EXPONENT_TYPE>(3));
214  b[0] = computeC00(u,v,w)/denominator;
215  b[1] = computeC01(u,v,w)/denominator;
216  b[2] = computeC02(u,v,w)/denominator;
217  b[3] = computeC11(u,v,w)/denominator;
218  b[4] = computeC12(u,v,w)/denominator;
219  b[5] = computeC22(u,v,w)/denominator;
220  return;
221  }
222 
223 
224  template<class Float>
225  __device__ __host__
226  void accumBothDerivatives(Matrix<complex<Float>,3>* result, const Matrix<complex<Float>,3> &left,
227  const Matrix<complex<Float>,3> &right, const Matrix<complex<Float>,3> &outer_prod)
228  {
229  const Float temp = (2.0*getTrace(left*outer_prod)).real();;
230  for(int k=0; k<3; ++k){
231  for(int l=0; l<3; ++l){
232  // Need to write it this way to get it to work
233  // on the CPU. Not sure why.
234  // FIXME check this is true
235  result->operator()(k,l).x += temp*right(k,l).x;
236  result->operator()(k,l).y += temp*right(k,l).y;
237  }
238  }
239  return;
240  }
241 
242 
243  template<class Cmplx>
244  __device__ __host__
245  void accumDerivatives(Matrix<Cmplx,3>* result, const Matrix<Cmplx,3> & left, const Matrix<Cmplx,3> & right, const Matrix<Cmplx,3> & outer_prod)
246  {
247  Cmplx temp = getTrace(left*outer_prod);
248  for(int k=0; k<3; ++k){
249  for(int l=0; l<3; ++l){
250  result->operator()(k,l) = temp*right(k,l);
251  }
252  }
253  return;
254  }
255 
256 
257  template<class T>
258  __device__ __host__
259  T getAbsMin(const T* const array, int size){
260  T min = fabs(array[0]);
261  for (int i=1; i<size; ++i) {
262  T abs_val = fabs(array[i]);
263  if ((abs_val) < min){ min = abs_val; }
264  }
265  return min;
266  }
267 
268 
269  template<class Real>
270  __device__ __host__
271  inline bool checkAbsoluteError(Real a, Real b, Real epsilon)
272  {
273  if( fabs(a-b) < epsilon) return true;
274  return false;
275  }
276 
277 
278  template<class Real>
279  __device__ __host__
280  inline bool checkRelativeError(Real a, Real b, Real epsilon)
281  {
282  if( fabs((a-b)/b) < epsilon ) return true;
283  return false;
284  }
285 
286 
287 
288 
289  // Compute the reciprocal square root of the matrix q
290  // Also modify q if the eigenvalues are dangerously small.
291  template<class Float, typename Arg>
292  __device__ __host__
293  void reciprocalRoot(Matrix<complex<Float>,3>* res, DerivativeCoefficients<Float>* deriv_coeffs,
294  Float f[3], Matrix<complex<Float>,3> & q, Arg &arg) {
295 
296  Matrix<complex<Float>,3> qsq, tempq;
297 
298  Float c[3];
299  Float g[3];
300 
301  if(!arg.svd_only){
302  qsq = q*q;
303  tempq = qsq*q;
304 
305  c[0] = getTrace(q).x;
306  c[1] = getTrace(qsq).x/2.0;
307  c[2] = getTrace(tempq).x/3.0;
308 
309  g[0] = g[1] = g[2] = c[0]/3.;
310  Float r,s,theta;
311  s = c[1]/3. - c[0]*c[0]/18;
312  r = c[2]/2. - (c[0]/3.)*(c[1] - c[0]*c[0]/9.);
313 
314  Float cosTheta = r/sqrt(s*s*s);
315  if (fabs(s) < arg.unitarize_eps) {
316  cosTheta = 1.;
317  s = 0.0;
318  }
319  if(fabs(cosTheta)>1.0){ r>0 ? theta=0.0 : theta=HISQ_UNITARIZE_PI/3.0; }
320  else{ theta = acos(cosTheta)/3.0; }
321 
322  s = 2.0*sqrt(s);
323  for(int i=0; i<3; ++i){
324  g[i] += s*cos(theta + (i-1)*HISQ_UNITARIZE_PI23);
325  }
326 
327  } // !REUNIT_SVD_ONLY?
328 
329  //
330  // Compare the product of the eigenvalues computed thus far to the
331  // absolute value of the determinant.
332  // If the determinant is very small or the relative error is greater than some predefined value
333  // then recompute the eigenvalues using a singular-value decomposition.
334  // Note that this particular calculation contains multiple branches,
335  // so it doesn't appear to be particularly well-suited to the GPU
336  // programming model. However, the analytic calculation of the
337  // unitarization is extremely fast, and if the SVD routine is not called
338  // too often, we expect pretty good performance.
339  //
340 
341  if (arg.allow_svd) {
342  bool perform_svd = true;
343  if (!arg.svd_only) {
344  const Float det = getDeterminant(q).x;
345  if( fabs(det) >= arg.svd_abs_error) {
346  if( checkRelativeError(g[0]*g[1]*g[2],det,arg.svd_rel_error) ) perform_svd = false;
347  }
348  }
349 
350  if(perform_svd){
352  // compute the eigenvalues using the singular value decomposition
353  computeSVD<Float>(q,tempq,tmp2,g);
354  // The array g contains the eigenvalues of the matrix q
355  // The determinant is the product of the eigenvalues, and I can use this
356  // to check the SVD
357  const Float determinant = getDeterminant(q).x;
358  const Float gprod = g[0]*g[1]*g[2];
359  // Check the svd result for errors
360  if (fabs(gprod - determinant) > arg.max_det_error) {
361  printf("Warning: Error in determinant computed by SVD : %g > %g\n", fabs(gprod-determinant), arg.max_det_error);
362  printLink(q);
363 
364 #ifdef __CUDA_ARCH__
365  atomicAdd(arg.fails, 1);
366 #else
367  (*arg.fails)++;
368 #endif
369  }
370  } // perform_svd?
371 
372  } // REUNIT_ALLOW_SVD?
373 
374  Float delta = getAbsMin(g,3);
375  if (delta < arg.force_filter) {
376  for (int i=0; i<3; ++i) {
377  g[i] += arg.force_filter;
378  q(i,i).x += arg.force_filter;
379  }
380  qsq = q*q; // recalculate Q^2
381  }
382 
383 
384  // At this point we have finished with the c's
385  // use these to store sqrt(g)
386  for (int i=0; i<3; ++i) c[i] = sqrt(g[i]);
387 
388  // done with the g's, use these to store u, v, w
389  g[0] = c[0]+c[1]+c[2];
390  g[1] = c[0]*c[1] + c[0]*c[2] + c[1]*c[2];
391  g[2] = c[0]*c[1]*c[2];
392 
393  // set the derivative coefficients!
394  deriv_coeffs->set(g[0], g[1], g[2]);
395 
396  const Float& denominator = g[2]*(g[0]*g[1]-g[2]);
397  c[0] = (g[0]*g[1]*g[1] - g[2]*(g[0]*g[0]+g[1]))/denominator;
398  c[1] = (-g[0]*g[0]*g[0] - g[2] + 2.*g[0]*g[1])/denominator;
399  c[2] = g[0]/denominator;
400 
401  tempq = c[1]*q + c[2]*qsq;
402  // Add a real scalar
403  tempq(0,0).x += c[0];
404  tempq(1,1).x += c[0];
405  tempq(2,2).x += c[0];
406 
407  f[0] = c[0];
408  f[1] = c[1];
409  f[2] = c[2];
410 
411  *res = tempq;
412  return;
413  }
414 
415 
416 
417  // "v" denotes a "fattened" link variable
418  template<class Float, typename Arg>
419  __device__ __host__
420  void getUnitarizeForceSite(Matrix<complex<Float>,3>& result, const Matrix<complex<Float>,3> & v,
421  const Matrix<complex<Float>,3> & outer_prod, Arg &arg)
422  {
423  typedef Matrix<complex<Float>,3> Link;
424  Float f[3];
425  Float b[6];
426 
427  Link v_dagger = conj(v); // okay!
428  Link q = v_dagger*v; // okay!
429  Link rsqrt_q;
430 
431  DerivativeCoefficients<Float> deriv_coeffs;
432 
433  reciprocalRoot<Float>(&rsqrt_q, &deriv_coeffs, f, q, arg); // approx 529 flops (assumes no SVD)
434 
435  // Pure hack here
436  b[0] = deriv_coeffs.getB00();
437  b[1] = deriv_coeffs.getB01();
438  b[2] = deriv_coeffs.getB02();
439  b[3] = deriv_coeffs.getB11();
440  b[4] = deriv_coeffs.getB12();
441  b[5] = deriv_coeffs.getB22();
442 
443  result = rsqrt_q*outer_prod;
444 
445  // We are now finished with rsqrt_q
446  Link qv_dagger = q*v_dagger;
447  Link vv_dagger = v*v_dagger;
448  Link vqv_dagger = v*qv_dagger;
449  Link temp = f[1]*vv_dagger + f[2]*vqv_dagger;
450 
451  temp = f[1]*v_dagger + f[2]*qv_dagger;
452  Link conj_outer_prod = conj(outer_prod);
453 
454  temp = f[1]*v + f[2]*v*q;
455  result = result + outer_prod*temp*v_dagger + f[2]*q*outer_prod*vv_dagger;
456  result = result + v_dagger*conj_outer_prod*conj(temp) + f[2]*qv_dagger*conj_outer_prod*v_dagger;
457 
458  Link qsqv_dagger = q*qv_dagger;
459  Link pv_dagger = b[0]*v_dagger + b[1]*qv_dagger + b[2]*qsqv_dagger;
460  accumBothDerivatives(&result, v, pv_dagger, outer_prod); // 41 flops
461 
462  Link rv_dagger = b[1]*v_dagger + b[3]*qv_dagger + b[4]*qsqv_dagger;
463  Link vq = v*q;
464  accumBothDerivatives(&result, vq, rv_dagger, outer_prod); // 41 flops
465 
466  Link sv_dagger = b[2]*v_dagger + b[4]*qv_dagger + b[5]*qsqv_dagger;
467  Link vqsq = vq*q;
468  accumBothDerivatives(&result, vqsq, sv_dagger, outer_prod); // 41 flops
469 
470  return;
471  // 4528 flops - 17 matrix multiplies (198 flops each) + reciprocal root (approx 529 flops) + accumBothDerivatives (41 each) + miscellaneous
472  } // get unit force term
473 
474 
475  template<typename Float, typename Arg>
476  __global__ void getUnitarizeForceField(Arg arg)
477  {
478  int idx = blockIdx.x*blockDim.x + threadIdx.x;
479  if(idx >= arg.threads) return;
480  int parity = 0;
481  if(idx >= arg.threads/2) {
482  parity = 1;
483  idx -= arg.threads/2;
484  }
485 
486  // This part of the calculation is always done in double precision
487  Matrix<complex<double>,3> v, result, oprod;
488  Matrix<complex<Float>,3> v_tmp, result_tmp, oprod_tmp;
489 
490  for(int dir=0; dir<4; ++dir){
491  oprod_tmp = arg.force_old(dir, idx, parity);
492  v_tmp = arg.gauge(dir, idx, parity);
493  v = v_tmp;
494  oprod = oprod_tmp;
495 
496  getUnitarizeForceSite<double>(result, v, oprod, arg);
497  result_tmp = result;
498 
499  arg.force(dir, idx, parity) = result_tmp;
500  } // 4*4528 flops per site
501  return;
502  } // getUnitarizeForceField
503 
504 
505  template <typename Float, typename Arg>
506  void unitarizeForceCPU(Arg &arg) {
507  Matrix<complex<double>,3> v, result, oprod;
508  Matrix<complex<Float>,3> v_tmp, result_tmp, oprod_tmp;
509 
510  for (int parity=0; parity<2; parity++) {
511  for (int i=0; i<arg.threads/2; i++) {
512  for (int dir=0; dir<4; dir++) {
513  oprod_tmp = arg.force_old(dir, i, parity);
514  v_tmp = arg.gauge(dir, i, parity);
515  v = v_tmp;
516  oprod = oprod_tmp;
517 
518  getUnitarizeForceSite<double>(result, v, oprod, arg);
519 
520  result_tmp = result;
521  arg.force(dir, i, parity) = result_tmp;
522  }
523  }
524  }
525  }
526 
527  void unitarizeForceCPU(cpuGaugeField& newForce, const cpuGaugeField& oldForce, const cpuGaugeField& gauge)
528  {
529  int num_failures = 0;
530  Matrix<complex<double>,3> old_force, new_force, v;
531 
532  if (gauge.Order() == QUDA_MILC_GAUGE_ORDER) {
533  if (gauge.Precision() == QUDA_DOUBLE_PRECISION) {
534  typedef gauge::MILCOrder<double,18> G;
535  UnitarizeForceArg<G,G> arg(G(newForce), G(oldForce), G(gauge), gauge, &num_failures, unitarize_eps, force_filter,
536  max_det_error, allow_svd, svd_only, svd_rel_error, svd_abs_error);
537  unitarizeForceCPU<double>(arg);
538  } else if (gauge.Precision() == QUDA_SINGLE_PRECISION) {
539  typedef gauge::MILCOrder<float,18> G;
540  UnitarizeForceArg<G,G> arg(G(newForce), G(oldForce), G(gauge), gauge, &num_failures, unitarize_eps, force_filter,
541  max_det_error, allow_svd, svd_only, svd_rel_error, svd_abs_error);
542  unitarizeForceCPU<float>(arg);
543  } else {
544  errorQuda("Precision = %d not supported", gauge.Precision());
545  }
546  } else if (gauge.Order() == QUDA_QDP_GAUGE_ORDER) {
547  if (gauge.Precision() == QUDA_DOUBLE_PRECISION) {
548  typedef gauge::QDPOrder<double,18> G;
549  UnitarizeForceArg<G,G> arg(G(newForce), G(oldForce), G(gauge), gauge, &num_failures, unitarize_eps, force_filter,
550  max_det_error, allow_svd, svd_only, svd_rel_error, svd_abs_error);
551  unitarizeForceCPU<double>(arg);
552  } else if (gauge.Precision() == QUDA_SINGLE_PRECISION) {
553  typedef gauge::QDPOrder<float,18> G;
554  UnitarizeForceArg<G,G> arg(G(newForce), G(oldForce), G(gauge), gauge, &num_failures, unitarize_eps, force_filter,
555  max_det_error, allow_svd, svd_only, svd_rel_error, svd_abs_error);
556  unitarizeForceCPU<float>(arg);
557  } else {
558  errorQuda("Precision = %d not supported", gauge.Precision());
559  }
560  } else {
561  errorQuda("Only MILC and QDP gauge orders supported\n");
562  }
563 
564  if (num_failures) errorQuda("Unitarization failed, failures = %d", num_failures);
565  return;
566  } // unitarize_force_cpu
567 
568  template <typename Float, typename Arg>
569  class UnitarizeForce : public Tunable {
570  private:
571  Arg &arg;
572  const GaugeField &meta;
573 
574  unsigned int sharedBytesPerThread() const { return 0; }
575  unsigned int sharedBytesPerBlock(const TuneParam &) const { return 0; }
576 
577  // don't tune the grid dimension
578  bool tuneGridDim() const { return false; }
579  unsigned int minThreads() const { return arg.threads; }
580 
581  public:
582  UnitarizeForce(Arg &arg, const GaugeField& meta) : arg(arg), meta(meta) {
583  writeAuxString("threads=%d,prec=%lu,stride=%d", meta.Volume(), meta.Precision(), meta.Stride());
584  }
585  virtual ~UnitarizeForce() { ; }
586 
587  void apply(const cudaStream_t &stream) {
588  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
589  getUnitarizeForceField<Float><<<tp.grid,tp.block>>>(arg);
590  }
591 
592  void preTune() { ; }
593  void postTune() { cudaMemset(arg.fails, 0, sizeof(int)); } // reset fails counter
594 
595  long long flops() const { return 4ll*4528*meta.Volume(); }
596  long long bytes() const { return 4ll * arg.threads * (arg.force.Bytes() + arg.force_old.Bytes() + arg.gauge.Bytes()); }
597 
598  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
599  }; // UnitarizeForce
600 
601  template<typename Float, typename Gauge>
602  void unitarizeForce(Gauge newForce, const Gauge oldForce, const Gauge gauge,
603  const GaugeField &meta, int* fails) {
604 
605  UnitarizeForceArg<Gauge,Gauge> arg(newForce, oldForce, gauge, meta, fails, unitarize_eps, force_filter,
606  max_det_error, allow_svd, svd_only, svd_rel_error, svd_abs_error);
607  UnitarizeForce<Float,UnitarizeForceArg<Gauge,Gauge> > unitarizeForce(arg, meta);
608  unitarizeForce.apply(0);
609  qudaDeviceSynchronize(); // need to synchronize to ensure failure write has completed
610  checkCudaError();
611  }
612 
613  void unitarizeForce(cudaGaugeField &newForce, const cudaGaugeField &oldForce, const cudaGaugeField &gauge,
614  int* fails) {
615 
616  if (oldForce.Reconstruct() != QUDA_RECONSTRUCT_NO)
617  errorQuda("Force field should not use reconstruct %d", oldForce.Reconstruct());
618 
619  if (newForce.Reconstruct() != QUDA_RECONSTRUCT_NO)
620  errorQuda("Force field should not use reconstruct %d", newForce.Reconstruct());
621 
622  if (oldForce.Reconstruct() != QUDA_RECONSTRUCT_NO)
623  errorQuda("Gauge field should not use reconstruct %d", gauge.Reconstruct());
624 
625  if (gauge.Precision() != oldForce.Precision() || gauge.Precision() != newForce.Precision())
626  errorQuda("Mixed precision not supported");
627 
628  if (gauge.Order() != oldForce.Order() || gauge.Order() != newForce.Order())
629  errorQuda("Mixed data ordering not supported");
630 
631  if (gauge.Order() == QUDA_FLOAT2_GAUGE_ORDER) {
632  if (gauge.Precision() == QUDA_DOUBLE_PRECISION) {
633  typedef typename gauge_mapper<double,QUDA_RECONSTRUCT_NO>::type G;
634  unitarizeForce<double>(G(newForce), G(oldForce), G(gauge), gauge, fails);
635  } else if (gauge.Precision() == QUDA_SINGLE_PRECISION) {
636  typedef typename gauge_mapper<float,QUDA_RECONSTRUCT_NO>::type G;
637  unitarizeForce<float>(G(newForce), G(oldForce), G(gauge), gauge, fails);
638  }
639  } else {
640  errorQuda("Data order %d not supported", gauge.Order());
641  }
642 
643  }
644 
645  } // namespace fermion_force
646 
647 } // namespace quda
648 
649 
650 #endif
cudaColorSpinorField * tmp2
__host__ __device__ double set(double &x)
Definition: blas_helper.cuh:58
void setUnitarizeForceConstants(double unitarize_eps, double hisq_force_filter, double max_det_error, bool allow_svd, bool svd_only, double svd_rel_error, double svd_abs_error)
Set the constant parameters for the force unitarization.
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define errorQuda(...)
Definition: util_quda.h:121
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:120
double epsilon
Definition: test_util.cpp:1649
cudaStream_t * stream
void unitarizeForce(cudaGaugeField &newForce, const cudaGaugeField &oldForce, const cudaGaugeField &gauge, int *unitarization_failed)
Unitarize the fermion force.
void unitarizeForceCPU(cpuGaugeField &newForce, const cpuGaugeField &oldForce, const cpuGaugeField &gauge)
Unitarize the fermion force on CPU.
int num_failures
__host__ __device__ void printLink(const Matrix< Cmplx, 3 > &link)
Definition: quda_matrix.h:1149
static __device__ double2 atomicAdd(double2 *addr, double2 val)
Implementation of double2 atomic addition using two double-precision additions.
Definition: atomic.cuh:51
static double svd_rel_error
#define qudaDeviceSynchronize()
constexpr int size
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
__host__ __device__ ValueType pow(ValueType x, ExponentType e)
Definition: complex_quda.h:111
static double svd_abs_error
Main header file for host and device accessors to GaugeFields.
__device__ __host__ T getTrace(const Matrix< T, 3 > &a)
Definition: quda_matrix.h:415
static double unitarize_eps
__shared__ float s[]
unsigned long long flops
Definition: blas_quda.cu:22
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
__host__ __device__ ValueType cos(ValueType x)
Definition: complex_quda.h:46
__host__ __device__ ValueType acos(ValueType x)
Definition: complex_quda.h:61
#define checkCudaError()
Definition: util_quda.h:161
__device__ __host__ T getDeterminant(const Mat< T, 3 > &a)
Definition: quda_matrix.h:422
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:130
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
QudaParity parity
Definition: covdev_test.cpp:54
unsigned long long bytes
Definition: blas_quda.cu:23