QUDA  v1.1.0
A library for QCD on GPUs
inv_msrc_cg_quda.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 
5 #include <quda_internal.h>
6 #include <color_spinor_field.h>
7 #include <blas_quda.h>
8 #include <dslash_quda.h>
9 #include <invert_quda.h>
10 #include <util_quda.h>
11 #include <sys/time.h>
12 #include <iostream>
13 
14 namespace quda {
15 
16  MultiSrcCG::MultiSrcCG(DiracMatrix &mat, DiracMatrix &matSloppy, SolverParam &param, TimeProfile &profile) :
17  MultiSrcSolver(param, profile), mat(mat), matSloppy(matSloppy)
18  {
19 
20  }
21 
22  MultiSrcCG::~MultiSrcCG() {
23 
24  }
25 
26  void MultiSrcCG::operator()(std::vector<ColorSpinorField*> out, std::vector<ColorSpinorField*> in)
27  {
28 
29 
30 
31 #if 0
32 
33  if (Location(x, b) != QUDA_CUDA_FIELD_LOCATION)
34  errorQuda("Not supported");
35 
36  profile.TPSTART(QUDA_PROFILE_INIT);
37 
38  // Check to see that we're not trying to invert on a zero-field source
39  const double b2 = blas::norm2(b);
40  if(b2 == 0){
41  profile.TPSTOP(QUDA_PROFILE_INIT);
42  printfQuda("Warning: inverting on zero-field source\n");
43  x=b;
44  param.true_res = 0.0;
45  param.true_res_hq = 0.0;
46  return;
47  }
48 
49  cudaColorSpinorField r(b);
50 
51  ColorSpinorParam csParam(x);
53  cudaColorSpinorField y(b, csParam);
54 
55  mat(r, x, y);
56  double r2 = blas::xmyNorm(b, r);
57 
58  csParam.setPrecision(param.precision_sloppy);
59  cudaColorSpinorField Ap(x, csParam);
60  cudaColorSpinorField tmp(x, csParam);
61 
62  // tmp2 only needed for multi-gpu Wilson-like kernels
63  cudaColorSpinorField *tmp2_p = !mat.isStaggered() ?
64  new cudaColorSpinorField(x, csParam) : &tmp;
65  cudaColorSpinorField &tmp2 = *tmp2_p;
66 
67  cudaColorSpinorField *r_sloppy;
68  if (param.precision_sloppy == x.Precision()) {
69  r_sloppy = &r;
70  } else {
72  r_sloppy = new cudaColorSpinorField(r, csParam);
73  }
74 
75  cudaColorSpinorField *x_sloppy;
76  if (param.precision_sloppy == x.Precision() ||
77  !param.use_sloppy_partial_accumulator) {
78  x_sloppy = &static_cast<cudaColorSpinorField&>(x);
79  } else {
81  x_sloppy = new cudaColorSpinorField(x, csParam);
82  }
83 
84  // additional high-precision temporary if Wilson and mixed-precision
85  csParam.setPrecision(param.precision);
86  cudaColorSpinorField *tmp3_p =
87  (param.precision != param.precision_sloppy && !mat.isStaggered()) ?
88  new cudaColorSpinorField(x, csParam) : &tmp;
89  cudaColorSpinorField &tmp3 = *tmp3_p;
90 
91  ColorSpinorField &xSloppy = *x_sloppy;
92  ColorSpinorField &rSloppy = *r_sloppy;
93 
94  cudaColorSpinorField p(rSloppy);
95 
96  if(&x != &xSloppy){
97  blas::copy(y,x);
98  blas::zero(xSloppy);
99  }else{
100  blas::zero(y);
101  }
102 
103  const bool use_heavy_quark_res =
104  (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false;
105  bool heavy_quark_restart = false;
106 
107  profile.TPSTOP(QUDA_PROFILE_INIT);
108  profile.TPSTART(QUDA_PROFILE_PREAMBLE);
109 
110  double r2_old;
111 
112  double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver
113 
114  double heavy_quark_res = 0.0; // heavy quark residual
115  double heavy_quark_res_old = 0.0; // heavy quark residual
116 
117  if (use_heavy_quark_res) {
118  heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z);
119  heavy_quark_res_old = heavy_quark_res; // heavy quark residual
120  }
121  const int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual
122 
123  double alpha=0.0, beta=0.0;
124  double pAp;
125  int rUpdate = 0;
126 
127  double rNorm = sqrt(r2);
128  double r0Norm = rNorm;
129  double maxrx = rNorm;
130  double maxrr = rNorm;
131  double delta = param.delta;
132 
133  // this parameter determines how many consective reliable update
134  // reisudal increases we tolerate before terminating the solver,
135  // i.e., how long do we want to keep trying to converge
136  const int maxResIncrease = (use_heavy_quark_res ? 0 : param.max_res_increase); // check if we reached the limit of our tolerance
137  const int maxResIncreaseTotal = param.max_res_increase_total;
138  // 0 means we have no tolerance
139  // maybe we should expose this as a parameter
140  const int hqmaxresIncrease = maxResIncrease + 1;
141 
142  int resIncrease = 0;
143  int resIncreaseTotal = 0;
144  int hqresIncrease = 0;
145 
146  // set this to true if maxResIncrease has been exceeded but when we use heavy quark residual we still want to continue the CG
147  // only used if we use the heavy_quark_res
148  bool L2breakdown =false;
149 
150  profile.TPSTOP(QUDA_PROFILE_PREAMBLE);
151  profile.TPSTART(QUDA_PROFILE_COMPUTE);
152  blas::flops = 0;
153 
154  int k=0;
155 
156  PrintStats("CG", k, r2, b2, heavy_quark_res);
157 
158  int steps_since_reliable = 1;
159  bool converged = convergence(r2, heavy_quark_res, stop, param.tol_hq);
160 
161  while ( !converged && k < param.maxiter) {
162  matSloppy(Ap, p, tmp, tmp2); // tmp as tmp
163 
164  double sigma;
165 
166  bool breakdown = false;
167  if (param.pipeline) {
168  double3 triplet = blas::tripleCGReduction(rSloppy, Ap, p);
169  r2 = triplet.x; double Ap2 = triplet.y; pAp = triplet.z;
170  r2_old = r2;
171 
172  alpha = r2 / pAp;
173  sigma = alpha*(alpha * Ap2 - pAp);
174  if (sigma < 0.0 || steps_since_reliable==0) { // sigma condition has broken down
175  r2 = blas::axpyNorm(-alpha, Ap, rSloppy);
176  sigma = r2;
177  breakdown = true;
178  }
179 
180  r2 = sigma;
181  } else {
182  r2_old = r2;
183  pAp = blas::reDotProduct(p, Ap);
184  alpha = r2 / pAp;
185 
186  // here we are deploying the alternative beta computation
187  Complex cg_norm = blas::axpyCGNorm(-alpha, Ap, rSloppy);
188  r2 = real(cg_norm); // (r_new, r_new)
189  sigma = imag(cg_norm) >= 0.0 ? imag(cg_norm) : r2; // use r2 if (r_k+1, r_k+1-r_k) breaks
190  }
191 
192  // reliable update conditions
193  rNorm = sqrt(r2);
194  if (rNorm > maxrx) maxrx = rNorm;
195  if (rNorm > maxrr) maxrr = rNorm;
196  int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? 1 : 0;
197  int updateR = ((rNorm < delta*maxrr && r0Norm <= maxrr) || updateX) ? 1 : 0;
198 
199  // force a reliable update if we are within target tolerance (only if doing reliable updates)
200  if ( convergence(r2, heavy_quark_res, stop, param.tol_hq) && param.delta >= param.tol) updateX = 1;
201 
202  // For heavy-quark inversion force a reliable update if we continue after
203  if (use_heavy_quark_res and L2breakdown and convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq) and param.delta >= param.tol) {
204  updateX = 1;
205  }
206 
207  if ( !(updateR || updateX)) {
208  //beta = r2 / r2_old;
209  beta = sigma / r2_old; // use the alternative beta computation
210 
211  if (param.pipeline && !breakdown) blas::tripleCGUpdate(alpha, beta, Ap, rSloppy, xSloppy, p);
212  else blas::axpyZpbx(alpha, p, xSloppy, rSloppy, beta);
213 
214 
215  if (use_heavy_quark_res && k%heavy_quark_check==0) {
216  if (&x != &xSloppy) {
217  blas::copy(tmp,y);
218  heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(xSloppy, tmp, rSloppy).z);
219  } else {
220  blas::copy(r, rSloppy);
221  heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x, y, r).z);
222  }
223  }
224 
225  steps_since_reliable++;
226  } else {
227  blas::axpy(alpha, p, xSloppy);
228  blas::copy(x, xSloppy); // nop when these pointers alias
229 
230  blas::xpy(x, y); // swap these around?
231  mat(r, y, x, tmp3); // here we can use x as tmp
232  r2 = blas::xmyNorm(b, r);
233 
234  blas::copy(rSloppy, r); //nop when these pointers alias
235  blas::zero(xSloppy);
236 
237  // calculate new reliable HQ resididual
238  if (use_heavy_quark_res) heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(y, r).z);
239 
240  // break-out check if we have reached the limit of the precision
241  if (sqrt(r2) > r0Norm && updateX) { // reuse r0Norm for this
242  resIncrease++;
243  resIncreaseTotal++;
244  warningQuda("CG: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
245  sqrt(r2), r0Norm, resIncreaseTotal);
246  if ( resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
247  if (use_heavy_quark_res) {
248  L2breakdown = true;
249  } else {
250  warningQuda("CG: solver exiting due to too many true residual norm increases");
251  break;
252  }
253  }
254  } else {
255  resIncrease = 0;
256  }
257  // if L2 broke down already we turn off reliable updates and restart the CG
258  if (use_heavy_quark_res and L2breakdown) {
259  delta = 0;
260  warningQuda("CG: Restarting without reliable updates for heavy-quark residual");
261  heavy_quark_restart = true;
262  if (heavy_quark_res > heavy_quark_res_old) {
263  hqresIncrease++;
264  warningQuda("CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", heavy_quark_res, heavy_quark_res_old);
265  // break out if we do not improve here anymore
266  if (hqresIncrease > hqmaxresIncrease) {
267  warningQuda("CG: solver exiting due to too many heavy quark residual norm increases");
268  break;
269  }
270  }
271  }
272 
273  rNorm = sqrt(r2);
274  maxrr = rNorm;
275  maxrx = rNorm;
276  r0Norm = rNorm;
277  rUpdate++;
278 
279  if (use_heavy_quark_res and heavy_quark_restart) {
280  // perform a restart
281  blas::copy(p, rSloppy);
282  heavy_quark_restart = false;
283  } else {
284  // explicitly restore the orthogonality of the gradient vector
285  double rp = blas::reDotProduct(rSloppy, p) / (r2);
286  blas::axpy(-rp, rSloppy, p);
287 
288  beta = r2 / r2_old;
289  blas::xpay(rSloppy, beta, p);
290  }
291 
292 
293  steps_since_reliable = 0;
294  heavy_quark_res_old = heavy_quark_res;
295  }
296 
297  breakdown = false;
298  k++;
299 
300  PrintStats("CG", k, r2, b2, heavy_quark_res);
301  // check convergence, if convergence is satisfied we only need to check that we had a reliable update for the heavy quarks recently
302  converged = convergence(r2, heavy_quark_res, stop, param.tol_hq);
303 
304  // check for recent enough reliable updates of the HQ residual if we use it
305  if (use_heavy_quark_res) {
306  // L2 is concverged or precision maxed out for L2
307  bool L2done = L2breakdown or convergenceL2(r2, heavy_quark_res, stop, param.tol_hq);
308  // HQ is converged and if we do reliable update the HQ residual has been calculated using a reliable update
309  bool HQdone = (steps_since_reliable == 0 and param.delta > 0) and convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq);
310  converged = L2done and HQdone;
311  }
312 
313  }
314 
315  blas::copy(x, xSloppy);
316  blas::xpy(y, x);
317 
318  profile.TPSTOP(QUDA_PROFILE_COMPUTE);
319  profile.TPSTART(QUDA_PROFILE_EPILOGUE);
320 
321  param.secs = profile.Last(QUDA_PROFILE_COMPUTE);
322  double gflops = (blas::flops + mat.flops() + matSloppy.flops())*1e-9;
323  reduceDouble(gflops);
324  param.gflops = gflops;
325  param.iter += k;
326 
327  if (k==param.maxiter)
328  warningQuda("Exceeded maximum iterations %d", param.maxiter);
329 
330  if (getVerbosity() >= QUDA_VERBOSE)
331  printfQuda("CG: Reliable updates = %d\n", rUpdate);
332 
333  // compute the true residuals
334  mat(r, x, y, tmp3);
335  param.true_res = sqrt(blas::xmyNorm(b, r) / b2);
336  param.true_res_hq = sqrt(blas::HeavyQuarkResidualNorm(x,r).z);
337 
338  PrintSummary("CG", k, r2, b2, stop, inv.tol_hq);
339 
340  // reset the flops counters
341  blas::flops = 0;
342  mat.flops();
343  matSloppy.flops();
344 
345  profile.TPSTOP(QUDA_PROFILE_EPILOGUE);
346  profile.TPSTART(QUDA_PROFILE_FREE);
347 
348  if (&tmp3 != &tmp) delete tmp3_p;
349  if (&tmp2 != &tmp) delete tmp2_p;
350 
351  if (rSloppy.Precision() != r.Precision()) delete r_sloppy;
352  if (xSloppy.Precision() != x.Precision()) delete x_sloppy;
353 
354  profile.TPSTOP(QUDA_PROFILE_FREE);
355 #endif
356  return;
357  }
358 
359 } // namespace quda
void reduceDouble(double &)
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:34
@ QUDA_CUDA_FIELD_LOCATION
Definition: enum_quda.h:326
@ QUDA_VERBOSE
Definition: enum_quda.h:267
@ QUDA_HEAVY_QUARK_RESIDUAL
Definition: enum_quda.h:195
@ QUDA_ZERO_FIELD_CREATE
Definition: enum_quda.h:361
@ QUDA_COPY_FIELD_CREATE
Definition: enum_quda.h:362
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:79
unsigned long long flops
void xpay(ColorSpinorField &x, double a, ColorSpinorField &y)
Definition: blas_quda.h:45
void zero(ColorSpinorField &a)
double norm2(const ColorSpinorField &a)
double axpyNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:78
void axpy(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:43
void xpy(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:41
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
Definition: blas_quda.h:24
void stop()
Stop profiling.
Definition: device.cpp:228
std::complex< double > Complex
Definition: quda_internal.h:86
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:120
@ QUDA_PROFILE_INIT
Definition: timer.h:106
@ QUDA_PROFILE_EPILOGUE
Definition: timer.h:110
@ QUDA_PROFILE_COMPUTE
Definition: timer.h:108
@ QUDA_PROFILE_FREE
Definition: timer.h:111
@ QUDA_PROFILE_PREAMBLE
Definition: timer.h:107
ColorSpinorParam csParam
Definition: pack_test.cpp:25
QudaGaugeParam param
Definition: pack_test.cpp:18
void updateR()
update the radius for halos.
#define printfQuda(...)
Definition: util_quda.h:114
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define warningQuda(...)
Definition: util_quda.h:132
#define errorQuda(...)
Definition: util_quda.h:120