QUDA  v1.1.0
A library for QCD on GPUs
inv_cg3_quda.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 #include <iostream>
5 
6 #include <complex>
7 
8 #include <quda_internal.h>
9 #include <blas_quda.h>
10 #include <dslash_quda.h>
11 #include <invert_quda.h>
12 #include <util_quda.h>
13 
14 namespace quda {
15 
16  CG3::CG3(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam &param,
17  TimeProfile &profile) :
18  Solver(mat, matSloppy, matPrecon, matPrecon, param, profile),
19  init(false)
20  {
21  }
22 
24  {
25  if ( init ) {
26  delete rp;
27  delete yp;
28  delete tmpp;
29  delete ArSp;
30  delete rS_oldp;
32  delete rSp;
33  delete xSp;
34  delete xS_oldp;
35  delete tmpSp;
36  }
37  if (!mat.isStaggered()) delete tmp2Sp;
38 
39  init = false;
40  }
41  }
42 
43  CG3NE::CG3NE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam &param,
44  TimeProfile &profile) :
45  CG3(mmdag, mmdagSloppy, mmdagPrecon, param, profile),
46  mmdag(mat.Expose()),
47  mmdagSloppy(matSloppy.Expose()),
48  mmdagPrecon(matPrecon.Expose()),
49  xp(nullptr),
50  yp(nullptr),
51  init(false)
52  {
53  }
54 
56  {
57  if (init) {
58  if (xp) delete xp;
59  if (yp) delete yp;
60  init = false;
61  }
62  }
63 
64  // CG3NE: M Mdag y = b is solved; x = Mdag y is returned as solution.
66  {
67  if (param.maxiter == 0 || param.Nsteps == 0) {
69  return;
70  }
71 
72  const int iter0 = param.iter;
73 
74  if (!init) {
80  init = true;
81  }
82 
83  double b2 = blas::norm2(b);
84 
86 
87  // compute initial residual
88  mmdag.Expose()->M(*xp, x);
89  double r2 = blas::xmyNorm(b, *xp);
90  if (b2 == 0.0) b2 = r2;
91 
92  // compute solution to residual equation
93  CG3::operator()(*yp, *xp);
94 
95  mmdag.Expose()->Mdag(*xp, *yp);
96 
97  // compute full solution
98  blas::xpy(*xp, x);
99 
100  } else {
101 
102  CG3::operator()(*yp, b);
103  mmdag.Expose()->Mdag(x, *yp);
104  }
105 
106  // future optimization: with preserve_source == QUDA_PRESERVE_SOURCE_NO; b is already
107  // expected to be the CG residual which matches the CG3NE residual
108  // (but only with zero initial guess). at the moment, CG does not respect this convention
110 
111  // compute the true residual
112  mmdag.Expose()->M(*xp, x);
113 
116  blas::axpby(-1.0, A, 1.0, B);
117 
118  double r2;
120  double3 h3 = blas::HeavyQuarkResidualNorm(x, B);
121  r2 = h3.y;
122  param.true_res_hq = sqrt(h3.z);
123  } else {
124  r2 = blas::norm2(B);
125  }
126  param.true_res = sqrt(r2 / b2);
127 
128  PrintSummary("CG3NE", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq);
129  }
130  }
131 
132  CG3NR::CG3NR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam &param,
133  TimeProfile &profile) :
134  CG3(mdagm, mdagmSloppy, mdagmPrecon, param, profile),
135  mdagm(mat.Expose()),
136  mdagmSloppy(matSloppy.Expose()),
137  mdagmPrecon(matPrecon.Expose()),
138  bp(nullptr),
139  init(false)
140  {
141  }
142 
144  {
145  if (init) {
146  if (bp) delete bp;
147  init = false;
148  }
149  }
150 
151  // CG3NR: Mdag M x = Mdag b is solved.
153  {
154  if (param.maxiter == 0 || param.Nsteps == 0) {
156  return;
157  }
158 
159  const int iter0 = param.iter;
160 
161  if (!init) {
165  init = true;
166  }
167 
168  double b2 = blas::norm2(b);
169  if (b2 == 0.0) { // compute initial residual vector
170  mdagm.Expose()->M(*bp, x);
171  b2 = blas::norm2(*bp);
172  }
173 
174  mdagm.Expose()->Mdag(*bp, b);
175  CG3::operator()(x, *bp);
176 
178 
179  // compute the true residual
180  mdagm.Expose()->M(*bp, x);
181 
184  blas::axpby(-1.0, A, 1.0, B);
185 
186  double r2;
188  double3 h3 = blas::HeavyQuarkResidualNorm(x, B);
189  r2 = h3.y;
190  param.true_res_hq = sqrt(h3.z);
191  } else {
192  r2 = blas::norm2(B);
193  }
194  param.true_res = sqrt(r2 / b2);
195  PrintSummary("CG3NR", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq);
196 
198  mdagm.Expose()->M(*bp, x);
199  blas::axpby(-1.0, *bp, 1.0, b);
200  }
201  }
202 
204  {
206  errorQuda("Not supported");
207  if (x.Precision() != param.precision || b.Precision() != param.precision)
208  errorQuda("Precision mismatch");
209 
210  profile.TPSTART(QUDA_PROFILE_INIT);
211 
212  // Check to see that we're not trying to invert on a zero-field source
213  double b2 = blas::norm2(b);
214  if(b2 == 0 &&
216  profile.TPSTOP(QUDA_PROFILE_INIT);
217  printfQuda("Warning: inverting on zero-field source\n");
218  x = b;
219  param.true_res = 0.0;
220  param.true_res_hq = 0.0;
221  return;
222  }
223 
224  const bool mixed_precision = (param.precision != param.precision_sloppy);
226  if (!init) {
231 
232  // Sloppy fields
233  csParam.setPrecision(param.precision_sloppy);
236  if (mixed_precision) {
241  } else {
242  xS_oldp = yp;
243  tmpSp = tmpp;
244  }
245  if(!mat.isStaggered()) {
247  } else {
248  tmp2Sp = tmpSp;
249  }
250 
251  init = true;
252  }
253 
254  ColorSpinorField &r = *rp;
255  ColorSpinorField &y = *yp;
256  ColorSpinorField &rS = mixed_precision ? *rSp : r;
257  ColorSpinorField &xS = mixed_precision ? *xSp : x;
258  ColorSpinorField &ArS = *ArSp;
259  ColorSpinorField &rS_old = *rS_oldp;
260  ColorSpinorField &xS_old = *xS_oldp;
261  ColorSpinorField &tmp = *tmpp;
262  ColorSpinorField &tmpS = *tmpSp;
263  ColorSpinorField &tmp2S = *tmp2Sp;
264 
265  double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver
266 
267  const bool use_heavy_quark_res =
268  (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false;
269 
270  // this parameter determines how many consective reliable update
271  // reisudal increases we tolerate before terminating the solver,
272  // i.e., how long do we want to keep trying to converge
273  const int maxResIncrease = param.max_res_increase; // check if we reached the limit of our tolerance
274  const int maxResIncreaseTotal = param.max_res_increase_total;
275  int resIncrease = 0;
276  int resIncreaseTotal = 0;
277 
278  // these are only used if we use the heavy_quark_res
279  const int hqmaxresIncrease = maxResIncrease + 1;
280  int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual
281  double heavy_quark_res = 0.0; // heavy quark residual
282  double heavy_quark_res_old = 0.0; // heavy quark residual
283  int hqresIncrease = 0;
284  bool L2breakdown = false;
285 
286  int pipeline = param.pipeline;
287 
288  profile.TPSTOP(QUDA_PROFILE_INIT);
290 
291  blas::flops = 0;
292 
293  // compute initial residual depending on whether we have an initial guess or not
294  double r2;
296  mat(r, x, y, tmp);
297  r2 = blas::xmyNorm(b, r);
298  if(b2==0) b2 = r2;
299  if (mixed_precision) {
300  blas::copy(y, x);
301  blas::zero(xS);
302  }
303  } else {
304  blas::copy(r, b);
305  r2 = b2;
306  blas::zero(x);
307  if (mixed_precision) {
308  blas::zero(y);
309  blas::zero(xS);
310  }
311  }
312  blas::copy(rS, r);
313 
314  if (use_heavy_quark_res) {
315  heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z);
316  heavy_quark_res_old = heavy_quark_res;
317  }
318 
320  if(convergence(r2, heavy_quark_res, stop, param.tol_hq)) {
322  blas::copy(b, r);
323  }
324  return;
325  }
326  profile.TPSTART(QUDA_PROFILE_COMPUTE);
327 
328  double r2_old = r2;
329  double rNorm = sqrt(r2);
330  double r0Norm = rNorm;
331  double maxrx = rNorm;
332  double maxrr = rNorm;
333  double delta = param.delta;
334  bool restart = false;
335 
336  int k = 0;
337  PrintStats("CG3", k, r2, b2, heavy_quark_res);
338  double rho = 1.0, gamma = 1.0;
339 
340  while ( !convergence(r2, heavy_quark_res, stop, param.tol_hq) && k < param.maxiter) {
341 
342  matSloppy(ArS, rS, tmpS, tmp2S);
343  double gamma_old = gamma;
344  double rAr = blas::reDotProduct(rS,ArS);
345  gamma = r2/rAr;
346 
347  // CG3 step
348  if (k == 0 || restart) { // First iteration
349  if (pipeline) {
350  r2 = blas::quadrupleCG3InitNorm(gamma, xS, rS, xS_old, rS_old, ArS);
351  } else {
352  blas::copy(xS_old, xS);
353  blas::copy(rS_old, rS);
354 
355  blas::axpy(gamma, rS, xS); // x += gamma*r
356  r2 = blas::axpyNorm(-gamma, ArS, rS); // r -= gamma*w
357  }
358  restart = false;
359  } else {
360  rho = rho/(rho-(gamma/gamma_old)*(r2/r2_old));
361  r2_old = r2;
362 
363  if (pipeline) {
364  r2 = blas::quadrupleCG3UpdateNorm(gamma, rho, xS, rS, xS_old, rS_old, ArS);
365  } else {
366  blas::copy(tmpS, xS);
367  blas::copy(tmp2S, rS);
368 
369  blas::axpby(gamma*rho, rS, rho, xS);
370  blas::axpby(-gamma*rho, ArS, rho, rS);
371 
372  blas::axpy(1.-rho, xS_old, xS);
373  r2 = blas::axpyNorm(1.-rho, rS_old, rS);
374 
375  blas::copy(xS_old, tmpS);
376  blas::copy(rS_old, tmp2S);
377  }
378  }
379 
380  k++;
381 
382  if (use_heavy_quark_res && k%heavy_quark_check==0) {
383  heavy_quark_res_old = heavy_quark_res;
384  if (mixed_precision) {
385  blas::copy(tmpS,y);
386  heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(xS, tmpS, rS).z);
387  } else {
388  heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(xS, rS).z);
389  }
390  }
391 
392  // reliable update conditions
393  if (mixed_precision) {
394  rNorm = sqrt(r2);
395  if (rNorm > maxrx) maxrx = rNorm;
396  if (rNorm > maxrr) maxrr = rNorm;
397  bool update = (rNorm < delta*r0Norm && r0Norm <= maxrx); // condition for x
398  update = ( update || (rNorm < delta*maxrr && r0Norm <= maxrr)); // condition for r
399 
400  // force a reliable update if we are within target tolerance (only if doing reliable updates)
401  if ( convergence(r2, heavy_quark_res, stop, param.tol_hq) && param.delta >= param.tol ) update = true;
402 
403  // For heavy-quark inversion force a reliable update if we continue after
404  if ( use_heavy_quark_res and L2breakdown and convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq) and param.delta >= param.tol ) {
405  update = true;
406  }
407 
408  if (update) {
409  // updating the "new" vectors
410  blas::copy(x, xS);
411  blas::xpy(x, y);
412  mat(r, y, x, tmp); // here we can use x as tmp
413  r2 = blas::xmyNorm(b, r);
414  param.true_res = sqrt(r2 / b2);
415  if (use_heavy_quark_res) {
416  heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(y, r).z);
417  param.true_res_hq = heavy_quark_res;
418  }
419  rNorm = sqrt(r2);
420  r0Norm = sqrt(r2);
421  maxrr = rNorm;
422  maxrx = rNorm;
423  // we update sloppy and old fields
424  if (!convergence(r2, heavy_quark_res, stop, param.tol_hq)) {
425  blas::copy(rS, r);
426  blas::axpy(-1., xS, xS_old);
427  // we preserve the orthogonality between the previous residual and the new
428  Complex rr_old = blas::cDotProduct(rS, rS_old);
429  r2_old = blas::caxpyNorm(-rr_old/r2, rS, rS_old);
430  blas::zero(xS);
431  }
432  }
433 
434  // break-out check if we have reached the limit of the precision
435  if (sqrt(r2) > r0Norm) {
436  resIncrease++;
437  resIncreaseTotal++;
438  warningQuda(
439  "CG3: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
440  sqrt(r2), r0Norm, resIncreaseTotal);
441  if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
442  if (use_heavy_quark_res) {
443  L2breakdown = true;
444  } else {
445  warningQuda("CG3: solver exiting due to too many true residual norm increases");
446  break;
447  }
448  }
449  } else {
450  resIncrease = 0;
451  }
452 
453  // if L2 broke down we turn off reliable updates and restart the CG
454  if (use_heavy_quark_res and L2breakdown) {
455  delta = 0;
456  heavy_quark_check = 1;
457  warningQuda("CG3: Restarting without reliable updates for heavy-quark residual");
458  restart = true;
459  L2breakdown = false;
460  if (heavy_quark_res > heavy_quark_res_old) {
461  hqresIncrease++;
462  warningQuda("CG3: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", heavy_quark_res, heavy_quark_res_old);
463  // break out if we do not improve here anymore
464  if (hqresIncrease > hqmaxresIncrease) {
465  warningQuda("CG3: solver exiting due to too many heavy quark residual norm increases");
466  break;
467  }
468  }
469  }
470  } else {
471  if (convergence(r2, heavy_quark_res, stop, param.tol_hq)) {
472  mat(r, x, tmp, tmp2S);
473  r2 = blas::xmyNorm(b, r);
474  r0Norm = sqrt(r2);
475  // we update sloppy and old fields
476  if (!convergence(r2, heavy_quark_res, stop, param.tol_hq)) {
477  // we preserve the orthogonality between the previous residual and the new
478  Complex rr_old = blas::cDotProduct(rS, rS_old);
479  r2_old = blas::caxpyNorm(-rr_old/r2, rS, rS_old);
480  }
481  }
482 
483  // break-out check if we have reached the limit of the precision
484  if (sqrt(r2) > r0Norm) {
485  resIncrease++;
486  resIncreaseTotal++;
487  warningQuda(
488  "CG3: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
489  sqrt(r2), r0Norm, resIncreaseTotal);
490  if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
491  warningQuda("CG3: solver exiting due to too many true residual norm increases");
492  break;
493  }
494  }
495  }
496 
497  PrintStats("CG3", k, r2, b2, heavy_quark_res);
498  }
499 
500  if (mixed_precision) blas::copy(x, y);
503 
505  double gflops = (blas::flops + mat.flops() + matSloppy.flops())*1e-9;
506  param.gflops = gflops;
507  param.iter += k;
508 
509  if (k == param.maxiter)
510  warningQuda("Exceeded maximum iterations %d", param.maxiter);
511 
512  // compute the true residuals
513  if (!mixed_precision && param.compute_true_res) {
514  mat(r, x, y, tmp);
515  param.true_res = sqrt(blas::xmyNorm(b, r) / b2);
516  if (use_heavy_quark_res) param.true_res_hq = sqrt(blas::HeavyQuarkResidualNorm(x, r).z);
517  }
518 
520  blas::copy(b, r);
521  }
522 
523  PrintSummary("CG3", k, r2, b2, stop, param.tol_hq);
524 
525  // reset the flops counters
526  blas::flops = 0;
527  mat.flops();
528  matSloppy.flops();
529 
531  }
532 
533 } // namespace quda
virtual ~CG3()
CG3(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam &param, TimeProfile &profile)
void operator()(ColorSpinorField &out, ColorSpinorField &in)
CG3NE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam &param, TimeProfile &profile)
virtual ~CG3NE()
void operator()(ColorSpinorField &out, ColorSpinorField &in)
CG3NR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam &param, TimeProfile &profile)
virtual ~CG3NR()
void operator()(ColorSpinorField &out, ColorSpinorField &in)
static ColorSpinorField * Create(const ColorSpinorParam &param)
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const =0
Apply M for the dirac op. E.g. the Schur Complement operator.
void Mdag(ColorSpinorField &out, const ColorSpinorField &in) const
Apply Mdag (daggered operator of M.
Definition: dirac.cpp:92
const Dirac * Expose() const
Definition: dirac_quda.h:1964
bool isStaggered() const
return if the operator is a staggered operator
Definition: dirac_quda.h:1935
unsigned long long flops() const
Definition: dirac_quda.h:1909
QudaPrecision Precision() const
TimeProfile & profile
Definition: invert_quda.h:471
const DiracMatrix & mat
Definition: invert_quda.h:465
bool convergence(double r2, double hq2, double r2_tol, double hq_tol)
Definition: solver.cpp:328
bool convergenceHQ(double r2, double hq2, double r2_tol, double hq_tol)
Test for HQ solver convergence – ignore L2 residual.
Definition: solver.cpp:348
void PrintSummary(const char *name, int k, double r2, double b2, double r2_tol, double hq_tol)
Prints out the summary of the solver convergence (requires a verbosity of QUDA_SUMMARIZE)....
Definition: solver.cpp:386
SolverParam & param
Definition: invert_quda.h:470
static double stopping(double tol, double b2, QudaResidualType residual_type)
Set the solver L2 stopping condition.
Definition: solver.cpp:311
void PrintStats(const char *name, int k, double r2, double b2, double hq2)
Prints out the running statistics of the solver (requires a verbosity of QUDA_VERBOSE)
Definition: solver.cpp:373
const DiracMatrix & matSloppy
Definition: invert_quda.h:466
double Last(QudaProfileType idx)
Definition: timer.h:254
int pipeline
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:34
@ QUDA_CUDA_FIELD_LOCATION
Definition: enum_quda.h:326
@ QUDA_USE_INIT_GUESS_NO
Definition: enum_quda.h:429
@ QUDA_USE_INIT_GUESS_YES
Definition: enum_quda.h:430
@ QUDA_HEAVY_QUARK_RESIDUAL
Definition: enum_quda.h:195
@ QUDA_PRESERVE_SOURCE_NO
Definition: enum_quda.h:238
@ QUDA_PRESERVE_SOURCE_YES
Definition: enum_quda.h:239
@ QUDA_ZERO_FIELD_CREATE
Definition: enum_quda.h:361
@ QUDA_NULL_FIELD_CREATE
Definition: enum_quda.h:360
@ QUDA_COMPUTE_NULL_VECTOR_NO
Definition: enum_quda.h:441
#define checkLocation(...)
void init()
Create the BLAS context.
double quadrupleCG3UpdateNorm(double a, double b, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v)
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:79
unsigned long long flops
double quadrupleCG3InitNorm(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v)
double3 xpyHeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &r)
double caxpyNorm(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
void zero(ColorSpinorField &a)
double norm2(const ColorSpinorField &a)
double reDotProduct(ColorSpinorField &x, ColorSpinorField &y)
double axpyNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:78
void axpy(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:43
void xpy(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:41
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
Definition: blas_quda.h:24
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
void axpby(double a, ColorSpinorField &x, double b, ColorSpinorField &y)
Definition: blas_quda.h:44
void stop()
Stop profiling.
Definition: device.cpp:228
std::complex< double > Complex
Definition: quda_internal.h:86
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:120
@ QUDA_PROFILE_INIT
Definition: timer.h:106
@ QUDA_PROFILE_EPILOGUE
Definition: timer.h:110
@ QUDA_PROFILE_COMPUTE
Definition: timer.h:108
@ QUDA_PROFILE_PREAMBLE
Definition: timer.h:107
ColorSpinorParam csParam
Definition: pack_test.cpp:25
QudaGaugeParam param
Definition: pack_test.cpp:18
QudaPreserveSource preserve_source
Definition: invert_quda.h:151
QudaPrecision precision
Definition: invert_quda.h:136
QudaComputeNullVector compute_null_vector
Definition: invert_quda.h:61
int max_res_increase_total
Definition: invert_quda.h:90
QudaResidualType residual_type
Definition: invert_quda.h:49
QudaPrecision precision_sloppy
Definition: invert_quda.h:139
QudaUseInitGuess use_init_guess
Definition: invert_quda.h:58
#define printfQuda(...)
Definition: util_quda.h:114
#define warningQuda(...)
Definition: util_quda.h:132
#define errorQuda(...)
Definition: util_quda.h:120