QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
inv_pcg_quda.cpp
Go to the documentation of this file.
1 #include <cstdio>
2 #include <cstdlib>
3 #include <cmath>
4 #include <iostream>
5 
6 #include <quda_internal.h>
7 #include <color_spinor_field.h>
8 #include <blas_quda.h>
9 #include <dslash_quda.h>
10 #include <invert_quda.h>
11 #include <util_quda.h>
12 
13 namespace quda {
14 
15  using namespace blas;
16 
17  // set the required parameters for the inner solver
18  static void fillInnerSolverParam(SolverParam &inner, const SolverParam &outer)
19  {
20  inner.tol = outer.tol_precondition;
21  inner.maxiter = outer.maxiter_precondition;
22  inner.delta = 1e-20; // no reliable updates within the inner solver
23  inner.precision = outer.precision_precondition; // preconditioners are uni-precision solvers
25 
26  inner.iter = 0;
27  inner.gflops = 0;
28  inner.secs = 0;
29 
31  inner.is_preconditioner = true; // used to tell the inner solver it is an inner solver
32 
36  }
37 
38 
40  Solver(param, profile), mat(mat), matSloppy(matSloppy), matPrecon(matPrecon), K(0), Kparam(param)
41  {
42 
44 
46  K = new CG(matPrecon, matPrecon, Kparam, profile);
47  }else if(param.inv_type_precondition == QUDA_MR_INVERTER){
48  K = new MR(matPrecon, matPrecon, Kparam, profile);
49  }else if(param.inv_type_precondition == QUDA_SD_INVERTER){
50  K = new SD(matPrecon, Kparam, profile);
51  }else if(param.inv_type_precondition != QUDA_INVALID_INVERTER){ // unknown preconditioner
52  errorQuda("Unknown inner solver %d", param.inv_type_precondition);
53  }
54  }
55 
57  profile.TPSTART(QUDA_PROFILE_FREE);
58 
59  if(K) delete K;
60 
61  profile.TPSTOP(QUDA_PROFILE_FREE);
62  }
63 
64 
66  {
67 
68  profile.TPSTART(QUDA_PROFILE_INIT);
69  // Check to see that we're not trying to invert on a zero-field source
70  const double b2 = norm2(b);
71  if(b2 == 0){
72  profile.TPSTOP(QUDA_PROFILE_INIT);
73  printfQuda("Warning: inverting on zero-field source\n");
74  x=b;
75  param.true_res = 0.0;
76  param.true_res_hq = 0.0;
77  }
78 
79  int k=0;
80  int rUpdate=0;
81 
82  cudaColorSpinorField* minvrPre = NULL;
83  cudaColorSpinorField* rPre = NULL;
84  cudaColorSpinorField* minvr = NULL;
85  cudaColorSpinorField* minvrSloppy = NULL;
86  cudaColorSpinorField* p = NULL;
87 
88 
91  if(K) minvr = new cudaColorSpinorField(b);
93  cudaColorSpinorField y(b,csParam);
94 
95  mat(r, x, y); // => r = A*x;
96  double r2 = xmyNorm(b,r);
97 
99  cudaColorSpinorField tmpSloppy(x,csParam);
100  cudaColorSpinorField Ap(x,csParam);
101 
102  cudaColorSpinorField *r_sloppy;
103  if(param.precision_sloppy == x.Precision())
104  {
105  r_sloppy = &r;
106  minvrSloppy = minvr;
107  }else{
108  csParam.create = QUDA_COPY_FIELD_CREATE;
109  r_sloppy = new cudaColorSpinorField(r,csParam);
110  if(K) minvrSloppy = new cudaColorSpinorField(*minvr,csParam);
111  }
112 
113 
114  cudaColorSpinorField *x_sloppy;
115  if(param.precision_sloppy == x.Precision() ||
118  x_sloppy = &static_cast<cudaColorSpinorField&>(x);
119  }else{
120  csParam.create = QUDA_COPY_FIELD_CREATE;
121  x_sloppy = new cudaColorSpinorField(x,csParam);
122  }
123 
124 
125  cudaColorSpinorField &xSloppy = *x_sloppy;
126  cudaColorSpinorField &rSloppy = *r_sloppy;
127 
128  if(&x != &xSloppy){
129  copy(y, x); // copy x to y
130  zero(xSloppy);
131  }else{
132  zero(y); // no reliable updates // NB: check this
133  }
134 
135  const bool use_heavy_quark_res = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false;
136 
137  if(K){
138  csParam.create = QUDA_COPY_FIELD_CREATE;
140  rPre = new cudaColorSpinorField(rSloppy,csParam);
141  // Create minvrPre
142  minvrPre = new cudaColorSpinorField(*rPre);
143  commGlobalReductionSet(false);
144  (*K)(*minvrPre, *rPre);
146  *minvrSloppy = *minvrPre;
147  p = new cudaColorSpinorField(*minvrSloppy);
148  }else{
149  p = new cudaColorSpinorField(rSloppy);
150  }
151 
152 
153  profile.TPSTOP(QUDA_PROFILE_INIT);
154 
155 
157 
158 
159 
160  double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver
161  double heavy_quark_res = 0.0; // heavy quark residual
162  if(use_heavy_quark_res) heavy_quark_res = sqrt(HeavyQuarkResidualNorm(x,r).z);
163 
164  double alpha = 0.0, beta=0.0;
165  double pAp;
166  double rMinvr = 0;
167  double rMinvr_old = 0.0;
168  double r_new_Minvr_old = 0.0;
169  double r2_old = 0;
170  r2 = norm2(r);
171 
172  double rNorm = sqrt(r2);
173  double r0Norm = rNorm;
174  double maxrx = rNorm;
175  double maxrr = rNorm;
176  double delta = param.delta;
177 
178 
179  if(K) rMinvr = reDotProduct(rSloppy,*minvrSloppy);
180 
182  profile.TPSTART(QUDA_PROFILE_COMPUTE);
183 
184 
185  blas::flops = 0;
186 
187  const int maxResIncrease = param.max_res_increase; // check if we reached the limit of our tolerance
188  const int maxResIncreaseTotal = param.max_res_increase_total;
189 
190  int resIncrease = 0;
191  int resIncreaseTotal = 0;
192 
193  while(!convergence(r2, heavy_quark_res, stop, param.tol_hq) && k < param.maxiter){
194 
195  matSloppy(Ap, *p, tmpSloppy);
196 
197  double sigma;
198  pAp = reDotProduct(*p,Ap);
199 
200  alpha = (K) ? rMinvr/pAp : r2/pAp;
201  Complex cg_norm = axpyCGNorm(-alpha, Ap, rSloppy);
202  // r --> r - alpha*A*p
203  r2_old = r2;
204  r2 = real(cg_norm);
205 
206  sigma = imag(cg_norm) >= 0.0 ? imag(cg_norm) : r2; // use r2 if (r_k+1, r_k-1 - r_k) breaks
207 
208  if(K) rMinvr_old = rMinvr;
209 
210  rNorm = sqrt(r2);
211  if(rNorm > maxrx) maxrx = rNorm;
212  if(rNorm > maxrr) maxrr = rNorm;
213 
214 
215  int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? 1 : 0;
216  int updateR = ((rNorm < delta*maxrr && r0Norm <= maxrr) || updateX) ? 1 : 0;
217 
218 
219  // force a reliable update if we are within target tolerance (only if doing reliable updates)
220  if( convergence(r2, heavy_quark_res, stop, param.tol_hq) && delta >= param.tol) updateX = 1;
221 
222 
223  if( !(updateR || updateX) ){
224 
225  if(K){
226  r_new_Minvr_old = reDotProduct(rSloppy,*minvrSloppy);
227  *rPre = rSloppy;
228  commGlobalReductionSet(false);
229  (*K)(*minvrPre, *rPre);
231 
232 
233  *minvrSloppy = *minvrPre;
234 
235  rMinvr = reDotProduct(rSloppy,*minvrSloppy);
236  beta = (rMinvr - r_new_Minvr_old)/rMinvr_old;
237  axpyZpbx(alpha, *p, xSloppy, *minvrSloppy, beta);
238  }else{
239  beta = sigma/r2_old; // use the alternative beta computation
240  axpyZpbx(alpha, *p, xSloppy, rSloppy, beta);
241  }
242  } else { // reliable update
243 
244  axpy(alpha, *p, xSloppy); // xSloppy += alpha*p
245  copy(x, xSloppy);
246  xpy(x, y); // y += x
247  // Now compute r
248  mat(r, y, x); // x is just a temporary here
249  r2 = xmyNorm(b, r);
250  copy(rSloppy, r); // copy r to rSloppy
251  zero(xSloppy);
252 
253 
254  // break-out check if we have reached the limit of the precision
255  if(sqrt(r2) > r0Norm && updateX) {
256  resIncrease++;
257  resIncreaseTotal++;
258  // reuse r0Norm for this
259  warningQuda("PCG: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)", sqrt(r2), r0Norm, resIncreaseTotal);
260 
261  if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) break;
262 
263  } else {
264  resIncrease = 0;
265  }
266 
267  rNorm = sqrt(r2);
268  maxrr = rNorm;
269  maxrx = rNorm;
270  r0Norm = rNorm;
271  ++rUpdate;
272 
273  if(K){
274  *rPre = rSloppy;
275  commGlobalReductionSet(false);
276  (*K)(*minvrPre, *rPre);
278 
279  *minvrSloppy = *minvrPre;
280 
281  rMinvr = reDotProduct(rSloppy,*minvrSloppy);
282  beta = rMinvr/rMinvr_old;
283 
284  xpay(*minvrSloppy, beta, *p); // p = minvrSloppy + beta*p
285  }else{ // standard CG - no preconditioning
286 
287  // explicitly restore the orthogonality of the gradient vector
288  double rp = reDotProduct(rSloppy, *p)/(r2);
289  axpy(-rp, rSloppy, *p);
290 
291  beta = r2/r2_old;
292  xpay(rSloppy, beta, *p);
293  }
294  }
295  ++k;
296  PrintStats("PCG", k, r2, b2, heavy_quark_res);
297  }
298 
299 
301 
303 
304  if(x.Precision() != param.precision_sloppy) copy(x, xSloppy);
305  xpy(y, x); // x += y
306 
307 
309  double gflops = (blas::flops + mat.flops() + matSloppy.flops() + matPrecon.flops())*1e-9;
310  param.gflops = gflops;
311  param.iter += k;
312 
313  if (k==param.maxiter)
314  warningQuda("Exceeded maximum iterations %d", param.maxiter);
315 
316  if (getVerbosity() >= QUDA_VERBOSE)
317  printfQuda("CG: Reliable updates = %d\n", rUpdate);
318 
319 
320 
321 
322 
323  // compute the true residual
324  mat(r, x, y);
325  double true_res = xmyNorm(b, r);
326  param.true_res = sqrt(true_res / b2);
327 
328  // reset the flops counters
329  blas::flops = 0;
330  mat.flops();
331  matSloppy.flops();
332  matPrecon.flops();
333 
335  profile.TPSTART(QUDA_PROFILE_FREE);
336 
337  if(K){ // These are only needed if preconditioning is used
338  delete minvrPre;
339  delete rPre;
340  delete minvr;
341  if(x.Precision() != param.precision_sloppy) delete minvrSloppy;
342  }
343  delete p;
344 
345  if(x.Precision() != param.precision_sloppy){
346  delete x_sloppy;
347  delete r_sloppy;
348  }
349 
350  profile.TPSTOP(QUDA_PROFILE_FREE);
351  return;
352  }
353 
354 
355 } // namespace quda
void setPrecision(QudaPrecision precision, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION, bool force_native=false)
void operator()(ColorSpinorField &out, ColorSpinorField &in)
void axpyZpbx(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, double b)
Definition: blas_quda.cu:552
QudaInverterType inv_type
Definition: invert_quda.h:22
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define errorQuda(...)
Definition: util_quda.h:121
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:120
void PrintStats(const char *name, int k, double r2, double b2, double hq2)
Prints out the running statistics of the solver (requires a verbosity of QUDA_VERBOSE) ...
Definition: solver.cpp:256
bool convergence(double r2, double hq2, double r2_tol, double hq_tol)
Definition: solver.cpp:223
double reDotProduct(ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:728
Complex axpyCGNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:796
__host__ __device__ void copy(T1 &a, const T2 &b)
TimeProfile & profile
Definition: invert_quda.h:464
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:75
QudaInverterType inv_type_precondition
Definition: invert_quda.h:28
QudaPreserveSource preserve_source
Definition: invert_quda.h:154
int max_res_increase_total
Definition: invert_quda.h:96
void xpay(ColorSpinorField &x, double a, ColorSpinorField &y)
Definition: blas_quda.h:37
double norm2(const CloverField &a, bool inverse=false)
QudaGaugeParam param
Definition: pack_test.cpp:17
virtual ~PreconCG()
double Last(QudaProfileType idx)
Definition: timer.h:251
const DiracMatrix & matPrecon
Definition: invert_quda.h:693
QudaResidualType residual_type
Definition: invert_quda.h:49
static void fillInnerSolverParam(SolverParam &inner, const SolverParam &outer)
static double stopping(double tol, double b2, QudaResidualType residual_type)
Set the solver L2 stopping condition.
Definition: solver.cpp:206
ColorSpinorParam csParam
Definition: pack_test.cpp:24
const DiracMatrix & mat
Definition: invert_quda.h:691
#define warningQuda(...)
Definition: util_quda.h:133
bool is_preconditioner
verbosity to use for preconditioner
Definition: invert_quda.h:241
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
Definition: reduce_quda.cu:809
std::complex< double > Complex
Definition: quda_internal.h:46
double tol_precondition
Definition: invert_quda.h:199
QudaPrecision precision_precondition
Definition: invert_quda.h:151
QudaPrecision precision
Definition: invert_quda.h:142
SolverParam & param
Definition: invert_quda.h:463
Conjugate-Gradient Solver.
Definition: invert_quda.h:570
unsigned long long flops() const
Definition: dirac_quda.h:1119
#define printfQuda(...)
Definition: util_quda.h:115
SolverParam Kparam
Definition: invert_quda.h:696
unsigned long long flops
Definition: blas_quda.cu:22
void xpy(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:33
QudaPrecision precision_sloppy
Definition: invert_quda.h:145
__device__ void axpy(real a, const real *x, Link &y)
bool use_sloppy_partial_accumulator
Definition: invert_quda.h:76
PreconCG(DiracMatrix &mat, DiracMatrix &matSloppy, DiracMatrix &matPrecon, SolverParam &param, TimeProfile &profile)
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
QudaPrecision Precision() const
__device__ __host__ void zero(vector_type< scalar, n > &v)
Definition: cub_helper.cuh:54
const DiracMatrix & matSloppy
Definition: invert_quda.h:692
void updateR()
update the radius for halos.
void commGlobalReductionSet(bool global_reduce)