QUDA  v0.7.0
A library for QCD on GPUs
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
inv_cg_quda.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 
5 #include <quda_internal.h>
6 #include <color_spinor_field.h>
7 #include <blas_quda.h>
8 #include <dslash_quda.h>
9 #include <invert_quda.h>
10 #include <util_quda.h>
11 #include <sys/time.h>
12 
13 #include <face_quda.h>
14 
15 #include <iostream>
16 
17 namespace quda {
18 
20  Solver(param, profile), mat(mat), matSloppy(matSloppy)
21  {
22 
23  }
24 
25  CG::~CG() {
26 
27  }
28 
30  {
32 
33  // Check to see that we're not trying to invert on a zero-field source
34  const double b2 = norm2(b);
35  if(b2 == 0){
37  printfQuda("Warning: inverting on zero-field source\n");
38  x=b;
39  param.true_res = 0.0;
40  param.true_res_hq = 0.0;
41  return;
42  }
43 
44 
46 
49  cudaColorSpinorField y(b, csParam);
50 
51  mat(r, x, y);
52 
53  double r2 = xmyNormCuda(b, r);
54 
56  cudaColorSpinorField Ap(x, csParam);
57  cudaColorSpinorField tmp(x, csParam);
58 
59  // tmp2 only needed for multi-gpu Wilson-like kernels
60  cudaColorSpinorField *tmp2_p = !mat.isStaggered() ?
61  new cudaColorSpinorField(x, csParam) : &tmp;
62  cudaColorSpinorField &tmp2 = *tmp2_p;
63 
64  cudaColorSpinorField *r_sloppy;
65  if (param.precision_sloppy == x.Precision()) {
66  r_sloppy = &r;
67  } else {
69  r_sloppy = new cudaColorSpinorField(r, csParam);
70  }
71 
72  cudaColorSpinorField *x_sloppy;
73  if (param.precision_sloppy == x.Precision() ||
75  x_sloppy = &x;
76  } else {
78  x_sloppy = new cudaColorSpinorField(x, csParam);
79  }
80 
81  // additional high-precision temporary if Wilson and mixed-precision
82  csParam.setPrecision(param.precision);
83  cudaColorSpinorField *tmp3_p =
85  new cudaColorSpinorField(x, csParam) : &tmp;
86  cudaColorSpinorField &tmp3 = *tmp3_p;
87 
88  cudaColorSpinorField &xSloppy = *x_sloppy;
89  cudaColorSpinorField &rSloppy = *r_sloppy;
90  cudaColorSpinorField p(rSloppy);
91 
92  if(&x != &xSloppy){
93  copyCuda(y,x);
94  zeroCuda(xSloppy);
95  } else {
96  zeroCuda(y);
97  }
98 
99  const bool use_heavy_quark_res =
100  (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false;
101  bool heavy_quark_restart = false;
102 
105 
106  double r2_old;
107 
108  double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver
109 
110  double heavy_quark_res = 0.0; // heavy quark residual
111  double heavy_quark_res_old = 0.0; // heavy quark residual
112 
113  if (use_heavy_quark_res) {
114  heavy_quark_res = sqrt(HeavyQuarkResidualNormCuda(x, r).z);
115  heavy_quark_res_old = heavy_quark_res; // heavy quark residual
116  }
117  const int heavy_quark_check = 1; // how often to check the heavy quark residual
118 
119  double alpha=0.0, beta=0.0;
120  double pAp;
121  int rUpdate = 0;
122 
123  double rNorm = sqrt(r2);
124  double r0Norm = rNorm;
125  double maxrx = rNorm;
126  double maxrr = rNorm;
127  double delta = param.delta;
128 
129  // this parameter determines how many consective reliable update
130  // reisudal increases we tolerate before terminating the solver,
131  // i.e., how long do we want to keep trying to converge
132  const int maxResIncrease = (use_heavy_quark_res ? 0 : param.max_res_increase); // check if we reached the limit of our tolerance
133  const int maxResIncreaseTotal = param.max_res_increase_total;
134  // 0 means we have no tolerance
135  // maybe we should expose this as a parameter
136  const int hqmaxresIncrease = maxResIncrease + 1;
137 
138  int resIncrease = 0;
139  int resIncreaseTotal = 0;
140  int hqresIncrease = 0;
141 
142  // set this to true if maxResIncrease has been exceeded but when we use heavy quark residual we still want to continue the CG
143  // only used if we use the heavy_quark_res
144  bool L2breakdown =false;
145 
148  blas_flops = 0;
149 
150  int k=0;
151 
152  PrintStats("CG", k, r2, b2, heavy_quark_res);
153 
154  int steps_since_reliable = 1;
155  bool converged = convergence(r2, heavy_quark_res, stop, param.tol_hq);
156 
157  while ( !converged && k < param.maxiter) {
158  matSloppy(Ap, p, tmp, tmp2); // tmp as tmp
159 
160  double sigma;
161 
162  bool breakdown = false;
163 
164  if (param.pipeline) {
165  double3 triplet = tripleCGReductionCuda(rSloppy, Ap, p);
166  r2 = triplet.x; double Ap2 = triplet.y; pAp = triplet.z;
167  r2_old = r2;
168 
169  alpha = r2 / pAp;
170  sigma = alpha*(alpha * Ap2 - pAp);
171  if (sigma < 0.0 || steps_since_reliable==0) { // sigma condition has broken down
172  r2 = axpyNormCuda(-alpha, Ap, rSloppy);
173  sigma = r2;
174  breakdown = true;
175  }
176 
177  r2 = sigma;
178  } else {
179  r2_old = r2;
180  pAp = reDotProductCuda(p, Ap);
181  alpha = r2 / pAp;
182 
183  // here we are deploying the alternative beta computation
184  Complex cg_norm = axpyCGNormCuda(-alpha, Ap, rSloppy);
185  r2 = real(cg_norm); // (r_new, r_new)
186  sigma = imag(cg_norm) >= 0.0 ? imag(cg_norm) : r2; // use r2 if (r_k+1, r_k+1-r_k) breaks
187  }
188 
189  // reliable update conditions
190  rNorm = sqrt(r2);
191  if (rNorm > maxrx) maxrx = rNorm;
192  if (rNorm > maxrr) maxrr = rNorm;
193  int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? 1 : 0;
194  int updateR = ((rNorm < delta*maxrr && r0Norm <= maxrr) || updateX) ? 1 : 0;
195 
196  // force a reliable update if we are within target tolerance (only if doing reliable updates)
197  if ( convergence(r2, heavy_quark_res, stop, param.tol_hq) && param.delta >= param.tol) updateX = 1;
198 
199  // For heavy-quark inversion force a reliable update if we continue after
200  if (use_heavy_quark_res and L2breakdown and convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq) and param.delta >= param.tol) {
201  updateX = 1;
202  }
203 
204  if ( !(updateR || updateX)) {
205  //beta = r2 / r2_old;
206  beta = sigma / r2_old; // use the alternative beta computation
207 
208  if (param.pipeline && !breakdown) tripleCGUpdateCuda(alpha, beta, Ap, xSloppy, rSloppy, p);
209  else axpyZpbxCuda(alpha, p, xSloppy, rSloppy, beta);
210 
211 
212  if (use_heavy_quark_res && k%heavy_quark_check==0) {
213  if (&x != &xSloppy) {
214  copyCuda(tmp,y);
215  heavy_quark_res = sqrt(xpyHeavyQuarkResidualNormCuda(xSloppy, tmp, rSloppy).z);
216  } else {
217  copyCuda(r, rSloppy);
218  heavy_quark_res = sqrt(xpyHeavyQuarkResidualNormCuda(x, y, r).z);
219  }
220  }
221 
222  steps_since_reliable++;
223  } else {
224 
225  axpyCuda(alpha, p, xSloppy);
226  copyCuda(x, xSloppy); // nop when these pointers alias
227 
228  xpyCuda(x, y); // swap these around?
229  mat(r, y, x, tmp3); // here we can use x as tmp
230  r2 = xmyNormCuda(b, r);
231 
232  copyCuda(rSloppy, r); //nop when these pointers alias
233  zeroCuda(xSloppy);
234 
235  // calculate new reliable HQ resididual
236  if (use_heavy_quark_res) heavy_quark_res = sqrt(HeavyQuarkResidualNormCuda(y, r).z);
237 
238  // break-out check if we have reached the limit of the precision
239  if (sqrt(r2) > r0Norm && updateX) { // reuse r0Norm for this
240  resIncrease++;
241  resIncreaseTotal++;
242  warningQuda("CG: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
243  sqrt(r2), r0Norm, resIncreaseTotal);
244  if ( resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
245  if (use_heavy_quark_res) L2breakdown = true;
246  else break;
247  }
248  } else {
249  resIncrease = 0;
250  }
251  // if L2 broke down already we turn off reliable updates and restart the CG
252  if (use_heavy_quark_res and L2breakdown) {
253  delta = 0;
254  warningQuda("CG: Restarting without reliable updates for heavy-quark residual");
255  heavy_quark_restart = true;
256  if (heavy_quark_res > heavy_quark_res_old) {
257  hqresIncrease++;
258  warningQuda("CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", heavy_quark_res, heavy_quark_res_old);
259  // break out if we do not improve here anymore
260  if (hqresIncrease > hqmaxresIncrease) break;
261  }
262  }
263 
264  rNorm = sqrt(r2);
265  maxrr = rNorm;
266  maxrx = rNorm;
267  r0Norm = rNorm;
268  rUpdate++;
269 
270  if (use_heavy_quark_res and heavy_quark_restart) {
271  // perform a restart
272  copyCuda(p, rSloppy);
273  heavy_quark_restart = false;
274  }
275  else {
276  // explicitly restore the orthogonality of the gradient vector
277  double rp = reDotProductCuda(rSloppy, p) / (r2);
278  axpyCuda(-rp, rSloppy, p);
279 
280  beta = r2 / r2_old;
281  xpayCuda(rSloppy, beta, p);
282  }
283 
284 
285  steps_since_reliable = 0;
286  heavy_quark_res_old = heavy_quark_res;
287  }
288 
289  breakdown = false;
290  k++;
291 
292  PrintStats("CG", k, r2, b2, heavy_quark_res);
293  // check convergence, if convergence is satisfied we only need to check that we had a reliable update for the heavy quarks recently
294  converged = convergence(r2, heavy_quark_res, stop, param.tol_hq);
295 
296  // check for recent enough relibale updates of the HQ residual if we use it
297  if (use_heavy_quark_res) {
298  // L2 is concverged or precision maxed out for L2
299  bool L2done = L2breakdown or convergenceL2(r2, heavy_quark_res, stop, param.tol_hq);
300  // HQ is converged and if we do reliable update the HQ residual has been caclculated using a reliable update
301  bool HQdone = (steps_since_reliable == 0 and param.delta > 0) and convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq);
302  converged = L2done and HQdone;
303  }
304 
305  }
306 
307  copyCuda(x, xSloppy); // nop when these pointers alias
308  xpyCuda(y, x);
309 
312 
314  double gflops = (quda::blas_flops + mat.flops() + matSloppy.flops())*1e-9;
315  reduceDouble(gflops);
316  param.gflops = gflops;
317  param.iter += k;
318 
319  if (k==param.maxiter)
320  warningQuda("Exceeded maximum iterations %d", param.maxiter);
321 
322  if (getVerbosity() >= QUDA_VERBOSE)
323  printfQuda("CG: Reliable updates = %d\n", rUpdate);
324 
325  // compute the true residuals
326  mat(r, x, y);
327  param.true_res = sqrt(xmyNormCuda(b, r) / b2);
328 #if (__COMPUTE_CAPABILITY__ >= 200)
330 #else
331  param.true_res_hq = 0.0;
332 #endif
333 
334  PrintSummary("CG", k, r2, b2);
335 
336  // reset the flops counters
337  quda::blas_flops = 0;
338  mat.flops();
339  matSloppy.flops();
340 
343 
344  if (&tmp3 != &tmp) delete tmp3_p;
345  if (&tmp2 != &tmp) delete tmp2_p;
346 
347  if (rSloppy.Precision() != r.Precision()) delete r_sloppy;
348  if (xSloppy.Precision() != x.Precision()) delete x_sloppy;
349 
351 
352  return;
353  }
354 
355 } // namespace quda
bool convergence(const double &r2, const double &hq2, const double &r2_tol, const double &hq_tol)
Definition: solver.cpp:82
void setPrecision(QudaPrecision precision)
double3 tripleCGReductionCuda(cudaColorSpinorField &x, cudaColorSpinorField &y, cudaColorSpinorField &z)
Definition: reduce_quda.cu:811
static double stopping(const double &tol, const double &b2, QudaResidualType residual_type)
Definition: solver.cpp:65
int y[4]
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:20
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:105
double axpyNormCuda(const double &a, cudaColorSpinorField &x, cudaColorSpinorField &y)
Definition: reduce_quda.cu:321
std::complex< double > Complex
Definition: eig_variables.h:13
bool convergenceL2(const double &r2, const double &hq2, const double &r2_tol, const double &hq_tol)
Definition: solver.cpp:110
void axpyZpbxCuda(const double &a, cudaColorSpinorField &x, cudaColorSpinorField &y, cudaColorSpinorField &z, const double &b)
Definition: blas_quda.cu:338
void mat(void *out, void **fatlink, void **longlink, void *in, double kappa, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision)
TimeProfile & profile
Definition: invert_quda.h:224
int max_res_increase_total
Definition: invert_quda.h:54
Complex axpyCGNormCuda(const double &a, cudaColorSpinorField &x, cudaColorSpinorField &y)
Definition: reduce_quda.cu:682
unsigned long long flops() const
Definition: dirac_quda.h:587
QudaGaugeParam param
Definition: pack_test.cpp:17
cudaColorSpinorField * tmp2
Definition: dslash_test.cpp:41
cudaColorSpinorField * tmp
void PrintSummary(const char *name, int k, const double &r2, const double &b2)
Definition: solver.cpp:137
QudaResidualType residual_type
Definition: invert_quda.h:35
CG(DiracMatrix &mat, DiracMatrix &matSloppy, SolverParam &param, TimeProfile &profile)
Definition: inv_cg_quda.cpp:19
ColorSpinorParam csParam
Definition: pack_test.cpp:24
#define warningQuda(...)
Definition: util_quda.h:84
void copyCuda(cudaColorSpinorField &dst, const cudaColorSpinorField &src)
Definition: copy_quda.cu:235
void operator()(cudaColorSpinorField &out, cudaColorSpinorField &in)
Definition: inv_cg_quda.cpp:29
void axpyCuda(const double &a, cudaColorSpinorField &x, cudaColorSpinorField &y)
Definition: blas_quda.cu:115
int x[4]
unsigned long long blas_flops
Definition: blas_quda.cu:37
double3 xpyHeavyQuarkResidualNormCuda(cudaColorSpinorField &x, cudaColorSpinorField &y, cudaColorSpinorField &r)
Definition: reduce_quda.cu:782
QudaPrecision precision
Definition: invert_quda.h:81
SolverParam & param
Definition: invert_quda.h:223
void xpyCuda(cudaColorSpinorField &x, cudaColorSpinorField &y)
Definition: blas_quda.cu:98
double reDotProductCuda(cudaColorSpinorField &a, cudaColorSpinorField &b)
Definition: reduce_quda.cu:170
void Stop(QudaProfileType idx)
QudaPrecision Precision() const
void PrintStats(const char *, int k, const double &r2, const double &b2, const double &hq2)
Definition: solver.cpp:122
double Last(QudaProfileType idx)
void reduceDouble(double &)
#define printfQuda(...)
Definition: util_quda.h:67
void zeroCuda(cudaColorSpinorField &a)
Definition: blas_quda.cu:40
void Start(QudaProfileType idx)
bool isStaggered() const
Definition: dirac_quda.h:594
void tripleCGUpdateCuda(const double &alpha, const double &beta, cudaColorSpinorField &q, cudaColorSpinorField &r, cudaColorSpinorField &x, cudaColorSpinorField &p)
Definition: blas_quda.cu:480
QudaPrecision precision_sloppy
Definition: invert_quda.h:84
bool use_sloppy_partial_accumulator
Definition: invert_quda.h:44
bool convergenceHQ(const double &r2, const double &hq2, const double &r2_tol, const double &hq_tol)
Definition: solver.cpp:99
virtual ~CG()
Definition: inv_cg_quda.cpp:25
void xpayCuda(cudaColorSpinorField &x, const double &a, cudaColorSpinorField &y)
Definition: blas_quda.cu:138
double3 HeavyQuarkResidualNormCuda(cudaColorSpinorField &x, cudaColorSpinorField &r)
Definition: reduce_quda.cu:777
double norm2(const ColorSpinorField &)
double xmyNormCuda(cudaColorSpinorField &a, cudaColorSpinorField &b)
Definition: reduce_quda.cu:343