4 #include <gauge_field.h>
5 #include <gauge_field_order.h>
7 #include <quda_matrix.h>
8 #include <unitarization_links.h>
9 #include <su3_project.cuh>
10 #include <index_helper.cuh>
11 #include <instantiate.h>
12 #include <color_spinor.h>
20 #ifndef FL_UNITARIZE_PI
21 #define FL_UNITARIZE_PI 3.14159265358979323846
23 #ifndef FL_UNITARIZE_PI23
24 #define FL_UNITARIZE_PI23 FL_UNITARIZE_PI*0.66666666666666666666
28 // supress compiler warnings about unused variables when GPU_UNITARIZE is not set
29 // when we switch to C++17 consider [[maybe_unused]]
30 __attribute__((unused)) static const int max_iter_newton = 20;
31 __attribute__((unused))static const int max_iter = 20;
33 __attribute__((unused)) static double unitarize_eps = 1e-14;
34 __attribute__((unused)) static double max_error = 1e-10;
35 __attribute__((unused)) static int reunit_allow_svd = 1;
36 __attribute__((unused)) static int reunit_svd_only = 0;
37 __attribute__((unused)) static double svd_rel_error = 1e-6;
38 __attribute__((unused)) static double svd_abs_error = 1e-6;
40 template <typename Float_, int nColor_, QudaReconstructType recon_>
41 struct UnitarizeLinksArg {
43 static constexpr int nColor = nColor_;
44 static constexpr QudaReconstructType recon = recon_;
45 typedef typename gauge_mapper<Float,recon>::type Gauge;
49 int threads; // number of active threads required
50 int X[4]; // grid dimensions
53 const double unitarize_eps;
54 const double max_error;
55 const int reunit_allow_svd;
56 const int reunit_svd_only;
57 const double svd_rel_error;
58 const double svd_abs_error;
59 const static bool check_unitarization = true;
61 UnitarizeLinksArg(GaugeField &out, const GaugeField &in, int* fails, int max_iter,
62 double unitarize_eps, double max_error, int reunit_allow_svd,
63 int reunit_svd_only, double svd_rel_error, double svd_abs_error) :
66 threads(in.VolumeCB()),
68 unitarize_eps(unitarize_eps),
71 reunit_allow_svd(reunit_allow_svd),
72 reunit_svd_only(reunit_svd_only),
73 svd_rel_error(svd_rel_error),
74 svd_abs_error(svd_abs_error)
76 for (int dir=0; dir<4; ++dir) X[dir] = in.X()[dir];
80 void setUnitarizeLinksConstants(double unitarize_eps_, double max_error_,
81 bool reunit_allow_svd_, bool reunit_svd_only_,
82 double svd_rel_error_, double svd_abs_error_) {
83 unitarize_eps = unitarize_eps_;
84 max_error = max_error_;
85 reunit_allow_svd = reunit_allow_svd_;
86 reunit_svd_only = reunit_svd_only_;
87 svd_rel_error = svd_rel_error_;
88 svd_abs_error = svd_abs_error_;
91 template <typename mat>
92 __device__ __host__ bool isUnitarizedLinkConsistent(const mat &initial_matrix,
93 const mat &unitary_matrix, double max_error)
95 auto n = initial_matrix.size();
96 mat temporary = conj(initial_matrix)*unitary_matrix;
97 temporary = temporary*temporary - conj(initial_matrix)*initial_matrix;
99 for (int i=0; i<n; ++i) {
100 for (int j=0; j<n; ++j) {
101 if (fabs(temporary(i,j).x) > max_error || fabs(temporary(i,j).y) > max_error) {
110 template <class T> constexpr T getAbsMin(const T* const array, int size)
112 T min = fabs(array[0]);
113 for(int i=1; i<size; ++i){
114 T abs_val = fabs(array[i]);
115 if((abs_val) < min){ min = abs_val; }
120 template <class Real> constexpr bool checkAbsoluteError(Real a, Real b, Real epsilon) { return fabs(a-b) < epsilon; }
122 template <class Real> constexpr bool checkRelativeError(Real a, Real b, Real epsilon) { return fabs((a-b)/b) < epsilon; }
124 // Compute the reciprocal square root of the matrix q
125 // Also modify q if the eigenvalues are dangerously small.
126 template <typename real, typename mat, typename Arg>
127 __device__ __host__ bool reciprocalRoot(mat &res, const mat& q, Arg &arg)
134 const real one_third = 0.333333333333333333333;
135 const real one_ninth = 0.111111111111111111111;
136 const real one_eighteenth = 0.055555555555555555555;
141 c[0] = getTrace(q).x;
142 c[1] = getTrace(qsq).x * 0.5;
143 c[2] = getTrace(tempq).x * one_third;;
145 g[0] = g[1] = g[2] = c[0] * one_third;
147 s = c[1]*one_third - c[0]*c[0]*one_eighteenth;
150 if (fabs(s) >= arg.unitarize_eps) { // faster when this conditional is removed?
151 const real rsqrt_s = rsqrt(s);
152 r = c[2]*0.5 - (c[0]*one_third)*(c[1] - c[0]*c[0]*one_ninth);
153 cosTheta = r*rsqrt_s*rsqrt_s*rsqrt_s;
155 if(fabs(cosTheta) >= 1.0){
156 theta = (r > 0) ? 0.0 : FL_UNITARIZE_PI;
158 theta = acos(cosTheta); // this is the primary performance limiter
161 const real sqrt_s = s*rsqrt_s;
163 #if 0 // experimental version
165 sincos( theta*one_third, &as, &ac );
166 g[0] = c[0]*one_third + 2*sqrt_s*ac;
167 //g[1] = c[0]*one_third + 2*sqrt_s*(ac*cos(1*FL_UNITARIZE_PI23) - as*sin(1*FL_UNITARIZE_PI23));
168 g[1] = c[0]*one_third - 2*sqrt_s*(0.5*ac + as*0.8660254037844386467637);
169 //g[2] = c[0]*one_third + 2*sqrt_s*(ac*cos(2*FL_UNITARIZE_PI23) - as*sin(2*FL_UNITARIZE_PI23));
170 g[2] = c[0]*one_third + 2*sqrt_s*(-0.5*ac + as*0.8660254037844386467637);
172 g[0] = c[0]*one_third + 2*sqrt_s*cos( theta*one_third );
173 g[1] = c[0]*one_third + 2*sqrt_s*cos( theta*one_third + FL_UNITARIZE_PI23 );
174 g[2] = c[0]*one_third + 2*sqrt_s*cos( theta*one_third + 2*FL_UNITARIZE_PI23 );
178 // Check the eigenvalues, if the determinant does not match the product of the eigenvalues
179 // return false. Then call SVD instead.
180 real det = getDeterminant(q).x;
181 if (fabs(det) < arg.svd_abs_error) return false;
182 if (!checkRelativeError<double>(g[0]*g[1]*g[2], det, arg.svd_rel_error)) return false;
184 // At this point we have finished with the c's
185 // use these to store sqrt(g)
186 for(int i=0; i<3; ++i) c[i] = sqrt(g[i]);
188 // done with the g's, use these to store u, v, w
189 g[0] = c[0]+c[1]+c[2];
190 g[1] = c[0]*c[1] + c[0]*c[2] + c[1]*c[2];
191 g[2] = c[0]*c[1]*c[2];
193 const real denominator = 1.0 / ( g[2]*(g[0]*g[1]-g[2]) );
194 c[0] = (g[0]*g[1]*g[1] - g[2]*(g[0]*g[0]+g[1])) * denominator;
195 c[1] = (-g[0]*g[0]*g[0] - g[2] + 2.*g[0]*g[1]) * denominator;
196 c[2] = g[0] * denominator;
198 tempq = c[1]*q + c[2]*qsq;
200 tempq(0,0).x += c[0];
201 tempq(1,1).x += c[0];
202 tempq(2,2).x += c[0];
209 template <typename real, typename mat, typename Arg>
210 __host__ __device__ bool unitarizeLinkMILC(mat &out, const mat &in, Arg &arg)
213 if (!arg.reunit_svd_only) {
214 if (reciprocalRoot<real>(u, conj(in)*in, arg) ) {
220 // If we've got this far, then the Caley-Hamilton unitarization
221 // has failed. If SVD is not allowed, the unitarization has failed.
222 if (!arg.reunit_allow_svd) return false;
225 real singular_values[3];
226 computeSVD<real>(in, u, v, singular_values);
231 template <typename mat>
232 __host__ __device__ bool unitarizeLinkNewton(mat &out, const mat& in, int max_iter)
236 for (int i=0; i<max_iter; ++i) {
237 mat uinv = inverse(u);
238 u = 0.5*(u + conj(uinv));
241 if (isUnitarizedLinkConsistent(in,u,0.0000001)==false) {
242 printf("ERROR: Unitarized link is not consistent with incoming link\n");
250 void unitarizeLinksCPU(GaugeField &outfield, const GaugeField& infield)
253 if (checkLocation(outfield, infield) != QUDA_CPU_FIELD_LOCATION) errorQuda("Location must be CPU");
254 checkPrecision(outfield, infield);
256 int num_failures = 0;
257 Matrix<complex<double>,3> inlink, outlink;
259 for (unsigned int i = 0; i < infield.Volume(); ++i) {
260 for (int dir=0; dir<4; ++dir){
261 if (infield.Precision() == QUDA_SINGLE_PRECISION) {
262 copyArrayToLink(&inlink, ((float*)(infield.Gauge_p()) + (i*4 + dir)*18)); // order of arguments?
263 if (unitarizeLinkNewton(outlink, inlink, max_iter_newton) == false ) num_failures++;
264 copyLinkToArray(((float*)(outfield.Gauge_p()) + (i*4 + dir)*18), outlink);
265 } else if (infield.Precision() == QUDA_DOUBLE_PRECISION) {
266 copyArrayToLink(&inlink, ((double*)(infield.Gauge_p()) + (i*4 + dir)*18)); // order of arguments?
267 if (unitarizeLinkNewton(outlink, inlink, max_iter_newton) == false ) num_failures++;
268 copyLinkToArray(((double*)(outfield.Gauge_p()) + (i*4 + dir)*18), outlink);
271 } // loop over volume
273 errorQuda("Unitarization has not been built");
277 // CPU function which checks that the gauge field is unitary
278 bool isUnitary(const GaugeField& field, double max_error)
281 if (field.Location() != QUDA_CPU_FIELD_LOCATION) errorQuda("Location must be CPU");
282 Matrix<complex<double>,3> link, identity;
284 for (unsigned int i = 0; i < field.Volume(); ++i) {
285 for (int dir=0; dir<4; ++dir) {
286 if (field.Precision() == QUDA_SINGLE_PRECISION) {
287 copyArrayToLink(&link, ((float*)(field.Gauge_p()) + (i*4 + dir)*18)); // order of arguments?
288 } else if (field.Precision() == QUDA_DOUBLE_PRECISION) {
289 copyArrayToLink(&link, ((double*)(field.Gauge_p()) + (i*4 + dir)*18)); // order of arguments?
291 errorQuda("Unsupported precision\n");
293 if (link.isUnitary(max_error) == false) {
294 printf("Unitarity failure\n");
295 printf("site index = %u,\t direction = %d\n", i, dir);
297 identity = conj(link)*link;
305 errorQuda("Unitarization has not been built");
311 template <typename Arg> __global__ void DoUnitarizedLink(Arg arg)
313 int idx = threadIdx.x + blockIdx.x*blockDim.x;
314 int parity = threadIdx.y + blockIdx.y*blockDim.y;
315 int mu = threadIdx.z + blockIdx.z*blockDim.z;
316 if (idx >= arg.threads) return;
319 // result is always in double precision
320 Matrix<complex<double>,Arg::nColor> v, result;
321 Matrix<complex<typename Arg::Float>,Arg::nColor> tmp = arg.in(mu, idx, parity);
324 unitarizeLinkMILC<double>(result, v, arg);
325 if (arg.check_unitarization) {
326 if (result.isUnitary(arg.max_error) == false) atomicAdd(arg.fails, 1);
330 arg.out(mu, idx, parity) = tmp;
333 template <typename Float, int nColor, QudaReconstructType recon>
334 class UnitarizeLinks : TunableVectorYZ {
335 UnitarizeLinksArg<Float, nColor, recon> arg;
336 const GaugeField &meta;
338 bool tuneGridDim() const { return false; }
339 unsigned int minThreads() const { return arg.threads; }
342 UnitarizeLinks(GaugeField &out, const GaugeField &in, int* fails) :
343 TunableVectorYZ(2,4),
344 arg(out, in, fails, max_iter, unitarize_eps, max_error, reunit_allow_svd,
345 reunit_svd_only, svd_rel_error, svd_abs_error),
349 qudaDeviceSynchronize(); // need to synchronize to ensure failure write has completed
352 void apply(const qudaStream_t &stream) {
353 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
354 qudaLaunchKernel(DoUnitarizedLink<decltype(arg)>, tp, stream, arg);
357 void preTune() { if (arg.in.gauge == arg.out.gauge) arg.out.save(); }
359 if (arg.in.gauge == arg.out.gauge) arg.out.load();
360 qudaMemset(arg.fails, 0, sizeof(int)); // reset fails counter
363 long long flops() const {
364 // Accounted only the minimum flops for the case reunitarize_svd_only=0
365 return 4ll * 2 * arg.threads * 1147;
367 long long bytes() const { return 4ll * 2 * arg.threads * (arg.in.Bytes() + arg.out.Bytes()); }
369 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
372 void unitarizeLinks(GaugeField& out, const GaugeField &in, int* fails)
375 checkPrecision(out, in);
376 instantiate<UnitarizeLinks, ReconstructWilson>(out, in, fails);
378 errorQuda("Unitarization has not been built");
382 void unitarizeLinks(GaugeField &links, int* fails) { unitarizeLinks(links, links, fails); }
384 template <typename Float_, int nColor_, QudaReconstructType recon_>
385 struct ProjectSU3Arg {
386 using Float = Float_;
387 static constexpr int nColor = nColor_;
388 static constexpr QudaReconstructType recon = recon_;
389 typedef typename gauge_mapper<Float,recon>::type Gauge;
392 int threads; // number of active threads required
395 ProjectSU3Arg(GaugeField &u, Float tol, int *fails) :
396 threads(u.VolumeCB()),
402 template<typename Arg>
403 __global__ void ProjectSU3kernel(Arg arg){
404 using real = typename Arg::Float;
405 int idx = threadIdx.x + blockIdx.x*blockDim.x;
406 int parity = threadIdx.y + blockIdx.y*blockDim.y;
407 int mu = threadIdx.z + blockIdx.z*blockDim.z;
408 if (idx >= arg.threads) return;
411 Matrix<complex<real>, Arg::nColor> u = arg.u(mu, idx, parity);
413 polarSu3<real>(u, arg.tol);
415 // count number of failures
416 if (u.isUnitary(arg.tol) == false) {
417 atomicAdd(arg.fails, 1);
420 arg.u(mu, idx, parity) = u;
423 template <typename Float, int nColor, QudaReconstructType recon>
424 class ProjectSU3 : TunableVectorYZ {
425 ProjectSU3Arg<Float, nColor, recon> arg;
426 const GaugeField &meta;
428 bool tuneGridDim() const { return false; }
429 unsigned int minThreads() const { return arg.threads; }
432 ProjectSU3(GaugeField &u, double tol, int *fails) :
433 arg(u, static_cast<Float>(tol), fails),
434 TunableVectorYZ(2, 4),
438 qudaDeviceSynchronize();
441 void apply(const qudaStream_t &stream) {
442 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
443 qudaLaunchKernel(ProjectSU3kernel<decltype(arg)>, tp, stream, arg);
446 void preTune() { arg.u.save(); }
449 qudaMemset(arg.fails, 0, sizeof(int)); // reset fails counter
452 long long flops() const { return 0; } // depends on number of iterations
453 long long bytes() const { return 4ll * 2 * arg.threads * 2 * arg.u.Bytes(); }
454 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
457 void projectSU3(GaugeField &u, double tol, int *fails) {
458 #ifdef GPU_GAUGE_TOOLS
459 // check the the field doesn't have staggered phases applied
460 if (u.StaggeredPhaseApplied())
461 errorQuda("Cannot project gauge field with staggered phases applied");
463 instantiate<ProjectSU3, ReconstructWilson>(u, tol, fails);
465 errorQuda("Gauge tools have not been built");