40 Solver(param, profile), mat(mat), matSloppy(matSloppy), matPrecon(matPrecon), K(0), Kparam(param)
46 K =
new CG(matPrecon, matPrecon,
Kparam, profile);
48 K =
new MR(matPrecon, matPrecon,
Kparam, profile);
70 const double b2 =
norm2(b);
73 printfQuda(
"Warning: inverting on zero-field source\n");
144 (*K)(*minvrPre, *rPre);
146 *minvrSloppy = *minvrPre;
161 double heavy_quark_res = 0.0;
164 double alpha = 0.0, beta=0.0;
167 double rMinvr_old = 0.0;
168 double r_new_Minvr_old = 0.0;
172 double rNorm =
sqrt(r2);
173 double r0Norm = rNorm;
174 double maxrx = rNorm;
175 double maxrr = rNorm;
191 int resIncreaseTotal = 0;
200 alpha = (
K) ? rMinvr/pAp : r2/pAp;
206 sigma = imag(cg_norm) >= 0.0 ? imag(cg_norm) : r2;
208 if(
K) rMinvr_old = rMinvr;
211 if(rNorm > maxrx) maxrx = rNorm;
212 if(rNorm > maxrr) maxrr = rNorm;
215 int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? 1 : 0;
216 int updateR = ((rNorm < delta*maxrr && r0Norm <= maxrr) || updateX) ? 1 : 0;
223 if( !(updateR || updateX) ){
229 (*K)(*minvrPre, *rPre);
233 *minvrSloppy = *minvrPre;
236 beta = (rMinvr - r_new_Minvr_old)/rMinvr_old;
237 axpyZpbx(alpha, *p, xSloppy, *minvrSloppy, beta);
240 axpyZpbx(alpha, *p, xSloppy, rSloppy, beta);
244 axpy(alpha, *p, xSloppy);
255 if(
sqrt(r2) > r0Norm && updateX) {
259 warningQuda(
"PCG: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
sqrt(r2), r0Norm, resIncreaseTotal);
261 if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal)
break;
276 (*K)(*minvrPre, *rPre);
279 *minvrSloppy = *minvrPre;
282 beta = rMinvr/rMinvr_old;
284 xpay(*minvrSloppy, beta, *p);
289 axpy(-rp, rSloppy, *p);
292 xpay(rSloppy, beta, *p);
296 PrintStats(
"PCG", k, r2, b2, heavy_quark_res);
317 printfQuda(
"CG: Reliable updates = %d\n", rUpdate);
325 double true_res =
xmyNorm(b, r);
void setPrecision(QudaPrecision precision, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION, bool force_native=false)
void operator()(ColorSpinorField &out, ColorSpinorField &in)
void axpyZpbx(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, double b)
QudaInverterType inv_type
QudaVerbosity getVerbosity()
__host__ __device__ ValueType sqrt(ValueType x)
void PrintStats(const char *name, int k, double r2, double b2, double hq2)
Prints out the running statistics of the solver (requires a verbosity of QUDA_VERBOSE) ...
bool convergence(double r2, double hq2, double r2_tol, double hq_tol)
double reDotProduct(ColorSpinorField &x, ColorSpinorField &y)
Complex axpyCGNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
__host__ __device__ void copy(T1 &a, const T2 &b)
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
QudaInverterType inv_type_precondition
QudaPreserveSource preserve_source
int max_res_increase_total
void xpay(ColorSpinorField &x, double a, ColorSpinorField &y)
double norm2(const CloverField &a, bool inverse=false)
double Last(QudaProfileType idx)
const DiracMatrix & matPrecon
QudaResidualType residual_type
static void fillInnerSolverParam(SolverParam &inner, const SolverParam &outer)
static double stopping(double tol, double b2, QudaResidualType residual_type)
Set the solver L2 stopping condition.
bool is_preconditioner
verbosity to use for preconditioner
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
std::complex< double > Complex
QudaPrecision precision_precondition
Conjugate-Gradient Solver.
unsigned long long flops() const
void xpy(ColorSpinorField &x, ColorSpinorField &y)
QudaPrecision precision_sloppy
__device__ void axpy(real a, const real *x, Link &y)
bool use_sloppy_partial_accumulator
PreconCG(DiracMatrix &mat, DiracMatrix &matSloppy, DiracMatrix &matPrecon, SolverParam ¶m, TimeProfile &profile)
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
QudaPrecision Precision() const
__device__ __host__ void zero(vector_type< scalar, n > &v)
const DiracMatrix & matSloppy
void updateR()
update the radius for halos.
void commGlobalReductionSet(bool global_reduce)