QUDA  v0.7.0
A library for QCD on GPUs
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
inv_bicgstab_quda.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 
5 #include <quda_internal.h>
6 #include <blas_quda.h>
7 #include <dslash_quda.h>
8 #include <invert_quda.h>
9 #include <util_quda.h>
10 
11 #include<face_quda.h>
12 
13 #include <color_spinor_field.h>
14 
15 namespace quda {
16 
17  // set the required parameters for the inner solver
18  void fillInnerSolveParam(SolverParam &inner, const SolverParam &outer);
19 
22  mat(r, x);
23  return xmyNormCuda(b, r);
24  }
25 
26 
28  Solver(param, profile), mat(mat), matSloppy(matSloppy), matPrecon(matPrecon), init(false) {
29 
30  }
31 
34 
35  if(init) {
36  delete yp;
37  delete rp;
38  delete pp;
39  delete vp;
40  delete tmpp;
41  delete tp;
42  }
43 
45  }
46 
47  int reliable(double &rNorm, double &maxrx, double &maxrr, const double &r2, const double &delta) {
48  // reliable updates
49  rNorm = sqrt(r2);
50  if (rNorm > maxrx) maxrx = rNorm;
51  if (rNorm > maxrr) maxrr = rNorm;
52  //int updateR = (rNorm < delta*maxrr && r0Norm <= maxrr) ? 1 : 0;
53  //int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? 1 : 0
54  int updateR = (rNorm < delta*maxrr) ? 1 : 0;
55 
56  //printf("reliable %d %e %e %e %e\n", updateR, rNorm, maxrx, maxrr, r2);
57 
58  return updateR;
59  }
60 
62  {
64 
65  if (!init) {
68  yp = new cudaColorSpinorField(x, csParam);
69  rp = new cudaColorSpinorField(x, csParam);
71  pp = new cudaColorSpinorField(x, csParam);
72  vp = new cudaColorSpinorField(x, csParam);
73  tmpp = new cudaColorSpinorField(x, csParam);
74  tp = new cudaColorSpinorField(x, csParam);
75 
76  init = true;
77  }
78 
79  cudaColorSpinorField &y = *yp;
80  cudaColorSpinorField &r = *rp;
81  cudaColorSpinorField &p = *pp;
82  cudaColorSpinorField &v = *vp;
83  cudaColorSpinorField &tmp = *tmpp;
84  cudaColorSpinorField &t = *tp;
85 
86  cudaColorSpinorField *x_sloppy, *r_sloppy, *r_0;
87 
88  double b2 = normCuda(b); // norm sq of source
89  double r2; // norm sq of residual
90 
91  // compute initial residual depending on whether we have an initial guess or not
93  mat(r, x, y);
94  r2 = xmyNormCuda(b, r);
95  copyCuda(y, x);
96  } else {
97  copyCuda(r, b);
98  r2 = b2;
99  zeroCuda(x); // defensive measure in case solution isn't already zero
100  }
101 
102  // Check to see that we're not trying to invert on a zero-field source
103  if (b2 == 0) {
105  warningQuda("inverting on zero-field source\n");
106  x = b;
107  param.true_res = 0.0;
108  param.true_res_hq = 0.0;
109  return;
110  }
111 
112  // set field aliasing according to whether we are doing mixed precision or not
113  if (param.precision_sloppy == x.Precision()) {
114  r_sloppy = &r;
115  r_0 = &b;
116  } else {
119  csParam.create = QUDA_COPY_FIELD_CREATE;
120  r_sloppy = new cudaColorSpinorField(r, csParam);
121  r_0 = new cudaColorSpinorField(b, csParam);
122  }
123 
124  // set field aliasing according to whether we are doing mixed precision or not
125  if (param.precision_sloppy == x.Precision() ||
127  x_sloppy = &x;
128  zeroCuda(*x_sloppy);
129  } else {
131  csParam.create = QUDA_ZERO_FIELD_CREATE;
133  x_sloppy = new cudaColorSpinorField(x, csParam);
134  }
135 
136  // Syntatic sugar
137  cudaColorSpinorField &rSloppy = *r_sloppy;
138  cudaColorSpinorField &xSloppy = *x_sloppy;
139  cudaColorSpinorField &r0 = *r_0;
140 
141  SolverParam solve_param_inner(param);
142  fillInnerSolveParam(solve_param_inner, param);
143 
144  double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver
145 
146  const bool use_heavy_quark_res =
147  (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false;
148  double heavy_quark_res = use_heavy_quark_res ? sqrt(HeavyQuarkResidualNormCuda(x,r).z) : 0.0;
149  int heavy_quark_check = 10; // how often to check the heavy quark residual
150 
151  double delta = param.delta;
152 
153  int k = 0;
154  int rUpdate = 0;
155 
156  Complex rho(1.0, 0.0);
157  Complex rho0 = rho;
158  Complex alpha(1.0, 0.0);
159  Complex omega(1.0, 0.0);
160  Complex beta;
161 
162  double3 rho_r2;
163  double3 omega_t2;
164 
165  double rNorm = sqrt(r2);
166  //double r0Norm = rNorm;
167  double maxrr = rNorm;
168  double maxrx = rNorm;
169 
170  PrintStats("BiCGstab", k, r2, b2, heavy_quark_res);
171 
172  if (param.inv_type_precondition != QUDA_GCR_INVERTER) { // do not do the below if we this is an inner solver
173  quda::blas_flops = 0;
174  }
175 
178 
179  rho = r2; // cDotProductCuda(r0, r_sloppy); // BiCRstab
180  copyCuda(p, rSloppy);
181 
183  printfQuda("BiCGstab debug: x2=%e, r2=%e, v2=%e, p2=%e, tmp2=%e r0=%e t2=%e\n",
184  norm2(x), norm2(rSloppy), norm2(v), norm2(p), norm2(tmp), norm2(r0), norm2(t));
185 
186  while ( !convergence(r2, heavy_quark_res, stop, param.tol_hq) &&
187  k < param.maxiter) {
188 
189  matSloppy(v, p, tmp);
190 
191  Complex r0v;
192  if (param.pipeline) {
193  r0v = cDotProductCuda(r0, v);
194  if (k>0) rho = cDotProductCuda(r0, r);
195  } else {
196  r0v = cDotProductCuda(r0, v);
197  }
198  if (abs(rho) == 0.0) alpha = 0.0;
199  else alpha = rho / r0v;
200 
201  // r -= alpha*v
202  caxpyCuda(-alpha, v, rSloppy);
203 
204  matSloppy(t, rSloppy, tmp);
205 
206  int updateR = 0;
207  if (param.pipeline) {
208  // omega = (t, r) / (t, t)
209  omega_t2 = cDotProductNormACuda(t, rSloppy);
210  Complex tr = Complex(omega_t2.x, omega_t2.y);
211  double t2 = omega_t2.z;
212  omega = tr / t2;
213  double s2 = norm2(rSloppy);
214  Complex r0t = cDotProductCuda(r0, t);
215  beta = -r0t / r0v;
216  r2 = s2 - real(omega * conj(tr)) ;
217 
218  // now we can work out if we need to do a reliable update
219  updateR = reliable(rNorm, maxrx, maxrr, r2, delta);
220  } else {
221  // omega = (t, r) / (t, t)
222  omega_t2 = cDotProductNormACuda(t, rSloppy);
223  omega = Complex(omega_t2.x / omega_t2.z, omega_t2.y / omega_t2.z);
224  }
225 
226  if (param.pipeline && !updateR) {
227  //x += alpha*p + omega*r, r -= omega*t, p = r - beta*omega*v + beta*p
228  caxpbypzYmbwCuda(alpha, p, omega, rSloppy, xSloppy, t);
229  cxpaypbzCuda(rSloppy, -beta*omega, v, beta, p);
230  //tripleBiCGstabUpdate(alpha, p, omega, rSloppy, xSloppy, t, -beta*omega, v, beta, p
231  } else {
232  //x += alpha*p + omega*r, r -= omega*t, r2 = (r,r), rho = (r0, r)
233  rho_r2 = caxpbypzYmbwcDotProductUYNormYCuda(alpha, p, omega, rSloppy, xSloppy, t, r0);
234 
235  rho0 = rho;
236  rho = Complex(rho_r2.x, rho_r2.y);
237  r2 = rho_r2.z;
238  }
239 
240  if (use_heavy_quark_res && k%heavy_quark_check==0) {
241  if (&x != &xSloppy) {
242  copyCuda(tmp,y);
243  heavy_quark_res = sqrt(xpyHeavyQuarkResidualNormCuda(xSloppy, tmp, rSloppy).z);
244  } else {
245  copyCuda(r, rSloppy);
246  heavy_quark_res = sqrt(xpyHeavyQuarkResidualNormCuda(x, y, r).z);
247  }
248  }
249 
250  if (!param.pipeline) updateR = reliable(rNorm, maxrx, maxrr, r2, delta);
251 
252  if (updateR) {
253  if (x.Precision() != xSloppy.Precision()) copyCuda(x, xSloppy);
254 
255  xpyCuda(x, y); // swap these around?
256 
257  mat(r, y, x);
258  r2 = xmyNormCuda(b, r);
259 
260  if (x.Precision() != rSloppy.Precision()) copyCuda(rSloppy, r);
261  zeroCuda(xSloppy);
262 
263  rNorm = sqrt(r2);
264  maxrr = rNorm;
265  maxrx = rNorm;
266  //r0Norm = rNorm;
267  rUpdate++;
268  }
269 
270  k++;
271 
272  PrintStats("BiCGstab", k, r2, b2, heavy_quark_res);
274  printfQuda("BiCGstab debug: x2=%e, r2=%e, v2=%e, p2=%e, tmp2=%e r0=%e t2=%e\n",
275  norm2(x), norm2(rSloppy), norm2(v), norm2(p), norm2(tmp), norm2(r0), norm2(t));
276 
277  // update p
278  if (!param.pipeline || updateR) {// need to update if not pipeline or did a reliable update
279  if (abs(rho*alpha) == 0.0) beta = 0.0;
280  else beta = (rho/rho0) * (alpha/omega);
281  cxpaypbzCuda(rSloppy, -beta*omega, v, beta, p);
282  }
283 
284  }
285 
286  if (x.Precision() != xSloppy.Precision()) copyCuda(x, xSloppy);
287  xpyCuda(y, x);
288 
291 
293  double gflops = (quda::blas_flops + mat.flops() + matSloppy.flops() + matPrecon.flops())*1e-9;
294  reduceDouble(gflops);
295 
296  param.gflops += gflops;
297  param.iter += k;
298 
299  if (k==param.maxiter) warningQuda("Exceeded maximum iterations %d", param.maxiter);
300 
301  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("BiCGstab: Reliable updates = %d\n", rUpdate);
302 
303  if (param.inv_type_precondition != QUDA_GCR_INVERTER) { // do not do the below if we this is an inner solver
304  // Calculate the true residual
305  mat(r, x);
306  param.true_res = sqrt(xmyNormCuda(b, r) / b2);
307 #if (__COMPUTE_CAPABILITY__ >= 200)
309 #else
310  param.true_res_hq = 0.0;
311 #endif
312 
313  PrintSummary("BiCGstab", k, r2, b2);
314  }
315 
316  // reset the flops counters
317  quda::blas_flops = 0;
318  mat.flops();
319  matSloppy.flops();
320  matPrecon.flops();
321 
323 
325  if (param.precision_sloppy != x.Precision()) {
326  delete r_0;
327  delete r_sloppy;
328  }
329 
330  if (&x != &xSloppy) delete x_sloppy;
331 
333 
334  return;
335  }
336 
337 } // namespace quda
bool convergence(const double &r2, const double &hq2, const double &r2_tol, const double &hq_tol)
Definition: solver.cpp:82
void setPrecision(QudaPrecision precision)
void caxpyCuda(const Complex &a, cudaColorSpinorField &x, cudaColorSpinorField &y)
Definition: blas_quda.cu:207
static double stopping(const double &tol, const double &b2, QudaResidualType residual_type)
Definition: solver.cpp:65
int y[4]
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:20
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:105
std::complex< double > Complex
Definition: eig_variables.h:13
void mat(void *out, void **fatlink, void **longlink, void *in, double kappa, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision)
void operator()(cudaColorSpinorField &out, cudaColorSpinorField &in)
double resNorm(const DiracMatrix &mat, cudaColorSpinorField &b, cudaColorSpinorField &x)
TimeProfile & profile
Definition: invert_quda.h:224
BiCGstab(DiracMatrix &mat, DiracMatrix &matSloppy, DiracMatrix &matPrecon, SolverParam &param, TimeProfile &profile)
QudaInverterType inv_type_precondition
Definition: invert_quda.h:24
void fillInnerSolveParam(SolverParam &inner, const SolverParam &outer)
unsigned long long flops() const
Definition: dirac_quda.h:587
QudaGaugeParam param
Definition: pack_test.cpp:17
cudaColorSpinorField * tmp
void PrintSummary(const char *name, int k, const double &r2, const double &b2)
Definition: solver.cpp:137
double3 caxpbypzYmbwcDotProductUYNormYCuda(const Complex &a, cudaColorSpinorField &x, const Complex &b, cudaColorSpinorField &y, cudaColorSpinorField &z, cudaColorSpinorField &w, cudaColorSpinorField &u)
Definition: reduce_quda.cu:643
QudaResidualType residual_type
Definition: invert_quda.h:35
Complex cDotProductCuda(cudaColorSpinorField &, cudaColorSpinorField &)
Definition: reduce_quda.cu:468
ColorSpinorParam csParam
Definition: pack_test.cpp:24
#define warningQuda(...)
Definition: util_quda.h:84
void copyCuda(cudaColorSpinorField &dst, const cudaColorSpinorField &src)
Definition: copy_quda.cu:235
double normCuda(const cudaColorSpinorField &b)
Definition: reduce_quda.cu:145
int x[4]
unsigned long long blas_flops
Definition: blas_quda.cu:37
double3 xpyHeavyQuarkResidualNormCuda(cudaColorSpinorField &x, cudaColorSpinorField &y, cudaColorSpinorField &r)
Definition: reduce_quda.cu:782
SolverParam & param
Definition: invert_quda.h:223
void xpyCuda(cudaColorSpinorField &x, cudaColorSpinorField &y)
Definition: blas_quda.cu:98
void Stop(QudaProfileType idx)
QudaPrecision Precision() const
void PrintStats(const char *, int k, const double &r2, const double &b2, const double &hq2)
Definition: solver.cpp:122
double Last(QudaProfileType idx)
void reduceDouble(double &)
void cxpaypbzCuda(cudaColorSpinorField &, const Complex &b, cudaColorSpinorField &y, const Complex &c, cudaColorSpinorField &z)
Definition: blas_quda.cu:290
#define printfQuda(...)
Definition: util_quda.h:67
int reliable(double &rNorm, double &maxrx, double &maxrr, const double &r2, const double &delta)
void zeroCuda(cudaColorSpinorField &a)
Definition: blas_quda.cu:40
double3 cDotProductNormACuda(cudaColorSpinorField &a, cudaColorSpinorField &b)
Definition: reduce_quda.cu:591
void Start(QudaProfileType idx)
QudaUseInitGuess use_init_guess
Definition: invert_quda.h:38
void init(int argc, char **argv)
Definition: dslash_test.cpp:79
__host__ __device__ ValueType abs(ValueType x)
Definition: complex_quda.h:110
QudaPrecision precision_sloppy
Definition: invert_quda.h:84
bool use_sloppy_partial_accumulator
Definition: invert_quda.h:44
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:115
double3 HeavyQuarkResidualNormCuda(cudaColorSpinorField &x, cudaColorSpinorField &r)
Definition: reduce_quda.cu:777
double norm2(const ColorSpinorField &)
double xmyNormCuda(cudaColorSpinorField &a, cudaColorSpinorField &b)
Definition: reduce_quda.cu:343
void caxpbypzYmbwCuda(const Complex &, cudaColorSpinorField &, const Complex &, cudaColorSpinorField &, cudaColorSpinorField &, cudaColorSpinorField &)
Definition: blas_quda.cu:366