QUDA  0.9.0
inv_cg_quda.cpp
Go to the documentation of this file.
1 #include <cstdio>
2 #include <cstdlib>
3 #include <cmath>
4 #include <limits>
5 #include <memory>
6 #include <iostream>
7 
8 #ifdef BLOCKSOLVER
9 #include <Eigen/Dense>
10 #endif
11 
12 #include <quda_internal.h>
13 #include <color_spinor_field.h>
14 #include <blas_quda.h>
15 #include <dslash_quda.h>
16 #include <invert_quda.h>
17 #include <util_quda.h>
18 
19 namespace quda {
20 
22  Solver(param, profile), mat(mat), matSloppy(matSloppy), init(false) {
23  }
24 
25  CG::~CG() {
26  if ( init ) {
27  for (auto pi : p) delete pi;
28  delete rp;
29  delete yp;
30  delete App;
31  delete tmpp;
32  init = false;
33  }
34  }
35 
37  CG(mmdag, mmdagSloppy, param, profile), mmdag(mat.Expose()), mmdagSloppy(mat.Expose()), init(false) {
38  }
39 
41  if ( init ) {
42  delete xp;
43  init = false;
44  }
45  }
46 
47  // CGNE: M Mdag y = b is solved; x = Mdag y is returned as solution.
49 
50  if (!init) {
54 
55  init = true;
56 
58  warningQuda("Initial guess may not work as expected with CGNE\n");
59  *xp = x;
60  }
61 
63 
64  mmdag.Expose()->Mdag(x,*xp);
65  }
66 
68  CG(mdagm, mdagmSloppy, param, profile), mdagm(mat.Expose()), mdagmSloppy(mat.Expose()), init(false) {
69  }
70 
72  if ( init ) {
73  delete bp;
74  init = false;
75  }
76  }
77 
78  // CGNR: Mdag M x = Mdag b is solved.
80  const int iter0 = param.iter;
81 
82  if (!init) {
86 
87  init = true;
88 
89  }
90 
91  mdagm.Expose()->Mdag(*bp,b);
93 
94  if (param.compute_true_res) {
95  // compute the true residuals
96  const double b2 = blas::norm2(b);
97  mdagm.Expose()->M(*bp, x);
98  const double r2 = blas::xmyNorm(b, *bp) / b2;
99  param.true_res = sqrt(r2);
101 
102  PrintSummary("CGNR", param.iter - iter0, r2, b2);
103  }
104  }
105 
108  errorQuda("Not supported");
109 
111  if (Np < 0 || Np > 16) errorQuda("Invalid value %d for solution_accumulator_pipeline\n", Np);
112 
113 #ifdef ALTRELIABLE
114  // hack to select alternative reliable updates
115  constexpr bool alternative_reliable = true;
116  warningQuda("Using alternative reliable updates. This feature is mostly ok but needs a little more testing in the real world.\n");
117 #else
118  constexpr bool alternative_reliable = false;
119 #endif
120  profile.TPSTART(QUDA_PROFILE_INIT);
121 
122  // Check to see that we're not trying to invert on a zero-field source
123  double b2 = blas::norm2(b);
124 
125  // Check to see that we're not trying to invert on a zero-field source
127  profile.TPSTOP(QUDA_PROFILE_INIT);
128  printfQuda("Warning: inverting on zero-field source\n");
129  x = b;
130  param.true_res = 0.0;
131  param.true_res_hq = 0.0;
132  return;
133  }
134 
136  if (!init) {
141  // sloppy fields
142  csParam.setPrecision(param.precision_sloppy);
145 
146  init = true;
147 
148  }
149  ColorSpinorField &r = *rp;
150  ColorSpinorField &y = *yp;
151  ColorSpinorField &Ap = *App;
153 
154  csParam.setPrecision(param.precision_sloppy);
156 
157  // tmp2 only needed for multi-gpu Wilson-like kernels
159  ColorSpinorField &tmp2 = *tmp2_p;
160 
161  // additional high-precision temporary if Wilson and mixed-precision
162  csParam.setPrecision(param.precision);
163  ColorSpinorField *tmp3_p = (x.Precision() != param.precision_sloppy && !mat.isStaggered()) ?
165  ColorSpinorField &tmp3 = *tmp3_p;
166 
167  // alternative reliable updates
168  // alternative reliable updates - set precision - does not hurt performance here
169 
170  const double u= param.precision_sloppy == 8 ? std::numeric_limits<double>::epsilon()/2. : ((param.precision_sloppy == 4) ? std::numeric_limits<float>::epsilon()/2. : pow(2.,-13));
171  const double uhigh= param.precision == 8 ? std::numeric_limits<double>::epsilon()/2. : ((param.precision == 4) ? std::numeric_limits<float>::epsilon()/2. : pow(2.,-13));
172  const double deps=sqrt(u);
173  constexpr double dfac = 1.1;
174  double d_new =0 ;
175  double d =0 ;
176  double dinit =0;
177  double xNorm = 0;
178  double xnorm = 0;
179  double pnorm = 0;
180  double ppnorm = 0;
181  double Anorm = 0;
182 
183  // for alternative reliable updates
184  if(alternative_reliable){
185  // estimate norm for reliable updates
186  mat(r, b, y, tmp3);
187  Anorm = sqrt(blas::norm2(r)/b2);
188  }
189 
190 
191  // compute initial residual
192  mat(r, x, y, tmp3);
193  double r2 = blas::xmyNorm(b, r);
194  if (b2 == 0) {
195  b2 = r2;
196  }
197 
198  csParam.setPrecision(param.precision_sloppy);
199  ColorSpinorField *r_sloppy;
200  if (param.precision_sloppy == x.Precision()) {
201  r_sloppy = &r;
202  } else {
204  r_sloppy = ColorSpinorField::Create(r, csParam);
205  }
206 
207  ColorSpinorField *x_sloppy;
208  if (param.precision_sloppy == x.Precision() ||
210  x_sloppy = &x;
211  } else {
213  x_sloppy = ColorSpinorField::Create(x, csParam);
214  }
215 
216  ColorSpinorField &xSloppy = *x_sloppy;
217  ColorSpinorField &rSloppy = *r_sloppy;
218 
220  csParam.setPrecision(param.precision_sloppy);
221 
222  if (Np != (int)p.size()) {
223  for (auto &pi : p) delete pi;
224  p.resize(Np);
225  for (auto &pi : p) pi = ColorSpinorField::Create(rSloppy, csParam);
226  } else {
227  for (auto &pi : p) *pi = rSloppy;
228  }
229 
230  if (&x != &xSloppy) {
231  blas::copy(y, x);
232  blas::zero(xSloppy);
233  } else {
234  blas::zero(y);
235  }
236 
237  const bool use_heavy_quark_res =
238  (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false;
239  bool heavy_quark_restart = false;
240 
241  profile.TPSTOP(QUDA_PROFILE_INIT);
243 
244  double r2_old;
245 
246  double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver
247 
248  double heavy_quark_res = 0.0; // heavy quark res idual
249  double heavy_quark_res_old = 0.0; // heavy quark residual
250 
251  if (use_heavy_quark_res) {
252  heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z);
253  heavy_quark_res_old = heavy_quark_res; // heavy quark residual
254  }
255  const int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual
256 
257  double alpha[Np];
258  double beta = 0.0;
259  double pAp;
260  int rUpdate = 0;
261 
262  double rNorm = sqrt(r2);
263  double r0Norm = rNorm;
264  double maxrx = rNorm;
265  double maxrr = rNorm;
266  double delta = param.delta;
267 
268 
269  // this parameter determines how many consective reliable update
270  // residual increases we tolerate before terminating the solver,
271  // i.e., how long do we want to keep trying to converge
272  const int maxResIncrease = (use_heavy_quark_res ? 0 : param.max_res_increase); // check if we reached the limit of our tolerance
273  const int maxResIncreaseTotal = param.max_res_increase_total;
274  // 0 means we have no tolerance
275  // maybe we should expose this as a parameter
276  const int hqmaxresIncrease = maxResIncrease + 1;
277 
278  int resIncrease = 0;
279  int resIncreaseTotal = 0;
280  int hqresIncrease = 0;
281 
282  // set this to true if maxResIncrease has been exceeded but when we use heavy quark residual we still want to continue the CG
283  // only used if we use the heavy_quark_res
284  bool L2breakdown = false;
285 
287  profile.TPSTART(QUDA_PROFILE_COMPUTE);
288  blas::flops = 0;
289 
290  int k = 0;
291  int j = 0;
292 
293  PrintStats("CG", k, r2, b2, heavy_quark_res);
294 
295  int steps_since_reliable = 1;
296  bool converged = convergence(r2, heavy_quark_res, stop, param.tol_hq);
297 
298  // alternative reliable updates
299  if(alternative_reliable){
300  dinit = uhigh * (rNorm + Anorm * xNorm);
301  d = dinit;
302  }
303 
304  while ( !converged && k < param.maxiter ) {
305  matSloppy(Ap, *p[j], tmp, tmp2); // tmp as tmp
306  double sigma;
307 
308  bool breakdown = false;
309  if (param.pipeline) {
310  double Ap2;
311  //TODO: alternative reliable updates - need r2, Ap2, pAp, p norm
312  if(alternative_reliable){
313  double4 quadruple = blas::quadrupleCGReduction(rSloppy, Ap, *p[j]);
314  r2 = quadruple.x; Ap2 = quadruple.y; pAp = quadruple.z; ppnorm= quadruple.w;
315  }
316  else{
317  double3 triplet = blas::tripleCGReduction(rSloppy, Ap, *p[j]);
318  r2 = triplet.x; Ap2 = triplet.y; pAp = triplet.z;
319  }
320  r2_old = r2;
321  alpha[j] = r2 / pAp;
322  sigma = alpha[j]*(alpha[j] * Ap2 - pAp);
323  if (sigma < 0.0 || steps_since_reliable == 0) { // sigma condition has broken down
324  r2 = blas::axpyNorm(-alpha[j], Ap, rSloppy);
325  sigma = r2;
326  breakdown = true;
327  }
328 
329  r2 = sigma;
330  } else {
331  r2_old = r2;
332 
333  // alternative reliable updates,
334  if(alternative_reliable){
335  double3 pAppp = blas::cDotProductNormA(*p[j],Ap);
336  pAp = pAppp.x;
337  ppnorm = pAppp.z;
338  }
339  else{
340  pAp = blas::reDotProduct(*p[j], Ap);
341  }
342 
343  alpha[j] = r2 / pAp;
344 
345  // here we are deploying the alternative beta computation
346  Complex cg_norm = blas::axpyCGNorm(-alpha[j], Ap, rSloppy);
347  r2 = real(cg_norm); // (r_new, r_new)
348  sigma = imag(cg_norm) >= 0.0 ? imag(cg_norm) : r2; // use r2 if (r_k+1, r_k+1-r_k) breaks
349  }
350 
351  // reliable update conditions
352  rNorm = sqrt(r2);
353  int updateX;
354  int updateR;
355 
356  if(alternative_reliable){
357  // alternative reliable updates
358  updateX = ( (d <= deps*sqrt(r2_old)) or (dfac * dinit > deps * r0Norm) ) and (d_new > deps*rNorm) and (d_new > dfac * dinit);
359  updateR = 0;
360  // if(updateX)
361  // printfQuda("new reliable update conditions (%i) d_n-1 < eps r2_old %e %e;\t dn > eps r_n %e %e;\t (dnew > 1.1 dinit %e %e)\n",
362  // updateX,d,deps*sqrt(r2_old),d_new,deps*rNorm,d_new,dinit);
363  }
364  else{
365  if (rNorm > maxrx) maxrx = rNorm;
366  if (rNorm > maxrr) maxrr = rNorm;
367  updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? 1 : 0;
368  updateR = ((rNorm < delta*maxrr && r0Norm <= maxrr) || updateX) ? 1 : 0;
369  }
370 
371  // force a reliable update if we are within target tolerance (only if doing reliable updates)
372  if ( convergence(r2, heavy_quark_res, stop, param.tol_hq) && param.delta >= param.tol ) updateX = 1;
373 
374  // For heavy-quark inversion force a reliable update if we continue after
375  if ( use_heavy_quark_res and L2breakdown and convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq) and param.delta >= param.tol ) {
376  updateX = 1;
377  }
378 
379  if ( !(updateR || updateX )) {
380  beta = sigma / r2_old; // use the alternative beta computation
381 
382 
383  if (param.pipeline && !breakdown) {
384 
385  if (Np == 1) {
386  blas::tripleCGUpdate(alpha[j], beta, Ap, xSloppy, rSloppy, *p[j]);
387  } else {
388  errorQuda("Not implemented pipelined CG with Np > 1");
389  }
390  } else {
391  if (Np == 1) {
392  // with Np=1 we just run regular fusion between x and p updates
393  blas::axpyZpbx(alpha[k%Np], *p[k%Np], xSloppy, rSloppy, beta);
394  } else {
395 
396  if ( (j+1)%Np == 0 ) {
397  const auto alpha_ = std::unique_ptr<Complex[]>(new Complex[Np]);
398  for (int i=0; i<Np; i++) alpha_[i] = alpha[i];
399  std::vector<ColorSpinorField*> x_;
400  x_.push_back(&xSloppy);
401  blas::caxpy(alpha_.get(), p, x_);
402  blas::flops -= 4*j*xSloppy.RealLength(); // correct for over flop count since using caxpy
403  }
404 
405  //p[(k+1)%Np] = r + beta * p[k%Np]
406  blas::xpayz(rSloppy, beta, *p[j], *p[(j+1)%Np]);
407  }
408  }
409 
410  if (use_heavy_quark_res && k%heavy_quark_check==0) {
411  if (&x != &xSloppy) {
412  blas::copy(tmp,y);
413  heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(xSloppy, tmp, rSloppy).z);
414  } else {
415  blas::copy(r, rSloppy);
416  heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x, y, r).z);
417  }
418  }
419 
420  // alternative reliable updates
421  if(alternative_reliable){
422  d = d_new;
423  pnorm = pnorm + alpha[j] * alpha[j]* (ppnorm);
424  xnorm = sqrt(pnorm);
425  d_new = d + u*rNorm + uhigh*Anorm * xnorm;
426  if(steps_since_reliable==0)
427  printfQuda("New dnew: %e (r %e , y %e)\n",d_new,u*rNorm,uhigh*Anorm * sqrt(blas::norm2(y)) );
428  }
429  steps_since_reliable++;
430 
431  } else {
432 
433  {
434  const auto alpha_ = std::unique_ptr<Complex[]>(new Complex[Np]);
435  for (int i=0; i<=j; i++) alpha_[i] = alpha[i];
436  std::vector<ColorSpinorField*> x_;
437  x_.push_back(&xSloppy);
438  std::vector<ColorSpinorField*> p_;
439  for (int i=0; i<=j; i++) p_.push_back(p[i]);
440  blas::caxpy(alpha_.get(), p_, x_);
441  blas::flops -= 4*j*xSloppy.RealLength(); // correct for over flop count since using caxpy
442  }
443 
444  blas::copy(x, xSloppy); // nop when these pointers alias
445 
446  blas::xpy(x, y); // swap these around?
447  mat(r, y, x, tmp3); // here we can use x as tmp
448  r2 = blas::xmyNorm(b, r);
449 
450  blas::copy(rSloppy, r); //nop when these pointers alias
451  blas::zero(xSloppy);
452 
453  // alternative reliable updates
454  if(alternative_reliable){
455  dinit = uhigh*(sqrt(r2) + Anorm * sqrt(blas::norm2(y)));
456  d = d_new;
457  xnorm = 0;//sqrt(norm2(x));
458  pnorm = 0;//pnorm + alpha * sqrt(norm2(p));
459  printfQuda("New dinit: %e (r %e , y %e)\n",dinit,uhigh*sqrt(r2),uhigh*Anorm*sqrt(blas::norm2(y)));
460  d_new = dinit;
461  }
462  else{
463  rNorm = sqrt(r2);
464  maxrr = rNorm;
465  maxrx = rNorm;
466  }
467 
468 
469  // calculate new reliable HQ resididual
470  if (use_heavy_quark_res) heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(y, r).z);
471 
472  // break-out check if we have reached the limit of the precision
473  if (sqrt(r2) > r0Norm && updateX) { // reuse r0Norm for this
474  resIncrease++;
475  resIncreaseTotal++;
476  warningQuda("CG: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
477  sqrt(r2), r0Norm, resIncreaseTotal);
478  if ( resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
479  if (use_heavy_quark_res) {
480  L2breakdown = true;
481  } else {
482  warningQuda("CG: solver exiting due to too many true residual norm increases");
483  break;
484  }
485  }
486  } else {
487  resIncrease = 0;
488  }
489  // if L2 broke down already we turn off reliable updates and restart the CG
490  if (use_heavy_quark_res and L2breakdown) {
491  delta = 0;
492  warningQuda("CG: Restarting without reliable updates for heavy-quark residual");
493  heavy_quark_restart = true;
494  if (heavy_quark_res > heavy_quark_res_old) {
495  hqresIncrease++;
496  warningQuda("CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", heavy_quark_res, heavy_quark_res_old);
497  // break out if we do not improve here anymore
498  if (hqresIncrease > hqmaxresIncrease) {
499  warningQuda("CG: solver exiting due to too many heavy quark residual norm increases");
500  break;
501  }
502  }
503  }
504 
505  if (use_heavy_quark_res and heavy_quark_restart) {
506  // perform a restart
507  blas::copy(*p[0], rSloppy);
508  heavy_quark_restart = false;
509  } else {
510  // explicitly restore the orthogonality of the gradient vector
511  Complex rp = blas::cDotProduct(rSloppy, *p[j]) / (r2);
512  blas::caxpy(-rp, rSloppy, *p[j]);
513 
514  beta = r2 / r2_old;
515  blas::xpayz(rSloppy, beta, *p[j], *p[0]);
516  }
517 
518  steps_since_reliable = 0;
519  r0Norm = sqrt(r2);
520  rUpdate++;
521 
522  heavy_quark_res_old = heavy_quark_res;
523  }
524 
525  breakdown = false;
526  k++;
527 
528  PrintStats("CG", k, r2, b2, heavy_quark_res);
529  // check convergence, if convergence is satisfied we only need to check that we had a reliable update for the heavy quarks recently
530  converged = convergence(r2, heavy_quark_res, stop, param.tol_hq);
531 
532  // check for recent enough reliable updates of the HQ residual if we use it
533  if (use_heavy_quark_res) {
534  // L2 is concverged or precision maxed out for L2
535  bool L2done = L2breakdown or convergenceL2(r2, heavy_quark_res, stop, param.tol_hq);
536  // HQ is converged and if we do reliable update the HQ residual has been calculated using a reliable update
537  bool HQdone = (steps_since_reliable == 0 and param.delta > 0) and convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq);
538  converged = L2done and HQdone;
539  }
540 
541  // if we have converged and need to update any trailing solutions
542  if (converged && steps_since_reliable > 0 && (j+1)%Np != 0 ) {
543  const auto alpha_ = std::unique_ptr<Complex[]>(new Complex[Np]);
544  for (int i=0; i<=j; i++) alpha_[i] = alpha[i];
545  std::vector<ColorSpinorField*> x_;
546  x_.push_back(&xSloppy);
547  std::vector<ColorSpinorField*> p_;
548  for (int i=0; i<=j; i++) p_.push_back(p[i]);
549  blas::caxpy(alpha_.get(), p_, x_);
550  blas::flops -= 4*j*xSloppy.RealLength(); // correct for over flop count since using caxpy
551  }
552 
553  j = steps_since_reliable == 0 ? 0 : (j+1)%Np; // if just done a reliable update then reset j
554  }
555 
556  blas::copy(x, xSloppy);
557  blas::xpy(y, x);
558 
561 
563  double gflops = (blas::flops + mat.flops() + matSloppy.flops())*1e-9;
564  param.gflops = gflops;
565  param.iter += k;
566 
567  if (k == param.maxiter)
568  warningQuda("Exceeded maximum iterations %d", param.maxiter);
569 
570  if (getVerbosity() >= QUDA_VERBOSE)
571  printfQuda("CG: Reliable updates = %d\n", rUpdate);
572 
573  if (param.compute_true_res) {
574  // compute the true residuals
575  mat(r, x, y, tmp3);
576  param.true_res = sqrt(blas::xmyNorm(b, r) / b2);
578  }
579 
580  PrintSummary("CG", k, r2, b2);
581 
582  // reset the flops counters
583  blas::flops = 0;
584  mat.flops();
585  matSloppy.flops();
586 
588  profile.TPSTART(QUDA_PROFILE_FREE);
589 
590  if (&tmp3 != &tmp) delete tmp3_p;
591  if (&tmp2 != &tmp) delete tmp2_p;
592 
593  if (&rSloppy != &r) delete r_sloppy;
594  if (&xSloppy != &x) delete x_sloppy;
595 
596  profile.TPSTOP(QUDA_PROFILE_FREE);
597 
598  return;
599  }
600 
601 
602 // use BlockCGrQ algortithm or BlockCG (with / without GS, see BLOCKCG_GS option)
603 #define BCGRQ 1
604 #if BCGRQ
606  #ifndef BLOCKSOLVER
607  errorQuda("QUDA_BLOCKSOLVER not built.");
608  #else
609 
611  errorQuda("Not supported");
612 
613  profile.TPSTART(QUDA_PROFILE_INIT);
614 
615  using Eigen::MatrixXcd;
616 
617  // Check to see that we're not trying to invert on a zero-field source
618  //MW: it might be useful to check what to do here.
619  double b2[QUDA_MAX_MULTI_SHIFT];
620  double b2avg=0;
621  for(int i=0; i< param.num_src; i++){
622  b2[i]=blas::norm2(b.Component(i));
623  b2avg += b2[i];
624  if(b2[i] == 0){
625  profile.TPSTOP(QUDA_PROFILE_INIT);
626  errorQuda("Warning: inverting on zero-field source - undefined for block solver\n");
627  x=b;
628  param.true_res = 0.0;
629  param.true_res_hq = 0.0;
630  return;
631  }
632  }
633 
634  b2avg = b2avg / param.num_src;
635 
637  if (!init) {
642  // sloppy fields
643  csParam.setPrecision(param.precision_sloppy);
646  init = true;
647 
648  }
649  ColorSpinorField &r = *rp;
650  ColorSpinorField &y = *yp;
651  ColorSpinorField &Ap = *App;
653 
654  // calculate residuals for all vectors
655  for(int i=0; i<param.num_src; i++){
656  mat(r.Component(i), x.Component(i), y.Component(i));
657  blas::xmyNorm(b.Component(i), r.Component(i));
658  }
659 
660  // initialize r2 matrix
661  double r2avg=0;
662  MatrixXcd r2(param.num_src, param.num_src);
663  for(int i=0; i<param.num_src; i++){
664  for(int j=i; j < param.num_src; j++){
665  r2(i,j) = blas::cDotProduct(r.Component(i),r.Component(j));
666  if (i!=j) r2(j,i) = std::conj(r2(i,j));
667  if (i==j) {
668  r2avg += r2(i,i).real();
669  printfQuda("r2[%i] %e\n", i, r2(i,i).real());
670  }
671  }
672  }
673 
674 
675  csParam.setPrecision(param.precision_sloppy);
676  // tmp2 only needed for multi-gpu Wilson-like kernels
677  ColorSpinorField *tmp2_p = !mat.isStaggered() ?
679  ColorSpinorField &tmp2 = *tmp2_p;
680 
681  ColorSpinorField *r_sloppy;
682  if (param.precision_sloppy == x.Precision()) {
683  r_sloppy = &r;
684  } else {
685  // will that work ?
687  r_sloppy = ColorSpinorField::Create(r, csParam);
688  for(int i=0; i<param.num_src; i++){
689  blas::copy(r_sloppy->Component(i), r.Component(i)); //nop when these pointers alias
690  }
691  }
692 
693 
694  ColorSpinorField *x_sloppy;
696  x_sloppy = &x;
697  } else {
699  x_sloppy = ColorSpinorField::Create(x, csParam);
700  }
701 
702  // additional high-precision temporary if Wilson and mixed-precision
703  csParam.setPrecision(param.precision);
704  ColorSpinorField *tmp3_p =
707  ColorSpinorField &tmp3 = *tmp3_p;
708 
709  ColorSpinorField &xSloppy = *x_sloppy;
710  ColorSpinorField &rSloppy = *r_sloppy;
711 
713  csParam.setPrecision(param.precision_sloppy);
715  ColorSpinorField &p = *pp;
717  ColorSpinorField &rnew = *rpnew;
718 
719  if (&x != &xSloppy) {
720  blas::copy(y, x);
721  blas::zero(xSloppy);
722  } else {
723  blas::zero(y);
724  }
725 
726  const bool use_heavy_quark_res =
727  (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false;
728  if(use_heavy_quark_res) errorQuda("ERROR: heavy quark residual not supported in block solver");
729 
730  profile.TPSTOP(QUDA_PROFILE_INIT);
732 
733  double stop[QUDA_MAX_MULTI_SHIFT];
734 
735  for(int i = 0; i < param.num_src; i++){
736  stop[i] = stopping(param.tol, b2[i], param.residual_type); // stopping condition of solver
737  }
738 
739  // Eigen Matrices instead of scalars
740  MatrixXcd alpha = MatrixXcd::Zero(param.num_src,param.num_src);
741  MatrixXcd beta = MatrixXcd::Zero(param.num_src,param.num_src);
742  MatrixXcd C = MatrixXcd::Zero(param.num_src,param.num_src);
743  MatrixXcd S = MatrixXcd::Identity(param.num_src,param.num_src);
744  MatrixXcd pAp = MatrixXcd::Identity(param.num_src,param.num_src);
746 
747  #ifdef MWVERBOSE
748  MatrixXcd pTp = MatrixXcd::Identity(param.num_src,param.num_src);
749  #endif
750 
751 
752 
753 
754  //FIXME:reliable updates currently not implemented
755  /*
756  double rNorm[QUDA_MAX_MULTI_SHIFT];
757  double r0Norm[QUDA_MAX_MULTI_SHIFT];
758  double maxrx[QUDA_MAX_MULTI_SHIFT];
759  double maxrr[QUDA_MAX_MULTI_SHIFT];
760 
761  for(int i = 0; i < param.num_src; i++){
762  rNorm[i] = sqrt(r2(i,i).real());
763  r0Norm[i] = rNorm[i];
764  maxrx[i] = rNorm[i];
765  maxrr[i] = rNorm[i];
766  }
767  bool L2breakdown = false;
768  int rUpdate = 0;
769  nt steps_since_reliable = 1;
770  */
771 
773  profile.TPSTART(QUDA_PROFILE_COMPUTE);
774  blas::flops = 0;
775 
776  int k = 0;
777 
778  PrintStats("CG", k, r2avg / param.num_src, b2avg, 0.);
779  bool allconverged = true;
780  bool converged[QUDA_MAX_MULTI_SHIFT];
781  for(int i=0; i<param.num_src; i++){
782  converged[i] = convergence(r2(i,i).real(), 0., stop[i], param.tol_hq);
783  allconverged = allconverged && converged[i];
784  }
785 
786  // CHolesky decomposition
787  MatrixXcd L = r2.llt().matrixL();
788  C = L.adjoint();
789  MatrixXcd Linv = C.inverse();
790 
791  #ifdef MWVERBOSE
792  std::cout << "r2\n " << r2 << std::endl;
793  std::cout << "L\n " << L.adjoint() << std::endl;
794  #endif
795 
796  // set p to QR decompsition of r
797  // temporary hack - use AC to pass matrix arguments to multiblas
798  for(int i=0; i<param.num_src; i++){
799  blas::zero(p.Component(i));
800  for(int j=0;j<param.num_src; j++){
801  AC[i*param.num_src + j] = Linv(i,j);
802  }
803  }
804  blas::caxpy(AC,r,p);
805 
806  // set rsloppy to to QR decompoistion of r (p)
807  for(int i=0; i< param.num_src; i++){
808  blas::copy(rSloppy.Component(i), p.Component(i));
809  }
810 
811  #ifdef MWVERBOSE
812  for(int i=0; i<param.num_src; i++){
813  for(int j=0; j<param.num_src; j++){
814  pTp(i,j) = blas::cDotProduct(p.Component(i), p.Component(j));
815  }
816  }
817  std::cout << " pTp " << std::endl << pTp << std::endl;
818  std::cout << " L " << std::endl << L.adjoint() << std::endl;
819  std::cout << " C " << std::endl << C << std::endl;
820  #endif
821 
822  while ( !allconverged && k < param.maxiter ) {
823  // apply matrix
824  for(int i=0; i<param.num_src; i++){
825  matSloppy(Ap.Component(i), p.Component(i), tmp.Component(i), tmp2.Component(i)); // tmp as tmp
826  }
827 
828  // calculate pAp
829  for(int i=0; i<param.num_src; i++){
830  for(int j=i; j < param.num_src; j++){
831  pAp(i,j) = blas::cDotProduct(p.Component(i), Ap.Component(j));
832  if (i!=j) pAp(j,i) = std::conj(pAp(i,j));
833  }
834  }
835 
836  // update Xsloppy
837  alpha = pAp.inverse() * C;
838  // temporary hack using AC
839  for(int i=0; i<param.num_src; i++){
840  for(int j=0;j<param.num_src; j++){
841  AC[i*param.num_src + j] = alpha(i,j);
842  }
843  }
844  blas::caxpy(AC,p,xSloppy);
845 
846  // update rSloppy
847  beta = pAp.inverse();
848  // temporary hack
849  for(int i=0; i<param.num_src; i++){
850  for(int j=0;j<param.num_src; j++){
851  AC[i*param.num_src + j] = -beta(i,j);
852  }
853  }
854  blas::caxpy(AC,Ap,rSloppy);
855 
856  // orthorgonalize R
857  // copy rSloppy to rnew as temporary
858  for(int i=0; i< param.num_src; i++){
859  blas::copy(rnew.Component(i), rSloppy.Component(i));
860  }
861  for(int i=0; i<param.num_src; i++){
862  for(int j=i; j < param.num_src; j++){
863  r2(i,j) = blas::cDotProduct(r.Component(i),r.Component(j));
864  if (i!=j) r2(j,i) = std::conj(r2(i,j));
865  }
866  }
867  // Cholesky decomposition
868  L = r2.llt().matrixL();// retrieve factor L in the decomposition
869  S = L.adjoint();
870  Linv = S.inverse();
871  // temporary hack
872  for(int i=0; i<param.num_src; i++){
873  blas::zero(rSloppy.Component(i));
874  for(int j=0;j<param.num_src; j++){
875  AC[i*param.num_src + j] = Linv(i,j);
876  }
877  }
878  blas::caxpy(AC,rnew,rSloppy);
879 
880  #ifdef MWVERBOSE
881  for(int i=0; i<param.num_src; i++){
882  for(int j=0; j<param.num_src; j++){
883  pTp(i,j) = blas::cDotProduct(rSloppy.Component(i), rSloppy.Component(j));
884  }
885  }
886  std::cout << " rTr " << std::endl << pTp << std::endl;
887  std::cout << "QR" << S<< std::endl << "QP " << S.inverse()*S << std::endl;;
888  #endif
889 
890  // update p
891  // use rnew as temporary again for summing up
892  for(int i=0; i<param.num_src; i++){
893  blas::copy(rnew.Component(i),rSloppy.Component(i));
894  }
895  // temporary hack
896  for(int i=0; i<param.num_src; i++){
897  for(int j=0;j<param.num_src; j++){
898  AC[i*param.num_src + j] = std::conj(S(j,i));
899  }
900  }
901  blas::caxpy(AC,p,rnew);
902  // set p = rnew
903  for(int i=0; i < param.num_src; i++){
904  blas::copy(p.Component(i),rnew.Component(i));
905  }
906 
907  // update C
908  C = S * C;
909 
910  #ifdef MWVERBOSE
911  for(int i=0; i<param.num_src; i++){
912  for(int j=0; j<param.num_src; j++){
913  pTp(i,j) = blas::cDotProduct(p.Component(i), p.Component(j));
914  }
915  }
916  std::cout << " pTp " << std::endl << pTp << std::endl;
917  std::cout << "S " << S<< std::endl << "C " << C << std::endl;
918  #endif
919 
920  // calculate the residuals for all shifts
921  r2avg=0;
922  for (int j=0; j<param.num_src; j++ ){
923  r2(j,j) = C(0,j)*conj(C(0,j));
924  for(int i=1; i < param.num_src; i++)
925  r2(j,j) += C(i,j) * conj(C(i,j));
926  r2avg += r2(j,j).real();
927  }
928 
929  k++;
930  PrintStats("CG", k, r2avg / param.num_src, b2avg, 0);
931  // check convergence
932  allconverged = true;
933  for(int i=0; i<param.num_src; i++){
934  converged[i] = convergence(r2(i,i).real(), 0, stop[i], param.tol_hq);
935  allconverged = allconverged && converged[i];
936  }
937 
938 
939  }
940 
941  for(int i=0; i<param.num_src; i++){
942  blas::xpy(y.Component(i), xSloppy.Component(i));
943  }
944 
947 
949  double gflops = (blas::flops + mat.flops() + matSloppy.flops())*1e-9;
950  param.gflops = gflops;
951  param.iter += k;
952 
953  if (k == param.maxiter)
954  warningQuda("Exceeded maximum iterations %d", param.maxiter);
955 
956  // if (getVerbosity() >= QUDA_VERBOSE)
957  // printfQuda("CG: Reliable updates = %d\n", rUpdate);
958 
959  // compute the true residuals
960  for(int i=0; i<param.num_src; i++){
961  mat(r.Component(i), x.Component(i), y.Component(i), tmp3.Component(i));
962  param.true_res = sqrt(blas::xmyNorm(b.Component(i), r.Component(i)) / b2[i]);
966 
967  PrintSummary("CG", k, r2(i,i).real(), b2[i]);
968  }
969 
970  // reset the flops counters
971  blas::flops = 0;
972  mat.flops();
973  matSloppy.flops();
974 
976  profile.TPSTART(QUDA_PROFILE_FREE);
977 
978  if (&tmp3 != &tmp) delete tmp3_p;
979  if (&tmp2 != &tmp) delete tmp2_p;
980 
981  if (rSloppy.Precision() != r.Precision()) delete r_sloppy;
982  if (xSloppy.Precision() != x.Precision()) delete x_sloppy;
983 
984  delete rpnew;
985  delete pp;
986  delete[] AC;
987  profile.TPSTOP(QUDA_PROFILE_FREE);
988 
989  return;
990 
991  #endif
992 }
993 
994 #else
995 
996 // use Gram Schmidt in Block CG ?
997 #define BLOCKCG_GS 1
999  #ifndef BLOCKSOLVER
1000  errorQuda("QUDA_BLOCKSOLVER not built.");
1001  #else
1002  #ifdef BLOCKCG_GS
1003  printfQuda("BCGdQ Solver\n");
1004  #else
1005  printfQuda("BCQ Solver\n");
1006  #endif
1007  const bool use_block = true;
1009  errorQuda("Not supported");
1010 
1011  profile.TPSTART(QUDA_PROFILE_INIT);
1012 
1013  using Eigen::MatrixXcd;
1014  MatrixXcd mPAP(param.num_src,param.num_src);
1015  MatrixXcd mRR(param.num_src,param.num_src);
1016 
1017 
1018  // Check to see that we're not trying to invert on a zero-field source
1019  //MW: it might be useful to check what to do here.
1020  double b2[QUDA_MAX_MULTI_SHIFT];
1021  double b2avg=0;
1022  double r2avg=0;
1023  for(int i=0; i< param.num_src; i++){
1024  b2[i]=blas::norm2(b.Component(i));
1025  b2avg += b2[i];
1026  if(b2[i] == 0){
1027  profile.TPSTOP(QUDA_PROFILE_INIT);
1028  errorQuda("Warning: inverting on zero-field source\n");
1029  x=b;
1030  param.true_res = 0.0;
1031  param.true_res_hq = 0.0;
1032  return;
1033  }
1034  }
1035 
1036  #ifdef MWVERBOSE
1037  MatrixXcd b2m(param.num_src,param.num_src);
1038  // just to check details of b
1039  for(int i=0; i<param.num_src; i++){
1040  for(int j=0; j<param.num_src; j++){
1041  b2m(i,j) = blas::cDotProduct(b.Component(i), b.Component(j));
1042  }
1043  }
1044  std::cout << "b2m\n" << b2m << std::endl;
1045  #endif
1046 
1047  ColorSpinorParam csParam(x);
1048  if (!init) {
1053  // sloppy fields
1054  csParam.setPrecision(param.precision_sloppy);
1057  init = true;
1058 
1059  }
1060  ColorSpinorField &r = *rp;
1061  ColorSpinorField &y = *yp;
1062  ColorSpinorField &Ap = *App;
1063  ColorSpinorField &tmp = *tmpp;
1064 
1065 
1066  // const int i = 0; // MW: hack to be able to write Component(i) instead and try with i=0 for now
1067 
1068  for(int i=0; i<param.num_src; i++){
1069  mat(r.Component(i), x.Component(i), y.Component(i));
1070  }
1071 
1072  // double r2[QUDA_MAX_MULTI_SHIFT];
1073  MatrixXcd r2(param.num_src,param.num_src);
1074  for(int i=0; i<param.num_src; i++){
1075  r2(i,i) = blas::xmyNorm(b.Component(i), r.Component(i));
1076  printfQuda("r2[%i] %e\n", i, r2(i,i).real());
1077  }
1078  if(use_block){
1079  // MW need to initalize the full r2 matrix here
1080  for(int i=0; i<param.num_src; i++){
1081  for(int j=i+1; j<param.num_src; j++){
1082  r2(i,j) = blas::cDotProduct(r.Component(i), r.Component(j));
1083  r2(j,i) = std::conj(r2(i,j));
1084  }
1085  }
1086  }
1087 
1088  csParam.setPrecision(param.precision_sloppy);
1089  // tmp2 only needed for multi-gpu Wilson-like kernels
1090  ColorSpinorField *tmp2_p = !mat.isStaggered() ?
1092  ColorSpinorField &tmp2 = *tmp2_p;
1093 
1094  ColorSpinorField *r_sloppy;
1095  if (param.precision_sloppy == x.Precision()) {
1096  r_sloppy = &r;
1097  } else {
1098  // will that work ?
1100  r_sloppy = ColorSpinorField::Create(r, csParam);
1101  for(int i=0; i<param.num_src; i++){
1102  blas::copy(r_sloppy->Component(i), r.Component(i)); //nop when these pointers alias
1103  }
1104  }
1105 
1106 
1107  ColorSpinorField *x_sloppy;
1108  if (param.precision_sloppy == x.Precision() ||
1110  x_sloppy = &x;
1111  } else {
1113  x_sloppy = ColorSpinorField::Create(x, csParam);
1114  }
1115 
1116  // additional high-precision temporary if Wilson and mixed-precision
1117  csParam.setPrecision(param.precision);
1118  ColorSpinorField *tmp3_p =
1121  ColorSpinorField &tmp3 = *tmp3_p;
1122 
1123  ColorSpinorField &xSloppy = *x_sloppy;
1124  ColorSpinorField &rSloppy = *r_sloppy;
1125 
1127  csParam.setPrecision(param.precision_sloppy);
1128  ColorSpinorField* pp = ColorSpinorField::Create(rSloppy, csParam);
1129  ColorSpinorField &p = *pp;
1130  ColorSpinorField* ppnew = ColorSpinorField::Create(rSloppy, csParam);
1131  ColorSpinorField &pnew = *ppnew;
1132 
1133  if (&x != &xSloppy) {
1134  blas::copy(y, x);
1135  blas::zero(xSloppy);
1136  } else {
1137  blas::zero(y);
1138  }
1139 
1140  const bool use_heavy_quark_res =
1141  (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false;
1142  bool heavy_quark_restart = false;
1143 
1144  profile.TPSTOP(QUDA_PROFILE_INIT);
1145  profile.TPSTART(QUDA_PROFILE_PREAMBLE);
1146 
1147  MatrixXcd r2_old(param.num_src, param.num_src);
1148  double heavy_quark_res[QUDA_MAX_MULTI_SHIFT] = {0.0}; // heavy quark res idual
1149  double heavy_quark_res_old[QUDA_MAX_MULTI_SHIFT] = {0.0}; // heavy quark residual
1150  double stop[QUDA_MAX_MULTI_SHIFT];
1151 
1152  for(int i = 0; i < param.num_src; i++){
1153  stop[i] = stopping(param.tol, b2[i], param.residual_type); // stopping condition of solver
1154  if (use_heavy_quark_res) {
1155  heavy_quark_res[i] = sqrt(blas::HeavyQuarkResidualNorm(x.Component(i), r.Component(i)).z);
1156  heavy_quark_res_old[i] = heavy_quark_res[i]; // heavy quark residual
1157  }
1158  }
1159  const int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual
1160 
1161  MatrixXcd alpha = MatrixXcd::Zero(param.num_src,param.num_src);
1162  MatrixXcd beta = MatrixXcd::Zero(param.num_src,param.num_src);
1163  MatrixXcd gamma = MatrixXcd::Identity(param.num_src,param.num_src);
1164  // gamma = gamma * 2.0;
1165 
1166  MatrixXcd pAp(param.num_src, param.num_src);
1167  MatrixXcd pTp(param.num_src, param.num_src);
1168  int rUpdate = 0;
1169 
1170  double rNorm[QUDA_MAX_MULTI_SHIFT];
1171  double r0Norm[QUDA_MAX_MULTI_SHIFT];
1172  double maxrx[QUDA_MAX_MULTI_SHIFT];
1173  double maxrr[QUDA_MAX_MULTI_SHIFT];
1174 
1175  for(int i = 0; i < param.num_src; i++){
1176  rNorm[i] = sqrt(r2(i,i).real());
1177  r0Norm[i] = rNorm[i];
1178  maxrx[i] = rNorm[i];
1179  maxrr[i] = rNorm[i];
1180  }
1181 
1182  double delta = param.delta;//MW: hack no reliable updates param.delta;
1183 
1184  // this parameter determines how many consective reliable update
1185  // reisudal increases we tolerate before terminating the solver,
1186  // i.e., how long do we want to keep trying to converge
1187  const int maxResIncrease = (use_heavy_quark_res ? 0 : param.max_res_increase); // check if we reached the limit of our tolerance
1188  const int maxResIncreaseTotal = param.max_res_increase_total;
1189  // 0 means we have no tolerance
1190  // maybe we should expose this as a parameter
1191  const int hqmaxresIncrease = maxResIncrease + 1;
1192 
1193  int resIncrease = 0;
1194  int resIncreaseTotal = 0;
1195  int hqresIncrease = 0;
1196 
1197  // set this to true if maxResIncrease has been exceeded but when we use heavy quark residual we still want to continue the CG
1198  // only used if we use the heavy_quark_res
1199  bool L2breakdown = false;
1200 
1202  profile.TPSTART(QUDA_PROFILE_COMPUTE);
1203  blas::flops = 0;
1204 
1205  int k = 0;
1206 
1207  for(int i=0; i<param.num_src; i++){
1208  r2avg+=r2(i,i).real();
1209  }
1210  PrintStats("CG", k, r2avg, b2avg, heavy_quark_res[0]);
1211  int steps_since_reliable = 1;
1212  bool allconverged = true;
1213  bool converged[QUDA_MAX_MULTI_SHIFT];
1214  for(int i=0; i<param.num_src; i++){
1215  converged[i] = convergence(r2(i,i).real(), heavy_quark_res[i], stop[i], param.tol_hq);
1216  allconverged = allconverged && converged[i];
1217  }
1218  MatrixXcd sigma(param.num_src,param.num_src);
1219 
1220  #ifdef BLOCKCG_GS
1221  // begin ignore Gram-Schmidt for now
1222 
1223  for(int i=0; i < param.num_src; i++){
1224  double n = blas::norm2(p.Component(i));
1225  blas::ax(1/sqrt(n),p.Component(i));
1226  for(int j=i+1; j < param.num_src; j++) {
1227  std::complex<double> ri=blas::cDotProduct(p.Component(i),p.Component(j));
1228  blas::caxpy(-ri,p.Component(i),p.Component(j));
1229  }
1230  }
1231 
1232  gamma = MatrixXcd::Zero(param.num_src,param.num_src);
1233  for ( int i = 0; i < param.num_src; i++){
1234  for (int j=i; j < param.num_src; j++){
1235  gamma(i,j) = blas::cDotProduct(p.Component(i),pnew.Component(j));
1236  }
1237  }
1238  #endif
1239  // end ignore Gram-Schmidt for now
1240 
1241  #ifdef MWVERBOSE
1242  for(int i=0; i<param.num_src; i++){
1243  for(int j=0; j<param.num_src; j++){
1244  pTp(i,j) = blas::cDotProduct(p.Component(i), p.Component(j));
1245  }
1246  }
1247 
1248  std::cout << " pTp " << std::endl << pTp << std::endl;
1249  std::cout << "QR" << gamma<< std::endl << "QP " << gamma.inverse()*gamma << std::endl;;
1250  #endif
1251  while ( !allconverged && k < param.maxiter ) {
1252  for(int i=0; i<param.num_src; i++){
1253  matSloppy(Ap.Component(i), p.Component(i), tmp.Component(i), tmp2.Component(i)); // tmp as tmp
1254  }
1255 
1256 
1257  bool breakdown = false;
1258  // FIXME: need to check breakdown
1259  // current implementation sets breakdown to true for pipelined CG if one rhs triggers breakdown
1260  // this is probably ok
1261 
1262 
1263  if (param.pipeline) {
1264  errorQuda("pipeline not implemented");
1265  } else {
1266  r2_old = r2;
1267  for(int i=0; i<param.num_src; i++){
1268  for(int j=0; j < param.num_src; j++){
1269  if(use_block or i==j)
1270  pAp(i,j) = blas::cDotProduct(p.Component(i), Ap.Component(j));
1271  else
1272  pAp(i,j) = 0.;
1273  }
1274  }
1275 
1276  alpha = pAp.inverse() * gamma.adjoint().inverse() * r2;
1277  #ifdef MWVERBOSE
1278  std::cout << "alpha\n" << alpha << std::endl;
1279 
1280  if(k==1){
1281  std::cout << "pAp " << std::endl <<pAp << std::endl;
1282  std::cout << "pAp^-1 " << std::endl <<pAp.inverse() << std::endl;
1283  std::cout << "r2 " << std::endl <<r2 << std::endl;
1284  std::cout << "alpha " << std::endl <<alpha << std::endl;
1285  std::cout << "pAp^-1r2" << std::endl << pAp.inverse()*r2 << std::endl;
1286  }
1287  #endif
1288  // here we are deploying the alternative beta computation
1289  for(int i=0; i<param.num_src; i++){
1290  for(int j=0; j < param.num_src; j++){
1291 
1292  blas::caxpy(-alpha(j,i), Ap.Component(j), rSloppy.Component(i));
1293  }
1294  }
1295  // MW need to calculate the full r2 matrix here, after update. Not sure how to do alternative sigma yet ...
1296  for(int i=0; i<param.num_src; i++){
1297  for(int j=0; j<param.num_src; j++){
1298  if(use_block or i==j)
1299  r2(i,j) = blas::cDotProduct(r.Component(i), r.Component(j));
1300  else
1301  r2(i,j) = 0.;
1302  }
1303  }
1304  sigma = r2;
1305  }
1306 
1307 
1308  bool updateX=false;
1309  bool updateR=false;
1310  // int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? true : false;
1311  // int updateR = ((rNorm < delta*maxrr && r0Norm <= maxrr) || updateX) ? true : false;
1312  //
1313  // printfQuda("Checking reliable update %i %i\n",updateX,updateR);
1314  // reliable update conditions
1315  for(int i=0; i<param.num_src; i++){
1316  rNorm[i] = sqrt(r2(i,i).real());
1317  if (rNorm[i] > maxrx[i]) maxrx[i] = rNorm[i];
1318  if (rNorm[i] > maxrr[i]) maxrr[i] = rNorm[i];
1319  updateX = (rNorm[i] < delta * r0Norm[i] && r0Norm[i] <= maxrx[i]) ? true : false;
1320  updateR = ((rNorm[i] < delta * maxrr[i] && r0Norm[i] <= maxrr[i]) || updateX) ? true : false;
1321  }
1322  if ( (updateR || updateX )) {
1323  // printfQuda("Suppressing reliable update %i %i\n",updateX,updateR);
1324  updateX=false;
1325  updateR=false;
1326  // printfQuda("Suppressing reliable update %i %i\n",updateX,updateR);
1327  }
1328 
1329  if ( !(updateR || updateX )) {
1330 
1331  beta = gamma * r2_old.inverse() * sigma;
1332  #ifdef MWVERBOSE
1333  std::cout << "beta\n" << beta << std::endl;
1334  #endif
1335  if (param.pipeline && !breakdown)
1336  errorQuda("pipeline not implemented");
1337 
1338  else{
1339  for(int i=0; i<param.num_src; i++){
1340  for(int j=0; j<param.num_src; j++){
1341  blas::caxpy(alpha(j,i),p.Component(j),xSloppy.Component(i));
1342  }
1343  }
1344 
1345  // set to zero
1346  for(int i=0; i < param.num_src; i++){
1347  blas::ax(0,pnew.Component(i)); // do we need components here?
1348  }
1349  // add r
1350  for(int i=0; i<param.num_src; i++){
1351  // for(int j=0;j<param.num_src; j++){
1352  // order of updating p might be relevant here
1353  blas::axpy(1.0,r.Component(i),pnew.Component(i));
1354  // blas::axpby(rcoeff,rSloppy.Component(i),beta(i,j),p.Component(j));
1355  // }
1356  }
1357  // beta = beta * gamma.inverse();
1358  for(int i=0; i<param.num_src; i++){
1359  for(int j=0;j<param.num_src; j++){
1360  double rcoeff= (j==0?1.0:0.0);
1361  // order of updating p might be relevant hereq
1362  blas::caxpy(beta(j,i),p.Component(j),pnew.Component(i));
1363  // blas::axpby(rcoeff,rSloppy.Component(i),beta(i,j),p.Component(j));
1364  }
1365  }
1366  // now need to do something with the p's
1367 
1368  for(int i=0; i< param.num_src; i++){
1369  blas::copy(p.Component(i), pnew.Component(i));
1370  }
1371 
1372 
1373  #ifdef BLOCKCG_GS
1374  for(int i=0; i < param.num_src; i++){
1375  double n = blas::norm2(p.Component(i));
1376  blas::ax(1/sqrt(n),p.Component(i));
1377  for(int j=i+1; j < param.num_src; j++) {
1378  std::complex<double> ri=blas::cDotProduct(p.Component(i),p.Component(j));
1379  blas::caxpy(-ri,p.Component(i),p.Component(j));
1380 
1381  }
1382  }
1383 
1384 
1385  gamma = MatrixXcd::Zero(param.num_src,param.num_src);
1386  for ( int i = 0; i < param.num_src; i++){
1387  for (int j=i; j < param.num_src; j++){
1388  gamma(i,j) = blas::cDotProduct(p.Component(i),pnew.Component(j));
1389  }
1390  }
1391  #endif
1392 
1393  #ifdef MWVERBOSE
1394  for(int i=0; i<param.num_src; i++){
1395  for(int j=0; j<param.num_src; j++){
1396  pTp(i,j) = blas::cDotProduct(p.Component(i), p.Component(j));
1397  }
1398  }
1399  std::cout << " pTp " << std::endl << pTp << std::endl;
1400  std::cout << "QR" << gamma<< std::endl << "QP " << gamma.inverse()*gamma << std::endl;;
1401  #endif
1402  }
1403 
1404 
1405  if (use_heavy_quark_res && (k % heavy_quark_check) == 0) {
1406  if (&x != &xSloppy) {
1407  blas::copy(tmp, y); // FIXME: check whether copy works here
1408  for(int i=0; i<param.num_src; i++){
1409  heavy_quark_res[i] = sqrt(blas::xpyHeavyQuarkResidualNorm(xSloppy.Component(i), tmp.Component(i), rSloppy.Component(i)).z);
1410  }
1411  } else {
1412  blas::copy(r, rSloppy); // FIXME: check whether copy works here
1413  for(int i=0; i<param.num_src; i++){
1414  heavy_quark_res[i] = sqrt(blas::xpyHeavyQuarkResidualNorm(x.Component(i), y.Component(i), r.Component(i)).z);
1415  }
1416  }
1417  }
1418 
1419  steps_since_reliable++;
1420  } else {
1421  printfQuda("reliable update\n");
1422  for(int i=0; i<param.num_src; i++){
1423  blas::axpy(alpha(i,i).real(), p.Component(i), xSloppy.Component(i));
1424  }
1425  blas::copy(x, xSloppy); // nop when these pointers alias
1426 
1427  for(int i=0; i<param.num_src; i++){
1428  blas::xpy(x.Component(i), y.Component(i)); // swap these around?
1429  }
1430  for(int i=0; i<param.num_src; i++){
1431  mat(r.Component(i), y.Component(i), x.Component(i), tmp3.Component(i)); // here we can use x as tmp
1432  }
1433  for(int i=0; i<param.num_src; i++){
1434  r2(i,i) = blas::xmyNorm(b.Component(i), r.Component(i));
1435  }
1436 
1437  for(int i=0; i<param.num_src; i++){
1438  blas::copy(rSloppy.Component(i), r.Component(i)); //nop when these pointers alias
1439  blas::zero(xSloppy.Component(i));
1440  }
1441 
1442  // calculate new reliable HQ resididual
1443  if (use_heavy_quark_res){
1444  for(int i=0; i<param.num_src; i++){
1445  heavy_quark_res[i] = sqrt(blas::HeavyQuarkResidualNorm(y.Component(i), r.Component(i)).z);
1446  }
1447  }
1448 
1449  // MW: FIXME as this probably goes terribly wrong right now
1450  for(int i = 0; i<param.num_src; i++){
1451  // break-out check if we have reached the limit of the precision
1452  if (sqrt(r2(i,i).real()) > r0Norm[i] && updateX) { // reuse r0Norm for this
1453  resIncrease++;
1454  resIncreaseTotal++;
1455  warningQuda("CG: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
1456  sqrt(r2(i,i).real()), r0Norm[i], resIncreaseTotal);
1457  if ( resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
1458  if (use_heavy_quark_res) {
1459  L2breakdown = true;
1460  } else {
1461  warningQuda("CG: solver exiting due to too many true residual norm increases");
1462  break;
1463  }
1464  }
1465  } else {
1466  resIncrease = 0;
1467  }
1468  }
1469  // if L2 broke down already we turn off reliable updates and restart the CG
1470  for(int i = 0; i<param.num_src; i++){
1471  if (use_heavy_quark_res and L2breakdown) {
1472  delta = 0;
1473  warningQuda("CG: Restarting without reliable updates for heavy-quark residual");
1474  heavy_quark_restart = true;
1475  if (heavy_quark_res[i] > heavy_quark_res_old[i]) {
1476  hqresIncrease++;
1477  warningQuda("CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", heavy_quark_res[i], heavy_quark_res_old[i]);
1478  // break out if we do not improve here anymore
1479  if (hqresIncrease > hqmaxresIncrease) {
1480  warningQuda("CG: solver exiting due to too many heavy quark residual norm increases");
1481  break;
1482  }
1483  }
1484  }
1485  }
1486 
1487  for(int i=0; i<param.num_src; i++){
1488  rNorm[i] = sqrt(r2(i,i).real());
1489  maxrr[i] = rNorm[i];
1490  maxrx[i] = rNorm[i];
1491  r0Norm[i] = rNorm[i];
1492  heavy_quark_res_old[i] = heavy_quark_res[i];
1493  }
1494  rUpdate++;
1495 
1496  if (use_heavy_quark_res and heavy_quark_restart) {
1497  // perform a restart
1498  blas::copy(p, rSloppy);
1499  heavy_quark_restart = false;
1500  } else {
1501  // explicitly restore the orthogonality of the gradient vector
1502  for(int i=0; i<param.num_src; i++){
1503  double rp = blas::reDotProduct(rSloppy.Component(i), p.Component(i)) / (r2(i,i).real());
1504  blas::axpy(-rp, rSloppy.Component(i), p.Component(i));
1505 
1506  beta(i,i) = r2(i,i) / r2_old(i,i);
1507  blas::xpay(rSloppy.Component(i), beta(i,i).real(), p.Component(i));
1508  }
1509  }
1510 
1511  steps_since_reliable = 0;
1512  }
1513 
1514  breakdown = false;
1515  k++;
1516 
1517  allconverged = true;
1518  r2avg=0;
1519  for(int i=0; i<param.num_src; i++){
1520  r2avg+= r2(i,i).real();
1521  // check convergence, if convergence is satisfied we only need to check that we had a reliable update for the heavy quarks recently
1522  converged[i] = convergence(r2(i,i).real(), heavy_quark_res[i], stop[i], param.tol_hq);
1523  allconverged = allconverged && converged[i];
1524  }
1525  PrintStats("CG", k, r2avg, b2avg, heavy_quark_res[0]);
1526 
1527  // check for recent enough reliable updates of the HQ residual if we use it
1528  if (use_heavy_quark_res) {
1529  for(int i=0; i<param.num_src; i++){
1530  // L2 is concverged or precision maxed out for L2
1531  bool L2done = L2breakdown or convergenceL2(r2(i,i).real(), heavy_quark_res[i], stop[i], param.tol_hq);
1532  // HQ is converged and if we do reliable update the HQ residual has been calculated using a reliable update
1533  bool HQdone = (steps_since_reliable == 0 and param.delta > 0) and convergenceHQ(r2(i,i).real(), heavy_quark_res[i], stop[i], param.tol_hq);
1534  converged[i] = L2done and HQdone;
1535  }
1536  }
1537 
1538  }
1539 
1540  blas::copy(x, xSloppy);
1541  for(int i=0; i<param.num_src; i++){
1542  blas::xpy(y.Component(i), x.Component(i));
1543  }
1544 
1545  profile.TPSTOP(QUDA_PROFILE_COMPUTE);
1546  profile.TPSTART(QUDA_PROFILE_EPILOGUE);
1547 
1549  double gflops = (blas::flops + mat.flops() + matSloppy.flops())*1e-9;
1550  param.gflops = gflops;
1551  param.iter += k;
1552 
1553  if (k == param.maxiter)
1554  warningQuda("Exceeded maximum iterations %d", param.maxiter);
1555 
1556  if (getVerbosity() >= QUDA_VERBOSE)
1557  printfQuda("CG: Reliable updates = %d\n", rUpdate);
1558 
1559  // compute the true residuals
1560  for(int i=0; i<param.num_src; i++){
1561  mat(r.Component(i), x.Component(i), y.Component(i), tmp3.Component(i));
1562  param.true_res = sqrt(blas::xmyNorm(b.Component(i), r.Component(i)) / b2[i]);
1563  param.true_res_hq = sqrt(blas::HeavyQuarkResidualNorm(x.Component(i), r.Component(i)).z);
1566 
1567  PrintSummary("CG", k, r2(i,i).real(), b2[i]);
1568  }
1569 
1570  // reset the flops counters
1571  blas::flops = 0;
1572  mat.flops();
1573  matSloppy.flops();
1574 
1576  profile.TPSTART(QUDA_PROFILE_FREE);
1577 
1578  if (&tmp3 != &tmp) delete tmp3_p;
1579  if (&tmp2 != &tmp) delete tmp2_p;
1580 
1581  if (rSloppy.Precision() != r.Precision()) delete r_sloppy;
1582  if (xSloppy.Precision() != x.Precision()) delete x_sloppy;
1583 
1584  delete pp;
1585 
1586  profile.TPSTOP(QUDA_PROFILE_FREE);
1587 
1588  return;
1589 
1590  #endif
1591 
1592 }
1593 #endif
1594 
1595 
1596 } // namespace quda
bool convergence(const double &r2, const double &hq2, const double &r2_tol, const double &hq_tol)
Definition: solver.cpp:139
void xpay(ColorSpinorField &x, const double &a, ColorSpinorField &y)
Definition: blas_quda.cu:173
const Dirac * Expose()
Definition: dirac_quda.h:1011
static double stopping(const double &tol, const double &b2, QudaResidualType residual_type)
Definition: solver.cpp:122
double3 cDotProductNormA(ColorSpinorField &a, ColorSpinorField &b)
Definition: reduce_quda.cu:572
virtual ~CGNR()
Definition: inv_cg_quda.cpp:71
DiracMMdag mmdag
Definition: invert_quda.h:423
DiracMdagM mdagm
Definition: invert_quda.h:438
#define QUDA_MAX_MULTI_SHIFT
Maximum number of shifts supported by the multi-shift solver. This number may be changed if need be...
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:20
Matrix< N, std::complex< T > > conj(const Matrix< N, std::complex< T > > &mat)
#define errorQuda(...)
Definition: util_quda.h:90
double norm2(const ColorSpinorField &a)
Definition: reduce_quda.cu:241
void init()
Definition: blas_quda.cu:64
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:105
virtual ~CGNE()
Definition: inv_cg_quda.cpp:40
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
Definition: reduce_quda.cu:500
const DiracMatrix & matSloppy
Definition: invert_quda.h:406
std::complex< double > Complex
Definition: eig_variables.h:13
bool convergenceL2(const double &r2, const double &hq2, const double &r2_tol, const double &hq_tol)
Definition: solver.cpp:167
void xpayz(ColorSpinorField &x, const double &a, ColorSpinorField &y, ColorSpinorField &z)
Definition: blas_quda.cu:177
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:44
double3 xpyHeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &r)
Definition: reduce_quda.cu:742
double axpyNorm(const double &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:325
static ColorSpinorField * Create(const ColorSpinorParam &param)
double reDotProduct(ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:277
TimeProfile & profile
Definition: invert_quda.h:329
const DiracMatrix & mat
Definition: invert_quda.h:405
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
Definition: copy_quda.cu:263
void ax(const double &a, ColorSpinorField &x)
Definition: blas_quda.cu:209
ColorSpinorField & Component(const int idx) const
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:364
int max_res_increase_total
Definition: invert_quda.h:79
ColorSpinorField * xp
Definition: invert_quda.h:425
ColorSpinorField * bp
Definition: invert_quda.h:440
This is just a dummy structure we use for trove to define the required structure size.
CGNE(DiracMatrix &mat, DiracMatrix &matSloppy, SolverParam &param, TimeProfile &profile)
Definition: inv_cg_quda.cpp:36
QudaGaugeParam param
Definition: pack_test.cpp:17
#define b
bool init
Definition: invert_quda.h:410
QudaComputeNullVector compute_null_vector
Definition: invert_quda.h:53
double Last(QudaProfileType idx)
void PrintSummary(const char *name, int k, const double &r2, const double &b2)
Definition: solver.cpp:194
static unsigned int delta
QudaResidualType residual_type
Definition: invert_quda.h:47
void axpyZpbx(const double &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, const double &b)
Definition: blas_quda.cu:384
CG(DiracMatrix &mat, DiracMatrix &matSloppy, SolverParam &param, TimeProfile &profile)
Definition: inv_cg_quda.cpp:21
double true_res_hq_offset[QUDA_MAX_MULTI_SHIFT]
Definition: invert_quda.h:151
#define tmp2
Definition: tmc_core.h:16
ColorSpinorParam csParam
Definition: pack_test.cpp:24
Complex axpyCGNorm(const double &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:654
void tripleCGUpdate(const double &alpha, const double &beta, ColorSpinorField &q, ColorSpinorField &r, ColorSpinorField &x, ColorSpinorField &p)
Definition: blas_quda.cu:610
double4 quadrupleCGReduction(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
#define warningQuda(...)
Definition: util_quda.h:101
#define checkLocation(...)
__host__ __device__ ValueType pow(ValueType x, ExponentType e)
Definition: complex_quda.h:100
ColorSpinorField * tmpp
Definition: invert_quda.h:408
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
Definition: reduce_quda.cu:703
double true_res_offset[QUDA_MAX_MULTI_SHIFT]
Definition: invert_quda.h:145
CGNR(DiracMatrix &mat, DiracMatrix &matSloppy, SolverParam &param, TimeProfile &profile)
Definition: inv_cg_quda.cpp:67
double gamma(double) __attribute__((availability(macosx
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:246
ColorSpinorField * rp
Definition: invert_quda.h:408
void zero(ColorSpinorField &a)
Definition: blas_quda.cu:45
std::vector< ColorSpinorField * > p
Definition: invert_quda.h:409
void axpy(const double &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:150
void Mdag(ColorSpinorField &out, const ColorSpinorField &in) const
Definition: dirac.cpp:73
QudaPrecision precision
Definition: invert_quda.h:112
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const =0
SolverParam & param
Definition: invert_quda.h:328
void operator()(ColorSpinorField &out, ColorSpinorField &in)
Definition: inv_cg_quda.cpp:79
ColorSpinorField * yp
Definition: invert_quda.h:408
unsigned long long flops() const
Definition: dirac_quda.h:995
void PrintStats(const char *, int k, const double &r2, const double &b2, const double &hq2)
Definition: solver.cpp:179
#define printfQuda(...)
Definition: util_quda.h:84
void operator()(ColorSpinorField &out, ColorSpinorField &in)
unsigned long long flops
Definition: blas_quda.cu:42
void xpy(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:128
ColorSpinorField * App
Definition: invert_quda.h:408
QudaUseInitGuess use_init_guess
Definition: invert_quda.h:50
void operator()(ColorSpinorField &out, ColorSpinorField &in)
Definition: inv_cg_quda.cpp:48
void solve(ColorSpinorField &out, ColorSpinorField &in)
QudaPrecision precision_sloppy
Definition: invert_quda.h:115
bool use_sloppy_partial_accumulator
Definition: invert_quda.h:59
bool convergenceHQ(const double &r2, const double &hq2, const double &r2_tol, const double &hq_tol)
Definition: solver.cpp:156
virtual ~CG()
Definition: inv_cg_quda.cpp:25
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:115
int solution_accumulator_pipeline
Definition: invert_quda.h:69
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
static __inline__ size_t size_t d
QudaPrecision Precision() const
bool isStaggered() const
Definition: dirac_quda.h:1004
double3 tripleCGReduction(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Definition: reduce_quda.cu:767
void updateR()
update the radius for halos.
#define tmp3
Definition: tmc_core.h:17