QUDA  v1.1.0
A library for QCD on GPUs
inv_eigcg_quda.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 #include <memory>
5 #include <iostream>
6 
7 #include <quda_internal.h>
8 #include <color_spinor_field.h>
9 #include <blas_quda.h>
10 #include <dslash_quda.h>
11 #include <invert_quda.h>
12 #include <util_quda.h>
13 #include <string.h>
14 
15 #ifdef MAGMA_LIB
16 #include <blas_magma.h>
17 #endif
18 
19 #include <eigen_helper.h>
20 #include <deflation.h>
21 
22 /*
23 Based on eigCG(n_ev, m) algorithm:
24 A. Stathopolous and K. Orginos, arXiv:0707.0131
25 */
26 
27 namespace quda {
28 
29  using namespace blas;
30 
31  using DynamicStride = Stride<Dynamic, Dynamic>;
32  using DenseMatrix = MatrixXcd;
33  using VectorSet = MatrixXcd;
34  using Vector = VectorXcd;
35  using RealVector = VectorXd;
36 
37 //special types needed for compatibility with QUDA blas:
39 
40  static int max_eigcg_cycles = 4;//how many eigcg cycles do we allow?
41 
43 
44  class EigCGArgs
45  {
46 
47  public:
48  // host Lanczos matrice, and its eigenvalue/vector arrays:
49  DenseMatrix Tm; // VH A V,
50  // eigenvectors:
51  VectorSet ritzVecs; // array of (m) ritz and of m length
52  // eigenvalues of both T[m, m ] and T[m-1, m-1] (re-used)
53  RealVector Tmvals; // eigenvalues of T[m, m ] and T[m-1, m-1] (re-used)
54  // Aux matrix for computing 2k Ritz vectors:
56 
57  int m;
58  int k;
59  int id; // cuurent search spase index
60 
61  int restarts;
62  double global_stop;
63 
64  bool run_residual_correction; // used in mixed precision cycles
65 
67  *V2k; // eigCG accumulation vectors needed to update Tm (spinor matrix of size eigen_vector_length x (2*k))
68 
69  EigCGArgs(int m, int k) :
70  Tm(DenseMatrix::Zero(m, m)),
71  ritzVecs(VectorSet::Zero(m, m)),
72  Tmvals(m),
73  H2k(2 * k, 2 * k),
74  m(m),
75  k(k),
76  id(0),
77  restarts(0),
78  global_stop(0.0),
79  run_residual_correction(false),
80  V2k(nullptr)
81  {
82  }
83 
85  {
86  if (V2k) delete V2k;
87  }
88 
89  // method for constructing Lanczos matrix :
90  inline void SetLanczos(Complex diag_val, Complex offdiag_val)
91  {
92  if (run_residual_correction) return;
93 
94  Tm.diagonal<0>()[id] = diag_val;
95 
96  if (id < (m - 1)) { // Load Lanczos off-diagonals:
97  Tm.diagonal<+1>()[id] = offdiag_val;
98  Tm.diagonal<-1>()[id] = offdiag_val;
99  }
100 
101  id += 1;
102 
103  return;
104  }
105 
106  inline void ResetArgs()
107  {
108  id = 0;
109  Tm.setZero();
110  Tmvals.setZero();
111  ritzVecs.setZero();
112 
113  if (V2k) delete V2k;
114  V2k = nullptr;
115  }
116 
117  inline void ResetSearchIdx()
118  {
119  id = 2 * k;
120  restarts += 1;
121  }
122 
123  void RestartLanczos(ColorSpinorField *w, ColorSpinorFieldSet *v, const double inv_sqrt_r2)
124  {
125  Tm.setZero();
126 
127  auto s = std::make_unique<Complex[]>(2 * k);
128 
129  for (int i = 0; i < 2 * k; i++) Tm(i, i) = Tmvals(i); //??
130 
131  std::vector<ColorSpinorField *> w_;
132  w_.push_back(w);
133 
134  std::vector<ColorSpinorField *> v_(v->Components().begin(), v->Components().begin() + 2 * k);
135 
136  blas::cDotProduct(s.get(), w_, v_);
137 
138  Map<VectorXcd, Unaligned> s_(s.get(), 2 * k);
139  s_ *= inv_sqrt_r2;
140 
141  Tm.col(2 * k).segment(0, 2 * k) = s_;
142  Tm.row(2 * k).segment(0, 2 * k) = s_.adjoint();
143  }
144  };
145 
146  //Rayleigh Ritz procedure:
147  template <libtype which_lib> void ComputeRitz(EigCGArgs &args) { errorQuda("\nUnknown library type."); }
148 
149  //pure eigen version:
150  template <> void ComputeRitz<libtype::eigen_lib>(EigCGArgs &args)
151  {
152  const int m = args.m;
153  const int k = args.k;
154  //Solve m dim eigenproblem:
155  SelfAdjointEigenSolver<MatrixXcd> es_tm(args.Tm);
156  args.ritzVecs.leftCols(k) = es_tm.eigenvectors().leftCols(k);
157  //Solve m-1 dim eigenproblem:
158  SelfAdjointEigenSolver<MatrixXcd> es_tm1(Map<MatrixXcd, Unaligned, DynamicStride >(args.Tm.data(), (m-1), (m-1), DynamicStride(m, 1)));
159  Block<MatrixXcd>(args.ritzVecs.derived(), 0, k, m-1, k) = es_tm1.eigenvectors().leftCols(k);
160  args.ritzVecs.block(m-1, k, 1, k).setZero();
161 
162  MatrixXcd Q2k(MatrixXcd::Identity(m, 2*k));
163  HouseholderQR<MatrixXcd> ritzVecs2k_qr( Map<MatrixXcd, Unaligned >(args.ritzVecs.data(), m, 2*k) );
164  Q2k.applyOnTheLeft( ritzVecs2k_qr.householderQ() );
165 
166  //2. Construct H = QH*Tm*Q :
167  args.H2k = Q2k.adjoint()*args.Tm*Q2k;
168 
169  /* solve the small evecm1 2n_ev x 2n_ev eigenproblem */
170  SelfAdjointEigenSolver<MatrixXcd> es_h2k(args.H2k);
171  Block<MatrixXcd>(args.ritzVecs.derived(), 0, 0, m, 2*k) = Q2k * es_h2k.eigenvectors();
172  args.Tmvals.segment(0,2*k) = es_h2k.eigenvalues();//this is ok
173 
174  return;
175  }
176 
177  //(supposed to be a pure) magma version:
178  template <> void ComputeRitz<libtype::magma_lib>(EigCGArgs &args)
179  {
180 #ifdef MAGMA_LIB
181  const int m = args.m;
182  const int k = args.k;
183  //Solve m dim eigenproblem:
184  args.ritzVecs = args.Tm;
185  Complex *evecm = static_cast<Complex*>( args.ritzVecs.data());
186  double *evalm = static_cast<double *>( args.Tmvals.data());
187 
188  cudaHostRegister(static_cast<void *>(evecm), m*m*sizeof(Complex), cudaHostRegisterDefault);
189  magma_Xheev(evecm, m, m, evalm, sizeof(Complex));
190  //Solve m-1 dim eigenproblem:
191  DenseMatrix ritzVecsm1(args.Tm);
192  Complex *evecm1 = static_cast<Complex*>( ritzVecsm1.data());
193 
194  cudaHostRegister(static_cast<void *>(evecm1), m*m*sizeof(Complex), cudaHostRegisterDefault);
195  magma_Xheev(evecm1, (m-1), m, evalm, sizeof(Complex));
196  // fill 0s in mth element of old evecs:
197  for(int l = 1; l <= m ; l++) evecm1[l*m-1] = 0.0 ;
198  // Attach the first n_ev old evecs at the end of the n_ev latest ones:
199  memcpy(&evecm[k*m], evecm1, k*m*sizeof(Complex));
200  //?
201  // Orthogonalize the 2*n_ev (new+old) vectors evecm=QR:
202 
203  MatrixXcd Q2k(MatrixXcd::Identity(m, 2*k));
204  HouseholderQR<MatrixXcd> ritzVecs2k_qr( Map<MatrixXcd, Unaligned >(args.ritzVecs.data(), m, 2*k) );
205  Q2k.applyOnTheLeft( ritzVecs2k_qr.householderQ() );
206 
207  //2. Construct H = QH*Tm*Q :
208  args.H2k = Q2k.adjoint()*args.Tm*Q2k;
209 
210  /* solve the small evecm1 2n_ev x 2n_ev eigenproblem */
211  SelfAdjointEigenSolver<MatrixXcd> es_h2k(args.H2k);
212  Block<MatrixXcd>(args.ritzVecs.derived(), 0, 0, m, 2*k) = Q2k * es_h2k.eigenvectors();
213  args.Tmvals.segment(0,2*k) = es_h2k.eigenvalues();//this is ok
214 //?
215  cudaHostUnregister(evecm);
216  cudaHostUnregister(evecm1);
217 #else
218  errorQuda("Magma library was not built.");
219 #endif
220  return;
221  }
222 
223  // set the required parameters for the inner solver
224  static void fillEigCGInnerSolverParam(SolverParam &inner, const SolverParam &outer, bool use_sloppy_partial_accumulator = true)
225  {
226  inner.tol = outer.tol_precondition;
227  inner.maxiter = outer.maxiter_precondition;
228  inner.delta = 1e-20; // no reliable updates within the inner solver
229  inner.precision = outer.precision_precondition; // preconditioners are uni-precision solvers
230  inner.precision_sloppy = outer.precision_precondition;
231 
232  inner.iter = 0;
233  inner.gflops = 0;
234  inner.secs = 0;
235 
236  inner.inv_type_precondition = QUDA_INVALID_INVERTER;
237  inner.is_preconditioner = true; // used to tell the inner solver it is an inner solver
238 
239  inner.use_sloppy_partial_accumulator= use_sloppy_partial_accumulator;
240 
241  if(outer.inv_type == QUDA_EIGCG_INVERTER && outer.precision_sloppy != outer.precision_precondition)
242  inner.preserve_source = QUDA_PRESERVE_SOURCE_NO;
243  else inner.preserve_source = QUDA_PRESERVE_SOURCE_YES;
244  }
245 
246  // set the required parameters for the initCG solver
247  static void fillInitCGSolverParam(SolverParam &inner, const SolverParam &outer) {
248  inner.iter = 0;
249  inner.gflops = 0;
250  inner.secs = 0;
251 
252  inner.tol = outer.tol;
253  inner.tol_restart = outer.tol_restart;
254  inner.maxiter = outer.maxiter;
255  inner.delta = outer.delta;
256  inner.precision = outer.precision; // preconditioners are uni-precision solvers
257  inner.precision_sloppy = outer.precision_precondition;
258 
259  inner.inv_type = QUDA_CG_INVERTER; // use CG solver
260  inner.use_init_guess = QUDA_USE_INIT_GUESS_YES;// use deflated initial guess...
261 
262  inner.use_sloppy_partial_accumulator= false;//outer.use_sloppy_partial_accumulator;
263  }
264 
265  IncEigCG::IncEigCG(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon,
266  SolverParam &param, TimeProfile &profile) :
267  Solver(mat, matSloppy, matPrecon, matPrecon, param, profile),
268  K(nullptr),
269  Kparam(param),
270  Vm(nullptr),
271  r_pre(nullptr),
272  p_pre(nullptr),
273  eigcg_args(nullptr),
274  profile(profile),
275  init(false)
276  {
277 
278  if (2 * param.n_ev >= param.m)
279  errorQuda(
280  "Incorrect number of the requested low modes: m= %d while n_ev=%d (note that 2*n_ev must be less then m).",
281  param.m, param.n_ev);
282 
284  printfQuda("\nInitialize eigCG(m=%d, n_ev=%d) solver.", param.m, param.n_ev);
285  else {
286  printfQuda("\nDeflation space is complete, running initCG solver.");
287  fillInitCGSolverParam(Kparam, param);
288  //K = new CG(mat, matPrecon, Kparam, profile);//Preconditioned Mat has comms flag on
289  return;
290  }
291 
293  fillEigCGInnerSolverParam(Kparam, param);
294  } else if ( param.inv_type == QUDA_INC_EIGCG_INVERTER ) {
296  errorQuda("preconditioning is not supported for the incremental solver.");
297  fillInitCGSolverParam(Kparam, param);
298  }
299 
301  K = new CG(matPrecon, matPrecon, matPrecon, matPrecon, Kparam, profile);
303  K = new MR(matPrecon, matPrecon, Kparam, profile);
305  K = new SD(matPrecon, Kparam, profile);
306  }else if(param.inv_type_precondition != QUDA_INVALID_INVERTER){ // unknown preconditioner
307  errorQuda("Unknown inner solver %d", param.inv_type_precondition);
308  }
309  return;
310  }
311 
313 
314  if(init)
315  {
316  if(Vm) delete Vm;
317 
318  delete tmpp;
319  delete rp;
320  delete yp;
321  delete Ap;
322  delete p;
323 
324  if(Az) delete Az;
325 
326  if(K) {
327  delete r_pre;
328  delete p_pre;
329 
330  delete K;
331  }
332  delete eigcg_args;
333  } else if (K) {
334  //delete K; //hack for the init CG solver
335  }
336  }
337 
338  void IncEigCG::RestartVT(const double beta, const double rho)
339  {
340  EigCGArgs &args = *eigcg_args;
341 
343  ComputeRitz<libtype::magma_lib>(args);
344  } else if( param.extlib_type == QUDA_EIGEN_EXTLIB ) {
345  ComputeRitz<libtype::eigen_lib>(args);//if args.m > 128, one may better use libtype::magma_lib
346  } else {
347  errorQuda("Library type %d is currently not supported.", param.extlib_type);
348  }
349 
350  //Restart V:
351 
352  blas::zero(*args.V2k);
353 
354  std::vector<ColorSpinorField*> vm (Vm->Components());
355  std::vector<ColorSpinorField*> v2k(args.V2k->Components());
356 
357  RowMajorDenseMatrix Alpha(args.ritzVecs.topLeftCorner(args.m, 2*args.k));
358  blas::caxpy( static_cast<Complex*>(Alpha.data()), vm , v2k);
359 
360  for(int i = 0; i < 2*args.k; i++) blas::copy(Vm->Component(i), args.V2k->Component(i));
361 
362  //Restart T:
363  ColorSpinorField *omega = nullptr;
364 
365  //Compute Az = Ap - beta*Ap_old(=Az):
366  blas::xpay(*Ap, -beta, *Az);
367 
368  if(Vm->Precision() != Az->Precision())//we may not need this if multiprec blas is used
369  {
370  Vm->Component(args.m-1) = *Az;//use the last vector as a temporary
371  omega = &Vm->Component(args.m-1);
372  }
373  else omega = Az;
374 
375  args.RestartLanczos(omega, Vm, 1.0 / rho);
376  return;
377  }
378 
379  void IncEigCG::UpdateVm(ColorSpinorField &res, double beta, double sqrtr2)
380  {
381  EigCGArgs &args = *eigcg_args;
382 
383  if(args.run_residual_correction) return;
384 
385  if (args.id == param.m){//Begin Rayleigh-Ritz block:
386  //
387  RestartVT(beta, sqrtr2);
388  args.ResetSearchIdx();
389  } else if (args.id == (param.m-1)) {
390  blas::copy(*Az, *Ap);//save current mat-vec result if ready for the restart in the next cycle
391  }
392 
393  //load Lanczos basis vector:
394  blas::copy(Vm->Component(args.id), res);//convert arrays
395  //rescale the vector
396  blas::ax(1.0 / sqrtr2, Vm->Component(args.id));
397  return;
398  }
399 
400 /*
401  * This is a solo precision solver.
402 */
404 
405  int k=0;
406 
407  if (checkLocation(x, b) != QUDA_CUDA_FIELD_LOCATION) errorQuda("Not supported");
408 
409  profile.TPSTART(QUDA_PROFILE_INIT);
410 
411  // Check to see that we're not trying to invert on a zero-field source
412  const double b2 = blas::norm2(b);
413  if (b2 == 0) {
414  profile.TPSTOP(QUDA_PROFILE_INIT);
415  printfQuda("Warning: inverting on zero-field source\n");
416  x = b;
417  param.true_res = 0.0;
418  param.true_res_hq = 0.0;
419  return 0;
420  }
421 
423 
424  if (!init) {
425  eigcg_args = new EigCGArgs(param.m, param.n_ev); // need only deflation meta structure
426 
431 
434 
436 
438 
440  csParam.setPrecision(param.precision_precondition);
443  }
444 
445  //Create a search vector set:
446  csParam.setPrecision(param.precision_ritz);//eigCG internal search space precision may not coincide with the solver precision!
447  csParam.is_composite = true;
448  csParam.composite_dim = param.m;
449 
450  Vm = ColorSpinorFieldSet::Create(csParam); //search space for Ritz vectors
451 
452  eigcg_args->global_stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver
453 
454  init = true;
455  }
456 
457  double local_stop = x.Precision() == QUDA_DOUBLE_PRECISION ? b2*param.tol*param.tol : b2*1e-11;
458 
459  EigCGArgs &args = *eigcg_args;
460 
462  profile.TPSTOP(QUDA_PROFILE_INIT);
463  (*K)(x, b);
464  return Kparam.iter;
465  }
466 
468  csParam.setPrecision(QUDA_DOUBLE_PRECISION);
469 
471  csParam.is_composite = true;
472  csParam.composite_dim = (2*args.k);
473 
474  args.V2k = ColorSpinorFieldSet::Create(csParam); //search space for Ritz vectors
476  ColorSpinorField &r = *rp;
477  ColorSpinorField &y = *yp;
478  ColorSpinorField &tmp = *tmpp;
479 
480  csParam.setPrecision(param.precision_sloppy);
481  csParam.is_composite = false;
482 
483  // compute initial residual
484  matSloppy(r, x, y);
485  double r2 = blas::xmyNorm(b, r);
486 
487  ColorSpinorField *z = (K != nullptr) ? ColorSpinorField::Create(csParam) : rp;//
488 
489  if( K ) {//apply preconditioner
490  if (param.precision_precondition == param.precision_sloppy) { r_pre = rp; p_pre = z; }
491 
492  ColorSpinorField &rPre = *r_pre;
493  ColorSpinorField &pPre = *p_pre;
494 
495  blas::copy(rPre, r);
496  commGlobalReductionSet(false);
497  (*K)(pPre, rPre);
499  blas::copy(*z, pPre);
500  }
501 
502  *p = *z;
503  blas::zero(y);
504 
505  const bool use_heavy_quark_res =
506  (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false;
507 
508  profile.TPSTOP(QUDA_PROFILE_INIT);
509  profile.TPSTART(QUDA_PROFILE_PREAMBLE);
510 
511  double heavy_quark_res = 0.0; // heavy quark res idual
512 
513  if (use_heavy_quark_res) heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z);
514 
515  double pAp;
516  double alpha=1.0, alpha_inv=1.0, beta=0.0, alpha_old_inv = 1.0;
517 
518  double lanczos_diag, lanczos_offdiag;
519 
520  profile.TPSTOP(QUDA_PROFILE_PREAMBLE);
521  profile.TPSTART(QUDA_PROFILE_COMPUTE);
522  blas::flops = 0;
523 
524  double rMinvr = blas::reDotProduct(r,*z);
525  //Begin EigCG iterations:
526  args.restarts = 0;
527 
528  PrintStats("eigCG", k, r2, b2, heavy_quark_res);
529 
530  bool converged = convergence(r2, heavy_quark_res, args.global_stop, param.tol_hq);
531 
532  while ( !converged && k < param.maxiter ) {
533  matSloppy(*Ap, *p, tmp); // tmp as tmp
534 
535  pAp = blas::reDotProduct(*p, *Ap);
536  alpha_old_inv = alpha_inv;
537  alpha = rMinvr / pAp;
538  alpha_inv = 1.0 / alpha;
539 
540  lanczos_diag = (alpha_inv + beta*alpha_old_inv);
541 
542  UpdateVm(*z, beta, sqrt(r2));
543 
544  r2 = blas::axpyNorm(-alpha, *Ap, r);
545  if( K ) {//apply preconditioner
546  ColorSpinorField &rPre = *r_pre;
547  ColorSpinorField &pPre = *p_pre;
548 
549  blas::copy(rPre, r);
550  commGlobalReductionSet(false);
551  (*K)(pPre, rPre);
553  blas::copy(*z, pPre);
554  }
555  //
556  double rMinvr_old = rMinvr;
557  rMinvr = K ? blas::reDotProduct(r,*z) : r2;
558  beta = rMinvr / rMinvr_old;
559  blas::axpyZpbx(alpha, *p, y, *z, beta);
560 
561  //
562  lanczos_offdiag = (-sqrt(beta)*alpha_inv);
563  args.SetLanczos(lanczos_diag, lanczos_offdiag);
564 
565  k++;
566 
567  PrintStats("eigCG", k, r2, b2, heavy_quark_res);
568  // check convergence, if convergence is satisfied we only need to check that we had a reliable update for the heavy quarks recently
569  converged = convergence(r2, heavy_quark_res, args.global_stop, param.tol_hq) or convergence(r2, heavy_quark_res, local_stop, param.tol_hq);
570  }
571 
572  args.ResetArgs();//eigCG cycle finished, this cleans V2k as well
573 
574  blas::xpy(y, x);
575 
576  profile.TPSTOP(QUDA_PROFILE_COMPUTE);
577  profile.TPSTART(QUDA_PROFILE_EPILOGUE);
578 
579  param.secs = profile.Last(QUDA_PROFILE_COMPUTE);
580  double gflops = (blas::flops + matSloppy.flops())*1e-9;
581  param.gflops = gflops;
582  param.iter += k;
583 
584  if (k == param.maxiter)
585  warningQuda("Exceeded maximum iterations %d", param.maxiter);
586 
587  // compute the true residuals
588  matSloppy(r, x, y);
589  param.true_res = sqrt(blas::xmyNorm(b, r) / b2);
591 
592  PrintSummary("eigCG", k, r2, b2, args.global_stop, param.tol_hq);
593 
594  // reset the flops counters
595  blas::flops = 0;
596  matSloppy.flops();
597 
598  profile.TPSTOP(QUDA_PROFILE_EPILOGUE);
599  profile.TPSTART(QUDA_PROFILE_FREE);
600 
601  profile.TPSTOP(QUDA_PROFILE_FREE);
602  return k;
603  }
604 
606  int k = 0;
607  //Start init CG iterations:
608  deflated_solver *defl_p = static_cast<deflated_solver*>(param.deflation_op);
609  Deflation &defl = *(defl_p->defl);
610 
611  const double full_tol = Kparam.tol;
612  Kparam.tol = Kparam.tol_restart;
613 
615 
617 
618  ColorSpinorField *tmpp2 = ColorSpinorField::Create(csParam);//full precision accumulator
619  ColorSpinorField &tmp2 = *tmpp2;
620  ColorSpinorField *rp = ColorSpinorField::Create(csParam);//full precision residual
621  ColorSpinorField &r = *rp;
622 
623  csParam.setPrecision(param.precision_ritz);
624 
626  ColorSpinorField &xProj = *xp_proj;
627 
629  ColorSpinorField &rProj = *rp_proj;
630 
631  int restart_idx = 0;
632 
633  xProj = x;
634  rProj = b;
635  //launch initCG:
636  while((Kparam.tol >= full_tol) && (restart_idx < param.max_restart_num)) {
637  restart_idx += 1;
638 
639  defl(xProj, rProj);
640  x = xProj;
641 
642  K = new CG(mat, matPrecon, matPrecon, matPrecon, Kparam, profile);
643  (*K)(x, b);
644  delete K;
645 
646  mat(r, x, tmp2);
647  blas::xpay(b, -1.0, r);
648 
649  xProj = x;
650  rProj = r;
651 
652  if(getVerbosity() >= QUDA_VERBOSE) printfQuda("\ninitCG stat: %i iter / %g secs = %g Gflops. \n", Kparam.iter, Kparam.secs, Kparam.gflops);
653 
654  Kparam.tol *= param.inc_tol;
655 
656  if(restart_idx == (param.max_restart_num-1)) Kparam.tol = full_tol;//do the last solve in the next cycle to full tolerance
657 
658  param.secs += Kparam.secs;
659  }
660 
661  if(getVerbosity() >= QUDA_VERBOSE) printfQuda("\ninitCG stat: %i iter / %g secs = %g Gflops. \n", Kparam.iter, Kparam.secs, Kparam.gflops);
662  //
663  param.secs += Kparam.secs;
664  param.gflops += Kparam.gflops;
665 
666  k += Kparam.iter;
667 
668  delete rp;
669  delete tmpp2;
670 
672  delete xp_proj;
673  delete rp_proj;
674  }
675  return k;
676  }
677 
679  {
680  if(param.rhs_idx == 0) max_eigcg_cycles = param.eigcg_max_restarts;
681 
682  const bool mixed_prec = (param.precision != param.precision_sloppy);
683  const double b2 = norm2(in);
684 
685  deflated_solver *defl_p = static_cast<deflated_solver*>(param.deflation_op);
686  Deflation &defl = *(defl_p->defl);
687 
688  //If deflation space is complete: use initCG solver
689  if( defl.is_complete() ) {
690 
691  if (K) errorQuda("\nInitCG does not (yet) support preconditioning.");
692 
693  int iters = initCGsolve(out, in);
694  param.iter += iters;
695 
696  return;
697  }
698 
699  //Start (incremental) eigCG solver:
702 
703  ColorSpinorField *ep = ColorSpinorField::Create(csParam);//full precision accumulator
704  ColorSpinorField &e = *ep;
705  ColorSpinorField *rp = ColorSpinorField::Create(csParam);//full precision residual
706  ColorSpinorField &r = *rp;
707 
708  //deflate initial guess ('out'-field):
709  mat(r, out, e);
710  //
711  double r2 = xmyNorm(in, r);
712 
713  csParam.setPrecision(param.precision_sloppy);
714 
715  ColorSpinorField *ep_sloppy = ( mixed_prec ) ? ColorSpinorField::Create(csParam) : ep;
716  ColorSpinorField &eSloppy = *ep_sloppy;
717  ColorSpinorField *rp_sloppy = ( mixed_prec ) ? ColorSpinorField::Create(csParam) : rp;
718  ColorSpinorField &rSloppy = *rp_sloppy;
719 
720  const double stop = b2*param.tol*param.tol;
721  //start iterative refinement cycles (or just one eigcg call for full (solo) precision solver):
722  int logical_rhs_id = 0;
723  bool dcg_cycle = false;
724  do {
725  blas::zero(e);
726  defl(e, r);
727  //
728  eSloppy = e, rSloppy = r;
729 
730  if( dcg_cycle ) { //run DCG instead
731  if(!K) {
733  Kparam.tol = 5*param.inc_tol;//former cg_iterref_tol param
734  K = new CG(matSloppy, matPrecon, matPrecon, matPrecon, Kparam, profile);
735  }
736 
737  eigcg_args->run_residual_correction = true;
738  printfQuda("Running DCG correction cycle.\n");
739  }
740 
741  int iters = eigCGsolve(eSloppy, rSloppy);
742 
743  bool update_ritz = !dcg_cycle && (eigcg_args->restarts > 1) && !defl.is_complete(); //too uglyyy
744 
745  if (update_ritz) {
746 
747  defl.increment(*Vm, param.n_ev);
748  logical_rhs_id += 1;
749 
750  dcg_cycle = (logical_rhs_id >= max_eigcg_cycles);
751 
752  } else { // run DCG instead
753  dcg_cycle = true;
754  }
755 
756  // use mixed blas ??
757  e = eSloppy;
758  blas::xpy(e, out);
759  // compute the true residuals
760  blas::zero(e);
761  mat(r, out, e);
762  //
763  r2 = blas::xmyNorm(in, r);
764 
765  param.true_res = sqrt(r2 / b2);
767  PrintSummary( !dcg_cycle ? "EigCG:" : "DCG (correction cycle):", iters, r2, b2, stop, param.tol_hq);
768 
769  if( getVerbosity() >= QUDA_VERBOSE ) {
770  if( !dcg_cycle && (eigcg_args->restarts > 1) && !defl.is_complete() ) defl.verify();
771  }
772  } while ((r2 > stop) && mixed_prec);
773 
774  delete ep;
775  delete rp;
776 
777  if(mixed_prec){
778  delete ep_sloppy;
779  delete rp_sloppy;
780  }
781 
782  if (mixed_prec && max_eigcg_cycles > logical_rhs_id) {
783  printfQuda("Reset maximum eigcg cycles to %d (was %d)\n", logical_rhs_id, max_eigcg_cycles);
784  max_eigcg_cycles = logical_rhs_id;//adjust maximum allowed cycles based on the actual information
785  }
786 
787  param.rhs_idx += logical_rhs_id;
788 
789  if(defl.is_complete()) {
790  if(param.rhs_idx != param.deflation_grid) warningQuda("\nTotal rhs number (%d) does not match the deflation grid size (%d).\n", param.rhs_idx, param.deflation_grid);
791  if(Vm) delete Vm;//safe some space
792  Vm = nullptr;
793 
794  const int max_n_ev = defl.size(); // param.m;
795  printfQuda("\nRequested to reserve %d eigenvectors with max tol %le.\n", max_n_ev, param.eigenval_tol);
796  defl.reduce(param.eigenval_tol, max_n_ev);
797  }
798  return;
799  }
800 
801 } // namespace quda
void magma_Xheev(void *Mat, const int n, const int ldm, void *evalues, const int prec)
Conjugate-Gradient Solver.
Definition: invert_quda.h:639
CompositeColorSpinorField & Components()
static ColorSpinorField * Create(const ColorSpinorParam &param)
ColorSpinorField & Component(const int idx) const
bool is_complete()
Test whether the deflation space is complete and therefore cannot be further extended
Definition: deflation.h:159
int size()
return deflation space size
Definition: deflation.h:164
void increment(ColorSpinorField &V, int n_ev)
Definition: deflation.cpp:180
void reduce(double tol, int max_n_ev)
Definition: deflation.cpp:258
unsigned long long flops() const
Definition: dirac_quda.h:1909
void SetLanczos(Complex diag_val, Complex offdiag_val)
void RestartLanczos(ColorSpinorField *w, ColorSpinorFieldSet *v, const double inv_sqrt_r2)
ColorSpinorFieldSet * V2k
EigCGArgs(int m, int k)
void operator()(ColorSpinorField &out, ColorSpinorField &in)
int eigCGsolve(ColorSpinorField &out, ColorSpinorField &in)
void UpdateVm(ColorSpinorField &res, double beta, double sqrtr2)
int initCGsolve(ColorSpinorField &out, ColorSpinorField &in)
IncEigCG(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam &param, TimeProfile &profile)
void RestartVT(const double beta, const double rho)
QudaPrecision Precision() const
T data[N *N]
Definition: quda_matrix.h:70
const DiracMatrix & mat
Definition: invert_quda.h:465
bool convergence(double r2, double hq2, double r2_tol, double hq_tol)
Definition: solver.cpp:328
void PrintSummary(const char *name, int k, double r2, double b2, double r2_tol, double hq_tol)
Prints out the summary of the solver convergence (requires a verbosity of QUDA_SUMMARIZE)....
Definition: solver.cpp:386
SolverParam & param
Definition: invert_quda.h:470
static double stopping(double tol, double b2, QudaResidualType residual_type)
Set the solver L2 stopping condition.
Definition: solver.cpp:311
void PrintStats(const char *name, int k, double r2, double b2, double hq2)
Prints out the running statistics of the solver (requires a verbosity of QUDA_VERBOSE)
Definition: solver.cpp:373
const DiracMatrix & matPrecon
Definition: invert_quda.h:467
const DiracMatrix & matSloppy
Definition: invert_quda.h:466
double Last(QudaProfileType idx)
Definition: timer.h:254
void commGlobalReductionSet(bool global_reduce)
double omega
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:34
@ QUDA_CUDA_FIELD_LOCATION
Definition: enum_quda.h:326
@ QUDA_USE_INIT_GUESS_YES
Definition: enum_quda.h:430
@ QUDA_VERBOSE
Definition: enum_quda.h:267
@ QUDA_HEAVY_QUARK_RESIDUAL
Definition: enum_quda.h:195
@ QUDA_INC_EIGCG_INVERTER
Definition: enum_quda.h:117
@ QUDA_MR_INVERTER
Definition: enum_quda.h:110
@ QUDA_SD_INVERTER
Definition: enum_quda.h:112
@ QUDA_CG_INVERTER
Definition: enum_quda.h:107
@ QUDA_INVALID_INVERTER
Definition: enum_quda.h:133
@ QUDA_EIGCG_INVERTER
Definition: enum_quda.h:116
@ QUDA_PRESERVE_SOURCE_NO
Definition: enum_quda.h:238
@ QUDA_PRESERVE_SOURCE_YES
Definition: enum_quda.h:239
@ QUDA_DOUBLE_PRECISION
Definition: enum_quda.h:65
@ QUDA_EIGEN_EXTLIB
Definition: enum_quda.h:557
@ QUDA_MAGMA_EXTLIB
Definition: enum_quda.h:558
@ QUDA_ZERO_FIELD_CREATE
Definition: enum_quda.h:361
@ QUDA_COPY_FIELD_CREATE
Definition: enum_quda.h:362
#define checkLocation(...)
void axpyZpbx(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, double b)
void init()
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:79
unsigned long long flops
void xpay(ColorSpinorField &x, double a, ColorSpinorField &y)
Definition: blas_quda.h:45
void ax(double a, ColorSpinorField &x)
void zero(ColorSpinorField &a)
double norm2(const ColorSpinorField &a)
double reDotProduct(ColorSpinorField &x, ColorSpinorField &y)
double axpyNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:78
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
void xpy(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:41
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
Definition: blas_quda.h:24
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
void stop()
Stop profiling.
Definition: device.cpp:228
VectorXd RealVector
MatrixXcd DenseMatrix
void ComputeRitz(EigCGArgs &args)
MatrixXcd VectorSet
double norm2(const CloverField &a, bool inverse=false)
VectorXcd Vector
std::complex< double > Complex
Definition: quda_internal.h:86
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:120
@ QUDA_PROFILE_INIT
Definition: timer.h:106
@ QUDA_PROFILE_EPILOGUE
Definition: timer.h:110
@ QUDA_PROFILE_COMPUTE
Definition: timer.h:108
@ QUDA_PROFILE_FREE
Definition: timer.h:111
@ QUDA_PROFILE_PREAMBLE
Definition: timer.h:107
Stride< Dynamic, Dynamic > DynamicStride
Definition: deflation.cpp:17
ColorSpinorParam csParam
Definition: pack_test.cpp:25
QudaGaugeParam param
Definition: pack_test.cpp:18
QudaPrecision precision
Definition: invert_quda.h:136
QudaResidualType residual_type
Definition: invert_quda.h:49
QudaExtLibType extlib_type
Definition: invert_quda.h:248
QudaPrecision precision_precondition
Definition: invert_quda.h:145
QudaPrecision precision_sloppy
Definition: invert_quda.h:139
QudaInverterType inv_type
Definition: invert_quda.h:22
QudaPrecision precision_ritz
Definition: invert_quda.h:224
QudaInverterType inv_type_precondition
Definition: invert_quda.h:28
Deflation * defl
Definition: deflation.h:189
#define printfQuda(...)
Definition: util_quda.h:114
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define warningQuda(...)
Definition: util_quda.h:132
#define errorQuda(...)
Definition: util_quda.h:120