4 #define DEVICEHOST __device__ __host__
6 #define LOG2 0.69314718055994530942
13 typename RealTypeId<Cmplx>::Type
cabs(
const Cmplx & z)
15 typename RealTypeId<Cmplx>::Type max, ratio, square;
16 if(fabs(z.x) > fabs(z.y)){ max = z.x; ratio = z.y/max; }
else{ max=z.y; ratio = z.x/max; }
17 square = max*max*(1.0 + ratio*ratio);
31 template<
class T,
class U>
33 typename PromoteTypeId<T,U>::Type ratio, square, max;
34 if(fabs(a) > fabs(b)){ max = a; ratio = b/a; }
else{ max=b; ratio = a/b; }
35 square = max*max*(1.0 + ratio*ratio);
53 float temp1, temp2, temp3;
63 double getNorm(
const Array<double2,3>& a){
64 double temp1, temp2, temp3;
80 temp2 =
Conj(tau)*temp1;
93 Real m11 = b(1,1)*b(1,1) + b(0,1)*b(0,1);
94 Real m22 = b(2,2)*b(2,2) + b(1,2)*b(1,2);
95 Real m12 = b(1,1)*b(1,2);
96 Real dm = (m11 - m22)/2.0;
100 lambda_max = m22 - (m12*m12)/(norm1 + dm);
102 lambda_max = m22 + (m12*m12)/(norm1 - dm);
115 }
else if( fabs(beta) > fabs(alpha) ){
117 s = rsqrt(1.0 + ratio*ratio);
121 c = rsqrt(1.0 + ratio*ratio);
130 int index_p1 = index+1;
132 for(
int i=0; i<3; ++i){
134 beta = m(i,index_p1);
135 m(i,index) = alpha*c - beta*
s;
136 m(i,index_p1) = alpha*s + beta*c;
173 m(0,0) = m(0,1)*c - m(1,1)*
s;
186 }
else if( m(1,1) == 0.0 ){
190 m(0,0) = m(0,0)*c - m(0,1)*
s;
196 }
else if( m(0,1) != 0.0 ){
199 Real abs01 = fabs(m(0,1));
200 Real abs11 = fabs(m(1,1));
202 if( abs01 > abs11 ){ min = abs11; max = abs01; }
203 else { min = abs01; max = abs11; }
206 Real ratio = min/max;
207 Real alpha = 2.0*
log(max) +
log(1.0 + ratio*ratio);
209 Real abs00 = fabs(m(0,0));
210 Real beta = 2.0*
log(abs00);
217 temp = alpha +
log(1.0 -
exp(beta-alpha));
220 temp = beta +
log(1.0 -
exp(alpha-beta));
223 temp = sign*
exp(temp);
225 if( m(0,0) < 0.0 ){ temp *= -1.0; }
226 if( m(0,1) < 0.0 ){ temp *= -1.0; }
231 temp = 1.0/(temp + beta);
233 temp = 1.0/(temp - beta);
244 p(0,0) = c*m(0,0) - s*m(0,1);
246 p(0,1) = s*m(0,0) + c*m(0,1);
252 alpha =
quadSum(p(0,0),p(1,0));
265 m(0,0) = p(0,0)*c - s*p(1,0);
266 m(1,1) = p(0,1)*s + c*p(1,1);
277 template<
class Cmplx>
288 const Cmplx COMPLEX_UNITY = makeComplex<Cmplx>(1,0);
289 const Cmplx COMPLEX_ZERO = makeComplex<Cmplx>(0,0);
293 typename RealTypeId<Cmplx>::Type
x =
cabs(
mat(1,0));
294 typename RealTypeId<Cmplx>::Type
y =
cabs(
mat(2,0));
295 typename RealTypeId<Cmplx>::Type norm1 =
quadSum(x,y);
296 typename RealTypeId<Cmplx>::Type beta;
299 if(norm1 == 0 &&
mat(0,0).y == 0){
302 Array<Cmplx,3> temp_vec;
307 if(
mat(0,0).x > 0.0){ beta = -beta; }
309 w.x =
mat(0,0).x - beta;
315 vec[0] = COMPLEX_UNITY;
316 vec[1] =
mat(1,0)*w/norm1;
317 vec[2] =
mat(2,0)*w/norm1;
320 tau.x = (beta -
mat(0,0).x)/beta;
321 tau.y =
mat(0,0).y/beta;
329 typename RealTypeId<Cmplx>::Type
norm2 =
cabs(p(0,2));
330 if(norm2 != 0.0 || p(0,1).y != 0.0){
331 norm1 =
cabs(p(0,1));
333 vec[0] = COMPLEX_ZERO;
334 vec[1] = COMPLEX_UNITY;
336 if( p(0,1).x > 0.0 ){ beta = -beta; }
341 z =
Conj(p(0,2))/norm1;
344 tau.x = (beta - p(0,1).x)/beta;
345 tau.y = (- p(0,1).y)/beta;
353 norm2 =
cabs(p(2,1));
354 if(norm2 != 0.0 || p(1,1).y != 0.0){
355 norm1 =
cabs(p(1,1));
359 vec[0] = COMPLEX_ZERO;
360 vec[1] = COMPLEX_UNITY;
362 if( p(1,1).x > 0 ){ beta = -beta; }
363 w.x = p(1,1).x - beta;
370 tau.x = (beta - p(1,1).x)/beta;
371 tau.y = p(1,1).y/beta;
383 if( p(1,2).y != 0.0 ){
385 if( p(1,2).x > 0.0 ){ beta = -beta; }
386 temp(2,2) =
Conj(p(1,2))/beta;
387 p(2,2) = p(2,2)*temp(2,2);
393 if( p(2,2).y != 0.0 ){
395 if( p(2,2).x > 0.0 ){ beta = -beta; }
396 temp(2,2) = p(2,2)/beta;
420 if( fabs(b(0,1)) <
SVDPREC*( fabs(b(0,0)) + fabs(b(1,1)) ) ){ b(0,1) = 0.0; }
421 if( fabs(b(1,2)) <
SVDPREC*( fabs(b(0,0)) + fabs(b(2,2)) ) ){ b(1,2) = 0.0; }
423 if( b(0,1) != 0.0 && b(1,2) != 0.0 ){
428 for(
int i=0; i<3; ++i){
431 u(i,0) = alpha*c - beta*
s;
432 u(i,1) = alpha*s + beta*c;
435 b(1,1) = b(0,1)*s + b(1,1)*c;
443 for(
int i=0; i<3; ++i){
446 u(i,0) = alpha*c - beta*
s;
447 u(i,2) = alpha*s + beta*c;
449 b(2,2) = b(0,2)*s + b(2,2)*c;
452 }
else if( b(1,1) == 0.0 ){
455 for(
int i=0; i<3; ++i){
458 u(i,1) = alpha*c - beta*
s;
459 u(i,2) = alpha*s + beta*c;
461 b(2,2) = b(1,2)*s + b(2,2)*c;
464 }
else if( b(2,2) == 0.0 ){
467 for(
int i=0; i<3; ++i){
470 v(i,1) = alpha*c + beta*
s;
471 v(i,2) = -alpha*s + beta*c;
473 b(1,1) = b(1,1)*c + b(1,2)*
s;
481 for(
int i=0; i<3; ++i){
484 v(i,0) = alpha*c + beta*
s;
485 v(i,2) = -alpha*s + beta*c;
487 b(0,0) = b(0,0)*c + b(0,2)*
s;
495 alpha = b(0,0)*b(0,0) - lambda_max;
496 beta = b(0,0)*b(0,1);
506 b(0,0) = alpha*c - beta*
s;
507 b(0,1) = alpha*s + beta*c;
517 b(0,0) = b(0,0)*c - b(1,0)*
s;
520 b(0,1) = alpha*c - beta*
s;
521 b(1,1) = alpha*s + beta*c;
533 b(0,1) = b(0,1)*c - b(0,2)*
s;
537 b(1,1) = alpha*c - beta*
s;
538 b(1,2) = alpha*s + beta*c;
549 b(1,1) = b(1,1)*c - b(2,1)*
s;
552 b(1,2) = alpha*c - beta*
s;
553 b(2,2) = alpha*s + beta*c;
559 }
else if( b(0,1) == 0.0 ){
562 m_small(0,0) = b(1,1);
563 m_small(0,1) = b(1,2);
564 m_small(1,0) = b(2,1);
565 m_small(1,1) = b(2,2);
567 smallSVD(u_small, v_small, m_small);
569 b(1,1) = m_small(0,0);
570 b(1,2) = m_small(0,1);
571 b(2,1) = m_small(1,0);
572 b(2,2) = m_small(1,1);
575 for(
int i=0; i<3; ++i){
578 u(i,1) = alpha*u_small(0,0) + beta*u_small(1,0);
579 u(i,2) = alpha*u_small(0,1) + beta*u_small(1,1);
583 v(i,1) = alpha*v_small(0,0) + beta*v_small(1,0);
584 v(i,2) = alpha*v_small(0,1) + beta*v_small(1,1);
588 }
else if( b(1,2) == 0.0 ){
591 m_small(0,0) = b(0,0);
592 m_small(0,1) = b(0,1);
593 m_small(1,0) = b(1,0);
594 m_small(1,1) = b(1,1);
596 smallSVD(u_small, v_small, m_small);
598 b(0,0) = m_small(0,0);
599 b(0,1) = m_small(0,1);
600 b(1,0) = m_small(1,0);
601 b(1,1) = m_small(1,1);
603 for(
int i=0; i<3; ++i){
606 u(i,0) = alpha*u_small(0,0) + beta*u_small(1,0);
607 u(i,1) = alpha*u_small(0,1) + beta*u_small(1,1);
611 v(i,0) = alpha*v_small(0,0) + beta*v_small(1,0);
612 v(i,1) = alpha*v_small(0,1) + beta*v_small(1,1);
617 }
while( (b(0,1) != 0.0 || b(1,2) != 0.0) && it < max_it);
620 for(
int i=0; i<3; ++i){
623 for(
int j=0; j<3; ++j){
633 template<
class Cmplx>
638 typename RealTypeId<Cmplx>::Type singular_values[3])
641 getRealBidiagMatrix<Cmplx>(m, u, v);
644 for(
int i=0; i<3; ++i){
645 for(
int j=0; j<3; ++j){
646 bd(i,j) = (
conj(u)*m*(v))(i,j).
x;
650 bdSVD(u_real, v_real, bd, 500);
651 for(
int i=0; i<3; ++i){
652 singular_values[i] = bd(i,i);
664 #endif // _SVD_QUDA_H
DEVICEHOST float getNorm(const Array< float2, 3 > &a)
Matrix< N, std::complex< T > > conj(const Matrix< N, std::complex< T > > &mat)
__host__ __device__ ValueType exp(ValueType x)
__host__ __device__ ValueType sqrt(ValueType x)
DEVICEHOST void getGivensRotation(const Real &alpha, const Real &beta, Real &c, Real &s)
DEVICEHOST void getLambdaMax(const Matrix< Real, 3 > &b, Real &lambda_max)
void mat(void *out, void **fatlink, void **longlink, void *in, double kappa, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision)
__device__ __host__ void outerProd(const Array< T, N > &a, const Array< T, N > &b, Matrix< T, N > *m)
DEVICEHOST void smallSVD(Matrix< Real, 2 > &u, Matrix< Real, 2 > &v, Matrix< Real, 2 > &m)
__device__ __host__ Cmplx Conj(const Cmplx &a)
__device__ __host__ void copyColumn(const Matrix< T, N > &m, int c, Array< T, N > *a)
double norm2(Float *v, int len)
__device__ __host__ void setIdentity(Matrix< T, N > *m)
DEVICEHOST PromoteTypeId< T, U >::Type quadSum(const T &a, const U &b)
__host__ __device__ ValueType log(ValueType x)
DEVICEHOST void swap(Real &a, Real &b)
DEVICEHOST void assignGivensRotation(const Real &c, const Real &s, Matrix< Real, 2 > &m)
DEVICEHOST RealTypeId< Cmplx >::Type cabs(const Cmplx &z)
DEVICEHOST void constructHHMat(const T &tau, const Array< T, 3 > &v, Matrix< T, 3 > &hh)
DEVICEHOST void computeSVD(const Matrix< Cmplx, 3 > &m, Matrix< Cmplx, 3 > &u, Matrix< Cmplx, 3 > &v, typename RealTypeId< Cmplx >::Type singular_values[3])
DEVICEHOST void bdSVD(Matrix< Real, 3 > &u, Matrix< Real, 3 > &v, Matrix< Real, 3 > &b, int max_it)
DEVICEHOST void getRealBidiagMatrix(const Matrix< Cmplx, 3 > &mat, Matrix< Cmplx, 3 > &u, Matrix< Cmplx, 3 > &v)
DEVICEHOST void accumGivensRotation(int index, const Real &c, const Real &s, Matrix< Real, 3 > &m)