52 std::vector<ColorSpinorField*> p;
53 std::vector<ColorSpinorField*> x;
73 double *alpha,
double *beta,
double *zeta,
double *zeta_old,
int j_low,
int n_shift) :
74 r(r), p(p), x(x), alpha(alpha), beta(beta), zeta(zeta), zeta_old(zeta_old), j_low(j_low),
75 n_shift(n_shift), n_update( (r->
Nspin()==4) ? 4 : 2 ) {
91 for (
int j= (count*n_shift)/n_update+1; j<=((count+1)*n_shift)/n_update && j<n_shift; j++) {
92 beta[j] = beta[j_low] * zeta[j] * alpha[j] / ( zeta_old[j] * alpha[j_low] );
97 int zero = (count*n_shift)/n_update+1;
98 std::vector<ColorSpinorField*> P, X;
99 for (
int j= (count*n_shift)/n_update+1; j<=((count+1)*n_shift)/n_update && j<n_shift; j++) {
100 beta[j] = beta[j_low] * zeta[j] * alpha[j] / ( zeta_old[j] * alpha[j_low] );
106 if (++count == n_update) count = 0;
127 const double *r2,
const double *beta,
const double pAp,
128 const double *offset,
const int nShift,
const int j_low) {
130 for (
int j=0; j<nShift; j++) alpha_old[j] = alpha[j];
132 alpha[0] = r2[0] / pAp;
134 for (
int j=1; j<nShift; j++) {
135 double c0 = zeta[j] * zeta_old[j] * alpha_old[j_low];
136 double c1 = alpha[j_low] * beta[j_low] * (zeta_old[j]-zeta[j]);
137 double c2 = zeta_old[j] * alpha_old[j_low] * (1.0+(offset[j]-offset[0])*alpha[j_low]);
139 zeta_old[j] = zeta[j];
141 zeta[j] = c0 / (c1 + c2);
147 alpha[j] = alpha[j_low] * zeta[j] / zeta_old[j];
165 if (num_offset == 0)
return;
171 printfQuda(
"Warning: inverting on zero-field source\n");
172 for(
int i=0; i<num_offset; ++i){
180 bool exit_early =
false;
188 const double fine_tol =
pow(10.,(-2*(
int)b.
Precision()+1));
189 std::unique_ptr<double[]> prec_tol(
new double[num_offset]);
191 prec_tol[0] = mixed ? sloppy_tol : fine_tol;
192 for (
int i=1; i<num_offset; i++) {
193 prec_tol[i] = std::min(sloppy_tol,std::max(fine_tol,
sqrt(
param.
tol_offset[i]*sloppy_tol)));
202 int num_offset_now = num_offset;
203 for (
int i=0; i<num_offset; i++) {
204 zeta[i] = zeta_old[i] = 1.0;
211 for (
int j=0; j<num_offset; j++)
216 std::vector<ColorSpinorField*> x_sloppy;
217 x_sloppy.resize(num_offset);
218 std::vector<ColorSpinorField*> y;
224 y.resize(num_offset);
240 for (
int i=0; i<num_offset; i++){
246 for (
int i=0; i<num_offset; i++)
250 p.resize(num_offset);
277 for (
int i=0; i<num_offset; i++) {
283 iter[num_offset] = 1;
292 for (
int i=0; i<num_offset; i++) {
293 rNorm[i] =
sqrt(r2[i]);
294 r0Norm[i] = rNorm[i];
308 for (
int i=0; i<num_offset; i++) {
309 resIncreaseTotal[i]=0;
316 bool aux_update =
false;
319 ShiftUpdate shift_update(r_sloppy, p, x_sloppy, alpha, beta, zeta, zeta_old, j_low, num_offset_now);
325 printfQuda(
"MultiShift CG: %d iterations, <r,r> = %e, |r|/|b| = %e\n", k, r2[0],
sqrt(r2[0]/b2));
343 for (
int j=1; j<num_offset_now; j++) r2_old_array[j] = zeta[j] * zeta[j] * r2[0];
344 updateAlphaZeta(alpha, zeta, zeta_old, r2, beta, pAp, offset, num_offset_now, j_low);
347 r2_old_array[0] = r2_old;
350 r2[0] = real(cg_norm);
351 double zn = imag(cg_norm);
354 rNorm[0] =
sqrt(r2[0]);
355 for (
int j=1; j<num_offset_now; j++) rNorm[j] = rNorm[0] * zeta[j];
360 int reliable_shift = -1;
361 for (
int j=0; j>=0; j--) {
362 if (rNorm[j] > maxrx[j]) maxrx[j] = rNorm[j];
363 if (rNorm[j] > maxrr[j]) maxrr[j] = rNorm[j];
364 updateX = (rNorm[j] < delta*r0Norm[j] && r0Norm[j] <= maxrx[j]) ? 1 : updateX;
365 updateR = ((rNorm[j] < delta*maxrr[j] && r0Norm[j] <= maxrr[j]) || updateX) ? 1 :
updateR;
366 if ((updateX ||
updateR) && reliable_shift == -1) reliable_shift = j;
371 beta[0] = zn / r2_old;
373 blas::axpyZpbx(alpha[0], *p[0], *x_sloppy[0], *r_sloppy, beta[0]);
385 for (
int j=0; j<num_offset_now; j++) {
390 mat(*r, *y[0], *x[0], tmp3);
391 if (r->Nspin()==4)
blas::axpy(offset[0], *y[0], *r);
394 for (
int j=1; j<num_offset_now; j++) r2[j] = zeta[j] * zeta[j] * r2[0];
395 for (
int j=0; j<num_offset_now; j++)
blas::zero(*x_sloppy[j]);
400 if (
sqrt(r2[reliable_shift]) > r0Norm[reliable_shift]) {
402 resIncreaseTotal[reliable_shift]++;
403 warningQuda(
"MultiShiftCG: Shift %d, updated residual %e is greater than previous residual %e (total #inc %i)",
404 reliable_shift,
sqrt(r2[reliable_shift]), r0Norm[reliable_shift], resIncreaseTotal[reliable_shift]);
406 if (resIncrease > maxResIncrease or resIncreaseTotal[reliable_shift] > maxResIncreaseTotal) {
407 warningQuda(
"MultiShiftCG: solver exiting due to too many true residual norm increases");
415 for (
int j=0; j<num_offset_now; j++) {
421 beta[0] = r2[0] / r2_old;
423 for (
int j=1; j<num_offset_now; j++) {
424 beta[j] = beta[j_low] * zeta[j] * alpha[j] / (zeta_old[j] * alpha[j_low]);
429 int m = reliable_shift;
430 rNorm[m] =
sqrt(r2[0]) * zeta[m];
433 r0Norm[m] = rNorm[m];
439 for (
int j=num_offset_now-1; j>=1; j--) {
440 if (zeta[j] == 0.0 && r2[j+1] <
stop[j+1]) {
443 printfQuda(
"MultiShift CG: Shift %d converged after %d iterations\n", j, k+1);
445 r2[j] = zeta[j] * zeta[j] * r2[0];
447 if ((r2[j] <
stop[j] ||
sqrt(r2[j] / b2) < prec_tol[j]) && iter[j+1] ) {
451 printfQuda(
"MultiShift CG: Shift %d converged after %d iterations\n", j, k+1);
455 num_offset_now -= converged;
467 printfQuda(
"Convergence of unshifted system so trigger shiftUpdate\n");
471 shift_update.
apply(0);
473 for (
int j=0; j<num_offset_now; j++) iter[j] = k+1;
479 printfQuda(
"MultiShift CG: %d iterations, <r,r> = %e, |r|/|b| = %e\n", k, r2[0],
sqrt(r2[0]/b2));
482 for (
int i=0; i<num_offset; i++) {
483 if (iter[i] == 0) iter[i] = k;
492 printfQuda(
"MultiShift CG: Reliable updates = %d\n", rUpdate);
508 for (
int i = 0; i < num_offset; i++) {
512 if ( (i > 0 and not mixed) or (i == 0 and not exit_early) ) {
513 mat(*r, *x[i], *tmp4_p, *tmp5_p);
514 if (r->Nspin() == 4) {
531 printfQuda(
"MultiShift CG: Converged after %d iterations\n", k);
532 for (
int i = 0; i < num_offset; i++) {
534 printfQuda(
" shift=%d, %d iterations, relative residual: iterated = %e\n",
537 printfQuda(
" shift=%d, %d iterations, relative residual: iterated = %e, true = %e\n",
543 if (tmp5_p != tmp4_p && tmp5_p != tmp2_p && (
reliable ? tmp5_p != y[1] : 1))
delete tmp5_p;
544 if (tmp4_p != &tmp1 && (
reliable ? tmp4_p != y[0] : 1))
delete tmp4_p;
548 printfQuda(
"MultiShift CG: Converged after %d iterations\n", k);
549 for (
int i = 0; i < num_offset; i++) {
551 printfQuda(
" shift=%d, %d iterations, relative residual: iterated = %e\n",
565 if (&tmp3 != &tmp1)
delete tmp3_p;
566 if (&tmp2 != &tmp1)
delete tmp2_p;
568 if (r_sloppy->
Precision() != r->Precision())
delete r_sloppy;
569 for (
int i=0; i<num_offset; i++)
570 if (x_sloppy[i]->Precision() != x[i]->Precision())
delete x_sloppy[i];
574 if (
reliable)
for (
int i=0; i<num_offset; i++)
delete y[i];
static ColorSpinorField * Create(const ColorSpinorParam ¶m)
bool isStaggered() const
return if the operator is a staggered operator
unsigned long long flops() const
QudaPrecision Precision() const
MultiShiftCG(const DiracMatrix &mat, const DiracMatrix &matSloppy, SolverParam ¶m, TimeProfile &profile)
void operator()(std::vector< ColorSpinorField * >x, ColorSpinorField &b, std::vector< ColorSpinorField * > &p, double *r2_old_array)
Run multi-shift and return Krylov-space at the end of the solve in p and r2_old_arry.
const DiracMatrix & matSloppy
bool convergence(const double *r2, const double *r2_tol, int n) const
void updateNshift(int new_n_shift)
void updateNupdate(int new_n_update)
void apply(const qudaStream_t &stream)
ShiftUpdate(ColorSpinorField *r, std::vector< ColorSpinorField * > p, std::vector< ColorSpinorField * > x, double *alpha, double *beta, double *zeta, double *zeta_old, int j_low, int n_shift)
static double stopping(double tol, double b2, QudaResidualType residual_type)
Set the solver L2 stopping condition.
double Last(QudaProfileType idx)
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
@ QUDA_CUDA_FIELD_LOCATION
#define checkLocation(...)
double axpyReDot(double a, ColorSpinorField &x, ColorSpinorField &y)
Complex axpyCGNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
void axpyZpbx(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, double b)
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
void axpyBzpcx(double a, ColorSpinorField &x, ColorSpinorField &y, double b, ColorSpinorField &z, double c)
void xpay(ColorSpinorField &x, double a, ColorSpinorField &y)
void zero(ColorSpinorField &a)
double norm2(const ColorSpinorField &a)
double reDotProduct(ColorSpinorField &x, ColorSpinorField &y)
void axpy(double a, ColorSpinorField &x, ColorSpinorField &y)
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
void xpy(ColorSpinorField &x, ColorSpinorField &y)
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
void axpby(double a, ColorSpinorField &x, double b, ColorSpinorField &y)
void stop()
Stop profiling.
int reliable(double &rNorm, double &maxrx, double &maxrr, const double &r2, const double &delta)
__device__ __host__ void zero(double &a)
void updateAlphaZeta(double *alpha, double *zeta, double *zeta_old, const double *r2, const double *beta, const double pAp, const double *offset, const int nShift, const int j_low)
std::complex< double > Complex
__host__ __device__ ValueType sqrt(ValueType x)
__host__ __device__ ValueType pow(ValueType x, ExponentType e)
void updateR()
update the radius for halos.
cudaStream_t qudaStream_t
#define QUDA_MAX_MULTI_SHIFT
Maximum number of shifts supported by the multi-shift solver. This number may be changed if need be.
double iter_res_offset[QUDA_MAX_MULTI_SHIFT]
bool use_sloppy_partial_accumulator
int max_res_increase_total
QudaResidualType residual_type
QudaPrecision precision_refinement_sloppy
double offset[QUDA_MAX_MULTI_SHIFT]
QudaPrecision precision_sloppy
double tol_offset[QUDA_MAX_MULTI_SHIFT]
double true_res_offset[QUDA_MAX_MULTI_SHIFT]
double true_res_hq_offset[QUDA_MAX_MULTI_SHIFT]
QudaVerbosity getVerbosity()