6 #define DEVICEHOST __device__ __host__ 8 #define LOG2 0.69314718055994530942 9 #define INVALID_DOUBLE (-DBL_MAX) 17 typename std::remove_reference<decltype(Cmplx::x)>::type
cabs(
const Cmplx & z)
19 typedef typename std::remove_reference<decltype(Cmplx::x)>::type real;
20 real max, ratio, square;
21 if(fabs(z.x) > fabs(z.y)){ max = z.x; ratio = z.y/max; }
else{ max=z.y; ratio = z.x/max; }
22 square = (max != 0.0) ? max*max*(1.0 + ratio*ratio) : 0.0;
27 template<
class T,
class U>
29 typename PromoteTypeId<T,U>::Type ratio, square, max;
30 if (fabs(a) > fabs(b)) { max = a; ratio = b/a; }
else { max=b; ratio = a/b; }
31 square = (max != 0.0) ? max*max*(1.0 + ratio*ratio) : 0.0;
38 inline float getNorm(
const Array<complex<float>,3>& a){
39 float temp1, temp2, temp3;
49 inline double getNorm(
const Array<complex<double>,3>& a){
50 double temp1, temp2, temp3;
66 temp2 =
conj(tau)*temp1;
79 Real m11 = b(1,1)*b(1,1) + b(0,1)*b(0,1);
80 Real m22 = b(2,2)*b(2,2) + b(1,2)*b(1,2);
81 Real m12 = b(1,1)*b(1,2);
82 Real dm = (m11 - m22) * 0.5;
86 lambda_max = m22 - (m12*m12)/(norm1 + dm);
88 lambda_max = m22 + (m12*m12)/(norm1 - dm);
101 }
else if( fabs(beta) > fabs(alpha) ){
103 s = rsqrt(1.0 + ratio*ratio);
107 c = rsqrt(1.0 + ratio*ratio);
116 int index_p1 = index+1;
118 for(
int i=0; i<3; ++i){
120 beta = m(i,index_p1);
121 m(i,index) = alpha*c - beta*
s;
122 m(i,index_p1) = alpha*s + beta*c;
159 m(0,0) = m(0,1)*c - m(1,1)*
s;
172 }
else if( m(1,1) == 0.0 ){
176 m(0,0) = m(0,0)*c - m(0,1)*
s;
182 }
else if( m(0,1) != 0.0 ){
185 Real abs01 = fabs(m(0,1));
186 Real abs11 = fabs(m(1,1));
188 if( abs01 > abs11 ){ min = abs11; max = abs01; }
189 else { min = abs01; max = abs11; }
192 Real ratio = min/max;
193 Real alpha = 2.0*
log(max) +
log(1.0 + ratio*ratio);
195 Real abs00 = fabs(m(0,0));
196 Real beta = 2.0*
log(abs00);
203 temp = alpha +
log(1.0 -
exp(beta-alpha));
206 temp = beta +
log(1.0 -
exp(alpha-beta));
209 temp = sign*
exp(temp);
211 if( m(0,0) < 0.0 ){ temp *= -1.0; }
212 if( m(0,1) < 0.0 ){ temp *= -1.0; }
215 temp = 1.0/(temp + beta);
225 p(0,0) = c*m(0,0) - s*m(0,1);
227 p(0,1) = s*m(0,0) + c*m(0,1);
233 alpha =
quadSum(p(0,0),p(1,0));
246 m(0,0) = p(0,0)*c - s*p(1,0);
247 m(1,1) = p(0,1)*s + c*p(1,1);
258 template<
class Float>
262 typedef complex<Float>
Complex;
265 Array<Complex,3> vec;
269 const Complex COMPLEX_UNITY(1.0,0.0);
270 const Complex COMPLEX_ZERO = 0.0;
286 if(norm1 == 0 &&
mat(0,0).y == 0){
289 Array<Complex,3> temp_vec;
294 w.x =
mat(0,0).x - beta;
296 Float norm1_inv = 1.0/
cabs(w);
297 w =
conj(w)*norm1_inv;
300 vec[0] = COMPLEX_UNITY;
301 vec[1] =
mat(1,0)*w*norm1_inv;
302 vec[2] =
mat(2,0)*w*norm1_inv;
305 tau.x = (beta -
mat(0,0).x)/beta;
306 tau.y =
mat(0,0).y/beta;
315 if(norm2 != 0.0 || p(0,1).y != 0.0){
316 norm1 =
cabs(p(0,1));
317 beta = (p(0,1).x > 0.0) ? -
quadSum(norm1,norm2) :
quadSum(norm1,norm2);
318 vec[0] = COMPLEX_ZERO;
319 vec[1] = COMPLEX_UNITY;
323 Float norm1_inv = 1.0/
cabs(w);
324 w =
conj(w)*norm1_inv;
325 z =
conj(p(0,2))*norm1_inv;
328 tau.x = (beta - p(0,1).x)/beta;
329 tau.y = (- p(0,1).y)/beta;
337 norm2 =
cabs(p(2,1));
338 if(norm2 != 0.0 || p(1,1).y != 0.0){
339 norm1 =
cabs(p(1,1));
340 beta = (p(1,1).x > 0) ? -
quadSum(norm1,norm2) :
quadSum(norm1,norm2);
343 vec[0] = COMPLEX_ZERO;
344 vec[1] = COMPLEX_UNITY;
346 w.x = p(1,1).x - beta;
348 Float norm1_inv = 1.0/
cabs(w);
349 w =
conj(w)*norm1_inv;
350 z = p(2,1)*norm1_inv;
353 tau.x = (beta - p(1,1).x)/beta;
354 tau.y = p(1,1).y/beta;
366 if( p(1,2).y != 0.0 ){
367 beta = p(1,2).x > 0.0 ? -
cabs(p(1,2)) :
cabs(p(1,2));
368 temp(2,2) =
conj(p(1,2))/beta;
369 p(2,2) = p(2,2)*temp(2,2);
374 if( p(2,2).y != 0.0 ){
375 beta = p(2,2).x > 0.0 ? -
cabs(p(2,2)) :
cabs(p(2,2));
376 temp(2,2) = p(2,2)/beta;
406 if( fabs(b(0,1)) <
SVDPREC*( fabs(b(0,0)) + fabs(b(1,1)) ) ){ b(0,1) = 0.0; }
407 if( fabs(b(1,2)) <
SVDPREC*( fabs(b(0,0)) + fabs(b(2,2)) ) ){ b(1,2) = 0.0; }
409 if( b(0,1) != 0.0 && b(1,2) != 0.0 ){
414 for(
int i=0; i<3; ++i){
417 u(i,0) = alpha*c - beta*
s;
418 u(i,1) = alpha*s + beta*c;
421 b(1,1) = b(0,1)*s + b(1,1)*c;
429 for(
int i=0; i<3; ++i){
432 u(i,0) = alpha*c - beta*
s;
433 u(i,2) = alpha*s + beta*c;
435 b(2,2) = b(0,2)*s + b(2,2)*c;
438 }
else if( b(1,1) == 0.0 ){
441 for(
int i=0; i<3; ++i){
444 u(i,1) = alpha*c - beta*
s;
445 u(i,2) = alpha*s + beta*c;
447 b(2,2) = b(1,2)*s + b(2,2)*c;
450 }
else if( b(2,2) == 0.0 ){
453 for(
int i=0; i<3; ++i){
456 v(i,1) = alpha*c + beta*
s;
457 v(i,2) = -alpha*s + beta*c;
459 b(1,1) = b(1,1)*c + b(1,2)*
s;
467 for(
int i=0; i<3; ++i){
470 v(i,0) = alpha*c + beta*
s;
471 v(i,2) = -alpha*s + beta*c;
473 b(0,0) = b(0,0)*c + b(0,2)*
s;
481 alpha = b(0,0)*b(0,0) - lambda_max;
482 beta = b(0,0)*b(0,1);
492 b(0,0) = alpha*c - beta*
s;
493 b(0,1) = alpha*s + beta*c;
503 b(0,0) = b(0,0)*c - b(1,0)*
s;
506 b(0,1) = alpha*c - beta*
s;
507 b(1,1) = alpha*s + beta*c;
519 b(0,1) = b(0,1)*c - b(0,2)*
s;
523 b(1,1) = alpha*c - beta*
s;
524 b(1,2) = alpha*s + beta*c;
535 b(1,1) = b(1,1)*c - b(2,1)*
s;
538 b(1,2) = alpha*c - beta*
s;
539 b(2,2) = alpha*s + beta*c;
545 }
else if( b(0,1) == 0.0 ){
548 m_small(0,0) = b(1,1);
549 m_small(0,1) = b(1,2);
550 m_small(1,0) = b(2,1);
551 m_small(1,1) = b(2,2);
553 smallSVD(u_small, v_small, m_small);
555 b(1,1) = m_small(0,0);
556 b(1,2) = m_small(0,1);
557 b(2,1) = m_small(1,0);
558 b(2,2) = m_small(1,1);
561 for(
int i=0; i<3; ++i){
564 u(i,1) = alpha*u_small(0,0) + beta*u_small(1,0);
565 u(i,2) = alpha*u_small(0,1) + beta*u_small(1,1);
569 v(i,1) = alpha*v_small(0,0) + beta*v_small(1,0);
570 v(i,2) = alpha*v_small(0,1) + beta*v_small(1,1);
574 }
else if( b(1,2) == 0.0 ){
577 m_small(0,0) = b(0,0);
578 m_small(0,1) = b(0,1);
579 m_small(1,0) = b(1,0);
580 m_small(1,1) = b(1,1);
582 smallSVD(u_small, v_small, m_small);
584 b(0,0) = m_small(0,0);
585 b(0,1) = m_small(0,1);
586 b(1,0) = m_small(1,0);
587 b(1,1) = m_small(1,1);
589 for(
int i=0; i<3; ++i){
592 u(i,0) = alpha*u_small(0,0) + beta*u_small(1,0);
593 u(i,1) = alpha*u_small(0,1) + beta*u_small(1,1);
597 v(i,0) = alpha*v_small(0,0) + beta*v_small(1,0);
598 v(i,1) = alpha*v_small(0,1) + beta*v_small(1,1);
603 }
while( (b(0,1) != 0.0 || b(1,2) != 0.0) && it < max_it);
605 for(
int i=0; i<3; ++i){
608 for(
int j=0; j<3; ++j){
618 template<
class Float>
621 Matrix<complex<Float>,3>& v, Float singular_values[3])
623 getRealBidiagMatrix<Float>(m, u, v);
627 for(
int i=0; i<3; ++i){
628 for(
int j=0; j<3; ++j){
629 bd(i,j) = (
conj(u)*m*(v))(i,j).real();
633 bdSVD(u_real, v_real, bd, 500);
634 for(
int i=0; i<3; ++i){
635 singular_values[i] = bd(i,i);
646 #undef INVALID_DOUBLE 648 #endif // _SVD_QUDA_H Matrix< N, std::complex< T > > conj(const Matrix< N, std::complex< T > > &mat)
__host__ __device__ ValueType exp(ValueType x)
__host__ __device__ ValueType sqrt(ValueType x)
DEVICEHOST void getGivensRotation(const Real &alpha, const Real &beta, Real &c, Real &s)
DEVICEHOST void getLambdaMax(const Matrix< Real, 3 > &b, Real &lambda_max)
DEVICEHOST float getNorm(const Array< complex< float >, 3 > &a)
__device__ __host__ void outerProd(const Array< T, N > &a, const Array< T, N > &b, Matrix< T, N > *m)
DEVICEHOST void computeSVD(const Matrix< complex< Float >, 3 > &m, Matrix< complex< Float >, 3 > &u, Matrix< complex< Float >, 3 > &v, Float singular_values[3])
DEVICEHOST void smallSVD(Matrix< Real, 2 > &u, Matrix< Real, 2 > &v, Matrix< Real, 2 > &m)
DEVICEHOST std::remove_reference< decltype(Cmplx::x)>::type cabs(const Cmplx &z)
__device__ __host__ void copyColumn(const Matrix< T, N > &m, int c, Array< T, N > *a)
double norm2(Float *v, int len)
std::complex< double > Complex
__device__ __host__ void setIdentity(Matrix< T, N > *m)
DEVICEHOST void getRealBidiagMatrix(const Matrix< complex< Float >, 3 > &mat, Matrix< complex< Float >, 3 > &u, Matrix< complex< Float >, 3 > &v)
DEVICEHOST PromoteTypeId< T, U >::Type quadSum(const T &a, const U &b)
__host__ __device__ ValueType log(ValueType x)
static int index(int ndim, const int *dims, const int *x)
double norm1(const ColorSpinorField &b)
DEVICEHOST void swap(Real &a, Real &b)
DEVICEHOST void assignGivensRotation(const Real &c, const Real &s, Matrix< Real, 2 > &m)
DEVICEHOST void constructHHMat(const T &tau, const Array< T, 3 > &v, Matrix< T, 3 > &hh)
DEVICEHOST void bdSVD(Matrix< Real, 3 > &u, Matrix< Real, 3 > &v, Matrix< Real, 3 > &b, int max_it)
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
DEVICEHOST void accumGivensRotation(int index, const Real &c, const Real &s, Matrix< Real, 3 > &m)