QUDA  v1.1.0
A library for QCD on GPUs
inv_bicgstabl_quda.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 #include <iostream>
5 #include <sstream>
6 #include <complex>
7 
8 #include <quda_internal.h>
9 #include <color_spinor_field.h>
10 #include <blas_quda.h>
11 #include <dslash_quda.h>
12 #include <invert_quda.h>
13 #include <util_quda.h>
14 
15 namespace quda {
16 
17  // Utility functions for Gram-Schmidt. Based on GCR functions.
18  // Big change is we need to go from 1 to n_krylov, not 0 to n_krylov-1.
19 
20  void BiCGstabL::computeTau(Complex **tau, double* sigma, std::vector<ColorSpinorField*> r, int begin, int size, int j)
21  {
22  Complex *Tau = new Complex[size];
23  std::vector<ColorSpinorField*> a(size), b(1);
24  for (int k=0; k<size; k++)
25  {
26  a[k] = r[begin+k];
27  Tau[k] = 0;
28  }
29  b[0] = r[j];
30  blas::cDotProduct(Tau, a, b); // vectorized dot product
31 
32  for (int k=0; k<size; k++)
33  {
34  tau[begin+k][j] = Tau[k]/sigma[begin+k];
35  }
36  delete []Tau;
37  }
38 
39  void BiCGstabL::updateR(Complex **tau, std::vector<ColorSpinorField*> r, int begin, int size, int j)
40  {
41 
42  Complex *tau_ = new Complex[size];
43  for (int i=0; i<size; i++)
44  {
45  tau_[i] = -tau[i+begin][j];
46  }
47 
48  std::vector<ColorSpinorField*> r_(r.begin() + begin, r.begin() + begin + size);
49  std::vector<ColorSpinorField*> rj(r.begin() + j, r.begin() + j + 1);
50 
51  blas::caxpy(tau_, r_, rj);
52 
53  delete[] tau_;
54  }
55 
56  void BiCGstabL::orthoDir(Complex **tau, double* sigma, std::vector<ColorSpinorField*> r, int j, int pipeline)
57  {
58 
59  switch (pipeline)
60  {
61  case 0: // no kernel fusion
62  for (int i=1; i<j; i++) // 5 (j-2) memory transactions here. Start at 1 b/c bicgstabl convention.
63  {
64  tau[i][j] = blas::cDotProduct(*r[i], *r[j])/sigma[i];
65  blas::caxpy(-tau[i][j], *r[i], *r[j]);
66  }
67  break;
68  case 1: // basic kernel fusion
69  if (j==1) // start at 1.
70  {
71  break;
72  }
73  tau[1][j] = blas::cDotProduct(*r[1], *r[j])/sigma[1];
74  for (int i=1; i<j-1; i++) // 4 (j-2) memory transactions here. start at 1.
75  {
76  tau[i+1][j] = blas::caxpyDotzy(-tau[i][j], *r[i], *r[j], *r[i+1])/sigma[i+1];
77  }
78  blas::caxpy(-tau[j-1][j], *r[j-1], *r[j]);
79  break;
80  default:
81  {
82  const int N = pipeline;
83  // We're orthogonalizing r[j] against r[1], ..., r[j-1].
84  // We need to do (j-1)/N updates of length N, at 1,1+N,1+2*N,...
85  // After, we do 1 update of length (j-1)%N.
86 
87  // (j-1)/N updates of length N, at 1,1+N,1+2*N,...
88  int step;
89  for (step = 0; step < (j-1)/N; step++)
90  {
91  computeTau(tau, sigma, r, 1+step*N, N, j);
92  updateR(tau, r, 1+step*N, N, j);
93  }
94 
95  if ((j-1)%N != 0) // need to update the remainder
96  {
97  // 1 update of length (j-1)%N.
98  computeTau(tau, sigma, r, 1+step*N, (j-1)%N, j);
99  updateR(tau, r, 1+step*N, (j-1)%N, j);
100  }
101  }
102  break;
103  }
104 
105  }
106 
107  void BiCGstabL::updateUend(Complex *gamma, std::vector<ColorSpinorField *> u, int n_krylov)
108  {
109  // for (j = 0; j <= n_krylov; j++) { caxpy(-gamma[j], *u[j], *u[0]); }
110  Complex *gamma_ = new Complex[n_krylov];
111  for (int i = 0; i < n_krylov; i++) { gamma_[i] = -gamma[i + 1]; }
112 
113  std::vector<ColorSpinorField*> u_(u.begin() + 1, u.end());
114  std::vector<ColorSpinorField*> u0(u.begin(), u.begin() + 1);
115 
116  blas::caxpy(gamma_, u_, u0);
117 
118  delete[] gamma_;
119  }
120 
121  void BiCGstabL::updateXRend(Complex *gamma, Complex *gamma_prime, Complex *gamma_prime_prime,
122  std::vector<ColorSpinorField *> r, ColorSpinorField &x, int n_krylov)
123  {
124 #if 0
125  blas::caxpy(gamma[1], *r[0], x);
126  blas::caxpy(-gamma_prime[n_krylov], *r[n_krylov], *r[0]);
127  for (int j = 1; j < n_krylov; j++)
128  {
129  blas::caxpy(gamma_prime_prime[j], *r[j], x);
130  blas::caxpy(-gamma_prime[j], *r[j], *r[0]);
131  }
132 #else
133  // This does two "wasted" caxpys (so 2*n_krylov+2 instead of 2*n_kKrylov), but
134  // the alternative way would be un-fusing some calls, which would require
135  // loading and saving x twice. In a solve where the sloppy precision is lower than
136  // the full precision, this can be a killer.
137  Complex *gamma_prime_prime_ = new Complex[n_krylov + 1];
138  Complex *gamma_prime_ = new Complex[n_krylov + 1];
139  gamma_prime_prime_[0] = gamma[1];
140  gamma_prime_prime_[n_krylov] = 0.0; // x never gets updated with r[n_krylov]
141  gamma_prime_[0] = 0.0; // r[0] never gets updated with r[0]... obvs.
142  gamma_prime_[n_krylov] = -gamma_prime[n_krylov];
143  for (int i = 1; i < n_krylov; i++) {
144  gamma_prime_prime_[i] = gamma_prime_prime[i];
145  gamma_prime_[i] = -gamma_prime[i];
146  }
147  blas::caxpyBxpz(gamma_prime_prime_, r, x, gamma_prime_, *r[0]);
148 
149  delete[] gamma_prime_prime_;
150  delete[] gamma_prime_;
151 #endif
152  }
153 
169  {
172  };
173 
174  class BiCGstabLUpdate : public Worker {
175 
176  ColorSpinorField* x;
177  std::vector<ColorSpinorField*> &r;
178  std::vector<ColorSpinorField*> &u;
179 
180  Complex* alpha;
181  Complex* beta;
182 
183  BiCGstabLUpdateType update_type;
184 
189  int j_max;
190 
196  int n_update;
197 
198  public:
199  BiCGstabLUpdate(ColorSpinorField* x, std::vector<ColorSpinorField*>& r, std::vector<ColorSpinorField*>& u,
200  Complex* alpha, Complex* beta, BiCGstabLUpdateType update_type, int j_max, int n_update) :
201  x(x), r(r), u(u), alpha(alpha), beta(beta), j_max(j_max),
202  n_update(n_update)
203  {
204 
205  }
206  virtual ~BiCGstabLUpdate() { }
207 
208  void update_j_max(int new_j_max) { j_max = new_j_max; }
209  void update_update_type(BiCGstabLUpdateType new_update_type) { update_type = new_update_type; }
210 
211  // note that we can't set the stream parameter here so it is
212  // ignored. This is more of a future design direction to consider
213  void apply(const qudaStream_t &stream)
214  {
215  static int count = 0;
216 
217  // on the first call do the first half of the update
218  if (update_type == BICGSTABL_UPDATE_U)
219  {
220  for (int i= (count*j_max)/n_update; i<((count+1)*j_max)/n_update && i<j_max; i++)
221  {
222  blas::caxpby(1.0, *r[i], -*beta, *u[i]);
223  }
224  }
225  else // (update_type == BICGSTABL_UPDATE_R)
226  {
227  if (count == 0)
228  {
229  blas::caxpy(*alpha, *u[0], *x);
230  }
231  if (j_max > 0)
232  {
233  for (int i= (count*j_max)/n_update; i<((count+1)*j_max)/n_update && i<j_max; i++)
234  {
235  blas::caxpy(-*alpha, *u[i+1], *r[i]);
236  }
237  }
238  }
239 
240  if (++count == n_update) count = 0;
241  }
242  };
243 
244  // this is the Worker pointer that the dslash uses to launch the shifted updates
245  namespace dslash {
246  extern Worker* aux_worker;
247  }
248 
250  Solver(mat, matSloppy, matSloppy, matSloppy, param, profile),
251  n_krylov(param.Nkrylov),
252  init(false)
253  {
254  r.resize(n_krylov + 1);
255  u.resize(n_krylov + 1);
256 
257  gamma = new Complex[n_krylov + 1];
258  gamma_prime = new Complex[n_krylov + 1];
259  gamma_prime_prime = new Complex[n_krylov + 1];
260  sigma = new double[n_krylov + 1];
261 
262  tau = new Complex *[n_krylov + 1];
263  for (int i = 0; i < n_krylov + 1; i++) { tau[i] = new Complex[n_krylov + 1]; }
264 
265  std::stringstream ss;
266  ss << "BiCGstab-" << n_krylov;
267  solver_name = ss.str();
268  }
269 
271  profile.TPSTART(QUDA_PROFILE_FREE);
272  delete[] gamma;
273  delete[] gamma_prime;
274  delete[] gamma_prime_prime;
275  delete[] sigma;
276 
277  for (int i = 0; i < n_krylov + 1; i++) { delete[] tau[i]; }
278  delete[] tau;
279 
280  if (init) {
281  delete r_sloppy_saved_p;
282  delete u[0];
283  for (int i = 1; i < n_krylov + 1; i++) {
284  delete r[i];
285  delete u[i];
286  }
287 
288  delete x_sloppy_saved_p;
289  delete r_fullp;
290  delete r0_saved_p;
291  delete yp;
292  delete tempp;
293 
294  init = false;
295  }
296 
297  profile.TPSTOP(QUDA_PROFILE_FREE);
298 
299  }
300 
301  // Code to check for reliable updates, copied from inv_bicgstab_quda.cpp
302  // Technically, there are ways to check both 'x' and 'r' for reliable updates...
303  // the current status in BiCGstab is to just look for reliable updates in 'r'.
304  int BiCGstabL::reliable(double &rNorm, double &maxrx, double &maxrr, const double &r2, const double &delta) {
305  // reliable updates
306  rNorm = sqrt(r2);
307  if (rNorm > maxrx) maxrx = rNorm;
308  if (rNorm > maxrr) maxrr = rNorm;
309  //int updateR = (rNorm < delta*maxrr && r0Norm <= maxrr) ? 1 : 0;
310  //int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? 1 : 0
311  int updateR = (rNorm < delta*maxrr) ? 1 : 0;
312 
313  //printf("reliable %d %e %e %e %e\n", updateR, rNorm, maxrx, maxrr, r2);
314 
315  return updateR;
316  }
317 
319  {
320  // BiCGstab-l is based on the algorithm outlined in
321  // BICGSTAB(L) FOR LINEAR EQUATIONS INVOLVING UNSYMMETRIC MATRICES WITH COMPLEX SPECTRUM
322  // G. Sleijpen, D. Fokkema, 1993.
323  // My implementation is based on Kate Clark's implementation in CPS, to be found in
324  // src/util/dirac_op/d_op_wilson_types/bicgstab.C
325 
326  // Begin profiling preamble.
328 
329  if (!init) {
330  // Initialize fields.
333 
334  // Full precision variables.
336 
337  // Create temporary.
339 
340  // Sloppy precision variables.
341  csParam.setPrecision(param.precision_sloppy);
342 
343  // Sloppy solution.
344  x_sloppy_saved_p = ColorSpinorField::Create(csParam); // Used depending on precision.
345 
346  // Shadow residual.
347  r0_saved_p = ColorSpinorField::Create(csParam); // Used depending on precision.
348 
349  // Temporary
351 
352  // Residual (+ extra residuals for BiCG steps), Search directions.
353  // Remark: search directions are sloppy in GCR. I wonder if we can
354  // get away with that here.
355  for (int i = 0; i <= n_krylov; i++) {
358  }
359  r_sloppy_saved_p = r[0]; // Used depending on precision.
360 
361  init = true;
362  }
363 
364  double b2 = blas::norm2(b); // norm sq of source.
365  double r2; // norm sq of residual
366 
367  ColorSpinorField &r_full = *r_fullp;
368  ColorSpinorField &y = *yp;
369  ColorSpinorField &temp = *tempp;
370 
371  ColorSpinorField *r0p, *x_sloppyp; // Get assigned below.
372 
373  // Compute initial residual depending on whether we have an initial guess or not.
375  mat(r_full, x, y); // r[0] = Ax
376  r2 = blas::xmyNorm(b, r_full); // r = b - Ax, return norm.
377  blas::copy(y, x);
378  } else {
379  blas::copy(r_full, b); // r[0] = b
380  r2 = b2;
381  blas::zero(x); // defensive measure in case solution isn't already zero
382  blas::zero(y);
383  }
384 
385  // Check to see that we're not trying to invert on a zero-field source
386  if (b2 == 0) {
388  warningQuda("inverting on zero-field source");
389  x = b;
390  param.true_res = 0.0;
391  param.true_res_hq = 0.0;
393  return;
395  b2 = r2;
396  } else {
397  errorQuda("Null vector computing requires non-zero guess!");
398  }
399  }
400 
401 
402 
403  // Set field aliasing according to whether we're doing mixed precision or not.
404  // There probably be bugs and headaches hiding here.
405  if (param.precision_sloppy == x.Precision()) {
406  r[0] = &r_full; // r[0] \equiv r_sloppy points to the same memory location as r.
408  {
409  r0p = &b; // r0, b point to the same vector in memory.
410  }
411  else
412  {
413  r0p = r0_saved_p; // r0p points to the saved r0 memory.
414  *r0p = r_full; // and is set equal to r.
415  }
416  }
417  else
418  {
419  r0p = r0_saved_p; // r0p points to saved r0 memory.
420  r[0] = r_sloppy_saved_p; // r[0] points to saved r_sloppy memory.
421  *r0p = r_full; // and is set equal to r.
422  *r[0] = r_full; // yup.
423  }
424 
426  {
427  x_sloppyp = &x; // x_sloppy and x point to the same vector in memory.
428  blas::zero(*x_sloppyp); // x_sloppy is zeroed out (and, by extension, so is x).
429  }
430  else
431  {
432  x_sloppyp = x_sloppy_saved_p; // x_sloppy point to saved x_sloppy memory.
433  blas::zero(*x_sloppyp); // and is zeroed out.
434  }
435 
436  // Syntatic sugar.
437  ColorSpinorField &r0 = *r0p;
438  ColorSpinorField &x_sloppy = *x_sloppyp;
439 
440  // Zero out the first search direction.
441  blas::zero(*u[0]);
442 
443 
444  // Set some initial values.
445  sigma[0] = blas::norm2(r_full);
446 
447 
448  // Initialize values.
449  for (int i = 1; i <= n_krylov; i++) { blas::zero(*r[i]); }
450 
451  rho0 = 1.0;
452  alpha = 0.0;
453  omega = 1.0;
454 
455  double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver.
456 
457  const bool use_heavy_quark_res =
458  (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false;
459  double heavy_quark_res = use_heavy_quark_res ? sqrt(blas::HeavyQuarkResidualNorm(x,r_full).z) : 0.0;
460  const int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual
461 
462  blas::flops = 0;
463  //bool l2_converge = false;
464  //double r2_old = r2;
465 
466  int pipeline = param.pipeline;
467 
468  // Create the worker class for updating non-critical r, u vectors.
469  BiCGstabLUpdate bicgstabl_update(&x_sloppy, r, u, &alpha, &beta, BICGSTABL_UPDATE_U, 0, matSloppy.getStencilSteps() );
470 
471 
472  // done with preamble, begin computing.
474  profile.TPSTART(QUDA_PROFILE_COMPUTE);
475 
476  // count iteration counts
477  int k = 0;
478 
479  // Various variables related to reliable updates.
480  int rUpdate = 0; // count reliable updates.
481  double delta = param.delta; // delta for reliable updates.
482  double rNorm = sqrt(r2); // The current residual norm.
483  double maxrr = rNorm; // The maximum residual norm since the last reliable update.
484  double maxrx = rNorm; // The same. Would be different if we did 'x' reliable updates.
485 
486  PrintStats(solver_name.c_str(), k, r2, b2, heavy_quark_res);
487  while(!convergence(r2, 0.0, stop, 0.0) && k < param.maxiter) {
488 
489  // rho0 = -omega*rho0;
490  rho0 *= -omega;
491 
492  // BiCG part of calculation.
493  for (int j = 0; j < n_krylov; j++) {
494  // rho1 = <r0, r_j>, beta = alpha*rho1/rho0, rho0 = rho1;
495  // Can fuse into updateXRend.
496  rho1 = blas::cDotProduct(r0, *r[j]);
497  beta = alpha*rho1/rho0;
498  rho0 = rho1;
499 
500  // for i = 0 .. j, u[i] = r[i] - beta*u[i]
501  // All but i = j is hidden in Dslash auxillary work (overlapping comms and compute).
502  /*for (int i = 0; i <= j; i++)
503  {
504  blas::caxpby(1.0, *r[i], -beta, *u[i]);
505  }*/
506  blas::caxpby(1.0, *r[j], -beta, *u[j]);
507  if (j > 0)
508  {
509  dslash::aux_worker = &bicgstabl_update;
510  bicgstabl_update.update_j_max(j);
511  bicgstabl_update.update_update_type(BICGSTABL_UPDATE_U);
512  }
513  else
514  {
515  dslash::aux_worker = NULL;
516  }
517 
518  // u[j+1] = A ( u[j] )
519  matSloppy(*u[j+1], *u[j], temp);
520 
521  // alpha = rho0/<r0, u[j+1]>
522  // The machinary isn't there yet, but this could be fused with the matSloppy above.
523  alpha = rho0/blas::cDotProduct(r0, *u[j+1]);
524 
525  // for i = 0 .. j, r[i] = r[i] - alpha u[i+1]
526  // All but i = j is hidden in Dslash auxillary work (overlapping comms and compute).
527  /*for (int i = 0; i <= j; i++)
528  {
529  blas::caxpy(-alpha, *u[i+1], *r[i]);
530  }*/
531  blas::caxpy(-alpha, *u[j+1], *r[j]);
532  // We can always at least update x.
533  dslash::aux_worker = &bicgstabl_update;
534  bicgstabl_update.update_j_max(j);
535  bicgstabl_update.update_update_type(BICGSTABL_UPDATE_R);
536 
537  // r[j+1] = A r[j], x = x + alpha*u[0]
538  matSloppy(*r[j+1], *r[j], temp);
539  dslash::aux_worker = NULL;
540 
541  } // End BiCG part.
542 
543  // MR part. Really just modified Gram-Schmidt.
544  // The algorithm uses the byproducts of the Gram-Schmidt to update x
545  // and other such niceties. One day I'll read the paper more closely.
546  // Can take this from 'orthoDir' in inv_gcr_quda.cpp, hard code pipelining up to l = 8.
547  for (int j = 1; j <= n_krylov; j++) {
548 
549  // This becomes a fused operator below.
550  /*for (int i = 1; i < j; i++)
551  {
552  // tau_ij = <r_i,r_j>/sigma_i.
553  tau[i][j] = blas::cDotProduct(*r[i], *r[j])/sigma[i];
554 
555  // r_j = r_j - tau_ij r_i;
556  blas::caxpy(-tau[i][j], *r[i], *r[j]);
557  }*/
558  orthoDir(tau, sigma, r, j, pipeline);
559 
560  // sigma_j = r_j^2, gamma'_j = <r_0, r_j>/sigma_j
561 
562  // This becomes a fused operator below.
563  //sigma[j] = blas::norm2(*r[j]);
564  //gamma_prime[j] = blas::cDotProduct(*r[j], *r[0])/sigma[j];
565 
566  // rjr.x = Re(<r[j],r[0]), rjr.y = Im(<r[j],r[0]>), rjr.z = <r[j],r[j]>
567  double3 rjr = blas::cDotProductNormA(*r[j], *r[0]);
568  sigma[j] = rjr.z;
569  gamma_prime[j] = Complex(rjr.x, rjr.y)/sigma[j];
570  }
571 
572  // gamma[n_krylov] = gamma'[n_krylov], omega = gamma[n_krylov]
573  gamma[n_krylov] = gamma_prime[n_krylov];
574  omega = gamma[n_krylov];
575 
576  // gamma = T^(-1) gamma_prime. It's in the paper, I promise.
577  for (int j = n_krylov - 1; j > 0; j--) {
578  // Internal def: gamma[j] = gamma'_j - \sum_{i = j+1 to n_krylov} tau_ji gamma_i
579  gamma[j] = gamma_prime[j];
580  for (int i = j + 1; i <= n_krylov; i++) { gamma[j] = gamma[j] - tau[j][i] * gamma[i]; }
581  }
582 
583  // gamma'' = T S gamma. Check paper for defn of S.
584  for (int j = 1; j < n_krylov; j++) {
585  gamma_prime_prime[j] = gamma[j+1];
586  for (int i = j + 1; i < n_krylov; i++) {
587  gamma_prime_prime[j] = gamma_prime_prime[j] + tau[j][i]*gamma[i+1];
588  }
589  }
590 
591  // Update x, r, u.
592  // x = x+ gamma_1 r_0, r_0 = r_0 - gamma'_l r_l, u_0 = u_0 - gamma_l u_l, where l = n_krylov.
593  // for (j = 0; j < n_krylov; j++) { caxpy(-gamma[j], *u[j], *u[0]); }
594  updateUend(gamma, u, n_krylov);
595 
596  // blas::caxpy(gamma[1], *r[0], x_sloppy);
597  // blas::caxpy(-gamma_prime[n_krylov], *r[n_krylov], *r[0]);
598  // for (j = 1; j < n_krylov; j++) {
599  // blas::caxpy(gamma_gamma_prime[j], *r[j], x_sloppy);
600  // blas::caxpy(-gamma_prime[j], *r[j], *r[0]);
601  //}
602  updateXRend(gamma, gamma_prime, gamma_prime_prime, r, x_sloppy, n_krylov);
603 
604  // sigma[0] = r_0^2
605  sigma[0] = blas::norm2(*r[0]);
606  r2 = sigma[0];
607 
608  // Check the heavy quark residual if we need to.
609  if (use_heavy_quark_res && k%heavy_quark_check==0) {
610  if (&x != &x_sloppy)
611  {
612  blas::copy(temp,y);
613  heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x_sloppy, temp, *r[0]).z);
614  }
615  else
616  {
617  blas::copy(r_full, *r[0]);
618  heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x, y, r_full).z);
619  }
620  }
621 
622  // Check if we need to do a reliable update.
623  // In inv_bicgstab_quda.cpp, there's a variable 'updateR' that holds the check.
624  // That variable gets carried about because there are a few different places 'r' can get
625  // updated (depending on if you're using pipelining or not). In BiCGstab-L, there's only
626  // one place (for now) to get the updated residual, so we just do away with 'updateR'.
627  // Further remark: "reliable" updates rNorm, maxrr, maxrx!!
628  if (reliable(rNorm, maxrx, maxrr, r2, delta))
629  {
630  if (x.Precision() != x_sloppy.Precision())
631  {
632  blas::copy(x, x_sloppy);
633  }
634 
635  blas::xpy(x, y); // swap these around? (copied from bicgstab)
636 
637  // Don't do aux work!
638  dslash::aux_worker = NULL;
639 
640  // Explicitly recompute the residual.
641  mat(r_full, y, x); // r[0] = Ax
642 
643  r2 = blas::xmyNorm(b, r_full); // r = b - Ax, return norm.
644 
645  sigma[0] = r2;
646 
647  if (x.Precision() != r[0]->Precision())
648  {
649  blas::copy(*r[0], r_full);
650  }
651  blas::zero(x_sloppy);
652 
653  // Update rNorm, maxrr, maxrx.
654  rNorm = sqrt(r2);
655  maxrr = rNorm;
656  maxrx = rNorm;
657 
658  // Increment the reliable update count.
659  rUpdate++;
660  }
661 
662  // Check convergence.
663  k += n_krylov;
664  PrintStats(solver_name.c_str(), k, r2, b2, heavy_quark_res);
665  } // Done iterating.
666 
667  if (x.Precision() != x_sloppy.Precision())
668  {
669  blas::copy(x, x_sloppy);
670  }
671 
672  blas::xpy(y, x);
673 
674  // Done with compute, begin the epilogue.
677 
679  double gflops = (blas::flops + mat.flops() + matSloppy.flops())*1e-9;
680  param.gflops = gflops;
681  param.iter += k;
682 
683  if (k >= param.maxiter) // >= if n_krylov doesn't divide max iter.
684  warningQuda("Exceeded maximum iterations %d", param.maxiter);
685 
686  // Print number of reliable updates.
687  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("%s: Reliable updates = %d\n", solver_name.c_str(), rUpdate);
688 
689  // compute the true residual
690  // !param.is_preconditioner comes from bicgstab, param.compute_true_res came from gcr.
691  if (!param.is_preconditioner && param.compute_true_res) { // do not do the below if this is an inner solver.
692  mat(r_full, x, y);
693  double true_res = blas::xmyNorm(b, r_full);
694  param.true_res = sqrt(true_res / b2);
695 
696  param.true_res_hq = use_heavy_quark_res ? sqrt(blas::HeavyQuarkResidualNorm(x,*r[0]).z) : 0.0;
697  }
698 
699  // Reset flops counters.
700  blas::flops = 0;
701  mat.flops();
702 
703  // copy the residual to b so we can use it outside of the solver.
705  {
706  blas::copy(b, r_full);
707  }
708 
709  // Done with epilogue, begin free.
710 
712  profile.TPSTART(QUDA_PROFILE_FREE);
713 
714  // ...yup...
715  PrintSummary(solver_name.c_str(), k, r2, b2, stop, param.tol_hq);
716 
717  // Done!
718  profile.TPSTOP(QUDA_PROFILE_FREE);
719  return;
720  }
721 
722 } // namespace quda
BiCGstabL(const DiracMatrix &mat, const DiracMatrix &matSloppy, SolverParam &param, TimeProfile &profile)
void operator()(ColorSpinorField &out, ColorSpinorField &in)
void apply(const qudaStream_t &stream)
void update_j_max(int new_j_max)
BiCGstabLUpdate(ColorSpinorField *x, std::vector< ColorSpinorField * > &r, std::vector< ColorSpinorField * > &u, Complex *alpha, Complex *beta, BiCGstabLUpdateType update_type, int j_max, int n_update)
void update_update_type(BiCGstabLUpdateType new_update_type)
static ColorSpinorField * Create(const ColorSpinorParam &param)
unsigned long long flops() const
Definition: dirac_quda.h:1909
virtual int getStencilSteps() const =0
QudaPrecision Precision() const
TimeProfile & profile
Definition: invert_quda.h:471
const DiracMatrix & mat
Definition: invert_quda.h:465
bool convergence(double r2, double hq2, double r2_tol, double hq_tol)
Definition: solver.cpp:328
void PrintSummary(const char *name, int k, double r2, double b2, double r2_tol, double hq_tol)
Prints out the summary of the solver convergence (requires a verbosity of QUDA_SUMMARIZE)....
Definition: solver.cpp:386
SolverParam & param
Definition: invert_quda.h:470
static double stopping(double tol, double b2, QudaResidualType residual_type)
Set the solver L2 stopping condition.
Definition: solver.cpp:311
void PrintStats(const char *name, int k, double r2, double b2, double hq2)
Prints out the running statistics of the solver (requires a verbosity of QUDA_VERBOSE)
Definition: solver.cpp:373
const DiracMatrix & matSloppy
Definition: invert_quda.h:466
double Last(QudaProfileType idx)
Definition: timer.h:254
int pipeline
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
@ QUDA_USE_INIT_GUESS_YES
Definition: enum_quda.h:430
@ QUDA_VERBOSE
Definition: enum_quda.h:267
@ QUDA_HEAVY_QUARK_RESIDUAL
Definition: enum_quda.h:195
@ QUDA_PRESERVE_SOURCE_NO
Definition: enum_quda.h:238
@ QUDA_ZERO_FIELD_CREATE
Definition: enum_quda.h:361
@ QUDA_COMPUTE_NULL_VECTOR_NO
Definition: enum_quda.h:441
void init()
Create the BLAS context.
Complex caxpyDotzy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
void caxpby(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y)
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:79
unsigned long long flops
void caxpyBxpz(const Complex &, ColorSpinorField &, ColorSpinorField &, const Complex &, ColorSpinorField &)
double3 xpyHeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &r)
void zero(ColorSpinorField &a)
double norm2(const ColorSpinorField &a)
double3 cDotProductNormA(ColorSpinorField &a, ColorSpinorField &b)
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
void xpy(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:41
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
Definition: blas_quda.h:24
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
void stop()
Stop profiling.
Definition: device.cpp:228
std::complex< double > Complex
Definition: quda_internal.h:86
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:120
qudaStream_t * stream
@ QUDA_PROFILE_EPILOGUE
Definition: timer.h:110
@ QUDA_PROFILE_COMPUTE
Definition: timer.h:108
@ QUDA_PROFILE_FREE
Definition: timer.h:111
@ QUDA_PROFILE_PREAMBLE
Definition: timer.h:107
ColorSpinorParam csParam
Definition: pack_test.cpp:25
QudaGaugeParam param
Definition: pack_test.cpp:18
void updateR()
update the radius for halos.
cudaStream_t qudaStream_t
Definition: quda_api.h:9
QudaPreserveSource preserve_source
Definition: invert_quda.h:151
QudaComputeNullVector compute_null_vector
Definition: invert_quda.h:61
bool is_preconditioner
verbosity to use for preconditioner
Definition: invert_quda.h:238
bool use_sloppy_partial_accumulator
Definition: invert_quda.h:70
QudaResidualType residual_type
Definition: invert_quda.h:49
QudaPrecision precision_sloppy
Definition: invert_quda.h:139
QudaUseInitGuess use_init_guess
Definition: invert_quda.h:58
#define printfQuda(...)
Definition: util_quda.h:114
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define warningQuda(...)
Definition: util_quda.h:132
#define errorQuda(...)
Definition: util_quda.h:120