5 #define DEVICEHOST __device__ __host__
7 #define LOG2 0.69314718055994530942
8 #define INVALID_DOUBLE (-DBL_MAX)
15 typename std::remove_reference<decltype(Cmplx::x)>::type
cabs(
const Cmplx & z)
17 typedef typename std::remove_reference<decltype(Cmplx::x)>::type real;
18 real max, ratio, square;
19 if(fabs(z.x) > fabs(z.y)){ max = z.x; ratio = z.y/max; }
else{ max=z.y; ratio = z.x/max; }
20 square = (max != 0.0) ? max*max*(1.0 + ratio*ratio) : 0.0;
24 template <
class T,
class U>
inline DEVICEHOST typename PromoteTypeId<T, U>::type
quadSum(
const T &a,
const U &b)
26 typename PromoteTypeId<T, U>::type ratio, square, max;
27 if (fabs(a) > fabs(b)) { max = a; ratio = b/a; }
else { max=b; ratio = a/b; }
28 square = (max != 0.0) ? max*max*(1.0 + ratio*ratio) : 0.0;
34 inline float getNorm(
const Array<complex<float>,3>& a){
35 float temp1, temp2, temp3;
45 inline double getNorm(
const Array<complex<double>,3>& a){
46 double temp1, temp2, temp3;
58 ColorSpinor<decltype(tau.real()), 3, 1> v_;
59 for (
int i = 0; i < 3; i++) v_(0, i) = v[i];
62 temp2 =
conj(tau)*temp1;
75 Real m11 = b(1,1)*b(1,1) + b(0,1)*b(0,1);
76 Real m22 = b(2,2)*b(2,2) + b(1,2)*b(1,2);
77 Real m12 = b(1,1)*b(1,2);
78 Real dm = (m11 - m22) * 0.5;
82 lambda_max = m22 - (m12*m12)/(
norm1 + dm);
84 lambda_max = m22 + (m12*m12)/(
norm1 - dm);
96 }
else if( fabs(beta) > fabs(alpha) ){
98 s = rsqrt(1.0 + ratio*ratio);
102 c = rsqrt(1.0 + ratio*ratio);
111 int index_p1 = index+1;
113 for(
int i=0; i<3; ++i){
115 beta = m(i,index_p1);
116 m(i,index) = alpha*c - beta*s;
117 m(i,index_p1) = alpha*s + beta*c;
154 m(0,0) = m(0,1)*c - m(1,1)*s;
167 }
else if( m(1,1) == 0.0 ){
171 m(0,0) = m(0,0)*c - m(0,1)*s;
177 }
else if( m(0,1) != 0.0 ){
180 Real abs01 = fabs(m(0,1));
181 Real abs11 = fabs(m(1,1));
183 if( abs01 > abs11 ){ min = abs11; max = abs01; }
184 else { min = abs01; max = abs11; }
187 Real ratio = min/max;
188 Real alpha = 2.0*
log(max) +
log(1.0 + ratio*ratio);
190 Real abs00 = fabs(m(0,0));
191 Real beta = 2.0*
log(abs00);
198 temp = alpha +
log(1.0 -
exp(beta-alpha));
201 temp = beta +
log(1.0 -
exp(alpha-beta));
204 temp = sign*
exp(temp);
206 if( m(0,0) < 0.0 ){ temp *= -1.0; }
207 if( m(0,1) < 0.0 ){ temp *= -1.0; }
210 temp = 1.0/(temp + beta);
220 p(0,0) = c*m(0,0) - s*m(0,1);
222 p(0,1) = s*m(0,0) + c*m(0,1);
228 alpha =
quadSum(p(0,0),p(1,0));
241 m(0,0) = p(0,0)*c - s*p(1,0);
242 m(1,1) = p(0,1)*s + c*p(1,1);
253 template<
class Float>
257 typedef complex<Float>
Complex;
260 Array<Complex,3> vec;
262 const Complex COMPLEX_UNITY(1.0,0.0);
263 const Complex COMPLEX_ZERO = 0.0;
282 Array<Complex,3> temp_vec;
287 w.x =
mat(0,0).x - beta;
290 w =
conj(w)*norm1_inv;
293 vec[0] = COMPLEX_UNITY;
294 vec[1] =
mat(1,0)*w*norm1_inv;
295 vec[2] =
mat(2,0)*w*norm1_inv;
298 tau.x = (beta -
mat(0,0).x)/beta;
299 tau.y =
mat(0,0).y/beta;
308 if(
norm2 != 0.0 || p(0,1).y != 0.0){
311 vec[0] = COMPLEX_ZERO;
312 vec[1] = COMPLEX_UNITY;
317 w =
conj(w)*norm1_inv;
318 z =
conj(p(0,2))*norm1_inv;
321 tau.x = (beta - p(0,1).x)/beta;
322 tau.y = (- p(0,1).y)/beta;
331 if(
norm2 != 0.0 || p(1,1).y != 0.0){
336 vec[0] = COMPLEX_ZERO;
337 vec[1] = COMPLEX_UNITY;
339 w.x = p(1,1).x - beta;
342 w =
conj(w)*norm1_inv;
343 z = p(2,1)*norm1_inv;
346 tau.x = (beta - p(1,1).x)/beta;
347 tau.y = p(1,1).y/beta;
359 if( p(1,2).y != 0.0 ){
360 beta = p(1,2).x > 0.0 ? -
cabs(p(1,2)) :
cabs(p(1,2));
361 temp(2,2) =
conj(p(1,2))/beta;
362 p(2,2) = p(2,2)*temp(2,2);
367 if( p(2,2).y != 0.0 ){
368 beta = p(2,2).x > 0.0 ? -
cabs(p(2,2)) :
cabs(p(2,2));
369 temp(2,2) = p(2,2)/beta;
399 if( fabs(b(0,1)) <
SVDPREC*( fabs(b(0,0)) + fabs(b(1,1)) ) ){ b(0,1) = 0.0; }
400 if( fabs(b(1,2)) <
SVDPREC*( fabs(b(0,0)) + fabs(b(2,2)) ) ){ b(1,2) = 0.0; }
402 if( b(0,1) != 0.0 && b(1,2) != 0.0 ){
407 for(
int i=0; i<3; ++i){
410 u(i,0) = alpha*c - beta*s;
411 u(i,1) = alpha*s + beta*c;
414 b(1,1) = b(0,1)*s + b(1,1)*c;
422 for(
int i=0; i<3; ++i){
425 u(i,0) = alpha*c - beta*s;
426 u(i,2) = alpha*s + beta*c;
428 b(2,2) = b(0,2)*s + b(2,2)*c;
431 }
else if( b(1,1) == 0.0 ){
434 for(
int i=0; i<3; ++i){
437 u(i,1) = alpha*c - beta*s;
438 u(i,2) = alpha*s + beta*c;
440 b(2,2) = b(1,2)*s + b(2,2)*c;
443 }
else if( b(2,2) == 0.0 ){
446 for(
int i=0; i<3; ++i){
449 v(i,1) = alpha*c + beta*s;
450 v(i,2) = -alpha*s + beta*c;
452 b(1,1) = b(1,1)*c + b(1,2)*s;
460 for(
int i=0; i<3; ++i){
463 v(i,0) = alpha*c + beta*s;
464 v(i,2) = -alpha*s + beta*c;
466 b(0,0) = b(0,0)*c + b(0,2)*s;
474 alpha = b(0,0)*b(0,0) - lambda_max;
475 beta = b(0,0)*b(0,1);
485 b(0,0) = alpha*c - beta*s;
486 b(0,1) = alpha*s + beta*c;
496 b(0,0) = b(0,0)*c - b(1,0)*s;
499 b(0,1) = alpha*c - beta*s;
500 b(1,1) = alpha*s + beta*c;
512 b(0,1) = b(0,1)*c - b(0,2)*s;
516 b(1,1) = alpha*c - beta*s;
517 b(1,2) = alpha*s + beta*c;
528 b(1,1) = b(1,1)*c - b(2,1)*s;
531 b(1,2) = alpha*c - beta*s;
532 b(2,2) = alpha*s + beta*c;
538 }
else if( b(0,1) == 0.0 ){
541 m_small(0,0) = b(1,1);
542 m_small(0,1) = b(1,2);
543 m_small(1,0) = b(2,1);
544 m_small(1,1) = b(2,2);
546 smallSVD(u_small, v_small, m_small);
548 b(1,1) = m_small(0,0);
549 b(1,2) = m_small(0,1);
550 b(2,1) = m_small(1,0);
551 b(2,2) = m_small(1,1);
554 for(
int i=0; i<3; ++i){
557 u(i,1) = alpha*u_small(0,0) + beta*u_small(1,0);
558 u(i,2) = alpha*u_small(0,1) + beta*u_small(1,1);
562 v(i,1) = alpha*v_small(0,0) + beta*v_small(1,0);
563 v(i,2) = alpha*v_small(0,1) + beta*v_small(1,1);
567 }
else if( b(1,2) == 0.0 ){
570 m_small(0,0) = b(0,0);
571 m_small(0,1) = b(0,1);
572 m_small(1,0) = b(1,0);
573 m_small(1,1) = b(1,1);
575 smallSVD(u_small, v_small, m_small);
577 b(0,0) = m_small(0,0);
578 b(0,1) = m_small(0,1);
579 b(1,0) = m_small(1,0);
580 b(1,1) = m_small(1,1);
582 for(
int i=0; i<3; ++i){
585 u(i,0) = alpha*u_small(0,0) + beta*u_small(1,0);
586 u(i,1) = alpha*u_small(0,1) + beta*u_small(1,1);
590 v(i,0) = alpha*v_small(0,0) + beta*v_small(1,0);
591 v(i,1) = alpha*v_small(0,1) + beta*v_small(1,1);
596 }
while( (b(0,1) != 0.0 || b(1,2) != 0.0) && it < max_it);
598 for(
int i=0; i<3; ++i){
601 for(
int j=0; j<3; ++j){
609 template <
class Float>
611 Matrix<complex<Float>, 3> &v,
Float singular_values[3])
613 getRealBidiagMatrix<Float>(m, u, v);
617 for(
int i=0; i<3; ++i){
618 for(
int j=0; j<3; ++j){
619 bd(i,j) = (
conj(u)*m*(v))(i,j).real();
623 bdSVD(u_real, v_real, bd, 500);
624 for(
int i=0; i<3; ++i){
625 singular_values[i] = bd(i,i);
634 #undef INVALID_DOUBLE
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
Matrix< N, std::complex< T > > conj(const Matrix< N, std::complex< T > > &mat)
double norm1(const ColorSpinorField &b)
double norm2(const ColorSpinorField &a)
__device__ __host__ Matrix< complex< Float >, Nc > outerProduct(const ColorSpinor< Float, Nc, 1 > &a, const ColorSpinor< Float, Nc, 1 > &b)
__host__ __device__ ValueType log(ValueType x)
std::complex< double > Complex
__host__ __device__ ValueType sqrt(ValueType x)
__device__ __host__ void setIdentity(Matrix< T, N > *m)
__host__ __device__ ValueType exp(ValueType x)
__device__ __host__ void copyColumn(const Matrix< T, N > &m, int c, Array< T, N > *a)
FloatingPoint< float > Float
DEVICEHOST void getLambdaMax(const Matrix< Real, 3 > &b, Real &lambda_max)
DEVICEHOST float getNorm(const Array< complex< float >, 3 > &a)
DEVICEHOST void computeSVD(const Matrix< complex< Float >, 3 > &m, Matrix< complex< Float >, 3 > &u, Matrix< complex< Float >, 3 > &v, Float singular_values[3])
DEVICEHOST void getGivensRotation(const Real &alpha, const Real &beta, Real &c, Real &s)
DEVICEHOST PromoteTypeId< T, U >::type quadSum(const T &a, const U &b)
DEVICEHOST void getRealBidiagMatrix(const Matrix< complex< Float >, 3 > &mat, Matrix< complex< Float >, 3 > &u, Matrix< complex< Float >, 3 > &v)
DEVICEHOST std::remove_reference< decltype(Cmplx::x)>::type cabs(const Cmplx &z)
DEVICEHOST void smallSVD(Matrix< Real, 2 > &u, Matrix< Real, 2 > &v, Matrix< Real, 2 > &m)
DEVICEHOST void swap(Real &a, Real &b)
DEVICEHOST void assignGivensRotation(const Real &c, const Real &s, Matrix< Real, 2 > &m)
DEVICEHOST void accumGivensRotation(int index, const Real &c, const Real &s, Matrix< Real, 3 > &m)
DEVICEHOST void bdSVD(Matrix< Real, 3 > &u, Matrix< Real, 3 > &v, Matrix< Real, 3 > &b, int max_it)
DEVICEHOST auto constructHHMat(const T &tau, const V &v)