QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
linalg.cuh
Go to the documentation of this file.
1 #pragma once
2 #include <color_spinor.h> // vector container
3 
15 namespace quda {
16 
17  namespace linalg {
18 
37  template <template<typename,int> class Mat, typename T, int N, bool fast=true>
38  class Cholesky {
39 
41  Mat<T,N> L_;
42 
43  public:
48  __device__ __host__ inline Cholesky(const Mat<T,N> &A) {
49  const Mat<T,N> &L = L_;
50 
51 #pragma unroll
52  for (int i=0; i<N; i++) {
53 #pragma unroll
54  for (int j=0; j<N; j++) if (j<i+1) {
55  complex<T> s = 0;
56 #pragma unroll
57  for (int k=0; k<N; k++) {
58  if (k==0) {
59  s.x = L(i,k).real()*L(j,k).real();
60  s.x += L(i,k).imag()*L(j,k).imag();
61  s.y = L(i,k).imag()*L(j,k).real();
62  s.y -= L(i,k).real()*L(j,k).imag();
63  } else if (k<j) {
64  s.x += L(i,k).real()*L(j,k).real();
65  s.x += L(i,k).imag()*L(j,k).imag();
66  s.y += L(i,k).imag()*L(j,k).real();
67  s.y -= L(i,k).real()*L(j,k).imag();
68  }
69  }
70  if (!fast) { // traditional Cholesky with sqrt and division
71  L_(i,j) = (i == j) ? sqrt((A(i,i)-s).real()) : (A(i,j) - s) / L(j,j).real();
72  } else { // optimized - since fwd/back subsitition only need inverse diagonal elements, avoid division and use rsqrt
73  L_(i,j) = (i == j) ? rsqrt((A(i,i)-s).real()) : (A(i,j)-s) * L(j,j).real();
74  }
75  }
76  }
77  }
78 
84  __device__ __host__ inline const T D(int i) const {
85  const auto &L = L_;
86  if (!fast) return L(i,i).real();
87  else return static_cast<T>(1.0) / L(i,i).real();
88  }
89 
96  template <class Vector>
97  __device__ __host__ inline Vector forward(const Vector &b) {
98  const Mat<T,N> &L = L_;
99  Vector x;
100 #pragma unroll
101  for (int i=0; i<N; i++) {
102  x(i) = b(i);
103 #pragma unroll
104  for (int j=0; j<N; j++) if (j<i) {
105  x(i).x -= L(i,j).real()*x(j).real();
106  x(i).x += L(i,j).imag()*x(j).imag();
107  x(i).y -= L(i,j).real()*x(j).imag();
108  x(i).y -= L(i,j).imag()*x(j).real();
109  }
110  if (!fast) x(i) /= L(i,i).real(); // traditional
111  else x(i) *= L(i,i).real(); // optimized
112  }
113  return x;
114  }
115 
122  template <class Vector>
123  __device__ __host__ inline Vector backward(const Vector &b) {
124  const Mat<T,N> &L = L_;
125  Vector x;
126 #pragma unroll
127  for (int i=N-1; i>=0; i--) {
128  x(i) = b(i);
129 #pragma unroll
130  for (int j=0; j<N; j++) if (j>=i+1) {
131  x(i).x -= L(i,j).real()*x(j).real();
132  x(i).x += L(i,j).imag()*x(j).imag();
133  x(i).y -= L(i,j).real()*x(j).imag();
134  x(i).y -= L(i,j).imag()*x(j).real();
135  }
136  if (!fast) x(i) /=L(i,i).real(); // traditional
137  else x(i) *= L(i,i).real(); // optimized
138  }
139  return x;
140  }
141 
146  __device__ __host__ inline Mat<T,N> invert() {
147  const Mat<T,N> &L = L_;
148  Mat<T,N> Ainv;
150 
151 #pragma unroll
152  for (int k=0;k<N;k++) {
153 
154  // forward substitute
155  if (!fast) v(k) = complex<T>(static_cast<T>(1.0)/L(k,k).real());
156  else v(k) = L(k,k).real();
157 
158 #pragma unroll
159  for (int i=0; i<N; i++) if (i>k) {
160  v(i) = complex<T>(0.0);
161 #pragma unroll
162  for (int j=0; j<N; j++) if (j>=k && j<i) {
163  v(i).x -= L(i,j).real() * v(j).real();
164  v(i).x += L(i,j).imag() * v(j).imag();
165  v(i).y -= L(i,j).real() * v(j).imag();
166  v(i).y -= L(i,j).imag() * v(j).real();
167  }
168  if (!fast) v(i) *= static_cast<T>(1.0) / L(i,i);
169  else v(i) *= L(i,i);
170  }
171 
172  // backward substitute
173  if (!fast) v(N-1) *= static_cast<T>(1.0) / L(N-1,N-1);
174  else v(N-1) *= L(N-1,N-1);
175 
176 #pragma unroll
177  for (int i=N-2; i>=0; i--) if (i>=k) {
178 #pragma unroll
179  for (int j=0; j<N; j++) if (j>i) {
180  v(i).x -= L(i,j).real() * v(j).real();
181  v(i).x += L(i,j).imag() * v(j).imag();
182  v(i).y -= L(i,j).real() * v(j).imag();
183  v(i).y -= L(i,j).imag() * v(j).real();
184  }
185  if (!fast) v(i) *= static_cast<T>(1.0) / L(i,i);
186  else v(i) *= L(i,i);
187  }
188 
189  // Overwrite column k
190  Ainv(k,k) = v(k);
191 
192 #pragma unroll
193  for(int i=0;i<N;i++) if (i>k) Ainv(i,k) = v(i);
194  }
195 
196  return Ainv;
197  }
198 
199  };
200 
201  } // namespace linalg
202 
203 } // namespace quda
__device__ __host__ const T D(int i) const
Return the diagonal element of the Cholesky decomposition L(i,i)
Definition: linalg.cuh:84
__device__ __host__ Vector backward(const Vector &b)
Backward substition to solve L^dagger x = b.
Definition: linalg.cuh:123
__device__ __host__ Cholesky(const Mat< T, N > &A)
Constructor that computes the Cholesky decomposition.
Definition: linalg.cuh:48
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:120
Compute Cholesky decomposition of A. By default, we use a modified Cholesky which avoids the division...
Definition: linalg.cuh:38
void Mat(sFloat *out, gFloat **link, sFloat *in, int daggerBit, int mu)
__device__ __host__ Vector forward(const Vector &b)
Forward substition to solve Lx = b.
Definition: linalg.cuh:97
__shared__ float s[]
__device__ __host__ Mat< T, N > invert()
Compute the inverse of A (the matrix used to construct the Cholesky decomposition).
Definition: linalg.cuh:146
VectorXcd Vector
Mat< T, N > L_
The Cholesky factorization.
Definition: linalg.cuh:41