6 #define DEVICEHOST __device__ __host__ 8 #define LOG2 0.69314718055994530942 9 #define INVALID_DOUBLE (-DBL_MAX) 17 typename std::remove_reference<decltype(Cmplx::x)>::type
cabs(
const Cmplx &
z)
19 typedef typename std::remove_reference<decltype(Cmplx::x)>::type real;
20 real max, ratio, square;
21 if(
fabs(
z.x) >
fabs(
z.y)){ max =
z.x; ratio =
z.y/max; }
else{ max=
z.y; ratio =
z.x/max; }
22 square = (max != 0.0) ? max*max*(1.0 + ratio*ratio) : 0.0;
27 template<
class T,
class U>
29 typename PromoteTypeId<T,U>::Type ratio, square, max;
30 if (
fabs(
a) >
fabs(
b)) { max =
a; ratio =
b/
a; }
else { max=
b; ratio =
a/
b; }
31 square = (max != 0.0) ? max*max*(1.0 + ratio*ratio) : 0.0;
38 inline float getNorm(
const Array<complex<float>,3>&
a){
39 float temp1, temp2, temp3;
49 inline double getNorm(
const Array<complex<double>,3>&
a){
50 double temp1, temp2, temp3;
66 temp2 =
conj(tau)*temp1;
79 Real m11 =
b(1,1)*
b(1,1) +
b(0,1)*
b(0,1);
80 Real m22 =
b(2,2)*
b(2,2) +
b(1,2)*
b(1,2);
81 Real m12 =
b(1,1)*
b(1,2);
82 Real dm = (m11 - m22) * 0.5;
86 lambda_max = m22 - (m12*m12)/(
norm1 + dm);
88 lambda_max = m22 + (m12*m12)/(
norm1 - dm);
101 }
else if(
fabs(beta) >
fabs(alpha) ){
103 s = rsqrt(1.0 + ratio*ratio);
107 c = rsqrt(1.0 + ratio*ratio);
116 int index_p1 =
index+1;
118 for(
int i=0;
i<3; ++
i){
120 beta = m(
i,index_p1);
122 m(
i,index_p1) = alpha*
s + beta*
c;
159 m(0,0) = m(0,1)*
c - m(1,1)*
s;
172 }
else if( m(1,1) == 0.0 ){
176 m(0,0) = m(0,0)*
c - m(0,1)*
s;
182 }
else if( m(0,1) != 0.0 ){
185 Real abs01 =
fabs(m(0,1));
186 Real abs11 =
fabs(m(1,1));
188 if( abs01 > abs11 ){ min = abs11; max = abs01; }
189 else { min = abs01; max = abs11; }
192 Real ratio = min/max;
193 Real alpha = 2.0*
log(max) +
log(1.0 + ratio*ratio);
195 Real abs00 =
fabs(m(0,0));
196 Real beta = 2.0*
log(abs00);
203 temp = alpha +
log(1.0 -
exp(beta-alpha));
206 temp = beta +
log(1.0 -
exp(alpha-beta));
211 if( m(0,0) < 0.0 ){ temp *= -1.0; }
212 if( m(0,1) < 0.0 ){ temp *= -1.0; }
215 temp = 1.0/(temp + beta);
225 p(0,0) =
c*m(0,0) -
s*m(0,1);
227 p(0,1) =
s*m(0,0) +
c*m(0,1);
246 m(0,0) =
p(0,0)*
c -
s*
p(1,0);
247 m(1,1) =
p(0,1)*
s +
c*
p(1,1);
258 template<
class Float>
262 typedef complex<Float>
Complex;
265 Array<Complex,3> vec;
269 const Complex COMPLEX_UNITY(1.0,0.0);
270 const Complex COMPLEX_ZERO = 0.0;
289 Array<Complex,3> temp_vec;
294 w.x =
mat(0,0).x - beta;
296 Float norm1_inv = 1.0/
cabs(
w);
300 vec[0] = COMPLEX_UNITY;
301 vec[1] =
mat(1,0)*
w*norm1_inv;
302 vec[2] =
mat(2,0)*
w*norm1_inv;
305 tau.x = (beta -
mat(0,0).x)/beta;
306 tau.y =
mat(0,0).y/beta;
315 if(
norm2 != 0.0 ||
p(0,1).
y != 0.0){
318 vec[0] = COMPLEX_ZERO;
319 vec[1] = COMPLEX_UNITY;
323 Float norm1_inv = 1.0/
cabs(
w);
325 z =
conj(
p(0,2))*norm1_inv;
328 tau.x = (beta -
p(0,1).x)/beta;
329 tau.y = (-
p(0,1).y)/beta;
338 if(
norm2 != 0.0 ||
p(1,1).
y != 0.0){
343 vec[0] = COMPLEX_ZERO;
344 vec[1] = COMPLEX_UNITY;
346 w.x =
p(1,1).x - beta;
348 Float norm1_inv = 1.0/
cabs(
w);
350 z =
p(2,1)*norm1_inv;
353 tau.x = (beta -
p(1,1).x)/beta;
354 tau.y =
p(1,1).y/beta;
366 if(
p(1,2).
y != 0.0 ){
367 beta =
p(1,2).x > 0.0 ? -
cabs(
p(1,2)) :
cabs(
p(1,2));
368 temp(2,2) =
conj(
p(1,2))/beta;
369 p(2,2) =
p(2,2)*temp(2,2);
374 if(
p(2,2).
y != 0.0 ){
375 beta =
p(2,2).x > 0.0 ? -
cabs(
p(2,2)) :
cabs(
p(2,2));
376 temp(2,2) =
p(2,2)/beta;
409 if(
b(0,1) != 0.0 &&
b(1,2) != 0.0 ){
414 for(
int i=0;
i<3; ++
i){
417 u(
i,0) = alpha*
c - beta*
s;
418 u(
i,1) = alpha*
s + beta*
c;
421 b(1,1) =
b(0,1)*
s +
b(1,1)*
c;
429 for(
int i=0;
i<3; ++
i){
432 u(
i,0) = alpha*
c - beta*
s;
433 u(
i,2) = alpha*
s + beta*
c;
435 b(2,2) =
b(0,2)*
s +
b(2,2)*
c;
438 }
else if(
b(1,1) == 0.0 ){
441 for(
int i=0;
i<3; ++
i){
444 u(
i,1) = alpha*
c - beta*
s;
445 u(
i,2) = alpha*
s + beta*
c;
447 b(2,2) =
b(1,2)*
s +
b(2,2)*
c;
450 }
else if(
b(2,2) == 0.0 ){
453 for(
int i=0;
i<3; ++
i){
456 v(
i,1) = alpha*
c + beta*
s;
457 v(
i,2) = -alpha*
s + beta*
c;
459 b(1,1) =
b(1,1)*
c +
b(1,2)*
s;
467 for(
int i=0;
i<3; ++
i){
470 v(
i,0) = alpha*
c + beta*
s;
471 v(
i,2) = -alpha*
s + beta*
c;
473 b(0,0) =
b(0,0)*
c +
b(0,2)*
s;
481 alpha =
b(0,0)*
b(0,0) - lambda_max;
482 beta =
b(0,0)*
b(0,1);
492 b(0,0) = alpha*
c - beta*
s;
493 b(0,1) = alpha*
s + beta*
c;
503 b(0,0) =
b(0,0)*
c -
b(1,0)*
s;
506 b(0,1) = alpha*
c - beta*
s;
507 b(1,1) = alpha*
s + beta*
c;
519 b(0,1) =
b(0,1)*
c -
b(0,2)*
s;
523 b(1,1) = alpha*
c - beta*
s;
524 b(1,2) = alpha*
s + beta*
c;
535 b(1,1) =
b(1,1)*
c -
b(2,1)*
s;
538 b(1,2) = alpha*
c - beta*
s;
539 b(2,2) = alpha*
s + beta*
c;
545 }
else if(
b(0,1) == 0.0 ){
548 m_small(0,0) =
b(1,1);
549 m_small(0,1) =
b(1,2);
550 m_small(1,0) =
b(2,1);
551 m_small(1,1) =
b(2,2);
553 smallSVD(u_small, v_small, m_small);
555 b(1,1) = m_small(0,0);
556 b(1,2) = m_small(0,1);
557 b(2,1) = m_small(1,0);
558 b(2,2) = m_small(1,1);
561 for(
int i=0;
i<3; ++
i){
564 u(
i,1) = alpha*u_small(0,0) + beta*u_small(1,0);
565 u(
i,2) = alpha*u_small(0,1) + beta*u_small(1,1);
569 v(
i,1) = alpha*v_small(0,0) + beta*v_small(1,0);
570 v(
i,2) = alpha*v_small(0,1) + beta*v_small(1,1);
574 }
else if(
b(1,2) == 0.0 ){
577 m_small(0,0) =
b(0,0);
578 m_small(0,1) =
b(0,1);
579 m_small(1,0) =
b(1,0);
580 m_small(1,1) =
b(1,1);
582 smallSVD(u_small, v_small, m_small);
584 b(0,0) = m_small(0,0);
585 b(0,1) = m_small(0,1);
586 b(1,0) = m_small(1,0);
587 b(1,1) = m_small(1,1);
589 for(
int i=0;
i<3; ++
i){
592 u(
i,0) = alpha*u_small(0,0) + beta*u_small(1,0);
593 u(
i,1) = alpha*u_small(0,1) + beta*u_small(1,1);
597 v(
i,0) = alpha*v_small(0,0) + beta*v_small(1,0);
598 v(
i,1) = alpha*v_small(0,1) + beta*v_small(1,1);
603 }
while( (
b(0,1) != 0.0 ||
b(1,2) != 0.0) &&
it < max_it);
605 for(
int i=0;
i<3; ++
i){
608 for(
int j=0; j<3; ++j){
618 template<
class Float>
621 Matrix<complex<Float>,3>& v, Float singular_values[3])
623 getRealBidiagMatrix<Float>(m, u, v);
627 for(
int i=0;
i<3; ++
i){
628 for(
int j=0; j<3; ++j){
629 bd(
i,j) = (
conj(u)*m*(v))(
i,j).real();
633 bdSVD(u_real, v_real, bd, 500);
634 for(
int i=0;
i<3; ++
i){
635 singular_values[
i] = bd(
i,
i);
646 #undef INVALID_DOUBLE 648 #endif // _SVD_QUDA_H
Matrix< N, std::complex< T > > conj(const Matrix< N, std::complex< T > > &mat)
std::complex< double > Complex
DEVICEHOST void getGivensRotation(const Real &alpha, const Real &beta, Real &c, Real &s)
DEVICEHOST void getLambdaMax(const Matrix< Real, 3 > &b, Real &lambda_max)
DEVICEHOST float getNorm(const Array< complex< float >, 3 > &a)
char * index(const char *, int)
__device__ __host__ void outerProd(const Array< T, N > &a, const Array< T, N > &b, Matrix< T, N > *m)
DEVICEHOST void computeSVD(const Matrix< complex< Float >, 3 > &m, Matrix< complex< Float >, 3 > &u, Matrix< complex< Float >, 3 > &v, Float singular_values[3])
DEVICEHOST void smallSVD(Matrix< Real, 2 > &u, Matrix< Real, 2 > &v, Matrix< Real, 2 > &m)
DEVICEHOST std::remove_reference< decltype(Cmplx::x)>::type cabs(const Cmplx &z)
static __inline__ size_t p
__device__ __host__ void copyColumn(const Matrix< T, N > &m, int c, Array< T, N > *a)
double norm2(Float *v, int len)
__device__ __host__ void setIdentity(Matrix< T, N > *m)
DEVICEHOST void getRealBidiagMatrix(const Matrix< complex< Float >, 3 > &mat, Matrix< complex< Float >, 3 > &u, Matrix< complex< Float >, 3 > &v)
DEVICEHOST PromoteTypeId< T, U >::Type quadSum(const T &a, const U &b)
double norm1(const ColorSpinorField &b)
DEVICEHOST void swap(Real &a, Real &b)
DEVICEHOST void assignGivensRotation(const Real &c, const Real &s, Matrix< Real, 2 > &m)
DEVICEHOST void constructHHMat(const T &tau, const Array< T, 3 > &v, Matrix< T, 3 > &hh)
DEVICEHOST void bdSVD(Matrix< Real, 3 > &u, Matrix< Real, 3 > &v, Matrix< Real, 3 > &b, int max_it)
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
DEVICEHOST void accumGivensRotation(int index, const Real &c, const Real &s, Matrix< Real, 3 > &m)