QUDA  v0.5.0
A library for QCD on GPUs
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
inv_cg_quda.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 
5 #include <quda_internal.h>
6 #include <color_spinor_field.h>
7 #include <blas_quda.h>
8 #include <dslash_quda.h>
9 #include <invert_quda.h>
10 #include <util_quda.h>
11 #include <sys/time.h>
12 
13 #include <face_quda.h>
14 
15 #include <iostream>
16 
17 namespace quda {
18 
19  CG::CG(DiracMatrix &mat, DiracMatrix &matSloppy, QudaInvertParam &invParam, TimeProfile &profile) :
20  Solver(invParam, profile), mat(mat), matSloppy(matSloppy)
21  {
22 
23  }
24 
25  CG::~CG() {
26 
27  }
28 
30  {
31  profile[QUDA_PROFILE_INIT].Start();
32 
33  // Check to see that we're not trying to invert on a zero-field source
34  const double b2 = norm2(b);
35  if(b2 == 0){
36  profile[QUDA_PROFILE_INIT].Stop();
37  printfQuda("Warning: inverting on zero-field source\n");
38  x=b;
39  invParam.true_res = 0.0;
40  invParam.true_res_hq = 0.0;
41  return;
42  }
43 
44 
46 
49  cudaColorSpinorField y(b, param);
50 
51  mat(r, x, y);
52  zeroCuda(y);
53 
54  double r2 = xmyNormCuda(b, r);
55 
57  cudaColorSpinorField Ap(x, param);
58  cudaColorSpinorField tmp(x, param);
59 
60  cudaColorSpinorField *tmp2_p = &tmp;
61  // tmp only needed for multi-gpu Wilson-like kernels
62  if (mat.Type() != typeid(DiracStaggeredPC).name() &&
63  mat.Type() != typeid(DiracStaggered).name()) {
64  tmp2_p = new cudaColorSpinorField(x, param);
65  }
66  cudaColorSpinorField &tmp2 = *tmp2_p;
67 
68  cudaColorSpinorField *x_sloppy, *r_sloppy;
69  if (invParam.cuda_prec_sloppy == x.Precision()) {
71  x_sloppy = &x;
72  r_sloppy = &r;
73  } else {
75  x_sloppy = new cudaColorSpinorField(x, param);
76  r_sloppy = new cudaColorSpinorField(r, param);
77  }
78 
79  cudaColorSpinorField &xSloppy = *x_sloppy;
80  cudaColorSpinorField &rSloppy = *r_sloppy;
81  cudaColorSpinorField p(rSloppy);
82 
83  const bool use_heavy_quark_res =
85 
86  profile[QUDA_PROFILE_INIT].Stop();
88 
89  double r2_old;
90  double stop = b2*invParam.tol*invParam.tol; // stopping condition of solver
91 
92  double heavy_quark_res = 0.0; // heavy quark residual
93  if(use_heavy_quark_res) heavy_quark_res = sqrt(HeavyQuarkResidualNormCuda(x,r).z);
94  int heavy_quark_check = 10; // how often to check the heavy quark residual
95 
96  double alpha=0.0, beta=0.0;
97  double pAp;
98  int rUpdate = 0;
99 
100  double rNorm = sqrt(r2);
101  double r0Norm = rNorm;
102  double maxrx = rNorm;
103  double maxrr = rNorm;
104  double delta = invParam.reliable_delta;
105 
107  profile[QUDA_PROFILE_COMPUTE].Start();
108  blas_flops = 0;
109 
110  int k=0;
111 
112  PrintStats("CG", k, r2, b2, heavy_quark_res);
113 
114  int steps_since_reliable = 1;
115 
116  while ( !convergence(r2, heavy_quark_res, stop, invParam.tol_hq) &&
117  k < invParam.maxiter) {
118  matSloppy(Ap, p, tmp, tmp2); // tmp as tmp
119 
120  double sigma;
121 
122  bool breakdown = false;
123  int pipeline = 0;
124  if (pipeline) {
125  double3 triplet = tripleCGReductionCuda(rSloppy, Ap, p);
126  r2 = triplet.x; double Ap2 = triplet.y; pAp = triplet.z;
127  r2_old = r2;
128 
129  alpha = r2 / pAp;
130  sigma = alpha*(alpha * Ap2 - pAp);
131  if (sigma < 0.0 || steps_since_reliable==0) { // sigma condition has broken down
132  r2 = axpyNormCuda(-alpha, Ap, rSloppy);
133  sigma = r2;
134  breakdown = true;
135  }
136 
137  r2 = sigma;
138  } else {
139  r2_old = r2;
140  pAp = reDotProductCuda(p, Ap);
141  alpha = r2 / pAp;
142 
143  // here we are deploying the alternative beta computation
144  Complex cg_norm = axpyCGNormCuda(-alpha, Ap, rSloppy);
145  r2 = real(cg_norm); // (r_new, r_new)
146  sigma = imag(cg_norm) >= 0.0 ? imag(cg_norm) : r2; // use r2 if (r_k+1, r_k+1-r_k) breaks
147  }
148 
149  // reliable update conditions
150  rNorm = sqrt(r2);
151  if (rNorm > maxrx) maxrx = rNorm;
152  if (rNorm > maxrr) maxrr = rNorm;
153  int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? 1 : 0;
154  int updateR = ((rNorm < delta*maxrr && r0Norm <= maxrr) || updateX) ? 1 : 0;
155 
156  // force a reliable update if we are within target tolerance (only if doing reliable updates)
157  if ( convergence(r2, heavy_quark_res, stop, invParam.tol_hq) && delta >= invParam.tol) updateX = 1;
158 
159  if ( !(updateR || updateX)) {
160  //beta = r2 / r2_old;
161  beta = sigma / r2_old; // use the alternative beta computation
162 
163  if (pipeline && !breakdown) tripleCGUpdateCuda(alpha, beta, Ap, rSloppy, xSloppy, p);
164  else axpyZpbxCuda(alpha, p, xSloppy, rSloppy, beta);
165 
166  if (use_heavy_quark_res && k%heavy_quark_check==0) {
167  copyCuda(tmp,y);
168  heavy_quark_res = sqrt(xpyHeavyQuarkResidualNormCuda(xSloppy, tmp, rSloppy).z);
169  }
170 
171  steps_since_reliable++;
172  } else {
173  axpyCuda(alpha, p, xSloppy);
174  if (x.Precision() != xSloppy.Precision()) copyCuda(x, xSloppy);
175 
176  xpyCuda(x, y); // swap these around?
177  mat(r, y, x); // here we can use x as tmp
178  r2 = xmyNormCuda(b, r);
179 
180  if (x.Precision() != rSloppy.Precision()) copyCuda(rSloppy, r);
181  zeroCuda(xSloppy);
182 
183  // break-out check if we have reached the limit of the precision
184  static int resIncrease = 0;
185  if (sqrt(r2) > r0Norm && updateX) { // reuse r0Norm for this
186  warningQuda("CG: new reliable residual norm %e is greater than previous reliable residual norm %e", sqrt(r2), r0Norm);
187  k++;
188  rUpdate++;
189  if (++resIncrease > 2) break; // only allowed two consecutive residual increases
190  } else {
191  resIncrease = 0;
192  }
193 
194  rNorm = sqrt(r2);
195  maxrr = rNorm;
196  maxrx = rNorm;
197  r0Norm = rNorm;
198  rUpdate++;
199 
200  // explicitly restore the orthogonality of the gradient vector
201  double rp = reDotProductCuda(rSloppy, p) / (r2);
202  axpyCuda(-rp, rSloppy, p);
203 
204  beta = r2 / r2_old;
205  xpayCuda(rSloppy, beta, p);
206 
207  if(use_heavy_quark_res) heavy_quark_res = sqrt(HeavyQuarkResidualNormCuda(y,r).z);
208 
209  steps_since_reliable = 0;
210  }
211 
212  breakdown = false;
213  k++;
214 
215  PrintStats("CG", k, r2, b2, heavy_quark_res);
216  }
217 
218  if (x.Precision() != xSloppy.Precision()) copyCuda(x, xSloppy);
219  xpyCuda(y, x);
220 
223 
225  double gflops = (quda::blas_flops + mat.flops() + matSloppy.flops())*1e-9;
226  reduceDouble(gflops);
227  invParam.gflops = gflops;
228  invParam.iter += k;
229 
230  if (k==invParam.maxiter)
231  warningQuda("Exceeded maximum iterations %d", invParam.maxiter);
232 
234  printfQuda("CG: Reliable updates = %d\n", rUpdate);
235 
236  // compute the true residuals
237  mat(r, x, y);
238  invParam.true_res = sqrt(xmyNormCuda(b, r) / b2);
239 #if (__COMPUTE_CAPABILITY__ >= 200)
241 #else
242  invParam.true_res_hq = 0.0;
243 #endif
244 
245  PrintSummary("CG", k, r2, b2);
246 
247  // reset the flops counters
248  quda::blas_flops = 0;
249  mat.flops();
250  matSloppy.flops();
251 
253  profile[QUDA_PROFILE_FREE].Start();
254 
255  if (&tmp2 != &tmp) delete tmp2_p;
256 
257  if (invParam.cuda_prec_sloppy != x.Precision()) {
258  delete r_sloppy;
259  delete x_sloppy;
260  }
261 
262  profile[QUDA_PROFILE_FREE].Stop();
263 
264  return;
265  }
266 
267 } // namespace quda