27 for (
auto pi :
p)
delete pi;
37 CG(mmdag, mmdagSloppy,
param, profile), mmdag(
mat.Expose()), mmdagSloppy(
mat.Expose()),
init(false) {
58 warningQuda(
"Initial guess may not work as expected with CGNE\n");
68 CG(mdagm, mdagmSloppy,
param, profile), mdagm(
mat.Expose()), mdagmSloppy(
mat.Expose()),
init(false) {
111 if (Np < 0 || Np > 16)
errorQuda(
"Invalid value %d for solution_accumulator_pipeline\n", Np);
115 constexpr
bool alternative_reliable =
true;
116 warningQuda(
"Using alternative reliable updates. This feature is mostly ok but needs a little more testing in the real world.\n");
118 constexpr
bool alternative_reliable =
false;
128 printfQuda(
"Warning: inverting on zero-field source\n");
171 const double uhigh=
param.
precision == 8 ? std::numeric_limits<double>::epsilon()/2. : ((
param.
precision == 4) ? std::numeric_limits<float>::epsilon()/2. :
pow(2.,-13));
172 const double deps=
sqrt(u);
173 constexpr
double dfac = 1.1;
184 if(alternative_reliable){
222 if (Np != (
int)
p.size()) {
223 for (
auto &pi :
p)
delete pi;
227 for (
auto &pi :
p) *pi = rSloppy;
230 if (&
x != &xSloppy) {
237 const bool use_heavy_quark_res =
239 bool heavy_quark_restart =
false;
248 double heavy_quark_res = 0.0;
249 double heavy_quark_res_old = 0.0;
251 if (use_heavy_quark_res) {
253 heavy_quark_res_old = heavy_quark_res;
262 double rNorm =
sqrt(r2);
263 double r0Norm = rNorm;
264 double maxrx = rNorm;
265 double maxrr = rNorm;
276 const int hqmaxresIncrease = maxResIncrease + 1;
279 int resIncreaseTotal = 0;
280 int hqresIncrease = 0;
284 bool L2breakdown =
false;
295 int steps_since_reliable = 1;
299 if(alternative_reliable){
300 dinit = uhigh * (rNorm + Anorm * xNorm);
308 bool breakdown =
false;
312 if(alternative_reliable){
314 r2 = quadruple.x; Ap2 = quadruple.y; pAp = quadruple.z; ppnorm= quadruple.w;
318 r2 = triplet.x; Ap2 = triplet.y; pAp = triplet.z;
322 sigma = alpha[j]*(alpha[j] * Ap2 - pAp);
323 if (sigma < 0.0 || steps_since_reliable == 0) {
334 if(alternative_reliable){
348 sigma = imag(cg_norm) >= 0.0 ? imag(cg_norm) : r2;
356 if(alternative_reliable){
358 updateX = ( (d <= deps*sqrt(r2_old)) or (dfac * dinit > deps * r0Norm) ) and (d_new > deps*rNorm) and (d_new > dfac * dinit);
365 if (rNorm > maxrx) maxrx = rNorm;
366 if (rNorm > maxrr) maxrr = rNorm;
367 updateX = (rNorm <
delta*r0Norm && r0Norm <= maxrx) ? 1 : 0;
368 updateR = ((rNorm <
delta*maxrr && r0Norm <= maxrr) || updateX) ? 1 : 0;
380 beta = sigma / r2_old;
388 errorQuda(
"Not implemented pipelined CG with Np > 1");
396 if ( (j+1)%Np == 0 ) {
397 const auto alpha_ = std::unique_ptr<Complex[]>(
new Complex[Np]);
398 for (
int i=0;
i<Np;
i++) alpha_[
i] = alpha[
i];
399 std::vector<ColorSpinorField*> x_;
400 x_.push_back(&xSloppy);
410 if (use_heavy_quark_res && k%heavy_quark_check==0) {
411 if (&
x != &xSloppy) {
421 if(alternative_reliable){
423 pnorm = pnorm + alpha[j] * alpha[j]* (ppnorm);
425 d_new =
d + u*rNorm + uhigh*Anorm * xnorm;
426 if(steps_since_reliable==0)
429 steps_since_reliable++;
434 const auto alpha_ = std::unique_ptr<Complex[]>(
new Complex[Np]);
435 for (
int i=0;
i<=j;
i++) alpha_[
i] = alpha[
i];
436 std::vector<ColorSpinorField*> x_;
437 x_.push_back(&xSloppy);
438 std::vector<ColorSpinorField*> p_;
439 for (
int i=0;
i<=j;
i++) p_.push_back(
p[
i]);
454 if(alternative_reliable){
473 if (
sqrt(r2) > r0Norm && updateX) {
476 warningQuda(
"CG: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
477 sqrt(r2), r0Norm, resIncreaseTotal);
478 if ( resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
479 if (use_heavy_quark_res) {
482 warningQuda(
"CG: solver exiting due to too many true residual norm increases");
490 if (use_heavy_quark_res and L2breakdown) {
492 warningQuda(
"CG: Restarting without reliable updates for heavy-quark residual");
493 heavy_quark_restart =
true;
494 if (heavy_quark_res > heavy_quark_res_old) {
496 warningQuda(
"CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", heavy_quark_res, heavy_quark_res_old);
498 if (hqresIncrease > hqmaxresIncrease) {
499 warningQuda(
"CG: solver exiting due to too many heavy quark residual norm increases");
505 if (use_heavy_quark_res and heavy_quark_restart) {
508 heavy_quark_restart =
false;
518 steps_since_reliable = 0;
522 heavy_quark_res_old = heavy_quark_res;
533 if (use_heavy_quark_res) {
538 converged = L2done and HQdone;
542 if (converged && steps_since_reliable > 0 && (j+1)%Np != 0 ) {
543 const auto alpha_ = std::unique_ptr<Complex[]>(
new Complex[Np]);
544 for (
int i=0;
i<=j;
i++) alpha_[
i] = alpha[
i];
545 std::vector<ColorSpinorField*> x_;
546 x_.push_back(&xSloppy);
547 std::vector<ColorSpinorField*> p_;
548 for (
int i=0;
i<=j;
i++) p_.push_back(
p[
i]);
553 j = steps_since_reliable == 0 ? 0 : (j+1)%Np;
571 printfQuda(
"CG: Reliable updates = %d\n", rUpdate);
590 if (&
tmp3 != &
tmp)
delete tmp3_p;
591 if (&
tmp2 != &
tmp)
delete tmp2_p;
593 if (&rSloppy != &r)
delete r_sloppy;
594 if (&xSloppy != &
x)
delete x_sloppy;
607 errorQuda(
"QUDA_BLOCKSOLVER not built.");
615 using Eigen::MatrixXcd;
626 errorQuda(
"Warning: inverting on zero-field source - undefined for block solver\n");
668 r2avg += r2(
i,
i).real();
719 if (&
x != &xSloppy) {
726 const bool use_heavy_quark_res =
728 if(use_heavy_quark_res)
errorQuda(
"ERROR: heavy quark residual not supported in block solver");
779 bool allconverged =
true;
783 allconverged = allconverged && converged[
i];
787 MatrixXcd L = r2.llt().matrixL();
789 MatrixXcd Linv = C.inverse();
792 std::cout <<
"r2\n " << r2 << std::endl;
793 std::cout <<
"L\n " << L.adjoint() << std::endl;
817 std::cout <<
" pTp " << std::endl << pTp << std::endl;
818 std::cout <<
" L " << std::endl << L.adjoint() << std::endl;
819 std::cout <<
" C " << std::endl << C << std::endl;
837 alpha = pAp.inverse() * C;
847 beta = pAp.inverse();
868 L = r2.llt().matrixL();
886 std::cout <<
" rTr " << std::endl << pTp << std::endl;
887 std::cout <<
"QR" <<
S<< std::endl <<
"QP " <<
S.inverse()*
S << std::endl;;
916 std::cout <<
" pTp " << std::endl << pTp << std::endl;
917 std::cout <<
"S " <<
S<< std::endl <<
"C " << C << std::endl;
923 r2(j,j) = C(0,j)*
conj(C(0,j));
925 r2(j,j) += C(
i,j) *
conj(C(
i,j));
926 r2avg += r2(j,j).real();
935 allconverged = allconverged && converged[
i];
978 if (&
tmp3 != &
tmp)
delete tmp3_p;
979 if (&
tmp2 != &
tmp)
delete tmp2_p;
982 if (xSloppy.
Precision() !=
x.Precision())
delete x_sloppy;
1000 errorQuda(
"QUDA_BLOCKSOLVER not built.");
1007 const bool use_block =
true;
1013 using Eigen::MatrixXcd;
1028 errorQuda(
"Warning: inverting on zero-field source\n");
1044 std::cout <<
"b2m\n" << b2m << std::endl;
1060 ColorSpinorField &r = *
rp;
1061 ColorSpinorField &
y = *
yp;
1062 ColorSpinorField &Ap = *
App;
1063 ColorSpinorField &
tmp = *
tmpp;
1069 mat(r.Component(
i),
x.Component(
i),
y.Component(
i));
1092 ColorSpinorField &
tmp2 = *tmp2_p;
1094 ColorSpinorField *r_sloppy;
1107 ColorSpinorField *x_sloppy;
1118 ColorSpinorField *tmp3_p =
1121 ColorSpinorField &
tmp3 = *tmp3_p;
1123 ColorSpinorField &xSloppy = *x_sloppy;
1124 ColorSpinorField &rSloppy = *r_sloppy;
1129 ColorSpinorField &
p = *pp;
1131 ColorSpinorField &pnew = *ppnew;
1133 if (&
x != &xSloppy) {
1140 const bool use_heavy_quark_res =
1142 bool heavy_quark_restart =
false;
1154 if (use_heavy_quark_res) {
1156 heavy_quark_res_old[
i] = heavy_quark_res[
i];
1176 rNorm[
i] =
sqrt(r2(
i,
i).real());
1177 r0Norm[
i] = rNorm[
i];
1178 maxrx[
i] = rNorm[
i];
1179 maxrr[
i] = rNorm[
i];
1191 const int hqmaxresIncrease = maxResIncrease + 1;
1193 int resIncrease = 0;
1194 int resIncreaseTotal = 0;
1195 int hqresIncrease = 0;
1199 bool L2breakdown =
false;
1208 r2avg+=r2(
i,
i).real();
1210 PrintStats(
"CG", k, r2avg, b2avg, heavy_quark_res[0]);
1211 int steps_since_reliable = 1;
1212 bool allconverged =
true;
1216 allconverged = allconverged && converged[
i];
1248 std::cout <<
" pTp " << std::endl << pTp << std::endl;
1249 std::cout <<
"QR" <<
gamma<< std::endl <<
"QP " <<
gamma.inverse()*
gamma << std::endl;;
1257 bool breakdown =
false;
1269 if(use_block or
i==j)
1276 alpha = pAp.inverse() *
gamma.adjoint().inverse() * r2;
1278 std::cout <<
"alpha\n" << alpha << std::endl;
1281 std::cout <<
"pAp " << std::endl <<pAp << std::endl;
1282 std::cout <<
"pAp^-1 " << std::endl <<pAp.inverse() << std::endl;
1283 std::cout <<
"r2 " << std::endl <<r2 << std::endl;
1284 std::cout <<
"alpha " << std::endl <<alpha << std::endl;
1285 std::cout <<
"pAp^-1r2" << std::endl << pAp.inverse()*r2 << std::endl;
1292 blas::caxpy(-alpha(j,
i), Ap.Component(j), rSloppy.Component(
i));
1298 if(use_block or
i==j)
1316 rNorm[
i] =
sqrt(r2(
i,
i).real());
1317 if (rNorm[
i] > maxrx[
i]) maxrx[
i] = rNorm[
i];
1318 if (rNorm[
i] > maxrr[
i]) maxrr[
i] = rNorm[
i];
1319 updateX = (rNorm[
i] <
delta * r0Norm[
i] && r0Norm[
i] <= maxrx[
i]) ?
true :
false;
1320 updateR = ((rNorm[
i] <
delta * maxrr[
i] && r0Norm[
i] <= maxrr[
i]) || updateX) ? true :
false;
1329 if ( !(
updateR || updateX )) {
1331 beta =
gamma * r2_old.inverse() * sigma;
1333 std::cout <<
"beta\n" << beta << std::endl;
1360 double rcoeff= (j==0?1.0:0.0);
1399 std::cout <<
" pTp " << std::endl << pTp << std::endl;
1400 std::cout <<
"QR" <<
gamma<< std::endl <<
"QP " <<
gamma.inverse()*
gamma << std::endl;;
1405 if (use_heavy_quark_res && (k % heavy_quark_check) == 0) {
1406 if (&
x != &xSloppy) {
1419 steps_since_reliable++;
1431 mat(r.Component(
i),
y.Component(
i),
x.Component(
i),
tmp3.Component(
i));
1443 if (use_heavy_quark_res){
1452 if (
sqrt(r2(
i,
i).real()) > r0Norm[
i] && updateX) {
1455 warningQuda(
"CG: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
1456 sqrt(r2(
i,
i).real()), r0Norm[
i], resIncreaseTotal);
1457 if ( resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
1458 if (use_heavy_quark_res) {
1461 warningQuda(
"CG: solver exiting due to too many true residual norm increases");
1471 if (use_heavy_quark_res and L2breakdown) {
1473 warningQuda(
"CG: Restarting without reliable updates for heavy-quark residual");
1474 heavy_quark_restart =
true;
1475 if (heavy_quark_res[
i] > heavy_quark_res_old[
i]) {
1477 warningQuda(
"CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", heavy_quark_res[
i], heavy_quark_res_old[
i]);
1479 if (hqresIncrease > hqmaxresIncrease) {
1480 warningQuda(
"CG: solver exiting due to too many heavy quark residual norm increases");
1488 rNorm[
i] =
sqrt(r2(
i,
i).real());
1489 maxrr[
i] = rNorm[
i];
1490 maxrx[
i] = rNorm[
i];
1491 r0Norm[
i] = rNorm[
i];
1492 heavy_quark_res_old[
i] = heavy_quark_res[
i];
1496 if (use_heavy_quark_res and heavy_quark_restart) {
1499 heavy_quark_restart =
false;
1506 beta(
i,
i) = r2(
i,
i) / r2_old(
i,
i);
1511 steps_since_reliable = 0;
1517 allconverged =
true;
1520 r2avg+= r2(
i,
i).real();
1523 allconverged = allconverged && converged[
i];
1525 PrintStats(
"CG", k, r2avg, b2avg, heavy_quark_res[0]);
1528 if (use_heavy_quark_res) {
1534 converged[
i] = L2done and HQdone;
1557 printfQuda(
"CG: Reliable updates = %d\n", rUpdate);
1561 mat(r.Component(
i),
x.Component(
i),
y.Component(
i),
tmp3.Component(
i));
1578 if (&
tmp3 != &
tmp)
delete tmp3_p;
1579 if (&
tmp2 != &
tmp)
delete tmp2_p;
1581 if (rSloppy.Precision() != r.Precision())
delete r_sloppy;
1582 if (xSloppy.Precision() !=
x.Precision())
delete x_sloppy;
bool convergence(const double &r2, const double &hq2, const double &r2_tol, const double &hq_tol)
void xpay(ColorSpinorField &x, const double &a, ColorSpinorField &y)
static double stopping(const double &tol, const double &b2, QudaResidualType residual_type)
double3 cDotProductNormA(ColorSpinorField &a, ColorSpinorField &b)
#define QUDA_MAX_MULTI_SHIFT
Maximum number of shifts supported by the multi-shift solver. This number may be changed if need be...
QudaVerbosity getVerbosity()
Matrix< N, std::complex< T > > conj(const Matrix< N, std::complex< T > > &mat)
double norm2(const ColorSpinorField &a)
__host__ __device__ ValueType sqrt(ValueType x)
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
const DiracMatrix & matSloppy
std::complex< double > Complex
bool convergenceL2(const double &r2, const double &hq2, const double &r2_tol, const double &hq_tol)
void xpayz(ColorSpinorField &x, const double &a, ColorSpinorField &y, ColorSpinorField &z)
cudaColorSpinorField * tmp
double3 xpyHeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &r)
double axpyNorm(const double &a, ColorSpinorField &x, ColorSpinorField &y)
static ColorSpinorField * Create(const ColorSpinorParam ¶m)
double reDotProduct(ColorSpinorField &x, ColorSpinorField &y)
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
void ax(const double &a, ColorSpinorField &x)
ColorSpinorField & Component(const int idx) const
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
int max_res_increase_total
This is just a dummy structure we use for trove to define the required structure size.
size_t RealLength() const
CGNE(DiracMatrix &mat, DiracMatrix &matSloppy, SolverParam ¶m, TimeProfile &profile)
QudaComputeNullVector compute_null_vector
double Last(QudaProfileType idx)
void PrintSummary(const char *name, int k, const double &r2, const double &b2)
static unsigned int delta
QudaResidualType residual_type
void axpyZpbx(const double &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, const double &b)
CG(DiracMatrix &mat, DiracMatrix &matSloppy, SolverParam ¶m, TimeProfile &profile)
double true_res_hq_offset[QUDA_MAX_MULTI_SHIFT]
Complex axpyCGNorm(const double &a, ColorSpinorField &x, ColorSpinorField &y)
void tripleCGUpdate(const double &alpha, const double &beta, ColorSpinorField &q, ColorSpinorField &r, ColorSpinorField &x, ColorSpinorField &p)
double4 quadrupleCGReduction(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
#define checkLocation(...)
__host__ __device__ ValueType pow(ValueType x, ExponentType e)
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
double true_res_offset[QUDA_MAX_MULTI_SHIFT]
CGNR(DiracMatrix &mat, DiracMatrix &matSloppy, SolverParam ¶m, TimeProfile &profile)
double gamma(double) __attribute__((availability(macosx
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
void zero(ColorSpinorField &a)
std::vector< ColorSpinorField * > p
void axpy(const double &a, ColorSpinorField &x, ColorSpinorField &y)
void Mdag(ColorSpinorField &out, const ColorSpinorField &in) const
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const =0
void operator()(ColorSpinorField &out, ColorSpinorField &in)
unsigned long long flops() const
void PrintStats(const char *, int k, const double &r2, const double &b2, const double &hq2)
void operator()(ColorSpinorField &out, ColorSpinorField &in)
void xpy(ColorSpinorField &x, ColorSpinorField &y)
QudaUseInitGuess use_init_guess
void operator()(ColorSpinorField &out, ColorSpinorField &in)
void solve(ColorSpinorField &out, ColorSpinorField &in)
QudaPrecision precision_sloppy
bool use_sloppy_partial_accumulator
bool convergenceHQ(const double &r2, const double &hq2, const double &r2_tol, const double &hq_tol)
__host__ __device__ ValueType conj(ValueType x)
int solution_accumulator_pipeline
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
static __inline__ size_t size_t d
QudaPrecision Precision() const
double3 tripleCGReduction(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
void updateR()
update the radius for halos.