11 #define DEVICEHOST __device__ __host__
13 #define LOG2 0.69314718055994530942
23 if(fabs(z.x) > fabs(z.y)){ max = z.x; ratio = z.y/max; }
else{ max=z.y; ratio = z.x/max; }
24 square = max*max*(1.0 + ratio*ratio);
38 template<
class T,
class U>
41 if(fabs(a) > fabs(b)){ max = a; ratio = b/a; }
else{ max=b; ratio = a/b; }
42 square = max*max*(1.0 + ratio*ratio);
60 float temp1, temp2, temp3;
71 double temp1, temp2, temp3;
87 temp2 =
conj(tau)*temp1;
100 Real m11 = b(1,1)*b(1,1) + b(0,1)*b(0,1);
101 Real m22 = b(2,2)*b(2,2) + b(1,2)*b(1,2);
102 Real m12 = b(1,1)*b(1,2);
103 Real dm = (m11 - m22)/2.0;
107 lambda_max = m22 - (m12*m12)/(norm1 + dm);
109 lambda_max = m22 + (m12*m12)/(norm1 - dm);
122 }
else if( fabs(beta) > fabs(alpha) ){
124 s = rsqrt(1.0 + ratio*ratio);
128 c = rsqrt(1.0 + ratio*ratio);
137 int index_p1 = index+1;
139 for(
int i=0; i<3; ++i){
141 beta = m(i,index_p1);
142 m(i,index) = alpha*c - beta*
s;
143 m(i,index_p1) = alpha*s + beta*c;
180 m(0,0) = m(0,1)*c - m(1,1)*
s;
193 }
else if( m(1,1) == 0.0 ){
197 m(0,0) = m(0,0)*c - m(0,1)*
s;
203 }
else if( m(0,1) != 0.0 ){
206 Real abs01 = fabs(m(0,1));
207 Real abs11 = fabs(m(1,1));
209 if( abs01 > abs11 ){ min = abs11; max = abs01; }
210 else { min = abs01; max = abs11; }
213 Real ratio = min/max;
214 Real alpha = 2.0*log(max) + log(1.0 + ratio*ratio);
216 Real abs00 = fabs(m(0,0));
217 Real beta = 2.0*log(abs00);
224 temp = alpha + log(1.0 - exp(beta-alpha));
227 temp = beta + log(1.0 - exp(alpha-beta));
229 temp -=
LOG2 + log(abs00) + log(abs01);
230 temp = sign*exp(temp);
232 if( m(0,0) < 0.0 ){ temp *= -1.0; }
233 if( m(0,1) < 0.0 ){ temp *= -1.0; }
238 temp = 1.0/(temp + beta);
240 temp = 1.0/(temp - beta);
251 p(0,0) = c*m(0,0) - s*m(0,1);
253 p(0,1) = s*m(0,0) + c*m(0,1);
259 alpha =
quadSum(p(0,0),p(1,0));
272 m(0,0) = p(0,0)*c - s*p(1,0);
273 m(1,1) = p(0,1)*s + c*p(1,1);
284 template<
class Cmplx>
295 const Cmplx COMPLEX_UNITY = makeComplex<Cmplx>(1,0);
296 const Cmplx COMPLEX_ZERO = makeComplex<Cmplx>(0,0);
306 if(norm1 == 0 &&
mat(0,0).y == 0){
314 if(
mat(0,0).x > 0.0){ beta = -beta; }
321 vec[0] = COMPLEX_UNITY;
322 vec[1] =
mat(1,0)*w/norm1;
323 vec[2] =
mat(2,0)*w/norm1;
326 tau.x = (beta -
mat(0,0).x)/beta;
327 tau.y =
mat(0,0).y/beta;
336 if(norm2 != 0.0 || p(0,1).y != 0.0){
337 norm1 =
cabs(p(0,1));
339 vec[0] = COMPLEX_ZERO;
340 vec[1] = COMPLEX_UNITY;
342 if( p(0,1).x > 0.0 ){ beta = -beta; }
346 z =
conj(p(0,2))/norm1;
349 tau = (beta - p(0,1))/beta;
357 norm2 =
cabs(p(2,1));
358 if(norm2 != 0.0 || p(1,1).y != 0.0){
359 norm1 =
cabs(p(1,1));
363 vec[0] = COMPLEX_ZERO;
364 vec[1] = COMPLEX_UNITY;
366 if( p(1,1).x > 0 ){ beta = -beta; }
373 tau.x = (beta - p(1,1).x)/beta;
374 tau.y = p(1,1).y/beta;
386 if( p(1,2).y != 0.0 ){
388 if( p(1,2).x > 0.0 ){ beta = -beta; }
389 temp(2,2) =
conj(p(1,2))/beta;
390 p(2,2) = p(2,2)*temp(2,2);
396 if( p(2,2).y != 0.0 ){
398 if( p(2,2).x > 0.0 ){ beta = -beta; }
399 temp(2,2) = p(2,2)/beta;
423 if( fabs(b(0,1)) <
SVDPREC*( fabs(b(0,0)) + fabs(b(1,1)) ) ){ b(0,1) = 0.0; }
424 if( fabs(b(1,2)) <
SVDPREC*( fabs(b(0,0)) + fabs(b(2,2)) ) ){ b(1,2) = 0.0; }
426 if( b(0,1) != 0.0 && b(1,2) != 0.0 ){
431 for(
int i=0; i<3; ++i){
434 u(i,0) = alpha*c - beta*
s;
435 u(i,1) = alpha*s + beta*c;
438 b(1,1) = b(0,1)*s + b(1,1)*c;
446 for(
int i=0; i<3; ++i){
449 u(i,0) = alpha*c - beta*
s;
450 u(i,2) = alpha*s + beta*c;
452 b(2,2) = b(0,2)*s + b(2,2)*c;
455 }
else if( b(1,1) == 0.0 ){
458 for(
int i=0; i<3; ++i){
461 u(i,1) = alpha*c - beta*
s;
462 u(i,2) = alpha*s + beta*c;
464 b(2,2) = b(1,2)*s + b(2,2)*c;
467 }
else if( b(2,2) == 0.0 ){
470 for(
int i=0; i<3; ++i){
473 v(i,1) = alpha*c + beta*
s;
474 v(i,2) = -alpha*s + beta*c;
476 b(1,1) = b(1,1)*c + b(1,2)*
s;
484 for(
int i=0; i<3; ++i){
487 v(i,0) = alpha*c + beta*
s;
488 v(i,2) = -alpha*s + beta*c;
490 b(0,0) = b(0,0)*c + b(0,2)*
s;
498 alpha = b(0,0)*b(0,0) - lambda_max;
499 beta = b(0,0)*b(0,1);
509 b(0,0) = alpha*c - beta*
s;
510 b(0,1) = alpha*s + beta*c;
520 b(0,0) = b(0,0)*c - b(1,0)*
s;
523 b(0,1) = alpha*c - beta*
s;
524 b(1,1) = alpha*s + beta*c;
536 b(0,1) = b(0,1)*c - b(0,2)*
s;
540 b(1,1) = alpha*c - beta*
s;
541 b(1,2) = alpha*s + beta*c;
552 b(1,1) = b(1,1)*c - b(2,1)*
s;
555 b(1,2) = alpha*c - beta*
s;
556 b(2,2) = alpha*s + beta*c;
562 }
else if( b(0,1) == 0.0 ){
565 m_small(0,0) = b(1,1);
566 m_small(0,1) = b(1,2);
567 m_small(1,0) = b(2,1);
568 m_small(1,1) = b(2,2);
570 smallSVD(u_small, v_small, m_small);
572 b(1,1) = m_small(0,0);
573 b(1,2) = m_small(0,1);
574 b(2,1) = m_small(1,0);
575 b(2,2) = m_small(1,1);
578 for(
int i=0; i<3; ++i){
581 u(i,1) = alpha*u_small(0,0) + beta*u_small(1,0);
582 u(i,2) = alpha*u_small(0,1) + beta*u_small(1,1);
586 v(i,1) = alpha*v_small(0,0) + beta*v_small(1,0);
587 v(i,2) = alpha*v_small(0,1) + beta*v_small(1,1);
591 }
else if( b(1,2) == 0.0 ){
594 m_small(0,0) = b(0,0);
595 m_small(0,1) = b(0,1);
596 m_small(1,0) = b(1,0);
597 m_small(1,1) = b(1,1);
599 smallSVD(u_small, v_small, m_small);
601 b(0,0) = m_small(0,0);
602 b(0,1) = m_small(0,1);
603 b(1,0) = m_small(1,0);
604 b(1,1) = m_small(1,1);
606 for(
int i=0; i<3; ++i){
609 u(i,0) = alpha*u_small(0,0) + beta*u_small(1,0);
610 u(i,1) = alpha*u_small(0,1) + beta*u_small(1,1);
614 v(i,0) = alpha*v_small(0,0) + beta*v_small(1,0);
615 v(i,1) = alpha*v_small(0,1) + beta*v_small(1,1);
619 }
while( (b(0,1) != 0.0 || b(1,2) != 0.0) && it < max_it);
622 for(
int i=0; i<3; ++i){
625 for(
int j=0; j<3; ++j){
635 template<
class Cmplx>
643 getRealBidiagMatrix<Cmplx>(m, u, v);
646 for(
int i=0; i<3; ++i){
647 for(
int j=0; j<3; ++j){
648 bd(i,j) = (
conj(u)*m*(v))(i,j).
x;
652 bdSVD(u_real, v_real, bd, 40);
653 for(
int i=0; i<3; ++i){
654 singular_values[i] = bd(i,i);
666 #endif // _SVD_QUDA_H