16 MultiSrcCG::MultiSrcCG(DiracMatrix &
mat, DiracMatrix &matSloppy, SolverParam &
param, TimeProfile &profile) :
17 MultiSrcSolver(
param, profile),
mat(
mat), matSloppy(matSloppy)
22 MultiSrcCG::~MultiSrcCG() {
26 void MultiSrcCG::operator()(std::vector<ColorSpinorField*> out, std::vector<ColorSpinorField*> in)
42 printfQuda(
"Warning: inverting on zero-field source\n");
45 param.true_res_hq = 0.0;
49 cudaColorSpinorField r(b);
53 cudaColorSpinorField y(b,
csParam);
59 cudaColorSpinorField Ap(x,
csParam);
63 cudaColorSpinorField *tmp2_p = !
mat.isStaggered() ?
65 cudaColorSpinorField &tmp2 = *tmp2_p;
67 cudaColorSpinorField *r_sloppy;
68 if (
param.precision_sloppy == x.Precision()) {
72 r_sloppy =
new cudaColorSpinorField(r,
csParam);
75 cudaColorSpinorField *x_sloppy;
76 if (
param.precision_sloppy == x.Precision() ||
77 !
param.use_sloppy_partial_accumulator) {
78 x_sloppy = &
static_cast<cudaColorSpinorField&
>(x);
81 x_sloppy =
new cudaColorSpinorField(x,
csParam);
86 cudaColorSpinorField *tmp3_p =
87 (
param.precision !=
param.precision_sloppy && !
mat.isStaggered()) ?
89 cudaColorSpinorField &tmp3 = *tmp3_p;
91 ColorSpinorField &xSloppy = *x_sloppy;
92 ColorSpinorField &rSloppy = *r_sloppy;
94 cudaColorSpinorField p(rSloppy);
103 const bool use_heavy_quark_res =
105 bool heavy_quark_restart =
false;
114 double heavy_quark_res = 0.0;
115 double heavy_quark_res_old = 0.0;
117 if (use_heavy_quark_res) {
118 heavy_quark_res =
sqrt(blas::HeavyQuarkResidualNorm(x, r).z);
119 heavy_quark_res_old = heavy_quark_res;
121 const int heavy_quark_check =
param.heavy_quark_check;
123 double alpha=0.0, beta=0.0;
127 double rNorm =
sqrt(r2);
128 double r0Norm = rNorm;
129 double maxrx = rNorm;
130 double maxrr = rNorm;
131 double delta =
param.delta;
136 const int maxResIncrease = (use_heavy_quark_res ? 0 :
param.max_res_increase);
137 const int maxResIncreaseTotal =
param.max_res_increase_total;
140 const int hqmaxresIncrease = maxResIncrease + 1;
143 int resIncreaseTotal = 0;
144 int hqresIncrease = 0;
148 bool L2breakdown =
false;
156 PrintStats(
"CG", k, r2, b2, heavy_quark_res);
158 int steps_since_reliable = 1;
159 bool converged = convergence(r2, heavy_quark_res,
stop,
param.tol_hq);
161 while ( !converged && k <
param.maxiter) {
162 matSloppy(Ap, p,
tmp, tmp2);
166 bool breakdown =
false;
167 if (
param.pipeline) {
168 double3 triplet = blas::tripleCGReduction(rSloppy, Ap, p);
169 r2 = triplet.x;
double Ap2 = triplet.y; pAp = triplet.z;
173 sigma = alpha*(alpha * Ap2 - pAp);
174 if (sigma < 0.0 || steps_since_reliable==0) {
183 pAp = blas::reDotProduct(p, Ap);
187 Complex cg_norm = blas::axpyCGNorm(-alpha, Ap, rSloppy);
189 sigma = imag(cg_norm) >= 0.0 ? imag(cg_norm) : r2;
194 if (rNorm > maxrx) maxrx = rNorm;
195 if (rNorm > maxrr) maxrr = rNorm;
196 int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? 1 : 0;
197 int updateR = ((rNorm < delta*maxrr && r0Norm <= maxrr) || updateX) ? 1 : 0;
200 if ( convergence(r2, heavy_quark_res,
stop,
param.tol_hq) &&
param.delta >=
param.tol) updateX = 1;
203 if (use_heavy_quark_res and L2breakdown and convergenceHQ(r2, heavy_quark_res,
stop,
param.tol_hq) and
param.delta >=
param.tol) {
209 beta = sigma / r2_old;
211 if (
param.pipeline && !breakdown) blas::tripleCGUpdate(alpha, beta, Ap, rSloppy, xSloppy, p);
212 else blas::axpyZpbx(alpha, p, xSloppy, rSloppy, beta);
215 if (use_heavy_quark_res && k%heavy_quark_check==0) {
216 if (&x != &xSloppy) {
218 heavy_quark_res =
sqrt(blas::xpyHeavyQuarkResidualNorm(xSloppy,
tmp, rSloppy).z);
221 heavy_quark_res =
sqrt(blas::xpyHeavyQuarkResidualNorm(x, y, r).z);
225 steps_since_reliable++;
238 if (use_heavy_quark_res) heavy_quark_res =
sqrt(blas::HeavyQuarkResidualNorm(y, r).z);
241 if (
sqrt(r2) > r0Norm && updateX) {
244 warningQuda(
"CG: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
245 sqrt(r2), r0Norm, resIncreaseTotal);
246 if ( resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
247 if (use_heavy_quark_res) {
250 warningQuda(
"CG: solver exiting due to too many true residual norm increases");
258 if (use_heavy_quark_res and L2breakdown) {
260 warningQuda(
"CG: Restarting without reliable updates for heavy-quark residual");
261 heavy_quark_restart =
true;
262 if (heavy_quark_res > heavy_quark_res_old) {
264 warningQuda(
"CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", heavy_quark_res, heavy_quark_res_old);
266 if (hqresIncrease > hqmaxresIncrease) {
267 warningQuda(
"CG: solver exiting due to too many heavy quark residual norm increases");
279 if (use_heavy_quark_res and heavy_quark_restart) {
282 heavy_quark_restart =
false;
285 double rp = blas::reDotProduct(rSloppy, p) / (r2);
293 steps_since_reliable = 0;
294 heavy_quark_res_old = heavy_quark_res;
300 PrintStats(
"CG", k, r2, b2, heavy_quark_res);
302 converged = convergence(r2, heavy_quark_res,
stop,
param.tol_hq);
305 if (use_heavy_quark_res) {
307 bool L2done = L2breakdown or convergenceL2(r2, heavy_quark_res,
stop,
param.tol_hq);
309 bool HQdone = (steps_since_reliable == 0 and
param.delta > 0) and convergenceHQ(r2, heavy_quark_res,
stop,
param.tol_hq);
310 converged = L2done and HQdone;
322 double gflops = (
blas::flops +
mat.flops() + matSloppy.flops())*1e-9;
324 param.gflops = gflops;
327 if (k==
param.maxiter)
331 printfQuda(
"CG: Reliable updates = %d\n", rUpdate);
336 param.true_res_hq =
sqrt(blas::HeavyQuarkResidualNorm(x,r).z);
338 PrintSummary(
"CG", k, r2, b2,
stop, inv.tol_hq);
348 if (&tmp3 != &
tmp)
delete tmp3_p;
349 if (&tmp2 != &
tmp)
delete tmp2_p;
351 if (rSloppy.Precision() != r.Precision())
delete r_sloppy;
352 if (xSloppy.Precision() != x.Precision())
delete x_sloppy;
void reduceDouble(double &)
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
cudaColorSpinorField * tmp
@ QUDA_CUDA_FIELD_LOCATION
@ QUDA_HEAVY_QUARK_RESIDUAL
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
void xpay(ColorSpinorField &x, double a, ColorSpinorField &y)
void zero(ColorSpinorField &a)
double norm2(const ColorSpinorField &a)
double axpyNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
void axpy(double a, ColorSpinorField &x, ColorSpinorField &y)
void xpy(ColorSpinorField &x, ColorSpinorField &y)
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
void stop()
Stop profiling.
std::complex< double > Complex
__host__ __device__ ValueType sqrt(ValueType x)
void updateR()
update the radius for halos.
QudaVerbosity getVerbosity()