52 std::vector<ColorSpinorField*>
p;
53 std::vector<ColorSpinorField*>
x;
73 double *alpha,
double *beta,
double *zeta,
double *zeta_old,
int j_low,
int n_shift) :
74 r(r), p(p), x(x), alpha(alpha), beta(beta), zeta(zeta), zeta_old(zeta_old), j_low(j_low),
75 n_shift(n_shift), n_update( (r->
Nspin()==4) ? 4 : 2 ) {
90 for (
int j= (count*n_shift)/n_update+1; j<=((count+1)*n_shift)/n_update && j<
n_shift; j++) {
91 beta[j] = beta[
j_low] * zeta[j] * alpha[j] / ( zeta_old[j] * alpha[
j_low] );
97 std::vector<ColorSpinorField*> P,
X;
98 for (
int j= (count*n_shift)/n_update+1; j<=((count+1)*n_shift)/n_update && j<
n_shift; j++) {
99 beta[j] = beta[
j_low] * zeta[j] * alpha[j] / ( zeta_old[j] * alpha[
j_low] );
103 if (P.size())
blas::axpyBzpcx(&alpha[zero], P, X, &zeta[zero], *r, &beta[zero]);
105 if (++count == n_update) count = 0;
129 const double *r2,
const double *beta,
const double pAp,
130 const double *offset,
const int nShift,
const int j_low) {
132 for (
int j=0; j<nShift; j++) alpha_old[j] = alpha[j];
134 alpha[0] = r2[0] / pAp;
136 for (
int j=1; j<nShift; j++) {
137 double c0 = zeta[j] * zeta_old[j] * alpha_old[j_low];
138 double c1 = alpha[j_low] * beta[j_low] * (zeta_old[j]-zeta[j]);
139 double c2 = zeta_old[j] * alpha_old[j_low] * (1.0+(offset[j]-offset[0])*alpha[j_low]);
141 zeta_old[j] = zeta[j];
143 zeta[j] = c0 / (c1 + c2);
149 alpha[j] = alpha[j_low] * zeta[j] / zeta_old[j];
167 if (num_offset == 0)
return;
173 printfQuda(
"Warning: inverting on zero-field source\n");
174 for(
int i=0; i<num_offset; ++i){
182 bool exit_early =
false;
190 const double fine_tol =
pow(10.,(-2*(
int)b.
Precision()+1));
191 std::unique_ptr<double[]> prec_tol(
new double[num_offset]);
193 prec_tol[0] = mixed ? sloppy_tol : fine_tol;
194 for (
int i=1; i<num_offset; i++) {
195 prec_tol[i] = std::min(sloppy_tol,std::max(fine_tol,
sqrt(
param.
tol_offset[i]*sloppy_tol)));
204 int num_offset_now = num_offset;
205 for (
int i=0; i<num_offset; i++) {
206 zeta[i] = zeta_old[i] = 1.0;
213 for (
int j=0; j<num_offset; j++)
218 std::vector<ColorSpinorField*> x_sloppy;
219 x_sloppy.resize(num_offset);
220 std::vector<ColorSpinorField*> y;
226 y.resize(num_offset);
242 for (
int i=0; i<num_offset; i++){
248 for (
int i=0; i<num_offset; i++)
252 p.resize(num_offset);
279 for (
int i=0; i<num_offset; i++) {
285 iter[num_offset] = 1;
294 for (
int i=0; i<num_offset; i++) {
295 rNorm[i] =
sqrt(r2[i]);
296 r0Norm[i] = rNorm[i];
310 for (
int i=0; i<num_offset; i++) {
311 resIncreaseTotal[i]=0;
318 bool aux_update =
false;
321 ShiftUpdate shift_update(r_sloppy, p, x_sloppy, alpha, beta, zeta, zeta_old, j_low, num_offset_now);
327 printfQuda(
"MultiShift CG: %d iterations, <r,r> = %e, |r|/|b| = %e\n", k, r2[0],
sqrt(r2[0]/b2));
345 for (
int j=1; j<num_offset_now; j++) r2_old_array[j] = zeta[j] * zeta[j] * r2[0];
346 updateAlphaZeta(alpha, zeta, zeta_old, r2, beta, pAp, offset, num_offset_now, j_low);
349 r2_old_array[0] = r2_old;
352 r2[0] = real(cg_norm);
353 double zn = imag(cg_norm);
356 rNorm[0] =
sqrt(r2[0]);
357 for (
int j=1; j<num_offset_now; j++) rNorm[j] = rNorm[0] * zeta[j];
362 int reliable_shift = -1;
363 for (
int j=0; j>=0; j--) {
364 if (rNorm[j] > maxrx[j]) maxrx[j] = rNorm[j];
365 if (rNorm[j] > maxrr[j]) maxrr[j] = rNorm[j];
366 updateX = (rNorm[j] < delta*r0Norm[j] && r0Norm[j] <= maxrx[j]) ? 1 : updateX;
367 updateR = ((rNorm[j] < delta*maxrr[j] && r0Norm[j] <= maxrr[j]) || updateX) ? 1 :
updateR;
368 if ((updateX ||
updateR) && reliable_shift == -1) reliable_shift = j;
373 beta[0] = zn / r2_old;
375 blas::axpyZpbx(alpha[0], *p[0], *x_sloppy[0], *r_sloppy, beta[0]);
387 for (
int j=0; j<num_offset_now; j++) {
392 mat(*r, *y[0], *x[0], tmp3);
393 if (r->Nspin()==4)
blas::axpy(offset[0], *y[0], *r);
396 for (
int j=1; j<num_offset_now; j++) r2[j] = zeta[j] * zeta[j] * r2[0];
397 for (
int j=0; j<num_offset_now; j++)
blas::zero(*x_sloppy[j]);
402 if (
sqrt(r2[reliable_shift]) > r0Norm[reliable_shift]) {
404 resIncreaseTotal[reliable_shift]++;
405 warningQuda(
"MultiShiftCG: Shift %d, updated residual %e is greater than previous residual %e (total #inc %i)",
406 reliable_shift,
sqrt(r2[reliable_shift]), r0Norm[reliable_shift], resIncreaseTotal[reliable_shift]);
408 if (resIncrease > maxResIncrease or resIncreaseTotal[reliable_shift] > maxResIncreaseTotal) {
409 warningQuda(
"MultiShiftCG: solver exiting due to too many true residual norm increases");
417 for (
int j=0; j<num_offset_now; j++) {
423 beta[0] = r2[0] / r2_old;
425 for (
int j=1; j<num_offset_now; j++) {
426 beta[j] = beta[j_low] * zeta[j] * alpha[j] / (zeta_old[j] * alpha[j_low]);
431 int m = reliable_shift;
432 rNorm[m] =
sqrt(r2[0]) * zeta[m];
435 r0Norm[m] = rNorm[m];
441 for (
int j=num_offset_now-1; j>=1; j--) {
442 if (zeta[j] == 0.0 && r2[j+1] < stop[j+1]) {
445 printfQuda(
"MultiShift CG: Shift %d converged after %d iterations\n", j, k+1);
447 r2[j] = zeta[j] * zeta[j] * r2[0];
449 if ((r2[j] < stop[j] ||
sqrt(r2[j] / b2) < prec_tol[j]) && iter[j+1] ) {
453 printfQuda(
"MultiShift CG: Shift %d converged after %d iterations\n", j, k+1);
457 num_offset_now -= converged;
469 printfQuda(
"Convergence of unshifted system so trigger shiftUpdate\n");
473 shift_update.
apply(0);
475 for (
int j=0; j<num_offset_now; j++) iter[j] = k+1;
481 printfQuda(
"MultiShift CG: %d iterations, <r,r> = %e, |r|/|b| = %e\n", k, r2[0],
sqrt(r2[0]/b2));
484 for (
int i=0; i<num_offset; i++) {
485 if (iter[i] == 0) iter[i] = k;
494 printfQuda(
"MultiShift CG: Reliable updates = %d\n", rUpdate);
510 for (
int i = 0; i < num_offset; i++) {
514 if ( (i > 0 and not mixed) or (i == 0 and not exit_early) ) {
515 mat(*r, *x[i], *tmp4_p, *tmp5_p);
516 if (r->Nspin() == 4) {
533 printfQuda(
"MultiShift CG: Converged after %d iterations\n", k);
534 for (
int i = 0; i < num_offset; i++) {
536 printfQuda(
" shift=%d, %d iterations, relative residual: iterated = %e\n",
539 printfQuda(
" shift=%d, %d iterations, relative residual: iterated = %e, true = %e\n",
545 if (tmp5_p != tmp4_p && tmp5_p != tmp2_p && (reliable ? tmp5_p != y[1] : 1))
delete tmp5_p;
546 if (tmp4_p != &tmp1 && (reliable ? tmp4_p != y[0] : 1))
delete tmp4_p;
550 printfQuda(
"MultiShift CG: Converged after %d iterations\n", k);
551 for (
int i = 0; i < num_offset; i++) {
553 printfQuda(
" shift=%d, %d iterations, relative residual: iterated = %e\n",
567 if (&tmp3 != &tmp1)
delete tmp3_p;
568 if (&tmp2 != &tmp1)
delete tmp2_p;
570 if (r_sloppy->
Precision() != r->Precision())
delete r_sloppy;
571 for (
int i=0; i<num_offset; i++)
572 if (x_sloppy[i]->Precision() != x[i]->Precision())
delete x_sloppy[i];
576 if (reliable)
for (
int i=0; i<num_offset; i++)
delete y[i];
cudaColorSpinorField * tmp2
double iter_res_offset[QUDA_MAX_MULTI_SHIFT]
void updateNshift(int new_n_shift)
void setPrecision(QudaPrecision precision, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION, bool force_native=false)
cudaColorSpinorField * tmp1
std::vector< ColorSpinorField * > p
void axpyZpbx(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, double b)
#define QUDA_MAX_MULTI_SHIFT
Maximum number of shifts supported by the multi-shift solver. This number may be changed if need be...
QudaVerbosity getVerbosity()
double norm2(const ColorSpinorField &a)
__host__ __device__ ValueType sqrt(ValueType x)
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
QudaPrecision precision_refinement_sloppy
static ColorSpinorField * Create(const ColorSpinorParam ¶m)
void updateAlphaZeta(double *alpha, double *zeta, double *zeta_old, const double *r2, const double *beta, const double pAp, const double *offset, const int nShift, const int j_low)
double reDotProduct(ColorSpinorField &x, ColorSpinorField &y)
Complex axpyCGNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
double tol_offset[QUDA_MAX_MULTI_SHIFT]
double offset[QUDA_MAX_MULTI_SHIFT]
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
int max_res_increase_total
void xpay(ColorSpinorField &x, double a, ColorSpinorField &y)
bool convergence(const double *r2, const double *r2_tol, int n) const
double Last(QudaProfileType idx)
void updateNupdate(int new_n_update)
QudaResidualType residual_type
const DiracMatrix & matSloppy
double true_res_hq_offset[QUDA_MAX_MULTI_SHIFT]
void apply(const cudaStream_t &stream)
static double stopping(double tol, double b2, QudaResidualType residual_type)
Set the solver L2 stopping condition.
void axpy(double a, ColorSpinorField &x, ColorSpinorField &y)
#define checkLocation(...)
__host__ __device__ ValueType pow(ValueType x, ExponentType e)
void axpyBzpcx(double a, ColorSpinorField &x, ColorSpinorField &y, double b, ColorSpinorField &z, double c)
std::vector< ColorSpinorField * > x
void operator()(std::vector< ColorSpinorField *>x, ColorSpinorField &b, std::vector< ColorSpinorField *> &p, double *r2_old_array)
Run multi-shift and return Krylov-space at the end of the solve in p and r2_old_arry.
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
double true_res_offset[QUDA_MAX_MULTI_SHIFT]
std::complex< double > Complex
double axpyReDot(double a, ColorSpinorField &x, ColorSpinorField &y)
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
void zero(ColorSpinorField &a)
unsigned long long flops() const
MultiShiftCG(DiracMatrix &mat, DiracMatrix &matSloppy, SolverParam ¶m, TimeProfile &profile)
int reliable(double &rNorm, double &maxrx, double &maxrr, const double &r2, const double &delta)
void xpy(ColorSpinorField &x, ColorSpinorField &y)
void axpby(double a, ColorSpinorField &x, double b, ColorSpinorField &y)
QudaPrecision precision_sloppy
bool use_sloppy_partial_accumulator
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
QudaPrecision Precision() const
__device__ unsigned int count[QUDA_MAX_MULTI_REDUCE]
__device__ __host__ void zero(vector_type< scalar, n > &v)
ShiftUpdate(ColorSpinorField *r, std::vector< ColorSpinorField *> p, std::vector< ColorSpinorField *> x, double *alpha, double *beta, double *zeta, double *zeta_old, int j_low, int n_shift)
void updateR()
update the radius for halos.