QUDA  v1.1.0
A library for QCD on GPUs
inv_bicgstab_quda.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 
5 #include <quda_internal.h>
6 #include <blas_quda.h>
7 #include <dslash_quda.h>
8 #include <invert_quda.h>
9 #include <util_quda.h>
10 #include <color_spinor_field.h>
11 
12 namespace quda {
13 
14  // set the required parameters for the inner solver
15  void fillInnerSolveParam(SolverParam &inner, const SolverParam &outer);
16 
17  BiCGstab::BiCGstab(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon,
18  const DiracMatrix &matEig, SolverParam &param, TimeProfile &profile) :
19  Solver(mat, matSloppy, matPrecon, matEig, param, profile),
20  init(false)
21  {
22  }
23 
25  profile.TPSTART(QUDA_PROFILE_FREE);
26 
27  if(init) {
28  delete yp;
29  delete rp;
30  delete pp;
31  delete vp;
32  delete tmpp;
33  delete tp;
34  }
35 
36  profile.TPSTOP(QUDA_PROFILE_FREE);
37  }
38 
39  int reliable(double &rNorm, double &maxrx, double &maxrr, const double &r2, const double &delta) {
40  // reliable updates
41  rNorm = sqrt(r2);
42  if (rNorm > maxrx) maxrx = rNorm;
43  if (rNorm > maxrr) maxrr = rNorm;
44  //int updateR = (rNorm < delta*maxrr && r0Norm <= maxrr) ? 1 : 0;
45  //int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? 1 : 0
46  int updateR = (rNorm < delta*maxrr) ? 1 : 0;
47 
48  //printf("reliable %d %e %e %e %e\n", updateR, rNorm, maxrx, maxrr, r2);
49 
50  return updateR;
51  }
52 
54  {
56 
57  if (!init) {
62  csParam.setPrecision(param.precision_sloppy);
67 
68  init = true;
69  }
70 
71  ColorSpinorField &y = *yp;
72  ColorSpinorField &r = *rp;
73  ColorSpinorField &p = *pp;
74  ColorSpinorField &v = *vp;
75  ColorSpinorField &tmp = *tmpp;
76  ColorSpinorField &t = *tp;
77 
78  ColorSpinorField *x_sloppy, *r_sloppy, *r_0;
79 
80  double b2 = blas::norm2(b); // norm sq of source
81  double r2; // norm sq of residual
82 
83  // compute initial residual depending on whether we have an initial guess or not
85  mat(r, x, y);
86  r2 = blas::xmyNorm(b, r);
87  blas::copy(y, x);
88  } else {
89  blas::copy(r, b);
90  r2 = b2;
91  blas::zero(x);
92  }
93 
94  // Check to see that we're not trying to invert on a zero-field source
95  if (b2 == 0) {
97  warningQuda("inverting on zero-field source");
98  x = b;
99  param.true_res = 0.0;
100  param.true_res_hq = 0.0;
102  return;
104  b2 = r2;
105  } else {
106  errorQuda("Null vector computing requires non-zero guess!");
107  }
108  }
109 
110  // set field aliasing according to whether we are doing mixed precision or not
111  if (param.precision_sloppy == x.Precision()) {
112  r_sloppy = &r;
113 
115  {
116  r_0 = &b;
117  }
118  else
119  {
122  r_0 = ColorSpinorField::Create(csParam);//remember to delete this pointer.
123  *r_0 = r;
124  }
125  } else {
127  csParam.setPrecision(param.precision_sloppy);
129  r_sloppy = ColorSpinorField::Create(csParam);
130  *r_sloppy = r;
132  *r_0 = r;
133  }
134 
136  {
137  x_sloppy = &x;
138  blas::zero(*x_sloppy);
139  }
140  else
141  {
144  csParam.setPrecision(param.precision_sloppy);
145  x_sloppy = ColorSpinorField::Create(csParam);
146  }
147 
148  // Syntatic sugar
149  ColorSpinorField &rSloppy = *r_sloppy;
150  ColorSpinorField &xSloppy = *x_sloppy;
151  ColorSpinorField &r0 = *r_0;
152 
153  SolverParam solve_param_inner(param);
154  fillInnerSolveParam(solve_param_inner, param);
155 
156  double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver
157 
158  const bool use_heavy_quark_res =
159  (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false;
160  double heavy_quark_res = use_heavy_quark_res ? sqrt(blas::HeavyQuarkResidualNorm(x,r).z) : 0.0;
161  const int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual
162 
163  double delta = param.delta;
164 
165  int k = 0;
166  int rUpdate = 0;
167 
168  Complex rho(1.0, 0.0);
169  Complex rho0 = rho;
170  Complex alpha(1.0, 0.0);
171  Complex omega(1.0, 0.0);
172  Complex beta;
173 
174  double3 rho_r2;
175  double3 omega_t2;
176 
177  double rNorm = sqrt(r2);
178  //double r0Norm = rNorm;
179  double maxrr = rNorm;
180  double maxrx = rNorm;
181 
182  PrintStats("BiCGstab", k, r2, b2, heavy_quark_res);
183 
184  if (!param.is_preconditioner) { // do not do the below if we this is an inner solver
185  blas::flops = 0;
186  }
187 
189  profile.TPSTART(QUDA_PROFILE_COMPUTE);
190 
191  rho = r2; // cDotProductCuda(r0, r_sloppy); // BiCRstab
192  blas::copy(p, rSloppy);
193 
195  printfQuda("BiCGstab debug: x2=%e, r2=%e, v2=%e, p2=%e, tmp2=%e r0=%e t2=%e\n",
196  blas::norm2(x), blas::norm2(rSloppy), blas::norm2(v), blas::norm2(p),
198 
199  while ( !convergence(r2, heavy_quark_res, stop, param.tol_hq) &&
200  k < param.maxiter) {
201 
202  matSloppy(v, p, tmp);
203 
204  Complex r0v;
205  if (param.pipeline) {
206  r0v = blas::cDotProduct(r0, v);
207  if (k>0) rho = blas::cDotProduct(r0, r);
208  } else {
209  r0v = blas::cDotProduct(r0, v);
210  }
211  if (abs(rho) == 0.0) alpha = 0.0;
212  else alpha = rho / r0v;
213 
214  // r -= alpha*v
215  blas::caxpy(-alpha, v, rSloppy);
216 
217  matSloppy(t, rSloppy, tmp);
218 
219  int updateR = 0;
220  if (param.pipeline) {
221  // omega = (t, r) / (t, t)
222  omega_t2 = blas::cDotProductNormA(t, rSloppy);
223  Complex tr = Complex(omega_t2.x, omega_t2.y);
224  double t2 = omega_t2.z;
225  omega = tr / t2;
226  double s2 = blas::norm2(rSloppy);
227  Complex r0t = blas::cDotProduct(r0, t);
228  beta = -r0t / r0v;
229  r2 = s2 - real(omega * conj(tr)) ;
230 
231  // now we can work out if we need to do a reliable update
232  updateR = reliable(rNorm, maxrx, maxrr, r2, delta);
233  } else {
234  // omega = (t, r) / (t, t)
235  omega_t2 = blas::cDotProductNormA(t, rSloppy);
236  omega = Complex(omega_t2.x / omega_t2.z, omega_t2.y / omega_t2.z);
237  }
238 
239  if (param.pipeline && !updateR) {
240  //x += alpha*p + omega*r, r -= omega*t, p = r - beta*omega*v + beta*p
241  blas::caxpbypzYmbw(alpha, p, omega, rSloppy, xSloppy, t);
242  blas::cxpaypbz(rSloppy, -beta*omega, v, beta, p);
243  //tripleBiCGstabUpdate(alpha, p, omega, rSloppy, xSloppy, t, -beta*omega, v, beta, p
244  } else {
245  //x += alpha*p + omega*r, r -= omega*t, r2 = (r,r), rho = (r0, r)
246  rho_r2 = blas::caxpbypzYmbwcDotProductUYNormY(alpha, p, omega, rSloppy, xSloppy, t, r0);
247  rho0 = rho;
248  rho = Complex(rho_r2.x, rho_r2.y);
249  r2 = rho_r2.z;
250  }
251 
252  if (use_heavy_quark_res && k%heavy_quark_check==0) {
253  if (&x != &xSloppy) {
254  blas::copy(tmp,y);
255  heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(xSloppy, tmp, rSloppy).z);
256  } else {
257  blas::copy(r, rSloppy);
258  heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x, y, r).z);
259  }
260  }
261 
262  if (!param.pipeline) updateR = reliable(rNorm, maxrx, maxrr, r2, delta);
263 
264  if (updateR) {
265  if (x.Precision() != xSloppy.Precision()) blas::copy(x, xSloppy);
266 
267  blas::xpy(x, y); // swap these around?
268 
269  mat(r, y, x);
270  r2 = blas::xmyNorm(b, r);
271 
272  if (x.Precision() != rSloppy.Precision()) blas::copy(rSloppy, r);
273  blas::zero(xSloppy);
274 
275  rNorm = sqrt(r2);
276  maxrr = rNorm;
277  maxrx = rNorm;
278  //r0Norm = rNorm;
279  rUpdate++;
280  }
281 
282  k++;
283 
284  PrintStats("BiCGstab", k, r2, b2, heavy_quark_res);
286  printfQuda("BiCGstab debug: x2=%e, r2=%e, v2=%e, p2=%e, tmp2=%e r0=%e t2=%e\n",
287  blas::norm2(x), blas::norm2(rSloppy), blas::norm2(v), blas::norm2(p),
289 
290  // update p
291  if (!param.pipeline || updateR) {// need to update if not pipeline or did a reliable update
292  if (abs(rho*alpha) == 0.0) beta = 0.0;
293  else beta = (rho/rho0) * (alpha/omega);
294  blas::cxpaypbz(rSloppy, -beta*omega, v, beta, p);
295  }
296 
297  }
298 
299  if (x.Precision() != xSloppy.Precision()) blas::copy(x, xSloppy);
300  blas::xpy(y, x);
301 
304 
306  double gflops = (blas::flops + mat.flops() + matSloppy.flops() + matPrecon.flops())*1e-9;
307 
308  param.gflops += gflops;
309  param.iter += k;
310 
311  if (k==param.maxiter) warningQuda("Exceeded maximum iterations %d", param.maxiter);
312 
313  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("BiCGstab: Reliable updates = %d\n", rUpdate);
314 
315  if (!param.is_preconditioner) { // do not do the below if we this is an inner solver
316  // Calculate the true residual
317  mat(r, x);
318  param.true_res = sqrt(blas::xmyNorm(b, r) / b2);
319  param.true_res_hq = use_heavy_quark_res ? sqrt(blas::HeavyQuarkResidualNorm(x,r).z) : 0.0;
320 
321  PrintSummary("BiCGstab", k, r2, b2, stop, param.tol_hq);
322  }
323 
324  // reset the flops counters
325  blas::flops = 0;
326  mat.flops();
327  matSloppy.flops();
328  matPrecon.flops();
329 
330  // copy the residual to b so we can use it outside of the solver
332 
334 
335  profile.TPSTART(QUDA_PROFILE_FREE);
336  if (param.precision_sloppy != x.Precision()) {
337  delete r_0;
338  delete r_sloppy;
339  }
341  {
342  delete r_0;
343  }
344 
345  if (&x != &xSloppy) delete x_sloppy;
346 
347  profile.TPSTOP(QUDA_PROFILE_FREE);
348 
349  return;
350  }
351 
352 } // namespace quda
BiCGstab(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam &param, TimeProfile &profile)
void operator()(ColorSpinorField &out, ColorSpinorField &in)
static ColorSpinorField * Create(const ColorSpinorParam &param)
unsigned long long flops() const
Definition: dirac_quda.h:1909
QudaPrecision Precision() const
TimeProfile & profile
Definition: invert_quda.h:471
const DiracMatrix & mat
Definition: invert_quda.h:465
bool convergence(double r2, double hq2, double r2_tol, double hq_tol)
Definition: solver.cpp:328
void PrintSummary(const char *name, int k, double r2, double b2, double r2_tol, double hq_tol)
Prints out the summary of the solver convergence (requires a verbosity of QUDA_SUMMARIZE)....
Definition: solver.cpp:386
SolverParam & param
Definition: invert_quda.h:470
static double stopping(double tol, double b2, QudaResidualType residual_type)
Set the solver L2 stopping condition.
Definition: solver.cpp:311
void PrintStats(const char *name, int k, double r2, double b2, double hq2)
Prints out the running statistics of the solver (requires a verbosity of QUDA_VERBOSE)
Definition: solver.cpp:373
const DiracMatrix & matPrecon
Definition: invert_quda.h:467
const DiracMatrix & matSloppy
Definition: invert_quda.h:466
double Last(QudaProfileType idx)
Definition: timer.h:254
double omega
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:34
@ QUDA_USE_INIT_GUESS_YES
Definition: enum_quda.h:430
@ QUDA_DEBUG_VERBOSE
Definition: enum_quda.h:268
@ QUDA_VERBOSE
Definition: enum_quda.h:267
@ QUDA_HEAVY_QUARK_RESIDUAL
Definition: enum_quda.h:195
@ QUDA_PRESERVE_SOURCE_NO
Definition: enum_quda.h:238
@ QUDA_ZERO_FIELD_CREATE
Definition: enum_quda.h:361
@ QUDA_NULL_FIELD_CREATE
Definition: enum_quda.h:360
@ QUDA_COMPUTE_NULL_VECTOR_NO
Definition: enum_quda.h:441
@ QUDA_COMPUTE_NULL_VECTOR_YES
Definition: enum_quda.h:442
void init()
Create the BLAS context.
void caxpbypzYmbw(const Complex &, ColorSpinorField &, const Complex &, ColorSpinorField &, ColorSpinorField &, ColorSpinorField &)
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:79
unsigned long long flops
double3 caxpbypzYmbwcDotProductUYNormY(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &u)
double3 xpyHeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &r)
void zero(ColorSpinorField &a)
double norm2(const ColorSpinorField &a)
double3 cDotProductNormA(ColorSpinorField &a, ColorSpinorField &b)
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
void cxpaypbz(ColorSpinorField &, const Complex &b, ColorSpinorField &y, const Complex &c, ColorSpinorField &z)
void xpy(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:41
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
Definition: blas_quda.h:24
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
void stop()
Stop profiling.
Definition: device.cpp:228
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:130
int reliable(double &rNorm, double &maxrx, double &maxrr, const double &r2, const double &delta)
std::complex< double > Complex
Definition: quda_internal.h:86
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:120
@ QUDA_PROFILE_EPILOGUE
Definition: timer.h:110
@ QUDA_PROFILE_COMPUTE
Definition: timer.h:108
@ QUDA_PROFILE_FREE
Definition: timer.h:111
@ QUDA_PROFILE_PREAMBLE
Definition: timer.h:107
__host__ __device__ ValueType abs(ValueType x)
Definition: complex_quda.h:125
void fillInnerSolveParam(SolverParam &inner, const SolverParam &outer)
ColorSpinorParam csParam
Definition: pack_test.cpp:25
QudaGaugeParam param
Definition: pack_test.cpp:18
void updateR()
update the radius for halos.
QudaPreserveSource preserve_source
Definition: invert_quda.h:151
QudaComputeNullVector compute_null_vector
Definition: invert_quda.h:61
bool is_preconditioner
verbosity to use for preconditioner
Definition: invert_quda.h:238
bool use_sloppy_partial_accumulator
Definition: invert_quda.h:70
QudaResidualType residual_type
Definition: invert_quda.h:49
QudaPrecision precision_sloppy
Definition: invert_quda.h:139
QudaUseInitGuess use_init_guess
Definition: invert_quda.h:58
#define printfQuda(...)
Definition: util_quda.h:114
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define warningQuda(...)
Definition: util_quda.h:132
#define errorQuda(...)
Definition: util_quda.h:120