QUDA  v1.1.0
A library for QCD on GPUs
inv_cg_quda.cpp
Go to the documentation of this file.
1 #include <cstdio>
2 #include <cstdlib>
3 #include <cmath>
4 #include <limits>
5 #include <memory>
6 #include <iostream>
7 
8 #include <quda_internal.h>
9 #include <color_spinor_field.h>
10 #include <blas_quda.h>
11 #include <dslash_quda.h>
12 #include <invert_quda.h>
13 #include <util_quda.h>
14 #include <eigensolve_quda.h>
15 #include <eigen_helper.h>
16 
17 namespace quda {
18 
19  CG::CG(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig,
20  SolverParam &param, TimeProfile &profile) :
21  Solver(mat, matSloppy, matPrecon, matEig, param, profile),
22  yp(nullptr),
23  rp(nullptr),
24  rnewp(nullptr),
25  pp(nullptr),
26  App(nullptr),
27  tmpp(nullptr),
28  tmp2p(nullptr),
29  tmp3p(nullptr),
30  rSloppyp(nullptr),
31  xSloppyp(nullptr),
32  init(false)
33  {
34  }
35 
37  {
39  if ( init ) {
40  for (auto pi : p) if (pi) delete pi;
41  if (rp) delete rp;
42  if (pp) delete pp;
43  if (yp) delete yp;
44  if (App) delete App;
46  if (rSloppyp) delete rSloppyp;
47  if (xSloppyp) delete xSloppyp;
48  }
49  if (tmpp) delete tmpp;
50  if (!mat.isStaggered()) {
51  if (tmp2p && tmpp != tmp2p) delete tmp2p;
52  if (tmp3p && tmpp != tmp3p && param.precision != param.precision_sloppy) delete tmp3p;
53  }
54  if (rnewp) delete rnewp;
55  init = false;
56 
58  }
60  }
61 
62  CGNE::CGNE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon,
63  const DiracMatrix &matEig, SolverParam &param, TimeProfile &profile) :
64  CG(mmdag, mmdagSloppy, mmdagPrecon, mmdagEig, param, profile),
65  mmdag(mat.Expose()),
66  mmdagSloppy(matSloppy.Expose()),
67  mmdagPrecon(matPrecon.Expose()),
68  mmdagEig(matEig.Expose()),
69  xp(nullptr),
70  yp(nullptr),
71  init(false)
72  {
73  }
74 
76  if ( init ) {
77  if (xp) delete xp;
78  if (yp) delete yp;
79  init = false;
80  }
81  }
82 
83  // CGNE: M Mdag y = b is solved; x = Mdag y is returned as solution.
85  if (param.maxiter == 0 || param.Nsteps == 0) {
87  return;
88  }
89 
90  const int iter0 = param.iter;
91 
92  if (!init) {
98  init = true;
99  }
100 
101  double b2 = blas::norm2(b);
102 
104 
105  // compute initial residual
106  mmdag.Expose()->M(*xp,x);
107  double r2 = blas::xmyNorm(b,*xp);
108  if (b2 == 0.0) b2 = r2;
109 
110  // compute solution to residual equation
111  CG::operator()(*yp,*xp);
112 
113  mmdag.Expose()->Mdag(*xp,*yp);
114 
115  // compute full solution
116  blas::xpy(*xp, x);
117 
118  } else {
119 
120  CG::operator()(*yp,b);
121  mmdag.Expose()->Mdag(x,*yp);
122 
123  }
124 
125  // future optimization: with preserve_source == QUDA_PRESERVE_SOURCE_NO; b is already
126  // expected to be the CG residual which matches the CGNE residual
127  // (but only with zero initial guess). at the moment, CG does not respect this convention
129 
130  // compute the true residual
131  mmdag.Expose()->M(*xp, x);
132 
135  blas::axpby(-1.0, A, 1.0, B);
136 
137  double r2;
139  double3 h3 = blas::HeavyQuarkResidualNorm(x, B);
140  r2 = h3.y;
141  param.true_res_hq = sqrt(h3.z);
142  } else {
143  r2 = blas::norm2(B);
144  }
145  param.true_res = sqrt(r2 / b2);
146 
147  PrintSummary("CGNE", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq);
148  }
149 
150  }
151 
152  CGNR::CGNR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon,
153  const DiracMatrix &matEig, SolverParam &param, TimeProfile &profile) :
154  CG(mdagm, mdagmSloppy, mdagmPrecon, mdagmEig, param, profile),
155  mdagm(mat.Expose()),
156  mdagmSloppy(matSloppy.Expose()),
157  mdagmPrecon(matPrecon.Expose()),
158  mdagmEig(matEig.Expose()),
159  bp(nullptr),
160  init(false)
161  {
162  }
163 
165  if ( init ) {
166  if (bp) delete bp;
167  init = false;
168  }
169  }
170 
171  // CGNR: Mdag M x = Mdag b is solved.
173  if (param.maxiter == 0 || param.Nsteps == 0) {
175  return;
176  }
177 
178  const int iter0 = param.iter;
179 
180  if (!init) {
184  init = true;
185  }
186 
187  double b2 = blas::norm2(b);
188  if (b2 == 0.0) { // compute initial residual vector
189  mdagm.Expose()->M(*bp,x);
190  b2 = blas::norm2(*bp);
191  }
192 
193  mdagm.Expose()->Mdag(*bp,b);
194  CG::operator()(x,*bp);
195 
197 
198  // compute the true residual
199  mdagm.Expose()->M(*bp, x);
200 
203  blas::axpby(-1.0, A, 1.0, B);
204 
205  double r2;
207  double3 h3 = blas::HeavyQuarkResidualNorm(x, B);
208  r2 = h3.y;
209  param.true_res_hq = sqrt(h3.z);
210  } else {
211  r2 = blas::norm2(B);
212  }
213  param.true_res = sqrt(r2 / b2);
214  PrintSummary("CGNR", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq);
215 
217  mdagm.Expose()->M(*bp, x);
218  blas::axpby(-1.0, *bp, 1.0, b);
219  }
220 
221  }
222 
223  void CG::operator()(ColorSpinorField &x, ColorSpinorField &b, ColorSpinorField *p_init, double r2_old_init)
224  {
226 
228  errorQuda("Not supported");
229  if (checkPrecision(x, b) != param.precision)
230  errorQuda("Precision mismatch: expected=%d, received=%d", param.precision, x.Precision());
231 
232  if (param.maxiter == 0 || param.Nsteps == 0) {
234  return;
235  }
236 
238  if (Np < 0 || Np > 16) errorQuda("Invalid value %d for solution_accumulator_pipeline\n", Np);
239 
240  // Detect whether this is a pure double solve or not; informs the necessity of some stability checks
242 
243  // whether to select alternative reliable updates
245 
247 
248  double b2 = blas::norm2(b);
249 
250  // Check to see that we're not trying to invert on a zero-field source
253  printfQuda("Warning: inverting on zero-field source\n");
254  x = b;
255  param.true_res = 0.0;
256  param.true_res_hq = 0.0;
257  return;
258  }
259 
260  if (!init) {
265 
266  // sloppy fields
267  csParam.setPrecision(param.precision_sloppy);
270  rSloppyp = ColorSpinorField::Create(csParam);
271  xSloppyp = ColorSpinorField::Create(csParam);
272  } else {
273  rSloppyp = rp;
275  }
276 
277  // temporary fields
279  if(!mat.isStaggered()) {
280  // tmp2 only needed for multi-gpu Wilson-like kernels
282  // additional high-precision temporary if Wilson and mixed-precision
283  csParam.setPrecision(param.precision);
284  tmp3p = (param.precision != param.precision_sloppy) ?
286  } else {
287  tmp3p = tmp2p = tmpp;
288  }
289 
290  init = true;
291  }
292 
293  if (param.deflate) {
294  // Construct the eigensolver and deflation space if requested.
296  if (deflate_compute) {
297  // compute the deflation space.
299  (*eig_solve)(evecs, evals);
301  deflate_compute = false;
302  }
303  if (recompute_evals) {
305  recompute_evals = false;
306  }
307  }
308 
309  ColorSpinorField &r = *rp;
310  ColorSpinorField &y = *yp;
311  ColorSpinorField &Ap = *App;
312  ColorSpinorField &tmp = *tmpp;
313  ColorSpinorField &tmp2 = *tmp2p;
314  ColorSpinorField &tmp3 = *tmp3p;
315  ColorSpinorField &rSloppy = *rSloppyp;
316  ColorSpinorField &xSloppy = param.use_sloppy_partial_accumulator ? *xSloppyp : x;
317 
318  {
321  csParam.setPrecision(param.precision_sloppy);
322 
323  if (Np != (int)p.size()) {
324  for (auto &pi : p) delete pi;
325  p.resize(Np);
326  for (auto &pi : p) pi = ColorSpinorField::Create(csParam);
327  }
328  }
329 
330  // alternative reliable updates
331  // alternative reliable updates - set precision - does not hurt performance here
332 
333  const double u = precisionEpsilon(param.precision_sloppy);
334  const double uhigh = precisionEpsilon(); // solver precision
335 
336  const double deps=sqrt(u);
337  constexpr double dfac = 1.1;
338  double d_new = 0;
339  double d = 0;
340  double dinit = 0;
341  double xNorm = 0;
342  double xnorm = 0;
343  double pnorm = 0;
344  double ppnorm = 0;
345  double Anorm = 0;
346  double beta = 0.0;
347 
348  // for alternative reliable updates
349  if (alternative_reliable) {
350  // estimate norm for reliable updates
351  mat(r, b, y, tmp3);
352  Anorm = sqrt(blas::norm2(r)/b2);
353  }
354 
355  // for detecting HQ residual stalls
356  // let |r2/b2| drop to epsilon tolerance * 1e-30, semi-arbitrarily, but
357  // with the intent of letting the solve grind as long as possible before
358  // triggering a `NaN`. Ignored for pure double solves because if
359  // pure double has stability issues, bigger problems are at hand.
360  const double hq_res_stall_check = is_pure_double ? 0. : uhigh * uhigh * 1e-60;
361 
362  // compute initial residual
363  double r2 = 0.0;
365  // Compute r = b - A * x
366  mat(r, x, y, tmp3);
367  r2 = blas::xmyNorm(b, r);
368  if (b2 == 0) b2 = r2;
369  // y contains the original guess.
370  blas::copy(y, x);
371  } else {
372  if (&r != &b) blas::copy(r, b);
373  r2 = b2;
374  blas::zero(y);
375  }
376 
377  if (param.deflate && param.maxiter > 1) {
378  // Deflate and accumulate to solution vector
379  eig_solve->deflate(y, r, evecs, evals, true);
380  mat(r, y, x, tmp3);
381  r2 = blas::xmyNorm(b, r);
382  }
383 
384  blas::zero(x);
385  if (&x != &xSloppy) blas::zero(xSloppy);
386  blas::copy(rSloppy,r);
387 
388  if (Np != (int)p.size()) {
389  for (auto &pi : p) delete pi;
390  p.resize(Np);
391  ColorSpinorParam csParam(rSloppy);
393  for (auto &pi : p)
394  pi = p_init ? ColorSpinorField::Create(*p_init, csParam) : ColorSpinorField::Create(rSloppy, csParam);
395  } else {
396  for (auto &p_i : p) *p_i = p_init ? *p_init : rSloppy;
397  }
398 
399  double r2_old=0.0;
400  if (r2_old_init != 0.0 and p_init) {
401  r2_old = r2_old_init;
402  Complex rp = blas::cDotProduct(rSloppy, *p[0]) / (r2);
403  blas::caxpy(-rp, rSloppy, *p[0]);
404  beta = r2 / r2_old;
405  blas::xpayz(rSloppy, beta, *p[0], *p[0]);
406  }
407 
408  const bool use_heavy_quark_res =
409  (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false;
410  bool heavy_quark_restart = false;
411 
412  if (!param.is_preconditioner) {
413  profile.TPSTOP(QUDA_PROFILE_INIT);
415  }
416 
417  double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver
418 
419  double heavy_quark_res = 0.0; // heavy quark res idual
420  double heavy_quark_res_old = 0.0; // heavy quark residual
421 
422  if (use_heavy_quark_res) {
423  heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z);
424  heavy_quark_res_old = heavy_quark_res; // heavy quark residual
425  }
426  const int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual
427 
428  double alpha[Np];
429  double pAp;
430  int rUpdate = 0;
431 
432  double rNorm = sqrt(r2);
433  double r0Norm = rNorm;
434  double maxrx = rNorm;
435  double maxrr = rNorm;
436  double maxr_deflate = rNorm; // The maximum residual since the last deflation
437  double delta = param.delta;
438 
439  // this parameter determines how many consective reliable update
440  // residual increases we tolerate before terminating the solver,
441  // i.e., how long do we want to keep trying to converge
442  const int maxResIncrease = param.max_res_increase; // check if we reached the limit of our tolerance
443  const int maxResIncreaseTotal = param.max_res_increase_total;
444 
445  // this means when using heavy quarks we will switch to simple hq restarts as soon as the reliable strategy fails
446  const int hqmaxresIncrease = param.max_hq_res_increase;
447  const int hqmaxresRestartTotal
448  = param.max_hq_res_restart_total; // this limits the number of heavy quark restarts we can do
449 
450  int resIncrease = 0;
451  int resIncreaseTotal = 0;
452  int hqresIncrease = 0;
453  int hqresRestartTotal = 0;
454 
455  // set this to true if maxResIncrease has been exceeded but when we use heavy quark residual we still want to continue the CG
456  // only used if we use the heavy_quark_res
457  bool L2breakdown = false;
458  const double L2breakdown_eps = 100. * uhigh;
459 
460  if (!param.is_preconditioner) {
462  profile.TPSTART(QUDA_PROFILE_COMPUTE);
463  blas::flops = 0;
464  }
465 
466  int k = 0;
467  int j = 0;
468 
469  PrintStats("CG", k, r2, b2, heavy_quark_res);
470 
471  int steps_since_reliable = 1;
472  bool converged = convergence(r2, heavy_quark_res, stop, param.tol_hq);
473 
474  // alternative reliable updates
476  dinit = uhigh * (rNorm + Anorm * xNorm);
477  d = dinit;
478  }
479 
480  while ( !converged && k < param.maxiter ) {
481  matSloppy(Ap, *p[j], tmp, tmp2); // tmp as tmp
482  double sigma;
483 
484  bool breakdown = false;
485  if (param.pipeline) {
486  double Ap2;
487  //TODO: alternative reliable updates - need r2, Ap2, pAp, p norm
489  double4 quadruple = blas::quadrupleCGReduction(rSloppy, Ap, *p[j]);
490  r2 = quadruple.x; Ap2 = quadruple.y; pAp = quadruple.z; ppnorm= quadruple.w;
491  }
492  else{
493  double3 triplet = blas::tripleCGReduction(rSloppy, Ap, *p[j]);
494  r2 = triplet.x; Ap2 = triplet.y; pAp = triplet.z;
495  }
496  r2_old = r2;
497  alpha[j] = r2 / pAp;
498  sigma = alpha[j]*(alpha[j] * Ap2 - pAp);
499  if (sigma < 0.0 || steps_since_reliable == 0) { // sigma condition has broken down
500  r2 = blas::axpyNorm(-alpha[j], Ap, rSloppy);
501  sigma = r2;
502  breakdown = true;
503  }
504 
505  r2 = sigma;
506  } else {
507  r2_old = r2;
508 
509  // alternative reliable updates,
510  if (alternative_reliable) {
511  double3 pAppp = blas::cDotProductNormA(*p[j],Ap);
512  pAp = pAppp.x;
513  ppnorm = pAppp.z;
514  } else {
515  pAp = blas::reDotProduct(*p[j], Ap);
516  }
517 
518  alpha[j] = r2 / pAp;
519 
520  // here we are deploying the alternative beta computation
521  Complex cg_norm = blas::axpyCGNorm(-alpha[j], Ap, rSloppy);
522  r2 = real(cg_norm); // (r_new, r_new)
523  sigma = imag(cg_norm) >= 0.0 ? imag(cg_norm) : r2; // use r2 if (r_k+1, r_k+1-r_k) breaks
524  }
525 
526  // reliable update conditions
527  rNorm = sqrt(r2);
528  int updateX;
529  int updateR;
530 
531  if (alternative_reliable) {
532  // alternative reliable updates
533  updateX = ( (d <= deps*sqrt(r2_old)) or (dfac * dinit > deps * r0Norm) ) and (d_new > deps*rNorm) and (d_new > dfac * dinit);
534  updateR = 0;
535  } else {
536  if (rNorm > maxrx) maxrx = rNorm;
537  if (rNorm > maxrr) maxrr = rNorm;
538  updateX = (rNorm < delta * r0Norm && r0Norm <= maxrx) ? 1 : 0;
539  updateR = ((rNorm < delta * maxrr && r0Norm <= maxrr) || updateX) ? 1 : 0;
540  }
541 
542  // force a reliable update if we are within target tolerance (only if doing reliable updates)
543  if ( convergence(r2, heavy_quark_res, stop, param.tol_hq) && param.delta >= param.tol ) updateX = 1;
544 
545  // For heavy-quark inversion force a reliable update if we continue after,
546  // or if r2/b2 has fictitiously dropped too far below precision epsilon
547  if (use_heavy_quark_res and L2breakdown
548  and (convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq) or (r2 / b2) < hq_res_stall_check)
549  and param.delta >= param.tol) {
550  updateX = 1;
551  }
552 
553  if ( !(updateR || updateX )) {
554  beta = sigma / r2_old; // use the alternative beta computation
555 
556  if (param.pipeline && !breakdown) {
557 
558  if (Np == 1) {
559  blas::tripleCGUpdate(alpha[j], beta, Ap, xSloppy, rSloppy, *p[j]);
560  } else {
561  errorQuda("Not implemented pipelined CG with Np > 1");
562  }
563  } else {
564  if (Np == 1) {
565  // with Np=1 we just run regular fusion between x and p updates
566  blas::axpyZpbx(alpha[k%Np], *p[k%Np], xSloppy, rSloppy, beta);
567  } else {
568 
569  if ( (j+1)%Np == 0 ) {
570  std::vector<ColorSpinorField*> x_;
571  x_.push_back(&xSloppy);
572  blas::axpy(alpha, p, x_);
573  }
574 
575  // p[(k+1)%Np] = r + beta * p[k%Np]
576  blas::xpayz(rSloppy, beta, *p[j], *p[(j + 1) % Np]);
577  }
578  }
579 
580  if (use_heavy_quark_res && k % heavy_quark_check == 0) {
581  if (&x != &xSloppy) {
582  blas::copy(tmp,y);
583  heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(xSloppy, tmp, rSloppy).z);
584  } else {
585  blas::copy(r, rSloppy);
586  heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x, y, r).z);
587  }
588  }
589 
590  // alternative reliable updates
591  if (alternative_reliable) {
592  d = d_new;
593  pnorm = pnorm + alpha[j] * alpha[j]* (ppnorm);
594  xnorm = sqrt(pnorm);
595  d_new = d + u*rNorm + uhigh*Anorm * xnorm;
596  if (steps_since_reliable==0 && getVerbosity() >= QUDA_DEBUG_VERBOSE)
597  printfQuda("New dnew: %e (r %e , y %e)\n",d_new,u*rNorm,uhigh*Anorm * sqrt(blas::norm2(y)) );
598  }
599  steps_since_reliable++;
600 
601  } else {
602 
603  {
604  std::vector<ColorSpinorField*> x_;
605  x_.push_back(&xSloppy);
606  std::vector<ColorSpinorField*> p_;
607  for (int i=0; i<=j; i++) p_.push_back(p[i]);
608  blas::axpy(alpha, p_, x_);
609  }
610 
611  blas::copy(x, xSloppy); // nop when these pointers alias
612 
613  blas::xpy(x, y); // swap these around?
614  mat(r, y, x, tmp3); // here we can use x as tmp
615  r2 = blas::xmyNorm(b, r);
616 
617  if (param.deflate && sqrt(r2) < maxr_deflate * param.tol_restart) {
618  // Deflate and accumulate to solution vector
619  eig_solve->deflate(y, r, evecs, evals, true);
620 
621  // Compute r_defl = RHS - A * LHS
622  mat(r, y, x, tmp3);
623  r2 = blas::xmyNorm(b, r);
624 
625  maxr_deflate = sqrt(r2);
626  }
627 
628  blas::copy(rSloppy, r); //nop when these pointers alias
629  blas::zero(xSloppy);
630 
631  // alternative reliable updates
632  if (alternative_reliable) {
633  dinit = uhigh*(sqrt(r2) + Anorm * sqrt(blas::norm2(y)));
634  d = d_new;
635  xnorm = 0;//sqrt(norm2(x));
636  pnorm = 0;//pnorm + alpha * sqrt(norm2(p));
637  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("New dinit: %e (r %e , y %e)\n",dinit,uhigh*sqrt(r2),uhigh*Anorm*sqrt(blas::norm2(y)));
638  d_new = dinit;
639  } else {
640  rNorm = sqrt(r2);
641  maxrr = rNorm;
642  maxrx = rNorm;
643  }
644 
645  // calculate new reliable HQ resididual
646  if (use_heavy_quark_res) heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(y, r).z);
647 
648  // break-out check if we have reached the limit of the precision
649  if (sqrt(r2) > r0Norm && updateX and not L2breakdown) { // reuse r0Norm for this
650  resIncrease++;
651  resIncreaseTotal++;
652  warningQuda(
653  "CG: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
654  sqrt(r2), r0Norm, resIncreaseTotal);
655 
656  if ((use_heavy_quark_res and sqrt(r2) < L2breakdown_eps) or resIncrease > maxResIncrease
657  or resIncreaseTotal > maxResIncreaseTotal or r2 < stop) {
658  if (use_heavy_quark_res) {
659  L2breakdown = true;
660  warningQuda("CG: L2 breakdown %e, %e", sqrt(r2), L2breakdown_eps);
661  } else {
662  if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal or r2 < stop) {
663  warningQuda("CG: solver exiting due to too many true residual norm increases");
664  break;
665  }
666  }
667  }
668  } else {
669  resIncrease = 0;
670  }
671 
672  // if L2 broke down already we turn off reliable updates and restart the CG
673  if (use_heavy_quark_res and L2breakdown) {
674  hqresRestartTotal++; // count the number of heavy quark restarts we've done
675  delta = 0;
676  warningQuda("CG: Restarting without reliable updates for heavy-quark residual (total #inc %i)",
677  hqresRestartTotal);
678  heavy_quark_restart = true;
679 
680  if (heavy_quark_res > heavy_quark_res_old) { // check if new hq residual is greater than previous
681  hqresIncrease++; // count the number of consecutive increases
682  warningQuda("CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e",
683  heavy_quark_res, heavy_quark_res_old);
684  // break out if we do not improve here anymore
685  if (hqresIncrease > hqmaxresIncrease) {
686  warningQuda("CG: solver exiting due to too many heavy quark residual norm increases (%i/%i)",
687  hqresIncrease, hqmaxresIncrease);
688  break;
689  }
690  } else {
691  hqresIncrease = 0;
692  }
693 
694  if (hqresRestartTotal > hqmaxresRestartTotal) {
695  warningQuda("CG: solver exiting due to too many heavy quark residual restarts (%i/%i)", hqresRestartTotal,
696  hqmaxresRestartTotal);
697  break;
698  }
699  }
700 
701  if (use_heavy_quark_res and heavy_quark_restart) {
702  // perform a restart
703  blas::copy(*p[0], rSloppy);
704  heavy_quark_restart = false;
705  } else {
706  // explicitly restore the orthogonality of the gradient vector
707  Complex rp = blas::cDotProduct(rSloppy, *p[j]) / (r2);
708  blas::caxpy(-rp, rSloppy, *p[j]);
709 
710  beta = r2 / r2_old;
711  blas::xpayz(rSloppy, beta, *p[j], *p[0]);
712  }
713 
714  steps_since_reliable = 0;
715  r0Norm = sqrt(r2);
716  rUpdate++;
717 
718  heavy_quark_res_old = heavy_quark_res;
719  }
720 
721  breakdown = false;
722  k++;
723 
724  PrintStats("CG", k, r2, b2, heavy_quark_res);
725  // check convergence, if convergence is satisfied we only need to check that we had a reliable update for the heavy quarks recently
726  converged = convergence(r2, heavy_quark_res, stop, param.tol_hq);
727 
728  // check for recent enough reliable updates of the HQ residual if we use it
729  if (use_heavy_quark_res) {
730  // L2 is converged or precision maxed out for L2
731  bool L2done = L2breakdown or convergenceL2(r2, heavy_quark_res, stop, param.tol_hq);
732  // HQ is converged and if we do reliable update the HQ residual has been calculated using a reliable update
733  bool HQdone = (steps_since_reliable == 0 and param.delta > 0) and convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq);
734  converged = L2done and HQdone;
735  }
736 
737  // if we have converged and need to update any trailing solutions
738  if (converged && steps_since_reliable > 0 && (j+1)%Np != 0 ) {
739  std::vector<ColorSpinorField*> x_;
740  x_.push_back(&xSloppy);
741  std::vector<ColorSpinorField*> p_;
742  for (int i=0; i<=j; i++) p_.push_back(p[i]);
743  blas::axpy(alpha, p_, x_);
744  }
745 
746  j = steps_since_reliable == 0 ? 0 : (j+1)%Np; // if just done a reliable update then reset j
747  }
748 
749  blas::copy(x, xSloppy);
750  blas::xpy(y, x);
751 
752  if (!param.is_preconditioner) {
755 
757  double gflops = (blas::flops + mat.flops() + matSloppy.flops() + matPrecon.flops() + matEig.flops()) * 1e-9;
758  param.gflops = gflops;
759  param.iter += k;
760 
761  if (k == param.maxiter) warningQuda("Exceeded maximum iterations %d", param.maxiter);
762  }
763 
764  if (getVerbosity() >= QUDA_VERBOSE)
765  printfQuda("CG: Reliable updates = %d\n", rUpdate);
766 
767  if (param.compute_true_res) {
768  // compute the true residuals
769  mat(r, x, y, tmp3);
770  param.true_res = sqrt(blas::xmyNorm(b, r) / b2);
772  }
773 
774  PrintSummary("CG", k, r2, b2, stop, param.tol_hq);
775 
776  if (!param.is_preconditioner) {
777  // reset the flops counters
778  blas::flops = 0;
779  mat.flops();
780  matSloppy.flops();
781  matPrecon.flops();
782 
784  }
785 
787  }
788 
789 // use BlockCGrQ algortithm or BlockCG (with / without GS, see BLOCKCG_GS option)
790 #define BCGRQ 1
791 #if BCGRQ
793  #ifndef BLOCKSOLVER
794  errorQuda("QUDA_BLOCKSOLVER not built.");
795  #else
796 
798  errorQuda("Not supported");
799 
800  profile.TPSTART(QUDA_PROFILE_INIT);
801 
802  using Eigen::MatrixXcd;
803 
804  // Check to see that we're not trying to invert on a zero-field source
805  //MW: it might be useful to check what to do here.
806  double b2[QUDA_MAX_MULTI_SHIFT];
807  double b2avg=0;
808  for(int i=0; i< param.num_src; i++){
809  b2[i]=blas::norm2(b.Component(i));
810  b2avg += b2[i];
811  if(b2[i] == 0){
812  profile.TPSTOP(QUDA_PROFILE_INIT);
813  errorQuda("Warning: inverting on zero-field source - undefined for block solver\n");
814  x=b;
815  param.true_res = 0.0;
816  param.true_res_hq = 0.0;
817  return;
818  }
819  }
820 
821  b2avg = b2avg / param.num_src;
822 
824  if (!init) {
825  csParam.setPrecision(param.precision);
829 
830  // sloppy fields
831  csParam.setPrecision(param.precision_sloppy);
835  rSloppyp = ColorSpinorField::Create(csParam);
836  xSloppyp = ColorSpinorField::Create(csParam);
837  } else {
838  rSloppyp = rp;
840  }
841 
842  // temporary fields
844  if(!mat.isStaggered()) {
845  // tmp2 only needed for multi-gpu Wilson-like kernels
847  // additional high-precision temporary if Wilson and mixed-precision
848  csParam.setPrecision(param.precision);
849  tmp3p = (param.precision != param.precision_sloppy) ?
851  } else {
852  tmp3p = tmp2p = tmpp;
853  }
854 
855  init = true;
856  }
857 
858  if(!rnewp) {
860  csParam.setPrecision(param.precision_sloppy);
861  // ColorSpinorField *rpnew = ColorSpinorField::Create(csParam);
862  }
863 
864  ColorSpinorField &r = *rp;
865  ColorSpinorField &y = *yp;
866  ColorSpinorField &p = *pp;
867  ColorSpinorField &Ap = *App;
868  ColorSpinorField &rnew = *rnewp;
869  ColorSpinorField &tmp = *tmpp;
870  ColorSpinorField &tmp2 = *tmp2p;
871  ColorSpinorField &tmp3 = *tmp3p;
872  ColorSpinorField &rSloppy = *rSloppyp;
873  ColorSpinorField &xSloppy = param.use_sloppy_partial_accumulator ? *xSloppyp : x;
874 
875  // calculate residuals for all vectors
876  // and initialize r2 matrix
877  double r2avg=0;
878  MatrixXcd r2(param.num_src, param.num_src);
879  for(int i=0; i<param.num_src; i++){
880  mat(r.Component(i), x.Component(i), y.Component(i));
881  r2(i,i) = blas::xmyNorm(b.Component(i), r.Component(i));
882  r2avg += r2(i,i).real();
883  printfQuda("r2[%i] %e\n", i, r2(i,i).real());
884  }
885  for(int i=0; i<param.num_src; i++){
886  for(int j=i+1; j < param.num_src; j++){
887  r2(i,j) = blas::cDotProduct(r.Component(i),r.Component(j));
888  r2(j,i) = std::conj(r2(i,j));
889  }
890  }
891 
892  blas::copy(rSloppy, r);
893  blas::copy(p, rSloppy);
894  blas::copy(rnew, rSloppy);
895 
896  if (&x != &xSloppy) {
897  blas::copy(y, x);
898  blas::zero(xSloppy);
899  } else {
900  blas::zero(y);
901  }
902 
903  const bool use_heavy_quark_res =
904  (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false;
905  if(use_heavy_quark_res) errorQuda("ERROR: heavy quark residual not supported in block solver");
906 
907  profile.TPSTOP(QUDA_PROFILE_INIT);
909 
910  double stop[QUDA_MAX_MULTI_SHIFT];
911 
912  for(int i = 0; i < param.num_src; i++){
913  stop[i] = stopping(param.tol, b2[i], param.residual_type); // stopping condition of solver
914  }
915 
916  // Eigen Matrices instead of scalars
917  MatrixXcd alpha = MatrixXcd::Zero(param.num_src,param.num_src);
918  MatrixXcd beta = MatrixXcd::Zero(param.num_src,param.num_src);
919  MatrixXcd C = MatrixXcd::Zero(param.num_src,param.num_src);
920  MatrixXcd S = MatrixXcd::Identity(param.num_src,param.num_src);
921  MatrixXcd pAp = MatrixXcd::Identity(param.num_src,param.num_src);
923 
924  #ifdef MWVERBOSE
925  MatrixXcd pTp = MatrixXcd::Identity(param.num_src,param.num_src);
926  #endif
927 
928 
929 
930 
931  //FIXME:reliable updates currently not implemented
932  /*
933  double rNorm[QUDA_MAX_MULTI_SHIFT];
934  double r0Norm[QUDA_MAX_MULTI_SHIFT];
935  double maxrx[QUDA_MAX_MULTI_SHIFT];
936  double maxrr[QUDA_MAX_MULTI_SHIFT];
937 
938  for(int i = 0; i < param.num_src; i++){
939  rNorm[i] = sqrt(r2(i,i).real());
940  r0Norm[i] = rNorm[i];
941  maxrx[i] = rNorm[i];
942  maxrr[i] = rNorm[i];
943  }
944  bool L2breakdown = false;
945  int rUpdate = 0;
946  nt steps_since_reliable = 1;
947  */
948 
950  profile.TPSTART(QUDA_PROFILE_COMPUTE);
951  blas::flops = 0;
952 
953  int k = 0;
954 
955  PrintStats("CG", k, r2avg / param.num_src, b2avg, 0.);
956  bool allconverged = true;
957  bool converged[QUDA_MAX_MULTI_SHIFT];
958  for(int i=0; i<param.num_src; i++){
959  converged[i] = convergence(r2(i,i).real(), 0., stop[i], param.tol_hq);
960  allconverged = allconverged && converged[i];
961  }
962 
963  // CHolesky decomposition
964  MatrixXcd L = r2.llt().matrixL();
965  C = L.adjoint();
966  MatrixXcd Linv = C.inverse();
967 
968  #ifdef MWVERBOSE
969  std::cout << "r2\n " << r2 << std::endl;
970  std::cout << "L\n " << L.adjoint() << std::endl;
971  #endif
972 
973  // set p to QR decompsition of r
974  // temporary hack - use AC to pass matrix arguments to multiblas
975  for(int i=0; i<param.num_src; i++){
976  blas::zero(p.Component(i));
977  for(int j=0;j<param.num_src; j++){
978  AC[i*param.num_src + j] = Linv(i,j);
979  }
980  }
981  blas::caxpy(AC,r,p);
982 
983  // set rsloppy to to QR decompoistion of r (p)
984  for(int i=0; i< param.num_src; i++){
985  blas::copy(rSloppy.Component(i), p.Component(i));
986  }
987 
988  #ifdef MWVERBOSE
989  for(int i=0; i<param.num_src; i++){
990  for(int j=0; j<param.num_src; j++){
991  pTp(i,j) = blas::cDotProduct(p.Component(i), p.Component(j));
992  }
993  }
994  std::cout << " pTp " << std::endl << pTp << std::endl;
995  std::cout << " L " << std::endl << L.adjoint() << std::endl;
996  std::cout << " C " << std::endl << C << std::endl;
997  #endif
998 
999  while ( !allconverged && k < param.maxiter ) {
1000  // apply matrix
1001  for(int i=0; i<param.num_src; i++){
1002  matSloppy(Ap.Component(i), p.Component(i), tmp.Component(i), tmp2.Component(i)); // tmp as tmp
1003  }
1004 
1005  // calculate pAp
1006  for(int i=0; i<param.num_src; i++){
1007  for(int j=i; j < param.num_src; j++){
1008  pAp(i,j) = blas::cDotProduct(p.Component(i), Ap.Component(j));
1009  if (i!=j) pAp(j,i) = std::conj(pAp(i,j));
1010  }
1011  }
1012 
1013  // update Xsloppy
1014  alpha = pAp.inverse() * C;
1015  // temporary hack using AC
1016  for(int i=0; i<param.num_src; i++){
1017  for(int j=0;j<param.num_src; j++){
1018  AC[i*param.num_src + j] = alpha(i,j);
1019  }
1020  }
1021  blas::caxpy(AC,p,xSloppy);
1022 
1023  // update rSloppy
1024  beta = pAp.inverse();
1025  // temporary hack
1026  for(int i=0; i<param.num_src; i++){
1027  for(int j=0;j<param.num_src; j++){
1028  AC[i*param.num_src + j] = -beta(i,j);
1029  }
1030  }
1031  blas::caxpy(AC,Ap,rSloppy);
1032 
1033  // orthorgonalize R
1034  // copy rSloppy to rnew as temporary
1035  for(int i=0; i< param.num_src; i++){
1036  blas::copy(rnew.Component(i), rSloppy.Component(i));
1037  }
1038  for(int i=0; i<param.num_src; i++){
1039  for(int j=i; j < param.num_src; j++){
1040  r2(i,j) = blas::cDotProduct(r.Component(i),r.Component(j));
1041  if (i!=j) r2(j,i) = std::conj(r2(i,j));
1042  }
1043  }
1044  // Cholesky decomposition
1045  L = r2.llt().matrixL();// retrieve factor L in the decomposition
1046  S = L.adjoint();
1047  Linv = S.inverse();
1048  // temporary hack
1049  for(int i=0; i<param.num_src; i++){
1050  blas::zero(rSloppy.Component(i));
1051  for(int j=0;j<param.num_src; j++){
1052  AC[i*param.num_src + j] = Linv(i,j);
1053  }
1054  }
1055  blas::caxpy(AC,rnew,rSloppy);
1056 
1057  #ifdef MWVERBOSE
1058  for(int i=0; i<param.num_src; i++){
1059  for(int j=0; j<param.num_src; j++){
1060  pTp(i,j) = blas::cDotProduct(rSloppy.Component(i), rSloppy.Component(j));
1061  }
1062  }
1063  std::cout << " rTr " << std::endl << pTp << std::endl;
1064  std::cout << "QR" << S<< std::endl << "QP " << S.inverse()*S << std::endl;;
1065  #endif
1066 
1067  // update p
1068  // use rnew as temporary again for summing up
1069  for(int i=0; i<param.num_src; i++){
1070  blas::copy(rnew.Component(i),rSloppy.Component(i));
1071  }
1072  // temporary hack
1073  for(int i=0; i<param.num_src; i++){
1074  for(int j=0;j<param.num_src; j++){
1075  AC[i*param.num_src + j] = std::conj(S(j,i));
1076  }
1077  }
1078  blas::caxpy(AC,p,rnew);
1079  // set p = rnew
1080  for(int i=0; i < param.num_src; i++){
1081  blas::copy(p.Component(i),rnew.Component(i));
1082  }
1083 
1084  // update C
1085  C = S * C;
1086 
1087  #ifdef MWVERBOSE
1088  for(int i=0; i<param.num_src; i++){
1089  for(int j=0; j<param.num_src; j++){
1090  pTp(i,j) = blas::cDotProduct(p.Component(i), p.Component(j));
1091  }
1092  }
1093  std::cout << " pTp " << std::endl << pTp << std::endl;
1094  std::cout << "S " << S<< std::endl << "C " << C << std::endl;
1095  #endif
1096 
1097  // calculate the residuals for all shifts
1098  r2avg=0;
1099  for (int j=0; j<param.num_src; j++ ){
1100  r2(j,j) = C(0,j)*conj(C(0,j));
1101  for(int i=1; i < param.num_src; i++)
1102  r2(j,j) += C(i,j) * conj(C(i,j));
1103  r2avg += r2(j,j).real();
1104  }
1105 
1106  k++;
1107  PrintStats("CG", k, r2avg / param.num_src, b2avg, 0);
1108  // check convergence
1109  allconverged = true;
1110  for(int i=0; i<param.num_src; i++){
1111  converged[i] = convergence(r2(i,i).real(), 0, stop[i], param.tol_hq);
1112  allconverged = allconverged && converged[i];
1113  }
1114 
1115 
1116  }
1117 
1118  for(int i=0; i<param.num_src; i++){
1119  blas::xpy(y.Component(i), xSloppy.Component(i));
1120  }
1121 
1122  profile.TPSTOP(QUDA_PROFILE_COMPUTE);
1123  profile.TPSTART(QUDA_PROFILE_EPILOGUE);
1124 
1126  double gflops = (blas::flops + mat.flops() + matSloppy.flops())*1e-9;
1127  param.gflops = gflops;
1128  param.iter += k;
1129 
1130  if (k == param.maxiter)
1131  warningQuda("Exceeded maximum iterations %d", param.maxiter);
1132 
1133  // if (getVerbosity() >= QUDA_VERBOSE)
1134  // printfQuda("CG: Reliable updates = %d\n", rUpdate);
1135 
1136  // compute the true residuals
1137  for(int i=0; i<param.num_src; i++){
1138  mat(r.Component(i), x.Component(i), y.Component(i), tmp3.Component(i));
1139  param.true_res = sqrt(blas::xmyNorm(b.Component(i), r.Component(i)) / b2[i]);
1143 
1144  PrintSummary("CG", k, r2(i,i).real(), b2[i], stop[i], 0.0);
1145  }
1146 
1147  // reset the flops counters
1148  blas::flops = 0;
1149  mat.flops();
1150  matSloppy.flops();
1151 
1153  profile.TPSTART(QUDA_PROFILE_FREE);
1154 
1155  delete[] AC;
1156  profile.TPSTOP(QUDA_PROFILE_FREE);
1157 
1158  return;
1159 
1160  #endif
1161 }
1162 
1163 #else
1164 
1165 // use Gram Schmidt in Block CG ?
1166 #define BLOCKCG_GS 1
1167 void CG::solve(ColorSpinorField& x, ColorSpinorField& b) {
1168  #ifndef BLOCKSOLVER
1169  errorQuda("QUDA_BLOCKSOLVER not built.");
1170  #else
1171  #ifdef BLOCKCG_GS
1172  printfQuda("BCGdQ Solver\n");
1173  #else
1174  printfQuda("BCQ Solver\n");
1175  #endif
1176  const bool use_block = true;
1178  errorQuda("Not supported");
1179 
1180  profile.TPSTART(QUDA_PROFILE_INIT);
1181 
1182  using Eigen::MatrixXcd;
1183  MatrixXcd mPAP(param.num_src,param.num_src);
1184  MatrixXcd mRR(param.num_src,param.num_src);
1185 
1186 
1187  // Check to see that we're not trying to invert on a zero-field source
1188  //MW: it might be useful to check what to do here.
1189  double b2[QUDA_MAX_MULTI_SHIFT];
1190  double b2avg=0;
1191  double r2avg=0;
1192  for(int i=0; i< param.num_src; i++){
1193  b2[i]=blas::norm2(b.Component(i));
1194  b2avg += b2[i];
1195  if(b2[i] == 0){
1196  profile.TPSTOP(QUDA_PROFILE_INIT);
1197  errorQuda("Warning: inverting on zero-field source\n");
1198  x=b;
1199  param.true_res = 0.0;
1200  param.true_res_hq = 0.0;
1201  return;
1202  }
1203  }
1204 
1205  #ifdef MWVERBOSE
1206  MatrixXcd b2m(param.num_src,param.num_src);
1207  // just to check details of b
1208  for(int i=0; i<param.num_src; i++){
1209  for(int j=0; j<param.num_src; j++){
1210  b2m(i,j) = blas::cDotProduct(b.Component(i), b.Component(j));
1211  }
1212  }
1213  std::cout << "b2m\n" << b2m << std::endl;
1214  #endif
1215 
1216  ColorSpinorParam csParam(x);
1217  if (!init) {
1218  csParam.setPrecision(param.precision);
1222 
1223  // sloppy fields
1224  csParam.setPrecision(param.precision_sloppy);
1228  rSloppyp = ColorSpinorField::Create(csParam);
1229  xSloppyp = ColorSpinorField::Create(csParam);
1230  } else {
1231  rSloppyp = rp;
1233  }
1234 
1235  // temporary fields
1237  if(!mat.isStaggered()) {
1238  // tmp2 only needed for multi-gpu Wilson-like kernels
1240  // additional high-precision temporary if Wilson and mixed-precision
1241  csParam.setPrecision(param.precision);
1242  tmp3p = (param.precision != param.precision_sloppy) ?
1244  } else {
1245  tmp3p = tmp2p = tmpp;
1246  }
1247 
1248  init = true;
1249  }
1250 
1251  if(!rnewp) {
1253  csParam.setPrecision(param.precision_sloppy);
1254  // ColorSpinorField *rpnew = ColorSpinorField::Create(csParam);
1255  }
1256 
1257  ColorSpinorField &r = *rp;
1258  ColorSpinorField &y = *yp;
1259  ColorSpinorField &p = *pp;
1260  ColorSpinorField &pnew = *rnewp;
1261  ColorSpinorField &Ap = *App;
1262  ColorSpinorField &tmp = *tmpp;
1263  ColorSpinorField &tmp2 = *tmp2p;
1264  ColorSpinorField &tmp3 = *tmp3p;
1265  ColorSpinorField &rSloppy = *rSloppyp;
1266  ColorSpinorField &xSloppy = param.use_sloppy_partial_accumulator ? *xSloppyp : x;
1267 
1268  // const int i = 0; // MW: hack to be able to write Component(i) instead and try with i=0 for now
1269 
1270  for(int i=0; i<param.num_src; i++){
1271  mat(r.Component(i), x.Component(i), y.Component(i));
1272  }
1273 
1274  // double r2[QUDA_MAX_MULTI_SHIFT];
1275  MatrixXcd r2(param.num_src,param.num_src);
1276  for(int i=0; i<param.num_src; i++){
1277  r2(i,i) = blas::xmyNorm(b.Component(i), r.Component(i));
1278  printfQuda("r2[%i] %e\n", i, r2(i,i).real());
1279  }
1280  if(use_block){
1281  // MW need to initalize the full r2 matrix here
1282  for(int i=0; i<param.num_src; i++){
1283  for(int j=i+1; j<param.num_src; j++){
1284  r2(i,j) = blas::cDotProduct(r.Component(i), r.Component(j));
1285  r2(j,i) = std::conj(r2(i,j));
1286  }
1287  }
1288  }
1289 
1290  blas::copy(rSloppy, r);
1291  blas::copy(p, rSloppy);
1292  blas::copy(pnew, rSloppy);
1293 
1294  if (&x != &xSloppy) {
1295  blas::copy(y, x);
1296  blas::zero(xSloppy);
1297  } else {
1298  blas::zero(y);
1299  }
1300 
1301  const bool use_heavy_quark_res =
1302  (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false;
1303  bool heavy_quark_restart = false;
1304 
1305  profile.TPSTOP(QUDA_PROFILE_INIT);
1306  profile.TPSTART(QUDA_PROFILE_PREAMBLE);
1307 
1308  MatrixXcd r2_old(param.num_src, param.num_src);
1309  double heavy_quark_res[QUDA_MAX_MULTI_SHIFT] = {0.0}; // heavy quark res idual
1310  double heavy_quark_res_old[QUDA_MAX_MULTI_SHIFT] = {0.0}; // heavy quark residual
1311  double stop[QUDA_MAX_MULTI_SHIFT];
1312 
1313  for(int i = 0; i < param.num_src; i++){
1314  stop[i] = stopping(param.tol, b2[i], param.residual_type); // stopping condition of solver
1315  if (use_heavy_quark_res) {
1316  heavy_quark_res[i] = sqrt(blas::HeavyQuarkResidualNorm(x.Component(i), r.Component(i)).z);
1317  heavy_quark_res_old[i] = heavy_quark_res[i]; // heavy quark residual
1318  }
1319  }
1320  const int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual
1321 
1322  MatrixXcd alpha = MatrixXcd::Zero(param.num_src,param.num_src);
1323  MatrixXcd beta = MatrixXcd::Zero(param.num_src,param.num_src);
1324  MatrixXcd gamma = MatrixXcd::Identity(param.num_src,param.num_src);
1325  // gamma = gamma * 2.0;
1326 
1327  MatrixXcd pAp(param.num_src, param.num_src);
1328  MatrixXcd pTp(param.num_src, param.num_src);
1329  int rUpdate = 0;
1330 
1331  double rNorm[QUDA_MAX_MULTI_SHIFT];
1332  double r0Norm[QUDA_MAX_MULTI_SHIFT];
1333  double maxrx[QUDA_MAX_MULTI_SHIFT];
1334  double maxrr[QUDA_MAX_MULTI_SHIFT];
1335 
1336  for(int i = 0; i < param.num_src; i++){
1337  rNorm[i] = sqrt(r2(i,i).real());
1338  r0Norm[i] = rNorm[i];
1339  maxrx[i] = rNorm[i];
1340  maxrr[i] = rNorm[i];
1341  }
1342 
1343  double delta = param.delta;//MW: hack no reliable updates param.delta;
1344 
1345  // this parameter determines how many consective reliable update
1346  // reisudal increases we tolerate before terminating the solver,
1347  // i.e., how long do we want to keep trying to converge
1348  const int maxResIncrease = (use_heavy_quark_res ? 0 : param.max_res_increase); // check if we reached the limit of our tolerance
1349  const int maxResIncreaseTotal = param.max_res_increase_total;
1350  // 0 means we have no tolerance
1351  // maybe we should expose this as a parameter
1352  const int hqmaxresIncrease = maxResIncrease + 1;
1353 
1354  int resIncrease = 0;
1355  int resIncreaseTotal = 0;
1356  int hqresIncrease = 0;
1357 
1358  // set this to true if maxResIncrease has been exceeded but when we use heavy quark residual we still want to continue the CG
1359  // only used if we use the heavy_quark_res
1360  bool L2breakdown = false;
1361 
1363  profile.TPSTART(QUDA_PROFILE_COMPUTE);
1364  blas::flops = 0;
1365 
1366  int k = 0;
1367 
1368  for(int i=0; i<param.num_src; i++){
1369  r2avg+=r2(i,i).real();
1370  }
1371  PrintStats("CG", k, r2avg, b2avg, heavy_quark_res[0]);
1372  int steps_since_reliable = 1;
1373  bool allconverged = true;
1374  bool converged[QUDA_MAX_MULTI_SHIFT];
1375  for(int i=0; i<param.num_src; i++){
1376  converged[i] = convergence(r2(i,i).real(), heavy_quark_res[i], stop[i], param.tol_hq);
1377  allconverged = allconverged && converged[i];
1378  }
1379  MatrixXcd sigma(param.num_src,param.num_src);
1380 
1381  #ifdef BLOCKCG_GS
1382  // begin ignore Gram-Schmidt for now
1383 
1384  for(int i=0; i < param.num_src; i++){
1385  double n = blas::norm2(p.Component(i));
1386  blas::ax(1/sqrt(n),p.Component(i));
1387  for(int j=i+1; j < param.num_src; j++) {
1388  std::complex<double> ri=blas::cDotProduct(p.Component(i),p.Component(j));
1389  blas::caxpy(-ri,p.Component(i),p.Component(j));
1390  }
1391  }
1392 
1393  gamma = MatrixXcd::Zero(param.num_src,param.num_src);
1394  for ( int i = 0; i < param.num_src; i++){
1395  for (int j=i; j < param.num_src; j++){
1396  gamma(i,j) = blas::cDotProduct(p.Component(i),pnew.Component(j));
1397  }
1398  }
1399  #endif
1400  // end ignore Gram-Schmidt for now
1401 
1402  #ifdef MWVERBOSE
1403  for(int i=0; i<param.num_src; i++){
1404  for(int j=0; j<param.num_src; j++){
1405  pTp(i,j) = blas::cDotProduct(p.Component(i), p.Component(j));
1406  }
1407  }
1408 
1409  std::cout << " pTp " << std::endl << pTp << std::endl;
1410  std::cout << "QR" << gamma<< std::endl << "QP " << gamma.inverse()*gamma << std::endl;;
1411  #endif
1412  while ( !allconverged && k < param.maxiter ) {
1413  for(int i=0; i<param.num_src; i++){
1414  matSloppy(Ap.Component(i), p.Component(i), tmp.Component(i), tmp2.Component(i)); // tmp as tmp
1415  }
1416 
1417 
1418  bool breakdown = false;
1419  // FIXME: need to check breakdown
1420  // current implementation sets breakdown to true for pipelined CG if one rhs triggers breakdown
1421  // this is probably ok
1422 
1423 
1424  if (param.pipeline) {
1425  errorQuda("pipeline not implemented");
1426  } else {
1427  r2_old = r2;
1428  for(int i=0; i<param.num_src; i++){
1429  for(int j=0; j < param.num_src; j++){
1430  if(use_block or i==j)
1431  pAp(i,j) = blas::cDotProduct(p.Component(i), Ap.Component(j));
1432  else
1433  pAp(i,j) = 0.;
1434  }
1435  }
1436 
1437  alpha = pAp.inverse() * gamma.adjoint().inverse() * r2;
1438  #ifdef MWVERBOSE
1439  std::cout << "alpha\n" << alpha << std::endl;
1440 
1441  if(k==1){
1442  std::cout << "pAp " << std::endl <<pAp << std::endl;
1443  std::cout << "pAp^-1 " << std::endl <<pAp.inverse() << std::endl;
1444  std::cout << "r2 " << std::endl <<r2 << std::endl;
1445  std::cout << "alpha " << std::endl <<alpha << std::endl;
1446  std::cout << "pAp^-1r2" << std::endl << pAp.inverse()*r2 << std::endl;
1447  }
1448  #endif
1449  // here we are deploying the alternative beta computation
1450  for(int i=0; i<param.num_src; i++){
1451  for(int j=0; j < param.num_src; j++){
1452 
1453  blas::caxpy(-alpha(j,i), Ap.Component(j), rSloppy.Component(i));
1454  }
1455  }
1456  // MW need to calculate the full r2 matrix here, after update. Not sure how to do alternative sigma yet ...
1457  for(int i=0; i<param.num_src; i++){
1458  for(int j=0; j<param.num_src; j++){
1459  if(use_block or i==j)
1460  r2(i,j) = blas::cDotProduct(r.Component(i), r.Component(j));
1461  else
1462  r2(i,j) = 0.;
1463  }
1464  }
1465  sigma = r2;
1466  }
1467 
1468 
1469  bool updateX=false;
1470  bool updateR=false;
1471  // int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? true : false;
1472  // int updateR = ((rNorm < delta*maxrr && r0Norm <= maxrr) || updateX) ? true : false;
1473  //
1474  // printfQuda("Checking reliable update %i %i\n",updateX,updateR);
1475  // reliable update conditions
1476  for(int i=0; i<param.num_src; i++){
1477  rNorm[i] = sqrt(r2(i,i).real());
1478  if (rNorm[i] > maxrx[i]) maxrx[i] = rNorm[i];
1479  if (rNorm[i] > maxrr[i]) maxrr[i] = rNorm[i];
1480  updateX = (rNorm[i] < delta * r0Norm[i] && r0Norm[i] <= maxrx[i]) ? true : false;
1481  updateR = ((rNorm[i] < delta * maxrr[i] && r0Norm[i] <= maxrr[i]) || updateX) ? true : false;
1482  }
1483  if ( (updateR || updateX )) {
1484  // printfQuda("Suppressing reliable update %i %i\n",updateX,updateR);
1485  updateX=false;
1486  updateR=false;
1487  // printfQuda("Suppressing reliable update %i %i\n",updateX,updateR);
1488  }
1489 
1490  if ( !(updateR || updateX )) {
1491 
1492  beta = gamma * r2_old.inverse() * sigma;
1493  #ifdef MWVERBOSE
1494  std::cout << "beta\n" << beta << std::endl;
1495  #endif
1496  if (param.pipeline && !breakdown)
1497  errorQuda("pipeline not implemented");
1498 
1499  else{
1500  for(int i=0; i<param.num_src; i++){
1501  for(int j=0; j<param.num_src; j++){
1502  blas::caxpy(alpha(j,i),p.Component(j),xSloppy.Component(i));
1503  }
1504  }
1505 
1506  // set to zero
1507  for(int i=0; i < param.num_src; i++){
1508  blas::ax(0,pnew.Component(i)); // do we need components here?
1509  }
1510  // add r
1511  for(int i=0; i<param.num_src; i++){
1512  // for(int j=0;j<param.num_src; j++){
1513  // order of updating p might be relevant here
1514  blas::axpy(1.0,r.Component(i),pnew.Component(i));
1515  // blas::axpby(rcoeff,rSloppy.Component(i),beta(i,j),p.Component(j));
1516  // }
1517  }
1518  // beta = beta * gamma.inverse();
1519  for(int i=0; i<param.num_src; i++){
1520  for(int j=0;j<param.num_src; j++){
1521  double rcoeff= (j==0?1.0:0.0);
1522  // order of updating p might be relevant hereq
1523  blas::caxpy(beta(j,i),p.Component(j),pnew.Component(i));
1524  // blas::axpby(rcoeff,rSloppy.Component(i),beta(i,j),p.Component(j));
1525  }
1526  }
1527  // now need to do something with the p's
1528 
1529  for(int i=0; i< param.num_src; i++){
1530  blas::copy(p.Component(i), pnew.Component(i));
1531  }
1532 
1533 
1534  #ifdef BLOCKCG_GS
1535  for(int i=0; i < param.num_src; i++){
1536  double n = blas::norm2(p.Component(i));
1537  blas::ax(1/sqrt(n),p.Component(i));
1538  for(int j=i+1; j < param.num_src; j++) {
1539  std::complex<double> ri=blas::cDotProduct(p.Component(i),p.Component(j));
1540  blas::caxpy(-ri,p.Component(i),p.Component(j));
1541 
1542  }
1543  }
1544 
1545 
1546  gamma = MatrixXcd::Zero(param.num_src,param.num_src);
1547  for ( int i = 0; i < param.num_src; i++){
1548  for (int j=i; j < param.num_src; j++){
1549  gamma(i,j) = blas::cDotProduct(p.Component(i),pnew.Component(j));
1550  }
1551  }
1552  #endif
1553 
1554  #ifdef MWVERBOSE
1555  for(int i=0; i<param.num_src; i++){
1556  for(int j=0; j<param.num_src; j++){
1557  pTp(i,j) = blas::cDotProduct(p.Component(i), p.Component(j));
1558  }
1559  }
1560  std::cout << " pTp " << std::endl << pTp << std::endl;
1561  std::cout << "QR" << gamma<< std::endl << "QP " << gamma.inverse()*gamma << std::endl;;
1562  #endif
1563  }
1564 
1565 
1566  if (use_heavy_quark_res && (k % heavy_quark_check) == 0) {
1567  if (&x != &xSloppy) {
1568  blas::copy(tmp, y); // FIXME: check whether copy works here
1569  for(int i=0; i<param.num_src; i++){
1570  heavy_quark_res[i] = sqrt(blas::xpyHeavyQuarkResidualNorm(xSloppy.Component(i), tmp.Component(i), rSloppy.Component(i)).z);
1571  }
1572  } else {
1573  blas::copy(r, rSloppy); // FIXME: check whether copy works here
1574  for(int i=0; i<param.num_src; i++){
1575  heavy_quark_res[i] = sqrt(blas::xpyHeavyQuarkResidualNorm(x.Component(i), y.Component(i), r.Component(i)).z);
1576  }
1577  }
1578  }
1579 
1580  steps_since_reliable++;
1581  } else {
1582  printfQuda("reliable update\n");
1583  for(int i=0; i<param.num_src; i++){
1584  blas::axpy(alpha(i,i).real(), p.Component(i), xSloppy.Component(i));
1585  }
1586  blas::copy(x, xSloppy); // nop when these pointers alias
1587 
1588  for(int i=0; i<param.num_src; i++){
1589  blas::xpy(x.Component(i), y.Component(i)); // swap these around?
1590  }
1591  for(int i=0; i<param.num_src; i++){
1592  mat(r.Component(i), y.Component(i), x.Component(i), tmp3.Component(i)); // here we can use x as tmp
1593  }
1594  for(int i=0; i<param.num_src; i++){
1595  r2(i,i) = blas::xmyNorm(b.Component(i), r.Component(i));
1596  }
1597 
1598  for(int i=0; i<param.num_src; i++){
1599  blas::copy(rSloppy.Component(i), r.Component(i)); //nop when these pointers alias
1600  blas::zero(xSloppy.Component(i));
1601  }
1602 
1603  // calculate new reliable HQ resididual
1604  if (use_heavy_quark_res){
1605  for(int i=0; i<param.num_src; i++){
1606  heavy_quark_res[i] = sqrt(blas::HeavyQuarkResidualNorm(y.Component(i), r.Component(i)).z);
1607  }
1608  }
1609 
1610  // MW: FIXME as this probably goes terribly wrong right now
1611  for(int i = 0; i<param.num_src; i++){
1612  // break-out check if we have reached the limit of the precision
1613  if (sqrt(r2(i,i).real()) > r0Norm[i] && updateX) { // reuse r0Norm for this
1614  resIncrease++;
1615  resIncreaseTotal++;
1616  warningQuda("CG: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
1617  sqrt(r2(i,i).real()), r0Norm[i], resIncreaseTotal);
1618  if ( resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
1619  if (use_heavy_quark_res) {
1620  L2breakdown = true;
1621  } else {
1622  warningQuda("CG: solver exiting due to too many true residual norm increases");
1623  break;
1624  }
1625  }
1626  } else {
1627  resIncrease = 0;
1628  }
1629  }
1630  // if L2 broke down already we turn off reliable updates and restart the CG
1631  for(int i = 0; i<param.num_src; i++){
1632  if (use_heavy_quark_res and L2breakdown) {
1633  delta = 0;
1634  warningQuda("CG: Restarting without reliable updates for heavy-quark residual");
1635  heavy_quark_restart = true;
1636  if (heavy_quark_res[i] > heavy_quark_res_old[i]) {
1637  hqresIncrease++;
1638  warningQuda("CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", heavy_quark_res[i], heavy_quark_res_old[i]);
1639  // break out if we do not improve here anymore
1640  if (hqresIncrease > hqmaxresIncrease) {
1641  warningQuda("CG: solver exiting due to too many heavy quark residual norm increases");
1642  break;
1643  }
1644  }
1645  }
1646  }
1647 
1648  for(int i=0; i<param.num_src; i++){
1649  rNorm[i] = sqrt(r2(i,i).real());
1650  maxrr[i] = rNorm[i];
1651  maxrx[i] = rNorm[i];
1652  r0Norm[i] = rNorm[i];
1653  heavy_quark_res_old[i] = heavy_quark_res[i];
1654  }
1655  rUpdate++;
1656 
1657  if (use_heavy_quark_res and heavy_quark_restart) {
1658  // perform a restart
1659  blas::copy(p, rSloppy);
1660  heavy_quark_restart = false;
1661  } else {
1662  // explicitly restore the orthogonality of the gradient vector
1663  for(int i=0; i<param.num_src; i++){
1664  double rp = blas::reDotProduct(rSloppy.Component(i), p.Component(i)) / (r2(i,i).real());
1665  blas::axpy(-rp, rSloppy.Component(i), p.Component(i));
1666 
1667  beta(i,i) = r2(i,i) / r2_old(i,i);
1668  blas::xpay(rSloppy.Component(i), beta(i,i).real(), p.Component(i));
1669  }
1670  }
1671 
1672  steps_since_reliable = 0;
1673  }
1674 
1675  breakdown = false;
1676  k++;
1677 
1678  allconverged = true;
1679  r2avg=0;
1680  for(int i=0; i<param.num_src; i++){
1681  r2avg+= r2(i,i).real();
1682  // check convergence, if convergence is satisfied we only need to check that we had a reliable update for the heavy quarks recently
1683  converged[i] = convergence(r2(i,i).real(), heavy_quark_res[i], stop[i], param.tol_hq);
1684  allconverged = allconverged && converged[i];
1685  }
1686  PrintStats("CG", k, r2avg, b2avg, heavy_quark_res[0]);
1687 
1688  // check for recent enough reliable updates of the HQ residual if we use it
1689  if (use_heavy_quark_res) {
1690  for(int i=0; i<param.num_src; i++){
1691  // L2 is concverged or precision maxed out for L2
1692  bool L2done = L2breakdown or convergenceL2(r2(i,i).real(), heavy_quark_res[i], stop[i], param.tol_hq);
1693  // HQ is converged and if we do reliable update the HQ residual has been calculated using a reliable update
1694  bool HQdone = (steps_since_reliable == 0 and param.delta > 0) and convergenceHQ(r2(i,i).real(), heavy_quark_res[i], stop[i], param.tol_hq);
1695  converged[i] = L2done and HQdone;
1696  }
1697  }
1698 
1699  }
1700 
1701  blas::copy(x, xSloppy);
1702  for(int i=0; i<param.num_src; i++){
1703  blas::xpy(y.Component(i), x.Component(i));
1704  }
1705 
1706  profile.TPSTOP(QUDA_PROFILE_COMPUTE);
1707  profile.TPSTART(QUDA_PROFILE_EPILOGUE);
1708 
1710  double gflops = (blas::flops + mat.flops() + matSloppy.flops())*1e-9;
1711  param.gflops = gflops;
1712  param.iter += k;
1713 
1714  if (k == param.maxiter)
1715  warningQuda("Exceeded maximum iterations %d", param.maxiter);
1716 
1717  if (getVerbosity() >= QUDA_VERBOSE)
1718  printfQuda("CG: Reliable updates = %d\n", rUpdate);
1719 
1720  // compute the true residuals
1721  for(int i=0; i<param.num_src; i++){
1722  mat(r.Component(i), x.Component(i), y.Component(i), tmp3.Component(i));
1723  param.true_res = sqrt(blas::xmyNorm(b.Component(i), r.Component(i)) / b2[i]);
1727 
1728  PrintSummary("CG", k, r2(i,i).real(), b2[i], stop[i], 0.0);
1729  }
1730 
1731  // reset the flops counters
1732  blas::flops = 0;
1733  mat.flops();
1734  matSloppy.flops();
1735 
1737  profile.TPSTART(QUDA_PROFILE_FREE);
1738 
1739  profile.TPSTOP(QUDA_PROFILE_FREE);
1740 
1741  return;
1742 
1743  #endif
1744 
1745 }
1746 #endif
1747 
1748 
1749 } // namespace quda
Conjugate-Gradient Solver.
Definition: invert_quda.h:639
void operator()(ColorSpinorField &out, ColorSpinorField &in)
Run CG.
Definition: invert_quda.h:656
void blocksolve(ColorSpinorField &out, ColorSpinorField &in)
CG(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam &param, TimeProfile &profile)
Definition: inv_cg_quda.cpp:19
virtual ~CG()
Definition: inv_cg_quda.cpp:36
void operator()(ColorSpinorField &out, ColorSpinorField &in)
Run CG.
Definition: inv_cg_quda.cpp:84
CGNE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam &param, TimeProfile &profile)
Definition: inv_cg_quda.cpp:62
virtual ~CGNE()
Definition: inv_cg_quda.cpp:75
void operator()(ColorSpinorField &out, ColorSpinorField &in)
Run CG.
virtual ~CGNR()
CGNR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam &param, TimeProfile &profile)
static ColorSpinorField * Create(const ColorSpinorParam &param)
ColorSpinorField & Component(const int idx) const
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const =0
Apply M for the dirac op. E.g. the Schur Complement operator.
void Mdag(ColorSpinorField &out, const ColorSpinorField &in) const
Apply Mdag (daggered operator of M.
Definition: dirac.cpp:92
const Dirac * Expose() const
Definition: dirac_quda.h:1964
bool isStaggered() const
return if the operator is a staggered operator
Definition: dirac_quda.h:1935
unsigned long long flops() const
Definition: dirac_quda.h:1909
void deflate(std::vector< ColorSpinorField * > &sol, const std::vector< ColorSpinorField * > &src, const std::vector< ColorSpinorField * > &evecs, const std::vector< Complex > &evals, bool accumulate=false) const
Deflate a set of source vectors with a given eigenspace.
void computeEvals(const DiracMatrix &mat, std::vector< ColorSpinorField * > &evecs, std::vector< Complex > &evals, int size)
Compute eigenvalues and their residiua.
QudaPrecision Precision() const
double precisionEpsilon(QudaPrecision prec=QUDA_INVALID_PRECISION) const
Returns the epsilon tolerance for a given precision, by default returns the solver precision.
Definition: solver.cpp:412
bool deflate_compute
Definition: invert_quda.h:475
TimeProfile & profile
Definition: invert_quda.h:471
bool convergenceL2(double r2, double hq2, double r2_tol, double hq_tol)
Test for L2 solver convergence – ignore HQ residual.
Definition: solver.cpp:361
const DiracMatrix & mat
Definition: invert_quda.h:465
bool convergence(double r2, double hq2, double r2_tol, double hq_tol)
Definition: solver.cpp:328
bool recompute_evals
Definition: invert_quda.h:476
std::vector< ColorSpinorField * > evecs
Definition: invert_quda.h:477
bool convergenceHQ(double r2, double hq2, double r2_tol, double hq_tol)
Test for HQ solver convergence – ignore L2 residual.
Definition: solver.cpp:348
void destroyDeflationSpace()
Destroy the allocated deflation space.
Definition: solver.cpp:229
void PrintSummary(const char *name, int k, double r2, double b2, double r2_tol, double hq_tol)
Prints out the summary of the solver convergence (requires a verbosity of QUDA_SUMMARIZE)....
Definition: solver.cpp:386
const DiracMatrix & matEig
Definition: invert_quda.h:468
SolverParam & param
Definition: invert_quda.h:470
static double stopping(double tol, double b2, QudaResidualType residual_type)
Set the solver L2 stopping condition.
Definition: solver.cpp:311
std::vector< Complex > evals
Definition: invert_quda.h:478
EigenSolver * eig_solve
Definition: invert_quda.h:473
void PrintStats(const char *name, int k, double r2, double b2, double hq2)
Prints out the running statistics of the solver (requires a verbosity of QUDA_VERBOSE)
Definition: solver.cpp:373
void constructDeflationSpace(const ColorSpinorField &meta, const DiracMatrix &mat)
Constructs the deflation space and eigensolver.
Definition: solver.cpp:168
const DiracMatrix & matPrecon
Definition: invert_quda.h:467
const DiracMatrix & matSloppy
Definition: invert_quda.h:466
double Last(QudaProfileType idx)
Definition: timer.h:254
void commGlobalReductionSet(bool global_reduce)
bool alternative_reliable
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:34
@ QUDA_CUDA_FIELD_LOCATION
Definition: enum_quda.h:326
@ QUDA_USE_INIT_GUESS_NO
Definition: enum_quda.h:429
@ QUDA_USE_INIT_GUESS_YES
Definition: enum_quda.h:430
@ QUDA_DEBUG_VERBOSE
Definition: enum_quda.h:268
@ QUDA_VERBOSE
Definition: enum_quda.h:267
@ QUDA_HEAVY_QUARK_RESIDUAL
Definition: enum_quda.h:195
@ QUDA_PRESERVE_SOURCE_NO
Definition: enum_quda.h:238
@ QUDA_PRESERVE_SOURCE_YES
Definition: enum_quda.h:239
@ QUDA_DOUBLE_PRECISION
Definition: enum_quda.h:65
@ QUDA_ZERO_FIELD_CREATE
Definition: enum_quda.h:361
@ QUDA_COPY_FIELD_CREATE
Definition: enum_quda.h:362
@ QUDA_NULL_FIELD_CREATE
Definition: enum_quda.h:360
@ QUDA_COMPUTE_NULL_VECTOR_NO
Definition: enum_quda.h:441
Matrix< N, std::complex< T > > conj(const Matrix< N, std::complex< T > > &mat)
#define checkPrecision(...)
#define checkLocation(...)
void init()
Create the BLAS context.
double4 quadrupleCGReduction(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Complex axpyCGNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
void axpyZpbx(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, double b)
void xpayz(ColorSpinorField &x, double a, ColorSpinorField &y, ColorSpinorField &z)
Definition: blas_quda.h:46
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:79
double3 tripleCGReduction(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
unsigned long long flops
void xpay(ColorSpinorField &x, double a, ColorSpinorField &y)
Definition: blas_quda.h:45
double3 xpyHeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &r)
void ax(double a, ColorSpinorField &x)
void zero(ColorSpinorField &a)
double norm2(const ColorSpinorField &a)
void tripleCGUpdate(double alpha, double beta, ColorSpinorField &q, ColorSpinorField &r, ColorSpinorField &x, ColorSpinorField &p)
double reDotProduct(ColorSpinorField &x, ColorSpinorField &y)
double axpyNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:78
void axpy(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:43
double3 cDotProductNormA(ColorSpinorField &a, ColorSpinorField &b)
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
void xpy(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:41
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
Definition: blas_quda.h:24
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
void axpby(double a, ColorSpinorField &x, double b, ColorSpinorField &y)
Definition: blas_quda.h:44
void stop()
Stop profiling.
Definition: device.cpp:228
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:130
std::complex< double > Complex
Definition: quda_internal.h:86
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:120
@ QUDA_PROFILE_INIT
Definition: timer.h:106
@ QUDA_PROFILE_EPILOGUE
Definition: timer.h:110
@ QUDA_PROFILE_COMPUTE
Definition: timer.h:108
@ QUDA_PROFILE_FREE
Definition: timer.h:111
@ QUDA_PROFILE_PREAMBLE
Definition: timer.h:107
ColorSpinorParam csParam
Definition: pack_test.cpp:25
QudaGaugeParam param
Definition: pack_test.cpp:18
void updateR()
update the radius for halos.
#define QUDA_MAX_MULTI_SHIFT
Maximum number of shifts supported by the multi-shift solver. This number may be changed if need be.
QudaPreserveSource preserve_source
Definition: invert_quda.h:151
QudaPrecision precision
Definition: invert_quda.h:136
QudaComputeNullVector compute_null_vector
Definition: invert_quda.h:61
bool is_preconditioner
verbosity to use for preconditioner
Definition: invert_quda.h:238
bool use_sloppy_partial_accumulator
Definition: invert_quda.h:70
int max_res_increase_total
Definition: invert_quda.h:90
QudaResidualType residual_type
Definition: invert_quda.h:49
bool use_alternative_reliable
Definition: invert_quda.h:67
QudaPrecision precision_sloppy
Definition: invert_quda.h:139
double true_res_offset[QUDA_MAX_MULTI_SHIFT]
Definition: invert_quda.h:178
int solution_accumulator_pipeline
Definition: invert_quda.h:80
double true_res_hq_offset[QUDA_MAX_MULTI_SHIFT]
Definition: invert_quda.h:184
QudaUseInitGuess use_init_guess
Definition: invert_quda.h:58
int max_hq_res_restart_total
Definition: invert_quda.h:100
bool global_reduction
whether the solver acting as a preconditioner for another solver
Definition: invert_quda.h:240
#define printfQuda(...)
Definition: util_quda.h:114
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define warningQuda(...)
Definition: util_quda.h:132
#define errorQuda(...)
Definition: util_quda.h:120