QUDA  0.9.0
inv_multi_cg_quda.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 
5 #include <quda_internal.h>
6 #include <color_spinor_field.h>
7 #include <blas_quda.h>
8 #include <dslash_quda.h>
9 #include <invert_quda.h>
10 #include <util_quda.h>
11 
22 #include <worker.h>
23 
24 namespace quda {
25 
48  class ShiftUpdate : public Worker {
49 
51  std::vector<ColorSpinorField*> p;
52  std::vector<ColorSpinorField*> x;
53 
54  double *alpha;
55  double *beta;
56  double *zeta;
57  double *zeta_old;
58 
59  const int j_low;
60  int n_shift;
61 
68  int n_update;
69 
70  public:
71  ShiftUpdate(ColorSpinorField *r, std::vector<ColorSpinorField*> p, std::vector<ColorSpinorField*> x,
72  double *alpha, double *beta, double *zeta, double *zeta_old, int j_low, int n_shift) :
74  n_shift(n_shift), n_update( (r->Nspin()==4) ? 4 : 2 ) {
75 
76  }
77  virtual ~ShiftUpdate() { }
78 
79  void updateNshift(int new_n_shift) { n_shift = new_n_shift; }
80  void updateNupdate(int new_n_update) { n_update = new_n_update; }
81 
82  // note that we can't set the stream parameter here so it is
83  // ignored. This is more of a future design direction to consider
84  void apply(const cudaStream_t &stream) {
85  static int count = 0;
86 
87 #if 0
88  // on the first call do the first half of the update
89  for (int j= (count*n_shift)/n_update+1; j<=((count+1)*n_shift)/n_update && j<n_shift; j++) {
90  beta[j] = beta[j_low] * zeta[j] * alpha[j] / ( zeta_old[j] * alpha[j_low] );
91  // update p[i] and x[i]
92  blas::axpyBzpcx(alpha[j], *(p[j]), *(x[j]), zeta[j], *r, beta[j]);
93  }
94 #else
95  int zero = (count*n_shift)/n_update+1;
96  std::vector<ColorSpinorField*> P, X;
97  for (int j= (count*n_shift)/n_update+1; j<=((count+1)*n_shift)/n_update && j<n_shift; j++) {
98  beta[j] = beta[j_low] * zeta[j] * alpha[j] / ( zeta_old[j] * alpha[j_low] );
99  P.push_back(p[j]);
100  X.push_back(x[j]);
101  }
102  if (P.size()) blas::axpyBzpcx(&alpha[zero], P, X, &zeta[zero], *r, &beta[zero]);
103 #endif
104  if (++count == n_update) count = 0;
105  }
106 
107  };
108 
109  // this is the Worker pointer that the dslash uses to launch the shifted updates
110  namespace dslash {
111  extern Worker* aux_worker;
112  }
113 
115  TimeProfile &profile)
116  : MultiShiftSolver(param, profile), mat(mat), matSloppy(matSloppy) {
117 
118  }
119 
121 
122  }
123 
127  void updateAlphaZeta(double *alpha, double *zeta, double *zeta_old,
128  const double *r2, const double *beta, const double pAp,
129  const double *offset, const int nShift, const int j_low) {
130  double alpha_old[QUDA_MAX_MULTI_SHIFT];
131  for (int j=0; j<nShift; j++) alpha_old[j] = alpha[j];
132 
133  alpha[0] = r2[0] / pAp;
134  zeta[0] = 1.0;
135  for (int j=1; j<nShift; j++) {
136  double c0 = zeta[j] * zeta_old[j] * alpha_old[j_low];
137  double c1 = alpha[j_low] * beta[j_low] * (zeta_old[j]-zeta[j]);
138  double c2 = zeta_old[j] * alpha_old[j_low] * (1.0+(offset[j]-offset[0])*alpha[j_low]);
139 
140  zeta_old[j] = zeta[j];
141  if (c1+c2 != 0.0){
142  zeta[j] = c0 / (c1 + c2);
143  }
144  else {
145  zeta[j] = 0.0;
146  }
147  if (zeta[j] != 0.0){
148  alpha[j] = alpha[j_low] * zeta[j] / zeta_old[j];
149  }
150  else {
151  alpha[j] = 0.0;
152  }
153  }
154  }
155 
156  void MultiShiftCG::operator()(std::vector<ColorSpinorField*>x, ColorSpinorField &b)
157  {
158  if (checkLocation(*(x[0]), b) != QUDA_CUDA_FIELD_LOCATION)
159  errorQuda("Not supported");
160 
161  profile.TPSTART(QUDA_PROFILE_INIT);
162 
163  int num_offset = param.num_offset;
164  double *offset = param.offset;
165 
166  if (num_offset == 0) return;
167 
168  const double b2 = blas::norm2(b);
169  // Check to see that we're not trying to invert on a zero-field source
170  if(b2 == 0){
171  profile.TPSTOP(QUDA_PROFILE_INIT);
172  printfQuda("Warning: inverting on zero-field source\n");
173  for(int i=0; i<num_offset; ++i){
174  *(x[i]) = b;
175  param.true_res_offset[i] = 0.0;
176  param.true_res_hq_offset[i] = 0.0;
177  }
178  return;
179  }
180 
181  // this is the limit of precision possible
182  const double prec_tol = pow(10.,(-2*(int)b.Precision()+1));
183 
184  double *zeta = new double[num_offset];
185  double *zeta_old = new double[num_offset];
186  double *alpha = new double[num_offset];
187  double *beta = new double[num_offset];
188 
189  int j_low = 0;
190  int num_offset_now = num_offset;
191  for (int i=0; i<num_offset; i++) {
192  zeta[i] = zeta_old[i] = 1.0;
193  beta[i] = 0.0;
194  alpha[i] = 1.0;
195  }
196 
197  // flag whether we will be using reliable updates or not
198  bool reliable = false;
199  for (int j=0; j<num_offset; j++)
200  if (param.tol_offset[j] < param.delta) reliable = true;
201 
202 
204  std::vector<ColorSpinorField*> x_sloppy;
205  x_sloppy.resize(num_offset);
206  std::vector<ColorSpinorField*> y;
207 
210 
211  if (reliable) {
212  y.resize(num_offset);
213  for (int i=0; i<num_offset; i++) y[i] = new cudaColorSpinorField(*r, csParam);
214  }
215 
216  csParam.setPrecision(param.precision_sloppy);
217 
218  cudaColorSpinorField *r_sloppy;
219  if (param.precision_sloppy == x[0]->Precision()) {
220  r_sloppy = r;
221  } else {
223  r_sloppy = new cudaColorSpinorField(*r, csParam);
224  }
225 
226  if (param.precision_sloppy == x[0]->Precision() ||
228  for (int i=0; i<num_offset; i++){
229  x_sloppy[i] = x[i];
230  blas::zero(*x_sloppy[i]);
231  }
232  } else {
234  for (int i=0; i<num_offset; i++)
235  x_sloppy[i] = new cudaColorSpinorField(*x[i], csParam);
236  }
237 
238  std::vector<ColorSpinorField*> p;
239  p.resize(num_offset);
240  for (int i=0; i<num_offset; i++) p[i] = new cudaColorSpinorField(*r_sloppy);
241 
243  cudaColorSpinorField* Ap = new cudaColorSpinorField(*r_sloppy, csParam);
244 
246 
247  // tmp2 only needed for multi-gpu Wilson-like kernels
248  cudaColorSpinorField *tmp2_p = !mat.isStaggered() ?
249  new cudaColorSpinorField(*Ap, csParam) : &tmp1;
250  cudaColorSpinorField &tmp2 = *tmp2_p;
251 
252  // additional high-precision temporary if Wilson and mixed-precision
253  csParam.setPrecision(param.precision);
254  cudaColorSpinorField *tmp3_p =
256  new cudaColorSpinorField(*r, csParam) : &tmp1;
257  cudaColorSpinorField &tmp3 = *tmp3_p;
258 
259  profile.TPSTOP(QUDA_PROFILE_INIT);
261 
262  // stopping condition of each shift
263  double stop[QUDA_MAX_MULTI_SHIFT];
264  double r2[QUDA_MAX_MULTI_SHIFT];
265  int iter[QUDA_MAX_MULTI_SHIFT+1]; // record how many iterations for each shift
266  for (int i=0; i<num_offset; i++) {
267  r2[i] = b2;
269  iter[i] = 0;
270  }
271  // this initial condition ensures that the heaviest shift can be removed
272  iter[num_offset] = 1;
273 
274  double r2_old;
275  double pAp;
276 
277  double rNorm[QUDA_MAX_MULTI_SHIFT];
278  double r0Norm[QUDA_MAX_MULTI_SHIFT];
279  double maxrx[QUDA_MAX_MULTI_SHIFT];
280  double maxrr[QUDA_MAX_MULTI_SHIFT];
281  for (int i=0; i<num_offset; i++) {
282  rNorm[i] = sqrt(r2[i]);
283  r0Norm[i] = rNorm[i];
284  maxrx[i] = rNorm[i];
285  maxrr[i] = rNorm[i];
286  }
287  double delta = param.delta;
288 
289  // this parameter determines how many consective reliable update
290  // reisudal increases we tolerate before terminating the solver,
291  // i.e., how long do we want to keep trying to converge
292  const int maxResIncrease = param.max_res_increase; // check if we reached the limit of our tolerance
293  const int maxResIncreaseTotal = param.max_res_increase_total;
294 
295  int resIncrease = 0;
296  int resIncreaseTotal[QUDA_MAX_MULTI_SHIFT];
297  for (int i=0; i<num_offset; i++) {
298  resIncreaseTotal[i]=0;
299  }
300 
301  int k = 0;
302  int rUpdate = 0;
303  blas::flops = 0;
304 
305  bool aux_update = false;
306 
307  // now create the worker class for updating the shifted solutions and gradient vectors
308  ShiftUpdate shift_update(r_sloppy, p, x_sloppy, alpha, beta, zeta, zeta_old, j_low, num_offset_now);
309 
311  profile.TPSTART(QUDA_PROFILE_COMPUTE);
312 
313  if (getVerbosity() >= QUDA_VERBOSE)
314  printfQuda("MultiShift CG: %d iterations, <r,r> = %e, |r|/|b| = %e\n", k, r2[0], sqrt(r2[0]/b2));
315 
316  while ( !convergence(r2, stop, num_offset_now) && k < param.maxiter) {
317 
318  if (aux_update) dslash::aux_worker = &shift_update;
319  matSloppy(*Ap, *p[0], tmp1, tmp2);
320  dslash::aux_worker = NULL;
321  aux_update = false;
322 
323  // update number of shifts now instead of end of previous
324  // iteration so that all shifts are updated during the dslash
325  shift_update.updateNshift(num_offset_now);
326 
327  // at some point we should curry these into the Dirac operator
328  if (r->Nspin()==4) pAp = blas::axpyReDot(offset[0], *p[0], *Ap);
329  else pAp = blas::reDotProduct(*p[0], *Ap);
330 
331  // compute zeta and alpha
332  updateAlphaZeta(alpha, zeta, zeta_old, r2, beta, pAp, offset, num_offset_now, j_low);
333 
334  r2_old = r2[0];
335  Complex cg_norm = blas::axpyCGNorm(-alpha[j_low], *Ap, *r_sloppy);
336  r2[0] = real(cg_norm);
337  double zn = imag(cg_norm);
338 
339  // reliable update conditions
340  rNorm[0] = sqrt(r2[0]);
341  for (int j=1; j<num_offset_now; j++) rNorm[j] = rNorm[0] * zeta[j];
342 
343  int updateX=0, updateR=0;
344  //fixme: with the current implementation of the reliable update it is sufficient to trigger it only for shift 0
345  //fixme: The loop below is unnecessary but I don't want to delete it as we still might find a better reliable update
346  int reliable_shift = -1; // this is the shift that sets the reliable_shift
347  for (int j=0; j>=0; j--) {
348  if (rNorm[j] > maxrx[j]) maxrx[j] = rNorm[j];
349  if (rNorm[j] > maxrr[j]) maxrr[j] = rNorm[j];
350  updateX = (rNorm[j] < delta*r0Norm[j] && r0Norm[j] <= maxrx[j]) ? 1 : updateX;
351  updateR = ((rNorm[j] < delta*maxrr[j] && r0Norm[j] <= maxrr[j]) || updateX) ? 1 : updateR;
352  if ((updateX || updateR) && reliable_shift == -1) reliable_shift = j;
353  }
354 
355  if ( !(updateR || updateX) || !reliable) {
356  //beta[0] = r2[0] / r2_old;
357  beta[0] = zn / r2_old;
358  // update p[0] and x[0]
359  blas::axpyZpbx(alpha[0], *p[0], *x_sloppy[0], *r_sloppy, beta[0]);
360 
361  // this should trigger the shift update in the subsequent sloppy dslash
362  aux_update = true;
363  /*
364  for (int j=1; j<num_offset_now; j++) {
365  beta[j] = beta[j_low] * zeta[j] * alpha[j] / (zeta_old[j] * alpha[j_low]);
366  // update p[i] and x[i]
367  blas::axpyBzpcx(alpha[j], *p[j], *x_sloppy[j], zeta[j], *r_sloppy, beta[j]);
368  }
369  */
370  } else {
371  for (int j=0; j<num_offset_now; j++) {
372  blas::axpy(alpha[j], *p[j], *x_sloppy[j]);
373  blas::xpy(*x_sloppy[j], *y[j]);
374  }
375 
376  mat(*r, *y[0], *x[0], tmp3); // here we can use x as tmp
377  if (r->Nspin()==4) blas::axpy(offset[0], *y[0], *r);
378 
379  r2[0] = blas::xmyNorm(b, *r);
380  for (int j=1; j<num_offset_now; j++) r2[j] = zeta[j] * zeta[j] * r2[0];
381  for (int j=0; j<num_offset_now; j++) blas::zero(*x_sloppy[j]);
382 
383  blas::copy(*r_sloppy, *r);
384 
385  // break-out check if we have reached the limit of the precision
386  if (sqrt(r2[reliable_shift]) > r0Norm[reliable_shift]) { // reuse r0Norm for this
387  resIncrease++;
388  resIncreaseTotal[reliable_shift]++;
389  warningQuda("MultiShiftCG: Shift %d, updated residual %e is greater than previous residual %e (total #inc %i)",
390  reliable_shift, sqrt(r2[reliable_shift]), r0Norm[reliable_shift], resIncreaseTotal[reliable_shift]);
391 
392  if (resIncrease > maxResIncrease or resIncreaseTotal[reliable_shift] > maxResIncreaseTotal) {
393  warningQuda("MultiShiftCG: solver exiting due to too many true residual norm increases");
394  break;
395  }
396  } else {
397  resIncrease = 0;
398  }
399 
400  // explicitly restore the orthogonality of the gradient vector
401  for (int j=0; j<num_offset_now; j++) {
402  Complex rp = blas::cDotProduct(*r_sloppy, *p[j]) / (r2[0]);
403  blas::caxpy(-rp, *r_sloppy, *p[j]);
404  }
405 
406  // update beta and p
407  beta[0] = r2[0] / r2_old;
408  blas::xpay(*r_sloppy, beta[0], *p[0]);
409  for (int j=1; j<num_offset_now; j++) {
410  beta[j] = beta[j_low] * zeta[j] * alpha[j] / (zeta_old[j] * alpha[j_low]);
411  blas::axpby(zeta[j], *r_sloppy, beta[j], *p[j]);
412  }
413 
414  // update reliable update parameters for the system that triggered the update
415  int m = reliable_shift;
416  rNorm[m] = sqrt(r2[0]) * zeta[m];
417  maxrr[m] = rNorm[m];
418  maxrx[m] = rNorm[m];
419  r0Norm[m] = rNorm[m];
420  rUpdate++;
421  }
422 
423  // now we can check if any of the shifts have converged and remove them
424  int converged = 0;
425  for (int j=num_offset_now-1; j>=1; j--) {
426  if (zeta[j] == 0.0 && r2[j+1] < stop[j+1]) {
427  converged++;
428  if (getVerbosity() >= QUDA_VERBOSE)
429  printfQuda("MultiShift CG: Shift %d converged after %d iterations\n", j, k+1);
430  } else {
431  r2[j] = zeta[j] * zeta[j] * r2[0];
432  // only remove if shift above has converged
433  if ((r2[j] < stop[j] || sqrt(r2[j] / b2) < prec_tol) && iter[j+1] ) {
434  converged++;
435  iter[j] = k+1;
436  if (getVerbosity() >= QUDA_VERBOSE)
437  printfQuda("MultiShift CG: Shift %d converged after %d iterations\n", j, k+1);
438  }
439  }
440  }
441  num_offset_now -= converged;
442 
443  // this ensure we do the update on any shifted systems that
444  // happen to converge when the un-shifted system converges
445  if ( (convergence(r2, stop, num_offset_now) || k == param.maxiter) && aux_update == true) {
446  if (getVerbosity() >= QUDA_VERBOSE)
447  printfQuda("Convergence of unshifted system so trigger shiftUpdate\n");
448 
449  // set worker to do all updates at once
450  shift_update.updateNupdate(1);
451  shift_update.apply(0);
452 
453  for (int j=0; j<num_offset_now; j++) iter[j] = k+1;
454  }
455 
456  k++;
457 
458  if (getVerbosity() >= QUDA_VERBOSE)
459  printfQuda("MultiShift CG: %d iterations, <r,r> = %e, |r|/|b| = %e\n", k, r2[0], sqrt(r2[0]/b2));
460  }
461 
462  for (int i=0; i<num_offset; i++) {
463  if (iter[i] == 0) iter[i] = k;
464  blas::copy(*x[i], *x_sloppy[i]);
465  if (reliable) blas::xpy(*y[i], *x[i]);
466  }
467 
470 
471  if (getVerbosity() >= QUDA_VERBOSE)
472  printfQuda("MultiShift CG: Reliable updates = %d\n", rUpdate);
473 
474  if (k==param.maxiter) warningQuda("Exceeded maximum iterations %d\n", param.maxiter);
475 
477  double gflops = (blas::flops + mat.flops() + matSloppy.flops())*1e-9;
478  param.gflops = gflops;
479  param.iter += k;
480 
481  if (param.compute_true_res) {
482  // only allocate temporaries if necessary
483  csParam.setPrecision(param.precision);
484  ColorSpinorField *tmp4_p = reliable ? y[0] : tmp1.Precision() == x[0]->Precision() ? &tmp1 : ColorSpinorField::Create(csParam);
485  ColorSpinorField *tmp5_p = mat.isStaggered() ? tmp4_p :
486  reliable ? y[1] : (tmp2.Precision() == x[0]->Precision() && &tmp1 != tmp2_p) ? tmp2_p : ColorSpinorField::Create(csParam);
487 
488  for(int i=0; i < num_offset; i++) {
489  mat(*r, *x[i], *tmp4_p, *tmp5_p);
490  if (r->Nspin()==4) {
491  blas::axpy(offset[i], *x[i], *r); // Offset it.
492  } else if (i!=0) {
493  blas::axpy(offset[i]-offset[0], *x[i], *r); // Offset it.
494  }
495  double true_res = blas::xmyNorm(b, *r);
496  param.true_res_offset[i] = sqrt(true_res/b2);
497  param.iter_res_offset[i] = sqrt(r2[i]/b2);
499  }
500 
501  if (getVerbosity() >= QUDA_SUMMARIZE){
502  printfQuda("MultiShift CG: Converged after %d iterations\n", k);
503  for(int i=0; i < num_offset; i++) {
504  printfQuda(" shift=%d, %d iterations, relative residual: iterated = %e, true = %e\n",
506  }
507  }
508 
509  if (tmp5_p != tmp4_p && tmp5_p != tmp2_p && (reliable ? tmp5_p != y[1] : 1)) delete tmp5_p;
510  if (tmp4_p != &tmp1 && (reliable ? tmp4_p != y[0] : 1)) delete tmp4_p;
511  } else {
512  if (getVerbosity() >= QUDA_SUMMARIZE) {
513  printfQuda("MultiShift CG: Converged after %d iterations\n", k);
514  for(int i=0; i < num_offset; i++) {
515  param.iter_res_offset[i] = sqrt(r2[i]/b2);
516  printfQuda(" shift=%d, %d iterations, relative residual: iterated = %e\n",
517  i, iter[i], param.iter_res_offset[i]);
518  }
519  }
520  }
521 
522  // reset the flops counters
523  blas::flops = 0;
524  mat.flops();
525  matSloppy.flops();
526 
528  profile.TPSTART(QUDA_PROFILE_FREE);
529 
530  if (&tmp3 != &tmp1) delete tmp3_p;
531  if (&tmp2 != &tmp1) delete tmp2_p;
532 
533  if (r_sloppy->Precision() != r->Precision()) delete r_sloppy;
534  for (int i=0; i<num_offset; i++)
535  if (x_sloppy[i]->Precision() != x[i]->Precision()) delete x_sloppy[i];
536 
537  delete r;
538  for (int i=0; i<num_offset; i++) delete p[i];
539 
540  if (reliable) for (int i=0; i<num_offset; i++) delete y[i];
541 
542  delete Ap;
543 
544  delete []zeta_old;
545  delete []zeta;
546  delete []alpha;
547  delete []beta;
548 
549  profile.TPSTOP(QUDA_PROFILE_FREE);
550 
551  return;
552  }
553 
554 } // namespace quda
double iter_res_offset[QUDA_MAX_MULTI_SHIFT]
Definition: invert_quda.h:148
void updateNshift(int new_n_shift)
void xpay(ColorSpinorField &x, const double &a, ColorSpinorField &y)
Definition: blas_quda.cu:173
std::vector< ColorSpinorField * > p
static double stopping(const double &tol, const double &b2, QudaResidualType residual_type)
Definition: solver.cpp:122
#define QUDA_MAX_MULTI_SHIFT
Maximum number of shifts supported by the multi-shift solver. This number may be changed if need be...
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:20
SolverParam & param
Definition: invert_quda.h:732
#define errorQuda(...)
Definition: util_quda.h:90
double norm2(const ColorSpinorField &a)
Definition: reduce_quda.cu:241
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:105
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
Definition: reduce_quda.cu:500
std::complex< double > Complex
Definition: eig_variables.h:13
const DiracMatrix & mat
Definition: invert_quda.h:747
cudaStream_t * stream
static ColorSpinorField * Create(const ColorSpinorParam &param)
void updateAlphaZeta(double *alpha, double *zeta, double *zeta_old, const double *r2, const double *beta, const double pAp, const double *offset, const int nShift, const int j_low)
double reDotProduct(ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:277
double tol_offset[QUDA_MAX_MULTI_SHIFT]
Definition: invert_quda.h:139
double offset[QUDA_MAX_MULTI_SHIFT]
Definition: invert_quda.h:136
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
Definition: copy_quda.cu:263
TimeProfile & profile
Definition: invert_quda.h:733
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:364
int max_res_increase_total
Definition: invert_quda.h:79
int Nspin
Definition: blas_test.cu:45
size_t size_t offset
QudaGaugeParam param
Definition: pack_test.cpp:17
#define b
bool convergence(const double *r2, const double *r2_tol, int n) const
Definition: solver.cpp:216
double Last(QudaProfileType idx)
static unsigned int delta
void updateNupdate(int new_n_update)
QudaResidualType residual_type
Definition: invert_quda.h:47
void axpyZpbx(const double &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, const double &b)
Definition: blas_quda.cu:384
const DiracMatrix & matSloppy
Definition: invert_quda.h:748
double true_res_hq_offset[QUDA_MAX_MULTI_SHIFT]
Definition: invert_quda.h:151
Worker * aux_worker
Definition: dslash_quda.cu:78
void apply(const cudaStream_t &stream)
#define tmp2
Definition: tmc_core.h:16
ColorSpinorParam csParam
Definition: pack_test.cpp:24
static __inline__ size_t p
Complex axpyCGNorm(const double &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:654
#define warningQuda(...)
Definition: util_quda.h:101
#define checkLocation(...)
__host__ __device__ ValueType pow(ValueType x, ExponentType e)
Definition: complex_quda.h:100
std::vector< ColorSpinorField * > x
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
Definition: reduce_quda.cu:703
#define tmp1
Definition: tmc_core.h:15
double true_res_offset[QUDA_MAX_MULTI_SHIFT]
Definition: invert_quda.h:145
ColorSpinorField * r
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:246
void zero(ColorSpinorField &a)
Definition: blas_quda.cu:45
void axpy(const double &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:150
QudaPrecision precision
Definition: invert_quda.h:112
void axpby(const double &a, ColorSpinorField &x, const double &b, ColorSpinorField &y)
Definition: blas_quda.cu:106
void operator()(std::vector< ColorSpinorField *> out, ColorSpinorField &in)
double axpyReDot(const double &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:345
void axpyBzpcx(const double &a, ColorSpinorField &x, ColorSpinorField &y, const double &b, ColorSpinorField &z, const double &c)
Definition: blas_quda.cu:356
unsigned long long flops() const
Definition: dirac_quda.h:995
MultiShiftCG(DiracMatrix &mat, DiracMatrix &matSloppy, SolverParam &param, TimeProfile &profile)
#define printfQuda(...)
Definition: util_quda.h:84
int reliable(double &rNorm, double &maxrx, double &maxrr, const double &r2, const double &delta)
unsigned long long flops
Definition: blas_quda.cu:42
void xpy(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:128
QudaPrecision precision_sloppy
Definition: invert_quda.h:115
bool use_sloppy_partial_accumulator
Definition: invert_quda.h:59
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
QudaPrecision Precision() const
__device__ unsigned int count[QUDA_MAX_MULTI_REDUCE]
Definition: cub_helper.cuh:118
__device__ __host__ void zero(vector_type< scalar, n > &v)
Definition: cub_helper.cuh:82
bool isStaggered() const
Definition: dirac_quda.h:1004
ShiftUpdate(ColorSpinorField *r, std::vector< ColorSpinorField *> p, std::vector< ColorSpinorField *> x, double *alpha, double *beta, double *zeta, double *zeta_old, int j_low, int n_shift)
void updateR()
update the radius for halos.
#define tmp3
Definition: tmc_core.h:17