QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
inv_cg_quda.cpp
Go to the documentation of this file.
1 #include <cstdio>
2 #include <cstdlib>
3 #include <cmath>
4 #include <limits>
5 #include <memory>
6 #include <iostream>
7 
8 #ifdef BLOCKSOLVER
9 #include <Eigen/Dense>
10 #endif
11 
12 #include <quda_internal.h>
13 #include <color_spinor_field.h>
14 #include <blas_quda.h>
15 #include <dslash_quda.h>
16 #include <invert_quda.h>
17 #include <util_quda.h>
18 #include <eigensolve_quda.h>
19 
20 namespace quda {
21 
23  Solver(param, profile), mat(mat), matSloppy(matSloppy), yp(nullptr), rp(nullptr),
24  rnewp(nullptr), pp(nullptr), App(nullptr), tmpp(nullptr), tmp2p(nullptr), tmp3p(nullptr),
25  rSloppyp(nullptr), xSloppyp(nullptr), init(false)
26  {
27  }
28 
30  {
31  profile.TPSTART(QUDA_PROFILE_FREE);
32  if ( init ) {
33  for (auto pi : p) if (pi) delete pi;
34  if (rp) delete rp;
35  if (pp) delete pp;
36  if (yp) delete yp;
37  if (App) delete App;
39  if (rSloppyp) delete rSloppyp;
40  if (xSloppyp) delete xSloppyp;
41  }
42  if (tmpp) delete tmpp;
43  if (!mat.isStaggered()) {
44  if (tmp2p && tmpp != tmp2p) delete tmp2p;
45  if (tmp3p && tmpp != tmp3p && param.precision != param.precision_sloppy) delete tmp3p;
46  }
47  if (rnewp) delete rnewp;
48  init = false;
49 
50  if (deflate_init) {
51  for (auto veci : param.evecs)
52  if (veci) delete veci;
53  delete defl_tmp1[0];
54  delete defl_tmp2[0];
55  }
56  }
57  profile.TPSTOP(QUDA_PROFILE_FREE);
58  }
59 
61  CG(mmdag, mmdagSloppy, param, profile), mmdag(mat.Expose()), mmdagSloppy(matSloppy.Expose()),
62  xp(nullptr), yp(nullptr), init(false) {
63  }
64 
66  if ( init ) {
67  if (xp) delete xp;
68  if (yp) delete yp;
69  init = false;
70  }
71  }
72 
73  // CGNE: M Mdag y = b is solved; x = Mdag y is returned as solution.
75  if (param.maxiter == 0 || param.Nsteps == 0) {
77  return;
78  }
79 
80  const int iter0 = param.iter;
81 
82  if (!init) {
85  xp = ColorSpinorField::Create(x, csParam);
87  yp = ColorSpinorField::Create(x, csParam);
88  init = true;
89  }
90 
91  double b2 = blas::norm2(b);
92 
94 
95  // compute initial residual
96  mmdag.Expose()->M(*xp,x);
97  double r2 = blas::xmyNorm(b,*xp);
98  if (b2 == 0.0) b2 = r2;
99 
100  // compute solution to residual equation
101  CG::operator()(*yp,*xp);
102 
103  mmdag.Expose()->Mdag(*xp,*yp);
104 
105  // compute full solution
106  blas::xpy(*xp, x);
107 
108  } else {
109 
110  CG::operator()(*yp,b);
111  mmdag.Expose()->Mdag(x,*yp);
112 
113  }
114 
115  // future optimization: with preserve_source == QUDA_PRESERVE_SOURCE_NO; b is already
116  // expected to be the CG residual which matches the CGNE residual
117  // (but only with zero initial guess). at the moment, CG does not respect this convention
119 
120  // compute the true residual
121  mmdag.Expose()->M(*xp, x);
122 
125  blas::axpby(-1.0, A, 1.0, B);
126 
127  double r2;
129  double3 h3 = blas::HeavyQuarkResidualNorm(x, B);
130  r2 = h3.y;
131  param.true_res_hq = sqrt(h3.z);
132  } else {
133  r2 = blas::norm2(B);
134  }
135  param.true_res = sqrt(r2 / b2);
136 
137  PrintSummary("CGNE", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq);
138  }
139 
140  }
141 
143  CG(mdagm, mdagmSloppy, param, profile), mdagm(mat.Expose()), mdagmSloppy(matSloppy.Expose()),
144  bp(nullptr), init(false) {
145  }
146 
148  if ( init ) {
149  if (bp) delete bp;
150  init = false;
151  }
152  }
153 
154  // CGNR: Mdag M x = Mdag b is solved.
156  if (param.maxiter == 0 || param.Nsteps == 0) {
158  return;
159  }
160 
161  const int iter0 = param.iter;
162 
163  if (!init) {
165  csParam.create = QUDA_ZERO_FIELD_CREATE;
166  bp = ColorSpinorField::Create(csParam);
167  init = true;
168  }
169 
170  double b2 = blas::norm2(b);
171  if (b2 == 0.0) { // compute initial residual vector
172  mdagm.Expose()->M(*bp,x);
173  b2 = blas::norm2(*bp);
174  }
175 
176  mdagm.Expose()->Mdag(*bp,b);
177  CG::operator()(x,*bp);
178 
180 
181  // compute the true residual
182  mdagm.Expose()->M(*bp, x);
183 
186  blas::axpby(-1.0, A, 1.0, B);
187 
188  double r2;
190  double3 h3 = blas::HeavyQuarkResidualNorm(x, B);
191  r2 = h3.y;
192  param.true_res_hq = sqrt(h3.z);
193  } else {
194  r2 = blas::norm2(B);
195  }
196  param.true_res = sqrt(r2 / b2);
197  PrintSummary("CGNR", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq);
198 
200  mdagm.Expose()->M(*bp, x);
201  blas::axpby(-1.0, *bp, 1.0, b);
202  }
203 
204  }
205 
206  void CG::operator()(ColorSpinorField &x, ColorSpinorField &b, ColorSpinorField *p_init, double r2_old_init)
207  {
209  errorQuda("Not supported");
210  if (checkPrecision(x, b) != param.precision)
211  errorQuda("Precision mismatch: expected=%d, received=%d", param.precision, x.Precision());
212 
213  if (param.maxiter == 0 || param.Nsteps == 0) {
215  return;
216  }
217 
219  if (Np < 0 || Np > 16) errorQuda("Invalid value %d for solution_accumulator_pipeline\n", Np);
220 
221  // whether to select alternative reliable updates
223 
224  profile.TPSTART(QUDA_PROFILE_INIT);
225 
226  // Check to see that we're not trying to invert on a zero-field source
227  double b2 = blas::norm2(b);
228 
229  // Check to see that we're not trying to invert on a zero-field source
231  profile.TPSTOP(QUDA_PROFILE_INIT);
232  printfQuda("Warning: inverting on zero-field source\n");
233  x = b;
234  param.true_res = 0.0;
235  param.true_res_hq = 0.0;
236  return;
237  }
238 
239  if (!init) {
241  csParam.create = QUDA_NULL_FIELD_CREATE;
242  rp = ColorSpinorField::Create(csParam);
243  yp = ColorSpinorField::Create(csParam);
244 
245  // sloppy fields
247  App = ColorSpinorField::Create(csParam);
251  } else {
252  rSloppyp = rp;
254  }
255 
256  // temporary fields
257  tmpp = ColorSpinorField::Create(csParam);
258  if(!mat.isStaggered()) {
259  // tmp2 only needed for multi-gpu Wilson-like kernels
260  tmp2p = ColorSpinorField::Create(csParam);
261  // additional high-precision temporary if Wilson and mixed-precision
262  csParam.setPrecision(param.precision);
264  ColorSpinorField::Create(csParam) : tmpp;
265  } else {
266  tmp3p = tmp2p = tmpp;
267  }
268 
269  init = true;
270  }
271 
272  // Once the CG operator is called, we are able to construct an appropriate
273  // Krylov space for deflation
274  if (param.deflate && !deflate_init) { constructDeflationSpace(b, mat, false); }
275 
276  ColorSpinorField &r = *rp;
277  ColorSpinorField &y = *yp;
278  ColorSpinorField &Ap = *App;
281  ColorSpinorField &tmp3 = *tmp3p;
282  ColorSpinorField &rSloppy = *rSloppyp;
284 
285  {
287  csParam.create = QUDA_NULL_FIELD_CREATE;
289 
290  if (Np != (int)p.size()) {
291  for (auto &pi : p) delete pi;
292  p.resize(Np);
293  for (auto &pi : p) pi = ColorSpinorField::Create(csParam);
294  }
295  }
296 
297  // alternative reliable updates
298  // alternative reliable updates - set precision - does not hurt performance here
299 
301  const double uhigh= param.precision == 8 ? std::numeric_limits<double>::epsilon()/2. : ((param.precision == 4) ? std::numeric_limits<float>::epsilon()/2. : pow(2.,-13));
302  const double deps=sqrt(u);
303  constexpr double dfac = 1.1;
304  double d_new = 0;
305  double d = 0;
306  double dinit = 0;
307  double xNorm = 0;
308  double xnorm = 0;
309  double pnorm = 0;
310  double ppnorm = 0;
311  double Anorm = 0;
312  double beta = 0.0;
313 
314  // for alternative reliable updates
315  if(alternative_reliable){
316  // estimate norm for reliable updates
317  mat(r, b, y, tmp3);
318  Anorm = sqrt(blas::norm2(r)/b2);
319  }
320 
321  // compute initial residual
322  double r2 = 0.0;
324  // Compute r = b - A * x
325  mat(r, x, y, tmp3);
326  r2 = blas::xmyNorm(b, r);
327  if (b2 == 0) b2 = r2;
328  // y contains the original guess.
329  blas::copy(y, x);
330  } else {
331  if (&r != &b) blas::copy(r, b);
332  r2 = b2;
333  blas::zero(y);
334  }
335 
336  if (param.deflate == true) {
337  std::vector<ColorSpinorField *> rhs;
338  // Use residual from supplied guess r, or original
339  // rhs b. use `x` as a temp.
340  blas::copy(x, r);
341  rhs.push_back(&x);
342 
343  // Deflate
345 
346  // Compute r_defl = RHS - A * LHS
347  mat(r, *defl_tmp1[0], tmp2, tmp3);
348  r2 = blas::xmyNorm(*rhs[0], r);
349 
351  // defl_tmp1 and y must be added to the solution at the end
352  blas::axpy(1.0, *defl_tmp1[0], y);
353  } else {
354  // Just add defl_tmp1 to y, which has been zeroed out
355  blas::copy(y, *defl_tmp1[0]);
356  }
357  }
358 
359  blas::zero(x);
360  if (&x != &xSloppy) blas::zero(xSloppy);
361  blas::copy(rSloppy,r);
362 
363  if (Np != (int)p.size()) {
364  for (auto &pi : p) delete pi;
365  p.resize(Np);
366  ColorSpinorParam csParam(rSloppy);
367  csParam.create = QUDA_COPY_FIELD_CREATE;
368  for (auto &pi : p)
369  pi = p_init ? ColorSpinorField::Create(*p_init, csParam) : ColorSpinorField::Create(rSloppy, csParam);
370  } else {
371  for (auto &p_i : p) *p_i = p_init ? *p_init : rSloppy;
372  }
373 
374  double r2_old=0.0;
375  if (r2_old_init != 0.0 and p_init) {
376  r2_old = r2_old_init;
377  Complex rp = blas::cDotProduct(rSloppy, *p[0]) / (r2);
378  blas::caxpy(-rp, rSloppy, *p[0]);
379  beta = r2 / r2_old;
380  blas::xpayz(rSloppy, beta, *p[0], *p[0]);
381  }
382 
383  const bool use_heavy_quark_res =
384  (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false;
385  bool heavy_quark_restart = false;
386 
387  profile.TPSTOP(QUDA_PROFILE_INIT);
389 
390  double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver
391 
392  double heavy_quark_res = 0.0; // heavy quark res idual
393  double heavy_quark_res_old = 0.0; // heavy quark residual
394 
395  if (use_heavy_quark_res) {
396  heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z);
397  heavy_quark_res_old = heavy_quark_res; // heavy quark residual
398  }
399  const int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual
400 
401  double alpha[Np];
402  double pAp;
403  int rUpdate = 0;
404 
405  double rNorm = sqrt(r2);
406  double r0Norm = rNorm;
407  double maxrx = rNorm;
408  double maxrr = rNorm;
409  double delta = param.delta;
410 
411  // this parameter determines how many consective reliable update
412  // residual increases we tolerate before terminating the solver,
413  // i.e., how long do we want to keep trying to converge
414  const int maxResIncrease = param.max_res_increase; // check if we reached the limit of our tolerance
415  const int maxResIncreaseTotal = param.max_res_increase_total;
416 
417  // this means when using heavy quarks we will switch to simple hq restarts as soon as the reliable strategy fails
418  const int hqmaxresIncrease = param.max_hq_res_increase;
419  const int hqmaxresRestartTotal
420  = param.max_hq_res_restart_total; // this limits the number of heavy quark restarts we can do
421 
422  int resIncrease = 0;
423  int resIncreaseTotal = 0;
424  int hqresIncrease = 0;
425  int hqresRestartTotal = 0;
426 
427  // set this to true if maxResIncrease has been exceeded but when we use heavy quark residual we still want to continue the CG
428  // only used if we use the heavy_quark_res
429  bool L2breakdown = false;
430  const double L2breakdown_eps = 100. * uhigh;
431 
433  profile.TPSTART(QUDA_PROFILE_COMPUTE);
434  blas::flops = 0;
435 
436  int k = 0;
437  int j = 0;
438 
439  PrintStats("CG", k, r2, b2, heavy_quark_res);
440 
441  int steps_since_reliable = 1;
442  bool converged = convergence(r2, heavy_quark_res, stop, param.tol_hq);
443 
444  // alternative reliable updates
445  if(alternative_reliable){
446  dinit = uhigh * (rNorm + Anorm * xNorm);
447  d = dinit;
448  }
449 
450  while ( !converged && k < param.maxiter ) {
451  matSloppy(Ap, *p[j], tmp, tmp2); // tmp as tmp
452  double sigma;
453 
454  bool breakdown = false;
455  if (param.pipeline) {
456  double Ap2;
457  //TODO: alternative reliable updates - need r2, Ap2, pAp, p norm
458  if(alternative_reliable){
459  double4 quadruple = blas::quadrupleCGReduction(rSloppy, Ap, *p[j]);
460  r2 = quadruple.x; Ap2 = quadruple.y; pAp = quadruple.z; ppnorm= quadruple.w;
461  }
462  else{
463  double3 triplet = blas::tripleCGReduction(rSloppy, Ap, *p[j]);
464  r2 = triplet.x; Ap2 = triplet.y; pAp = triplet.z;
465  }
466  r2_old = r2;
467  alpha[j] = r2 / pAp;
468  sigma = alpha[j]*(alpha[j] * Ap2 - pAp);
469  if (sigma < 0.0 || steps_since_reliable == 0) { // sigma condition has broken down
470  r2 = blas::axpyNorm(-alpha[j], Ap, rSloppy);
471  sigma = r2;
472  breakdown = true;
473  }
474 
475  r2 = sigma;
476  } else {
477  r2_old = r2;
478 
479  // alternative reliable updates,
480  if (alternative_reliable) {
481  double3 pAppp = blas::cDotProductNormA(*p[j],Ap);
482  pAp = pAppp.x;
483  ppnorm = pAppp.z;
484  } else {
485  pAp = blas::reDotProduct(*p[j], Ap);
486  }
487 
488  alpha[j] = r2 / pAp;
489 
490  // here we are deploying the alternative beta computation
491  Complex cg_norm = blas::axpyCGNorm(-alpha[j], Ap, rSloppy);
492  r2 = real(cg_norm); // (r_new, r_new)
493  sigma = imag(cg_norm) >= 0.0 ? imag(cg_norm) : r2; // use r2 if (r_k+1, r_k+1-r_k) breaks
494  }
495 
496  // reliable update conditions
497  rNorm = sqrt(r2);
498  int updateX;
499  int updateR;
500 
501  if (alternative_reliable) {
502  // alternative reliable updates
503  updateX = ( (d <= deps*sqrt(r2_old)) or (dfac * dinit > deps * r0Norm) ) and (d_new > deps*rNorm) and (d_new > dfac * dinit);
504  updateR = 0;
505  } else {
506  if (rNorm > maxrx) maxrx = rNorm;
507  if (rNorm > maxrr) maxrr = rNorm;
508  updateX = (rNorm < delta * r0Norm && r0Norm <= maxrx) ? 1 : 0;
509  updateR = ((rNorm < delta * maxrr && r0Norm <= maxrr) || updateX) ? 1 : 0;
510  }
511 
512  // force a reliable update if we are within target tolerance (only if doing reliable updates)
513  if ( convergence(r2, heavy_quark_res, stop, param.tol_hq) && param.delta >= param.tol ) updateX = 1;
514 
515  // For heavy-quark inversion force a reliable update if we continue after
516  if ( use_heavy_quark_res and L2breakdown and convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq) and param.delta >= param.tol ) {
517  updateX = 1;
518  }
519 
520  if ( !(updateR || updateX )) {
521  beta = sigma / r2_old; // use the alternative beta computation
522 
523 
524  if (param.pipeline && !breakdown) {
525 
526  if (Np == 1) {
527  blas::tripleCGUpdate(alpha[j], beta, Ap, xSloppy, rSloppy, *p[j]);
528  } else {
529  errorQuda("Not implemented pipelined CG with Np > 1");
530  }
531  } else {
532  if (Np == 1) {
533  // with Np=1 we just run regular fusion between x and p updates
534  blas::axpyZpbx(alpha[k%Np], *p[k%Np], xSloppy, rSloppy, beta);
535  } else {
536 
537  if ( (j+1)%Np == 0 ) {
538  const auto alpha_ = std::unique_ptr<Complex[]>(new Complex[Np]);
539  for (int i=0; i<Np; i++) alpha_[i] = alpha[i];
540  std::vector<ColorSpinorField*> x_;
541  x_.push_back(&xSloppy);
542  blas::caxpy(alpha_.get(), p, x_);
543  blas::flops -= 4*j*xSloppy.RealLength(); // correct for over flop count since using caxpy
544  }
545 
546  //p[(k+1)%Np] = r + beta * p[k%Np]
547  blas::xpayz(rSloppy, beta, *p[j], *p[(j+1)%Np]);
548  }
549  }
550 
551  if (use_heavy_quark_res && k%heavy_quark_check==0) {
552  if (&x != &xSloppy) {
553  blas::copy(tmp,y);
554  heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(xSloppy, tmp, rSloppy).z);
555  } else {
556  blas::copy(r, rSloppy);
557  heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x, y, r).z);
558  }
559  }
560 
561  // alternative reliable updates
562  if (alternative_reliable) {
563  d = d_new;
564  pnorm = pnorm + alpha[j] * alpha[j]* (ppnorm);
565  xnorm = sqrt(pnorm);
566  d_new = d + u*rNorm + uhigh*Anorm * xnorm;
567  if (steps_since_reliable==0 && getVerbosity() >= QUDA_DEBUG_VERBOSE)
568  printfQuda("New dnew: %e (r %e , y %e)\n",d_new,u*rNorm,uhigh*Anorm * sqrt(blas::norm2(y)) );
569  }
570  steps_since_reliable++;
571 
572  } else {
573 
574  {
575  const auto alpha_ = std::unique_ptr<Complex[]>(new Complex[Np]);
576  for (int i=0; i<=j; i++) alpha_[i] = alpha[i];
577  std::vector<ColorSpinorField*> x_;
578  x_.push_back(&xSloppy);
579  std::vector<ColorSpinorField*> p_;
580  for (int i=0; i<=j; i++) p_.push_back(p[i]);
581  blas::caxpy(alpha_.get(), p_, x_);
582  blas::flops -= 4*j*xSloppy.RealLength(); // correct for over flop count since using caxpy
583  }
584 
585  blas::copy(x, xSloppy); // nop when these pointers alias
586 
587  blas::xpy(x, y); // swap these around?
588  mat(r, y, x, tmp3); // here we can use x as tmp
589  r2 = blas::xmyNorm(b, r);
590 
591  blas::copy(rSloppy, r); //nop when these pointers alias
592  blas::zero(xSloppy);
593 
594  // alternative reliable updates
595  if (alternative_reliable) {
596  dinit = uhigh*(sqrt(r2) + Anorm * sqrt(blas::norm2(y)));
597  d = d_new;
598  xnorm = 0;//sqrt(norm2(x));
599  pnorm = 0;//pnorm + alpha * sqrt(norm2(p));
600  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("New dinit: %e (r %e , y %e)\n",dinit,uhigh*sqrt(r2),uhigh*Anorm*sqrt(blas::norm2(y)));
601  d_new = dinit;
602  } else {
603  rNorm = sqrt(r2);
604  maxrr = rNorm;
605  maxrx = rNorm;
606  }
607 
608  // calculate new reliable HQ resididual
609  if (use_heavy_quark_res) heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(y, r).z);
610 
611  // break-out check if we have reached the limit of the precision
612  if (sqrt(r2) > r0Norm && updateX and not L2breakdown) { // reuse r0Norm for this
613  resIncrease++;
614  resIncreaseTotal++;
615  warningQuda(
616  "CG: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
617  sqrt(r2), r0Norm, resIncreaseTotal);
618 
619  if ((use_heavy_quark_res and sqrt(r2) < L2breakdown_eps) or resIncrease > maxResIncrease
620  or resIncreaseTotal > maxResIncreaseTotal or r2 < stop) {
621  if (use_heavy_quark_res) {
622  L2breakdown = true;
623  warningQuda("CG: L2 breakdown %e, %e", sqrt(r2), L2breakdown_eps);
624  } else {
625  if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal or r2 < stop) {
626  warningQuda("CG: solver exiting due to too many true residual norm increases");
627  break;
628  }
629  }
630  }
631  } else {
632  resIncrease = 0;
633  }
634 
635  // if L2 broke down already we turn off reliable updates and restart the CG
636  if (use_heavy_quark_res and L2breakdown) {
637  hqresRestartTotal++; // count the number of heavy quark restarts we've done
638  delta = 0;
639  warningQuda("CG: Restarting without reliable updates for heavy-quark residual (total #inc %i)",
640  hqresRestartTotal);
641  heavy_quark_restart = true;
642 
643  if (heavy_quark_res > heavy_quark_res_old) { // check if new hq residual is greater than previous
644  hqresIncrease++; // count the number of consecutive increases
645  warningQuda("CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e",
646  heavy_quark_res, heavy_quark_res_old);
647  // break out if we do not improve here anymore
648  if (hqresIncrease > hqmaxresIncrease) {
649  warningQuda("CG: solver exiting due to too many heavy quark residual norm increases (%i/%i)",
650  hqresIncrease, hqmaxresIncrease);
651  break;
652  }
653  } else {
654  hqresIncrease = 0;
655  }
656 
657  if (hqresRestartTotal > hqmaxresRestartTotal) {
658  warningQuda("CG: solver exiting due to too many heavy quark residual restarts (%i/%i)", hqresRestartTotal,
659  hqmaxresRestartTotal);
660  break;
661  }
662  }
663 
664  if (use_heavy_quark_res and heavy_quark_restart) {
665  // perform a restart
666  blas::copy(*p[0], rSloppy);
667  heavy_quark_restart = false;
668  } else {
669  // explicitly restore the orthogonality of the gradient vector
670  Complex rp = blas::cDotProduct(rSloppy, *p[j]) / (r2);
671  blas::caxpy(-rp, rSloppy, *p[j]);
672 
673  beta = r2 / r2_old;
674  blas::xpayz(rSloppy, beta, *p[j], *p[0]);
675  }
676 
677  steps_since_reliable = 0;
678  r0Norm = sqrt(r2);
679  rUpdate++;
680 
681  heavy_quark_res_old = heavy_quark_res;
682  }
683 
684  breakdown = false;
685  k++;
686 
687  PrintStats("CG", k, r2, b2, heavy_quark_res);
688  // check convergence, if convergence is satisfied we only need to check that we had a reliable update for the heavy quarks recently
689  converged = convergence(r2, heavy_quark_res, stop, param.tol_hq);
690 
691  // check for recent enough reliable updates of the HQ residual if we use it
692  if (use_heavy_quark_res) {
693  // L2 is converged or precision maxed out for L2
694  bool L2done = L2breakdown or convergenceL2(r2, heavy_quark_res, stop, param.tol_hq);
695  // HQ is converged and if we do reliable update the HQ residual has been calculated using a reliable update
696  bool HQdone = (steps_since_reliable == 0 and param.delta > 0) and convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq);
697  converged = L2done and HQdone;
698  }
699 
700  // if we have converged and need to update any trailing solutions
701  if (converged && steps_since_reliable > 0 && (j+1)%Np != 0 ) {
702  const auto alpha_ = std::unique_ptr<Complex[]>(new Complex[Np]);
703  for (int i=0; i<=j; i++) alpha_[i] = alpha[i];
704  std::vector<ColorSpinorField*> x_;
705  x_.push_back(&xSloppy);
706  std::vector<ColorSpinorField*> p_;
707  for (int i=0; i<=j; i++) p_.push_back(p[i]);
708  blas::caxpy(alpha_.get(), p_, x_);
709  blas::flops -= 4*j*xSloppy.RealLength(); // correct for over flop count since using caxpy
710  }
711 
712  j = steps_since_reliable == 0 ? 0 : (j+1)%Np; // if just done a reliable update then reset j
713  }
714 
715  blas::copy(x, xSloppy);
716  blas::xpy(y, x);
717 
720 
722  double gflops = (blas::flops + mat.flops() + matSloppy.flops())*1e-9;
723  param.gflops = gflops;
724  param.iter += k;
725 
726  if (k == param.maxiter)
727  warningQuda("Exceeded maximum iterations %d", param.maxiter);
728 
729  if (getVerbosity() >= QUDA_VERBOSE)
730  printfQuda("CG: Reliable updates = %d\n", rUpdate);
731 
732  if (param.compute_true_res) {
733  // compute the true residuals
734  mat(r, x, y, tmp3);
735  param.true_res = sqrt(blas::xmyNorm(b, r) / b2);
737  }
738 
739  PrintSummary("CG", k, r2, b2, stop, param.tol_hq);
740 
741  // reset the flops counters
742  blas::flops = 0;
743  mat.flops();
744  matSloppy.flops();
745 
747 
748  return;
749  }
750 
751 // use BlockCGrQ algortithm or BlockCG (with / without GS, see BLOCKCG_GS option)
752 #define BCGRQ 1
753 #if BCGRQ
755  #ifndef BLOCKSOLVER
756  errorQuda("QUDA_BLOCKSOLVER not built.");
757  #else
758 
760  errorQuda("Not supported");
761 
762  profile.TPSTART(QUDA_PROFILE_INIT);
763 
764  using Eigen::MatrixXcd;
765 
766  // Check to see that we're not trying to invert on a zero-field source
767  //MW: it might be useful to check what to do here.
768  double b2[QUDA_MAX_MULTI_SHIFT];
769  double b2avg=0;
770  for(int i=0; i< param.num_src; i++){
771  b2[i]=blas::norm2(b.Component(i));
772  b2avg += b2[i];
773  if(b2[i] == 0){
774  profile.TPSTOP(QUDA_PROFILE_INIT);
775  errorQuda("Warning: inverting on zero-field source - undefined for block solver\n");
776  x=b;
777  param.true_res = 0.0;
778  param.true_res_hq = 0.0;
779  return;
780  }
781  }
782 
783  b2avg = b2avg / param.num_src;
784 
786  if (!init) {
787  csParam.setPrecision(param.precision);
788  csParam.create = QUDA_ZERO_FIELD_CREATE;
789  rp = ColorSpinorField::Create(csParam);
790  yp = ColorSpinorField::Create(csParam);
791 
792  // sloppy fields
794  pp = ColorSpinorField::Create(csParam);
795  App = ColorSpinorField::Create(csParam);
799  } else {
800  rSloppyp = rp;
802  }
803 
804  // temporary fields
805  tmpp = ColorSpinorField::Create(csParam);
806  if(!mat.isStaggered()) {
807  // tmp2 only needed for multi-gpu Wilson-like kernels
808  tmp2p = ColorSpinorField::Create(csParam);
809  // additional high-precision temporary if Wilson and mixed-precision
810  csParam.setPrecision(param.precision);
812  ColorSpinorField::Create(csParam) : tmpp;
813  } else {
814  tmp3p = tmp2p = tmpp;
815  }
816 
817  init = true;
818  }
819 
820  if(!rnewp) {
821  csParam.create = QUDA_ZERO_FIELD_CREATE;
823  // ColorSpinorField *rpnew = ColorSpinorField::Create(csParam);
824  }
825 
826  ColorSpinorField &r = *rp;
827  ColorSpinorField &y = *yp;
828  ColorSpinorField &p = *pp;
829  ColorSpinorField &Ap = *App;
830  ColorSpinorField &rnew = *rnewp;
833  ColorSpinorField &tmp3 = *tmp3p;
834  ColorSpinorField &rSloppy = *rSloppyp;
836 
837  // calculate residuals for all vectors
838  // and initialize r2 matrix
839  double r2avg=0;
840  MatrixXcd r2(param.num_src, param.num_src);
841  for(int i=0; i<param.num_src; i++){
842  mat(r.Component(i), x.Component(i), y.Component(i));
843  r2(i,i) = blas::xmyNorm(b.Component(i), r.Component(i));
844  r2avg += r2(i,i).real();
845  printfQuda("r2[%i] %e\n", i, r2(i,i).real());
846  }
847  for(int i=0; i<param.num_src; i++){
848  for(int j=i+1; j < param.num_src; j++){
849  r2(i,j) = blas::cDotProduct(r.Component(i),r.Component(j));
850  r2(j,i) = std::conj(r2(i,j));
851  }
852  }
853 
854  blas::copy(rSloppy, r);
855  blas::copy(p, rSloppy);
856  blas::copy(rnew, rSloppy);
857 
858  if (&x != &xSloppy) {
859  blas::copy(y, x);
860  blas::zero(xSloppy);
861  } else {
862  blas::zero(y);
863  }
864 
865  const bool use_heavy_quark_res =
866  (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false;
867  if(use_heavy_quark_res) errorQuda("ERROR: heavy quark residual not supported in block solver");
868 
869  profile.TPSTOP(QUDA_PROFILE_INIT);
871 
872  double stop[QUDA_MAX_MULTI_SHIFT];
873 
874  for(int i = 0; i < param.num_src; i++){
875  stop[i] = stopping(param.tol, b2[i], param.residual_type); // stopping condition of solver
876  }
877 
878  // Eigen Matrices instead of scalars
879  MatrixXcd alpha = MatrixXcd::Zero(param.num_src,param.num_src);
880  MatrixXcd beta = MatrixXcd::Zero(param.num_src,param.num_src);
881  MatrixXcd C = MatrixXcd::Zero(param.num_src,param.num_src);
882  MatrixXcd S = MatrixXcd::Identity(param.num_src,param.num_src);
883  MatrixXcd pAp = MatrixXcd::Identity(param.num_src,param.num_src);
885 
886  #ifdef MWVERBOSE
887  MatrixXcd pTp = MatrixXcd::Identity(param.num_src,param.num_src);
888  #endif
889 
890 
891 
892 
893  //FIXME:reliable updates currently not implemented
894  /*
895  double rNorm[QUDA_MAX_MULTI_SHIFT];
896  double r0Norm[QUDA_MAX_MULTI_SHIFT];
897  double maxrx[QUDA_MAX_MULTI_SHIFT];
898  double maxrr[QUDA_MAX_MULTI_SHIFT];
899 
900  for(int i = 0; i < param.num_src; i++){
901  rNorm[i] = sqrt(r2(i,i).real());
902  r0Norm[i] = rNorm[i];
903  maxrx[i] = rNorm[i];
904  maxrr[i] = rNorm[i];
905  }
906  bool L2breakdown = false;
907  int rUpdate = 0;
908  nt steps_since_reliable = 1;
909  */
910 
912  profile.TPSTART(QUDA_PROFILE_COMPUTE);
913  blas::flops = 0;
914 
915  int k = 0;
916 
917  PrintStats("CG", k, r2avg / param.num_src, b2avg, 0.);
918  bool allconverged = true;
919  bool converged[QUDA_MAX_MULTI_SHIFT];
920  for(int i=0; i<param.num_src; i++){
921  converged[i] = convergence(r2(i,i).real(), 0., stop[i], param.tol_hq);
922  allconverged = allconverged && converged[i];
923  }
924 
925  // CHolesky decomposition
926  MatrixXcd L = r2.llt().matrixL();
927  C = L.adjoint();
928  MatrixXcd Linv = C.inverse();
929 
930  #ifdef MWVERBOSE
931  std::cout << "r2\n " << r2 << std::endl;
932  std::cout << "L\n " << L.adjoint() << std::endl;
933  #endif
934 
935  // set p to QR decompsition of r
936  // temporary hack - use AC to pass matrix arguments to multiblas
937  for(int i=0; i<param.num_src; i++){
938  blas::zero(p.Component(i));
939  for(int j=0;j<param.num_src; j++){
940  AC[i*param.num_src + j] = Linv(i,j);
941  }
942  }
943  blas::caxpy(AC,r,p);
944 
945  // set rsloppy to to QR decompoistion of r (p)
946  for(int i=0; i< param.num_src; i++){
947  blas::copy(rSloppy.Component(i), p.Component(i));
948  }
949 
950  #ifdef MWVERBOSE
951  for(int i=0; i<param.num_src; i++){
952  for(int j=0; j<param.num_src; j++){
953  pTp(i,j) = blas::cDotProduct(p.Component(i), p.Component(j));
954  }
955  }
956  std::cout << " pTp " << std::endl << pTp << std::endl;
957  std::cout << " L " << std::endl << L.adjoint() << std::endl;
958  std::cout << " C " << std::endl << C << std::endl;
959  #endif
960 
961  while ( !allconverged && k < param.maxiter ) {
962  // apply matrix
963  for(int i=0; i<param.num_src; i++){
964  matSloppy(Ap.Component(i), p.Component(i), tmp.Component(i), tmp2.Component(i)); // tmp as tmp
965  }
966 
967  // calculate pAp
968  for(int i=0; i<param.num_src; i++){
969  for(int j=i; j < param.num_src; j++){
970  pAp(i,j) = blas::cDotProduct(p.Component(i), Ap.Component(j));
971  if (i!=j) pAp(j,i) = std::conj(pAp(i,j));
972  }
973  }
974 
975  // update Xsloppy
976  alpha = pAp.inverse() * C;
977  // temporary hack using AC
978  for(int i=0; i<param.num_src; i++){
979  for(int j=0;j<param.num_src; j++){
980  AC[i*param.num_src + j] = alpha(i,j);
981  }
982  }
983  blas::caxpy(AC,p,xSloppy);
984 
985  // update rSloppy
986  beta = pAp.inverse();
987  // temporary hack
988  for(int i=0; i<param.num_src; i++){
989  for(int j=0;j<param.num_src; j++){
990  AC[i*param.num_src + j] = -beta(i,j);
991  }
992  }
993  blas::caxpy(AC,Ap,rSloppy);
994 
995  // orthorgonalize R
996  // copy rSloppy to rnew as temporary
997  for(int i=0; i< param.num_src; i++){
998  blas::copy(rnew.Component(i), rSloppy.Component(i));
999  }
1000  for(int i=0; i<param.num_src; i++){
1001  for(int j=i; j < param.num_src; j++){
1002  r2(i,j) = blas::cDotProduct(r.Component(i),r.Component(j));
1003  if (i!=j) r2(j,i) = std::conj(r2(i,j));
1004  }
1005  }
1006  // Cholesky decomposition
1007  L = r2.llt().matrixL();// retrieve factor L in the decomposition
1008  S = L.adjoint();
1009  Linv = S.inverse();
1010  // temporary hack
1011  for(int i=0; i<param.num_src; i++){
1012  blas::zero(rSloppy.Component(i));
1013  for(int j=0;j<param.num_src; j++){
1014  AC[i*param.num_src + j] = Linv(i,j);
1015  }
1016  }
1017  blas::caxpy(AC,rnew,rSloppy);
1018 
1019  #ifdef MWVERBOSE
1020  for(int i=0; i<param.num_src; i++){
1021  for(int j=0; j<param.num_src; j++){
1022  pTp(i,j) = blas::cDotProduct(rSloppy.Component(i), rSloppy.Component(j));
1023  }
1024  }
1025  std::cout << " rTr " << std::endl << pTp << std::endl;
1026  std::cout << "QR" << S<< std::endl << "QP " << S.inverse()*S << std::endl;;
1027  #endif
1028 
1029  // update p
1030  // use rnew as temporary again for summing up
1031  for(int i=0; i<param.num_src; i++){
1032  blas::copy(rnew.Component(i),rSloppy.Component(i));
1033  }
1034  // temporary hack
1035  for(int i=0; i<param.num_src; i++){
1036  for(int j=0;j<param.num_src; j++){
1037  AC[i*param.num_src + j] = std::conj(S(j,i));
1038  }
1039  }
1040  blas::caxpy(AC,p,rnew);
1041  // set p = rnew
1042  for(int i=0; i < param.num_src; i++){
1043  blas::copy(p.Component(i),rnew.Component(i));
1044  }
1045 
1046  // update C
1047  C = S * C;
1048 
1049  #ifdef MWVERBOSE
1050  for(int i=0; i<param.num_src; i++){
1051  for(int j=0; j<param.num_src; j++){
1052  pTp(i,j) = blas::cDotProduct(p.Component(i), p.Component(j));
1053  }
1054  }
1055  std::cout << " pTp " << std::endl << pTp << std::endl;
1056  std::cout << "S " << S<< std::endl << "C " << C << std::endl;
1057  #endif
1058 
1059  // calculate the residuals for all shifts
1060  r2avg=0;
1061  for (int j=0; j<param.num_src; j++ ){
1062  r2(j,j) = C(0,j)*conj(C(0,j));
1063  for(int i=1; i < param.num_src; i++)
1064  r2(j,j) += C(i,j) * conj(C(i,j));
1065  r2avg += r2(j,j).real();
1066  }
1067 
1068  k++;
1069  PrintStats("CG", k, r2avg / param.num_src, b2avg, 0);
1070  // check convergence
1071  allconverged = true;
1072  for(int i=0; i<param.num_src; i++){
1073  converged[i] = convergence(r2(i,i).real(), 0, stop[i], param.tol_hq);
1074  allconverged = allconverged && converged[i];
1075  }
1076 
1077 
1078  }
1079 
1080  for(int i=0; i<param.num_src; i++){
1081  blas::xpy(y.Component(i), xSloppy.Component(i));
1082  }
1083 
1084  profile.TPSTOP(QUDA_PROFILE_COMPUTE);
1085  profile.TPSTART(QUDA_PROFILE_EPILOGUE);
1086 
1088  double gflops = (blas::flops + mat.flops() + matSloppy.flops())*1e-9;
1089  param.gflops = gflops;
1090  param.iter += k;
1091 
1092  if (k == param.maxiter)
1093  warningQuda("Exceeded maximum iterations %d", param.maxiter);
1094 
1095  // if (getVerbosity() >= QUDA_VERBOSE)
1096  // printfQuda("CG: Reliable updates = %d\n", rUpdate);
1097 
1098  // compute the true residuals
1099  for(int i=0; i<param.num_src; i++){
1100  mat(r.Component(i), x.Component(i), y.Component(i), tmp3.Component(i));
1101  param.true_res = sqrt(blas::xmyNorm(b.Component(i), r.Component(i)) / b2[i]);
1105 
1106  PrintSummary("CG", k, r2(i,i).real(), b2[i], stop[i], 0.0);
1107  }
1108 
1109  // reset the flops counters
1110  blas::flops = 0;
1111  mat.flops();
1112  matSloppy.flops();
1113 
1115  profile.TPSTART(QUDA_PROFILE_FREE);
1116 
1117  delete[] AC;
1118  profile.TPSTOP(QUDA_PROFILE_FREE);
1119 
1120  return;
1121 
1122  #endif
1123 }
1124 
1125 #else
1126 
1127 // use Gram Schmidt in Block CG ?
1128 #define BLOCKCG_GS 1
1129 void CG::solve(ColorSpinorField& x, ColorSpinorField& b) {
1130  #ifndef BLOCKSOLVER
1131  errorQuda("QUDA_BLOCKSOLVER not built.");
1132  #else
1133  #ifdef BLOCKCG_GS
1134  printfQuda("BCGdQ Solver\n");
1135  #else
1136  printfQuda("BCQ Solver\n");
1137  #endif
1138  const bool use_block = true;
1140  errorQuda("Not supported");
1141 
1142  profile.TPSTART(QUDA_PROFILE_INIT);
1143 
1144  using Eigen::MatrixXcd;
1145  MatrixXcd mPAP(param.num_src,param.num_src);
1146  MatrixXcd mRR(param.num_src,param.num_src);
1147 
1148 
1149  // Check to see that we're not trying to invert on a zero-field source
1150  //MW: it might be useful to check what to do here.
1151  double b2[QUDA_MAX_MULTI_SHIFT];
1152  double b2avg=0;
1153  double r2avg=0;
1154  for(int i=0; i< param.num_src; i++){
1155  b2[i]=blas::norm2(b.Component(i));
1156  b2avg += b2[i];
1157  if(b2[i] == 0){
1158  profile.TPSTOP(QUDA_PROFILE_INIT);
1159  errorQuda("Warning: inverting on zero-field source\n");
1160  x=b;
1161  param.true_res = 0.0;
1162  param.true_res_hq = 0.0;
1163  return;
1164  }
1165  }
1166 
1167  #ifdef MWVERBOSE
1168  MatrixXcd b2m(param.num_src,param.num_src);
1169  // just to check details of b
1170  for(int i=0; i<param.num_src; i++){
1171  for(int j=0; j<param.num_src; j++){
1172  b2m(i,j) = blas::cDotProduct(b.Component(i), b.Component(j));
1173  }
1174  }
1175  std::cout << "b2m\n" << b2m << std::endl;
1176  #endif
1177 
1179  if (!init) {
1180  csParam.setPrecision(param.precision);
1181  csParam.create = QUDA_ZERO_FIELD_CREATE;
1182  rp = ColorSpinorField::Create(csParam);
1183  yp = ColorSpinorField::Create(csParam);
1184 
1185  // sloppy fields
1187  pp = ColorSpinorField::Create(csParam);
1188  App = ColorSpinorField::Create(csParam);
1192  } else {
1193  rSloppyp = rp;
1195  }
1196 
1197  // temporary fields
1198  tmpp = ColorSpinorField::Create(csParam);
1199  if(!mat.isStaggered()) {
1200  // tmp2 only needed for multi-gpu Wilson-like kernels
1201  tmp2p = ColorSpinorField::Create(csParam);
1202  // additional high-precision temporary if Wilson and mixed-precision
1203  csParam.setPrecision(param.precision);
1205  ColorSpinorField::Create(csParam) : tmpp;
1206  } else {
1207  tmp3p = tmp2p = tmpp;
1208  }
1209 
1210  init = true;
1211  }
1212 
1213  if(!rnewp) {
1214  csParam.create = QUDA_ZERO_FIELD_CREATE;
1216  // ColorSpinorField *rpnew = ColorSpinorField::Create(csParam);
1217  }
1218 
1219  ColorSpinorField &r = *rp;
1220  ColorSpinorField &y = *yp;
1221  ColorSpinorField &p = *pp;
1222  ColorSpinorField &pnew = *rnewp;
1223  ColorSpinorField &Ap = *App;
1226  ColorSpinorField &tmp3 = *tmp3p;
1227  ColorSpinorField &rSloppy = *rSloppyp;
1229 
1230  // const int i = 0; // MW: hack to be able to write Component(i) instead and try with i=0 for now
1231 
1232  for(int i=0; i<param.num_src; i++){
1233  mat(r.Component(i), x.Component(i), y.Component(i));
1234  }
1235 
1236  // double r2[QUDA_MAX_MULTI_SHIFT];
1237  MatrixXcd r2(param.num_src,param.num_src);
1238  for(int i=0; i<param.num_src; i++){
1239  r2(i,i) = blas::xmyNorm(b.Component(i), r.Component(i));
1240  printfQuda("r2[%i] %e\n", i, r2(i,i).real());
1241  }
1242  if(use_block){
1243  // MW need to initalize the full r2 matrix here
1244  for(int i=0; i<param.num_src; i++){
1245  for(int j=i+1; j<param.num_src; j++){
1246  r2(i,j) = blas::cDotProduct(r.Component(i), r.Component(j));
1247  r2(j,i) = std::conj(r2(i,j));
1248  }
1249  }
1250  }
1251 
1252  blas::copy(rSloppy, r);
1253  blas::copy(p, rSloppy);
1254  blas::copy(pnew, rSloppy);
1255 
1256  if (&x != &xSloppy) {
1257  blas::copy(y, x);
1258  blas::zero(xSloppy);
1259  } else {
1260  blas::zero(y);
1261  }
1262 
1263  const bool use_heavy_quark_res =
1264  (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false;
1265  bool heavy_quark_restart = false;
1266 
1267  profile.TPSTOP(QUDA_PROFILE_INIT);
1268  profile.TPSTART(QUDA_PROFILE_PREAMBLE);
1269 
1270  MatrixXcd r2_old(param.num_src, param.num_src);
1271  double heavy_quark_res[QUDA_MAX_MULTI_SHIFT] = {0.0}; // heavy quark res idual
1272  double heavy_quark_res_old[QUDA_MAX_MULTI_SHIFT] = {0.0}; // heavy quark residual
1273  double stop[QUDA_MAX_MULTI_SHIFT];
1274 
1275  for(int i = 0; i < param.num_src; i++){
1276  stop[i] = stopping(param.tol, b2[i], param.residual_type); // stopping condition of solver
1277  if (use_heavy_quark_res) {
1278  heavy_quark_res[i] = sqrt(blas::HeavyQuarkResidualNorm(x.Component(i), r.Component(i)).z);
1279  heavy_quark_res_old[i] = heavy_quark_res[i]; // heavy quark residual
1280  }
1281  }
1282  const int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual
1283 
1284  MatrixXcd alpha = MatrixXcd::Zero(param.num_src,param.num_src);
1285  MatrixXcd beta = MatrixXcd::Zero(param.num_src,param.num_src);
1286  MatrixXcd gamma = MatrixXcd::Identity(param.num_src,param.num_src);
1287  // gamma = gamma * 2.0;
1288 
1289  MatrixXcd pAp(param.num_src, param.num_src);
1290  MatrixXcd pTp(param.num_src, param.num_src);
1291  int rUpdate = 0;
1292 
1293  double rNorm[QUDA_MAX_MULTI_SHIFT];
1294  double r0Norm[QUDA_MAX_MULTI_SHIFT];
1295  double maxrx[QUDA_MAX_MULTI_SHIFT];
1296  double maxrr[QUDA_MAX_MULTI_SHIFT];
1297 
1298  for(int i = 0; i < param.num_src; i++){
1299  rNorm[i] = sqrt(r2(i,i).real());
1300  r0Norm[i] = rNorm[i];
1301  maxrx[i] = rNorm[i];
1302  maxrr[i] = rNorm[i];
1303  }
1304 
1305  double delta = param.delta;//MW: hack no reliable updates param.delta;
1306 
1307  // this parameter determines how many consective reliable update
1308  // reisudal increases we tolerate before terminating the solver,
1309  // i.e., how long do we want to keep trying to converge
1310  const int maxResIncrease = (use_heavy_quark_res ? 0 : param.max_res_increase); // check if we reached the limit of our tolerance
1311  const int maxResIncreaseTotal = param.max_res_increase_total;
1312  // 0 means we have no tolerance
1313  // maybe we should expose this as a parameter
1314  const int hqmaxresIncrease = maxResIncrease + 1;
1315 
1316  int resIncrease = 0;
1317  int resIncreaseTotal = 0;
1318  int hqresIncrease = 0;
1319 
1320  // set this to true if maxResIncrease has been exceeded but when we use heavy quark residual we still want to continue the CG
1321  // only used if we use the heavy_quark_res
1322  bool L2breakdown = false;
1323 
1325  profile.TPSTART(QUDA_PROFILE_COMPUTE);
1326  blas::flops = 0;
1327 
1328  int k = 0;
1329 
1330  for(int i=0; i<param.num_src; i++){
1331  r2avg+=r2(i,i).real();
1332  }
1333  PrintStats("CG", k, r2avg, b2avg, heavy_quark_res[0]);
1334  int steps_since_reliable = 1;
1335  bool allconverged = true;
1336  bool converged[QUDA_MAX_MULTI_SHIFT];
1337  for(int i=0; i<param.num_src; i++){
1338  converged[i] = convergence(r2(i,i).real(), heavy_quark_res[i], stop[i], param.tol_hq);
1339  allconverged = allconverged && converged[i];
1340  }
1341  MatrixXcd sigma(param.num_src,param.num_src);
1342 
1343  #ifdef BLOCKCG_GS
1344  // begin ignore Gram-Schmidt for now
1345 
1346  for(int i=0; i < param.num_src; i++){
1347  double n = blas::norm2(p.Component(i));
1348  blas::ax(1/sqrt(n),p.Component(i));
1349  for(int j=i+1; j < param.num_src; j++) {
1350  std::complex<double> ri=blas::cDotProduct(p.Component(i),p.Component(j));
1351  blas::caxpy(-ri,p.Component(i),p.Component(j));
1352  }
1353  }
1354 
1355  gamma = MatrixXcd::Zero(param.num_src,param.num_src);
1356  for ( int i = 0; i < param.num_src; i++){
1357  for (int j=i; j < param.num_src; j++){
1358  gamma(i,j) = blas::cDotProduct(p.Component(i),pnew.Component(j));
1359  }
1360  }
1361  #endif
1362  // end ignore Gram-Schmidt for now
1363 
1364  #ifdef MWVERBOSE
1365  for(int i=0; i<param.num_src; i++){
1366  for(int j=0; j<param.num_src; j++){
1367  pTp(i,j) = blas::cDotProduct(p.Component(i), p.Component(j));
1368  }
1369  }
1370 
1371  std::cout << " pTp " << std::endl << pTp << std::endl;
1372  std::cout << "QR" << gamma<< std::endl << "QP " << gamma.inverse()*gamma << std::endl;;
1373  #endif
1374  while ( !allconverged && k < param.maxiter ) {
1375  for(int i=0; i<param.num_src; i++){
1376  matSloppy(Ap.Component(i), p.Component(i), tmp.Component(i), tmp2.Component(i)); // tmp as tmp
1377  }
1378 
1379 
1380  bool breakdown = false;
1381  // FIXME: need to check breakdown
1382  // current implementation sets breakdown to true for pipelined CG if one rhs triggers breakdown
1383  // this is probably ok
1384 
1385 
1386  if (param.pipeline) {
1387  errorQuda("pipeline not implemented");
1388  } else {
1389  r2_old = r2;
1390  for(int i=0; i<param.num_src; i++){
1391  for(int j=0; j < param.num_src; j++){
1392  if(use_block or i==j)
1393  pAp(i,j) = blas::cDotProduct(p.Component(i), Ap.Component(j));
1394  else
1395  pAp(i,j) = 0.;
1396  }
1397  }
1398 
1399  alpha = pAp.inverse() * gamma.adjoint().inverse() * r2;
1400  #ifdef MWVERBOSE
1401  std::cout << "alpha\n" << alpha << std::endl;
1402 
1403  if(k==1){
1404  std::cout << "pAp " << std::endl <<pAp << std::endl;
1405  std::cout << "pAp^-1 " << std::endl <<pAp.inverse() << std::endl;
1406  std::cout << "r2 " << std::endl <<r2 << std::endl;
1407  std::cout << "alpha " << std::endl <<alpha << std::endl;
1408  std::cout << "pAp^-1r2" << std::endl << pAp.inverse()*r2 << std::endl;
1409  }
1410  #endif
1411  // here we are deploying the alternative beta computation
1412  for(int i=0; i<param.num_src; i++){
1413  for(int j=0; j < param.num_src; j++){
1414 
1415  blas::caxpy(-alpha(j,i), Ap.Component(j), rSloppy.Component(i));
1416  }
1417  }
1418  // MW need to calculate the full r2 matrix here, after update. Not sure how to do alternative sigma yet ...
1419  for(int i=0; i<param.num_src; i++){
1420  for(int j=0; j<param.num_src; j++){
1421  if(use_block or i==j)
1422  r2(i,j) = blas::cDotProduct(r.Component(i), r.Component(j));
1423  else
1424  r2(i,j) = 0.;
1425  }
1426  }
1427  sigma = r2;
1428  }
1429 
1430 
1431  bool updateX=false;
1432  bool updateR=false;
1433  // int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? true : false;
1434  // int updateR = ((rNorm < delta*maxrr && r0Norm <= maxrr) || updateX) ? true : false;
1435  //
1436  // printfQuda("Checking reliable update %i %i\n",updateX,updateR);
1437  // reliable update conditions
1438  for(int i=0; i<param.num_src; i++){
1439  rNorm[i] = sqrt(r2(i,i).real());
1440  if (rNorm[i] > maxrx[i]) maxrx[i] = rNorm[i];
1441  if (rNorm[i] > maxrr[i]) maxrr[i] = rNorm[i];
1442  updateX = (rNorm[i] < delta * r0Norm[i] && r0Norm[i] <= maxrx[i]) ? true : false;
1443  updateR = ((rNorm[i] < delta * maxrr[i] && r0Norm[i] <= maxrr[i]) || updateX) ? true : false;
1444  }
1445  if ( (updateR || updateX )) {
1446  // printfQuda("Suppressing reliable update %i %i\n",updateX,updateR);
1447  updateX=false;
1448  updateR=false;
1449  // printfQuda("Suppressing reliable update %i %i\n",updateX,updateR);
1450  }
1451 
1452  if ( !(updateR || updateX )) {
1453 
1454  beta = gamma * r2_old.inverse() * sigma;
1455  #ifdef MWVERBOSE
1456  std::cout << "beta\n" << beta << std::endl;
1457  #endif
1458  if (param.pipeline && !breakdown)
1459  errorQuda("pipeline not implemented");
1460 
1461  else{
1462  for(int i=0; i<param.num_src; i++){
1463  for(int j=0; j<param.num_src; j++){
1464  blas::caxpy(alpha(j,i),p.Component(j),xSloppy.Component(i));
1465  }
1466  }
1467 
1468  // set to zero
1469  for(int i=0; i < param.num_src; i++){
1470  blas::ax(0,pnew.Component(i)); // do we need components here?
1471  }
1472  // add r
1473  for(int i=0; i<param.num_src; i++){
1474  // for(int j=0;j<param.num_src; j++){
1475  // order of updating p might be relevant here
1476  blas::axpy(1.0,r.Component(i),pnew.Component(i));
1477  // blas::axpby(rcoeff,rSloppy.Component(i),beta(i,j),p.Component(j));
1478  // }
1479  }
1480  // beta = beta * gamma.inverse();
1481  for(int i=0; i<param.num_src; i++){
1482  for(int j=0;j<param.num_src; j++){
1483  double rcoeff= (j==0?1.0:0.0);
1484  // order of updating p might be relevant hereq
1485  blas::caxpy(beta(j,i),p.Component(j),pnew.Component(i));
1486  // blas::axpby(rcoeff,rSloppy.Component(i),beta(i,j),p.Component(j));
1487  }
1488  }
1489  // now need to do something with the p's
1490 
1491  for(int i=0; i< param.num_src; i++){
1492  blas::copy(p.Component(i), pnew.Component(i));
1493  }
1494 
1495 
1496  #ifdef BLOCKCG_GS
1497  for(int i=0; i < param.num_src; i++){
1498  double n = blas::norm2(p.Component(i));
1499  blas::ax(1/sqrt(n),p.Component(i));
1500  for(int j=i+1; j < param.num_src; j++) {
1501  std::complex<double> ri=blas::cDotProduct(p.Component(i),p.Component(j));
1502  blas::caxpy(-ri,p.Component(i),p.Component(j));
1503 
1504  }
1505  }
1506 
1507 
1508  gamma = MatrixXcd::Zero(param.num_src,param.num_src);
1509  for ( int i = 0; i < param.num_src; i++){
1510  for (int j=i; j < param.num_src; j++){
1511  gamma(i,j) = blas::cDotProduct(p.Component(i),pnew.Component(j));
1512  }
1513  }
1514  #endif
1515 
1516  #ifdef MWVERBOSE
1517  for(int i=0; i<param.num_src; i++){
1518  for(int j=0; j<param.num_src; j++){
1519  pTp(i,j) = blas::cDotProduct(p.Component(i), p.Component(j));
1520  }
1521  }
1522  std::cout << " pTp " << std::endl << pTp << std::endl;
1523  std::cout << "QR" << gamma<< std::endl << "QP " << gamma.inverse()*gamma << std::endl;;
1524  #endif
1525  }
1526 
1527 
1528  if (use_heavy_quark_res && (k % heavy_quark_check) == 0) {
1529  if (&x != &xSloppy) {
1530  blas::copy(tmp, y); // FIXME: check whether copy works here
1531  for(int i=0; i<param.num_src; i++){
1532  heavy_quark_res[i] = sqrt(blas::xpyHeavyQuarkResidualNorm(xSloppy.Component(i), tmp.Component(i), rSloppy.Component(i)).z);
1533  }
1534  } else {
1535  blas::copy(r, rSloppy); // FIXME: check whether copy works here
1536  for(int i=0; i<param.num_src; i++){
1537  heavy_quark_res[i] = sqrt(blas::xpyHeavyQuarkResidualNorm(x.Component(i), y.Component(i), r.Component(i)).z);
1538  }
1539  }
1540  }
1541 
1542  steps_since_reliable++;
1543  } else {
1544  printfQuda("reliable update\n");
1545  for(int i=0; i<param.num_src; i++){
1546  blas::axpy(alpha(i,i).real(), p.Component(i), xSloppy.Component(i));
1547  }
1548  blas::copy(x, xSloppy); // nop when these pointers alias
1549 
1550  for(int i=0; i<param.num_src; i++){
1551  blas::xpy(x.Component(i), y.Component(i)); // swap these around?
1552  }
1553  for(int i=0; i<param.num_src; i++){
1554  mat(r.Component(i), y.Component(i), x.Component(i), tmp3.Component(i)); // here we can use x as tmp
1555  }
1556  for(int i=0; i<param.num_src; i++){
1557  r2(i,i) = blas::xmyNorm(b.Component(i), r.Component(i));
1558  }
1559 
1560  for(int i=0; i<param.num_src; i++){
1561  blas::copy(rSloppy.Component(i), r.Component(i)); //nop when these pointers alias
1562  blas::zero(xSloppy.Component(i));
1563  }
1564 
1565  // calculate new reliable HQ resididual
1566  if (use_heavy_quark_res){
1567  for(int i=0; i<param.num_src; i++){
1568  heavy_quark_res[i] = sqrt(blas::HeavyQuarkResidualNorm(y.Component(i), r.Component(i)).z);
1569  }
1570  }
1571 
1572  // MW: FIXME as this probably goes terribly wrong right now
1573  for(int i = 0; i<param.num_src; i++){
1574  // break-out check if we have reached the limit of the precision
1575  if (sqrt(r2(i,i).real()) > r0Norm[i] && updateX) { // reuse r0Norm for this
1576  resIncrease++;
1577  resIncreaseTotal++;
1578  warningQuda("CG: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
1579  sqrt(r2(i,i).real()), r0Norm[i], resIncreaseTotal);
1580  if ( resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
1581  if (use_heavy_quark_res) {
1582  L2breakdown = true;
1583  } else {
1584  warningQuda("CG: solver exiting due to too many true residual norm increases");
1585  break;
1586  }
1587  }
1588  } else {
1589  resIncrease = 0;
1590  }
1591  }
1592  // if L2 broke down already we turn off reliable updates and restart the CG
1593  for(int i = 0; i<param.num_src; i++){
1594  if (use_heavy_quark_res and L2breakdown) {
1595  delta = 0;
1596  warningQuda("CG: Restarting without reliable updates for heavy-quark residual");
1597  heavy_quark_restart = true;
1598  if (heavy_quark_res[i] > heavy_quark_res_old[i]) {
1599  hqresIncrease++;
1600  warningQuda("CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", heavy_quark_res[i], heavy_quark_res_old[i]);
1601  // break out if we do not improve here anymore
1602  if (hqresIncrease > hqmaxresIncrease) {
1603  warningQuda("CG: solver exiting due to too many heavy quark residual norm increases");
1604  break;
1605  }
1606  }
1607  }
1608  }
1609 
1610  for(int i=0; i<param.num_src; i++){
1611  rNorm[i] = sqrt(r2(i,i).real());
1612  maxrr[i] = rNorm[i];
1613  maxrx[i] = rNorm[i];
1614  r0Norm[i] = rNorm[i];
1615  heavy_quark_res_old[i] = heavy_quark_res[i];
1616  }
1617  rUpdate++;
1618 
1619  if (use_heavy_quark_res and heavy_quark_restart) {
1620  // perform a restart
1621  blas::copy(p, rSloppy);
1622  heavy_quark_restart = false;
1623  } else {
1624  // explicitly restore the orthogonality of the gradient vector
1625  for(int i=0; i<param.num_src; i++){
1626  double rp = blas::reDotProduct(rSloppy.Component(i), p.Component(i)) / (r2(i,i).real());
1627  blas::axpy(-rp, rSloppy.Component(i), p.Component(i));
1628 
1629  beta(i,i) = r2(i,i) / r2_old(i,i);
1630  blas::xpay(rSloppy.Component(i), beta(i,i).real(), p.Component(i));
1631  }
1632  }
1633 
1634  steps_since_reliable = 0;
1635  }
1636 
1637  breakdown = false;
1638  k++;
1639 
1640  allconverged = true;
1641  r2avg=0;
1642  for(int i=0; i<param.num_src; i++){
1643  r2avg+= r2(i,i).real();
1644  // check convergence, if convergence is satisfied we only need to check that we had a reliable update for the heavy quarks recently
1645  converged[i] = convergence(r2(i,i).real(), heavy_quark_res[i], stop[i], param.tol_hq);
1646  allconverged = allconverged && converged[i];
1647  }
1648  PrintStats("CG", k, r2avg, b2avg, heavy_quark_res[0]);
1649 
1650  // check for recent enough reliable updates of the HQ residual if we use it
1651  if (use_heavy_quark_res) {
1652  for(int i=0; i<param.num_src; i++){
1653  // L2 is concverged or precision maxed out for L2
1654  bool L2done = L2breakdown or convergenceL2(r2(i,i).real(), heavy_quark_res[i], stop[i], param.tol_hq);
1655  // HQ is converged and if we do reliable update the HQ residual has been calculated using a reliable update
1656  bool HQdone = (steps_since_reliable == 0 and param.delta > 0) and convergenceHQ(r2(i,i).real(), heavy_quark_res[i], stop[i], param.tol_hq);
1657  converged[i] = L2done and HQdone;
1658  }
1659  }
1660 
1661  }
1662 
1663  blas::copy(x, xSloppy);
1664  for(int i=0; i<param.num_src; i++){
1665  blas::xpy(y.Component(i), x.Component(i));
1666  }
1667 
1668  profile.TPSTOP(QUDA_PROFILE_COMPUTE);
1669  profile.TPSTART(QUDA_PROFILE_EPILOGUE);
1670 
1672  double gflops = (blas::flops + mat.flops() + matSloppy.flops())*1e-9;
1673  param.gflops = gflops;
1674  param.iter += k;
1675 
1676  if (k == param.maxiter)
1677  warningQuda("Exceeded maximum iterations %d", param.maxiter);
1678 
1679  if (getVerbosity() >= QUDA_VERBOSE)
1680  printfQuda("CG: Reliable updates = %d\n", rUpdate);
1681 
1682  // compute the true residuals
1683  for(int i=0; i<param.num_src; i++){
1684  mat(r.Component(i), x.Component(i), y.Component(i), tmp3.Component(i));
1685  param.true_res = sqrt(blas::xmyNorm(b.Component(i), r.Component(i)) / b2[i]);
1689 
1690  PrintSummary("CG", k, r2(i,i).real(), b2[i], stop[i], 0.0);
1691  }
1692 
1693  // reset the flops counters
1694  blas::flops = 0;
1695  mat.flops();
1696  matSloppy.flops();
1697 
1699  profile.TPSTART(QUDA_PROFILE_FREE);
1700 
1701  profile.TPSTOP(QUDA_PROFILE_FREE);
1702 
1703  return;
1704 
1705  #endif
1706 
1707 }
1708 #endif
1709 
1710 
1711 } // namespace quda
cudaColorSpinorField * tmp2
void blocksolve(ColorSpinorField &out, ColorSpinorField &in)
void ax(double a, ColorSpinorField &x)
Definition: blas_quda.cu:508
void setPrecision(QudaPrecision precision, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION, bool force_native=false)
ColorSpinorField * tmp3p
Definition: invert_quda.h:576
double3 cDotProductNormA(ColorSpinorField &a, ColorSpinorField &b)
Definition: reduce_quda.cu:778
bool convergenceHQ(double r2, double hq2, double r2_tol, double hq_tol)
Test for HQ solver convergence – ignore L2 residual.
Definition: solver.cpp:237
virtual ~CGNR()
void axpyZpbx(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, double b)
Definition: blas_quda.cu:552
DiracMMdag mmdag
Definition: invert_quda.h:645
DiracMdagM mdagm
Definition: invert_quda.h:661
#define QUDA_MAX_MULTI_SHIFT
Maximum number of shifts supported by the multi-shift solver. This number may be changed if need be...
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
Matrix< N, std::complex< T > > conj(const Matrix< N, std::complex< T > > &mat)
#define checkPrecision(...)
#define errorQuda(...)
Definition: util_quda.h:121
double norm2(const ColorSpinorField &a)
Definition: reduce_quda.cu:721
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:120
ColorSpinorField * rnewp
Definition: invert_quda.h:576
virtual ~CGNE()
Definition: inv_cg_quda.cpp:65
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
Definition: reduce_quda.cu:764
double epsilon
Definition: test_util.cpp:1649
const DiracMatrix & matSloppy
Definition: invert_quda.h:574
void deflate(std::vector< ColorSpinorField *> vec_defl, std::vector< ColorSpinorField *> vec, std::vector< ColorSpinorField *> evecs, std::vector< Complex > evals)
Deflate vector with Eigenvectors.
void PrintStats(const char *name, int k, double r2, double b2, double hq2)
Prints out the running statistics of the solver (requires a verbosity of QUDA_VERBOSE) ...
Definition: solver.cpp:256
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:44
ColorSpinorField * yp
Definition: invert_quda.h:648
double3 xpyHeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &r)
Definition: reduce_quda.cu:818
static ColorSpinorField * Create(const ColorSpinorParam &param)
bool convergence(double r2, double hq2, double r2_tol, double hq_tol)
Definition: solver.cpp:223
double reDotProduct(ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:728
Complex axpyCGNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:796
TimeProfile & profile
Definition: invert_quda.h:464
const DiracMatrix & mat
Definition: invert_quda.h:573
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
Definition: copy_quda.cu:355
ColorSpinorField & Component(const int idx) const
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:75
QudaPreserveSource preserve_source
Definition: invert_quda.h:154
int max_res_increase_total
Definition: invert_quda.h:96
std::vector< ColorSpinorField * > defl_tmp1
Definition: invert_quda.h:547
ColorSpinorField * xp
Definition: invert_quda.h:647
ColorSpinorField * bp
Definition: invert_quda.h:663
ColorSpinorField * pp
Definition: invert_quda.h:576
bool alternative_reliable
Definition: test_util.cpp:1659
This is just a dummy structure we use for trove to define the required structure size.
void xpay(ColorSpinorField &x, double a, ColorSpinorField &y)
Definition: blas_quda.h:37
CGNE(DiracMatrix &mat, DiracMatrix &matSloppy, SolverParam &param, TimeProfile &profile)
Definition: inv_cg_quda.cpp:60
QudaGaugeParam param
Definition: pack_test.cpp:17
bool init
Definition: invert_quda.h:578
bool convergenceL2(double r2, double hq2, double r2_tol, double hq_tol)
Test for L2 solver convergence – ignore HQ residual.
Definition: solver.cpp:246
std::vector< ColorSpinorField * > defl_tmp2
Definition: invert_quda.h:548
QudaComputeNullVector compute_null_vector
Definition: invert_quda.h:67
double Last(QudaProfileType idx)
Definition: timer.h:251
QudaResidualType residual_type
Definition: invert_quda.h:49
CG(DiracMatrix &mat, DiracMatrix &matSloppy, SolverParam &param, TimeProfile &profile)
Definition: inv_cg_quda.cpp:22
double true_res_hq_offset[QUDA_MAX_MULTI_SHIFT]
Definition: invert_quda.h:187
int max_hq_res_restart_total
Definition: invert_quda.h:106
static double stopping(double tol, double b2, QudaResidualType residual_type)
Set the solver L2 stopping condition.
Definition: solver.cpp:206
ColorSpinorParam csParam
Definition: pack_test.cpp:24
void axpy(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:35
double4 quadrupleCGReduction(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Definition: reduce_quda.cu:833
#define warningQuda(...)
Definition: util_quda.h:133
#define checkLocation(...)
__host__ __device__ ValueType pow(ValueType x, ExponentType e)
Definition: complex_quda.h:111
ColorSpinorField * tmpp
Definition: invert_quda.h:576
void constructDeflationSpace(const ColorSpinorField &meta, const DiracMatrix &mat, bool svd)
Constructs the deflation space.
Definition: solver.cpp:159
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
Definition: reduce_quda.cu:809
double true_res_offset[QUDA_MAX_MULTI_SHIFT]
Definition: invert_quda.h:181
CGNR(DiracMatrix &mat, DiracMatrix &matSloppy, SolverParam &param, TimeProfile &profile)
std::complex< double > Complex
Definition: quda_internal.h:46
EigenSolver * eig_solve
Definition: invert_quda.h:545
void tripleCGUpdate(double alpha, double beta, ColorSpinorField &q, ColorSpinorField &r, ColorSpinorField &x, ColorSpinorField &p)
Definition: blas_quda.cu:614
std::vector< Complex > evals
Definition: invert_quda.h:61
void init()
Create the CUBLAS context.
Definition: blas_cublas.cu:31
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:512
ColorSpinorField * rp
Definition: invert_quda.h:576
void zero(ColorSpinorField &a)
Definition: blas_quda.cu:472
std::vector< ColorSpinorField * > p
Definition: invert_quda.h:577
void Mdag(ColorSpinorField &out, const ColorSpinorField &in) const
Definition: dirac.cpp:90
std::vector< ColorSpinorField * > evecs
Definition: invert_quda.h:58
QudaPrecision precision
Definition: invert_quda.h:142
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const =0
ColorSpinorField * rSloppyp
Definition: invert_quda.h:576
SolverParam & param
Definition: invert_quda.h:463
void operator()(ColorSpinorField &out, ColorSpinorField &in)
Run CG.
ColorSpinorField * yp
Definition: invert_quda.h:576
Conjugate-Gradient Solver.
Definition: invert_quda.h:570
unsigned long long flops() const
Definition: dirac_quda.h:1119
ColorSpinorField * xSloppyp
Definition: invert_quda.h:576
#define printfQuda(...)
Definition: util_quda.h:115
void operator()(ColorSpinorField &out, ColorSpinorField &in)
Run CG.
Definition: invert_quda.h:588
unsigned long long flops
Definition: blas_quda.cu:22
void xpy(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:33
void axpby(double a, ColorSpinorField &x, double b, ColorSpinorField &y)
Definition: blas_quda.h:36
double axpyNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:74
bool use_alternative_reliable
Definition: invert_quda.h:73
ColorSpinorField * tmp2p
Definition: invert_quda.h:576
ColorSpinorField * App
Definition: invert_quda.h:576
QudaUseInitGuess use_init_guess
Definition: invert_quda.h:64
void operator()(ColorSpinorField &out, ColorSpinorField &in)
Run CG.
Definition: inv_cg_quda.cpp:74
QudaPrecision precision_sloppy
Definition: invert_quda.h:145
bool use_sloppy_partial_accumulator
Definition: invert_quda.h:76
void PrintSummary(const char *name, int k, double r2, double b2, double r2_tol, double hq_tol)
Prints out the summary of the solver convergence (requires a verbosity of QUDA_SUMMARIZE). Assumes SolverParam.true_res and SolverParam.true_res_hq has been set.
Definition: solver.cpp:270
virtual ~CG()
Definition: inv_cg_quda.cpp:29
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:130
int solution_accumulator_pipeline
Definition: invert_quda.h:86
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
QudaPrecision Precision() const
bool deflate_init
Definition: invert_quda.h:546
void xpayz(ColorSpinorField &x, double a, ColorSpinorField &y, ColorSpinorField &z)
Definition: blas_quda.h:38
const Dirac * Expose() const
Definition: dirac_quda.h:1135
bool isStaggered() const
Definition: dirac_quda.h:1128
double3 tripleCGReduction(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Definition: reduce_quda.cu:828
void updateR()
update the radius for halos.