4 #include <gauge_field.h>
6 #include <quda_matrix.h>
7 #include <gauge_field_order.h>
8 #include <instantiate.h>
9 #include <color_spinor.h>
13 namespace { // anonymous
17 #define HISQ_UNITARIZE_PI 3.14159265358979323846
18 #define HISQ_UNITARIZE_PI23 HISQ_UNITARIZE_PI*2.0/3.0
20 static double unitarize_eps;
21 static double force_filter;
22 static double max_det_error;
23 static bool allow_svd;
25 static double svd_rel_error;
26 static double svd_abs_error;
28 namespace fermion_force {
30 template <typename Float_, int nColor_, QudaReconstructType recon_, QudaGaugeFieldOrder order = QUDA_NATIVE_GAUGE_ORDER>
31 struct UnitarizeForceArg {
33 static constexpr int nColor = nColor_;
34 static constexpr QudaReconstructType recon = recon_;
35 // use long form here to allow specification of order
36 typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO,2*nColor*nColor,QUDA_STAGGERED_PHASE_NO,gauge::default_huge_alloc, QUDA_GHOST_EXCHANGE_INVALID,false,order>::type F;
37 typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO,2*nColor*nColor,QUDA_STAGGERED_PHASE_NO,gauge::default_huge_alloc, QUDA_GHOST_EXCHANGE_INVALID,false,order>::type G;
43 const double unitarize_eps;
44 const double force_filter;
45 const double max_det_error;
48 const double svd_rel_error;
49 const double svd_abs_error;
51 UnitarizeForceArg(GaugeField &force, const GaugeField &force_old, const GaugeField &u, int *fails,
52 double unitarize_eps, double force_filter, double max_det_error, int allow_svd,
53 int svd_only, double svd_rel_error, double svd_abs_error) :
58 unitarize_eps(unitarize_eps),
59 force_filter(force_filter),
60 max_det_error(max_det_error),
63 svd_rel_error(svd_rel_error),
64 svd_abs_error(svd_abs_error),
65 threads(u.VolumeCB()) { }
68 void setUnitarizeForceConstants(double unitarize_eps_, double force_filter_,
69 double max_det_error_, bool allow_svd_, bool svd_only_,
70 double svd_rel_error_, double svd_abs_error_)
72 unitarize_eps = unitarize_eps_;
73 force_filter = force_filter_;
74 max_det_error = max_det_error_;
75 allow_svd = allow_svd_;
77 svd_rel_error = svd_rel_error_;
78 svd_abs_error = svd_abs_error_;
81 template <class Real> class DerivativeCoefficients {
83 constexpr Real computeC00(const Real &u, const Real &v, const Real &w)
85 return -pow(w,3) * pow(u,6) + 3*v*pow(w,3)*pow(u,4) + 3*pow(v,4)*w*pow(u,4)
86 - pow(v,6)*pow(u,3) - 4*pow(w,4)*pow(u,3) - 12*pow(v,3)*pow(w,2)*pow(u,3)
87 + 16*pow(v,2)*pow(w,3)*pow(u,2) + 3*pow(v,5)*w*pow(u,2) - 8*v*pow(w,4)*u
88 - 3*pow(v,4)*pow(w,2)*u + pow(w,5) + pow(v,3)*pow(w,3);
91 constexpr Real computeC01(const Real & u, const Real & v, const Real & w)
93 return -pow(w,2)*pow(u,7) - pow(v,2)*w*pow(u,6) + pow(v,4)*pow(u,5) + 6*v*pow(w,2)*pow(u,5)
94 - 5*pow(w,3)*pow(u,4) - pow(v,3)*w*pow(u,4)- 2*pow(v,5)*pow(u,3) - 6*pow(v,2)*pow(w,2)*pow(u,3)
95 + 10*v*pow(w,3)*pow(u,2) + 6*pow(v,4)*w*pow(u,2) - 3*pow(w,4)*u - 6*pow(v,3)*pow(w,2)*u + 2*pow(v,2)*pow(w,3);
98 constexpr Real computeC02(const Real & u, const Real & v, const Real & w)
100 return pow(w,2)*pow(u,5) + pow(v,2)*w*pow(u,4)- pow(v,4)*pow(u,3)- 4*v*pow(w,2)*pow(u,3)
101 + 4*pow(w,3)*pow(u,2) + 3*pow(v,3)*w*pow(u,2) - 3*pow(v,2)*pow(w,2)*u + v*pow(w,3);
104 constexpr Real computeC11(const Real & u, const Real & v, const Real & w)
106 return -w*pow(u,8) - pow(v,2)*pow(u,7) + 7*v*w*pow(u,6) + 4*pow(v,3)*pow(u,5)
107 - 5*pow(w,2)*pow(u,5) - 16*pow(v,2)*w*pow(u,4) - 4*pow(v,4)*pow(u,3) + 16*v*pow(w,2)*pow(u,3)
108 - 3*pow(w,3)*pow(u,2) + 12*pow(v,3)*w*pow(u,2) - 12*pow(v,2)*pow(w,2)*u + 3*v*pow(w,3);
111 constexpr Real computeC12(const Real &u, const Real &v, const Real &w)
113 return w*pow(u,6) + pow(v,2)*pow(u,5) - 5*v*w*pow(u,4) - 2*pow(v,3)*pow(u,3)
114 + 4*pow(w,2)*pow(u,3) + 6*pow(v,2)*w*pow(u,2) - 6*v*pow(w,2)*u + pow(w,3);
117 constexpr Real computeC22(const Real &u, const Real &v, const Real &w)
119 return -w*pow(u,4) - pow(v,2)*pow(u,3) + 3*v*w*pow(u,2) - 3*pow(w,2)*u;
123 constexpr void set(const Real &u, const Real &v, const Real &w)
125 const Real denominator = 1.0 / (2.0*pow(w*(u*v-w),3));
126 b[0] = computeC00(u,v,w) * denominator;
127 b[1] = computeC01(u,v,w) * denominator;
128 b[2] = computeC02(u,v,w) * denominator;
129 b[3] = computeC11(u,v,w) * denominator;
130 b[4] = computeC12(u,v,w) * denominator;
131 b[5] = computeC22(u,v,w) * denominator;
134 constexpr Real getB00() const { return b[0]; }
135 constexpr Real getB01() const { return b[1]; }
136 constexpr Real getB02() const { return b[2]; }
137 constexpr Real getB11() const { return b[3]; }
138 constexpr Real getB12() const { return b[4]; }
139 constexpr Real getB22() const { return b[5]; }
142 template <typename mat>
143 __device__ __host__ void accumBothDerivatives(mat &result, const mat &left, const mat &right, const mat &outer_prod)
145 auto temp = (2.0*getTrace(left*outer_prod)).real();
146 for (int k=0; k<3; ++k) {
147 for (int l=0; l<3; ++l) {
148 result(k,l) += temp*right(k,l);
154 __device__ __host__ void accumDerivatives(mat &result, const mat &left, const mat &right, const mat &outer_prod)
156 auto temp = getTrace(left*outer_prod);
157 for(int k=0; k<3; ++k){
158 for(int l=0; l<3; ++l){
159 result(k,l) = temp*right(k,l);
164 template<class T> constexpr T getAbsMin(const T* const array, int size)
166 T min = fabs(array[0]);
167 for (int i=1; i<size; ++i) {
168 T abs_val = fabs(array[i]);
169 if ((abs_val) < min){ min = abs_val; }
174 template<class Real> constexpr bool checkAbsoluteError(Real a, Real b, Real epsilon) { return fabs(a-b) < epsilon; }
175 template<class Real> constexpr bool checkRelativeError(Real a, Real b, Real epsilon) { return fabs((a-b)/b) < epsilon; }
177 // Compute the reciprocal square root of the matrix q
178 // Also modify q if the eigenvalues are dangerously small.
179 template<class Float, typename Arg>
180 __device__ __host__ void reciprocalRoot(Matrix<complex<Float>,3>* res, DerivativeCoefficients<Float>* deriv_coeffs,
181 Float f[3], Matrix<complex<Float>,3> & q, Arg &arg)
183 Matrix<complex<Float>,3> qsq, tempq;
192 c[0] = getTrace(q).x;
193 c[1] = getTrace(qsq).x/2.0;
194 c[2] = getTrace(tempq).x/3.0;
196 g[0] = g[1] = g[2] = c[0]/3.;
198 s = c[1]/3. - c[0]*c[0]/18;
199 r = c[2]/2. - (c[0]/3.)*(c[1] - c[0]*c[0]/9.);
201 Float cosTheta = r*rsqrt(s*s*s);
202 if (fabs(s) < arg.unitarize_eps) {
206 if(fabs(cosTheta)>1.0){ r>0 ? theta=0.0 : theta=HISQ_UNITARIZE_PI/3.0; }
207 else{ theta = acos(cosTheta)/3.0; }
210 for (int i=0; i<3; ++i) {
211 g[i] += s*cos(theta + (i-1)*HISQ_UNITARIZE_PI23);
214 } // !REUNIT_SVD_ONLY?
217 // Compare the product of the eigenvalues computed thus far to the
218 // absolute value of the determinant.
219 // If the determinant is very small or the relative error is greater than some predefined value
220 // then recompute the eigenvalues using a singular-value decomposition.
221 // Note that this particular calculation contains multiple branches,
222 // so it doesn't appear to be particularly well-suited to the GPU
223 // programming model. However, the analytic calculation of the
224 // unitarization is extremely fast, and if the SVD routine is not called
225 // too often, we expect pretty good performance.
229 bool perform_svd = true;
231 const Float det = getDeterminant(q).x;
232 if( fabs(det) >= arg.svd_abs_error) {
233 if( checkRelativeError(g[0]*g[1]*g[2],det,arg.svd_rel_error) ) perform_svd = false;
238 Matrix<complex<Float>,3> tmp2;
239 // compute the eigenvalues using the singular value decomposition
240 computeSVD<Float>(q,tempq,tmp2,g);
241 // The array g contains the eigenvalues of the matrix q
242 // The determinant is the product of the eigenvalues, and I can use this
244 const Float determinant = getDeterminant(q).x;
245 const Float gprod = g[0]*g[1]*g[2];
246 // Check the svd result for errors
247 if (fabs(gprod - determinant) > arg.max_det_error) {
248 printf("Warning: Error in determinant computed by SVD : %g > %g\n", fabs(gprod-determinant), arg.max_det_error);
252 atomicAdd(arg.fails, 1);
259 } // REUNIT_ALLOW_SVD?
261 Float delta = getAbsMin(g,3);
262 if (delta < arg.force_filter) {
263 for (int i=0; i<3; ++i) {
264 g[i] += arg.force_filter;
265 q(i,i).x += arg.force_filter;
267 qsq = q*q; // recalculate Q^2
271 // At this point we have finished with the c's
272 // use these to store sqrt(g)
273 for (int i=0; i<3; ++i) c[i] = sqrt(g[i]);
275 // done with the g's, use these to store u, v, w
276 g[0] = c[0]+c[1]+c[2];
277 g[1] = c[0]*c[1] + c[0]*c[2] + c[1]*c[2];
278 g[2] = c[0]*c[1]*c[2];
280 // set the derivative coefficients!
281 deriv_coeffs->set(g[0], g[1], g[2]);
283 const Float& denominator = g[2]*(g[0]*g[1]-g[2]);
284 c[0] = (g[0]*g[1]*g[1] - g[2]*(g[0]*g[0]+g[1]))/denominator;
285 c[1] = (-g[0]*g[0]*g[0] - g[2] + 2.*g[0]*g[1])/denominator;
286 c[2] = g[0]/denominator;
288 tempq = c[1]*q + c[2]*qsq;
290 tempq(0,0).x += c[0];
291 tempq(1,1).x += c[0];
292 tempq(2,2).x += c[0];
301 // "v" denotes a "fattened" link variable
302 template <class Float, typename Arg>
303 __device__ __host__ void getUnitarizeForceSite(Matrix<complex<Float>,3>& result, const Matrix<complex<Float>,3> & v,
304 const Matrix<complex<Float>,3> & outer_prod, Arg &arg)
306 typedef Matrix<complex<Float>,3> Link;
310 Link v_dagger = conj(v); // okay!
311 Link q = v_dagger*v; // okay!
314 DerivativeCoefficients<Float> deriv_coeffs;
316 reciprocalRoot<Float>(&rsqrt_q, &deriv_coeffs, f, q, arg); // approx 529 flops (assumes no SVD)
319 b[0] = deriv_coeffs.getB00();
320 b[1] = deriv_coeffs.getB01();
321 b[2] = deriv_coeffs.getB02();
322 b[3] = deriv_coeffs.getB11();
323 b[4] = deriv_coeffs.getB12();
324 b[5] = deriv_coeffs.getB22();
326 result = rsqrt_q*outer_prod;
328 // We are now finished with rsqrt_q
329 Link qv_dagger = q*v_dagger;
330 Link vv_dagger = v*v_dagger;
331 Link vqv_dagger = v*qv_dagger;
332 Link temp = f[1]*vv_dagger + f[2]*vqv_dagger;
334 temp = f[1]*v_dagger + f[2]*qv_dagger;
335 Link conj_outer_prod = conj(outer_prod);
337 temp = f[1]*v + f[2]*v*q;
338 result = result + outer_prod*temp*v_dagger + f[2]*q*outer_prod*vv_dagger;
339 result = result + v_dagger*conj_outer_prod*conj(temp) + f[2]*qv_dagger*conj_outer_prod*v_dagger;
341 Link qsqv_dagger = q*qv_dagger;
342 Link pv_dagger = b[0]*v_dagger + b[1]*qv_dagger + b[2]*qsqv_dagger;
343 accumBothDerivatives(result, v, pv_dagger, outer_prod); // 41 flops
345 Link rv_dagger = b[1]*v_dagger + b[3]*qv_dagger + b[4]*qsqv_dagger;
347 accumBothDerivatives(result, vq, rv_dagger, outer_prod); // 41 flops
349 Link sv_dagger = b[2]*v_dagger + b[4]*qv_dagger + b[5]*qsqv_dagger;
351 accumBothDerivatives(result, vqsq, sv_dagger, outer_prod); // 41 flops
353 // 4528 flops - 17 matrix multiplies (198 flops each) + reciprocal root (approx 529 flops) + accumBothDerivatives (41 each) + miscellaneous
354 } // get unit force term
356 template <typename Arg> __global__ void getUnitarizeForceField(Arg arg)
358 using real = typename Arg::Float;
359 int x_cb = blockIdx.x*blockDim.x + threadIdx.x;
360 if (x_cb >= arg.threads) return;
361 int parity = blockIdx.y*blockDim.y + threadIdx.y;
363 // This part of the calculation is always done in double precision
364 Matrix<complex<double>,3> v, result, oprod;
365 Matrix<complex<real>,3> v_tmp, result_tmp, oprod_tmp;
367 for (int dir=0; dir<4; ++dir) {
368 oprod_tmp = arg.force_old(dir, x_cb, parity);
369 v_tmp = arg.u(dir, x_cb, parity);
373 getUnitarizeForceSite<double>(result, v, oprod, arg);
376 arg.force(dir, x_cb, parity) = result_tmp;
377 } // 4*4528 flops per site
378 } // getUnitarizeForceField
380 template <typename Float, int nColor, QudaReconstructType recon> class UnitarizeForce : public TunableVectorY {
381 UnitarizeForceArg<Float, nColor, recon> arg;
382 const GaugeField &meta;
384 // don't tune the grid dimension
385 bool tuneGridDim() const { return false; }
386 unsigned int minThreads() const { return arg.threads; }
389 UnitarizeForce(GaugeField &newForce, const GaugeField &oldForce, const GaugeField &u, int* fails) :
391 arg(newForce, oldForce, u, fails, unitarize_eps, force_filter,
392 max_det_error, allow_svd, svd_only, svd_rel_error, svd_abs_error),
396 qudaDeviceSynchronize(); // need to synchronize to ensure failure write has completed
399 void apply(const qudaStream_t &stream) {
400 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
401 qudaLaunchKernel(getUnitarizeForceField<decltype(arg)>, tp, stream, arg);
405 void postTune() { qudaMemset(arg.fails, 0, sizeof(int)); } // reset fails counter
407 long long flops() const { return 4ll*4528*meta.Volume(); }
408 long long bytes() const { return 4ll * meta.Volume() * (arg.force.Bytes() + arg.force_old.Bytes() + arg.u.Bytes()); }
410 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
413 void unitarizeForce(GaugeField &newForce, const GaugeField &oldForce, const GaugeField &u,
416 #ifdef GPU_HISQ_FORCE
417 checkReconstruct(u, oldForce, newForce);
418 checkPrecision(u, oldForce, newForce);
420 if (!u.isNative() || !oldForce.isNative() || !newForce.isNative())
421 errorQuda("Only native order supported");
423 instantiate<UnitarizeForce,ReconstructNone>(newForce, oldForce, u, fails);
425 errorQuda("HISQ force has not been built");
429 template <typename Float, typename Arg> void unitarizeForceCPU(Arg &arg)
431 #ifdef GPU_HISQ_FORCE
432 Matrix<complex<double>, 3> v, result, oprod;
433 Matrix<complex<Float>, 3> v_tmp, result_tmp, oprod_tmp;
435 for (int parity = 0; parity < 2; parity++) {
436 for (int i = 0; i < arg.threads; i++) {
437 for (int dir = 0; dir < 4; dir++) {
438 oprod_tmp = arg.force_old(dir, i, parity);
439 v_tmp = arg.u(dir, i, parity);
443 getUnitarizeForceSite<double>(result, v, oprod, arg);
446 arg.force(dir, i, parity) = result_tmp;
451 errorQuda("HISQ force has not been built");
455 void unitarizeForceCPU(GaugeField &newForce, const GaugeField &oldForce, const GaugeField &u)
457 if (checkLocation(newForce, oldForce, u) != QUDA_CPU_FIELD_LOCATION) errorQuda("Location must be CPU");
458 int num_failures = 0;
459 constexpr int nColor = 3;
460 Matrix<complex<double>, nColor> old_force, new_force, v;
461 if (u.Order() == QUDA_MILC_GAUGE_ORDER) {
462 if (u.Precision() == QUDA_DOUBLE_PRECISION) {
463 UnitarizeForceArg<double, nColor, QUDA_RECONSTRUCT_NO, QUDA_MILC_GAUGE_ORDER> arg(
464 newForce, oldForce, u, &num_failures, unitarize_eps, force_filter, max_det_error, allow_svd, svd_only,
465 svd_rel_error, svd_abs_error);
466 unitarizeForceCPU<double>(arg);
467 } else if (u.Precision() == QUDA_SINGLE_PRECISION) {
468 UnitarizeForceArg<float, nColor, QUDA_RECONSTRUCT_NO, QUDA_MILC_GAUGE_ORDER> arg(
469 newForce, oldForce, u, &num_failures, unitarize_eps, force_filter, max_det_error, allow_svd, svd_only,
470 svd_rel_error, svd_abs_error);
471 unitarizeForceCPU<float>(arg);
473 errorQuda("Precision = %d not supported", u.Precision());
475 } else if (u.Order() == QUDA_QDP_GAUGE_ORDER) {
476 if (u.Precision() == QUDA_DOUBLE_PRECISION) {
477 UnitarizeForceArg<double, nColor, QUDA_RECONSTRUCT_NO, QUDA_QDP_GAUGE_ORDER> arg(
478 newForce, oldForce, u, &num_failures, unitarize_eps, force_filter, max_det_error, allow_svd, svd_only,
479 svd_rel_error, svd_abs_error);
480 unitarizeForceCPU<double>(arg);
481 } else if (u.Precision() == QUDA_SINGLE_PRECISION) {
482 UnitarizeForceArg<float, nColor, QUDA_RECONSTRUCT_NO, QUDA_QDP_GAUGE_ORDER> arg(
483 newForce, oldForce, u, &num_failures, unitarize_eps, force_filter, max_det_error, allow_svd, svd_only,
484 svd_rel_error, svd_abs_error);
485 unitarizeForceCPU<float>(arg);
487 errorQuda("Precision = %d not supported", u.Precision());
490 errorQuda("Only MILC and QDP gauge orders supported\n");
492 if (num_failures) errorQuda("Unitarization failed, failures = %d", num_failures);
493 } // unitarize_force_cpu
495 } // namespace fermion_force