QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
inv_cg3_quda.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 #include <iostream>
5 
6 #include <complex>
7 
8 #include <quda_internal.h>
9 #include <blas_quda.h>
10 #include <dslash_quda.h>
11 #include <invert_quda.h>
12 #include <util_quda.h>
13 
14 namespace quda {
15 
17  Solver(param, profile), mat(mat), matSloppy(matSloppy), init(false)
18  {
19  }
20 
22  if ( init ) {
23  delete rp;
24  delete yp;
25  delete tmpp;
26  delete ArSp;
28  delete rSp;
29  delete xSp;
30  delete xS_oldp;
31  delete tmpSp;
32  delete rS_oldp;
33  }
34  if(!mat.isStaggered()) delete tmp2Sp;
35 
36  init = false;
37  }
38  }
39 
41  {
43  errorQuda("Not supported");
44  if (x.Precision() != param.precision || b.Precision() != param.precision)
45  errorQuda("Precision mismatch");
46 
47  profile.TPSTART(QUDA_PROFILE_INIT);
48 
49  // Check to see that we're not trying to invert on a zero-field source
50  double b2 = blas::norm2(b);
51  if(b2 == 0 &&
53  profile.TPSTOP(QUDA_PROFILE_INIT);
54  printfQuda("Warning: inverting on zero-field source\n");
55  x = b;
56  param.true_res = 0.0;
57  param.true_res_hq = 0.0;
58  return;
59  }
60 
61  const bool mixed_precision = (param.precision != param.precision_sloppy);
63  if (!init) {
65  rp = ColorSpinorField::Create(csParam);
66  tmpp = ColorSpinorField::Create(csParam);
67  yp = ColorSpinorField::Create(csParam);
68 
69  // Sloppy fields
71  ArSp = ColorSpinorField::Create(csParam);
73  if(mixed_precision) {
74  rSp = ColorSpinorField::Create(csParam);
75  xSp = ColorSpinorField::Create(csParam);
78  } else {
79  xS_oldp = yp;
80  tmpSp = tmpp;
81  }
82  if(!mat.isStaggered()) {
84  } else {
85  tmp2Sp = tmpSp;
86  }
87 
88  init = true;
89  }
90 
91  ColorSpinorField &r = *rp;
92  ColorSpinorField &y = *yp;
93  ColorSpinorField &rS = mixed_precision ? *rSp : r;
94  ColorSpinorField &xS = mixed_precision ? *xSp : x;
95  ColorSpinorField &ArS = *ArSp;
96  ColorSpinorField &rS_old = *rS_oldp;
97  ColorSpinorField &xS_old = *xS_oldp;
99  ColorSpinorField &tmpS = *tmpSp;
100  ColorSpinorField &tmp2S = *tmp2Sp;
101 
102  double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver
103 
104  const bool use_heavy_quark_res =
105  (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false;
106 
107  // this parameter determines how many consective reliable update
108  // reisudal increases we tolerate before terminating the solver,
109  // i.e., how long do we want to keep trying to converge
110  const int maxResIncrease = param.max_res_increase; // check if we reached the limit of our tolerance
111  const int maxResIncreaseTotal = param.max_res_increase_total;
112  int resIncrease = 0;
113  int resIncreaseTotal = 0;
114 
115  // these are only used if we use the heavy_quark_res
116  const int hqmaxresIncrease = maxResIncrease + 1;
117  int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual
118  double heavy_quark_res = 0.0; // heavy quark residual
119  double heavy_quark_res_old = 0.0; // heavy quark residual
120  int hqresIncrease = 0;
121  bool L2breakdown = false;
122 
123  int pipeline = param.pipeline;
124 
125  profile.TPSTOP(QUDA_PROFILE_INIT);
127 
128  blas::flops = 0;
129 
130  // compute initial residual depending on whether we have an initial guess or not
131  double r2;
133  mat(r, x, y, tmp);
134  r2 = blas::xmyNorm(b, r);
135  if(b2==0) b2 = r2;
136  if (mixed_precision) {
137  blas::copy(y, x);
138  blas::zero(xS);
139  }
140  } else {
141  blas::copy(r, b);
142  r2 = b2;
143  blas::zero(x);
144  if (mixed_precision) {
145  blas::zero(y);
146  blas::zero(xS);
147  }
148  }
149  blas::copy(rS, r);
150 
151  if (use_heavy_quark_res) {
152  heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z);
153  heavy_quark_res_old = heavy_quark_res;
154  }
155 
157  if(convergence(r2, heavy_quark_res, stop, param.tol_hq)) {
159  blas::copy(b, r);
160  }
161  return;
162  }
163  profile.TPSTART(QUDA_PROFILE_COMPUTE);
164 
165  double r2_old = r2;
166  double rNorm = sqrt(r2);
167  double r0Norm = rNorm;
168  double maxrx = rNorm;
169  double maxrr = rNorm;
170  double delta = param.delta;
171  bool restart = false;
172 
173  int k = 0;
174  double rho = 1.0, gamma = 1.0;
175  while ( !convergence(r2, heavy_quark_res, stop, param.tol_hq) && k < param.maxiter) {
176 
177  matSloppy(ArS, rS, tmpS, tmp2S);
178  double gamma_old = gamma;
179  double rAr = blas::reDotProduct(rS,ArS);
180  gamma = r2/rAr;
181 
182  // CG3 step
183  if(k==0 || restart) { // First iteration
184  if(pipeline) {
185  r2 = blas::quadrupleCG3InitNorm(gamma, xS, rS, xS_old, rS_old, ArS);
186  } else {
187  blas::copy(xS_old, xS);
188  blas::copy(rS_old, rS);
189 
190  blas::axpy(gamma, rS, xS); // x += gamma*r
191  r2 = blas::axpyNorm(-gamma, ArS, rS); // r -= gamma*w
192  }
193  restart = false;
194  } else {
195  rho = rho/(rho-(gamma/gamma_old)*(r2/r2_old));
196  r2_old = r2;
197 
198  if(pipeline) {
199  r2 = blas::quadrupleCG3UpdateNorm(gamma, rho, xS, rS, xS_old, rS_old, ArS);
200  } else {
201  blas::copy(tmpS, xS);
202  blas::copy(tmp2S, rS);
203 
204  blas::axpby(gamma*rho, rS, rho, xS);
205  blas::axpby(-gamma*rho, ArS, rho, rS);
206 
207  blas::axpy(1.-rho, xS_old, xS);
208  r2 = blas::axpyNorm(1.-rho, rS_old, rS);
209 
210  blas::copy(xS_old, tmpS);
211  blas::copy(rS_old, tmp2S);
212  }
213  }
214 
215  k++;
216 
217  if (use_heavy_quark_res && k%heavy_quark_check==0) {
218  heavy_quark_res_old = heavy_quark_res;
219  if (mixed_precision) {
220  blas::copy(tmpS,y);
221  heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(xS, tmpS, rS).z);
222  } else {
223  heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(xS, rS).z);
224  }
225  }
226 
227  // reliable update conditions
228  if (mixed_precision) {
229  rNorm = sqrt(r2);
230 
231  if (rNorm > maxrx) maxrx = rNorm;
232  if (rNorm > maxrr) maxrr = rNorm;
233  bool update = (rNorm < delta*r0Norm && r0Norm <= maxrx); // condition for x
234  update = ( update || (rNorm < delta*maxrr && r0Norm <= maxrr)); // condition for r
235 
236  // force a reliable update if we are within target tolerance (only if doing reliable updates)
237  if ( convergence(r2, heavy_quark_res, stop, param.tol_hq) && param.delta >= param.tol ) update = true;
238 
239  // For heavy-quark inversion force a reliable update if we continue after
240  if ( use_heavy_quark_res and L2breakdown and convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq) and param.delta >= param.tol ) {
241  update = true;
242  }
243 
244  if (update) {
245  // updating the "new" vectors
246  blas::copy(x, xS);
247  blas::xpy(x, y);
248  mat(r, y, x, tmp); // here we can use x as tmp
249  r2 = blas::xmyNorm(b, r);
250  param.true_res = sqrt(r2 / b2);
251  if (use_heavy_quark_res) {
252  heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(y, r).z);
253  param.true_res_hq = heavy_quark_res;
254  }
255  // we update sloppy and old fields
256  if (!convergence(r2, heavy_quark_res, stop, param.tol_hq)) {
257  blas::copy(rS, r);
258  blas::axpy(-1., xS, xS_old);
259  // we preserve the orthogonality between the previous residual and the new
260  Complex rr_old = blas::cDotProduct(rS, rS_old);
261  r2_old = blas::caxpyNorm(-rr_old/r2, rS, rS_old);
262  blas::zero(xS);
263  }
264  }
265 
266  // break-out check if we have reached the limit of the precision
267  if (r2 > r2_old) {
268  resIncrease++;
269  resIncreaseTotal++;
270  warningQuda("CG3: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
271  sqrt(r2), sqrt(r2_old), resIncreaseTotal);
272  if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
273  if (use_heavy_quark_res) {
274  L2breakdown = true;
275  } else {
276  warningQuda("CG3: solver exiting due to too many true residual norm increases");
277  break;
278  }
279  }
280  } else {
281  resIncrease = 0;
282  }
283 
284  // if L2 broke down we turn off reliable updates and restart the CG
285  if (use_heavy_quark_res and L2breakdown) {
286  delta = 0;
287  heavy_quark_check = 1;
288  warningQuda("CG3: Restarting without reliable updates for heavy-quark residual");
289  restart = true;
290  L2breakdown = false;
291  if (heavy_quark_res > heavy_quark_res_old) {
292  hqresIncrease++;
293  warningQuda("CG3: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", heavy_quark_res, heavy_quark_res_old);
294  // break out if we do not improve here anymore
295  if (hqresIncrease > hqmaxresIncrease) {
296  warningQuda("CG3: solver exiting due to too many heavy quark residual norm increases");
297  break;
298  }
299  }
300  }
301  } else {
302  if (convergence(r2, heavy_quark_res, stop, param.tol_hq)) {
303  mat(r, x, tmp, tmp2S);
304  r2 = blas::xmyNorm(b, r);
305  // we update sloppy and old fields
306  if (!convergence(r2, heavy_quark_res, stop, param.tol_hq)) {
307  // we preserve the orthogonality between the previous residual and the new
308  Complex rr_old = blas::cDotProduct(rS, rS_old);
309  r2_old = blas::caxpyNorm(-rr_old/r2, rS, rS_old);
310  }
311  }
312 
313  // break-out check if we have reached the limit of the precision
314  if (r2 > r2_old) {
315  resIncrease++;
316  resIncreaseTotal++;
317  warningQuda("CG3: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
318  sqrt(r2), sqrt(r2_old), resIncreaseTotal);
319  if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
320  warningQuda("CG3: solver exiting due to too many true residual norm increases");
321  break;
322  }
323  }
324  }
325 
326  PrintStats("CG3", k, r2, b2, heavy_quark_res);
327  }
328 
331 
333  double gflops = (blas::flops + mat.flops() + matSloppy.flops())*1e-9;
334  param.gflops = gflops;
335  param.iter += k;
336 
337  if (k == param.maxiter)
338  warningQuda("Exceeded maximum iterations %d", param.maxiter);
339 
340 
341  // compute the true residuals
342  if (!mixed_precision && param.compute_true_res) {
343  mat(r, x, y, tmp);
344  param.true_res = sqrt(blas::xmyNorm(b, r) / b2);
345  if (use_heavy_quark_res) param.true_res_hq = sqrt(blas::HeavyQuarkResidualNorm(x, r).z);
346  }
347 
349  blas::copy(b, r);
350  }
351 
352  PrintSummary("CG3", k, r2, b2, stop, param.tol_hq);
353 
354  // reset the flops counters
355  blas::flops = 0;
356  mat.flops();
357  matSloppy.flops();
358 
360 
361  return;
362  }
363 
364 } // namespace quda
const DiracMatrix & mat
Definition: invert_quda.h:610
void setPrecision(QudaPrecision precision, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION, bool force_native=false)
int pipeline
Definition: test_util.cpp:1634
bool convergenceHQ(double r2, double hq2, double r2_tol, double hq_tol)
Test for HQ solver convergence – ignore L2 residual.
Definition: solver.cpp:237
double caxpyNorm(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:746
double quadrupleCG3InitNorm(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v)
Definition: reduce_quda.cu:838
#define errorQuda(...)
Definition: util_quda.h:121
double norm2(const ColorSpinorField &a)
Definition: reduce_quda.cu:721
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:120
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
Definition: reduce_quda.cu:764
void PrintStats(const char *name, int k, double r2, double b2, double hq2)
Prints out the running statistics of the solver (requires a verbosity of QUDA_VERBOSE) ...
Definition: solver.cpp:256
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:44
double3 xpyHeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &r)
Definition: reduce_quda.cu:818
virtual ~CG3()
static ColorSpinorField * Create(const ColorSpinorParam &param)
bool convergence(double r2, double hq2, double r2_tol, double hq_tol)
Definition: solver.cpp:223
double reDotProduct(ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:728
TimeProfile & profile
Definition: invert_quda.h:464
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
Definition: copy_quda.cu:355
ColorSpinorField * tmp2Sp
Definition: invert_quda.h:613
ColorSpinorField * rp
Definition: invert_quda.h:613
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:75
QudaPreserveSource preserve_source
Definition: invert_quda.h:154
int max_res_increase_total
Definition: invert_quda.h:96
ColorSpinorField * xSp
Definition: invert_quda.h:613
ColorSpinorField * rS_oldp
Definition: invert_quda.h:613
void operator()(ColorSpinorField &out, ColorSpinorField &in)
ColorSpinorField * tmpp
Definition: invert_quda.h:613
QudaGaugeParam param
Definition: pack_test.cpp:17
QudaComputeNullVector compute_null_vector
Definition: invert_quda.h:67
double Last(QudaProfileType idx)
Definition: timer.h:251
CG3(DiracMatrix &mat, DiracMatrix &matSloppy, SolverParam &param, TimeProfile &profile)
QudaResidualType residual_type
Definition: invert_quda.h:49
static double stopping(double tol, double b2, QudaResidualType residual_type)
Set the solver L2 stopping condition.
Definition: solver.cpp:206
ColorSpinorParam csParam
Definition: pack_test.cpp:24
void axpy(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:35
#define warningQuda(...)
Definition: util_quda.h:133
#define checkLocation(...)
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
Definition: reduce_quda.cu:809
std::complex< double > Complex
Definition: quda_internal.h:46
void init()
Create the CUBLAS context.
Definition: blas_cublas.cu:31
void zero(ColorSpinorField &a)
Definition: blas_quda.cu:472
ColorSpinorField * rSp
Definition: invert_quda.h:613
QudaPrecision precision
Definition: invert_quda.h:142
SolverParam & param
Definition: invert_quda.h:463
unsigned long long flops() const
Definition: dirac_quda.h:1119
ColorSpinorField * xS_oldp
Definition: invert_quda.h:613
#define printfQuda(...)
Definition: util_quda.h:115
unsigned long long flops
Definition: blas_quda.cu:22
ColorSpinorField * tmpSp
Definition: invert_quda.h:613
void xpy(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:33
void axpby(double a, ColorSpinorField &x, double b, ColorSpinorField &y)
Definition: blas_quda.h:36
double axpyNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:74
QudaUseInitGuess use_init_guess
Definition: invert_quda.h:64
QudaPrecision precision_sloppy
Definition: invert_quda.h:145
void PrintSummary(const char *name, int k, double r2, double b2, double r2_tol, double hq_tol)
Prints out the summary of the solver convergence (requires a verbosity of QUDA_SUMMARIZE). Assumes SolverParam.true_res and SolverParam.true_res_hq has been set.
Definition: solver.cpp:270
ColorSpinorField * yp
Definition: invert_quda.h:613
const DiracMatrix & matSloppy
Definition: invert_quda.h:611
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
QudaPrecision Precision() const
bool isStaggered() const
Definition: dirac_quda.h:1128
double quadrupleCG3UpdateNorm(double a, double b, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v)
Definition: reduce_quda.cu:843
ColorSpinorField * ArSp
Definition: invert_quda.h:613