QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
inv_multi_cg_quda.cpp
Go to the documentation of this file.
1 #include <cstdio>
2 #include <cstdlib>
3 #include <cmath>
4 #include <limits>
5 
6 #include <quda_internal.h>
7 #include <color_spinor_field.h>
8 #include <blas_quda.h>
9 #include <dslash_quda.h>
10 #include <invert_quda.h>
11 #include <util_quda.h>
12 
23 #include <worker.h>
24 
25 namespace quda {
26 
49  class ShiftUpdate : public Worker {
50 
52  std::vector<ColorSpinorField*> p;
53  std::vector<ColorSpinorField*> x;
54 
55  double *alpha;
56  double *beta;
57  double *zeta;
58  double *zeta_old;
59 
60  const int j_low;
61  int n_shift;
62 
69  int n_update;
70 
71  public:
72  ShiftUpdate(ColorSpinorField *r, std::vector<ColorSpinorField*> p, std::vector<ColorSpinorField*> x,
73  double *alpha, double *beta, double *zeta, double *zeta_old, int j_low, int n_shift) :
74  r(r), p(p), x(x), alpha(alpha), beta(beta), zeta(zeta), zeta_old(zeta_old), j_low(j_low),
75  n_shift(n_shift), n_update( (r->Nspin()==4) ? 4 : 2 ) {
76 
77  }
78  virtual ~ShiftUpdate() { }
79 
80  void updateNshift(int new_n_shift) { n_shift = new_n_shift; }
81  void updateNupdate(int new_n_update) { n_update = new_n_update; }
82 
83  // note that we can't set the stream parameter here so it is
84  // ignored. This is more of a future design direction to consider
85  void apply(const cudaStream_t &stream) {
86  static int count = 0;
87 
88 #if 0
89  // on the first call do the first half of the update
90  for (int j= (count*n_shift)/n_update+1; j<=((count+1)*n_shift)/n_update && j<n_shift; j++) {
91  beta[j] = beta[j_low] * zeta[j] * alpha[j] / ( zeta_old[j] * alpha[j_low] );
92  // update p[i] and x[i]
93  blas::axpyBzpcx(alpha[j], *(p[j]), *(x[j]), zeta[j], *r, beta[j]);
94  }
95 #else
96  int zero = (count*n_shift)/n_update+1;
97  std::vector<ColorSpinorField*> P, X;
98  for (int j= (count*n_shift)/n_update+1; j<=((count+1)*n_shift)/n_update && j<n_shift; j++) {
99  beta[j] = beta[j_low] * zeta[j] * alpha[j] / ( zeta_old[j] * alpha[j_low] );
100  P.push_back(p[j]);
101  X.push_back(x[j]);
102  }
103  if (P.size()) blas::axpyBzpcx(&alpha[zero], P, X, &zeta[zero], *r, &beta[zero]);
104 #endif
105  if (++count == n_update) count = 0;
106  }
107 
108  };
109 
110  // this is the Worker pointer that the dslash uses to launch the shifted updates
111  namespace dslash {
112  extern Worker* aux_worker;
113  }
114 
116  TimeProfile &profile)
117  : MultiShiftSolver(param, profile), mat(mat), matSloppy(matSloppy) {
118 
119  }
120 
122 
123  }
124 
128  void updateAlphaZeta(double *alpha, double *zeta, double *zeta_old,
129  const double *r2, const double *beta, const double pAp,
130  const double *offset, const int nShift, const int j_low) {
131  double alpha_old[QUDA_MAX_MULTI_SHIFT];
132  for (int j=0; j<nShift; j++) alpha_old[j] = alpha[j];
133 
134  alpha[0] = r2[0] / pAp;
135  zeta[0] = 1.0;
136  for (int j=1; j<nShift; j++) {
137  double c0 = zeta[j] * zeta_old[j] * alpha_old[j_low];
138  double c1 = alpha[j_low] * beta[j_low] * (zeta_old[j]-zeta[j]);
139  double c2 = zeta_old[j] * alpha_old[j_low] * (1.0+(offset[j]-offset[0])*alpha[j_low]);
140 
141  zeta_old[j] = zeta[j];
142  if (c1+c2 != 0.0){
143  zeta[j] = c0 / (c1 + c2);
144  }
145  else {
146  zeta[j] = 0.0;
147  }
148  if (zeta[j] != 0.0){
149  alpha[j] = alpha[j_low] * zeta[j] / zeta_old[j];
150  }
151  else {
152  alpha[j] = 0.0;
153  }
154  }
155  }
156 
157  void MultiShiftCG::operator()(std::vector<ColorSpinorField*>x, ColorSpinorField &b, std::vector<ColorSpinorField*> &p, double* r2_old_array )
158  {
159  if (checkLocation(*(x[0]), b) != QUDA_CUDA_FIELD_LOCATION)
160  errorQuda("Not supported");
161 
162  profile.TPSTART(QUDA_PROFILE_INIT);
163 
164  int num_offset = param.num_offset;
165  double *offset = param.offset;
166 
167  if (num_offset == 0) return;
168 
169  const double b2 = blas::norm2(b);
170  // Check to see that we're not trying to invert on a zero-field source
171  if(b2 == 0){
172  profile.TPSTOP(QUDA_PROFILE_INIT);
173  printfQuda("Warning: inverting on zero-field source\n");
174  for(int i=0; i<num_offset; ++i){
175  *(x[i]) = b;
176  param.true_res_offset[i] = 0.0;
177  param.true_res_hq_offset[i] = 0.0;
178  }
179  return;
180  }
181 
182  bool exit_early = false;
183  bool mixed = param.precision_sloppy != param.precision;
184  // whether we will switch to refinement on unshifted system after other shifts have converged
185  bool zero_refinement = param.precision_refinement_sloppy != param.precision;
186 
187  // this is the limit of precision possible
188  const double sloppy_tol= param.precision_sloppy == 8 ? std::numeric_limits<double>::epsilon() :
190  const double fine_tol = pow(10.,(-2*(int)b.Precision()+1));
191  std::unique_ptr<double[]> prec_tol(new double[num_offset]);
192 
193  prec_tol[0] = mixed ? sloppy_tol : fine_tol;
194  for (int i=1; i<num_offset; i++) {
195  prec_tol[i] = std::min(sloppy_tol,std::max(fine_tol,sqrt(param.tol_offset[i]*sloppy_tol)));
196  }
197 
198  double zeta[QUDA_MAX_MULTI_SHIFT];
199  double zeta_old[QUDA_MAX_MULTI_SHIFT];
200  double alpha[QUDA_MAX_MULTI_SHIFT];
201  double beta[QUDA_MAX_MULTI_SHIFT];
202 
203  int j_low = 0;
204  int num_offset_now = num_offset;
205  for (int i=0; i<num_offset; i++) {
206  zeta[i] = zeta_old[i] = 1.0;
207  beta[i] = 0.0;
208  alpha[i] = 1.0;
209  }
210 
211  // flag whether we will be using reliable updates or not
212  bool reliable = false;
213  for (int j=0; j<num_offset; j++)
214  if (param.tol_offset[j] < param.delta) reliable = true;
215 
216 
217  auto *r = new cudaColorSpinorField(b);
218  std::vector<ColorSpinorField*> x_sloppy;
219  x_sloppy.resize(num_offset);
220  std::vector<ColorSpinorField*> y;
221 
223  csParam.create = QUDA_ZERO_FIELD_CREATE;
224 
225  if (reliable) {
226  y.resize(num_offset);
227  for (int i=0; i<num_offset; i++) y[i] = new cudaColorSpinorField(*r, csParam);
228  }
229 
231 
232  cudaColorSpinorField *r_sloppy;
233  if (param.precision_sloppy == x[0]->Precision()) {
234  r_sloppy = r;
235  } else {
236  csParam.create = QUDA_COPY_FIELD_CREATE;
237  r_sloppy = new cudaColorSpinorField(*r, csParam);
238  }
239 
240  if (param.precision_sloppy == x[0]->Precision() ||
242  for (int i=0; i<num_offset; i++){
243  x_sloppy[i] = x[i];
244  blas::zero(*x_sloppy[i]);
245  }
246  } else {
247  csParam.create = QUDA_ZERO_FIELD_CREATE;
248  for (int i=0; i<num_offset; i++)
249  x_sloppy[i] = new cudaColorSpinorField(*x[i], csParam);
250  }
251 
252  p.resize(num_offset);
253  for (int i=0; i<num_offset; i++) p[i] = new cudaColorSpinorField(*r_sloppy);
254 
255  csParam.create = QUDA_ZERO_FIELD_CREATE;
256  auto* Ap = new cudaColorSpinorField(*r_sloppy, csParam);
257 
258  cudaColorSpinorField tmp1(*Ap, csParam);
259 
260  // tmp2 only needed for multi-gpu Wilson-like kernels
261  cudaColorSpinorField *tmp2_p = !mat.isStaggered() ?
262  new cudaColorSpinorField(*Ap, csParam) : &tmp1;
263  cudaColorSpinorField &tmp2 = *tmp2_p;
264 
265  // additional high-precision temporary if Wilson and mixed-precision
266  csParam.setPrecision(param.precision);
267  cudaColorSpinorField *tmp3_p =
269  new cudaColorSpinorField(*r, csParam) : &tmp1;
270  cudaColorSpinorField &tmp3 = *tmp3_p;
271 
272  profile.TPSTOP(QUDA_PROFILE_INIT);
274 
275  // stopping condition of each shift
276  double stop[QUDA_MAX_MULTI_SHIFT];
277  double r2[QUDA_MAX_MULTI_SHIFT];
278  int iter[QUDA_MAX_MULTI_SHIFT+1]; // record how many iterations for each shift
279  for (int i=0; i<num_offset; i++) {
280  r2[i] = b2;
282  iter[i] = 0;
283  }
284  // this initial condition ensures that the heaviest shift can be removed
285  iter[num_offset] = 1;
286 
287  double r2_old;
288  double pAp;
289 
290  double rNorm[QUDA_MAX_MULTI_SHIFT];
291  double r0Norm[QUDA_MAX_MULTI_SHIFT];
292  double maxrx[QUDA_MAX_MULTI_SHIFT];
293  double maxrr[QUDA_MAX_MULTI_SHIFT];
294  for (int i=0; i<num_offset; i++) {
295  rNorm[i] = sqrt(r2[i]);
296  r0Norm[i] = rNorm[i];
297  maxrx[i] = rNorm[i];
298  maxrr[i] = rNorm[i];
299  }
300  double delta = param.delta;
301 
302  // this parameter determines how many consective reliable update
303  // reisudal increases we tolerate before terminating the solver,
304  // i.e., how long do we want to keep trying to converge
305  const int maxResIncrease = param.max_res_increase; // check if we reached the limit of our tolerance
306  const int maxResIncreaseTotal = param.max_res_increase_total;
307 
308  int resIncrease = 0;
309  int resIncreaseTotal[QUDA_MAX_MULTI_SHIFT];
310  for (int i=0; i<num_offset; i++) {
311  resIncreaseTotal[i]=0;
312  }
313 
314  int k = 0;
315  int rUpdate = 0;
316  blas::flops = 0;
317 
318  bool aux_update = false;
319 
320  // now create the worker class for updating the shifted solutions and gradient vectors
321  ShiftUpdate shift_update(r_sloppy, p, x_sloppy, alpha, beta, zeta, zeta_old, j_low, num_offset_now);
322 
324  profile.TPSTART(QUDA_PROFILE_COMPUTE);
325 
326  if (getVerbosity() >= QUDA_VERBOSE)
327  printfQuda("MultiShift CG: %d iterations, <r,r> = %e, |r|/|b| = %e\n", k, r2[0], sqrt(r2[0]/b2));
328 
329  while ( !convergence(r2, stop, num_offset_now) && !exit_early && k < param.maxiter) {
330 
331  if (aux_update) dslash::aux_worker = &shift_update;
332  matSloppy(*Ap, *p[0], tmp1, tmp2);
333  dslash::aux_worker = nullptr;
334  aux_update = false;
335 
336  // update number of shifts now instead of end of previous
337  // iteration so that all shifts are updated during the dslash
338  shift_update.updateNshift(num_offset_now);
339 
340  // at some point we should curry these into the Dirac operator
341  if (r->Nspin()==4) pAp = blas::axpyReDot(offset[0], *p[0], *Ap);
342  else pAp = blas::reDotProduct(*p[0], *Ap);
343 
344  // compute zeta and alpha
345  for (int j=1; j<num_offset_now; j++) r2_old_array[j] = zeta[j] * zeta[j] * r2[0];
346  updateAlphaZeta(alpha, zeta, zeta_old, r2, beta, pAp, offset, num_offset_now, j_low);
347 
348  r2_old = r2[0];
349  r2_old_array[0] = r2_old;
350 
351  Complex cg_norm = blas::axpyCGNorm(-alpha[j_low], *Ap, *r_sloppy);
352  r2[0] = real(cg_norm);
353  double zn = imag(cg_norm);
354 
355  // reliable update conditions
356  rNorm[0] = sqrt(r2[0]);
357  for (int j=1; j<num_offset_now; j++) rNorm[j] = rNorm[0] * zeta[j];
358 
359  int updateX=0, updateR=0;
360  //fixme: with the current implementation of the reliable update it is sufficient to trigger it only for shift 0
361  //fixme: The loop below is unnecessary but I don't want to delete it as we still might find a better reliable update
362  int reliable_shift = -1; // this is the shift that sets the reliable_shift
363  for (int j=0; j>=0; j--) {
364  if (rNorm[j] > maxrx[j]) maxrx[j] = rNorm[j];
365  if (rNorm[j] > maxrr[j]) maxrr[j] = rNorm[j];
366  updateX = (rNorm[j] < delta*r0Norm[j] && r0Norm[j] <= maxrx[j]) ? 1 : updateX;
367  updateR = ((rNorm[j] < delta*maxrr[j] && r0Norm[j] <= maxrr[j]) || updateX) ? 1 : updateR;
368  if ((updateX || updateR) && reliable_shift == -1) reliable_shift = j;
369  }
370 
371  if ( !(updateR || updateX) || !reliable) {
372  //beta[0] = r2[0] / r2_old;
373  beta[0] = zn / r2_old;
374  // update p[0] and x[0]
375  blas::axpyZpbx(alpha[0], *p[0], *x_sloppy[0], *r_sloppy, beta[0]);
376 
377  // this should trigger the shift update in the subsequent sloppy dslash
378  aux_update = true;
379  /*
380  for (int j=1; j<num_offset_now; j++) {
381  beta[j] = beta[j_low] * zeta[j] * alpha[j] / (zeta_old[j] * alpha[j_low]);
382  // update p[i] and x[i]
383  blas::axpyBzpcx(alpha[j], *p[j], *x_sloppy[j], zeta[j], *r_sloppy, beta[j]);
384  }
385  */
386  } else {
387  for (int j=0; j<num_offset_now; j++) {
388  blas::axpy(alpha[j], *p[j], *x_sloppy[j]);
389  blas::xpy(*x_sloppy[j], *y[j]);
390  }
391 
392  mat(*r, *y[0], *x[0], tmp3); // here we can use x as tmp
393  if (r->Nspin()==4) blas::axpy(offset[0], *y[0], *r);
394 
395  r2[0] = blas::xmyNorm(b, *r);
396  for (int j=1; j<num_offset_now; j++) r2[j] = zeta[j] * zeta[j] * r2[0];
397  for (int j=0; j<num_offset_now; j++) blas::zero(*x_sloppy[j]);
398 
399  blas::copy(*r_sloppy, *r);
400 
401  // break-out check if we have reached the limit of the precision
402  if (sqrt(r2[reliable_shift]) > r0Norm[reliable_shift]) { // reuse r0Norm for this
403  resIncrease++;
404  resIncreaseTotal[reliable_shift]++;
405  warningQuda("MultiShiftCG: Shift %d, updated residual %e is greater than previous residual %e (total #inc %i)",
406  reliable_shift, sqrt(r2[reliable_shift]), r0Norm[reliable_shift], resIncreaseTotal[reliable_shift]);
407 
408  if (resIncrease > maxResIncrease or resIncreaseTotal[reliable_shift] > maxResIncreaseTotal) {
409  warningQuda("MultiShiftCG: solver exiting due to too many true residual norm increases");
410  break;
411  }
412  } else {
413  resIncrease = 0;
414  }
415 
416  // explicitly restore the orthogonality of the gradient vector
417  for (int j=0; j<num_offset_now; j++) {
418  Complex rp = blas::cDotProduct(*r_sloppy, *p[j]) / (r2[0]);
419  blas::caxpy(-rp, *r_sloppy, *p[j]);
420  }
421 
422  // update beta and p
423  beta[0] = r2[0] / r2_old;
424  blas::xpay(*r_sloppy, beta[0], *p[0]);
425  for (int j=1; j<num_offset_now; j++) {
426  beta[j] = beta[j_low] * zeta[j] * alpha[j] / (zeta_old[j] * alpha[j_low]);
427  blas::axpby(zeta[j], *r_sloppy, beta[j], *p[j]);
428  }
429 
430  // update reliable update parameters for the system that triggered the update
431  int m = reliable_shift;
432  rNorm[m] = sqrt(r2[0]) * zeta[m];
433  maxrr[m] = rNorm[m];
434  maxrx[m] = rNorm[m];
435  r0Norm[m] = rNorm[m];
436  rUpdate++;
437  }
438 
439  // now we can check if any of the shifts have converged and remove them
440  int converged = 0;
441  for (int j=num_offset_now-1; j>=1; j--) {
442  if (zeta[j] == 0.0 && r2[j+1] < stop[j+1]) {
443  converged++;
444  if (getVerbosity() >= QUDA_VERBOSE)
445  printfQuda("MultiShift CG: Shift %d converged after %d iterations\n", j, k+1);
446  } else {
447  r2[j] = zeta[j] * zeta[j] * r2[0];
448  // only remove if shift above has converged
449  if ((r2[j] < stop[j] || sqrt(r2[j] / b2) < prec_tol[j]) && iter[j+1] ) {
450  converged++;
451  iter[j] = k+1;
452  if (getVerbosity() >= QUDA_VERBOSE)
453  printfQuda("MultiShift CG: Shift %d converged after %d iterations\n", j, k+1);
454  }
455  }
456  }
457  num_offset_now -= converged;
458 
459  // exit early so that we can finish of shift 0 using CG and allowing for mixed precison refinement
460  if ( (mixed || zero_refinement) and param.compute_true_res and num_offset_now==1) {
461  exit_early=true;
462  num_offset_now--;
463  }
464 
465  // this ensure we do the update on any shifted systems that
466  // happen to converge when the un-shifted system converges
467  if ( (convergence(r2, stop, num_offset_now) || exit_early || k == param.maxiter) && aux_update == true) {
468  if (getVerbosity() >= QUDA_VERBOSE)
469  printfQuda("Convergence of unshifted system so trigger shiftUpdate\n");
470 
471  // set worker to do all updates at once
472  shift_update.updateNupdate(1);
473  shift_update.apply(0);
474 
475  for (int j=0; j<num_offset_now; j++) iter[j] = k+1;
476  }
477 
478  k++;
479 
480  if (getVerbosity() >= QUDA_VERBOSE)
481  printfQuda("MultiShift CG: %d iterations, <r,r> = %e, |r|/|b| = %e\n", k, r2[0], sqrt(r2[0]/b2));
482  }
483 
484  for (int i=0; i<num_offset; i++) {
485  if (iter[i] == 0) iter[i] = k;
486  blas::copy(*x[i], *x_sloppy[i]);
487  if (reliable) blas::xpy(*y[i], *x[i]);
488  }
489 
492 
493  if (getVerbosity() >= QUDA_VERBOSE)
494  printfQuda("MultiShift CG: Reliable updates = %d\n", rUpdate);
495 
496  if (k==param.maxiter) warningQuda("Exceeded maximum iterations %d\n", param.maxiter);
497 
499  double gflops = (blas::flops + mat.flops() + matSloppy.flops())*1e-9;
500  param.gflops = gflops;
501  param.iter += k;
502 
503  if (param.compute_true_res){
504  // only allocate temporaries if necessary
505  csParam.setPrecision(param.precision);
506  ColorSpinorField *tmp4_p = reliable ? y[0] : tmp1.Precision() == x[0]->Precision() ? &tmp1 : ColorSpinorField::Create(csParam);
507  ColorSpinorField *tmp5_p = mat.isStaggered() ? tmp4_p :
508  reliable ? y[1] : (tmp2.Precision() == x[0]->Precision() && &tmp1 != tmp2_p) ? tmp2_p : ColorSpinorField::Create(csParam);
509 
510  for (int i = 0; i < num_offset; i++) {
511  // only calculate true residual if we need to:
512  // 1.) For higher shifts if we did not use mixed precision
513  // 2.) For shift 0 if we did not exit early (we went to the full solution)
514  if ( (i > 0 and not mixed) or (i == 0 and not exit_early) ) {
515  mat(*r, *x[i], *tmp4_p, *tmp5_p);
516  if (r->Nspin() == 4) {
517  blas::axpy(offset[i], *x[i], *r); // Offset it.
518  } else if (i != 0) {
519  blas::axpy(offset[i] - offset[0], *x[i], *r); // Offset it.
520  }
521  double true_res = blas::xmyNorm(b, *r);
522  param.true_res_offset[i] = sqrt(true_res / b2);
523  param.iter_res_offset[i] = sqrt(r2[i] / b2);
525  } else {
526  param.iter_res_offset[i] = sqrt(r2[i] / b2);
527  param.true_res_offset[i] = std::numeric_limits<double>::infinity();
528  param.true_res_hq_offset[i] = std::numeric_limits<double>::infinity();
529  }
530  }
531 
532  if (getVerbosity() >= QUDA_SUMMARIZE) {
533  printfQuda("MultiShift CG: Converged after %d iterations\n", k);
534  for (int i = 0; i < num_offset; i++) {
535  if(std::isinf(param.true_res_offset[i])){
536  printfQuda(" shift=%d, %d iterations, relative residual: iterated = %e\n",
537  i, iter[i], param.iter_res_offset[i]);
538  } else {
539  printfQuda(" shift=%d, %d iterations, relative residual: iterated = %e, true = %e\n",
540  i, iter[i], param.iter_res_offset[i], param.true_res_offset[i]);
541  }
542  }
543  }
544 
545  if (tmp5_p != tmp4_p && tmp5_p != tmp2_p && (reliable ? tmp5_p != y[1] : 1)) delete tmp5_p;
546  if (tmp4_p != &tmp1 && (reliable ? tmp4_p != y[0] : 1)) delete tmp4_p;
547  } else {
548  if (getVerbosity() >= QUDA_SUMMARIZE)
549  {
550  printfQuda("MultiShift CG: Converged after %d iterations\n", k);
551  for (int i = 0; i < num_offset; i++) {
552  param.iter_res_offset[i] = sqrt(r2[i] / b2);
553  printfQuda(" shift=%d, %d iterations, relative residual: iterated = %e\n",
554  i, iter[i], param.iter_res_offset[i]);
555  }
556  }
557  }
558 
559  // reset the flops counters
560  blas::flops = 0;
561  mat.flops();
562  matSloppy.flops();
563 
565  profile.TPSTART(QUDA_PROFILE_FREE);
566 
567  if (&tmp3 != &tmp1) delete tmp3_p;
568  if (&tmp2 != &tmp1) delete tmp2_p;
569 
570  if (r_sloppy->Precision() != r->Precision()) delete r_sloppy;
571  for (int i=0; i<num_offset; i++)
572  if (x_sloppy[i]->Precision() != x[i]->Precision()) delete x_sloppy[i];
573 
574  delete r;
575 
576  if (reliable) for (int i=0; i<num_offset; i++) delete y[i];
577 
578  delete Ap;
579 
580  profile.TPSTOP(QUDA_PROFILE_FREE);
581 
582  return;
583  }
584 
585 } // namespace quda
cudaColorSpinorField * tmp2
double iter_res_offset[QUDA_MAX_MULTI_SHIFT]
Definition: invert_quda.h:184
void updateNshift(int new_n_shift)
void setPrecision(QudaPrecision precision, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION, bool force_native=false)
cudaColorSpinorField * tmp1
std::vector< ColorSpinorField * > p
void axpyZpbx(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, double b)
Definition: blas_quda.cu:552
#define QUDA_MAX_MULTI_SHIFT
Maximum number of shifts supported by the multi-shift solver. This number may be changed if need be...
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define errorQuda(...)
Definition: util_quda.h:121
double norm2(const ColorSpinorField &a)
Definition: reduce_quda.cu:721
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:120
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
Definition: reduce_quda.cu:764
double epsilon
Definition: test_util.cpp:1649
const DiracMatrix & mat
Definition: invert_quda.h:1124
cudaStream_t * stream
QudaPrecision precision_refinement_sloppy
Definition: invert_quda.h:148
static ColorSpinorField * Create(const ColorSpinorParam &param)
void updateAlphaZeta(double *alpha, double *zeta, double *zeta_old, const double *r2, const double *beta, const double pAp, const double *offset, const int nShift, const int j_low)
double reDotProduct(ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:728
Complex axpyCGNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:796
double tol_offset[QUDA_MAX_MULTI_SHIFT]
Definition: invert_quda.h:175
double offset[QUDA_MAX_MULTI_SHIFT]
Definition: invert_quda.h:172
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
Definition: copy_quda.cu:355
TimeProfile & profile
Definition: invert_quda.h:1107
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:75
int max_res_increase_total
Definition: invert_quda.h:96
int Nspin
Definition: blas_test.cu:45
void xpay(ColorSpinorField &x, double a, ColorSpinorField &y)
Definition: blas_quda.h:37
QudaGaugeParam param
Definition: pack_test.cpp:17
bool convergence(const double *r2, const double *r2_tol, int n) const
Definition: solver.cpp:295
double Last(QudaProfileType idx)
Definition: timer.h:251
void updateNupdate(int new_n_update)
QudaResidualType residual_type
Definition: invert_quda.h:49
const DiracMatrix & matSloppy
Definition: invert_quda.h:1125
double true_res_hq_offset[QUDA_MAX_MULTI_SHIFT]
Definition: invert_quda.h:187
Worker * aux_worker
Definition: dslash_quda.cu:87
void apply(const cudaStream_t &stream)
static double stopping(double tol, double b2, QudaResidualType residual_type)
Set the solver L2 stopping condition.
Definition: solver.cpp:206
ColorSpinorParam csParam
Definition: pack_test.cpp:24
void axpy(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:35
#define warningQuda(...)
Definition: util_quda.h:133
#define checkLocation(...)
__host__ __device__ ValueType pow(ValueType x, ExponentType e)
Definition: complex_quda.h:111
void axpyBzpcx(double a, ColorSpinorField &x, ColorSpinorField &y, double b, ColorSpinorField &z, double c)
Definition: blas_quda.cu:541
std::vector< ColorSpinorField * > x
void operator()(std::vector< ColorSpinorField *>x, ColorSpinorField &b, std::vector< ColorSpinorField *> &p, double *r2_old_array)
Run multi-shift and return Krylov-space at the end of the solve in p and r2_old_arry.
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
Definition: reduce_quda.cu:809
double true_res_offset[QUDA_MAX_MULTI_SHIFT]
Definition: invert_quda.h:181
int X[4]
Definition: covdev_test.cpp:70
std::complex< double > Complex
Definition: quda_internal.h:46
double axpyReDot(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:740
ColorSpinorField * r
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:512
void zero(ColorSpinorField &a)
Definition: blas_quda.cu:472
QudaPrecision precision
Definition: invert_quda.h:142
unsigned long long flops() const
Definition: dirac_quda.h:1119
MultiShiftCG(DiracMatrix &mat, DiracMatrix &matSloppy, SolverParam &param, TimeProfile &profile)
#define printfQuda(...)
Definition: util_quda.h:115
int reliable(double &rNorm, double &maxrx, double &maxrr, const double &r2, const double &delta)
unsigned long long flops
Definition: blas_quda.cu:22
void xpy(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:33
void axpby(double a, ColorSpinorField &x, double b, ColorSpinorField &y)
Definition: blas_quda.h:36
QudaPrecision precision_sloppy
Definition: invert_quda.h:145
bool use_sloppy_partial_accumulator
Definition: invert_quda.h:76
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
QudaPrecision Precision() const
__device__ unsigned int count[QUDA_MAX_MULTI_REDUCE]
Definition: cub_helper.cuh:90
__device__ __host__ void zero(vector_type< scalar, n > &v)
Definition: cub_helper.cuh:54
bool isStaggered() const
Definition: dirac_quda.h:1128
ShiftUpdate(ColorSpinorField *r, std::vector< ColorSpinorField *> p, std::vector< ColorSpinorField *> x, double *alpha, double *beta, double *zeta, double *zeta_old, int j_low, int n_shift)
void updateR()
update the radius for halos.