QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
inv_bicgstabl_quda.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 #include <iostream>
5 #include <sstream>
6 #include <complex>
7 
8 #include <quda_internal.h>
9 #include <color_spinor_field.h>
10 #include <blas_quda.h>
11 #include <dslash_quda.h>
12 #include <invert_quda.h>
13 #include <util_quda.h>
14 
15 namespace quda {
16 
17 
18  // Utility functions for Gram-Schmidt. Based on GCR functions.
19  // Big change is we need to go from 1 to nKrylov, not 0 to nKrylov-1.
20 
21  void BiCGstabL::computeTau(Complex **tau, double* sigma, std::vector<ColorSpinorField*> r, int begin, int size, int j)
22  {
23  Complex *Tau = new Complex[size];
24  std::vector<ColorSpinorField*> a(size), b(1);
25  for (int k=0; k<size; k++)
26  {
27  a[k] = r[begin+k];
28  Tau[k] = 0;
29  }
30  b[0] = r[j];
31  blas::cDotProduct(Tau, a, b); // vectorized dot product
32 
33  for (int k=0; k<size; k++)
34  {
35  tau[begin+k][j] = Tau[k]/sigma[begin+k];
36  }
37  delete []Tau;
38  }
39 
40  void BiCGstabL::updateR(Complex **tau, std::vector<ColorSpinorField*> r, int begin, int size, int j)
41  {
42 
43  Complex *tau_ = new Complex[size];
44  for (int i=0; i<size; i++)
45  {
46  tau_[i] = -tau[i+begin][j];
47  }
48 
49  std::vector<ColorSpinorField*> r_(r.begin() + begin, r.begin() + begin + size);
50  std::vector<ColorSpinorField*> rj(r.begin() + j, r.begin() + j + 1);
51 
52  blas::caxpy(tau_, r_, rj);
53 
54  delete[] tau_;
55  }
56 
57  void BiCGstabL::orthoDir(Complex **tau, double* sigma, std::vector<ColorSpinorField*> r, int j, int pipeline)
58  {
59 
60  switch (pipeline)
61  {
62  case 0: // no kernel fusion
63  for (int i=1; i<j; i++) // 5 (j-2) memory transactions here. Start at 1 b/c bicgstabl convention.
64  {
65  tau[i][j] = blas::cDotProduct(*r[i], *r[j])/sigma[i];
66  blas::caxpy(-tau[i][j], *r[i], *r[j]);
67  }
68  break;
69  case 1: // basic kernel fusion
70  if (j==1) // start at 1.
71  {
72  break;
73  }
74  tau[1][j] = blas::cDotProduct(*r[1], *r[j])/sigma[1];
75  for (int i=1; i<j-1; i++) // 4 (j-2) memory transactions here. start at 1.
76  {
77  tau[i+1][j] = blas::caxpyDotzy(-tau[i][j], *r[i], *r[j], *r[i+1])/sigma[i+1];
78  }
79  blas::caxpy(-tau[j-1][j], *r[j-1], *r[j]);
80  break;
81  default:
82  {
83  const int N = pipeline;
84  // We're orthogonalizing r[j] against r[1], ..., r[j-1].
85  // We need to do (j-1)/N updates of length N, at 1,1+N,1+2*N,...
86  // After, we do 1 update of length (j-1)%N.
87 
88  // (j-1)/N updates of length N, at 1,1+N,1+2*N,...
89  int step;
90  for (step = 0; step < (j-1)/N; step++)
91  {
92  computeTau(tau, sigma, r, 1+step*N, N, j);
93  updateR(tau, r, 1+step*N, N, j);
94  }
95 
96  if ((j-1)%N != 0) // need to update the remainder
97  {
98  // 1 update of length (j-1)%N.
99  computeTau(tau, sigma, r, 1+step*N, (j-1)%N, j);
100  updateR(tau, r, 1+step*N, (j-1)%N, j);
101  }
102  }
103  break;
104  }
105 
106  }
107 
108  void BiCGstabL::updateUend(Complex* gamma, std::vector<ColorSpinorField*> u, int nKrylov)
109  {
110  // for (j = 0; j <= nKrylov; j++) { caxpy(-gamma[j], *u[j], *u[0]); }
111  Complex *gamma_ = new Complex[nKrylov];
112  for (int i = 0; i < nKrylov; i++)
113  {
114  gamma_[i] = -gamma[i+1];
115  }
116 
117  std::vector<ColorSpinorField*> u_(u.begin() + 1, u.end());
118  std::vector<ColorSpinorField*> u0(u.begin(), u.begin() + 1);
119 
120  blas::caxpy(gamma_, u_, u0);
121 
122  delete[] gamma_;
123  }
124 
126  std::vector<ColorSpinorField*> r, ColorSpinorField& x, int nKrylov)
127  {
128  /*
129  blas::caxpy(gamma[1], *r[0], x_sloppy);
130  blas::caxpy(-gamma_prime[nKrylov], *r[nKrylov], *r[0]);
131  for (j = 1; j < nKrylov; j++)
132  {
133  caxpy(gamma_prime_prime[j], *r[j], x);
134  caxpy(-gamma_prime[j], *r[j], *r[0]);
135  }
136  */
137 
138  // This does two "wasted" caxpys (so 2*nKrylov+2 instead of 2*nKrylov), but
139  // the alternative way would be un-fusing some calls, which would require
140  // loading and saving x twice. In a solve where the sloppy precision is lower than
141  // the full precision, this can be a killer.
142  Complex *gamma_prime_prime_ = new Complex[nKrylov+1];
143  Complex *gamma_prime_ = new Complex[nKrylov+1];
144  gamma_prime_prime_[0] = gamma[1];
145  gamma_prime_prime_[nKrylov] = 0.0; // x never gets updated with r[nKrylov]
146  gamma_prime_[0] = 0.0; // r[0] never gets updated with r[0]... obvs.
147  gamma_prime_[nKrylov] = -gamma_prime[nKrylov];
148  for (int i = 1; i < nKrylov; i++)
149  {
150  gamma_prime_prime_[i] = gamma_prime_prime[i];
151  gamma_prime_[i] = -gamma_prime[i];
152  }
153  blas::caxpyBxpz(gamma_prime_prime_, r, x, gamma_prime_, *r[0]);
154 
155  delete[] gamma_prime_prime_;
156  delete[] gamma_prime_;
157  }
158 
174  {
177  };
178 
179  class BiCGstabLUpdate : public Worker {
180 
182  std::vector<ColorSpinorField*> &r;
183  std::vector<ColorSpinorField*> &u;
184 
187 
189 
194  int j_max;
195 
201  int n_update;
202 
203  public:
204  BiCGstabLUpdate(ColorSpinorField* x, std::vector<ColorSpinorField*>& r, std::vector<ColorSpinorField*>& u,
205  Complex* alpha, Complex* beta, BiCGstabLUpdateType update_type, int j_max, int n_update) :
206  x(x), r(r), u(u), alpha(alpha), beta(beta), j_max(j_max),
207  n_update(n_update)
208  {
209 
210  }
211  virtual ~BiCGstabLUpdate() { }
212 
213  void update_j_max(int new_j_max) { j_max = new_j_max; }
214  void update_update_type(BiCGstabLUpdateType new_update_type) { update_type = new_update_type; }
215 
216  // note that we can't set the stream parameter here so it is
217  // ignored. This is more of a future design direction to consider
218  void apply(const cudaStream_t &stream) {
219  static int count = 0;
220 
221  // on the first call do the first half of the update
222  if (update_type == BICGSTABL_UPDATE_U)
223  {
224  for (int i= (count*j_max)/n_update; i<((count+1)*j_max)/n_update && i<j_max; i++)
225  {
226  blas::caxpby(1.0, *r[i], -*beta, *u[i]);
227  }
228  }
229  else // (update_type == BICGSTABL_UPDATE_R)
230  {
231  if (count == 0)
232  {
233  blas::caxpy(*alpha, *u[0], *x);
234  }
235  if (j_max > 0)
236  {
237  for (int i= (count*j_max)/n_update; i<((count+1)*j_max)/n_update && i<j_max; i++)
238  {
239  blas::caxpy(-*alpha, *u[i+1], *r[i]);
240  }
241  }
242  }
243 
244  if (++count == n_update) count = 0;
245 
246  }
247 
248  };
249 
250  // this is the Worker pointer that the dslash uses to launch the shifted updates
251  namespace dslash {
252  extern Worker* aux_worker;
253  }
254 
256  Solver(param, profile), mat(mat), matSloppy(matSloppy), nKrylov(param.Nkrylov), init(false)
257  {
258  r.resize(nKrylov+1);
259  u.resize(nKrylov+1);
260 
261  gamma = new Complex[nKrylov+1];
262  gamma_prime = new Complex[nKrylov+1];
264  sigma = new double[nKrylov+1];
265 
266  tau = new Complex*[nKrylov+1];
267  for (int i = 0; i < nKrylov+1; i++) { tau[i] = new Complex[nKrylov+1]; }
268 
269  std::stringstream ss;
270  ss << "BiCGstab-" << nKrylov;
271  solver_name = ss.str();
272  }
273 
275  profile.TPSTART(QUDA_PROFILE_FREE);
276  delete[] gamma;
277  delete[] gamma_prime;
278  delete[] gamma_prime_prime;
279  delete[] sigma;
280 
281  for (int i = 0; i < nKrylov+1; i++) { delete[] tau[i]; }
282  delete[] tau;
283 
284  if (init) {
285  delete r_sloppy_saved_p;
286  delete u[0];
287  for (int i = 1; i < nKrylov+1; i++) {
288  delete r[i];
289  delete u[i];
290  }
291 
292  delete x_sloppy_saved_p;
293  delete r_fullp;
294  delete r0_saved_p;
295  delete yp;
296  delete tempp;
297 
298  init = false;
299  }
300 
301  profile.TPSTOP(QUDA_PROFILE_FREE);
302 
303  }
304 
305  // Code to check for reliable updates, copied from inv_bicgstab_quda.cpp
306  // Technically, there are ways to check both 'x' and 'r' for reliable updates...
307  // the current status in BiCGstab is to just look for reliable updates in 'r'.
308  int BiCGstabL::reliable(double &rNorm, double &maxrx, double &maxrr, const double &r2, const double &delta) {
309  // reliable updates
310  rNorm = sqrt(r2);
311  if (rNorm > maxrx) maxrx = rNorm;
312  if (rNorm > maxrr) maxrr = rNorm;
313  //int updateR = (rNorm < delta*maxrr && r0Norm <= maxrr) ? 1 : 0;
314  //int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? 1 : 0
315  int updateR = (rNorm < delta*maxrr) ? 1 : 0;
316 
317  //printf("reliable %d %e %e %e %e\n", updateR, rNorm, maxrx, maxrr, r2);
318 
319  return updateR;
320  }
321 
323  {
324  // BiCGstab-l is based on the algorithm outlined in
325  // BICGSTAB(L) FOR LINEAR EQUATIONS INVOLVING UNSYMMETRIC MATRICES WITH COMPLEX SPECTRUM
326  // G. Sleijpen, D. Fokkema, 1993.
327  // My implementation is based on Kate Clark's implementation in CPS, to be found in
328  // src/util/dirac_op/d_op_wilson_types/bicgstab.C
329 
330  // Begin profiling preamble.
332 
333  if (!init) {
334  // Initialize fields.
336  csParam.create = QUDA_ZERO_FIELD_CREATE;
337 
338  // Full precision variables.
340 
341  // Create temporary.
342  yp = ColorSpinorField::Create(csParam);
343 
344  // Sloppy precision variables.
346 
347  // Sloppy solution.
348  x_sloppy_saved_p = ColorSpinorField::Create(csParam); // Used depending on precision.
349 
350  // Shadow residual.
351  r0_saved_p = ColorSpinorField::Create(csParam); // Used depending on precision.
352 
353  // Temporary
354  tempp = ColorSpinorField::Create(csParam);
355 
356  // Residual (+ extra residuals for BiCG steps), Search directions.
357  // Remark: search directions are sloppy in GCR. I wonder if we can
358  // get away with that here.
359  for (int i = 0; i <= nKrylov; i++) {
360  r[i] = ColorSpinorField::Create(csParam);
361  u[i] = ColorSpinorField::Create(csParam);
362  }
363  r_sloppy_saved_p = r[0]; // Used depending on precision.
364 
365  init = true;
366  }
367 
368  double b2 = blas::norm2(b); // norm sq of source.
369  double r2; // norm sq of residual
370 
371  ColorSpinorField &r_full = *r_fullp;
372  ColorSpinorField &y = *yp;
373  ColorSpinorField &temp = *tempp;
374 
375  ColorSpinorField *r0p, *x_sloppyp; // Get assigned below.
376 
377  // Compute initial residual depending on whether we have an initial guess or not.
379  mat(r_full, x, y); // r[0] = Ax
380  r2 = blas::xmyNorm(b, r_full); // r = b - Ax, return norm.
381  blas::copy(y, x);
382  } else {
383  blas::copy(r_full, b); // r[0] = b
384  r2 = b2;
385  blas::zero(x); // defensive measure in case solution isn't already zero
386  blas::zero(y);
387  }
388 
389  // Check to see that we're not trying to invert on a zero-field source
390  if (b2 == 0) {
392  warningQuda("inverting on zero-field source");
393  x = b;
394  param.true_res = 0.0;
395  param.true_res_hq = 0.0;
397  return;
399  b2 = r2;
400  } else {
401  errorQuda("Null vector computing requires non-zero guess!");
402  }
403  }
404 
405 
406 
407  // Set field aliasing according to whether we're doing mixed precision or not.
408  // There probably be bugs and headaches hiding here.
409  if (param.precision_sloppy == x.Precision()) {
410  r[0] = &r_full; // r[0] \equiv r_sloppy points to the same memory location as r.
412  {
413  r0p = &b; // r0, b point to the same vector in memory.
414  }
415  else
416  {
417  r0p = r0_saved_p; // r0p points to the saved r0 memory.
418  *r0p = r_full; // and is set equal to r.
419  }
420  }
421  else
422  {
423  r0p = r0_saved_p; // r0p points to saved r0 memory.
424  r[0] = r_sloppy_saved_p; // r[0] points to saved r_sloppy memory.
425  *r0p = r_full; // and is set equal to r.
426  *r[0] = r_full; // yup.
427  }
428 
430  {
431  x_sloppyp = &x; // x_sloppy and x point to the same vector in memory.
432  blas::zero(*x_sloppyp); // x_sloppy is zeroed out (and, by extension, so is x).
433  }
434  else
435  {
436  x_sloppyp = x_sloppy_saved_p; // x_sloppy point to saved x_sloppy memory.
437  blas::zero(*x_sloppyp); // and is zeroed out.
438  }
439 
440  // Syntatic sugar.
441  ColorSpinorField &r0 = *r0p;
442  ColorSpinorField &x_sloppy = *x_sloppyp;
443 
444  // Zero out the first search direction.
445  blas::zero(*u[0]);
446 
447 
448  // Set some initial values.
449  sigma[0] = blas::norm2(r_full);
450 
451 
452  // Initialize values.
453  for (int i = 1; i <= nKrylov; i++)
454  {
455  blas::zero(*r[i]);
456  }
457 
458  rho0 = 1.0;
459  alpha = 0.0;
460  omega = 1.0;
461 
462  double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver.
463 
464  const bool use_heavy_quark_res =
465  (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false;
466  double heavy_quark_res = use_heavy_quark_res ? sqrt(blas::HeavyQuarkResidualNorm(x,r_full).z) : 0.0;
467  const int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual
468 
469  blas::flops = 0;
470  //bool l2_converge = false;
471  //double r2_old = r2;
472 
473  int pipeline = param.pipeline;
474 
475  // Create the worker class for updating non-critical r, u vectors.
476  BiCGstabLUpdate bicgstabl_update(&x_sloppy, r, u, &alpha, &beta, BICGSTABL_UPDATE_U, 0, matSloppy.getStencilSteps() );
477 
478 
479  // done with preamble, begin computing.
481  profile.TPSTART(QUDA_PROFILE_COMPUTE);
482 
483  // count iteration counts
484  int k = 0;
485 
486  // Various variables related to reliable updates.
487  int rUpdate = 0; // count reliable updates.
488  double delta = param.delta; // delta for reliable updates.
489  double rNorm = sqrt(r2); // The current residual norm.
490  double maxrr = rNorm; // The maximum residual norm since the last reliable update.
491  double maxrx = rNorm; // The same. Would be different if we did 'x' reliable updates.
492 
493  PrintStats(solver_name.c_str(), k, r2, b2, heavy_quark_res);
494  while(!convergence(r2, 0.0, stop, 0.0) && k < param.maxiter) {
495 
496  // rho0 = -omega*rho0;
497  rho0 *= -omega;
498 
499  // BiCG part of calculation.
500  for (int j = 0; j < nKrylov; j++) {
501  // rho1 = <r0, r_j>, beta = alpha*rho1/rho0, rho0 = rho1;
502  // Can fuse into updateXRend.
503  rho1 = blas::cDotProduct(r0, *r[j]);
504  beta = alpha*rho1/rho0;
505  rho0 = rho1;
506 
507  // for i = 0 .. j, u[i] = r[i] - beta*u[i]
508  // All but i = j is hidden in Dslash auxillary work (overlapping comms and compute).
509  /*for (int i = 0; i <= j; i++)
510  {
511  blas::caxpby(1.0, *r[i], -beta, *u[i]);
512  }*/
513  blas::caxpby(1.0, *r[j], -beta, *u[j]);
514  if (j > 0)
515  {
516  dslash::aux_worker = &bicgstabl_update;
517  bicgstabl_update.update_j_max(j);
518  bicgstabl_update.update_update_type(BICGSTABL_UPDATE_U);
519  }
520  else
521  {
522  dslash::aux_worker = NULL;
523  }
524 
525  // u[j+1] = A ( u[j] )
526  matSloppy(*u[j+1], *u[j], temp);
527 
528  // alpha = rho0/<r0, u[j+1]>
529  // The machinary isn't there yet, but this could be fused with the matSloppy above.
530  alpha = rho0/blas::cDotProduct(r0, *u[j+1]);
531 
532  // for i = 0 .. j, r[i] = r[i] - alpha u[i+1]
533  // All but i = j is hidden in Dslash auxillary work (overlapping comms and compute).
534  /*for (int i = 0; i <= j; i++)
535  {
536  blas::caxpy(-alpha, *u[i+1], *r[i]);
537  }*/
538  blas::caxpy(-alpha, *u[j+1], *r[j]);
539  // We can always at least update x.
540  dslash::aux_worker = &bicgstabl_update;
541  bicgstabl_update.update_j_max(j);
542  bicgstabl_update.update_update_type(BICGSTABL_UPDATE_R);
543 
544  // r[j+1] = A r[j], x = x + alpha*u[0]
545  matSloppy(*r[j+1], *r[j], temp);
546  dslash::aux_worker = NULL;
547 
548  } // End BiCG part.
549 
550  // MR part. Really just modified Gram-Schmidt.
551  // The algorithm uses the byproducts of the Gram-Schmidt to update x
552  // and other such niceties. One day I'll read the paper more closely.
553  // Can take this from 'orthoDir' in inv_gcr_quda.cpp, hard code pipelining up to l = 8.
554  for (int j = 1; j <= nKrylov; j++)
555  {
556 
557 
558  // This becomes a fused operator below.
559  /*for (int i = 1; i < j; i++)
560  {
561  // tau_ij = <r_i,r_j>/sigma_i.
562  tau[i][j] = blas::cDotProduct(*r[i], *r[j])/sigma[i];
563 
564  // r_j = r_j - tau_ij r_i;
565  blas::caxpy(-tau[i][j], *r[i], *r[j]);
566  }*/
567  orthoDir(tau, sigma, r, j, pipeline);
568 
569  // sigma_j = r_j^2, gamma'_j = <r_0, r_j>/sigma_j
570 
571  // This becomes a fused operator below.
572  //sigma[j] = blas::norm2(*r[j]);
573  //gamma_prime[j] = blas::cDotProduct(*r[j], *r[0])/sigma[j];
574 
575  // rjr.x = Re(<r[j],r[0]), rjr.y = Im(<r[j],r[0]>), rjr.z = <r[j],r[j]>
576  double3 rjr = blas::cDotProductNormA(*r[j], *r[0]);
577  sigma[j] = rjr.z;
578  gamma_prime[j] = Complex(rjr.x, rjr.y)/sigma[j];
579  }
580 
581  // gamma[nKrylov] = gamma'[nKrylov], omega = gamma[nKrylov]
583  omega = gamma[nKrylov];
584 
585  // gamma = T^(-1) gamma_prime. It's in the paper, I promise.
586  for (int j = nKrylov-1; j > 0; j--)
587  {
588  // Internal def: gamma[j] = gamma'_j - \sum_{i = j+1 to nKrylov} tau_ji gamma_i
589  gamma[j] = gamma_prime[j];
590  for (int i = j+1; i <= nKrylov; i++)
591  {
592  gamma[j] = gamma[j] - tau[j][i]*gamma[i];
593  }
594  }
595 
596  // gamma'' = T S gamma. Check paper for defn of S.
597  for (int j = 1; j < nKrylov; j++)
598  {
599  gamma_prime_prime[j] = gamma[j+1];
600  for (int i = j+1; i < nKrylov; i++)
601  {
602  gamma_prime_prime[j] = gamma_prime_prime[j] + tau[j][i]*gamma[i+1];
603  }
604  }
605 
606  // Update x, r, u.
607  // x = x+ gamma_1 r_0, r_0 = r_0 - gamma'_l r_l, u_0 = u_0 - gamma_l u_l, where l = nKrylov.
608  // for (j = 0; j < nKrylov; j++) { caxpy(-gamma[j], *u[j], *u[0]); }
609  updateUend(gamma, u, nKrylov);
610 
611  //blas::caxpy(gamma[1], *r[0], x_sloppy);
612  //blas::caxpy(-gamma_prime[nKrylov], *r[nKrylov], *r[0]);
613  //for (j = 1; j < nKrylov; j++) {
614  // blas::caxpy(gamma_gamma_prime[j], *r[j], x_sloppy);
615  // blas::caxpy(-gamma_prime[j], *r[j], *r[0]);
616  //}
617  updateXRend(gamma, gamma_prime, gamma_prime_prime, r, x_sloppy, nKrylov);
618 
619  // sigma[0] = r_0^2
620  sigma[0] = blas::norm2(*r[0]);
621  r2 = sigma[0];
622 
623  // Check the heavy quark residual if we need to.
624  if (use_heavy_quark_res && k%heavy_quark_check==0) {
625  if (&x != &x_sloppy)
626  {
627  blas::copy(temp,y);
628  heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x_sloppy, temp, *r[0]).z);
629  }
630  else
631  {
632  blas::copy(r_full, *r[0]);
633  heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x, y, r_full).z);
634  }
635  }
636 
637  // Check if we need to do a reliable update.
638  // In inv_bicgstab_quda.cpp, there's a variable 'updateR' that holds the check.
639  // That variable gets carried about because there are a few different places 'r' can get
640  // updated (depending on if you're using pipelining or not). In BiCGstab-L, there's only
641  // one place (for now) to get the updated residual, so we just do away with 'updateR'.
642  // Further remark: "reliable" updates rNorm, maxrr, maxrx!!
643  if (reliable(rNorm, maxrx, maxrr, r2, delta))
644  {
645  if (x.Precision() != x_sloppy.Precision())
646  {
647  blas::copy(x, x_sloppy);
648  }
649 
650  blas::xpy(x, y); // swap these around? (copied from bicgstab)
651 
652  // Don't do aux work!
653  dslash::aux_worker = NULL;
654 
655  // Explicitly recompute the residual.
656  mat(r_full, y, x); // r[0] = Ax
657 
658  r2 = blas::xmyNorm(b, r_full); // r = b - Ax, return norm.
659 
660  sigma[0] = r2;
661 
662  if (x.Precision() != r[0]->Precision())
663  {
664  blas::copy(*r[0], r_full);
665  }
666  blas::zero(x_sloppy);
667 
668  // Update rNorm, maxrr, maxrx.
669  rNorm = sqrt(r2);
670  maxrr = rNorm;
671  maxrx = rNorm;
672 
673  // Increment the reliable update count.
674  rUpdate++;
675  }
676 
677  // Check convergence.
678  k += nKrylov;
679  PrintStats(solver_name.c_str(), k, r2, b2, heavy_quark_res);
680  } // Done iterating.
681 
682  if (x.Precision() != x_sloppy.Precision())
683  {
684  blas::copy(x, x_sloppy);
685  }
686 
687  blas::xpy(y, x);
688 
689  // Done with compute, begin the epilogue.
692 
694  double gflops = (blas::flops + mat.flops() + matSloppy.flops())*1e-9;
695  param.gflops = gflops;
696  param.iter += k;
697 
698  if (k >= param.maxiter) // >= if nKrylov doesn't divide max iter.
699  warningQuda("Exceeded maximum iterations %d", param.maxiter);
700 
701  // Print number of reliable updates.
702  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("%s: Reliable updates = %d\n", solver_name.c_str(), rUpdate);
703 
704  // compute the true residual
705  // !param.is_preconditioner comes from bicgstab, param.compute_true_res came from gcr.
706  if (!param.is_preconditioner && param.compute_true_res) { // do not do the below if this is an inner solver.
707  mat(r_full, x, y);
708  double true_res = blas::xmyNorm(b, r_full);
709  param.true_res = sqrt(true_res / b2);
710 
711  param.true_res_hq = use_heavy_quark_res ? sqrt(blas::HeavyQuarkResidualNorm(x,*r[0]).z) : 0.0;
712  }
713 
714  // Reset flops counters.
715  blas::flops = 0;
716  mat.flops();
717 
718  // copy the residual to b so we can use it outside of the solver.
720  {
721  blas::copy(b, r_full);
722  }
723 
724  // Done with epilogue, begin free.
725 
727  profile.TPSTART(QUDA_PROFILE_FREE);
728 
729  // ...yup...
730  PrintSummary(solver_name.c_str(), k, r2, b2, stop, param.tol_hq);
731 
732  // Done!
733  profile.TPSTOP(QUDA_PROFILE_FREE);
734  return;
735  }
736 
737 } // namespace quda
ColorSpinorField * x_sloppy_saved_p
Definition: invert_quda.h:782
void update_update_type(BiCGstabLUpdateType new_update_type)
ColorSpinorField * yp
Full precision residual.
Definition: invert_quda.h:775
void setPrecision(QudaPrecision precision, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION, bool force_native=false)
int pipeline
Definition: test_util.cpp:1634
double3 cDotProductNormA(ColorSpinorField &a, ColorSpinorField &b)
Definition: reduce_quda.cu:778
ColorSpinorField * r0_saved_p
Sloppy solution vector.
Definition: invert_quda.h:783
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define errorQuda(...)
Definition: util_quda.h:121
double norm2(const ColorSpinorField &a)
Definition: reduce_quda.cu:721
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:120
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
Definition: reduce_quda.cu:764
void PrintStats(const char *name, int k, double r2, double b2, double hq2)
Prints out the running statistics of the solver (requires a verbosity of QUDA_VERBOSE) ...
Definition: solver.cpp:256
cudaStream_t * stream
BiCGstabL(DiracMatrix &mat, DiracMatrix &matSloppy, SolverParam &param, TimeProfile &profile)
double3 xpyHeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &r)
Definition: reduce_quda.cu:818
static ColorSpinorField * Create(const ColorSpinorParam &param)
void computeTau(Complex **tau, double *sigma, std::vector< ColorSpinorField *> r, int begin, int size, int j)
bool convergence(double r2, double hq2, double r2_tol, double hq_tol)
Definition: solver.cpp:223
TimeProfile & profile
Definition: invert_quda.h:464
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
Definition: copy_quda.cu:355
Complex * gamma
Definition: invert_quda.h:768
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:75
std::vector< ColorSpinorField * > u
Definition: invert_quda.h:779
Complex * gamma_prime
Definition: invert_quda.h:768
QudaPreserveSource preserve_source
Definition: invert_quda.h:154
ColorSpinorField * r_fullp
Definition: invert_quda.h:774
std::string solver_name
Definition: invert_quda.h:807
DiracMatrix & mat
Definition: invert_quda.h:758
BiCGstabLUpdate(ColorSpinorField *x, std::vector< ColorSpinorField *> &r, std::vector< ColorSpinorField *> &u, Complex *alpha, Complex *beta, BiCGstabLUpdateType update_type, int j_max, int n_update)
void caxpyBxpz(const Complex &, ColorSpinorField &, ColorSpinorField &, const Complex &, ColorSpinorField &)
Definition: blas_quda.cu:574
QudaComputeNullVector compute_null_vector
Definition: invert_quda.h:67
std::vector< ColorSpinorField * > & u
double Last(QudaProfileType idx)
Definition: timer.h:251
Complex * gamma_prime_prime
Definition: invert_quda.h:768
QudaResidualType residual_type
Definition: invert_quda.h:49
Worker * aux_worker
Definition: dslash_quda.cu:87
static double stopping(double tol, double b2, QudaResidualType residual_type)
Set the solver L2 stopping condition.
Definition: solver.cpp:206
ColorSpinorParam csParam
Definition: pack_test.cpp:24
void apply(const cudaStream_t &stream)
constexpr int size
#define warningQuda(...)
Definition: util_quda.h:133
bool is_preconditioner
verbosity to use for preconditioner
Definition: invert_quda.h:241
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
Definition: reduce_quda.cu:809
std::complex< double > Complex
Definition: quda_internal.h:46
std::vector< ColorSpinorField * > r
Sloppy temporary vector.
Definition: invert_quda.h:778
void orthoDir(Complex **tau, double *sigma, std::vector< ColorSpinorField *> r, int j, int pipeline)
const DiracMatrix & matSloppy
Definition: invert_quda.h:759
void updateUend(Complex *gamma, std::vector< ColorSpinorField *> u, int nKrylov)
Complex caxpyDotzy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Definition: reduce_quda.cu:771
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:512
void zero(ColorSpinorField &a)
Definition: blas_quda.cu:472
ColorSpinorField * tempp
Full precision temporary.
Definition: invert_quda.h:777
int reliable(double &rNorm, double &maxrx, double &maxrr, const double &r2, const double &delta)
Current residual, in BiCG language.
SolverParam & param
Definition: invert_quda.h:463
void updateR(Complex **tau, std::vector< ColorSpinorField *> r, int begin, int size, int j)
unsigned long long flops() const
Definition: dirac_quda.h:1119
void updateXRend(Complex *gamma, Complex *gamma_prime, Complex *gamma_prime_prime, std::vector< ColorSpinorField *> r, ColorSpinorField &x, int nKrylov)
#define printfQuda(...)
Definition: util_quda.h:115
ColorSpinorField * r_sloppy_saved_p
Shadow residual, in BiCG language.
Definition: invert_quda.h:784
unsigned long long flops
Definition: blas_quda.cu:22
void xpy(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:33
double * sigma
Definition: invert_quda.h:770
virtual int getStencilSteps() const =0
void caxpby(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y)
Definition: blas_quda.cu:523
std::vector< ColorSpinorField * > & r
QudaUseInitGuess use_init_guess
Definition: invert_quda.h:64
QudaPrecision precision_sloppy
Definition: invert_quda.h:145
bool use_sloppy_partial_accumulator
Definition: invert_quda.h:76
void PrintSummary(const char *name, int k, double r2, double b2, double r2_tol, double hq_tol)
Prints out the summary of the solver convergence (requires a verbosity of QUDA_SUMMARIZE). Assumes SolverParam.true_res and SolverParam.true_res_hq has been set.
Definition: solver.cpp:270
void update_j_max(int new_j_max)
QudaPrecision Precision() const
__device__ unsigned int count[QUDA_MAX_MULTI_REDUCE]
Definition: cub_helper.cuh:90
BiCGstabLUpdateType update_type
void operator()(ColorSpinorField &out, ColorSpinorField &in)
Complex ** tau
Definition: invert_quda.h:769