QUDA  v0.5.0
A library for QCD on GPUs
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
inv_bicgstab_quda.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 
5 #include <complex>
6 
7 #include <quda_internal.h>
8 #include <blas_quda.h>
9 #include <dslash_quda.h>
10 #include <invert_quda.h>
11 #include <util_quda.h>
12 
13 #include<face_quda.h>
14 
15 #include <color_spinor_field.h>
16 
17 namespace quda {
18 
19  // set the required parameters for the inner solver
20  void fillInnerInvertParam(QudaInvertParam &inner, const QudaInvertParam &outer);
21 
24  mat(r, x);
25  return xmyNormCuda(b, r);
26  }
27 
28 
29  BiCGstab::BiCGstab(DiracMatrix &mat, DiracMatrix &matSloppy, DiracMatrix &matPrecon, QudaInvertParam &invParam, TimeProfile &profile) :
30  Solver(invParam, profile), mat(mat), matSloppy(matSloppy), matPrecon(matPrecon), init(false) {
31 
32  }
33 
35  profile[QUDA_PROFILE_FREE].Start();
36 
37  if(init) {
38  if (wp && wp != pp) delete wp;
39  if (zp && zp != pp) delete zp;
40  delete yp;
41  delete rp;
42  delete pp;
43  delete vp;
44  delete tmpp;
45  delete tp;
46  }
47 
48  profile[QUDA_PROFILE_FREE].Stop();
49  }
50 
52  {
54 
55  if (!init) {
58  yp = new cudaColorSpinorField(x, csParam);
59  rp = new cudaColorSpinorField(x, csParam);
61  pp = new cudaColorSpinorField(x, csParam);
62  vp = new cudaColorSpinorField(x, csParam);
63  tmpp = new cudaColorSpinorField(x, csParam);
64  tp = new cudaColorSpinorField(x, csParam);
65 
66  // MR preconditioner - we need extra vectors
68  wp = new cudaColorSpinorField(x, csParam);
69  zp = new cudaColorSpinorField(x, csParam);
70  } else { // dummy assignments
71  wp = pp;
72  zp = pp;
73  }
74 
75  init = true;
76  }
77 
78  cudaColorSpinorField &y = *yp;
79  cudaColorSpinorField &r = *rp;
80  cudaColorSpinorField &p = *pp;
81  cudaColorSpinorField &v = *vp;
82  cudaColorSpinorField &tmp = *tmpp;
83  cudaColorSpinorField &t = *tp;
84  cudaColorSpinorField &w = *wp;
85  cudaColorSpinorField &z = *zp;
86 
87  cudaColorSpinorField *x_sloppy, *r_sloppy, *r_0;
88 
89  double b2; // norm sq of source
90  double r2; // norm sq of residual
91 
92  // compute initial residual depending on whether we have an initial guess or not
94  mat(r, x, y);
95  r2 = xmyNormCuda(b, r);
96  b2 = normCuda(b);
97  copyCuda(y, x);
98  } else {
99  copyCuda(r, b);
100  r2 = normCuda(b);
101  b2 = r2;
102  }
103 
104  // Check to see that we're not trying to invert on a zero-field source
105  if(b2 == 0){
106  profile[QUDA_PROFILE_INIT].Stop();
107  printfQuda("Warning: inverting on zero-field source\n");
108  x = b;
109  invParam.true_res = 0.0;
110  invParam.true_res_hq = 0.0;
111  return;
112  }
113 
114  // set field aliasing according to whether we are doing mixed precision or not
115  if (invParam.cuda_prec_sloppy == x.Precision()) {
116  x_sloppy = &x;
117  r_sloppy = &r;
118  r_0 = &b;
119  zeroCuda(*x_sloppy);
120  } else {
122  csParam.create = QUDA_ZERO_FIELD_CREATE;
124  x_sloppy = new cudaColorSpinorField(x, csParam);
125  csParam.create = QUDA_COPY_FIELD_CREATE;
126  r_sloppy = new cudaColorSpinorField(r, csParam);
127  r_0 = new cudaColorSpinorField(b, csParam);
128  }
129 
130  // Syntatic sugar
131  cudaColorSpinorField &rSloppy = *r_sloppy;
132  cudaColorSpinorField &xSloppy = *x_sloppy;
133  cudaColorSpinorField &r0 = *r_0;
134 
135  QudaInvertParam invert_param_inner = newQudaInvertParam();
136  fillInnerInvertParam(invert_param_inner, invParam);
137 
138  double stop = b2*invParam.tol*invParam.tol; // stopping condition of solver
139 
140  const bool use_heavy_quark_res =
142  double heavy_quark_res = use_heavy_quark_res ? sqrt(HeavyQuarkResidualNormCuda(x,r).z) : 0.0;
143  int heavy_quark_check = 10; // how often to check the heavy quark residual
144 
145  double delta = invParam.reliable_delta;
146 
147  int k = 0;
148  int rUpdate = 0;
149 
150  Complex rho(1.0, 0.0);
151  Complex rho0 = rho;
152  Complex alpha(1.0, 0.0);
153  Complex omega(1.0, 0.0);
154  Complex beta;
155 
156  double3 rho_r2;
157  double3 omega_t2;
158 
159  double rNorm = sqrt(r2);
160  //double r0Norm = rNorm;
161  double maxrr = rNorm;
162  double maxrx = rNorm;
163 
164  PrintStats("BiCGstab", k, r2, b2, heavy_quark_res);
165 
166  if (invParam.inv_type_precondition != QUDA_GCR_INVERTER) { // do not do the below if we this is an inner solver
167  quda::blas_flops = 0;
168  }
169 
171  profile[QUDA_PROFILE_COMPUTE].Start();
172 
173  while ( !convergence(r2, heavy_quark_res, stop, invParam.tol_hq) &&
174  k < invParam.maxiter) {
175 
176  if (k==0) {
177  rho = r2; // cDotProductCuda(r0, r_sloppy); // BiCRstab
178  copyCuda(p, rSloppy);
179  } else {
180  if (abs(rho*alpha) == 0.0) beta = 0.0;
181  else beta = (rho/rho0) * (alpha/omega);
182 
183  cxpaypbzCuda(rSloppy, -beta*omega, v, beta, p);
184  }
185 
187  errorQuda("Temporary disabled");
188  //invertMRCuda(*matPrecon, w, p, &invert_param_inner);
189  matSloppy(v, w, tmp);
190  } else {
191  matSloppy(v, p, tmp);
192  }
193 
194  if (abs(rho) == 0.0) alpha = 0.0;
195  else alpha = rho / cDotProductCuda(r0, v);
196 
197  // r -= alpha*v
198  caxpyCuda(-alpha, v, rSloppy);
199 
201  errorQuda("Temporary disabled");
202  //invertMRCuda(*matPrecon, z, rSloppy, &invert_param_inner);
203  matSloppy(t, z, tmp);
204  } else {
205  matSloppy(t, rSloppy, tmp);
206  }
207 
208  // omega = (t, r) / (t, t)
209  omega_t2 = cDotProductNormACuda(t, rSloppy);
210  omega = quda::Complex(omega_t2.x / omega_t2.z, omega_t2.y / omega_t2.z);
211 
213  //x += alpha*w + omega*z, r -= omega*t, r2 = (r,r), rho = (r0, r)
214  caxpyCuda(alpha, w, xSloppy);
215  caxpyCuda(omega, z, xSloppy);
216  caxpyCuda(-omega, t, rSloppy);
217  rho_r2 = cDotProductNormBCuda(r0, rSloppy);
218  } else {
219  //x += alpha*p + omega*r, r -= omega*t, r2 = (r,r), rho = (r0, r)
220  rho_r2 = caxpbypzYmbwcDotProductUYNormYCuda(alpha, p, omega, rSloppy, xSloppy, t, r0);
221  }
222 
223  rho0 = rho;
224  rho = quda::Complex(rho_r2.x, rho_r2.y);
225  r2 = rho_r2.z;
226 
227  if (use_heavy_quark_res && k%heavy_quark_check==0) {
228  copyCuda(tmp,y);
229  heavy_quark_res = sqrt(xpyHeavyQuarkResidualNormCuda(xSloppy, tmp, rSloppy).z);
230  }
231 
232  // reliable updates
233  rNorm = sqrt(r2);
234  if (rNorm > maxrx) maxrx = rNorm;
235  if (rNorm > maxrr) maxrr = rNorm;
236  //int updateR = (rNorm < delta*maxrr && r0Norm <= maxrr) ? 1 : 0;
237  //int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? 1 : 0;
238 
239  int updateR = (rNorm < delta*maxrr) ? 1 : 0;
240 
241  if (updateR) {
242  if (x.Precision() != xSloppy.Precision()) copyCuda(x, xSloppy);
243 
244  xpyCuda(x, y); // swap these around?
245 
246  mat(r, y, x);
247  r2 = xmyNormCuda(b, r);
248 
249  if (x.Precision() != rSloppy.Precision()) copyCuda(rSloppy, r);
250  zeroCuda(xSloppy);
251 
252  rNorm = sqrt(r2);
253  maxrr = rNorm;
254  maxrx = rNorm;
255  //r0Norm = rNorm;
256  rUpdate++;
257  }
258 
259  k++;
260 
261  PrintStats("BiCGstab", k, r2, b2, heavy_quark_res);
262  }
263 
264  if (x.Precision() != xSloppy.Precision()) copyCuda(x, xSloppy);
265  xpyCuda(y, x);
266 
269 
271  double gflops = (quda::blas_flops + mat.flops() + matSloppy.flops() + matPrecon.flops())*1e-9;
272  reduceDouble(gflops);
273 
274  invParam.gflops += gflops;
275  invParam.iter += k;
276 
277  if (k==invParam.maxiter) warningQuda("Exceeded maximum iterations %d", invParam.maxiter);
278 
279  if (invParam.verbosity >= QUDA_VERBOSE) printfQuda("BiCGstab: Reliable updates = %d\n", rUpdate);
280 
281  if (invParam.inv_type_precondition != QUDA_GCR_INVERTER) { // do not do the below if we this is an inner solver
282  // Calculate the true residual
283  mat(r, x);
284  invParam.true_res = sqrt(xmyNormCuda(b, r) / b2);
285 #if (__COMPUTE_CAPABILITY__ >= 200)
287 #else
288  invParam.true_res_hq = 0.0;
289 #endif
290 
291  PrintSummary("BiCGstab", k, r2, b2);
292  }
293 
294  // reset the flops counters
295  quda::blas_flops = 0;
296  mat.flops();
297  matSloppy.flops();
298  matPrecon.flops();
299 
301 
302  profile[QUDA_PROFILE_FREE].Start();
303  if (invParam.cuda_prec_sloppy != x.Precision()) {
304  delete r_0;
305  delete r_sloppy;
306  delete x_sloppy;
307  }
308  profile[QUDA_PROFILE_FREE].Stop();
309 
310  return;
311  }
312 
313 } // namespace quda