QUDA  v0.7.0
A library for QCD on GPUs
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
unitarize_force_quda.cu
Go to the documentation of this file.
1 #include <cstdlib>
2 #include <cstdio>
3 #include <iostream>
4 #include <iomanip>
5 #include <cuda.h>
6 #include <gauge_field.h>
7 #include <tune_quda.h>
8 
9 #include <tune_quda.h>
10 #include <quda_matrix.h>
11 
12 #ifdef GPU_HISQ_FORCE
13 
14 namespace quda{
15 namespace {
16  #include <svd_quda.h>
17 }
18 
19 //#ifdef GPU_HISQ_FORCE
20 
21 namespace { // anonymous
22 #include <svd_quda.h>
23 }
24 
25 #define HISQ_UNITARIZE_PI 3.14159265358979323846
26 #define HISQ_UNITARIZE_PI23 HISQ_UNITARIZE_PI*2.0/3.0
27 
28 // constants - File scope only
29 __constant__ double DEV_HISQ_UNITARIZE_EPS;
30 __constant__ double DEV_HISQ_FORCE_FILTER;
31 __constant__ double DEV_MAX_DET_ERROR;
32 __constant__ bool DEV_REUNIT_ALLOW_SVD;
33 __constant__ bool DEV_REUNIT_SVD_ONLY;
34 __constant__ double DEV_REUNIT_SVD_REL_ERROR;
35 __constant__ double DEV_REUNIT_SVD_ABS_ERROR;
36 
37 static double HOST_HISQ_UNITARIZE_EPS;
38 static double HOST_HISQ_FORCE_FILTER;
39 static double HOST_MAX_DET_ERROR;
40 static bool HOST_REUNIT_ALLOW_SVD;
41 static bool HOST_REUNIT_SVD_ONLY;
42 static double HOST_REUNIT_SVD_REL_ERROR;
43 static double HOST_REUNIT_SVD_ABS_ERROR;
44 
45 
46 
47  namespace fermion_force{
48 
49 
50  void setUnitarizeForceConstants(double unitarize_eps_h, double hisq_force_filter_h,
51  double max_det_error_h, bool allow_svd_h, bool svd_only_h,
52  double svd_rel_error_h, double svd_abs_error_h)
53  {
54 
55  // not_set is only initialised once
56  static bool not_set=true;
57 
58  if(not_set){
59 
60  cudaMemcpyToSymbol(DEV_HISQ_UNITARIZE_EPS, &unitarize_eps_h, sizeof(double));
61  cudaMemcpyToSymbol(DEV_HISQ_FORCE_FILTER, &hisq_force_filter_h, sizeof(double));
62  cudaMemcpyToSymbol(DEV_MAX_DET_ERROR, &max_det_error_h, sizeof(double));
63  cudaMemcpyToSymbol(DEV_REUNIT_ALLOW_SVD, &allow_svd_h, sizeof(bool));
64  cudaMemcpyToSymbol(DEV_REUNIT_SVD_ONLY, &svd_only_h, sizeof(bool));
65  cudaMemcpyToSymbol(DEV_REUNIT_SVD_REL_ERROR, &svd_rel_error_h, sizeof(double));
66  cudaMemcpyToSymbol(DEV_REUNIT_SVD_ABS_ERROR, &svd_abs_error_h, sizeof(double));
67 
68  HOST_HISQ_UNITARIZE_EPS = unitarize_eps_h;
69  HOST_HISQ_FORCE_FILTER = hisq_force_filter_h;
70  HOST_MAX_DET_ERROR = max_det_error_h;
71  HOST_REUNIT_ALLOW_SVD = allow_svd_h;
72  HOST_REUNIT_SVD_ONLY = svd_only_h;
73  HOST_REUNIT_SVD_REL_ERROR = svd_rel_error_h;
74  HOST_REUNIT_SVD_ABS_ERROR = svd_abs_error_h;
75  not_set = false;
76  }
78  return;
79  }
80 
81 
82  template<class Real>
83  class DerivativeCoefficients{
84  private:
85  Real b[6];
86  __device__ __host__
87  Real computeC00(const Real &, const Real &, const Real &);
88  __device__ __host__
89  Real computeC01(const Real &, const Real &, const Real &);
90  __device__ __host__
91  Real computeC02(const Real &, const Real &, const Real &);
92  __device__ __host__
93  Real computeC11(const Real &, const Real &, const Real &);
94  __device__ __host__
95  Real computeC12(const Real &, const Real &, const Real &);
96  __device__ __host__
97  Real computeC22(const Real &, const Real &, const Real &);
98 
99  public:
100  __device__ __host__ void set(const Real & u, const Real & v, const Real & w);
101  __device__ __host__
102  Real getB00() const { return b[0]; }
103  __device__ __host__
104  Real getB01() const { return b[1]; }
105  __device__ __host__
106  Real getB02() const { return b[2]; }
107  __device__ __host__
108  Real getB11() const { return b[3]; }
109  __device__ __host__
110  Real getB12() const { return b[4]; }
111  __device__ __host__
112  Real getB22() const { return b[5]; }
113  };
114 
115  template<class Real>
116  __device__ __host__
117  Real DerivativeCoefficients<Real>::computeC00(const Real & u, const Real & v, const Real & w){
118  Real result = -pow(w,3)*pow(u,6)
119  + 3*v*pow(w,3)*pow(u,4)
120  + 3*pow(v,4)*w*pow(u,4)
121  - pow(v,6)*pow(u,3)
122  - 4*pow(w,4)*pow(u,3)
123  - 12*pow(v,3)*pow(w,2)*pow(u,3)
124  + 16*pow(v,2)*pow(w,3)*pow(u,2)
125  + 3*pow(v,5)*w*pow(u,2)
126  - 8*v*pow(w,4)*u
127  - 3*pow(v,4)*pow(w,2)*u
128  + pow(w,5)
129  + pow(v,3)*pow(w,3);
130 
131  return result;
132  }
133 
134  template<class Real>
135  __device__ __host__
136  Real DerivativeCoefficients<Real>::computeC01(const Real & u, const Real & v, const Real & w){
137  Real result = - pow(w,2)*pow(u,7)
138  - pow(v,2)*w*pow(u,6)
139  + pow(v,4)*pow(u,5) // This was corrected!
140  + 6*v*pow(w,2)*pow(u,5)
141  - 5*pow(w,3)*pow(u,4) // This was corrected!
142  - pow(v,3)*w*pow(u,4)
143  - 2*pow(v,5)*pow(u,3)
144  - 6*pow(v,2)*pow(w,2)*pow(u,3)
145  + 10*v*pow(w,3)*pow(u,2)
146  + 6*pow(v,4)*w*pow(u,2)
147  - 3*pow(w,4)*u
148  - 6*pow(v,3)*pow(w,2)*u
149  + 2*pow(v,2)*pow(w,3);
150  return result;
151  }
152 
153  template<class Real>
154  __device__ __host__
155  Real DerivativeCoefficients<Real>::computeC02(const Real & u, const Real & v, const Real & w){
156  Real result = pow(w,2)*pow(u,5)
157  + pow(v,2)*w*pow(u,4)
158  - pow(v,4)*pow(u,3)
159  - 4*v*pow(w,2)*pow(u,3)
160  + 4*pow(w,3)*pow(u,2)
161  + 3*pow(v,3)*w*pow(u,2)
162  - 3*pow(v,2)*pow(w,2)*u
163  + v*pow(w,3);
164  return result;
165  }
166 
167  template<class Real>
168  __device__ __host__
169  Real DerivativeCoefficients<Real>::computeC11(const Real & u, const Real & v, const Real & w){
170  Real result = - w*pow(u,8)
171  - pow(v,2)*pow(u,7)
172  + 7*v*w*pow(u,6)
173  + 4*pow(v,3)*pow(u,5)
174  - 5*pow(w,2)*pow(u,5)
175  - 16*pow(v,2)*w*pow(u,4)
176  - 4*pow(v,4)*pow(u,3)
177  + 16*v*pow(w,2)*pow(u,3)
178  - 3*pow(w,3)*pow(u,2)
179  + 12*pow(v,3)*w*pow(u,2)
180  - 12*pow(v,2)*pow(w,2)*u
181  + 3*v*pow(w,3);
182  return result;
183  }
184 
185  template<class Real>
186  __device__ __host__
187  Real DerivativeCoefficients<Real>::computeC12(const Real & u, const Real & v, const Real & w){
188  Real result = w*pow(u,6)
189  + pow(v,2)*pow(u,5) // Fixed this!
190  - 5*v*w*pow(u,4) // Fixed this!
191  - 2*pow(v,3)*pow(u,3)
192  + 4*pow(w,2)*pow(u,3)
193  + 6*pow(v,2)*w*pow(u,2)
194  - 6*v*pow(w,2)*u
195  + pow(w,3);
196  return result;
197  }
198 
199  template<class Real>
200  __device__ __host__
201  Real DerivativeCoefficients<Real>::computeC22(const Real & u, const Real & v, const Real & w){
202  Real result = - w*pow(u,4)
203  - pow(v,2)*pow(u,3)
204  + 3*v*w*pow(u,2)
205  - 3*pow(w,2)*u;
206  return result;
207  }
208 
209  template <class Real>
210  __device__ __host__
211  void DerivativeCoefficients<Real>::set(const Real & u, const Real & v, const Real & w){
212  const Real & denominator = 2.0*pow(w*(u*v-w),3);
213  b[0] = computeC00(u,v,w)/denominator;
214  b[1] = computeC01(u,v,w)/denominator;
215  b[2] = computeC02(u,v,w)/denominator;
216  b[3] = computeC11(u,v,w)/denominator;
217  b[4] = computeC12(u,v,w)/denominator;
218  b[5] = computeC22(u,v,w)/denominator;
219  return;
220  }
221 
222 
223  template<class Cmplx>
224  __device__ __host__
225  void accumBothDerivatives(Matrix<Cmplx,3>* result, const Matrix<Cmplx,3> & left, const Matrix<Cmplx,3> & right, const Matrix<Cmplx,3> & outer_prod)
226  {
227  const typename RealTypeId<Cmplx>::Type temp = 2.0*getTrace(left*outer_prod).x;
228  for(int k=0; k<3; ++k){
229  for(int l=0; l<3; ++l){
230  // Need to write it this way to get it to work
231  // on the CPU. Not sure why.
232  result->operator()(k,l).x += temp*right(k,l).x;
233  result->operator()(k,l).y += temp*right(k,l).y;
234  }
235  }
236  return;
237  }
238 
239 
240  template<class Cmplx>
241  __device__ __host__
242  void accumDerivatives(Matrix<Cmplx,3>* result, const Matrix<Cmplx,3> & left, const Matrix<Cmplx,3> & right, const Matrix<Cmplx,3> & outer_prod)
243  {
244  Cmplx temp = getTrace(left*outer_prod);
245  for(int k=0; k<3; ++k){
246  for(int l=0; l<3; ++l){
247  result->operator()(k,l) = temp*right(k,l);
248  }
249  }
250  return;
251  }
252 
253 
254  template<class T>
255  __device__ __host__
256  T getAbsMin(const T* const array, int size){
257  T min = fabs(array[0]);
258  for(int i=1; i<size; ++i){
259  T abs_val = fabs(array[i]);
260  if((abs_val) < min){ min = abs_val; }
261  }
262  return min;
263  }
264 
265 
266  template<class Real>
267  __device__ __host__
268  inline bool checkAbsoluteError(Real a, Real b, Real epsilon)
269  {
270  if( fabs(a-b) < epsilon) return true;
271  return false;
272  }
273 
274 
275  template<class Real>
276  __device__ __host__
277  inline bool checkRelativeError(Real a, Real b, Real epsilon)
278  {
279  if( fabs((a-b)/b) < epsilon ) return true;
280  return false;
281  }
282 
283 
284 
285 
286  // Compute the reciprocal square root of the matrix q
287  // Also modify q if the eigenvalues are dangerously small.
288  template<class Cmplx>
289  __device__ __host__
290  void reciprocalRoot(Matrix<Cmplx,3>* res, DerivativeCoefficients<typename RealTypeId<Cmplx>::Type>* deriv_coeffs,
291  typename RealTypeId<Cmplx>::Type f[3], Matrix<Cmplx,3> & q, int *unitarization_failed){
292 
293  Matrix<Cmplx,3> qsq, tempq;
294 
295  typename RealTypeId<Cmplx>::Type c[3];
296  typename RealTypeId<Cmplx>::Type g[3];
297 
298 #ifdef __CUDA_ARCH__
299 #define REUNIT_SVD_ONLY DEV_REUNIT_SVD_ONLY
300 #else
301 #define REUNIT_SVD_ONLY HOST_REUNIT_SVD_ONLY
302 #endif
303  if(!REUNIT_SVD_ONLY){
304  qsq = q*q;
305  tempq = qsq*q;
306 
307  c[0] = getTrace(q).x;
308  c[1] = getTrace(qsq).x/2.0;
309  c[2] = getTrace(tempq).x/3.0;
310 
311  g[0] = g[1] = g[2] = c[0]/3.;
312  typename RealTypeId<Cmplx>::Type r,s,theta;
313  s = c[1]/3. - c[0]*c[0]/18;
314  r = c[2]/2. - (c[0]/3.)*(c[1] - c[0]*c[0]/9.);
315 
316 #ifdef __CUDA_ARCH__
317 #define HISQ_UNITARIZE_EPS DEV_HISQ_UNITARIZE_EPS
318 #else
319 #define HISQ_UNITARIZE_EPS HOST_HISQ_UNITARIZE_EPS
320 #endif
321 
322  typename RealTypeId<Cmplx>::Type cosTheta = r/sqrt(s*s*s);
323  if(fabs(s) < HISQ_UNITARIZE_EPS){
324  cosTheta = 1.;
325  s = 0.0;
326  }
327  if(fabs(cosTheta)>1.0){ r>0 ? theta=0.0 : theta=HISQ_UNITARIZE_PI/3.0; }
328  else{ theta = acos(cosTheta)/3.0; }
329 
330  s = 2.0*sqrt(s);
331  for(int i=0; i<3; ++i){
332  g[i] += s*cos(theta + (i-1)*HISQ_UNITARIZE_PI23);
333  }
334 
335  } // !REUNIT_SVD_ONLY?
336 
337  //
338  // Compare the product of the eigenvalues computed thus far to the
339  // absolute value of the determinant.
340  // If the determinant is very small or the relative error is greater than some predefined value
341  // then recompute the eigenvalues using a singular-value decomposition.
342  // Note that this particular calculation contains multiple branches,
343  // so it doesn't appear to be particularly well-suited to the GPU
344  // programming model. However, the analytic calculation of the
345  // unitarization is extremely fast, and if the SVD routine is not called
346  // too often, we expect pretty good performance.
347  //
348 
349 #ifdef __CUDA_ARCH__
350 #define REUNIT_ALLOW_SVD DEV_REUNIT_ALLOW_SVD
351 #define REUNIT_SVD_REL_ERROR DEV_REUNIT_SVD_REL_ERROR
352 #define REUNIT_SVD_ABS_ERROR DEV_REUNIT_SVD_ABS_ERROR
353 #else // cpu
354 #define REUNIT_ALLOW_SVD HOST_REUNIT_ALLOW_SVD
355 #define REUNIT_SVD_REL_ERROR HOST_REUNIT_SVD_REL_ERROR
356 #define REUNIT_SVD_ABS_ERROR HOST_REUNIT_SVD_ABS_ERROR
357 #endif
358 
359  if(REUNIT_ALLOW_SVD){
360  bool perform_svd = true;
361  if(!REUNIT_SVD_ONLY){
362  const typename RealTypeId<Cmplx>::Type det = getDeterminant(q).x;
363  if( fabs(det) >= REUNIT_SVD_ABS_ERROR){
364  if( checkRelativeError(g[0]*g[1]*g[2],det,REUNIT_SVD_REL_ERROR) ) perform_svd = false;
365  }
366  }
367 
368  if(perform_svd){
370  // compute the eigenvalues using the singular value decomposition
371  computeSVD<Cmplx>(q,tempq,tmp2,g);
372  // The array g contains the eigenvalues of the matrix q
373  // The determinant is the product of the eigenvalues, and I can use this
374  // to check the SVD
375  const typename RealTypeId<Cmplx>::Type determinant = getDeterminant(q).x;
376  const typename RealTypeId<Cmplx>::Type gprod = g[0]*g[1]*g[2];
377  // Check the svd result for errors
378 #ifdef __CUDA_ARCH__
379 #define MAX_DET_ERROR DEV_MAX_DET_ERROR
380 #else
381 #define MAX_DET_ERROR HOST_MAX_DET_ERROR
382 #endif
383  if(fabs(gprod - determinant) > MAX_DET_ERROR){
384 #if (!defined(__CUDA_ARCH__) || (__COMPUTE_CAPABILITY__ >= 200))
385  printf("Warning: Error in determinant computed by SVD : %g > %g\n", fabs(gprod-determinant), MAX_DET_ERROR);
386  printLink(q);
387 #endif
388 
389 #ifdef __CUDA_ARCH__
390  atomicAdd(unitarization_failed,1);
391 #else
392  (*unitarization_failed)++;
393 #endif
394  }
395  } // perform_svd?
396 
397  } // REUNIT_ALLOW_SVD?
398 
399 #ifdef __CUDA_ARCH__
400 #define HISQ_FORCE_FILTER DEV_HISQ_FORCE_FILTER
401 #else
402 #define HISQ_FORCE_FILTER HOST_HISQ_FORCE_FILTER
403 #endif
404  typename RealTypeId<Cmplx>::Type delta = getAbsMin(g,3);
405  if(delta < HISQ_FORCE_FILTER){
406  for(int i=0; i<3; ++i){
407  g[i] += HISQ_FORCE_FILTER;
408  q(i,i).x += HISQ_FORCE_FILTER;
409  }
410  qsq = q*q; // recalculate Q^2
411  }
412 
413 
414  // At this point we have finished with the c's
415  // use these to store sqrt(g)
416  for(int i=0; i<3; ++i) c[i] = sqrt(g[i]);
417 
418  // done with the g's, use these to store u, v, w
419  g[0] = c[0]+c[1]+c[2];
420  g[1] = c[0]*c[1] + c[0]*c[2] + c[1]*c[2];
421  g[2] = c[0]*c[1]*c[2];
422 
423  // set the derivative coefficients!
424  deriv_coeffs->set(g[0], g[1], g[2]);
425 
426  const typename RealTypeId<Cmplx>::Type & denominator = g[2]*(g[0]*g[1]-g[2]);
427  c[0] = (g[0]*g[1]*g[1] - g[2]*(g[0]*g[0]+g[1]))/denominator;
428  c[1] = (-g[0]*g[0]*g[0] - g[2] + 2.*g[0]*g[1])/denominator;
429  c[2] = g[0]/denominator;
430 
431  tempq = c[1]*q + c[2]*qsq;
432  // Add a real scalar
433  tempq(0,0).x += c[0];
434  tempq(1,1).x += c[0];
435  tempq(2,2).x += c[0];
436 
437  f[0] = c[0];
438  f[1] = c[1];
439  f[2] = c[2];
440 
441  *res = tempq;
442  return;
443  }
444 
445 
446 
447  // "v" denotes a "fattened" link variable
448  template<class Cmplx>
449  __device__ __host__
450  void getUnitarizeForceSite(const Matrix<Cmplx,3> & v, const Matrix<Cmplx,3> & outer_prod, Matrix<Cmplx,3>* result, int *unitarization_failed)
451  {
452  typename RealTypeId<Cmplx>::Type f[3];
453  typename RealTypeId<Cmplx>::Type b[6];
454 
455  Matrix<Cmplx,3> v_dagger = conj(v); // okay!
456  Matrix<Cmplx,3> q = v_dagger*v; // okay!
457 
458  Matrix<Cmplx,3> rsqrt_q;
459 
460  DerivativeCoefficients<typename RealTypeId<Cmplx>::Type> deriv_coeffs;
461 
462  reciprocalRoot<Cmplx>(&rsqrt_q, &deriv_coeffs, f, q, unitarization_failed); // approx 529 flops (assumes no SVD)
463 
464  // Pure hack here
465  b[0] = deriv_coeffs.getB00();
466  b[1] = deriv_coeffs.getB01();
467  b[2] = deriv_coeffs.getB02();
468  b[3] = deriv_coeffs.getB11();
469  b[4] = deriv_coeffs.getB12();
470  b[5] = deriv_coeffs.getB22();
471 
472 
473  Matrix<Cmplx,3> & local_result = *result;
474 
475  local_result = rsqrt_q*outer_prod;
476 
477  // We are now finished with rsqrt_q
478  Matrix<Cmplx,3> qv_dagger = q*v_dagger;
479  Matrix<Cmplx,3> vv_dagger = v*v_dagger;
480  Matrix<Cmplx,3> vqv_dagger = v*qv_dagger;
481  Matrix<Cmplx,3> temp = f[1]*vv_dagger + f[2]*vqv_dagger;
482 
483 
484  temp = f[1]*v_dagger + f[2]*qv_dagger;
485  Matrix<Cmplx,3> conj_outer_prod = conj(outer_prod);
486 
487 
488  temp = f[1]*v + f[2]*v*q;
489  local_result = local_result + outer_prod*temp*v_dagger + f[2]*q*outer_prod*vv_dagger;
490 
491  local_result = local_result + v_dagger*conj_outer_prod*conj(temp) + f[2]*qv_dagger*conj_outer_prod*v_dagger;
492 
493 
494  // now done with vv_dagger, I think
495  Matrix<Cmplx,3> qsqv_dagger = q*qv_dagger;
496  Matrix<Cmplx,3> pv_dagger = b[0]*v_dagger + b[1]*qv_dagger + b[2]*qsqv_dagger;
497  accumBothDerivatives(&local_result, v, pv_dagger, outer_prod); // 41 flops
498 
499  Matrix<Cmplx,3> rv_dagger = b[1]*v_dagger + b[3]*qv_dagger + b[4]*qsqv_dagger;
500  Matrix<Cmplx,3> vq = v*q;
501  accumBothDerivatives(&local_result, vq, rv_dagger, outer_prod); // 41 flops
502 
503  Matrix<Cmplx,3> sv_dagger = b[2]*v_dagger + b[4]*qv_dagger + b[5]*qsqv_dagger;
504  Matrix<Cmplx,3> vqsq = vq*q;
505  accumBothDerivatives(&local_result, vqsq, sv_dagger, outer_prod); // 41 flops
506  return;
507  // 4528 flops - 17 matrix multiplies (198 flops each) + reciprocal root (approx 529 flops) + accumBothDerivatives (41 each) + miscellaneous
508  } // get unit force term
509 
510 
511 
512  template<class Cmplx>
513  __global__ void getUnitarizeForceField(const int threads, const Cmplx* link_even, const Cmplx* link_odd,
514  const Cmplx* old_force_even, const Cmplx* old_force_odd,
515  Cmplx* force_even, Cmplx* force_odd,
516  int* unitarization_failed)
517  {
518 
519  int mem_idx = blockIdx.x*blockDim.x + threadIdx.x;
520  // The number of GPU threads is equal to the local volume
521  const int HALF_VOLUME = threads/2;
522  if(mem_idx >= threads) return;
523 
524  Cmplx* force;
525  const Cmplx* link;
526  const Cmplx* old_force;
527 
528  force = force_even;
529  link = link_even;
530  old_force = old_force_even;
531  if(mem_idx >= HALF_VOLUME){
532  mem_idx = mem_idx - HALF_VOLUME;
533  force = force_odd;
534  link = link_odd;
535  old_force = old_force_odd;
536  }
537 
538 
539  // This part of the calculation is always done in double precision
540  Matrix<double2,3> v, result, oprod;
541 
542  for(int dir=0; dir<4; ++dir){
543  loadLinkVariableFromArray(old_force, dir, mem_idx, HALF_VOLUME, &oprod);
544  loadLinkVariableFromArray(link, dir, mem_idx, HALF_VOLUME, &v);
545 
546  getUnitarizeForceSite<double2>(v, oprod, &result, unitarization_failed);
547 
548  writeLinkVariableToArray(result, dir, mem_idx, HALF_VOLUME, force);
549  } // 4*4528 flops per site
550  return;
551  } // getUnitarizeForceField
552 
553 
554  void unitarizeForceCPU(cpuGaugeField& cpuOldForce, cpuGaugeField& cpuGauge, cpuGaugeField* cpuNewForce)
555  {
556 
557  int num_failures = 0;
558  Matrix<double2,3> old_force, new_force, v;
559 
560  // I can change this code to make it much more compact
561 
562  const QudaGaugeFieldOrder order = cpuGauge.Order();
563 
564  if(order == QUDA_MILC_GAUGE_ORDER){
565  for(int i=0; i<cpuGauge.Volume(); ++i){
566  for(int dir=0; dir<4; ++dir){
567  if(cpuGauge.Precision() == QUDA_SINGLE_PRECISION){
568  copyArrayToLink(&old_force, ((float*)(cpuOldForce.Gauge_p()) + (i*4 + dir)*18));
569  copyArrayToLink(&v, ((float*)(cpuGauge.Gauge_p()) + (i*4 + dir)*18));
570  getUnitarizeForceSite<double2>(v, old_force, &new_force, &num_failures);
571  copyLinkToArray(((float*)(cpuNewForce->Gauge_p()) + (i*4 + dir)*18), new_force);
572  }else if(cpuGauge.Precision() == QUDA_DOUBLE_PRECISION){
573  copyArrayToLink(&old_force, ((double*)(cpuOldForce.Gauge_p()) + (i*4 + dir)*18));
574  copyArrayToLink(&v, ((double*)(cpuGauge.Gauge_p()) + (i*4 + dir)*18));
575  getUnitarizeForceSite<double2>(v, old_force, &new_force, &num_failures);
576  copyLinkToArray(((double*)(cpuNewForce->Gauge_p()) + (i*4 + dir)*18), new_force);
577  } // precision?
578  } // dir
579  } // i
580  }else if(order == QUDA_QDP_GAUGE_ORDER){
581  for(int dir=0; dir<4; ++dir){
582  for(int i=0; i<cpuGauge.Volume(); ++i){
583  if(cpuGauge.Precision() == QUDA_SINGLE_PRECISION){
584  copyArrayToLink(&old_force, ((float**)(cpuOldForce.Gauge_p()))[dir] + i*18);
585  copyArrayToLink(&v, ((float**)(cpuGauge.Gauge_p()))[dir] + i*18);
586  getUnitarizeForceSite<double2>(v, old_force, &new_force, &num_failures);
587  copyLinkToArray(((float**)(cpuNewForce->Gauge_p()))[dir] + i*18, new_force);
588  }else if(cpuGauge.Precision() == QUDA_DOUBLE_PRECISION){
589  copyArrayToLink(&old_force, ((double**)(cpuOldForce.Gauge_p()))[dir] + i*18);
590  copyArrayToLink(&v, ((double**)(cpuGauge.Gauge_p()))[dir] + i*18);
591  getUnitarizeForceSite<double2>(v, old_force, &new_force, &num_failures);
592  copyLinkToArray(((double**)(cpuNewForce->Gauge_p()))[dir] + i*18, new_force);
593  }
594  }
595  }
596  }else{
597  errorQuda("Only MILC and QDP gauge orders supported\n");
598  }
599  return;
600  } // unitarize_force_cpu
601 
602  class UnitarizeForceCuda : public Tunable {
603  private:
604  const cudaGaugeField &oldForce;
605  const cudaGaugeField &gauge;
606  cudaGaugeField &newForce;
607  int *fails;
608 
609  unsigned int sharedBytesPerThread() const { return 0; }
610  unsigned int sharedBytesPerBlock(const TuneParam &) const { return 0; }
611 
612  // don't tune the grid dimension
613  bool tuneGridDim() const { return false; }
614  unsigned int minThreads() const { return gauge.Volume(); }
615 
616  public:
617  UnitarizeForceCuda(const cudaGaugeField& oldForce, const cudaGaugeField& gauge,
618  cudaGaugeField& newForce, int* fails) :
619  oldForce(oldForce), gauge(gauge), newForce(newForce), fails(fails) {
620  writeAuxString("threads=%d,prec=%lu,stride=%d",
621  gauge.Volume(), gauge.Precision(), gauge.Stride());
622  }
623  virtual ~UnitarizeForceCuda() { ; }
624 
625  void apply(const cudaStream_t &stream) {
626  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
627 
628  if(gauge.Precision() == QUDA_SINGLE_PRECISION){
629  getUnitarizeForceField<<<tp.grid,tp.block>>>(gauge.Volume(), (const float2*)gauge.Even_p(), (const float2*)gauge.Odd_p(),
630  (const float2*)oldForce.Even_p(), (const float2*)oldForce.Odd_p(),
631  (float2*)newForce.Even_p(), (float2*)newForce.Odd_p(),
632  fails);
633  }else if(gauge.Precision() == QUDA_DOUBLE_PRECISION){
634  getUnitarizeForceField<<<tp.grid,tp.block>>>(gauge.Volume(), (const double2*)gauge.Even_p(), (const double2*)gauge.Odd_p(),
635  (const double2*)oldForce.Even_p(), (const double2*)oldForce.Odd_p(),
636  (double2*)newForce.Even_p(), (double2*)newForce.Odd_p(),
637  fails);
638  }
639  }
640 
641  void preTune() { ; }
642  void postTune() { cudaMemset(fails, 0, sizeof(int)); } // reset fails counter
643 
644  long long flops() const { return 4*4528*gauge.Volume(); } // FIXME: add flops counter
645 
646  TuneKey tuneKey() const { return TuneKey(gauge.VolString(), typeid(*this).name(), aux); }
647  }; // UnitarizeForceCuda
648 
649  void unitarizeForceCuda(cudaGaugeField &cudaOldForce,
650  cudaGaugeField &cudaGauge, cudaGaugeField *cudaNewForce, int* unitarization_failed, long long *flops) {
651 
652  UnitarizeForceCuda unitarizeForce(cudaOldForce, cudaGauge, *cudaNewForce, unitarization_failed);
653  unitarizeForce.apply(0);
654  if(flops) *flops = unitarizeForce.flops();
655  checkCudaError();
656  }
657 
658 
659  } // namespace fermion_force
660 
661 //#endif
662 } // namespace quda
663 
664 #endif
int y[4]
void setUnitarizeForceConstants(double unitarize_eps, double hisq_force_filter, double max_det_error, bool allow_svd, bool svd_only, double svd_rel_error, double svd_abs_error)
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:20
void unitarizeForceCuda(cudaGaugeField &cudaOldForce, cudaGaugeField &cudaGauge, cudaGaugeField *cudaNewForce, int *unitarization_failed, long long *flops=NULL)
#define errorQuda(...)
Definition: util_quda.h:73
void unitarizeForceCPU(cpuGaugeField &cpuOldForce, cpuGaugeField &cpuGauge, cpuGaugeField *cpuNewForce)
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:105
cudaStream_t * stream
__global__ void const FloatN FloatM FloatM Float Float int threads
Definition: llfat_core.h:1099
__device__ __host__ T getDeterminant(const Matrix< T, 3 > &a)
Definition: quda_matrix.h:385
cudaGaugeField * cudaGauge
__host__ __device__ void printLink(const Matrix< Cmplx, 3 > &link)
Definition: quda_matrix.h:1012
void copyLinkToArray(float *array, const Matrix< float2, 3 > &link)
Definition: quda_matrix.h:985
cudaColorSpinorField * tmp2
Definition: dslash_test.cpp:41
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:271
__host__ __device__ ValueType pow(ValueType x, ExponentType e)
Definition: complex_quda.h:100
enum QudaGaugeFieldOrder_s QudaGaugeFieldOrder
cpuGaugeField * cpuGauge
int mem_idx
Definition: llfat_core.h:812
int x[4]
void copyArrayToLink(Matrix< float2, 3 > *link, float *array)
Definition: quda_matrix.h:962
__device__ __host__ T getTrace(const Matrix< T, 3 > &a)
Definition: quda_matrix.h:378
__device__ void loadLinkVariableFromArray(const T *const array, const int dir, const int idx, const int stride, Matrix< T, 3 > *link)
Definition: quda_matrix.h:767
__host__ __device__ ValueType cos(ValueType x)
Definition: complex_quda.h:35
__host__ __device__ ValueType acos(ValueType x)
Definition: complex_quda.h:50
#define checkCudaError()
Definition: util_quda.h:110
__device__ void writeLinkVariableToArray(const Matrix< T, 3 > &link, const int dir, const int idx, const int stride, T *const array)
Definition: quda_matrix.h:830
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:115
QudaTune getTuning()
Definition: util_quda.cpp:32
VOLATILE spinorFloat * s
void * gauge[4]
Definition: su3_test.cpp:15