17 Solver(param, profile), mat(mat), matSloppy(matSloppy), matDagSloppy(matSloppy),
init(false)
53 printfQuda(
"Warning: inverting on zero-field source\n");
62 strcpy(name, is_cg3ne ?
"CG3NE" :
"CG3NR");
100 const bool use_heavy_quark_res =
109 int resIncreaseTotal = 0;
112 const int hqmaxresIncrease = maxResIncrease + 1;
114 double heavy_quark_res = 0.0;
115 double heavy_quark_res_old = 0.0;
116 int hqresIncrease = 0;
117 bool L2breakdown =
false;
132 if (mixed_precision) {
140 if (mixed_precision) {
147 if (use_heavy_quark_res) {
149 heavy_quark_res_old = heavy_quark_res;
162 double rNorm =
sqrt(r2);
163 double r0Norm = rNorm;
164 double maxrx = rNorm;
165 double maxrr = rNorm;
167 bool restart =
false;
170 double rho = 1.0, gamma = 1.0, Ar2 = 1.0;
175 double gamma_old = gamma;
176 double Ar2_old = Ar2;
185 if(k==0 || restart) {
199 rho = rho/(rho-(gamma/gamma_old)*(r2/r2_old));
201 rho = rho/(rho-(gamma/gamma_old)*(Ar2/Ar2_old));
223 if (use_heavy_quark_res && k%heavy_quark_check==0) {
224 heavy_quark_res_old = heavy_quark_res;
225 if (mixed_precision) {
234 if (mixed_precision) {
237 if (rNorm > maxrx) maxrx = rNorm;
238 if (rNorm > maxrr) maxrr = rNorm;
239 bool update = (rNorm < delta*r0Norm && r0Norm <= maxrx);
240 update = ( update || (rNorm < delta*maxrr && r0Norm <= maxrr));
257 if (use_heavy_quark_res) {
279 warningQuda(
"%s: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
280 name,
sqrt(r2),
sqrt(r2_old), resIncreaseTotal);
281 if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
282 if (use_heavy_quark_res) {
285 warningQuda(
"%s: solver exiting due to too many true residual norm increases", name);
294 if (use_heavy_quark_res and L2breakdown) {
296 heavy_quark_check = 1;
297 warningQuda(
"%s: Restarting without reliable updates for heavy-quark residual", name);
300 if (heavy_quark_res > heavy_quark_res_old) {
302 warningQuda(
"%s: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", name, heavy_quark_res, heavy_quark_res_old);
304 if (hqresIncrease > hqmaxresIncrease) {
305 warningQuda(
"%s: solver exiting due to too many heavy quark residual norm increases", name);
327 warningQuda(
"CG3: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
328 sqrt(r2),
sqrt(r2_old), resIncreaseTotal);
329 if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
330 warningQuda(
"CG3: solver exiting due to too many true residual norm increases");
void setPrecision(QudaPrecision precision, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION, bool force_native=false)
bool convergenceHQ(double r2, double hq2, double r2_tol, double hq_tol)
Test for HQ solver convergence – ignore L2 residual.
double caxpyNorm(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
QudaInverterType inv_type
double norm2(const ColorSpinorField &a)
__host__ __device__ ValueType sqrt(ValueType x)
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
void PrintStats(const char *name, int k, double r2, double b2, double hq2)
Prints out the running statistics of the solver (requires a verbosity of QUDA_VERBOSE) ...
double3 xpyHeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &r)
static ColorSpinorField * Create(const ColorSpinorParam ¶m)
bool convergence(double r2, double hq2, double r2_tol, double hq_tol)
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
ColorSpinorField * rS_oldp
QudaPreserveSource preserve_source
int max_res_increase_total
const DiracMatrix & matSloppy
void doubleCG3Update(double a, double b, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
void operator()(ColorSpinorField &out, ColorSpinorField &in)
CG3NE(DiracMatrix &mat, DiracMatrix &matSloppy, SolverParam ¶m, TimeProfile &profile)
QudaComputeNullVector compute_null_vector
double Last(QudaProfileType idx)
QudaResidualType residual_type
static double stopping(double tol, double b2, QudaResidualType residual_type)
Set the solver L2 stopping condition.
void axpy(double a, ColorSpinorField &x, ColorSpinorField &y)
#define checkLocation(...)
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
std::complex< double > Complex
void init()
Create the CUBLAS context.
void zero(ColorSpinorField &a)
void doubleCG3Init(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
ColorSpinorField * xS_oldp
ColorSpinorField * AAdagrSp
unsigned long long flops() const
double doubleCG3UpdateNorm(double a, double b, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
void xpy(ColorSpinorField &x, ColorSpinorField &y)
void axpby(double a, ColorSpinorField &x, double b, ColorSpinorField &y)
double axpyNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
QudaUseInitGuess use_init_guess
ColorSpinorField * AdagrSp
QudaPrecision precision_sloppy
void PrintSummary(const char *name, int k, double r2, double b2, double r2_tol, double hq_tol)
Prints out the summary of the solver convergence (requires a verbosity of QUDA_SUMMARIZE). Assumes SolverParam.true_res and SolverParam.true_res_hq has been set.
double doubleCG3InitNorm(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
QudaPrecision Precision() const