45 CG3(mmdag, mmdagSloppy, mmdagPrecon,
param, profile),
47 mmdagSloppy(matSloppy.Expose()),
48 mmdagPrecon(matPrecon.Expose()),
90 if (b2 == 0.0) b2 = r2;
134 CG3(mdagm, mdagmSloppy, mdagmPrecon,
param, profile),
136 mdagmSloppy(matSloppy.Expose()),
137 mdagmPrecon(matPrecon.Expose()),
217 printfQuda(
"Warning: inverting on zero-field source\n");
236 if (mixed_precision) {
267 const bool use_heavy_quark_res =
276 int resIncreaseTotal = 0;
279 const int hqmaxresIncrease = maxResIncrease + 1;
281 double heavy_quark_res = 0.0;
282 double heavy_quark_res_old = 0.0;
283 int hqresIncrease = 0;
284 bool L2breakdown =
false;
299 if (mixed_precision) {
307 if (mixed_precision) {
314 if (use_heavy_quark_res) {
316 heavy_quark_res_old = heavy_quark_res;
329 double rNorm =
sqrt(r2);
330 double r0Norm = rNorm;
331 double maxrx = rNorm;
332 double maxrr = rNorm;
334 bool restart =
false;
337 PrintStats(
"CG3", k, r2, b2, heavy_quark_res);
338 double rho = 1.0, gamma = 1.0;
343 double gamma_old = gamma;
348 if (k == 0 || restart) {
360 rho = rho/(rho-(gamma/gamma_old)*(r2/r2_old));
382 if (use_heavy_quark_res && k%heavy_quark_check==0) {
383 heavy_quark_res_old = heavy_quark_res;
384 if (mixed_precision) {
393 if (mixed_precision) {
395 if (rNorm > maxrx) maxrx = rNorm;
396 if (rNorm > maxrr) maxrr = rNorm;
397 bool update = (rNorm < delta*r0Norm && r0Norm <= maxrx);
398 update = ( update || (rNorm < delta*maxrr && r0Norm <= maxrr));
415 if (use_heavy_quark_res) {
435 if (
sqrt(r2) > r0Norm) {
439 "CG3: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
440 sqrt(r2), r0Norm, resIncreaseTotal);
441 if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
442 if (use_heavy_quark_res) {
445 warningQuda(
"CG3: solver exiting due to too many true residual norm increases");
454 if (use_heavy_quark_res and L2breakdown) {
456 heavy_quark_check = 1;
457 warningQuda(
"CG3: Restarting without reliable updates for heavy-quark residual");
460 if (heavy_quark_res > heavy_quark_res_old) {
462 warningQuda(
"CG3: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", heavy_quark_res, heavy_quark_res_old);
464 if (hqresIncrease > hqmaxresIncrease) {
465 warningQuda(
"CG3: solver exiting due to too many heavy quark residual norm increases");
484 if (
sqrt(r2) > r0Norm) {
488 "CG3: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
489 sqrt(r2), r0Norm, resIncreaseTotal);
490 if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
491 warningQuda(
"CG3: solver exiting due to too many true residual norm increases");
497 PrintStats(
"CG3", k, r2, b2, heavy_quark_res);
CG3(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam ¶m, TimeProfile &profile)
void operator()(ColorSpinorField &out, ColorSpinorField &in)
CG3NE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam ¶m, TimeProfile &profile)
void operator()(ColorSpinorField &out, ColorSpinorField &in)
CG3NR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam ¶m, TimeProfile &profile)
void operator()(ColorSpinorField &out, ColorSpinorField &in)
static ColorSpinorField * Create(const ColorSpinorParam ¶m)
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const =0
Apply M for the dirac op. E.g. the Schur Complement operator.
void Mdag(ColorSpinorField &out, const ColorSpinorField &in) const
Apply Mdag (daggered operator of M.
const Dirac * Expose() const
bool isStaggered() const
return if the operator is a staggered operator
unsigned long long flops() const
QudaPrecision Precision() const
bool convergence(double r2, double hq2, double r2_tol, double hq_tol)
bool convergenceHQ(double r2, double hq2, double r2_tol, double hq_tol)
Test for HQ solver convergence – ignore L2 residual.
void PrintSummary(const char *name, int k, double r2, double b2, double r2_tol, double hq_tol)
Prints out the summary of the solver convergence (requires a verbosity of QUDA_SUMMARIZE)....
static double stopping(double tol, double b2, QudaResidualType residual_type)
Set the solver L2 stopping condition.
void PrintStats(const char *name, int k, double r2, double b2, double hq2)
Prints out the running statistics of the solver (requires a verbosity of QUDA_VERBOSE)
const DiracMatrix & matSloppy
double Last(QudaProfileType idx)
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
cudaColorSpinorField * tmp
@ QUDA_CUDA_FIELD_LOCATION
@ QUDA_USE_INIT_GUESS_YES
@ QUDA_HEAVY_QUARK_RESIDUAL
@ QUDA_PRESERVE_SOURCE_NO
@ QUDA_PRESERVE_SOURCE_YES
@ QUDA_COMPUTE_NULL_VECTOR_NO
#define checkLocation(...)
void init()
Create the BLAS context.
double quadrupleCG3UpdateNorm(double a, double b, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v)
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
double quadrupleCG3InitNorm(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v)
double3 xpyHeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &r)
double caxpyNorm(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
void zero(ColorSpinorField &a)
double norm2(const ColorSpinorField &a)
double reDotProduct(ColorSpinorField &x, ColorSpinorField &y)
double axpyNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
void axpy(double a, ColorSpinorField &x, ColorSpinorField &y)
void xpy(ColorSpinorField &x, ColorSpinorField &y)
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
void axpby(double a, ColorSpinorField &x, double b, ColorSpinorField &y)
void stop()
Stop profiling.
std::complex< double > Complex
__host__ __device__ ValueType sqrt(ValueType x)
QudaPreserveSource preserve_source
QudaComputeNullVector compute_null_vector
int max_res_increase_total
QudaResidualType residual_type
QudaPrecision precision_sloppy
QudaUseInitGuess use_init_guess