51 std::vector<ColorSpinorField*>
p;
52 std::vector<ColorSpinorField*>
x;
96 std::vector<ColorSpinorField*> P,
X;
128 const double *r2,
const double *beta,
const double pAp,
129 const double *
offset,
const int nShift,
const int j_low) {
131 for (
int j=0; j<nShift; j++) alpha_old[j] = alpha[j];
133 alpha[0] = r2[0] / pAp;
135 for (
int j=1; j<nShift; j++) {
136 double c0 = zeta[j] * zeta_old[j] * alpha_old[j_low];
137 double c1 = alpha[j_low] * beta[j_low] * (zeta_old[j]-zeta[j]);
138 double c2 = zeta_old[j] * alpha_old[j_low] * (1.0+(
offset[j]-
offset[0])*alpha[j_low]);
140 zeta_old[j] = zeta[j];
142 zeta[j] = c0 / (c1 + c2);
148 alpha[j] = alpha[j_low] * zeta[j] / zeta_old[j];
166 if (num_offset == 0)
return;
172 printfQuda(
"Warning: inverting on zero-field source\n");
173 for(
int i=0;
i<num_offset; ++
i){
182 const double prec_tol =
pow(10.,(-2*(
int)
b.Precision()+1));
184 double *zeta =
new double[num_offset];
185 double *zeta_old =
new double[num_offset];
186 double *alpha =
new double[num_offset];
187 double *beta =
new double[num_offset];
190 int num_offset_now = num_offset;
191 for (
int i=0;
i<num_offset;
i++) {
192 zeta[
i] = zeta_old[
i] = 1.0;
199 for (
int j=0; j<num_offset; j++)
204 std::vector<ColorSpinorField*> x_sloppy;
205 x_sloppy.resize(num_offset);
206 std::vector<ColorSpinorField*>
y;
212 y.resize(num_offset);
228 for (
int i=0;
i<num_offset;
i++){
234 for (
int i=0;
i<num_offset;
i++)
238 std::vector<ColorSpinorField*>
p;
239 p.resize(num_offset);
266 for (
int i=0;
i<num_offset;
i++) {
272 iter[num_offset] = 1;
281 for (
int i=0;
i<num_offset;
i++) {
283 r0Norm[
i] = rNorm[
i];
297 for (
int i=0;
i<num_offset;
i++) {
298 resIncreaseTotal[
i]=0;
305 bool aux_update =
false;
308 ShiftUpdate shift_update(r_sloppy,
p, x_sloppy, alpha, beta, zeta, zeta_old, j_low, num_offset_now);
314 printfQuda(
"MultiShift CG: %d iterations, <r,r> = %e, |r|/|b| = %e\n", k, r2[0],
sqrt(r2[0]/b2));
336 r2[0] = real(cg_norm);
337 double zn = imag(cg_norm);
340 rNorm[0] =
sqrt(r2[0]);
341 for (
int j=1; j<num_offset_now; j++) rNorm[j] = rNorm[0] * zeta[j];
346 int reliable_shift = -1;
347 for (
int j=0; j>=0; j--) {
348 if (rNorm[j] > maxrx[j]) maxrx[j] = rNorm[j];
349 if (rNorm[j] > maxrr[j]) maxrr[j] = rNorm[j];
350 updateX = (rNorm[j] <
delta*r0Norm[j] && r0Norm[j] <= maxrx[j]) ? 1 : updateX;
352 if ((updateX ||
updateR) && reliable_shift == -1) reliable_shift = j;
357 beta[0] = zn / r2_old;
371 for (
int j=0; j<num_offset_now; j++) {
380 for (
int j=1; j<num_offset_now; j++) r2[j] = zeta[j] * zeta[j] * r2[0];
381 for (
int j=0; j<num_offset_now; j++)
blas::zero(*x_sloppy[j]);
386 if (
sqrt(r2[reliable_shift]) > r0Norm[reliable_shift]) {
388 resIncreaseTotal[reliable_shift]++;
389 warningQuda(
"MultiShiftCG: Shift %d, updated residual %e is greater than previous residual %e (total #inc %i)",
390 reliable_shift,
sqrt(r2[reliable_shift]), r0Norm[reliable_shift], resIncreaseTotal[reliable_shift]);
392 if (resIncrease > maxResIncrease or resIncreaseTotal[reliable_shift] > maxResIncreaseTotal) {
393 warningQuda(
"MultiShiftCG: solver exiting due to too many true residual norm increases");
401 for (
int j=0; j<num_offset_now; j++) {
407 beta[0] = r2[0] / r2_old;
409 for (
int j=1; j<num_offset_now; j++) {
410 beta[j] = beta[j_low] * zeta[j] * alpha[j] / (zeta_old[j] * alpha[j_low]);
415 int m = reliable_shift;
416 rNorm[m] =
sqrt(r2[0]) * zeta[m];
419 r0Norm[m] = rNorm[m];
425 for (
int j=num_offset_now-1; j>=1; j--) {
426 if (zeta[j] == 0.0 && r2[j+1] < stop[j+1]) {
429 printfQuda(
"MultiShift CG: Shift %d converged after %d iterations\n", j, k+1);
431 r2[j] = zeta[j] * zeta[j] * r2[0];
433 if ((r2[j] < stop[j] ||
sqrt(r2[j] / b2) < prec_tol) && iter[j+1] ) {
437 printfQuda(
"MultiShift CG: Shift %d converged after %d iterations\n", j, k+1);
441 num_offset_now -= converged;
447 printfQuda(
"Convergence of unshifted system so trigger shiftUpdate\n");
451 shift_update.
apply(0);
453 for (
int j=0; j<num_offset_now; j++) iter[j] = k+1;
459 printfQuda(
"MultiShift CG: %d iterations, <r,r> = %e, |r|/|b| = %e\n", k, r2[0],
sqrt(r2[0]/b2));
462 for (
int i=0;
i<num_offset;
i++) {
463 if (iter[
i] == 0) iter[
i] = k;
472 printfQuda(
"MultiShift CG: Reliable updates = %d\n", rUpdate);
488 for(
int i=0;
i < num_offset;
i++) {
489 mat(*r, *
x[
i], *tmp4_p, *tmp5_p);
502 printfQuda(
"MultiShift CG: Converged after %d iterations\n", k);
503 for(
int i=0;
i < num_offset;
i++) {
504 printfQuda(
" shift=%d, %d iterations, relative residual: iterated = %e, true = %e\n",
509 if (tmp5_p != tmp4_p && tmp5_p != tmp2_p && (
reliable ? tmp5_p !=
y[1] : 1))
delete tmp5_p;
510 if (tmp4_p != &
tmp1 && (
reliable ? tmp4_p !=
y[0] : 1))
delete tmp4_p;
513 printfQuda(
"MultiShift CG: Converged after %d iterations\n", k);
514 for(
int i=0;
i < num_offset;
i++) {
516 printfQuda(
" shift=%d, %d iterations, relative residual: iterated = %e\n",
534 for (
int i=0;
i<num_offset;
i++)
535 if (x_sloppy[
i]->Precision() !=
x[
i]->Precision())
delete x_sloppy[
i];
538 for (
int i=0;
i<num_offset;
i++)
delete p[
i];
double iter_res_offset[QUDA_MAX_MULTI_SHIFT]
void updateNshift(int new_n_shift)
void xpay(ColorSpinorField &x, const double &a, ColorSpinorField &y)
std::vector< ColorSpinorField * > p
static double stopping(const double &tol, const double &b2, QudaResidualType residual_type)
#define QUDA_MAX_MULTI_SHIFT
Maximum number of shifts supported by the multi-shift solver. This number may be changed if need be...
QudaVerbosity getVerbosity()
double norm2(const ColorSpinorField &a)
__host__ __device__ ValueType sqrt(ValueType x)
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
std::complex< double > Complex
static ColorSpinorField * Create(const ColorSpinorParam ¶m)
void updateAlphaZeta(double *alpha, double *zeta, double *zeta_old, const double *r2, const double *beta, const double pAp, const double *offset, const int nShift, const int j_low)
double reDotProduct(ColorSpinorField &x, ColorSpinorField &y)
double tol_offset[QUDA_MAX_MULTI_SHIFT]
double offset[QUDA_MAX_MULTI_SHIFT]
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
int max_res_increase_total
bool convergence(const double *r2, const double *r2_tol, int n) const
double Last(QudaProfileType idx)
static unsigned int delta
void updateNupdate(int new_n_update)
QudaResidualType residual_type
void axpyZpbx(const double &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, const double &b)
const DiracMatrix & matSloppy
double true_res_hq_offset[QUDA_MAX_MULTI_SHIFT]
void apply(const cudaStream_t &stream)
static __inline__ size_t p
Complex axpyCGNorm(const double &a, ColorSpinorField &x, ColorSpinorField &y)
#define checkLocation(...)
__host__ __device__ ValueType pow(ValueType x, ExponentType e)
std::vector< ColorSpinorField * > x
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
double true_res_offset[QUDA_MAX_MULTI_SHIFT]
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
void zero(ColorSpinorField &a)
void axpy(const double &a, ColorSpinorField &x, ColorSpinorField &y)
void axpby(const double &a, ColorSpinorField &x, const double &b, ColorSpinorField &y)
void operator()(std::vector< ColorSpinorField *> out, ColorSpinorField &in)
double axpyReDot(const double &a, ColorSpinorField &x, ColorSpinorField &y)
void axpyBzpcx(const double &a, ColorSpinorField &x, ColorSpinorField &y, const double &b, ColorSpinorField &z, const double &c)
unsigned long long flops() const
MultiShiftCG(DiracMatrix &mat, DiracMatrix &matSloppy, SolverParam ¶m, TimeProfile &profile)
int reliable(double &rNorm, double &maxrx, double &maxrr, const double &r2, const double &delta)
void xpy(ColorSpinorField &x, ColorSpinorField &y)
QudaPrecision precision_sloppy
bool use_sloppy_partial_accumulator
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
QudaPrecision Precision() const
__device__ unsigned int count[QUDA_MAX_MULTI_REDUCE]
__device__ __host__ void zero(vector_type< scalar, n > &v)
ShiftUpdate(ColorSpinorField *r, std::vector< ColorSpinorField *> p, std::vector< ColorSpinorField *> x, double *alpha, double *beta, double *zeta, double *zeta_old, int j_low, int n_shift)
void updateR()
update the radius for halos.