QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
inv_cg3ne_quda.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 
5 #include <complex>
6 
7 #include <quda_internal.h>
8 #include <blas_quda.h>
9 #include <dslash_quda.h>
10 #include <invert_quda.h>
11 #include <util_quda.h>
12 #include <color_spinor_field.h>
13 
14 namespace quda {
15 
17  Solver(param, profile), mat(mat), matSloppy(matSloppy), matDagSloppy(matSloppy), init(false)
18  {
19  }
20 
22  if ( init ) {
23  delete rp;
24  delete yp;
25  delete AdagrSp;
26  delete AAdagrSp;
27  delete tmpSp;
29  delete rSp;
30  delete xSp;
31  delete xS_oldp;
32  delete rS_oldp;
33  }
34 
35  init = false;
36  }
37  }
38 
40  {
42  errorQuda("Not supported");
43  if (x.Precision() != param.precision || b.Precision() != param.precision)
44  errorQuda("Precision mismatch");
45 
46  profile.TPSTART(QUDA_PROFILE_INIT);
47 
48  // Check to see that we're not trying to invert on a zero-field source
49  double b2 = blas::norm2(b);
50  if(b2 == 0 &&
52  profile.TPSTOP(QUDA_PROFILE_INIT);
53  printfQuda("Warning: inverting on zero-field source\n");
54  x = b;
55  param.true_res = 0.0;
56  param.true_res_hq = 0.0;
57  return;
58  }
59 
60  const bool is_cg3ne = (param.inv_type == QUDA_CG3NE_INVERTER) ? true:false;
61  char name[6];
62  strcpy(name, is_cg3ne ? "CG3NE" : "CG3NR");
63 
64  const bool mixed_precision = (param.precision != param.precision_sloppy);
65  if (!init) {
68  rp = ColorSpinorField::Create(csParam);
69  yp = ColorSpinorField::Create(csParam);
70 
71  // Sloppy fields
77  if(mixed_precision) {
78  rSp = ColorSpinorField::Create(csParam);
79  xSp = ColorSpinorField::Create(csParam);
81  } else {
82  xS_oldp = yp;
83  }
84 
85  init = true;
86  }
87 
88  ColorSpinorField &r = *rp;
89  ColorSpinorField &y = *yp;
90  ColorSpinorField &rS = mixed_precision ? *rSp : r;
91  ColorSpinorField &xS = mixed_precision ? *xSp : x;
92  ColorSpinorField &AdagrS = *AdagrSp;
93  ColorSpinorField &AAdagrS = *AAdagrSp;
94  ColorSpinorField &rS_old = *rS_oldp;
95  ColorSpinorField &xS_old = *xS_oldp;
96  ColorSpinorField &tmpS = *tmpSp;
97 
98  double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver
99 
100  const bool use_heavy_quark_res =
101  (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false;
102 
103  // this parameter determines how many consective reliable update
104  // reisudal increases we tolerate before terminating the solver,
105  // i.e., how long do we want to keep trying to converge
106  const int maxResIncrease = param.max_res_increase; // check if we reached the limit of our tolerance
107  const int maxResIncreaseTotal = param.max_res_increase_total;
108  int resIncrease = 0;
109  int resIncreaseTotal = 0;
110 
111  // these are only used if we use the heavy_quark_res
112  const int hqmaxresIncrease = maxResIncrease + 1;
113  int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual
114  double heavy_quark_res = 0.0; // heavy quark residual
115  double heavy_quark_res_old = 0.0; // heavy quark residual
116  int hqresIncrease = 0;
117  bool L2breakdown = false;
118 
119  int pipeline = param.pipeline;
120 
121  profile.TPSTOP(QUDA_PROFILE_INIT);
123 
124  blas::flops = 0;
125 
126  // compute initial residual depending on whether we have an initial guess or not
127  double r2;
129  mat(r, x, y);
130  r2 = blas::xmyNorm(b, r);
131  if(b2==0) b2 = r2;
132  if (mixed_precision) {
133  blas::copy(y, x);
134  blas::zero(xS);
135  }
136  } else {
137  blas::copy(r, b);
138  r2 = b2;
139  blas::zero(x);
140  if (mixed_precision) {
141  blas::zero(y);
142  blas::zero(xS);
143  }
144  }
145  blas::copy(rS, r);
146 
147  if (use_heavy_quark_res) {
148  heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z);
149  heavy_quark_res_old = heavy_quark_res;
150  }
151 
153  if(convergence(r2, heavy_quark_res, stop, param.tol_hq)) {
155  blas::copy(b, r);
156  }
157  return;
158  }
159  profile.TPSTART(QUDA_PROFILE_COMPUTE);
160 
161  double r2_old = r2;
162  double rNorm = sqrt(r2);
163  double r0Norm = rNorm;
164  double maxrx = rNorm;
165  double maxrr = rNorm;
166  double delta = param.delta;
167  bool restart = false;
168 
169  int k = 0;
170  double rho = 1.0, gamma = 1.0, Ar2 = 1.0;
171  while ( !convergence(r2, heavy_quark_res, stop, param.tol_hq) && k < param.maxiter) {
172 
173  matDagSloppy(AdagrS, rS, tmpS);
174  matSloppy(AAdagrS, AdagrS, tmpS);
175  double gamma_old = gamma;
176  double Ar2_old = Ar2;
177  Ar2 = blas::norm2(AdagrS);
178  if(is_cg3ne) {
179  gamma = r2/Ar2;
180  } else {
181  double AAr2 = blas::norm2(AAdagrS);
182  gamma = Ar2/AAr2;
183  }
184  // CG3NE step
185  if(k==0 || restart) { // First iteration
186  if(pipeline) {
187  blas::doubleCG3Init(gamma, xS, xS_old, AdagrS);
188  r2 = blas::doubleCG3InitNorm(gamma, rS, rS_old, AAdagrS);
189  } else {
190  blas::copy(xS_old, xS);
191  blas::copy(rS_old, rS);
192 
193  blas::axpy(gamma, AdagrS, xS); // x += gamma*Adag*r
194  r2 = blas::axpyNorm(-gamma, AAdagrS, rS); // r -= gamma*w
195  }
196  restart = false;
197  } else {
198  if(is_cg3ne) {
199  rho = rho/(rho-(gamma/gamma_old)*(r2/r2_old));
200  } else {
201  rho = rho/(rho-(gamma/gamma_old)*(Ar2/Ar2_old));
202  }
203  r2_old = r2;
204 
205  if(pipeline) {
206  blas::doubleCG3Update(gamma, rho, xS, xS_old, AdagrS);
207  r2 = blas::doubleCG3UpdateNorm(gamma, rho, rS, rS_old, AAdagrS);
208  } else {
209  blas::copy(tmpS, xS);
210  blas::axpby(gamma*rho, AdagrS, rho, xS);
211  blas::axpy(1.-rho, xS_old, xS);
212  blas::copy(xS_old, tmpS);
213 
214  blas::copy(tmpS, rS);
215  blas::axpby(-gamma*rho, AAdagrS, rho, rS);
216  r2 = blas::axpyNorm(1.-rho, rS_old, rS);
217  blas::copy(rS_old, tmpS);
218  }
219  }
220 
221  k++;
222 
223  if (use_heavy_quark_res && k%heavy_quark_check==0) {
224  heavy_quark_res_old = heavy_quark_res;
225  if (mixed_precision) {
226  blas::copy(tmpS,y);
227  heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(xS, tmpS, rS).z);
228  } else {
229  heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(xS, rS).z);
230  }
231  }
232 
233  // reliable update conditions
234  if (mixed_precision) {
235  rNorm = sqrt(r2);
236 
237  if (rNorm > maxrx) maxrx = rNorm;
238  if (rNorm > maxrr) maxrr = rNorm;
239  bool update = (rNorm < delta*r0Norm && r0Norm <= maxrx); // condition for x
240  update = ( update || (rNorm < delta*maxrr && r0Norm <= maxrr)); // condition for r
241 
242  // force a reliable update if we are within target tolerance (only if doing reliable updates)
243  if ( convergence(r2, heavy_quark_res, stop, param.tol_hq) && param.delta >= param.tol ) update = true;
244 
245  // For heavy-quark inversion force a reliable update if we continue after
246  if ( use_heavy_quark_res and L2breakdown and convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq) and param.delta >= param.tol ) {
247  update = true;
248  }
249 
250  if (update) {
251  // updating the "new" vectors
252  blas::copy(x, xS);
253  blas::xpy(x, y);
254  mat(r, y, x); // here we can use x as tmp
255  r2 = blas::xmyNorm(b, r);
256  param.true_res = sqrt(r2 / b2);
257  if (use_heavy_quark_res) {
258  heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(y, r).z);
259  param.true_res_hq = heavy_quark_res;
260  }
261  // we update sloppy and old fields
262  if (!convergence(r2, heavy_quark_res, stop, param.tol_hq)) {
263  blas::copy(rS, r);
264  blas::axpy(-1., xS, xS_old);
265  if(is_cg3ne) {
266  // we preserve the orthogonality between the previous residual and the new
267  // This is possible only for cg3ne. For cg3nr ArS and ArS_old should be orthogonal
268  Complex rr_old = blas::cDotProduct(rS, rS_old);
269  r2_old = blas::caxpyNorm(-rr_old/r2, rS, rS_old);
270  blas::zero(xS);
271  }
272  }
273  }
274 
275  // break-out check if we have reached the limit of the precision
276  if (r2 > r2_old) {
277  resIncrease++;
278  resIncreaseTotal++;
279  warningQuda("%s: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
280  name, sqrt(r2), sqrt(r2_old), resIncreaseTotal);
281  if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
282  if (use_heavy_quark_res) {
283  L2breakdown = true;
284  } else {
285  warningQuda("%s: solver exiting due to too many true residual norm increases", name);
286  break;
287  }
288  }
289  } else {
290  resIncrease = 0;
291  }
292 
293  // if L2 broke down we turn off reliable updates and restart the CG
294  if (use_heavy_quark_res and L2breakdown) {
295  delta = 0;
296  heavy_quark_check = 1;
297  warningQuda("%s: Restarting without reliable updates for heavy-quark residual", name);
298  restart = true;
299  L2breakdown = false;
300  if (heavy_quark_res > heavy_quark_res_old) {
301  hqresIncrease++;
302  warningQuda("%s: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", name, heavy_quark_res, heavy_quark_res_old);
303  // break out if we do not improve here anymore
304  if (hqresIncrease > hqmaxresIncrease) {
305  warningQuda("%s: solver exiting due to too many heavy quark residual norm increases", name);
306  break;
307  }
308  }
309  }
310  } else {
311  if (convergence(r2, heavy_quark_res, stop, param.tol_hq)) {
312  mat(r, x, tmpS);
313  r2 = blas::xmyNorm(b, r);
314  // we update sloppy and old fields
315  if (!convergence(r2, heavy_quark_res, stop, param.tol_hq) and is_cg3ne) {
316  // we preserve the orthogonality between the previous residual and the new
317  // This is possible only for cg3ne. For cg3nr ArS and ArS_old should be orthogonal
318  Complex rr_old = blas::cDotProduct(rS, rS_old);
319  r2_old = blas::caxpyNorm(-rr_old/r2, rS, rS_old);
320  }
321  }
322 
323  // break-out check if we have reached the limit of the precision
324  if (r2 > r2_old) {
325  resIncrease++;
326  resIncreaseTotal++;
327  warningQuda("CG3: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
328  sqrt(r2), sqrt(r2_old), resIncreaseTotal);
329  if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
330  warningQuda("CG3: solver exiting due to too many true residual norm increases");
331  break;
332  }
333  }
334  }
335 
336  PrintStats(name, k, r2, b2, heavy_quark_res);
337  }
338 
341 
343  double gflops = (blas::flops + mat.flops() + matSloppy.flops())*1e-9;
344  param.gflops = gflops;
345  param.iter += k;
346 
347  if (k == param.maxiter)
348  warningQuda("Exceeded maximum iterations %d", param.maxiter);
349 
350 
351  // compute the true residuals
352  if (!mixed_precision && param.compute_true_res) {
353  mat(r, x, y);
354  param.true_res = sqrt(blas::xmyNorm(b, r) / b2);
355  if (use_heavy_quark_res) param.true_res_hq = sqrt(blas::HeavyQuarkResidualNorm(x, r).z);
356  }
357 
359  blas::copy(b, r);
360  }
361 
362  PrintSummary(name, k, r2, b2, stop, param.tol_hq);
363 
364  // reset the flops counters
365  blas::flops = 0;
366  mat.flops();
367  matSloppy.flops();
368 
370 
371  return;
372  }
373 
374 } // namespace quda
void setPrecision(QudaPrecision precision, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION, bool force_native=false)
int pipeline
Definition: test_util.cpp:1634
bool convergenceHQ(double r2, double hq2, double r2_tol, double hq_tol)
Test for HQ solver convergence – ignore L2 residual.
Definition: solver.cpp:237
double caxpyNorm(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:746
ColorSpinorField * rp
Definition: invert_quda.h:632
QudaInverterType inv_type
Definition: invert_quda.h:22
#define errorQuda(...)
Definition: util_quda.h:121
double norm2(const ColorSpinorField &a)
Definition: reduce_quda.cu:721
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:120
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
Definition: reduce_quda.cu:764
ColorSpinorField * yp
Definition: invert_quda.h:632
void PrintStats(const char *name, int k, double r2, double b2, double hq2)
Prints out the running statistics of the solver (requires a verbosity of QUDA_VERBOSE) ...
Definition: solver.cpp:256
double3 xpyHeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &r)
Definition: reduce_quda.cu:818
static ColorSpinorField * Create(const ColorSpinorParam &param)
bool convergence(double r2, double hq2, double r2_tol, double hq_tol)
Definition: solver.cpp:223
TimeProfile & profile
Definition: invert_quda.h:464
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
Definition: copy_quda.cu:355
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:75
ColorSpinorField * rS_oldp
Definition: invert_quda.h:632
QudaPreserveSource preserve_source
Definition: invert_quda.h:154
int max_res_increase_total
Definition: invert_quda.h:96
const DiracMatrix & matSloppy
Definition: invert_quda.h:629
void doubleCG3Update(double a, double b, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Definition: blas_quda.cu:631
void operator()(ColorSpinorField &out, ColorSpinorField &in)
CG3NE(DiracMatrix &mat, DiracMatrix &matSloppy, SolverParam &param, TimeProfile &profile)
QudaGaugeParam param
Definition: pack_test.cpp:17
QudaComputeNullVector compute_null_vector
Definition: invert_quda.h:67
double Last(QudaProfileType idx)
Definition: timer.h:251
QudaResidualType residual_type
Definition: invert_quda.h:49
static double stopping(double tol, double b2, QudaResidualType residual_type)
Set the solver L2 stopping condition.
Definition: solver.cpp:206
ColorSpinorParam csParam
Definition: pack_test.cpp:24
void axpy(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:35
ColorSpinorField * rSp
Definition: invert_quda.h:632
#define warningQuda(...)
Definition: util_quda.h:133
#define checkLocation(...)
DiracDagger matDagSloppy
Definition: invert_quda.h:630
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
Definition: reduce_quda.cu:809
std::complex< double > Complex
Definition: quda_internal.h:46
void init()
Create the CUBLAS context.
Definition: blas_cublas.cu:31
void zero(ColorSpinorField &a)
Definition: blas_quda.cu:472
void doubleCG3Init(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Definition: blas_quda.cu:626
virtual ~CG3NE()
QudaPrecision precision
Definition: invert_quda.h:142
ColorSpinorField * xSp
Definition: invert_quda.h:632
SolverParam & param
Definition: invert_quda.h:463
ColorSpinorField * xS_oldp
Definition: invert_quda.h:632
ColorSpinorField * AAdagrSp
Definition: invert_quda.h:632
unsigned long long flops() const
Definition: dirac_quda.h:1119
#define printfQuda(...)
Definition: util_quda.h:115
double doubleCG3UpdateNorm(double a, double b, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Definition: reduce_quda.cu:853
const DiracMatrix & mat
Definition: invert_quda.h:628
unsigned long long flops
Definition: blas_quda.cu:22
void xpy(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:33
void axpby(double a, ColorSpinorField &x, double b, ColorSpinorField &y)
Definition: blas_quda.h:36
double axpyNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:74
QudaUseInitGuess use_init_guess
Definition: invert_quda.h:64
ColorSpinorField * AdagrSp
Definition: invert_quda.h:632
QudaPrecision precision_sloppy
Definition: invert_quda.h:145
void PrintSummary(const char *name, int k, double r2, double b2, double r2_tol, double hq_tol)
Prints out the summary of the solver convergence (requires a verbosity of QUDA_SUMMARIZE). Assumes SolverParam.true_res and SolverParam.true_res_hq has been set.
Definition: solver.cpp:270
double doubleCG3InitNorm(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Definition: reduce_quda.cu:848
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
QudaPrecision Precision() const
ColorSpinorField * tmpSp
Definition: invert_quda.h:632