QUDA  v0.5.0
A library for QCD on GPUs
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
unitarize_force_quda.cu
Go to the documentation of this file.
1 #include <cstdlib>
2 #include <cstdio>
3 #include <iostream>
4 #include <iomanip>
5 #include <cuda.h>
6 #include <gauge_field.h>
7 
8 #include <quda_matrix.h>
9 #include <svd_quda.h>
10 
11 namespace quda{
12 
13 #define HISQ_UNITARIZE_PI 3.14159265358979323846
14 #define HISQ_UNITARIZE_PI23 HISQ_UNITARIZE_PI*2.0/3.0
15 
16 // constants - File scope only
17 __constant__ double DEV_HISQ_UNITARIZE_EPS;
18 __constant__ double DEV_HISQ_FORCE_FILTER;
19 __constant__ double DEV_MAX_DET_ERROR;
20 __constant__ bool DEV_REUNIT_ALLOW_SVD;
21 __constant__ bool DEV_REUNIT_SVD_ONLY;
22 __constant__ double DEV_REUNIT_SVD_REL_ERROR;
23 __constant__ double DEV_REUNIT_SVD_ABS_ERROR;
24 
25 
26 
27 static double HOST_HISQ_UNITARIZE_EPS;
28 static double HOST_HISQ_FORCE_FILTER;
29 static double HOST_MAX_DET_ERROR;
30 static bool HOST_REUNIT_ALLOW_SVD;
31 static bool HOST_REUNIT_SVD_ONLY;
32 static double HOST_REUNIT_SVD_REL_ERROR;
33 static double HOST_REUNIT_SVD_ABS_ERROR;
34 
35 
36 
37 
38  namespace fermion_force{
39 
40 
41  void setUnitarizeForceConstants(double unitarize_eps_h, double hisq_force_filter_h,
42  double max_det_error_h, bool allow_svd_h, bool svd_only_h,
43  double svd_rel_error_h, double svd_abs_error_h)
44  {
45 
46  // not_set is only initialised once
47  static bool not_set=true;
48 
49  if(not_set){
50  cudaMemcpyToSymbol(DEV_HISQ_UNITARIZE_EPS, &unitarize_eps_h, sizeof(double));
51  cudaMemcpyToSymbol(DEV_HISQ_FORCE_FILTER, &hisq_force_filter_h, sizeof(double));
52  cudaMemcpyToSymbol(DEV_MAX_DET_ERROR, &max_det_error_h, sizeof(double));
53  cudaMemcpyToSymbol(DEV_REUNIT_ALLOW_SVD, &allow_svd_h, sizeof(bool));
54  cudaMemcpyToSymbol(DEV_REUNIT_SVD_ONLY, &svd_only_h, sizeof(bool));
55  cudaMemcpyToSymbol(DEV_REUNIT_SVD_REL_ERROR, &svd_rel_error_h, sizeof(double));
56  cudaMemcpyToSymbol(DEV_REUNIT_SVD_ABS_ERROR, &svd_abs_error_h, sizeof(double));
57 
58  HOST_HISQ_UNITARIZE_EPS = unitarize_eps_h;
59  HOST_HISQ_FORCE_FILTER = hisq_force_filter_h;
60  HOST_MAX_DET_ERROR = max_det_error_h;
61  HOST_REUNIT_ALLOW_SVD = allow_svd_h;
62  HOST_REUNIT_SVD_ONLY = svd_only_h;
63  HOST_REUNIT_SVD_REL_ERROR = svd_rel_error_h;
64  HOST_REUNIT_SVD_ABS_ERROR = svd_abs_error_h;
65 
66  not_set = false;
67  }
69  return;
70  }
71 
72 
73  template<class Real>
75  private:
76  Real b[6];
77  __device__ __host__
78  Real computeC00(const Real &, const Real &, const Real &);
79  __device__ __host__
80  Real computeC01(const Real &, const Real &, const Real &);
81  __device__ __host__
82  Real computeC02(const Real &, const Real &, const Real &);
83  __device__ __host__
84  Real computeC11(const Real &, const Real &, const Real &);
85  __device__ __host__
86  Real computeC12(const Real &, const Real &, const Real &);
87  __device__ __host__
88  Real computeC22(const Real &, const Real &, const Real &);
89 
90  public:
91  __device__ __host__ void set(const Real & u, const Real & v, const Real & w);
92  __device__ __host__
93  Real getB00() const { return b[0]; }
94  __device__ __host__
95  Real getB01() const { return b[1]; }
96  __device__ __host__
97  Real getB02() const { return b[2]; }
98  __device__ __host__
99  Real getB11() const { return b[3]; }
100  __device__ __host__
101  Real getB12() const { return b[4]; }
102  __device__ __host__
103  Real getB22() const { return b[5]; }
104  };
105 
106  template<class Real>
107  __device__ __host__
108  Real DerivativeCoefficients<Real>::computeC00(const Real & u, const Real & v, const Real & w){
109  Real result = -pow(w,3)*pow(u,6)
110  + 3*v*pow(w,3)*pow(u,4)
111  + 3*pow(v,4)*w*pow(u,4)
112  - pow(v,6)*pow(u,3)
113  - 4*pow(w,4)*pow(u,3)
114  - 12*pow(v,3)*pow(w,2)*pow(u,3)
115  + 16*pow(v,2)*pow(w,3)*pow(u,2)
116  + 3*pow(v,5)*w*pow(u,2)
117  - 8*v*pow(w,4)*u
118  - 3*pow(v,4)*pow(w,2)*u
119  + pow(w,5)
120  + pow(v,3)*pow(w,3);
121 
122  return result;
123  }
124 
125  template<class Real>
126  __device__ __host__
127  Real DerivativeCoefficients<Real>::computeC01(const Real & u, const Real & v, const Real & w){
128  Real result = - pow(w,2)*pow(u,7)
129  - pow(v,2)*w*pow(u,6)
130  + pow(v,4)*pow(u,5) // This was corrected!
131  + 6*v*pow(w,2)*pow(u,5)
132  - 5*pow(w,3)*pow(u,4) // This was corrected!
133  - pow(v,3)*w*pow(u,4)
134  - 2*pow(v,5)*pow(u,3)
135  - 6*pow(v,2)*pow(w,2)*pow(u,3)
136  + 10*v*pow(w,3)*pow(u,2)
137  + 6*pow(v,4)*w*pow(u,2)
138  - 3*pow(w,4)*u
139  - 6*pow(v,3)*pow(w,2)*u
140  + 2*pow(v,2)*pow(w,3);
141  return result;
142  }
143 
144  template<class Real>
145  __device__ __host__
146  Real DerivativeCoefficients<Real>::computeC02(const Real & u, const Real & v, const Real & w){
147  Real result = pow(w,2)*pow(u,5)
148  + pow(v,2)*w*pow(u,4)
149  - pow(v,4)*pow(u,3)
150  - 4*v*pow(w,2)*pow(u,3)
151  + 4*pow(w,3)*pow(u,2)
152  + 3*pow(v,3)*w*pow(u,2)
153  - 3*pow(v,2)*pow(w,2)*u
154  + v*pow(w,3);
155  return result;
156  }
157 
158  template<class Real>
159  __device__ __host__
160  Real DerivativeCoefficients<Real>::computeC11(const Real & u, const Real & v, const Real & w){
161  Real result = - w*pow(u,8)
162  - pow(v,2)*pow(u,7)
163  + 7*v*w*pow(u,6)
164  + 4*pow(v,3)*pow(u,5)
165  - 5*pow(w,2)*pow(u,5)
166  - 16*pow(v,2)*w*pow(u,4)
167  - 4*pow(v,4)*pow(u,3)
168  + 16*v*pow(w,2)*pow(u,3)
169  - 3*pow(w,3)*pow(u,2)
170  + 12*pow(v,3)*w*pow(u,2)
171  - 12*pow(v,2)*pow(w,2)*u
172  + 3*v*pow(w,3);
173  return result;
174  }
175 
176  template<class Real>
177  __device__ __host__
178  Real DerivativeCoefficients<Real>::computeC12(const Real & u, const Real & v, const Real & w){
179  Real result = w*pow(u,6)
180  + pow(v,2)*pow(u,5) // Fixed this!
181  - 5*v*w*pow(u,4) // Fixed this!
182  - 2*pow(v,3)*pow(u,3)
183  + 4*pow(w,2)*pow(u,3)
184  + 6*pow(v,2)*w*pow(u,2)
185  - 6*v*pow(w,2)*u
186  + pow(w,3);
187  return result;
188  }
189 
190  template<class Real>
191  __device__ __host__
192  Real DerivativeCoefficients<Real>::computeC22(const Real & u, const Real & v, const Real & w){
193  Real result = - w*pow(u,4)
194  - pow(v,2)*pow(u,3)
195  + 3*v*w*pow(u,2)
196  - 3*pow(w,2)*u;
197  return result;
198  }
199 
200  template <class Real>
201  __device__ __host__
202  void DerivativeCoefficients<Real>::set(const Real & u, const Real & v, const Real & w){
203  const Real & denominator = 2.0*pow(w*(u*v-w),3);
204  b[0] = computeC00(u,v,w)/denominator;
205  b[1] = computeC01(u,v,w)/denominator;
206  b[2] = computeC02(u,v,w)/denominator;
207  b[3] = computeC11(u,v,w)/denominator;
208  b[4] = computeC12(u,v,w)/denominator;
209  b[5] = computeC22(u,v,w)/denominator;
210  return;
211  }
212 
213 
214  template<class Cmplx>
215  __device__ __host__
216  void accumBothDerivatives(Matrix<Cmplx,3>* result, const Matrix<Cmplx,3> & left, const Matrix<Cmplx,3> & right, const Matrix<Cmplx,3> & outer_prod)
217  {
218  const typename RealTypeId<Cmplx>::Type temp = 2.0*getTrace(left*outer_prod).x;
219  for(int k=0; k<3; ++k){
220  for(int l=0; l<3; ++l){
221  // Need to write it this way to get it to work
222  // on the CPU. Not sure why.
223  result->operator()(k,l).x += temp*right(k,l).x;
224  result->operator()(k,l).y += temp*right(k,l).y;
225  }
226  }
227  return;
228  }
229 
230 
231  template<class Cmplx>
232  __device__ __host__
233  void accumDerivatives(Matrix<Cmplx,3>* result, const Matrix<Cmplx,3> & left, const Matrix<Cmplx,3> & right, const Matrix<Cmplx,3> & outer_prod)
234  {
235  Cmplx temp = getTrace(left*outer_prod);
236  for(int k=0; k<3; ++k){
237  for(int l=0; l<3; ++l){
238  result->operator()(k,l) = temp*right(k,l);
239  }
240  }
241  return;
242  }
243 
244 
245  template<class T>
246  __device__ __host__
247  T getAbsMin(const T* const array, int size){
248  T min = fabs(array[0]);
249  for(int i=1; i<size; ++i){
250  T abs_val = fabs(array[i]);
251  if((abs_val) < min){ min = abs_val; }
252  }
253  return min;
254  }
255 
256 
257  template<class Real>
258  __device__ __host__
259  inline bool checkAbsoluteError(Real a, Real b, Real epsilon)
260  {
261  if( fabs(a-b) < epsilon) return true;
262  return false;
263  }
264 
265 
266  template<class Real>
267  __device__ __host__
268  inline bool checkRelativeError(Real a, Real b, Real epsilon)
269  {
270  if( fabs((a-b)/b) < epsilon ) return true;
271  return false;
272  }
273 
274 
275 
276 
277  // Compute the reciprocal square root of the matrix q
278  // Also modify q if the eigenvalues are dangerously small.
279  template<class Cmplx>
280  __device__ __host__
282  typename RealTypeId<Cmplx>::Type f[3], Matrix<Cmplx,3> & q, int *unitarization_failed){
283 
284  Matrix<Cmplx,3> qsq, tempq;
285 
286  typename RealTypeId<Cmplx>::Type c[3];
287  typename RealTypeId<Cmplx>::Type g[3];
288 
289 #ifdef __CUDA_ARCH__
290 #define REUNIT_SVD_ONLY DEV_REUNIT_SVD_ONLY
291 #else
292 #define REUNIT_SVD_ONLY HOST_REUNIT_SVD_ONLY
293 #endif
294  if(!REUNIT_SVD_ONLY){
295  qsq = q*q;
296  tempq = qsq*q;
297 
298  c[0] = getTrace(q).x;
299  c[1] = getTrace(qsq).x/2.0;
300  c[2] = getTrace(tempq).x/3.0;
301 
302  g[0] = g[1] = g[2] = c[0]/3.;
303  typename RealTypeId<Cmplx>::Type r,s,theta;
304  s = c[1]/3. - c[0]*c[0]/18;
305  r = c[2]/2. - (c[0]/3.)*(c[1] - c[0]*c[0]/9.);
306 
307 #ifdef __CUDA_ARCH__
308 #define HISQ_UNITARIZE_EPS DEV_HISQ_UNITARIZE_EPS
309 #else
310 #define HISQ_UNITARIZE_EPS HOST_HISQ_UNITARIZE_EPS
311 #endif
312 
313  typename RealTypeId<Cmplx>::Type cosTheta = r/sqrt(s*s*s);
314  if(fabs(s) < HISQ_UNITARIZE_EPS){
315  cosTheta = 1.;
316  s = 0.0;
317  }
318  if(fabs(cosTheta)>1.0){ r>0 ? theta=0.0 : theta=HISQ_UNITARIZE_PI/3.0; }
319  else{ theta = acos(cosTheta)/3.0; }
320 
321  s = 2.0*sqrt(s);
322  for(int i=0; i<3; ++i){
323  g[i] += s*cos(theta + (i-1)*HISQ_UNITARIZE_PI23);
324  }
325 
326  } // !REUNIT_SVD_ONLY?
327 
328  //
329  // Compare the product of the eigenvalues computed thus far to the
330  // absolute value of the determinant.
331  // If the determinant is very small or the relative error is greater than some predefined value
332  // then recompute the eigenvalues using a singular-value decomposition.
333  // Note that this particular calculation contains multiple branches,
334  // so it doesn't appear to be particularly well-suited to the GPU
335  // programming model. However, the analytic calculation of the
336  // unitarization is extremely fast, and if the SVD routine is not called
337  // too often, we expect pretty good performance.
338  //
339 
340 #ifdef __CUDA_ARCH__
341 #define REUNIT_ALLOW_SVD DEV_REUNIT_ALLOW_SVD
342 #define REUNIT_SVD_REL_ERROR DEV_REUNIT_SVD_REL_ERROR
343 #define REUNIT_SVD_ABS_ERROR DEV_REUNIT_SVD_ABS_ERROR
344 #else // cpu
345 #define REUNIT_ALLOW_SVD HOST_REUNIT_ALLOW_SVD
346 #define REUNIT_SVD_REL_ERROR HOST_REUNIT_SVD_REL_ERROR
347 #define REUNIT_SVD_ABS_ERROR HOST_REUNIT_SVD_ABS_ERROR
348 #endif
349 
350  if(REUNIT_ALLOW_SVD){
351  bool perform_svd = true;
352  if(!REUNIT_SVD_ONLY){
353  const typename RealTypeId<Cmplx>::Type det = getDeterminant(q).x;
354  if( fabs(det) >= REUNIT_SVD_ABS_ERROR){
355  if( checkRelativeError(g[0]*g[1]*g[2],det,REUNIT_SVD_REL_ERROR) ) perform_svd = false;
356  }
357  }
358 
359  if(perform_svd){
361  // compute the eigenvalues using the singular value decomposition
362  computeSVD<Cmplx>(q,tempq,tmp2,g);
363  // The array g contains the eigenvalues of the matrix q
364  // The determinant is the product of the eigenvalues, and I can use this
365  // to check the SVD
366  const typename RealTypeId<Cmplx>::Type determinant = getDeterminant(q).x;
367  const typename RealTypeId<Cmplx>::Type gprod = g[0]*g[1]*g[2];
368  // Check the svd result for errors
369 #ifdef __CUDA_ARCH__
370 #define MAX_DET_ERROR DEV_MAX_DET_ERROR
371 #else
372 #define MAX_DET_ERROR HOST_MAX_DET_ERROR
373 #endif
374  if(fabs(gprod - determinant) > MAX_DET_ERROR){
375 #if (!defined(__CUDA_ARCH__) || (__COMPUTE_CAPABILITY__ >= 200))
376  printf("Warning: Error in determinant computed by SVD : %g > %g\n", fabs(gprod-determinant), MAX_DET_ERROR);
377  printLink(q);
378 #endif
379 
380 #ifdef __CUDA_ARCH__
381  atomicAdd(unitarization_failed,1);
382 #else
383  (*unitarization_failed)++;
384 #endif
385  }
386  } // perform_svd?
387 
388  } // REUNIT_ALLOW_SVD?
389 
390 #ifdef __CUDA_ARCH__
391 #define HISQ_FORCE_FILTER DEV_HISQ_FORCE_FILTER
392 #else
393 #define HISQ_FORCE_FILTER HOST_HISQ_FORCE_FILTER
394 #endif
395  typename RealTypeId<Cmplx>::Type delta = getAbsMin(g,3);
396  if(delta < HISQ_FORCE_FILTER){
397  for(int i=0; i<3; ++i){
398  g[i] += HISQ_FORCE_FILTER;
399  q(i,i).x += HISQ_FORCE_FILTER;
400  }
401  qsq = q*q; // recalculate Q^2
402  }
403 
404 
405  // At this point we have finished with the c's
406  // use these to store sqrt(g)
407  for(int i=0; i<3; ++i) c[i] = sqrt(g[i]);
408 
409  // done with the g's, use these to store u, v, w
410  g[0] = c[0]+c[1]+c[2];
411  g[1] = c[0]*c[1] + c[0]*c[2] + c[1]*c[2];
412  g[2] = c[0]*c[1]*c[2];
413 
414  // set the derivative coefficients!
415  deriv_coeffs->set(g[0], g[1], g[2]);
416 
417  const typename RealTypeId<Cmplx>::Type & denominator = g[2]*(g[0]*g[1]-g[2]);
418  c[0] = (g[0]*g[1]*g[1] - g[2]*(g[0]*g[0]+g[1]))/denominator;
419  c[1] = (-g[0]*g[0]*g[0] - g[2] + 2.*g[0]*g[1])/denominator;
420  c[2] = g[0]/denominator;
421 
422  tempq = c[1]*q + c[2]*qsq;
423  // Add a real scalar
424  tempq(0,0).x += c[0];
425  tempq(1,1).x += c[0];
426  tempq(2,2).x += c[0];
427 
428  f[0] = c[0];
429  f[1] = c[1];
430  f[2] = c[2];
431 
432  *res = tempq;
433  return;
434  }
435 
436 
437 
438  // "v" denotes a "fattened" link variable
439  template<class Cmplx>
440  __device__ __host__
441  void getUnitarizeForceSite(const Matrix<Cmplx,3> & v, const Matrix<Cmplx,3> & outer_prod, Matrix<Cmplx,3>* result, int *unitarization_failed)
442  {
443  typename RealTypeId<Cmplx>::Type f[3];
444  typename RealTypeId<Cmplx>::Type b[6];
445 
446  Matrix<Cmplx,3> v_dagger = conj(v); // okay!
447  Matrix<Cmplx,3> q = v_dagger*v; // okay!
448 
449  Matrix<Cmplx,3> rsqrt_q;
450 
452 
453  reciprocalRoot<Cmplx>(&rsqrt_q, &deriv_coeffs, f, q, unitarization_failed);
454 
455  // Pure hack here
456  b[0] = deriv_coeffs.getB00();
457  b[1] = deriv_coeffs.getB01();
458  b[2] = deriv_coeffs.getB02();
459  b[3] = deriv_coeffs.getB11();
460  b[4] = deriv_coeffs.getB12();
461  b[5] = deriv_coeffs.getB22();
462 
463 
464  Matrix<Cmplx,3> & local_result = *result;
465 
466  local_result = rsqrt_q*outer_prod;
467 
468  // We are now finished with rsqrt_q
469  Matrix<Cmplx,3> qv_dagger = q*v_dagger;
470  Matrix<Cmplx,3> vv_dagger = v*v_dagger;
471  Matrix<Cmplx,3> vqv_dagger = v*qv_dagger;
472  Matrix<Cmplx,3> temp = f[1]*vv_dagger + f[2]*vqv_dagger;
473 
474 
475  temp = f[1]*v_dagger + f[2]*qv_dagger;
476  Matrix<Cmplx,3> conj_outer_prod = conj(outer_prod);
477 
478 
479  temp = f[1]*v + f[2]*v*q;
480  local_result = local_result + outer_prod*temp*v_dagger + f[2]*q*outer_prod*vv_dagger;
481 
482  local_result = local_result + v_dagger*conj_outer_prod*conj(temp) + f[2]*qv_dagger*conj_outer_prod*v_dagger;
483 
484 
485  // now done with vv_dagger, I think
486  Matrix<Cmplx,3> qsqv_dagger = q*qv_dagger;
487  Matrix<Cmplx,3> pv_dagger = b[0]*v_dagger + b[1]*qv_dagger + b[2]*qsqv_dagger;
488  accumBothDerivatives(&local_result, v, pv_dagger, outer_prod);
489 
490  Matrix<Cmplx,3> rv_dagger = b[1]*v_dagger + b[3]*qv_dagger + b[4]*qsqv_dagger;
491  Matrix<Cmplx,3> vq = v*q;
492  accumBothDerivatives(&local_result, vq, rv_dagger, outer_prod);
493 
494  Matrix<Cmplx,3> sv_dagger = b[2]*v_dagger + b[4]*qv_dagger + b[5]*qsqv_dagger;
495  Matrix<Cmplx,3> vqsq = vq*q;
496  accumBothDerivatives(&local_result, vqsq, sv_dagger, outer_prod);
497  return;
498  } // get unit force term
499 
500 
501 
502  template<class Cmplx>
503  __global__ void getUnitarizeForceField(const int threads, const Cmplx* link_even, const Cmplx* link_odd,
504  const Cmplx* old_force_even, const Cmplx* old_force_odd,
505  Cmplx* force_even, Cmplx* force_odd,
506  int* unitarization_failed)
507  {
508 
509  int mem_idx = blockIdx.x*blockDim.x + threadIdx.x;
510  // The number of GPU threads is equal to the local volume
511  const int HALF_VOLUME = threads/2;
512  if(mem_idx >= threads) return;
513 
514  Cmplx* force;
515  const Cmplx* link;
516  const Cmplx* old_force;
517 
518  force = force_even;
519  link = link_even;
520  old_force = old_force_even;
521  if(mem_idx >= HALF_VOLUME){
522  mem_idx = mem_idx - HALF_VOLUME;
523  force = force_odd;
524  link = link_odd;
525  old_force = old_force_odd;
526  }
527 
528 
529  // This part of the calculation is always done in double precision
530  Matrix<double2,3> v, result, oprod;
531 
532  for(int dir=0; dir<4; ++dir){
533  loadLinkVariableFromArray(old_force, dir, mem_idx, HALF_VOLUME, &oprod);
534  loadLinkVariableFromArray(link, dir, mem_idx, HALF_VOLUME, &v);
535 
536  getUnitarizeForceSite<double2>(v, oprod, &result, unitarization_failed);
537 
538  writeLinkVariableToArray(result, dir, mem_idx, HALF_VOLUME, force);
539  }
540  return;
541  } // getUnitarizeForceField
542 
543 
545  {
546 
547  int num_failures = 0;
548  Matrix<double2,3> old_force, new_force, v;
549 
550  // I can change this code to make it much more compact
551 
552  const QudaGaugeFieldOrder order = cpuGauge.Order();
553 
554  if(order == QUDA_MILC_GAUGE_ORDER){
555  for(int i=0; i<cpuGauge.Volume(); ++i){
556  for(int dir=0; dir<4; ++dir){
557  if(param.cpu_prec == QUDA_SINGLE_PRECISION){
558  copyArrayToLink(&old_force, ((float*)(cpuOldForce.Gauge_p()) + (i*4 + dir)*18));
559  copyArrayToLink(&v, ((float*)(cpuGauge.Gauge_p()) + (i*4 + dir)*18));
560  getUnitarizeForceSite<double2>(v, old_force, &new_force, &num_failures);
561  copyLinkToArray(((float*)(cpuNewForce->Gauge_p()) + (i*4 + dir)*18), new_force);
562  }else if(param.cpu_prec == QUDA_DOUBLE_PRECISION){
563  copyArrayToLink(&old_force, ((double*)(cpuOldForce.Gauge_p()) + (i*4 + dir)*18));
564  copyArrayToLink(&v, ((double*)(cpuGauge.Gauge_p()) + (i*4 + dir)*18));
565  getUnitarizeForceSite<double2>(v, old_force, &new_force, &num_failures);
566  copyLinkToArray(((double*)(cpuNewForce->Gauge_p()) + (i*4 + dir)*18), new_force);
567  } // precision?
568  } // dir
569  } // i
570  }else if(order == QUDA_QDP_GAUGE_ORDER){
571  for(int dir=0; dir<4; ++dir){
572  for(int i=0; i<cpuGauge.Volume(); ++i){
573  if(param.cpu_prec == QUDA_SINGLE_PRECISION){
574  copyArrayToLink(&old_force, ((float**)(cpuOldForce.Gauge_p()))[dir] + i*18);
575  copyArrayToLink(&v, ((float**)(cpuGauge.Gauge_p()))[dir] + i*18);
576  getUnitarizeForceSite<double2>(v, old_force, &new_force, &num_failures);
577  copyLinkToArray(((float**)(cpuNewForce->Gauge_p()))[dir] + i*18, new_force);
578  }else if(param.cpu_prec == QUDA_DOUBLE_PRECISION){
579  copyArrayToLink(&old_force, ((double**)(cpuOldForce.Gauge_p()))[dir] + i*18);
580  copyArrayToLink(&v, ((double**)(cpuGauge.Gauge_p()))[dir] + i*18);
581  getUnitarizeForceSite<double2>(v, old_force, &new_force, &num_failures);
582  copyLinkToArray(((double**)(cpuNewForce->Gauge_p()))[dir] + i*18, new_force);
583  }
584  }
585  }
586  }else{
587  errorQuda("Only MILC and QDP gauge orders supported\n");
588  }
589  return;
590  } // unitarize_force_cpu
591 
592  class UnitarizeForceCuda : public Tunable {
593  private:
594  const cudaGaugeField &oldForce;
595  const cudaGaugeField &gauge;
596  cudaGaugeField &newForce;
597  int *fails;
598 
599  int sharedBytesPerThread() const { return 0; }
600  int sharedBytesPerBlock(const TuneParam &) const { return 0; }
601 
602  // don't tune the grid dimension
603  bool advanceGridDim(TuneParam &param) const { return false; }
604 
605  // generalize Tunable::advanceBlockDim() to also set gridDim, with extra checking to ensure that gridDim isn't too large for the device
606  bool advanceBlockDim(TuneParam &param) const
607  {
608  const unsigned int max_threads = deviceProp.maxThreadsDim[0];
609  const unsigned int max_blocks = deviceProp.maxGridSize[0];
610  const unsigned int max_shared = 16384; // FIXME: use deviceProp.sharedMemPerBlock;
611  const int step = deviceProp.warpSize;
612  const int threads = gauge.Volume();
613  bool ret;
614  param.block.x += step;
615  if (param.block.x > max_threads || sharedBytesPerThread()*param.block.x > max_shared) {
616  param.block = dim3((threads+max_blocks-1)/max_blocks, 1, 1); // ensure the blockDim is large enough, given the limit on gridDim
617  param.block.x = ((param.block.x+step-1) / step) * step; // round up to the nearest "step"
618  if (param.block.x > max_threads) errorQuda("Local lattice volume is too large for device");
619  ret = false;
620  } else {
621  ret = true;
622  }
623  param.grid = dim3((threads+param.block.x-1)/param.block.x, 1, 1);
624  return ret;
625  }
626 
627  public:
628  UnitarizeForceCuda(const cudaGaugeField& oldForce, const cudaGaugeField& gauge,
629  cudaGaugeField& newForce, int* fails) :
630  oldForce(oldForce), gauge(gauge), newForce(newForce), fails(fails) { ; }
631  virtual ~UnitarizeForceCuda() { ; }
632 
633  void apply(const cudaStream_t &stream) {
634  TuneParam tp = tuneLaunch(*this, dslashTuning, verbosity);
635 
636  if(gauge.Precision() == QUDA_SINGLE_PRECISION){
637  getUnitarizeForceField<<<tp.grid,tp.block>>>(gauge.Volume(), (const float2*)gauge.Even_p(), (const float2*)gauge.Odd_p(),
638  (const float2*)oldForce.Even_p(), (const float2*)oldForce.Odd_p(),
639  (float2*)newForce.Even_p(), (float2*)newForce.Odd_p(),
640  fails);
641  }else if(gauge.Precision() == QUDA_DOUBLE_PRECISION){
642  getUnitarizeForceField<<<tp.grid,tp.block>>>(gauge.Volume(), (const double2*)gauge.Even_p(), (const double2*)gauge.Odd_p(),
643  (const double2*)oldForce.Even_p(), (const double2*)oldForce.Odd_p(),
644  (double2*)newForce.Even_p(), (double2*)newForce.Odd_p(),
645  fails);
646  }
647  }
648 
649  void preTune() { ; }
650  void postTune() { cudaMemset(fails, 0, sizeof(int)); } // reset fails counter
651 
652  virtual void initTuneParam(TuneParam &param) const
653  {
654  const unsigned int max_threads = deviceProp.maxThreadsDim[0];
655  const unsigned int max_blocks = deviceProp.maxGridSize[0];
656  const int threads = gauge.Volume();
657  const int step = deviceProp.warpSize;
658  param.block = dim3((threads+max_blocks-1)/max_blocks, 1, 1); // ensure the blockDim is large enough, given the limit on gridDim
659  param.block.x = ((param.block.x+step-1) / step) * step; // round up to the nearest "step"
660  if (param.block.x > max_threads) errorQuda("Local lattice volume is too large for device");
661  param.grid = dim3((threads+param.block.x-1)/param.block.x, 1, 1);
662  param.shared_bytes = sharedBytesPerThread()*param.block.x > sharedBytesPerBlock(param) ?
663  sharedBytesPerThread()*param.block.x : sharedBytesPerBlock(param);
664  }
665 
667  void defaultTuneParam(TuneParam &param) const {
668  initTuneParam(param);
669  }
670 
671  long long flops() const { return 0; } // FIXME: add flops counter
672 
673  TuneKey tuneKey() const {
674  std::stringstream vol, aux;
675  vol << gauge.X()[0] << "x";
676  vol << gauge.X()[1] << "x";
677  vol << gauge.X()[2] << "x";
678  vol << gauge.X()[3] << "x";
679  aux << "threads=" << gauge.Volume() << ",prec=" << gauge.Precision();
680  aux << "stride=" << gauge.Stride();
681  return TuneKey(vol.str(), typeid(*this).name(), aux.str());
682  }
683  }; // UnitarizeForceCuda
684 
686  cudaGaugeField &cudaGauge, cudaGaugeField *cudaNewForce, int* unitarization_failed) {
687  UnitarizeForceCuda unitarizeForce(cudaOldForce, cudaGauge, *cudaNewForce, unitarization_failed);
688  unitarizeForce.apply(0);
689  checkCudaError();
690  }
691 
692 
693  } // namespace fermion_force
694 } // namespace quda
695 
696