QUDA  0.9.0
inv_eigcg_quda.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 #include <memory>
5 #include <iostream>
6 
7 #include <quda_internal.h>
8 #include <color_spinor_field.h>
9 #include <blas_quda.h>
10 #include <dslash_quda.h>
11 #include <invert_quda.h>
12 #include <util_quda.h>
13 #include <string.h>
14 
15 #ifdef MAGMA_LIB
16 #include <blas_magma.h>
17 #endif
18 
19 
20 #include <Eigen/Dense>
21 
22 #include <deflation.h>
23 
24 
25 /*
26 Based on eigCG(nev, m) algorithm:
27 A. Stathopolous and K. Orginos, arXiv:0707.0131
28 */
29 
30 namespace quda {
31 
32  using namespace blas;
33  using namespace Eigen;
34 
35  using DynamicStride = Stride<Dynamic, Dynamic>;
36  using DenseMatrix = MatrixXcd;
37  using VectorSet = MatrixXcd;
38  using Vector = VectorXcd;
39  using RealVector = VectorXd;
40 
41 //special types needed for compatibility with QUDA blas:
43 
44  static int max_eigcg_cycles = 4;//how many eigcg cycles do we allow?
45 
46 
48 
49  class EigCGArgs{
50 
51  public:
52  //host Lanczos matrice, and its eigenvalue/vector arrays:
53  DenseMatrix Tm;//VH A V,
54  //eigenvectors:
55  VectorSet ritzVecs;//array of (m) ritz and of m length
56  //eigenvalues of both T[m, m ] and T[m-1, m-1] (re-used)
57  RealVector Tmvals;//eigenvalues of T[m, m ] and T[m-1, m-1] (re-used)
58  //Aux matrix for computing 2k Ritz vectors:
60 
61  int m;
62  int k;
63  int id;//cuurent search spase index
64 
65  int restarts;
66  double global_stop;
67 
68  bool run_residual_correction;//used in mixed precision cycles
69 
70  ColorSpinorFieldSet *V2k; //eigCG accumulation vectors needed to update Tm (spinor matrix of size eigen_vector_length x (2*k))
71 
72  EigCGArgs(int m, int k) : Tm(DenseMatrix::Zero(m,m)), ritzVecs(VectorSet::Zero(m,m)), Tmvals(m), H2k(2*k, 2*k),
73  m(m), k(k), id(0), restarts(0), global_stop(0.0), run_residual_correction(false), V2k(nullptr) { }
74 
76  if(V2k) delete V2k;
77  }
78 
79  //method for constructing Lanczos matrix :
80  inline void SetLanczos(Complex diag_val, Complex offdiag_val) {
81  if(run_residual_correction) return;
82 
83  Tm.diagonal<0>()[id] = diag_val;
84 
85  if (id < (m-1)){ //Load Lanczos off-diagonals:
86  Tm.diagonal<+1>()[id] = offdiag_val;
87  Tm.diagonal<-1>()[id] = offdiag_val;
88  }
89 
90  id += 1;
91 
92  return;
93  }
94 
95  inline void ResetArgs() {
96  id = 0;
97  Tm.setZero();
98  Tmvals.setZero();
99  ritzVecs.setZero();
100 
101  if(V2k) delete V2k;
102  V2k = nullptr;
103  }
104 
105  inline void ResetSearchIdx() { id = 2*k; restarts += 1; }
106 
107  void RestartLanczos(ColorSpinorField *w, ColorSpinorFieldSet *v, const double inv_sqrt_r2)
108  {
109  Tm.setZero();
110 
111  std::unique_ptr<Complex[] > s(new Complex[2*k]);
112 
113  for(int i = 0; i < 2*k; i++) Tm(i,i) = Tmvals(i);//??
114 
115  std::vector<ColorSpinorField*> w_;
116  w_.push_back(w);
117 
118  std::vector<ColorSpinorField*> v_(v->Components().begin(), v->Components().begin()+2*k);
119 
120  blas::cDotProduct(s.get(), w_, v_);
121 
122  Map<VectorXcd, Unaligned > s_(s.get(), 2*k);
123  s_ *= inv_sqrt_r2;
124 
125  Tm.col(2*k).segment(0, 2*k) = s_;
126  Tm.row(2*k).segment(0, 2*k) = s_.adjoint();
127 
128  return;
129  }
130  };
131 
132  //Rayleigh Ritz procedure:
133  template<libtype which_lib> void ComputeRitz(EigCGArgs &args) {errorQuda("\nUnknown library type.\n");}
134 
135  //pure eigen version:
136  template <> void ComputeRitz<libtype::eigen_lib>(EigCGArgs &args)
137  {
138  const int m = args.m;
139  const int k = args.k;
140  //Solve m dim eigenproblem:
141  SelfAdjointEigenSolver<MatrixXcd> es_tm(args.Tm);
142  args.ritzVecs.leftCols(k) = es_tm.eigenvectors().leftCols(k);
143  //Solve m-1 dim eigenproblem:
144  SelfAdjointEigenSolver<MatrixXcd> es_tm1(Map<MatrixXcd, Unaligned, DynamicStride >(args.Tm.data(), (m-1), (m-1), DynamicStride(m, 1)));
145  Block<MatrixXcd>(args.ritzVecs.derived(), 0, k, m-1, k) = es_tm1.eigenvectors().leftCols(k);
146  args.ritzVecs.block(m-1, k, 1, k).setZero();
147 
148  MatrixXcd Q2k(MatrixXcd::Identity(m, 2*k));
149  HouseholderQR<MatrixXcd> ritzVecs2k_qr( Map<MatrixXcd, Unaligned >(args.ritzVecs.data(), m, 2*k) );
150  Q2k.applyOnTheLeft( ritzVecs2k_qr.householderQ() );
151 
152  //2. Construct H = QH*Tm*Q :
153  args.H2k = Q2k.adjoint()*args.Tm*Q2k;
154 
155  /* solve the small evecm1 2nev x 2nev eigenproblem */
156  SelfAdjointEigenSolver<MatrixXcd> es_h2k(args.H2k);
157  Block<MatrixXcd>(args.ritzVecs.derived(), 0, 0, m, 2*k) = Q2k * es_h2k.eigenvectors();
158  args.Tmvals.segment(0,2*k) = es_h2k.eigenvalues();//this is ok
159 
160  return;
161  }
162 
163  //(supposed to be a pure) magma version:
164  template <> void ComputeRitz<libtype::magma_lib>(EigCGArgs &args)
165  {
166 #ifdef MAGMA_LIB
167  const int m = args.m;
168  const int k = args.k;
169  //Solve m dim eigenproblem:
170  args.ritzVecs = args.Tm;
171  Complex *evecm = static_cast<Complex*>( args.ritzVecs.data());
172  double *evalm = static_cast<double *>( args.Tmvals.data());
173 
174  cudaHostRegister(static_cast<void *>(evecm), m*m*sizeof(Complex), cudaHostRegisterDefault);
175  magma_Xheev(evecm, m, m, evalm, sizeof(Complex));
176  //Solve m-1 dim eigenproblem:
177  DenseMatrix ritzVecsm1(args.Tm);
178  Complex *evecm1 = static_cast<Complex*>( ritzVecsm1.data());
179 
180  cudaHostRegister(static_cast<void *>(evecm1), m*m*sizeof(Complex), cudaHostRegisterDefault);
181  magma_Xheev(evecm1, (m-1), m, evalm, sizeof(Complex));
182  // fill 0s in mth element of old evecs:
183  for(int l = 1; l <= m ; l++) evecm1[l*m-1] = 0.0 ;
184  // Attach the first nev old evecs at the end of the nev latest ones:
185  memcpy(&evecm[k*m], evecm1, k*m*sizeof(Complex));
186 //?
187  // Orthogonalize the 2*nev (new+old) vectors evecm=QR:
188 
189  MatrixXcd Q2k(MatrixXcd::Identity(m, 2*k));
190  HouseholderQR<MatrixXcd> ritzVecs2k_qr( Map<MatrixXcd, Unaligned >(args.ritzVecs.data(), m, 2*k) );
191  Q2k.applyOnTheLeft( ritzVecs2k_qr.householderQ() );
192 
193  //2. Construct H = QH*Tm*Q :
194  args.H2k = Q2k.adjoint()*args.Tm*Q2k;
195 
196  /* solve the small evecm1 2nev x 2nev eigenproblem */
197  SelfAdjointEigenSolver<MatrixXcd> es_h2k(args.H2k);
198  Block<MatrixXcd>(args.ritzVecs.derived(), 0, 0, m, 2*k) = Q2k * es_h2k.eigenvectors();
199  args.Tmvals.segment(0,2*k) = es_h2k.eigenvalues();//this is ok
200 //?
201  cudaHostUnregister(evecm);
202  cudaHostUnregister(evecm1);
203 #else
204  errorQuda("Magma library was not built.\n");
205 #endif
206  return;
207  }
208 
209  // set the required parameters for the inner solver
210  static void fillEigCGInnerSolverParam(SolverParam &inner, const SolverParam &outer, bool use_sloppy_partial_accumulator = true)
211  {
212  inner.tol = outer.tol_precondition;
213  inner.maxiter = outer.maxiter_precondition;
214  inner.delta = 1e-20; // no reliable updates within the inner solver
215  inner.precision = outer.precision_precondition; // preconditioners are uni-precision solvers
217 
218  inner.iter = 0;
219  inner.gflops = 0;
220  inner.secs = 0;
221 
223  inner.is_preconditioner = true; // used to tell the inner solver it is an inner solver
224 
225  inner.use_sloppy_partial_accumulator= use_sloppy_partial_accumulator;
226 
230  }
231 
232  // set the required parameters for the initCG solver
233  static void fillInitCGSolverParam(SolverParam &inner, const SolverParam &outer) {
234  inner.iter = 0;
235  inner.gflops = 0;
236  inner.secs = 0;
237 
238  inner.tol = outer.tol;
239  inner.tol_restart = outer.tol_restart;
240  inner.maxiter = outer.maxiter;
241  inner.delta = outer.delta;
242  inner.precision = outer.precision; // preconditioners are uni-precision solvers
244 
245  inner.inv_type = QUDA_CG_INVERTER; // use CG solver
246  inner.use_init_guess = QUDA_USE_INIT_GUESS_YES;// use deflated initial guess...
247 
248  inner.use_sloppy_partial_accumulator= false;//outer.use_sloppy_partial_accumulator;
249  }
250 
251 
253  Solver(param, profile), mat(mat), matSloppy(matSloppy), matPrecon(matPrecon), K(nullptr), Kparam(param), Vm(nullptr), r_pre(nullptr), p_pre(nullptr), eigcg_args(nullptr), profile(profile), init(false)
254  {
255  if( param.rhs_idx < param.deflation_grid ) printfQuda("\nInitialize eigCG(m=%d, nev=%d) solver.\n", param.m, param.nev);
256  else {
257  printfQuda("\nDeflation space is complete, running initCG solver.\n");
259  //K = new CG(mat, matPrecon, Kparam, profile);//Preconditioned Mat has comms flag on
260  return;
261  }
262 
265  } else if ( param.inv_type == QUDA_INC_EIGCG_INVERTER ) {
266  if(param.inv_type_precondition != QUDA_INVALID_INVERTER) errorQuda("preconditioning is not supported for the incremental solver \n");
268  }
269 
271  K = new CG(matPrecon, matPrecon, Kparam, profile);
273  K = new MR(matPrecon, matPrecon, Kparam, profile);
275  K = new SD(matPrecon, Kparam, profile);
276  }else if(param.inv_type_precondition != QUDA_INVALID_INVERTER){ // unknown preconditioner
277  errorQuda("Unknown inner solver %d", param.inv_type_precondition);
278  }
279  return;
280  }
281 
283 
284  if(init)
285  {
286  if(Vm) delete Vm;
287 
288  delete tmpp;
289  delete rp;
290  delete yp;
291  delete Ap;
292  delete p;
293 
294  if(Az) delete Az;
295 
296  if(K) {
297  delete r_pre;
298  delete p_pre;
299 
300  delete K;
301  }
302  delete eigcg_args;
303  } else if (K) {
304  //delete K; //hack for the init CG solver
305  }
306  }
307 
308  void IncEigCG::RestartVT(const double beta, const double rho)
309  {
311 
313  ComputeRitz<libtype::magma_lib>(args);
314  } else if( param.extlib_type == QUDA_EIGEN_EXTLIB ) {
315  ComputeRitz<libtype::eigen_lib>(args);//if args.m > 128, one may better use libtype::magma_lib
316  } else {
317  errorQuda( "Library type %d is currently not supported.\n",param.extlib_type );
318  }
319 
320  //Restart V:
321 
322  blas::zero(*args.V2k);
323 
324  std::vector<ColorSpinorField*> vm (Vm->Components());
325  std::vector<ColorSpinorField*> v2k(args.V2k->Components());
326 
327  RowMajorDenseMatrix Alpha(args.ritzVecs.topLeftCorner(args.m, 2*args.k));
328  blas::caxpy( static_cast<Complex*>(Alpha.data()), vm , v2k);
329 
330  for(int i = 0; i < 2*args.k; i++) blas::copy(Vm->Component(i), args.V2k->Component(i));
331 
332  //Restart T:
333  ColorSpinorField *omega = nullptr;
334 
335  //Compute Az = Ap - beta*Ap_old(=Az):
336  blas::xpay(*Ap, -beta, *Az);
337 
338  if(Vm->Precision() != Az->Precision())//we may not need this if multiprec blas is used
339  {
340  Vm->Component(args.m-1) = *Az;//use the last vector as a temporary
341  omega = &Vm->Component(args.m-1);
342  }
343  else omega = Az;
344 
345  args.RestartLanczos(omega, Vm, 1.0 / rho);
346  return;
347  }
348 
349  void IncEigCG::UpdateVm(ColorSpinorField &res, double beta, double sqrtr2)
350  {
352 
353  if(args.run_residual_correction) return;
354 
355  if (args.id == param.m){//Begin Rayleigh-Ritz block:
356  //
357  RestartVT(beta, sqrtr2);
358  args.ResetSearchIdx();
359  } else if (args.id == (param.m-1)) {
360  blas::copy(*Az, *Ap);//save current mat-vec result if ready for the restart in the next cycle
361  }
362 
363  //load Lanczos basis vector:
364  blas::copy(Vm->Component(args.id), res);//convert arrays
365  //rescale the vector
366  blas::ax(1.0 / sqrtr2, Vm->Component(args.id));
367  return;
368  }
369 
370 /*
371  * This is a solo precision solver.
372 */
374 
375  int k=0;
376 
377  if (checkLocation(x, b) != QUDA_CUDA_FIELD_LOCATION) errorQuda("Not supported");
378 
379  profile.TPSTART(QUDA_PROFILE_INIT);
380 
381  // Check to see that we're not trying to invert on a zero-field source
382  const double b2 = blas::norm2(b);
383  if (b2 == 0) {
384  profile.TPSTOP(QUDA_PROFILE_INIT);
385  printfQuda("Warning: inverting on zero-field source\n");
386  x = b;
387  param.true_res = 0.0;
388  param.true_res_hq = 0.0;
389  return 0;
390  }
391 
393 
394  if (!init) {
395  eigcg_args = new EigCGArgs(param.m, param.nev);//need only deflation meta structure
396 
401 
404 
406 
408 
410  csParam.setPrecision(param.precision_precondition);
413  }
414 
415  //Create a search vector set:
416  csParam.setPrecision(param.precision_ritz);//eigCG internal search space precision may not coincide with the solver precision!
417  csParam.is_composite = true;
418  csParam.composite_dim = param.m;
419 
420  Vm = ColorSpinorFieldSet::Create(csParam); //search space for Ritz vectors
421 
422  eigcg_args->global_stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver
423 
424  init = true;
425  }
426 
427  double local_stop = x.Precision() == QUDA_DOUBLE_PRECISION ? b2*param.tol*param.tol : b2*1e-11;
428 
430 
431  if(args.run_residual_correction && param.inv_type == QUDA_INC_EIGCG_INVERTER) {
432  profile.TPSTOP(QUDA_PROFILE_INIT);
433  (*K)(x, b);
434  return Kparam.iter;
435  }
436 
438  csParam.setPrecision(QUDA_DOUBLE_PRECISION);
439 
441  csParam.is_composite = true;
442  csParam.composite_dim = (2*args.k);
443 
444  args.V2k = ColorSpinorFieldSet::Create(csParam); //search space for Ritz vectors
446  ColorSpinorField &r = *rp;
447  ColorSpinorField &y = *yp;
449 
450  csParam.setPrecision(param.precision_sloppy);
451  csParam.is_composite = false;
452 
453  // compute initial residual
454  matSloppy(r, x, y);
455  double r2 = blas::xmyNorm(b, r);
456 
457  ColorSpinorField *z = (K != nullptr) ? ColorSpinorField::Create(csParam) : rp;//
458 
459  if( K ) {//apply preconditioner
461 
462  ColorSpinorField &rPre = *r_pre;
463  ColorSpinorField &pPre = *p_pre;
464 
465  blas::copy(rPre, r);
466  commGlobalReductionSet(false);
467  (*K)(pPre, rPre);
469  blas::copy(*z, pPre);
470  }
471 
472  *p = *z;
473  blas::zero(y);
474 
475  const bool use_heavy_quark_res =
476  (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false;
477 
478  profile.TPSTOP(QUDA_PROFILE_INIT);
480 
481  double heavy_quark_res = 0.0; // heavy quark res idual
482 
483  if (use_heavy_quark_res) heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z);
484 
485  double pAp;
486  double alpha=1.0, alpha_inv=1.0, beta=0.0, alpha_old_inv = 1.0;
487 
488  double lanczos_diag, lanczos_offdiag;
489 
491  profile.TPSTART(QUDA_PROFILE_COMPUTE);
492  blas::flops = 0;
493 
494  double rMinvr = blas::reDotProduct(r,*z);
495  //Begin EigCG iterations:
496  args.restarts = 0;
497 
498  PrintStats("eigCG", k, r2, b2, heavy_quark_res);
499 
500  bool converged = convergence(r2, heavy_quark_res, args.global_stop, param.tol_hq);
501 
502  while ( !converged && k < param.maxiter ) {
503  matSloppy(*Ap, *p, tmp); // tmp as tmp
504 
505  pAp = blas::reDotProduct(*p, *Ap);
506  alpha_old_inv = alpha_inv;
507  alpha = rMinvr / pAp;
508  alpha_inv = 1.0 / alpha;
509 
510  lanczos_diag = (alpha_inv + beta*alpha_old_inv);
511 
512  UpdateVm(*z, beta, sqrt(r2));
513 
514  r2 = blas::axpyNorm(-alpha, *Ap, r);
515  if( K ) {//apply preconditioner
516  ColorSpinorField &rPre = *r_pre;
517  ColorSpinorField &pPre = *p_pre;
518 
519  blas::copy(rPre, r);
520  commGlobalReductionSet(false);
521  (*K)(pPre, rPre);
523  blas::copy(*z, pPre);
524  }
525  //
526  double rMinvr_old = rMinvr;
527  rMinvr = K ? blas::reDotProduct(r,*z) : r2;
528  beta = rMinvr / rMinvr_old;
529  blas::axpyZpbx(alpha, *p, y, *z, beta);
530 
531  //
532  lanczos_offdiag = (-sqrt(beta)*alpha_inv);
533  args.SetLanczos(lanczos_diag, lanczos_offdiag);
534 
535  k++;
536 
537  PrintStats("eigCG", k, r2, b2, heavy_quark_res);
538  // check convergence, if convergence is satisfied we only need to check that we had a reliable update for the heavy quarks recently
539  converged = convergence(r2, heavy_quark_res, args.global_stop, param.tol_hq) or convergence(r2, heavy_quark_res, local_stop, param.tol_hq);
540  }
541 
542  args.ResetArgs();//eigCG cycle finished, this cleans V2k as well
543 
544  blas::xpy(y, x);
545 
548 
550  double gflops = (blas::flops + matSloppy.flops())*1e-9;
551  param.gflops = gflops;
552  param.iter += k;
553 
554  if (k == param.maxiter)
555  warningQuda("Exceeded maximum iterations %d", param.maxiter);
556 
557  // compute the true residuals
558  matSloppy(r, x, y);
559  param.true_res = sqrt(blas::xmyNorm(b, r) / b2);
561 
562  PrintSummary("eigCG", k, r2, b2);
563 
564  // reset the flops counters
565  blas::flops = 0;
566  matSloppy.flops();
567 
569  profile.TPSTART(QUDA_PROFILE_FREE);
570 
571  profile.TPSTOP(QUDA_PROFILE_FREE);
572  return k;
573  }
574 
576  int k = 0;
577  //Start init CG iterations:
578  deflated_solver *defl_p = static_cast<deflated_solver*>(param.deflation_op);
579  Deflation &defl = *(defl_p->defl);
580 
581  const double full_tol = Kparam.tol;
583 
585 
587 
588  ColorSpinorField *tmpp2 = ColorSpinorField::Create(csParam);//full precision accumulator
589  ColorSpinorField &tmp2 = *tmpp2;
590  ColorSpinorField *rp = ColorSpinorField::Create(csParam);//full precision residual
591  ColorSpinorField &r = *rp;
592 
593  csParam.setPrecision(param.precision_ritz);
594 
596  ColorSpinorField &xProj = *xp_proj;
597 
599  ColorSpinorField &rProj = *rp_proj;
600 
601  int restart_idx = 0;
602 
603  xProj = x;
604  rProj = b;
605  //launch initCG:
606  while((Kparam.tol >= full_tol) && (restart_idx < param.max_restart_num)) {
607  restart_idx += 1;
608 
609  defl(xProj, rProj);
610  x = xProj;
611 
612  K = new CG(mat, matPrecon, Kparam, profile);
613  (*K)(x, b);
614  delete K;
615 
616  mat(r, x, tmp2);
617  blas::xpay(b, -1.0, r);
618 
619  xProj = x;
620  rProj = r;
621 
622  if(getVerbosity() >= QUDA_VERBOSE) printfQuda("\ninitCG stat: %i iter / %g secs = %g Gflops. \n", Kparam.iter, Kparam.secs, Kparam.gflops);
623 
624  Kparam.tol *= param.inc_tol;
625 
626  if(restart_idx == (param.max_restart_num-1)) Kparam.tol = full_tol;//do the last solve in the next cycle to full tolerance
627 
628  param.secs += Kparam.secs;
629  }
630 
631  if(getVerbosity() >= QUDA_VERBOSE) printfQuda("\ninitCG stat: %i iter / %g secs = %g Gflops. \n", Kparam.iter, Kparam.secs, Kparam.gflops);
632  //
633  param.secs += Kparam.secs;
635 
636  k += Kparam.iter;
637 
638  delete rp;
639  delete tmpp2;
640 
642  delete xp_proj;
643  delete rp_proj;
644  }
645  return k;
646  }
647 
649  {
651 
652  const bool mixed_prec = (param.precision != param.precision_sloppy);
653  const double b2 = norm2(in);
654 
655  deflated_solver *defl_p = static_cast<deflated_solver*>(param.deflation_op);
656  Deflation &defl = *(defl_p->defl);
657 
658  //If deflation space is complete: use initCG solver
659  if( defl.is_complete() ) {
660 
661  if(K) errorQuda("\nInitCG does not (yet) support preconditioning.\n");
662 
663  int iters = initCGsolve(out, in);
664  param.iter += iters;
665 
666  return;
667  }
668 
669  //Start (incremental) eigCG solver:
672 
673  ColorSpinorField *ep = ColorSpinorField::Create(csParam);//full precision accumulator
674  ColorSpinorField &e = *ep;
675  ColorSpinorField *rp = ColorSpinorField::Create(csParam);//full precision residual
676  ColorSpinorField &r = *rp;
677 
678  //deflate initial guess ('out'-field):
679  mat(r, out, e);
680  //
681  double r2 = xmyNorm(in, r);
682 
683  csParam.setPrecision(param.precision_sloppy);
684 
685  ColorSpinorField *ep_sloppy = ( mixed_prec ) ? ColorSpinorField::Create(csParam) : ep;
686  ColorSpinorField &eSloppy = *ep_sloppy;
687  ColorSpinorField *rp_sloppy = ( mixed_prec ) ? ColorSpinorField::Create(csParam) : rp;
688  ColorSpinorField &rSloppy = *rp_sloppy;
689 
690  const double stop = b2*param.tol*param.tol;
691  //start iterative refinement cycles (or just one eigcg call for full (solo) precision solver):
692  int logical_rhs_id = 0;
693  bool dcg_cycle = false;
694  do {
695  blas::zero(e);
696  defl(e, r);
697  //
698  eSloppy = e, rSloppy = r;
699 
700  if( dcg_cycle ) { //run DCG instead
701  if(!K) {
703  Kparam.tol = 5*param.inc_tol;//former cg_iterref_tol param
704  K = new CG(matSloppy, matPrecon, Kparam, profile);
705  }
706 
708  printfQuda("Running DCG correction cycle.\n");
709  }
710 
711  int iters = eigCGsolve(eSloppy, rSloppy);
712 
713  bool update_ritz = !dcg_cycle && (eigcg_args->restarts > 1) && !defl.is_complete(); //too uglyyy
714 
715  if( update_ritz ) {
716 
717  defl.increment(*Vm, param.nev);
718  logical_rhs_id += 1;
719 
720  dcg_cycle = (logical_rhs_id >= max_eigcg_cycles);
721 
722  } else { //run DCG instead
723  dcg_cycle = true;
724  }
725 
726  // use mixed blas ??
727  e = eSloppy;
728  blas::xpy(e, out);
729  // compute the true residuals
730  blas::zero(e);
731  mat(r, out, e);
732  //
733  r2 = blas::xmyNorm(in, r);
734 
735  param.true_res = sqrt(r2 / b2);
737  PrintSummary( !dcg_cycle ? "EigCG:" : "DCG (correction cycle):", iters, r2, b2);
738 
739  if( getVerbosity() >= QUDA_VERBOSE ) {
740  if( !dcg_cycle && (eigcg_args->restarts > 1) && !defl.is_complete() ) defl.verify();
741  }
742  } while ((r2 > stop) && mixed_prec);
743 
744  delete ep;
745  delete rp;
746 
747  if(mixed_prec){
748  delete ep_sloppy;
749  delete rp_sloppy;
750  }
751 
752  if (mixed_prec && max_eigcg_cycles > logical_rhs_id) {
753  printfQuda("Reset maximum eigcg cycles to %d (was %d)\n", logical_rhs_id, max_eigcg_cycles);
754  max_eigcg_cycles = logical_rhs_id;//adjust maximum allowed cycles based on the actual information
755  }
756 
757  param.rhs_idx += logical_rhs_id;
758 
759  if(defl.is_complete()) {
760  if(param.rhs_idx != param.deflation_grid) warningQuda("\nTotal rhs number (%d) does not match the deflation grid size (%d).\n", param.rhs_idx, param.deflation_grid);
761  if(Vm) delete Vm;//safe some space
762  Vm = nullptr;
763 
764  const int max_nev = defl.size();//param.m;
765  printfQuda("\nRequested to reserve %d eigenvectors with max tol %le.\n", max_nev, param.eigenval_tol);
766  defl.reduce(param.eigenval_tol, max_nev);
767  }
768  return;
769  }
770 
771 } // namespace quda
bool convergence(const double &r2, const double &hq2, const double &r2_tol, const double &hq_tol)
Definition: solver.cpp:139
VectorXd RealVector
void xpay(ColorSpinorField &x, const double &a, ColorSpinorField &y)
Definition: blas_quda.cu:173
ColorSpinorField * Az
temporary for mat-vec
Definition: invert_quda.h:830
static double stopping(const double &tol, const double &b2, QudaResidualType residual_type)
Definition: solver.cpp:122
ColorSpinorField * p_pre
residual passed to preconditioner
Definition: invert_quda.h:832
DiracMatrix & matSloppy
Definition: invert_quda.h:817
QudaInverterType inv_type
Definition: invert_quda.h:19
ColorSpinorField * tmpp
Definition: invert_quda.h:829
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:20
DiracMatrix & mat
Definition: invert_quda.h:816
#define errorQuda(...)
Definition: util_quda.h:90
double norm2(const ColorSpinorField &a)
Definition: reduce_quda.cu:241
void init()
Definition: blas_quda.cu:64
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:105
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
Definition: reduce_quda.cu:500
std::complex< double > Complex
Definition: eig_variables.h:13
ColorSpinorField * Ap
Definition: invert_quda.h:828
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:44
IncEigCG(DiracMatrix &mat, DiracMatrix &matSloppy, DiracMatrix &matPrecon, SolverParam &param, TimeProfile &profile)
double axpyNorm(const double &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:325
CompositeColorSpinorField & Components()
static ColorSpinorField * Create(const ColorSpinorParam &param)
int size()
return deflation space size
Definition: deflation.h:164
int initCGsolve(ColorSpinorField &out, ColorSpinorField &in)
double reDotProduct(ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:277
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
Definition: copy_quda.cu:263
void ax(const double &a, ColorSpinorField &x)
Definition: blas_quda.cu:209
ColorSpinorField & Component(const int idx) const
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:364
int eigCGsolve(ColorSpinorField &out, ColorSpinorField &in)
QudaPrecision precision_ritz
Definition: invert_quda.h:185
QudaInverterType inv_type_precondition
Definition: invert_quda.h:25
QudaPreserveSource preserve_source
Definition: invert_quda.h:121
void reduce(double tol, int max_nev)
Definition: deflation.cpp:285
double norm2(const CloverField &a, bool inverse=false)
QudaGaugeParam param
Definition: pack_test.cpp:17
#define b
bool is_complete()
Test whether the deflation space is complete and therefore cannot be further extended ...
Definition: deflation.h:159
void ComputeRitz(EigCGArgs &args)
double Last(QudaProfileType idx)
void PrintSummary(const char *name, int k, const double &r2, const double &b2)
Definition: solver.cpp:194
void magma_Xheev(void *Mat, const int n, const int ldm, void *evalues, const int prec)
Definition: blas_magma.cu:299
QudaResidualType residual_type
Definition: invert_quda.h:47
void axpyZpbx(const double &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, const double &b)
Definition: blas_quda.cu:384
Stride< Dynamic, Dynamic > DynamicStride
Definition: deflation.cpp:22
#define tmp2
Definition: tmc_core.h:16
ColorSpinorParam csParam
Definition: pack_test.cpp:24
cpuColorSpinorField * in
def id
projector matrices ######################################################################## ...
int int int w
TimeProfile & profile
Definition: invert_quda.h:836
void UpdateVm(ColorSpinorField &res, double beta, double sqrtr2)
static void fillInitCGSolverParam(SolverParam &inner, const SolverParam &outer)
#define warningQuda(...)
Definition: util_quda.h:101
#define checkLocation(...)
static void fillEigCGInnerSolverParam(SolverParam &inner, const SolverParam &outer, bool use_sloppy_partial_accumulator=true)
bool is_preconditioner
verbosity to use for preconditioner
Definition: invert_quda.h:199
static int max_eigcg_cycles
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
Definition: reduce_quda.cu:703
Deflation * defl
Definition: deflation.h:189
ColorSpinorField * rp
Definition: invert_quda.h:825
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:246
double tol_precondition
Definition: invert_quda.h:164
void * memcpy(void *__dst, const void *__src, size_t __n)
QudaExtLibType extlib_type
whether to use a global or local (node) reduction for this solver
Definition: invert_quda.h:204
void zero(ColorSpinorField &a)
Definition: blas_quda.cu:45
void SetLanczos(Complex diag_val, Complex offdiag_val)
dim3 dim3 void ** args
QudaPrecision precision_precondition
Definition: invert_quda.h:118
QudaPrecision precision
Definition: invert_quda.h:112
SolverParam & param
Definition: invert_quda.h:328
cpuColorSpinorField * out
ColorSpinorField * r_pre
Definition: invert_quda.h:831
ColorSpinorField * yp
residual vector
Definition: invert_quda.h:826
unsigned long long flops() const
Definition: dirac_quda.h:995
void PrintStats(const char *, int k, const double &r2, const double &b2, const double &hq2)
Definition: solver.cpp:179
MatrixXcd DenseMatrix
EigCGArgs * eigcg_args
preconditioner result
Definition: invert_quda.h:834
#define printfQuda(...)
Definition: util_quda.h:84
unsigned long long flops
Definition: blas_quda.cu:42
void xpy(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:128
DiracMatrix & matPrecon
Definition: invert_quda.h:818
VectorXcd Vector
void increment(ColorSpinorField &V, int nev)
Definition: deflation.cpp:207
ColorSpinorField * p
high precision accumulator
Definition: invert_quda.h:827
void RestartLanczos(ColorSpinorField *w, ColorSpinorFieldSet *v, const double inv_sqrt_r2)
QudaUseInitGuess use_init_guess
Definition: invert_quda.h:50
ColorSpinorFieldSet * V2k
QudaPrecision precision_sloppy
Definition: invert_quda.h:115
bool use_sloppy_partial_accumulator
Definition: invert_quda.h:59
SolverParam Kparam
Definition: invert_quda.h:821
double omega
Definition: test_util.cpp:1663
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
QudaPrecision Precision() const
void RestartVT(const double beta, const double rho)
MatrixXcd VectorSet
EigCGArgs(int m, int k)
ColorSpinorFieldSet * Vm
Definition: invert_quda.h:823
void commGlobalReductionSet(bool global_reduce)
void operator()(ColorSpinorField &out, ColorSpinorField &in)