QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
inv_bicgstab_quda.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 
5 #include <quda_internal.h>
6 #include <blas_quda.h>
7 #include <dslash_quda.h>
8 #include <invert_quda.h>
9 #include <util_quda.h>
10 #include <color_spinor_field.h>
11 
12 namespace quda {
13 
14  // set the required parameters for the inner solver
15  void fillInnerSolveParam(SolverParam &inner, const SolverParam &outer);
16 
18  Solver(param, profile), mat(mat), matSloppy(matSloppy), matPrecon(matPrecon), init(false) {
19 
20  }
21 
23  profile.TPSTART(QUDA_PROFILE_FREE);
24 
25  if(init) {
26  delete yp;
27  delete rp;
28  delete pp;
29  delete vp;
30  delete tmpp;
31  delete tp;
32  }
33 
34  profile.TPSTOP(QUDA_PROFILE_FREE);
35  }
36 
37  int reliable(double &rNorm, double &maxrx, double &maxrr, const double &r2, const double &delta) {
38  // reliable updates
39  rNorm = sqrt(r2);
40  if (rNorm > maxrx) maxrx = rNorm;
41  if (rNorm > maxrr) maxrr = rNorm;
42  //int updateR = (rNorm < delta*maxrr && r0Norm <= maxrr) ? 1 : 0;
43  //int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? 1 : 0
44  int updateR = (rNorm < delta*maxrr) ? 1 : 0;
45 
46  //printf("reliable %d %e %e %e %e\n", updateR, rNorm, maxrx, maxrr, r2);
47 
48  return updateR;
49  }
50 
52  {
54 
55  if (!init) {
58  yp = ColorSpinorField::Create(csParam);
59  rp = ColorSpinorField::Create(csParam);
61  pp = ColorSpinorField::Create(csParam);
62  vp = ColorSpinorField::Create(csParam);
63  tmpp = ColorSpinorField::Create(csParam);
64  tp = ColorSpinorField::Create(csParam);
65 
66  init = true;
67  }
68 
69  ColorSpinorField &y = *yp;
70  ColorSpinorField &r = *rp;
71  ColorSpinorField &p = *pp;
72  ColorSpinorField &v = *vp;
74  ColorSpinorField &t = *tp;
75 
76  ColorSpinorField *x_sloppy, *r_sloppy, *r_0;
77 
78  double b2 = blas::norm2(b); // norm sq of source
79  double r2; // norm sq of residual
80 
81  // compute initial residual depending on whether we have an initial guess or not
83  mat(r, x, y);
84  r2 = blas::xmyNorm(b, r);
85  blas::copy(y, x);
86  } else {
87  blas::copy(r, b);
88  r2 = b2;
89  blas::zero(x);
90  }
91 
92  // Check to see that we're not trying to invert on a zero-field source
93  if (b2 == 0) {
95  warningQuda("inverting on zero-field source");
96  x = b;
97  param.true_res = 0.0;
98  param.true_res_hq = 0.0;
100  return;
102  b2 = r2;
103  } else {
104  errorQuda("Null vector computing requires non-zero guess!");
105  }
106  }
107 
108  // set field aliasing according to whether we are doing mixed precision or not
109  if (param.precision_sloppy == x.Precision()) {
110  r_sloppy = &r;
111 
113  {
114  r_0 = &b;
115  }
116  else
117  {
119  csParam.create = QUDA_ZERO_FIELD_CREATE;
120  r_0 = ColorSpinorField::Create(csParam);//remember to delete this pointer.
121  *r_0 = r;
122  }
123  } else {
126  csParam.create = QUDA_NULL_FIELD_CREATE;
127  r_sloppy = ColorSpinorField::Create(csParam);
128  *r_sloppy = r;
129  r_0 = ColorSpinorField::Create(csParam);
130  *r_0 = r;
131  }
132 
134  {
135  x_sloppy = &x;
136  blas::zero(*x_sloppy);
137  }
138  else
139  {
141  csParam.create = QUDA_ZERO_FIELD_CREATE;
143  x_sloppy = ColorSpinorField::Create(csParam);
144  }
145 
146  // Syntatic sugar
147  ColorSpinorField &rSloppy = *r_sloppy;
148  ColorSpinorField &xSloppy = *x_sloppy;
149  ColorSpinorField &r0 = *r_0;
150 
151  SolverParam solve_param_inner(param);
152  fillInnerSolveParam(solve_param_inner, param);
153 
154  double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver
155 
156  const bool use_heavy_quark_res =
157  (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false;
158  double heavy_quark_res = use_heavy_quark_res ? sqrt(blas::HeavyQuarkResidualNorm(x,r).z) : 0.0;
159  const int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual
160 
161  double delta = param.delta;
162 
163  int k = 0;
164  int rUpdate = 0;
165 
166  Complex rho(1.0, 0.0);
167  Complex rho0 = rho;
168  Complex alpha(1.0, 0.0);
169  Complex omega(1.0, 0.0);
170  Complex beta;
171 
172  double3 rho_r2;
173  double3 omega_t2;
174 
175  double rNorm = sqrt(r2);
176  //double r0Norm = rNorm;
177  double maxrr = rNorm;
178  double maxrx = rNorm;
179 
180  PrintStats("BiCGstab", k, r2, b2, heavy_quark_res);
181 
182  if (!param.is_preconditioner) { // do not do the below if we this is an inner solver
183  blas::flops = 0;
184  }
185 
187  profile.TPSTART(QUDA_PROFILE_COMPUTE);
188 
189  rho = r2; // cDotProductCuda(r0, r_sloppy); // BiCRstab
190  blas::copy(p, rSloppy);
191 
193  printfQuda("BiCGstab debug: x2=%e, r2=%e, v2=%e, p2=%e, tmp2=%e r0=%e t2=%e\n",
194  blas::norm2(x), blas::norm2(rSloppy), blas::norm2(v), blas::norm2(p),
195  blas::norm2(tmp), blas::norm2(r0), blas::norm2(t));
196 
197  while ( !convergence(r2, heavy_quark_res, stop, param.tol_hq) &&
198  k < param.maxiter) {
199 
200  matSloppy(v, p, tmp);
201 
202  Complex r0v;
203  if (param.pipeline) {
204  r0v = blas::cDotProduct(r0, v);
205  if (k>0) rho = blas::cDotProduct(r0, r);
206  } else {
207  r0v = blas::cDotProduct(r0, v);
208  }
209  if (abs(rho) == 0.0) alpha = 0.0;
210  else alpha = rho / r0v;
211 
212  // r -= alpha*v
213  blas::caxpy(-alpha, v, rSloppy);
214 
215  matSloppy(t, rSloppy, tmp);
216 
217  int updateR = 0;
218  if (param.pipeline) {
219  // omega = (t, r) / (t, t)
220  omega_t2 = blas::cDotProductNormA(t, rSloppy);
221  Complex tr = Complex(omega_t2.x, omega_t2.y);
222  double t2 = omega_t2.z;
223  omega = tr / t2;
224  double s2 = blas::norm2(rSloppy);
225  Complex r0t = blas::cDotProduct(r0, t);
226  beta = -r0t / r0v;
227  r2 = s2 - real(omega * conj(tr)) ;
228 
229  // now we can work out if we need to do a reliable update
230  updateR = reliable(rNorm, maxrx, maxrr, r2, delta);
231  } else {
232  // omega = (t, r) / (t, t)
233  omega_t2 = blas::cDotProductNormA(t, rSloppy);
234  omega = Complex(omega_t2.x / omega_t2.z, omega_t2.y / omega_t2.z);
235  }
236 
237  if (param.pipeline && !updateR) {
238  //x += alpha*p + omega*r, r -= omega*t, p = r - beta*omega*v + beta*p
239  blas::caxpbypzYmbw(alpha, p, omega, rSloppy, xSloppy, t);
240  blas::cxpaypbz(rSloppy, -beta*omega, v, beta, p);
241  //tripleBiCGstabUpdate(alpha, p, omega, rSloppy, xSloppy, t, -beta*omega, v, beta, p
242  } else {
243  //x += alpha*p + omega*r, r -= omega*t, r2 = (r,r), rho = (r0, r)
244  rho_r2 = blas::caxpbypzYmbwcDotProductUYNormY(alpha, p, omega, rSloppy, xSloppy, t, r0);
245  rho0 = rho;
246  rho = Complex(rho_r2.x, rho_r2.y);
247  r2 = rho_r2.z;
248  }
249 
250  if (use_heavy_quark_res && k%heavy_quark_check==0) {
251  if (&x != &xSloppy) {
252  blas::copy(tmp,y);
253  heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(xSloppy, tmp, rSloppy).z);
254  } else {
255  blas::copy(r, rSloppy);
256  heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x, y, r).z);
257  }
258  }
259 
260  if (!param.pipeline) updateR = reliable(rNorm, maxrx, maxrr, r2, delta);
261 
262  if (updateR) {
263  if (x.Precision() != xSloppy.Precision()) blas::copy(x, xSloppy);
264 
265  blas::xpy(x, y); // swap these around?
266 
267  mat(r, y, x);
268  r2 = blas::xmyNorm(b, r);
269 
270  if (x.Precision() != rSloppy.Precision()) blas::copy(rSloppy, r);
271  blas::zero(xSloppy);
272 
273  rNorm = sqrt(r2);
274  maxrr = rNorm;
275  maxrx = rNorm;
276  //r0Norm = rNorm;
277  rUpdate++;
278  }
279 
280  k++;
281 
282  PrintStats("BiCGstab", k, r2, b2, heavy_quark_res);
284  printfQuda("BiCGstab debug: x2=%e, r2=%e, v2=%e, p2=%e, tmp2=%e r0=%e t2=%e\n",
285  blas::norm2(x), blas::norm2(rSloppy), blas::norm2(v), blas::norm2(p),
286  blas::norm2(tmp), blas::norm2(r0), blas::norm2(t));
287 
288  // update p
289  if (!param.pipeline || updateR) {// need to update if not pipeline or did a reliable update
290  if (abs(rho*alpha) == 0.0) beta = 0.0;
291  else beta = (rho/rho0) * (alpha/omega);
292  blas::cxpaypbz(rSloppy, -beta*omega, v, beta, p);
293  }
294 
295  }
296 
297  if (x.Precision() != xSloppy.Precision()) blas::copy(x, xSloppy);
298  blas::xpy(y, x);
299 
302 
304  double gflops = (blas::flops + mat.flops() + matSloppy.flops() + matPrecon.flops())*1e-9;
305 
306  param.gflops += gflops;
307  param.iter += k;
308 
309  if (k==param.maxiter) warningQuda("Exceeded maximum iterations %d", param.maxiter);
310 
311  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("BiCGstab: Reliable updates = %d\n", rUpdate);
312 
313  if (!param.is_preconditioner) { // do not do the below if we this is an inner solver
314  // Calculate the true residual
315  mat(r, x);
316  param.true_res = sqrt(blas::xmyNorm(b, r) / b2);
317  param.true_res_hq = use_heavy_quark_res ? sqrt(blas::HeavyQuarkResidualNorm(x,r).z) : 0.0;
318 
319  PrintSummary("BiCGstab", k, r2, b2, stop, param.tol_hq);
320  }
321 
322  // reset the flops counters
323  blas::flops = 0;
324  mat.flops();
325  matSloppy.flops();
326  matPrecon.flops();
327 
328  // copy the residual to b so we can use it outside of the solver
330 
332 
333  profile.TPSTART(QUDA_PROFILE_FREE);
334  if (param.precision_sloppy != x.Precision()) {
335  delete r_0;
336  delete r_sloppy;
337  }
339  {
340  delete r_0;
341  }
342 
343  if (&x != &xSloppy) delete x_sloppy;
344 
345  profile.TPSTOP(QUDA_PROFILE_FREE);
346 
347  return;
348  }
349 
350 } // namespace quda
void setPrecision(QudaPrecision precision, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION, bool force_native=false)
ColorSpinorField * vp
Definition: invert_quda.h:715
double3 cDotProductNormA(ColorSpinorField &a, ColorSpinorField &b)
Definition: reduce_quda.cu:778
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define errorQuda(...)
Definition: util_quda.h:121
double norm2(const ColorSpinorField &a)
Definition: reduce_quda.cu:721
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:120
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
Definition: reduce_quda.cu:764
void PrintStats(const char *name, int k, double r2, double b2, double hq2)
Prints out the running statistics of the solver (requires a verbosity of QUDA_VERBOSE) ...
Definition: solver.cpp:256
ColorSpinorField * tp
Definition: invert_quda.h:715
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:44
double3 xpyHeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &r)
Definition: reduce_quda.cu:818
static ColorSpinorField * Create(const ColorSpinorParam &param)
bool convergence(double r2, double hq2, double r2_tol, double hq_tol)
Definition: solver.cpp:223
TimeProfile & profile
Definition: invert_quda.h:464
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
Definition: copy_quda.cu:355
BiCGstab(DiracMatrix &mat, DiracMatrix &matSloppy, DiracMatrix &matPrecon, SolverParam &param, TimeProfile &profile)
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:75
QudaPreserveSource preserve_source
Definition: invert_quda.h:154
void fillInnerSolveParam(SolverParam &inner, const SolverParam &outer)
QudaGaugeParam param
Definition: pack_test.cpp:17
QudaComputeNullVector compute_null_vector
Definition: invert_quda.h:67
double Last(QudaProfileType idx)
Definition: timer.h:251
QudaResidualType residual_type
Definition: invert_quda.h:49
static double stopping(double tol, double b2, QudaResidualType residual_type)
Set the solver L2 stopping condition.
Definition: solver.cpp:206
const DiracMatrix & matSloppy
Definition: invert_quda.h:711
ColorSpinorParam csParam
Definition: pack_test.cpp:24
void caxpbypzYmbw(const Complex &, ColorSpinorField &, const Complex &, ColorSpinorField &, ColorSpinorField &, ColorSpinorField &)
Definition: blas_quda.cu:585
void operator()(ColorSpinorField &out, ColorSpinorField &in)
double omega
Definition: test_util.cpp:1690
#define warningQuda(...)
Definition: util_quda.h:133
bool is_preconditioner
verbosity to use for preconditioner
Definition: invert_quda.h:241
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
Definition: reduce_quda.cu:809
ColorSpinorField * rp
Definition: invert_quda.h:715
std::complex< double > Complex
Definition: quda_internal.h:46
void init()
Create the CUBLAS context.
Definition: blas_cublas.cu:31
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:512
void zero(ColorSpinorField &a)
Definition: blas_quda.cu:472
double3 caxpbypzYmbwcDotProductUYNormY(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &u)
Definition: reduce_quda.cu:783
SolverParam & param
Definition: invert_quda.h:463
unsigned long long flops() const
Definition: dirac_quda.h:1119
#define printfQuda(...)
Definition: util_quda.h:115
int reliable(double &rNorm, double &maxrx, double &maxrr, const double &r2, const double &delta)
unsigned long long flops
Definition: blas_quda.cu:22
ColorSpinorField * tmpp
Definition: invert_quda.h:715
void xpy(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:33
const DiracMatrix & matPrecon
Definition: invert_quda.h:712
QudaUseInitGuess use_init_guess
Definition: invert_quda.h:64
__host__ __device__ ValueType abs(ValueType x)
Definition: complex_quda.h:125
void cxpaypbz(ColorSpinorField &, const Complex &b, ColorSpinorField &y, const Complex &c, ColorSpinorField &z)
Definition: blas_quda.cu:535
QudaPrecision precision_sloppy
Definition: invert_quda.h:145
bool use_sloppy_partial_accumulator
Definition: invert_quda.h:76
void PrintSummary(const char *name, int k, double r2, double b2, double r2_tol, double hq_tol)
Prints out the summary of the solver convergence (requires a verbosity of QUDA_SUMMARIZE). Assumes SolverParam.true_res and SolverParam.true_res_hq has been set.
Definition: solver.cpp:270
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:130
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
QudaPrecision Precision() const
ColorSpinorField * pp
Definition: invert_quda.h:715
DiracMatrix & mat
Definition: invert_quda.h:710
void updateR()
update the radius for halos.
ColorSpinorField * yp
Definition: invert_quda.h:715