20 Solver(param, profile), mat(mat), matSloppy(matSloppy)
34 const double b2 =
norm2(b);
37 printfQuda(
"Warning: inverting on zero-field source\n");
99 const bool use_heavy_quark_res =
101 bool heavy_quark_restart =
false;
110 double heavy_quark_res = 0.0;
111 double heavy_quark_res_old = 0.0;
113 if (use_heavy_quark_res) {
115 heavy_quark_res_old = heavy_quark_res;
117 const int heavy_quark_check = 1;
119 double alpha=0.0, beta=0.0;
123 double rNorm =
sqrt(r2);
124 double r0Norm = rNorm;
125 double maxrx = rNorm;
126 double maxrr = rNorm;
136 const int hqmaxresIncrease = maxResIncrease + 1;
139 int resIncreaseTotal = 0;
140 int hqresIncrease = 0;
144 bool L2breakdown =
false;
154 int steps_since_reliable = 1;
158 matSloppy(Ap, p, tmp, tmp2);
162 bool breakdown =
false;
166 r2 = triplet.x;
double Ap2 = triplet.y; pAp = triplet.z;
170 sigma = alpha*(alpha * Ap2 - pAp);
171 if (sigma < 0.0 || steps_since_reliable==0) {
186 sigma = imag(cg_norm) >= 0.0 ? imag(cg_norm) : r2;
191 if (rNorm > maxrx) maxrx = rNorm;
192 if (rNorm > maxrr) maxrr = rNorm;
193 int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? 1 : 0;
194 int updateR = ((rNorm < delta*maxrr && r0Norm <= maxrr) || updateX) ? 1 : 0;
204 if ( !(updateR || updateX)) {
206 beta = sigma / r2_old;
212 if (use_heavy_quark_res && k%heavy_quark_check==0) {
213 if (&x != &xSloppy) {
222 steps_since_reliable++;
239 if (
sqrt(r2) > r0Norm && updateX) {
242 warningQuda(
"CG: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
243 sqrt(r2), r0Norm, resIncreaseTotal);
244 if ( resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
245 if (use_heavy_quark_res) L2breakdown =
true;
252 if (use_heavy_quark_res and L2breakdown) {
254 warningQuda(
"CG: Restarting without reliable updates for heavy-quark residual");
255 heavy_quark_restart =
true;
256 if (heavy_quark_res > heavy_quark_res_old) {
258 warningQuda(
"CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", heavy_quark_res, heavy_quark_res_old);
260 if (hqresIncrease > hqmaxresIncrease)
break;
270 if (use_heavy_quark_res and heavy_quark_restart) {
273 heavy_quark_restart =
false;
285 steps_since_reliable = 0;
286 heavy_quark_res_old = heavy_quark_res;
297 if (use_heavy_quark_res) {
302 converged = L2done and HQdone;
323 printfQuda(
"CG: Reliable updates = %d\n", rUpdate);
328 #if (__COMPUTE_CAPABILITY__ >= 200)
344 if (&tmp3 != &tmp)
delete tmp3_p;
345 if (&tmp2 != &tmp)
delete tmp2_p;
bool convergence(const double &r2, const double &hq2, const double &r2_tol, const double &hq_tol)
void setPrecision(QudaPrecision precision)
double3 tripleCGReductionCuda(cudaColorSpinorField &x, cudaColorSpinorField &y, cudaColorSpinorField &z)
static double stopping(const double &tol, const double &b2, QudaResidualType residual_type)
QudaVerbosity getVerbosity()
__host__ __device__ ValueType sqrt(ValueType x)
double axpyNormCuda(const double &a, cudaColorSpinorField &x, cudaColorSpinorField &y)
std::complex< double > Complex
bool convergenceL2(const double &r2, const double &hq2, const double &r2_tol, const double &hq_tol)
void axpyZpbxCuda(const double &a, cudaColorSpinorField &x, cudaColorSpinorField &y, cudaColorSpinorField &z, const double &b)
void mat(void *out, void **fatlink, void **longlink, void *in, double kappa, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision)
int max_res_increase_total
Complex axpyCGNormCuda(const double &a, cudaColorSpinorField &x, cudaColorSpinorField &y)
unsigned long long flops() const
cudaColorSpinorField * tmp2
cudaColorSpinorField * tmp
void PrintSummary(const char *name, int k, const double &r2, const double &b2)
QudaResidualType residual_type
CG(DiracMatrix &mat, DiracMatrix &matSloppy, SolverParam ¶m, TimeProfile &profile)
void copyCuda(cudaColorSpinorField &dst, const cudaColorSpinorField &src)
void operator()(cudaColorSpinorField &out, cudaColorSpinorField &in)
void axpyCuda(const double &a, cudaColorSpinorField &x, cudaColorSpinorField &y)
unsigned long long blas_flops
double3 xpyHeavyQuarkResidualNormCuda(cudaColorSpinorField &x, cudaColorSpinorField &y, cudaColorSpinorField &r)
void xpyCuda(cudaColorSpinorField &x, cudaColorSpinorField &y)
double reDotProductCuda(cudaColorSpinorField &a, cudaColorSpinorField &b)
void Stop(QudaProfileType idx)
QudaPrecision Precision() const
void PrintStats(const char *, int k, const double &r2, const double &b2, const double &hq2)
double Last(QudaProfileType idx)
void reduceDouble(double &)
void zeroCuda(cudaColorSpinorField &a)
void Start(QudaProfileType idx)
void tripleCGUpdateCuda(const double &alpha, const double &beta, cudaColorSpinorField &q, cudaColorSpinorField &r, cudaColorSpinorField &x, cudaColorSpinorField &p)
QudaPrecision precision_sloppy
bool use_sloppy_partial_accumulator
bool convergenceHQ(const double &r2, const double &hq2, const double &r2_tol, const double &hq_tol)
void xpayCuda(cudaColorSpinorField &x, const double &a, cudaColorSpinorField &y)
double3 HeavyQuarkResidualNormCuda(cudaColorSpinorField &x, cudaColorSpinorField &r)
double norm2(const ColorSpinorField &)
double xmyNormCuda(cudaColorSpinorField &a, cudaColorSpinorField &b)