QUDA  v1.1.0
A library for QCD on GPUs
inv_mpbicgstab_quda.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 #include <iostream>
5 
6 #include <quda_internal.h>
7 #include <color_spinor_field.h>
8 #include <blas_quda.h>
9 #include <dslash_quda.h>
10 #include <invert_quda.h>
11 #include <util_quda.h>
12 
13 namespace quda {
14 
16  Solver(mat, mat, mat, mat, param, profile)
17  {
18  }
19 
21  }
22 
23 
24  void MPBiCGstab::computeMatrixPowers(std::vector<cudaColorSpinorField>& pr, cudaColorSpinorField& p, cudaColorSpinorField& r, int nsteps){
25  cudaColorSpinorField temp(p);
26  pr[0] = p;
27  for(int i=1; i<=(2*nsteps); ++i){
28  mat(pr[i], pr[i-1], temp);
29  }
30 
31  pr[(2*nsteps)+1] = r;
32  // for(int i=(2*nsteps+2); i<(4*nsteps+2); ++i){
33  for(int i=(2*nsteps+2); i<(4*nsteps+1); ++i){
34  mat(pr[i], pr[i-1], temp);
35  }
36  }
37 
38 #ifdef SSTEP
39  static void print(const double d[], int n){
40  for(int i=0; i<n; ++i){
41  std::cout << d[i] << " ";
42  }
43  std::cout << std::endl;
44  }
45 
46  static void print(const Complex d[], int n){
47  for(int i=0; i<n; ++i){
48  std::cout << "(" << real(d[i]) << "," << imag(d[i]) << ") ";
49  }
50  std::cout << std::endl;
51  }
52 
53 
54  template<typename T>
55  static void zero(T d[], int N){
56  for(int i=0; i<N; ++i) d[i] = static_cast<T>(0);
57  }
58 
59 
60  static void computeGramMatrix(Complex** G, std::vector<cudaColorSpinorField>& v){
61 
62  const int dim = v.size();
63 
64  for(int i=0; i<dim; ++i){
65  for(int j=0; j<dim; ++j){
66  G[i][j] = blas::cDotProduct(v[i],v[j]);
67  }
68  }
69  return;
70  }
71 
72  static void computeGramVector(Complex* g, cudaColorSpinorField& r0, std::vector<cudaColorSpinorField>& pr){
73 
74  const int dim = pr.size();
75 
76  for(int i=0; i<dim; ++i){
77  g[i] = blas::cDotProduct(r0,pr[i]);
78  }
79  }
80 
81 /*
82  // Here, B is an (s+1)x(s+1) matrix with 1s on the subdiagonal
83  template<class T>
84  static void getBColumn(T *col, int col_index, int nsteps){
85  zero(col,nsteps+1);
86  col[col_index] = static_cast<T>(1);
87  }
88 
89  template<typename T>
90  static void init_c_vector(T *c, int index, int nsteps){
91  zero(c,2*nsteps+2);
92  getBColumn(c+(nsteps+1),index,nsteps);
93  }
94 
95  template<typename T>
96  static void init_a_vector(T *a, int index, int nsteps){
97  zero(a,2*nsteps+2);
98  getBColumn(a,index,nsteps);
99  }
100 
101  template<typename T>
102  static void init_e_vector(T *e, int nsteps){
103  zero(e,2*nsteps+2);
104  e[2*nsteps+2] = static_cast<T>(1);
105  }
106 */
107 
108  template<typename T>
109  static T zip(T a[], T b[], int dim){
110  T result = 0.0;
111  for(int i=0; i<dim; ++i){
112  result += a[i]*b[i];
113  }
114  return result;
115  }
116 
117  static Complex computeUdaggerMV(Complex* u, Complex** M, Complex* v, int dim)
118  {
119  Complex result(0,0);
120 
121  for(int i=0; i<dim; ++i){
122  for(int j=0; j<dim; ++j){
123  result += conj(u[i])*v[j]*M[i][j];
124  }
125  }
126  return result;
127  }
128 #endif
129 
131  {
132 #ifndef SSTEP
133  errorQuda("S-step solvers not built\n");
134 #else
135  // Check to see that we're not trying to invert on a zero-field source
136  const double b2 = blas::norm2(b);
137  if(b2 == 0){
138  profile.TPSTOP(QUDA_PROFILE_INIT);
139  printfQuda("Warning: inverting on zero-field source\n");
140  x=b;
141  param.true_res = 0.0;
142  param.true_res_hq = 0.0;
143  return;
144  }
145 
148 
149  cudaColorSpinorField temp(b, csParam);
150 
152 
153 
154 
155  mat(r, x, temp); // r = Ax
156  double r2 = blas::xmyNorm(b,r); // r = b - Ax
157 
158 
159 
160  cudaColorSpinorField r0(r);
162  cudaColorSpinorField Ap(r);
163 
164 
165  const int s = 3;
166 
167  // Vector of matrix powers
168  std::vector<cudaColorSpinorField> PR(4*s+2,cudaColorSpinorField(b,csParam));
169 
170 
171  Complex r0r;
172  Complex alpha;
173  Complex omega;
174  Complex beta;
175 
176  Complex** G = new Complex*[4*s+2];
177  for(int i=0; i<(4*s+2); ++i){
178  G[i] = new Complex[4*s+2];
179  }
180  Complex* g = new Complex[4*s+2];
181 
182  Complex** a = new Complex*[2*s+1];
183  Complex** c = new Complex*[2*s+1];
184  Complex** a_new = new Complex*[2*s+1];
185  Complex** c_new = new Complex*[2*s+1];
186 
187  for(int i=0; i<(2*s+1); ++i){
188  a[i] = new Complex[4*s+2];
189  c[i] = new Complex[4*s+2];
190  a_new[i] = new Complex[4*s+2];
191  c_new[i] = new Complex[4*s+2];
192  }
193 
194 
195  Complex* e = new Complex[4*s+2];
196 
197 
198 
199 
200  double stop = stopping(param.tol, b2, param.residual_type);
201  int it=0;
202  int m=0;
203  while(!convergence(r2, 0.0, stop, 0.0 ) && it<param.maxiter){
204 
205  computeMatrixPowers(PR, p, r, s);
206  computeGramVector(g, r0, PR);
207  computeGramMatrix(G, PR);
208 
209  // initialize coefficient vectors
210  for(int i=0; i<(2*s+1); ++i){
211  zero(a[i],(4*s+2));
212  zero(c[i],(4*s+2));
213  a[i][i] = static_cast<Complex>(1);
214  c[i][i + (2*s+1)] = static_cast<Complex>(1);
215  }
216 
217 
218  zero(e,(4*s+2));
219  int j=0;
220  while(!convergence(r2,0.0,stop,0.0) && j<s){
221  PrintStats("MPBiCGstab", it, r2, b2, 0.0);
222 
223  alpha = zip(g,c[0],4*s+2)/zip(g,a[1],4*s+2);
224 
225  Complex omega_num = computeUdaggerMV(c[0],G,c[1],(4*s+2))
226  - alpha*computeUdaggerMV(c[0],G,a[2],(4*s+2))
227  - conj(alpha)*computeUdaggerMV(a[1],G,c[1],(4*s+2))
228  + conj(alpha)*alpha*computeUdaggerMV(a[1],G,a[2],(4*s+2));
229 
230  Complex omega_denom = computeUdaggerMV(c[1],G,c[1],(4*s+2))
231  - alpha*computeUdaggerMV(c[1],G,a[2],(4*s+2))
232  - conj(alpha)*computeUdaggerMV(a[2],G,c[1],(4*s+2))
233  + conj(alpha)*alpha*computeUdaggerMV(a[2],G,a[2],(4*s+2));
234 
235  omega = omega_num/omega_denom;
236  // Update candidate solution
237  for(int i=0; i<(4*s+2); ++i){
238  e[i] += alpha*a[0][i] + omega*c[0][i] - alpha*omega*a[1][i];
239  }
240 
241  // Update residual
242  for(int k=0; k<=(2*(s - j - 1)); ++k){
243  for(int i=0; i<(4*s+2); ++i){
244  c_new[k][i] = c[k][i] - alpha*a[k+1][i] - omega*c[k+1][i] + alpha*omega*a[k+2][i];
245  }
246  }
247 
248  // update search direction
249  beta = (zip(g,c_new[0],(4*s+2))/zip(g,c[0],(4*s+2)))*(alpha/omega);
250 
251  for(int k=0; k<=(2*(s - j - 1)); ++k){
252  for(int i=0; i<(4*s+2); ++i){
253  a_new[k][i] = c_new[k][i] + beta*a[k][i] - beta*omega*a[k+1][i];
254  }
255 
256  for(int i=0; i<(4*s+2); ++i){
257  a[k][i] = a_new[k][i];
258  c[k][i] = c_new[k][i];
259  }
260  }
261  blas::zero(r);
262  for(int i=0; i<(4*s+2); ++i){
263  blas::caxpy(c[0][i], PR[i], r);
264  }
265  r2 = blas::norm2(r);
266  j++;
267  it++;
268  } // j
269 
270  blas::zero(p);
271  for(int i=0; i<(4*s+2); ++i){
272  blas::caxpy(a[0][i], PR[i], p);
273  blas::caxpy(e[i], PR[i], x);
274  }
275 
276  m++;
277  }
278 
279  if(it >= param.maxiter)
280  warningQuda("Exceeded maximum iterations %d", param.maxiter);
281 
282  // compute the true residual
283  mat(r, x, temp);
284  param.true_res = sqrt(blas::xmyNorm(b, r)/b2);
285 
286  PrintSummary("MPBiCGstab", it, r2, b2, stop, param.tol_hq);
287 
288 
289 
290  for(int i=0; i<(4*s+2); ++i){
291  delete[] G[i];
292  }
293  delete[] G;
294 
295 
296  delete[] g;
297 
298  // Is 2*s + 3 really correct?
299  for(int i=0; i<(2*s+1); ++i){
300  delete[] a[i];
301  delete[] a_new[i];
302  delete[] c[i];
303  delete[] c_new[i];
304  }
305  delete[] a;
306  delete[] a_new;
307  delete[] c;
308  delete[] c_new;
309  delete[] e;
310 #endif
311  return;
312  }
313 
314 } // namespace quda
MPBiCGstab(const DiracMatrix &mat, SolverParam &param, TimeProfile &profile)
void operator()(ColorSpinorField &out, ColorSpinorField &in)
TimeProfile & profile
Definition: invert_quda.h:471
const DiracMatrix & mat
Definition: invert_quda.h:465
bool convergence(double r2, double hq2, double r2_tol, double hq_tol)
Definition: solver.cpp:328
void PrintSummary(const char *name, int k, double r2, double b2, double r2_tol, double hq_tol)
Prints out the summary of the solver convergence (requires a verbosity of QUDA_SUMMARIZE)....
Definition: solver.cpp:386
SolverParam & param
Definition: invert_quda.h:470
static double stopping(double tol, double b2, QudaResidualType residual_type)
Set the solver L2 stopping condition.
Definition: solver.cpp:311
void PrintStats(const char *name, int k, double r2, double b2, double hq2)
Prints out the running statistics of the solver (requires a verbosity of QUDA_VERBOSE)
Definition: solver.cpp:373
std::array< int, 4 > dim
double omega
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
@ QUDA_ZERO_FIELD_CREATE
Definition: enum_quda.h:361
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:79
void zero(ColorSpinorField &a)
double norm2(const ColorSpinorField &a)
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
void stop()
Stop profiling.
Definition: device.cpp:228
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:130
__device__ __host__ void zero(double &a)
Definition: float_vector.h:318
std::complex< double > Complex
Definition: quda_internal.h:86
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:120
@ QUDA_PROFILE_INIT
Definition: timer.h:106
void print(const double d[], int n)
ColorSpinorParam csParam
Definition: pack_test.cpp:25
QudaGaugeParam param
Definition: pack_test.cpp:18
QudaResidualType residual_type
Definition: invert_quda.h:49
#define printfQuda(...)
Definition: util_quda.h:114
#define warningQuda(...)
Definition: util_quda.h:132
#define errorQuda(...)
Definition: util_quda.h:120