16 MultiSrcCG::MultiSrcCG(DiracMatrix &
mat, DiracMatrix &matSloppy, SolverParam &
param, TimeProfile &profile) :
17 MultiSrcSolver(param, profile), mat(mat), matSloppy(matSloppy)
22 MultiSrcCG::~MultiSrcCG() {
26 void MultiSrcCG::operator()(std::vector<ColorSpinorField*>
out, std::vector<ColorSpinorField*>
in)
42 printfQuda(
"Warning: inverting on zero-field source\n");
45 param.true_res_hq = 0.0;
49 cudaColorSpinorField r(b);
53 cudaColorSpinorField y(b,
csParam);
59 cudaColorSpinorField Ap(x,
csParam);
65 cudaColorSpinorField &
tmp2 = *tmp2_p;
67 cudaColorSpinorField *r_sloppy;
68 if (
param.precision_sloppy == x.Precision()) {
72 r_sloppy =
new cudaColorSpinorField(r,
csParam);
75 cudaColorSpinorField *x_sloppy;
76 if (
param.precision_sloppy == x.Precision() ||
77 !
param.use_sloppy_partial_accumulator) {
78 x_sloppy = &
static_cast<cudaColorSpinorField&
>(x);
81 x_sloppy =
new cudaColorSpinorField(x,
csParam);
86 cudaColorSpinorField *tmp3_p =
89 cudaColorSpinorField &tmp3 = *tmp3_p;
91 ColorSpinorField &xSloppy = *x_sloppy;
92 ColorSpinorField &rSloppy = *r_sloppy;
94 cudaColorSpinorField p(rSloppy);
103 const bool use_heavy_quark_res =
105 bool heavy_quark_restart =
false;
112 double stop = stopping(
param.tol, b2,
param.residual_type);
114 double heavy_quark_res = 0.0;
115 double heavy_quark_res_old = 0.0;
117 if (use_heavy_quark_res) {
119 heavy_quark_res_old = heavy_quark_res;
121 const int heavy_quark_check =
param.heavy_quark_check;
123 double alpha=0.0, beta=0.0;
127 double rNorm =
sqrt(r2);
128 double r0Norm = rNorm;
129 double maxrx = rNorm;
130 double maxrr = rNorm;
131 double delta =
param.delta;
136 const int maxResIncrease = (use_heavy_quark_res ? 0 :
param.max_res_increase);
137 const int maxResIncreaseTotal =
param.max_res_increase_total;
140 const int hqmaxresIncrease = maxResIncrease + 1;
143 int resIncreaseTotal = 0;
144 int hqresIncrease = 0;
148 bool L2breakdown =
false;
156 PrintStats(
"CG", k, r2, b2, heavy_quark_res);
158 int steps_since_reliable = 1;
159 bool converged = convergence(r2, heavy_quark_res, stop,
param.tol_hq);
161 while ( !converged && k <
param.maxiter) {
162 matSloppy(Ap, p,
tmp, tmp2);
166 bool breakdown =
false;
167 if (
param.pipeline) {
169 r2 = triplet.x;
double Ap2 = triplet.y; pAp = triplet.z;
173 sigma = alpha*(alpha * Ap2 - pAp);
174 if (sigma < 0.0 || steps_since_reliable==0) {
189 sigma = imag(cg_norm) >= 0.0 ? imag(cg_norm) : r2;
194 if (rNorm > maxrx) maxrx = rNorm;
195 if (rNorm > maxrr) maxrr = rNorm;
196 int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? 1 : 0;
197 int updateR = ((rNorm < delta*maxrr && r0Norm <= maxrr) || updateX) ? 1 : 0;
200 if ( convergence(r2, heavy_quark_res, stop,
param.tol_hq) &&
param.delta >=
param.tol) updateX = 1;
203 if (use_heavy_quark_res and L2breakdown and convergenceHQ(r2, heavy_quark_res, stop,
param.tol_hq) and
param.delta >=
param.tol) {
207 if ( !(updateR || updateX)) {
209 beta = sigma / r2_old;
215 if (use_heavy_quark_res && k%heavy_quark_check==0) {
216 if (&x != &xSloppy) {
225 steps_since_reliable++;
241 if (
sqrt(r2) > r0Norm && updateX) {
244 warningQuda(
"CG: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
245 sqrt(r2), r0Norm, resIncreaseTotal);
246 if ( resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
247 if (use_heavy_quark_res) {
250 warningQuda(
"CG: solver exiting due to too many true residual norm increases");
258 if (use_heavy_quark_res and L2breakdown) {
260 warningQuda(
"CG: Restarting without reliable updates for heavy-quark residual");
261 heavy_quark_restart =
true;
262 if (heavy_quark_res > heavy_quark_res_old) {
264 warningQuda(
"CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", heavy_quark_res, heavy_quark_res_old);
266 if (hqresIncrease > hqmaxresIncrease) {
267 warningQuda(
"CG: solver exiting due to too many heavy quark residual norm increases");
279 if (use_heavy_quark_res and heavy_quark_restart) {
282 heavy_quark_restart =
false;
293 steps_since_reliable = 0;
294 heavy_quark_res_old = heavy_quark_res;
300 PrintStats(
"CG", k, r2, b2, heavy_quark_res);
302 converged = convergence(r2, heavy_quark_res, stop,
param.tol_hq);
305 if (use_heavy_quark_res) {
307 bool L2done = L2breakdown or convergenceL2(r2, heavy_quark_res, stop,
param.tol_hq);
309 bool HQdone = (steps_since_reliable == 0 and
param.delta > 0) and convergenceHQ(r2, heavy_quark_res, stop,
param.tol_hq);
310 converged = L2done and HQdone;
324 param.gflops = gflops;
327 if (k==
param.maxiter)
331 printfQuda(
"CG: Reliable updates = %d\n", rUpdate);
338 PrintSummary(
"CG", k, r2, b2, stop, inv.tol_hq);
348 if (&tmp3 != &
tmp)
delete tmp3_p;
349 if (&tmp2 != &
tmp)
delete tmp2_p;
351 if (rSloppy.Precision() != r.Precision())
delete r_sloppy;
352 if (xSloppy.Precision() != x.Precision())
delete x_sloppy;
cudaColorSpinorField * tmp2
void axpyZpbx(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, double b)
QudaVerbosity getVerbosity()
double norm2(const ColorSpinorField &a)
__host__ __device__ ValueType sqrt(ValueType x)
cudaColorSpinorField * tmp
double3 xpyHeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &r)
double reDotProduct(ColorSpinorField &x, ColorSpinorField &y)
Complex axpyCGNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
void xpay(ColorSpinorField &x, double a, ColorSpinorField &y)
TimeProfile & profile
whether A is hermitian ot not
double Last(QudaProfileType idx)
void axpy(double a, ColorSpinorField &x, ColorSpinorField &y)
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
std::complex< double > Complex
void tripleCGUpdate(double alpha, double beta, ColorSpinorField &q, ColorSpinorField &r, ColorSpinorField &x, ColorSpinorField &p)
void zero(ColorSpinorField &a)
cpuColorSpinorField * out
unsigned long long flops() const
void xpy(ColorSpinorField &x, ColorSpinorField &y)
double axpyNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
void reduceDouble(double &)
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
double3 tripleCGReduction(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
void updateR()
update the radius for halos.