39 const double *r2,
const double *beta,
const double pAp,
40 const double *offset,
const int nShift,
const int j_low) {
42 for (
int j=0; j<nShift; j++) alpha_old[j] = alpha[j];
44 alpha[0] = r2[0] / pAp;
46 for (
int j=1; j<nShift; j++) {
47 double c0 = zeta[j] * zeta_old[j] * alpha_old[j_low];
48 double c1 = alpha[j_low] * beta[j_low] * (zeta_old[j]-zeta[j]);
49 double c2 = zeta_old[j] * alpha_old[j_low] * (1.0+(offset[j]-offset[0])*alpha[j_low]);
51 zeta_old[j] = zeta[j];
52 zeta[j] = c0 / (c1 + c2);
53 alpha[j] = alpha[j_low] * zeta[j] / zeta_old[j];
64 if (num_offset == 0)
return;
70 printfQuda(
"Warning: inverting on zero-field source\n");
71 for(
int i=0; i<num_offset; ++i){
80 double *zeta =
new double[num_offset];
81 double *zeta_old =
new double[num_offset];
82 double *alpha =
new double[num_offset];
83 double *beta =
new double[num_offset];
86 int num_offset_now = num_offset;
87 for (
int i=0; i<num_offset; i++) {
88 zeta[i] = zeta_old[i] = 1.0;
94 bool reliable =
false;
95 for (
int j=0; j<num_offset; j++)
112 for (
int i=0; i<num_offset; i++){
118 for (
int i=0; i<num_offset; i++)
146 for (
int i=0; i<num_offset; i++) {
158 for (
int i=0; i<num_offset; i++) {
159 rNorm[i] = sqrt(r2[i]);
160 r0Norm[i] = rNorm[i];
174 printfQuda(
"MultiShift CG: %d iterations, <r,r> = %e, |r|/|b| = %e\n", k, r2[0], sqrt(r2[0]/b2));
183 updateAlphaZeta(alpha, zeta, zeta_old, r2, beta, pAp, offset, num_offset_now, j_low);
187 r2[0] = real(cg_norm);
188 double zn = imag(cg_norm);
191 rNorm[0] = sqrt(r2[0]);
192 for (
int j=1; j<num_offset_now; j++) rNorm[j] = rNorm[0] * zeta[j];
194 int updateX=0, updateR=0;
195 int reliable_shift = -1;
196 for (
int j=num_offset_now-1; j>=0; j--) {
197 if (rNorm[j] > maxrx[j]) maxrx[j] = rNorm[j];
198 if (rNorm[j] > maxrr[j]) maxrr[j] = rNorm[j];
199 updateX = (rNorm[j] < delta*r0Norm[j] && r0Norm[j] <= maxrx[j]) ? 1 : updateX;
200 updateR = ((rNorm[j] < delta*maxrr[j] && r0Norm[j] <= maxrr[j]) || updateX) ? 1 : updateR;
201 if ((updateX || updateR) && reliable_shift == -1) reliable_shift = j;
204 if ( !(updateR || updateX) || !reliable) {
206 beta[0] = zn / r2_old;
208 axpyZpbxCuda(alpha[0], *p[0], *x_sloppy[0], *r_sloppy, beta[0]);
210 for (
int j=1; j<num_offset_now; j++) {
211 beta[j] = beta[j_low] * zeta[j] * alpha[j] / (zeta_old[j] * alpha[j_low]);
213 axpyBzpcxCuda(alpha[j], *p[j], *x_sloppy[j], zeta[j], *r_sloppy, beta[j]);
216 for (
int j=0; j<num_offset_now; j++) {
217 axpyCuda(alpha[j], *p[j], *x_sloppy[j]);
222 mat(*r, *y[0], *x[0]);
226 for (
int j=1; j<num_offset_now; j++) r2[j] = zeta[j] * zeta[j] * r2[0];
227 for (
int j=0; j<num_offset_now; j++)
zeroCuda(*x_sloppy[j]);
232 if (sqrt(r2[reliable_shift]) > r0Norm[reliable_shift]) {
233 warningQuda(
"MultiShiftCG: Shift %d, updated residual %e is greater than previous residual %e",
234 reliable_shift, sqrt(r2[reliable_shift]), r0Norm[reliable_shift]);
237 if (reliable_shift == j_low)
break;
241 beta[0] = r2[0] / r2_old;
242 xpayCuda(*r_sloppy, beta[0], *p[0]);
243 for (
int j=1; j<num_offset_now; j++) {
244 beta[j] = beta[j_low] * zeta[j] * alpha[j] / (zeta_old[j] * alpha[j_low]);
245 axpbyCuda(zeta[j], *r_sloppy, beta[j], *p[j]);
249 int m = reliable_shift;
250 rNorm[m] = sqrt(r2[0]) * zeta[m];
253 r0Norm[m] = rNorm[m];
258 for (
int j=1; j<num_offset_now; j++) {
259 r2[j] = zeta[j] * zeta[j] * r2[0];
260 if (r2[j] < stop[j]) {
262 printfQuda(
"MultiShift CG: Shift %d converged after %d iterations\n", j, k+1);
270 printfQuda(
"MultiShift CG: %d iterations, <r,r> = %e, |r|/|b| = %e\n", k, r2[0], sqrt(r2[0]/b2));
274 for (
int i=0; i<num_offset; i++) {
276 if (reliable)
xpyCuda(*y[i], *x[i]);
290 for(
int i=0; i < num_offset; i++) {
295 axpyCuda(offset[i]-offset[0], *x[i], *r);
299 #if (__COMPUTE_CAPABILITY__ >= 200)
307 printfQuda(
"MultiShift CG: Converged after %d iterations\n", k);
308 for(
int i=0; i < num_offset; i++) {
309 printfQuda(
" shift=%d, relative residua: iterated = %e, true = %e\n",
322 if (&tmp2 != &tmp1)
delete tmp2_p;
325 for (
int i=0; i<num_offset; i++)
delete p[i];
329 for (
int i=0; i<num_offset; i++)
delete y[i];
336 for (
int i=0; i<num_offset; i++)
delete x_sloppy[i];