40 for (
auto pi : p)
if (pi)
delete pi;
46 if (rSloppyp)
delete rSloppyp;
47 if (xSloppyp)
delete xSloppyp;
49 if (tmpp)
delete tmpp;
51 if (tmp2p && tmpp != tmp2p)
delete tmp2p;
54 if (rnewp)
delete rnewp;
64 CG(mmdag, mmdagSloppy, mmdagPrecon, mmdagEig,
param, profile),
66 mmdagSloppy(matSloppy.Expose()),
67 mmdagPrecon(matPrecon.Expose()),
68 mmdagEig(matEig.Expose()),
108 if (b2 == 0.0) b2 = r2;
154 CG(mdagm, mdagmSloppy, mdagmPrecon, mdagmEig,
param, profile),
156 mdagmSloppy(matSloppy.Expose()),
157 mdagmPrecon(matPrecon.Expose()),
158 mdagmEig(matEig.Expose()),
238 if (Np < 0 || Np > 16)
errorQuda(
"Invalid value %d for solution_accumulator_pipeline\n", Np);
253 printfQuda(
"Warning: inverting on zero-field source\n");
287 tmp3p = tmp2p = tmpp;
323 if (Np != (
int)p.size()) {
324 for (
auto &pi : p)
delete pi;
336 const double deps=
sqrt(u);
337 constexpr
double dfac = 1.1;
360 const double hq_res_stall_check = is_pure_double ? 0. : uhigh * uhigh * 1e-60;
368 if (b2 == 0) b2 = r2;
388 if (Np != (
int)p.size()) {
389 for (
auto &pi : p)
delete pi;
396 for (
auto &p_i : p) *p_i = p_init ? *p_init : rSloppy;
400 if (r2_old_init != 0.0 and p_init) {
401 r2_old = r2_old_init;
408 const bool use_heavy_quark_res =
410 bool heavy_quark_restart =
false;
419 double heavy_quark_res = 0.0;
420 double heavy_quark_res_old = 0.0;
422 if (use_heavy_quark_res) {
424 heavy_quark_res_old = heavy_quark_res;
432 double rNorm =
sqrt(r2);
433 double r0Norm = rNorm;
434 double maxrx = rNorm;
435 double maxrr = rNorm;
436 double maxr_deflate = rNorm;
447 const int hqmaxresRestartTotal
451 int resIncreaseTotal = 0;
452 int hqresIncrease = 0;
453 int hqresRestartTotal = 0;
457 bool L2breakdown =
false;
458 const double L2breakdown_eps = 100. * uhigh;
471 int steps_since_reliable = 1;
476 dinit = uhigh * (rNorm + Anorm * xNorm);
484 bool breakdown =
false;
490 r2 = quadruple.x; Ap2 = quadruple.y; pAp = quadruple.z; ppnorm= quadruple.w;
494 r2 = triplet.x; Ap2 = triplet.y; pAp = triplet.z;
498 sigma = alpha[j]*(alpha[j] * Ap2 - pAp);
499 if (sigma < 0.0 || steps_since_reliable == 0) {
523 sigma = imag(cg_norm) >= 0.0 ? imag(cg_norm) : r2;
533 updateX = ( (d <= deps*
sqrt(r2_old)) or (dfac * dinit > deps * r0Norm) ) and (d_new > deps*rNorm) and (d_new > dfac * dinit);
536 if (rNorm > maxrx) maxrx = rNorm;
537 if (rNorm > maxrr) maxrr = rNorm;
538 updateX = (rNorm < delta * r0Norm && r0Norm <= maxrx) ? 1 : 0;
539 updateR = ((rNorm < delta * maxrr && r0Norm <= maxrr) || updateX) ? 1 : 0;
547 if (use_heavy_quark_res and L2breakdown
554 beta = sigma / r2_old;
561 errorQuda(
"Not implemented pipelined CG with Np > 1");
569 if ( (j+1)%Np == 0 ) {
570 std::vector<ColorSpinorField*> x_;
571 x_.push_back(&xSloppy);
576 blas::xpayz(rSloppy, beta, *p[j], *p[(j + 1) % Np]);
580 if (use_heavy_quark_res && k % heavy_quark_check == 0) {
581 if (&x != &xSloppy) {
593 pnorm = pnorm + alpha[j] * alpha[j]* (ppnorm);
595 d_new = d + u*rNorm + uhigh*Anorm * xnorm;
599 steps_since_reliable++;
604 std::vector<ColorSpinorField*> x_;
605 x_.push_back(&xSloppy);
606 std::vector<ColorSpinorField*> p_;
607 for (
int i=0; i<=j; i++) p_.push_back(p[i]);
625 maxr_deflate =
sqrt(r2);
649 if (
sqrt(r2) > r0Norm && updateX and not L2breakdown) {
653 "CG: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
654 sqrt(r2), r0Norm, resIncreaseTotal);
656 if ((use_heavy_quark_res and
sqrt(r2) < L2breakdown_eps) or resIncrease > maxResIncrease
657 or resIncreaseTotal > maxResIncreaseTotal or r2 <
stop) {
658 if (use_heavy_quark_res) {
662 if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal or r2 <
stop) {
663 warningQuda(
"CG: solver exiting due to too many true residual norm increases");
673 if (use_heavy_quark_res and L2breakdown) {
676 warningQuda(
"CG: Restarting without reliable updates for heavy-quark residual (total #inc %i)",
678 heavy_quark_restart =
true;
680 if (heavy_quark_res > heavy_quark_res_old) {
682 warningQuda(
"CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e",
683 heavy_quark_res, heavy_quark_res_old);
685 if (hqresIncrease > hqmaxresIncrease) {
686 warningQuda(
"CG: solver exiting due to too many heavy quark residual norm increases (%i/%i)",
687 hqresIncrease, hqmaxresIncrease);
694 if (hqresRestartTotal > hqmaxresRestartTotal) {
695 warningQuda(
"CG: solver exiting due to too many heavy quark residual restarts (%i/%i)", hqresRestartTotal,
696 hqmaxresRestartTotal);
701 if (use_heavy_quark_res and heavy_quark_restart) {
704 heavy_quark_restart =
false;
714 steps_since_reliable = 0;
718 heavy_quark_res_old = heavy_quark_res;
729 if (use_heavy_quark_res) {
734 converged = L2done and HQdone;
738 if (converged && steps_since_reliable > 0 && (j+1)%Np != 0 ) {
739 std::vector<ColorSpinorField*> x_;
740 x_.push_back(&xSloppy);
741 std::vector<ColorSpinorField*> p_;
742 for (
int i=0; i<=j; i++) p_.push_back(p[i]);
746 j = steps_since_reliable == 0 ? 0 : (j+1)%Np;
765 printfQuda(
"CG: Reliable updates = %d\n", rUpdate);
794 errorQuda(
"QUDA_BLOCKSOLVER not built.");
802 using Eigen::MatrixXcd;
813 errorQuda(
"Warning: inverting on zero-field source - undefined for block solver\n");
852 tmp3p = tmp2p = tmpp;
882 r2avg += r2(i,i).real();
896 if (&x != &xSloppy) {
903 const bool use_heavy_quark_res =
905 if(use_heavy_quark_res)
errorQuda(
"ERROR: heavy quark residual not supported in block solver");
956 bool allconverged =
true;
960 allconverged = allconverged && converged[i];
964 MatrixXcd L = r2.llt().matrixL();
966 MatrixXcd Linv = C.inverse();
969 std::cout <<
"r2\n " << r2 << std::endl;
970 std::cout <<
"L\n " << L.adjoint() << std::endl;
994 std::cout <<
" pTp " << std::endl << pTp << std::endl;
995 std::cout <<
" L " << std::endl << L.adjoint() << std::endl;
996 std::cout <<
" C " << std::endl << C << std::endl;
1009 if (i!=j) pAp(j,i) =
std::conj(pAp(i,j));
1014 alpha = pAp.inverse() * C;
1024 beta = pAp.inverse();
1045 L = r2.llt().matrixL();
1063 std::cout <<
" rTr " << std::endl << pTp << std::endl;
1064 std::cout <<
"QR" << S<< std::endl <<
"QP " << S.inverse()*S << std::endl;;
1093 std::cout <<
" pTp " << std::endl << pTp << std::endl;
1094 std::cout <<
"S " << S<< std::endl <<
"C " << C << std::endl;
1100 r2(j,j) = C(0,j)*
conj(C(0,j));
1102 r2(j,j) += C(i,j) *
conj(C(i,j));
1103 r2avg += r2(j,j).real();
1109 allconverged =
true;
1112 allconverged = allconverged && converged[i];
1166 #define BLOCKCG_GS 1
1169 errorQuda(
"QUDA_BLOCKSOLVER not built.");
1176 const bool use_block =
true;
1182 using Eigen::MatrixXcd;
1197 errorQuda(
"Warning: inverting on zero-field source\n");
1213 std::cout <<
"b2m\n" << b2m << std::endl;
1245 tmp3p = tmp2p = tmpp;
1257 ColorSpinorField &r = *rp;
1258 ColorSpinorField &y = *yp;
1259 ColorSpinorField &p = *pp;
1260 ColorSpinorField &pnew = *rnewp;
1261 ColorSpinorField &Ap = *App;
1262 ColorSpinorField &
tmp = *tmpp;
1263 ColorSpinorField &tmp2 = *tmp2p;
1264 ColorSpinorField &tmp3 = *tmp3p;
1265 ColorSpinorField &rSloppy = *rSloppyp;
1278 printfQuda(
"r2[%i] %e\n", i, r2(i,i).real());
1294 if (&x != &xSloppy) {
1301 const bool use_heavy_quark_res =
1303 bool heavy_quark_restart =
false;
1315 if (use_heavy_quark_res) {
1317 heavy_quark_res_old[i] = heavy_quark_res[i];
1337 rNorm[i] =
sqrt(r2(i,i).real());
1338 r0Norm[i] = rNorm[i];
1339 maxrx[i] = rNorm[i];
1340 maxrr[i] = rNorm[i];
1352 const int hqmaxresIncrease = maxResIncrease + 1;
1354 int resIncrease = 0;
1355 int resIncreaseTotal = 0;
1356 int hqresIncrease = 0;
1360 bool L2breakdown =
false;
1369 r2avg+=r2(i,i).real();
1371 PrintStats(
"CG", k, r2avg, b2avg, heavy_quark_res[0]);
1372 int steps_since_reliable = 1;
1373 bool allconverged =
true;
1377 allconverged = allconverged && converged[i];
1409 std::cout <<
" pTp " << std::endl << pTp << std::endl;
1410 std::cout <<
"QR" << gamma<< std::endl <<
"QP " << gamma.inverse()*gamma << std::endl;;
1414 matSloppy(Ap.Component(i), p.Component(i),
tmp.Component(i), tmp2.Component(i));
1418 bool breakdown =
false;
1430 if(use_block or i==j)
1437 alpha = pAp.inverse() * gamma.adjoint().inverse() * r2;
1439 std::cout <<
"alpha\n" << alpha << std::endl;
1442 std::cout <<
"pAp " << std::endl <<pAp << std::endl;
1443 std::cout <<
"pAp^-1 " << std::endl <<pAp.inverse() << std::endl;
1444 std::cout <<
"r2 " << std::endl <<r2 << std::endl;
1445 std::cout <<
"alpha " << std::endl <<alpha << std::endl;
1446 std::cout <<
"pAp^-1r2" << std::endl << pAp.inverse()*r2 << std::endl;
1453 blas::caxpy(-alpha(j,i), Ap.Component(j), rSloppy.Component(i));
1459 if(use_block or i==j)
1477 rNorm[i] =
sqrt(r2(i,i).real());
1478 if (rNorm[i] > maxrx[i]) maxrx[i] = rNorm[i];
1479 if (rNorm[i] > maxrr[i]) maxrr[i] = rNorm[i];
1480 updateX = (rNorm[i] < delta * r0Norm[i] && r0Norm[i] <= maxrx[i]) ?
true :
false;
1481 updateR = ((rNorm[i] < delta * maxrr[i] && r0Norm[i] <= maxrr[i]) || updateX) ? true :
false;
1490 if ( !(
updateR || updateX )) {
1492 beta = gamma * r2_old.inverse() * sigma;
1494 std::cout <<
"beta\n" << beta << std::endl;
1502 blas::caxpy(alpha(j,i),p.Component(j),xSloppy.Component(i));
1514 blas::axpy(1.0,r.Component(i),pnew.Component(i));
1521 double rcoeff= (j==0?1.0:0.0);
1523 blas::caxpy(beta(j,i),p.Component(j),pnew.Component(i));
1530 blas::copy(p.Component(i), pnew.Component(i));
1560 std::cout <<
" pTp " << std::endl << pTp << std::endl;
1561 std::cout <<
"QR" << gamma<< std::endl <<
"QP " << gamma.inverse()*gamma << std::endl;;
1566 if (use_heavy_quark_res && (k % heavy_quark_check) == 0) {
1567 if (&x != &xSloppy) {
1580 steps_since_reliable++;
1584 blas::axpy(alpha(i,i).real(), p.Component(i), xSloppy.Component(i));
1599 blas::copy(rSloppy.Component(i), r.Component(i));
1604 if (use_heavy_quark_res){
1613 if (
sqrt(r2(i,i).real()) > r0Norm[i] && updateX) {
1616 warningQuda(
"CG: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
1617 sqrt(r2(i,i).real()), r0Norm[i], resIncreaseTotal);
1618 if ( resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
1619 if (use_heavy_quark_res) {
1622 warningQuda(
"CG: solver exiting due to too many true residual norm increases");
1632 if (use_heavy_quark_res and L2breakdown) {
1634 warningQuda(
"CG: Restarting without reliable updates for heavy-quark residual");
1635 heavy_quark_restart =
true;
1636 if (heavy_quark_res[i] > heavy_quark_res_old[i]) {
1638 warningQuda(
"CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", heavy_quark_res[i], heavy_quark_res_old[i]);
1640 if (hqresIncrease > hqmaxresIncrease) {
1641 warningQuda(
"CG: solver exiting due to too many heavy quark residual norm increases");
1649 rNorm[i] =
sqrt(r2(i,i).real());
1650 maxrr[i] = rNorm[i];
1651 maxrx[i] = rNorm[i];
1652 r0Norm[i] = rNorm[i];
1653 heavy_quark_res_old[i] = heavy_quark_res[i];
1657 if (use_heavy_quark_res and heavy_quark_restart) {
1660 heavy_quark_restart =
false;
1664 double rp =
blas::reDotProduct(rSloppy.Component(i), p.Component(i)) / (r2(i,i).real());
1665 blas::axpy(-rp, rSloppy.Component(i), p.Component(i));
1667 beta(i,i) = r2(i,i) / r2_old(i,i);
1668 blas::xpay(rSloppy.Component(i), beta(i,i).real(), p.Component(i));
1672 steps_since_reliable = 0;
1678 allconverged =
true;
1681 r2avg+= r2(i,i).real();
1684 allconverged = allconverged && converged[i];
1686 PrintStats(
"CG", k, r2avg, b2avg, heavy_quark_res[0]);
1689 if (use_heavy_quark_res) {
1695 converged[i] = L2done and HQdone;
1718 printfQuda(
"CG: Reliable updates = %d\n", rUpdate);
Conjugate-Gradient Solver.
void operator()(ColorSpinorField &out, ColorSpinorField &in)
Run CG.
void blocksolve(ColorSpinorField &out, ColorSpinorField &in)
CG(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam ¶m, TimeProfile &profile)
void operator()(ColorSpinorField &out, ColorSpinorField &in)
Run CG.
CGNE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam ¶m, TimeProfile &profile)
void operator()(ColorSpinorField &out, ColorSpinorField &in)
Run CG.
CGNR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam ¶m, TimeProfile &profile)
static ColorSpinorField * Create(const ColorSpinorParam ¶m)
ColorSpinorField & Component(const int idx) const
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const =0
Apply M for the dirac op. E.g. the Schur Complement operator.
void Mdag(ColorSpinorField &out, const ColorSpinorField &in) const
Apply Mdag (daggered operator of M.
const Dirac * Expose() const
bool isStaggered() const
return if the operator is a staggered operator
unsigned long long flops() const
void deflate(std::vector< ColorSpinorField * > &sol, const std::vector< ColorSpinorField * > &src, const std::vector< ColorSpinorField * > &evecs, const std::vector< Complex > &evals, bool accumulate=false) const
Deflate a set of source vectors with a given eigenspace.
void computeEvals(const DiracMatrix &mat, std::vector< ColorSpinorField * > &evecs, std::vector< Complex > &evals, int size)
Compute eigenvalues and their residiua.
QudaPrecision Precision() const
double precisionEpsilon(QudaPrecision prec=QUDA_INVALID_PRECISION) const
Returns the epsilon tolerance for a given precision, by default returns the solver precision.
bool convergenceL2(double r2, double hq2, double r2_tol, double hq_tol)
Test for L2 solver convergence – ignore HQ residual.
bool convergence(double r2, double hq2, double r2_tol, double hq_tol)
std::vector< ColorSpinorField * > evecs
bool convergenceHQ(double r2, double hq2, double r2_tol, double hq_tol)
Test for HQ solver convergence – ignore L2 residual.
void destroyDeflationSpace()
Destroy the allocated deflation space.
void PrintSummary(const char *name, int k, double r2, double b2, double r2_tol, double hq_tol)
Prints out the summary of the solver convergence (requires a verbosity of QUDA_SUMMARIZE)....
const DiracMatrix & matEig
static double stopping(double tol, double b2, QudaResidualType residual_type)
Set the solver L2 stopping condition.
std::vector< Complex > evals
void PrintStats(const char *name, int k, double r2, double b2, double hq2)
Prints out the running statistics of the solver (requires a verbosity of QUDA_VERBOSE)
void constructDeflationSpace(const ColorSpinorField &meta, const DiracMatrix &mat)
Constructs the deflation space and eigensolver.
const DiracMatrix & matPrecon
const DiracMatrix & matSloppy
double Last(QudaProfileType idx)
void commGlobalReductionSet(bool global_reduce)
bool alternative_reliable
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
cudaColorSpinorField * tmp
@ QUDA_CUDA_FIELD_LOCATION
@ QUDA_USE_INIT_GUESS_YES
@ QUDA_HEAVY_QUARK_RESIDUAL
@ QUDA_PRESERVE_SOURCE_NO
@ QUDA_PRESERVE_SOURCE_YES
@ QUDA_COMPUTE_NULL_VECTOR_NO
Matrix< N, std::complex< T > > conj(const Matrix< N, std::complex< T > > &mat)
#define checkPrecision(...)
#define checkLocation(...)
void init()
Create the BLAS context.
double4 quadrupleCGReduction(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Complex axpyCGNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
void axpyZpbx(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, double b)
void xpayz(ColorSpinorField &x, double a, ColorSpinorField &y, ColorSpinorField &z)
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
double3 tripleCGReduction(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
void xpay(ColorSpinorField &x, double a, ColorSpinorField &y)
double3 xpyHeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &r)
void ax(double a, ColorSpinorField &x)
void zero(ColorSpinorField &a)
double norm2(const ColorSpinorField &a)
void tripleCGUpdate(double alpha, double beta, ColorSpinorField &q, ColorSpinorField &r, ColorSpinorField &x, ColorSpinorField &p)
double reDotProduct(ColorSpinorField &x, ColorSpinorField &y)
double axpyNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
void axpy(double a, ColorSpinorField &x, ColorSpinorField &y)
double3 cDotProductNormA(ColorSpinorField &a, ColorSpinorField &b)
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
void xpy(ColorSpinorField &x, ColorSpinorField &y)
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
void axpby(double a, ColorSpinorField &x, double b, ColorSpinorField &y)
void stop()
Stop profiling.
__host__ __device__ ValueType conj(ValueType x)
std::complex< double > Complex
__host__ __device__ ValueType sqrt(ValueType x)
void updateR()
update the radius for halos.
#define QUDA_MAX_MULTI_SHIFT
Maximum number of shifts supported by the multi-shift solver. This number may be changed if need be.
QudaPreserveSource preserve_source
QudaComputeNullVector compute_null_vector
bool is_preconditioner
verbosity to use for preconditioner
bool use_sloppy_partial_accumulator
int max_res_increase_total
QudaResidualType residual_type
bool use_alternative_reliable
QudaPrecision precision_sloppy
double true_res_offset[QUDA_MAX_MULTI_SHIFT]
int solution_accumulator_pipeline
double true_res_hq_offset[QUDA_MAX_MULTI_SHIFT]
QudaUseInitGuess use_init_guess
int max_hq_res_restart_total
bool global_reduction
whether the solver acting as a preconditioner for another solver
QudaVerbosity getVerbosity()