QUDA  v0.7.0
A library for QCD on GPUs
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
solver.cpp
Go to the documentation of this file.
1 #include <quda_internal.h>
2 #include <invert_quda.h>
3 #include <cmath>
4 
5 namespace quda {
6 
7  static void report(const char *type) {
8  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("Creating a %s solver\n", type);
9  }
10 
11  // solver factory
13  DiracMatrix &matPrecon, TimeProfile &profile)
14  {
15  Solver *solver=0;
16 
17  switch (param.inv_type) {
18  case QUDA_CG_INVERTER:
19  report("CG");
20  solver = new CG(mat, matSloppy, param, profile);
21  break;
23  report("BiCGstab");
24  solver = new BiCGstab(mat, matSloppy, matPrecon, param, profile);
25  break;
26  case QUDA_GCR_INVERTER:
27  report("GCR");
28  solver = new GCR(mat, matSloppy, matPrecon, param, profile);
29  break;
30  case QUDA_MR_INVERTER:
31  report("MR");
32  solver = new MR(mat, param, profile);
33  break;
34  case QUDA_SD_INVERTER:
35  report("SD");
36  solver = new SD(mat, param, profile);
37  break;
38  case QUDA_XSD_INVERTER:
39 #ifdef MULTI_GPU
40  report("XSD");
41  solver = new XSD(mat, param, profile);
42 #else
43  errorQuda("Extended Steepest Descent is multi-gpu only");
44 #endif
45  break;
46  case QUDA_PCG_INVERTER:
47  report("PCG");
48  solver = new PreconCG(mat, matSloppy, matPrecon, param, profile);
49  break;
50  case QUDA_MPCG_INVERTER:
51  report("MPCG");
52  solver = new MPCG(mat, param, profile);
53  break;
55  report("MPBICGSTAB");
56  solver = new MPBiCGstab(mat, param, profile);
57  break;
58  default:
59  errorQuda("Invalid solver type");
60  }
61 
62  return solver;
63  }
64 
65  double Solver::stopping(const double &tol, const double &b2, QudaResidualType residual_type) {
66 
67  double stop=0.0;
68  if ( (residual_type & QUDA_L2_ABSOLUTE_RESIDUAL) &&
69  (residual_type & QUDA_L2_RELATIVE_RESIDUAL) ) {
70  // use the most stringent stopping condition
71  double lowest = (b2 < 1.0) ? b2 : 1.0;
72  stop = lowest*tol*tol;
73  } else if (residual_type & QUDA_L2_ABSOLUTE_RESIDUAL) {
74  stop = tol*tol;
75  } else {
76  stop = b2*tol*tol;
77  }
78 
79  return stop;
80  }
81 
82  bool Solver::convergence(const double &r2, const double &hq2, const double &r2_tol,
83  const double &hq_tol) {
84  //printf("converge: L2 %e / %e and HQ %e / %e\n", r2, r2_tol, hq2, hq_tol);
85 
86  // check the heavy quark residual norm if necessary
87  if ( (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) && (hq2 > hq_tol) )
88  return false;
89 
90  // check the L2 relative residual norm if necessary
92  (param.residual_type & QUDA_L2_ABSOLUTE_RESIDUAL)) && (r2 > r2_tol) )
93  return false;
94 
95  return true;
96  }
97 
98 //
99  bool Solver::convergenceHQ(const double &r2, const double &hq2, const double &r2_tol,
100  const double &hq_tol) {
101  //printf("converge: L2 %e / %e and HQ %e / %e\n", r2, r2_tol, hq2, hq_tol);
102 
103  // check the heavy quark residual norm if necessary
104  if ( (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) && (hq2 > hq_tol) )
105  return false;
106 
107  return true;
108  }
109 
110  bool Solver::convergenceL2(const double &r2, const double &hq2, const double &r2_tol,
111  const double &hq_tol) {
112  //printf("converge: L2 %e / %e and HQ %e / %e\n", r2, r2_tol, hq2, hq_tol);
113 
114  // check the L2 relative residual norm if necessary
116  (param.residual_type & QUDA_L2_ABSOLUTE_RESIDUAL)) && (r2 > r2_tol) )
117  return false;
118 
119  return true;
120  }
121 
122  void Solver::PrintStats(const char* name, int k, const double &r2,
123  const double &b2, const double &hq2) {
124  if (getVerbosity() >= QUDA_VERBOSE) {
126  printfQuda("%s: %d iterations, <r,r> = %e, |r|/|b| = %e, heavy-quark residual = %e\n",
127  name, k, r2, sqrt(r2/b2), hq2);
128  } else {
129  printfQuda("%s: %d iterations, <r,r> = %e, |r|/|b| = %e\n",
130  name, k, r2, sqrt(r2/b2));
131  }
132  }
133 
134  if (std::isnan(r2)) errorQuda("Solver appears to have diverged");
135  }
136 
137  void Solver::PrintSummary(const char *name, int k, const double &r2, const double &b2) {
138  if (getVerbosity() >= QUDA_SUMMARIZE) {
140  printfQuda("%s: Convergence at %d iterations, L2 relative residual: iterated = %e, true = %e, heavy-quark residual = %e\n", name, k, sqrt(r2/b2), param.true_res, param.true_res_hq);
141  } else {
142  printfQuda("%s: Convergence at %d iterations, L2 relative residual: iterated = %e, true = %e\n",
143  name, k, sqrt(r2/b2), param.true_res);
144  }
145 
146  }
147  }
148 
149  // Deflated solver factory
151  {
152  DeflatedSolver* solver=0;
153 
155  report("Incremental EIGCG");
156  solver = new IncEigCG(mat, matSloppy, matCGSloppy, matDeflate, param, profile);
157  }else{
158  errorQuda("Invalid solver type");
159  }
160 
161  return solver;
162  }
163 
164  bool DeflatedSolver::convergence(const double &r2, const double &hq2, const double &r2_tol,
165  const double &hq_tol) {
166  //printf("converge: L2 %e / %e and HQ %e / %e\n", r2, r2_tol, hq2, hq_tol);
167 
168  // check the heavy quark residual norm if necessary
169  if ( (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) && (hq2 > hq_tol) )
170  return false;
171 
172  // check the L2 relative residual norm if necessary
173  if ( (param.residual_type & QUDA_L2_RELATIVE_RESIDUAL) && (r2 > r2_tol) )
174  return false;
175 
176  return true;
177  }
178 
179  void DeflatedSolver::PrintStats(const char* name, int k, const double &r2,
180  const double &b2, const double &hq2) {
181  if (getVerbosity() >= QUDA_VERBOSE) {
183  printfQuda("%s: %d iterations, <r,r> = %e, |r|/|b| = %e, heavy-quark residual = %e\n",
184  name, k, r2, sqrt(r2/b2), hq2);
185  } else {
186  printfQuda("%s: %d iterations, <r,r> = %e, |r|/|b| = %e\n",
187  name, k, r2, sqrt(r2/b2));
188  }
189  }
190 
191  if (std::isnan(r2)) errorQuda("Solver appears to have diverged");
192  }
193 
194  void DeflatedSolver::PrintSummary(const char *name, int k, const double &r2, const double &b2) {
195  if (getVerbosity() >= QUDA_SUMMARIZE) {
197  printfQuda("%s: Convergence at %d iterations, L2 relative residual: iterated = %e, true = %e, heavy-quark residual = %e\n", name, k, sqrt(r2/b2), param.true_res, param.true_res_hq);
198  } else {
199  printfQuda("%s: Convergence at %d iterations, L2 relative residual: iterated = %e, true = %e\n",
200  name, k, sqrt(r2/b2), param.true_res);
201  }
202 
203  }
204  }
205 
206 
207 } // namespace quda
bool convergence(const double &r2, const double &hq2, const double &r2_tol, const double &hq_tol)
Definition: solver.cpp:82
static double stopping(const double &tol, const double &b2, QudaResidualType residual_type)
Definition: solver.cpp:65
enum QudaResidualType_s QudaResidualType
QudaInverterType inv_type
Definition: invert_quda.h:18
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:20
#define errorQuda(...)
Definition: util_quda.h:73
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:105
bool convergenceL2(const double &r2, const double &hq2, const double &r2_tol, const double &hq_tol)
Definition: solver.cpp:110
void mat(void *out, void **fatlink, void **longlink, void *in, double kappa, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision)
QudaGaugeParam param
Definition: pack_test.cpp:17
void PrintSummary(const char *name, int k, const double &r2, const double &b2)
Definition: solver.cpp:137
QudaResidualType residual_type
Definition: invert_quda.h:35
static Solver * create(SolverParam &param, DiracMatrix &mat, DiracMatrix &matSloppy, DiracMatrix &matPrecon, TimeProfile &profile)
Definition: solver.cpp:12
bool convergence(const double &r2, const double &hq2, const double &r2_tol, const double &hq_tol)
Definition: solver.cpp:164
void PrintSummary(const char *name, int k, const double &r2, const double &b2)
Definition: solver.cpp:194
SolverParam & param
Definition: invert_quda.h:223
static DeflatedSolver * create(SolverParam &param, DiracMatrix &mat, DiracMatrix &matSloppy, DiracMatrix &matCGSloppy, DiracMatrix &matDeflate, TimeProfile &profile)
Definition: solver.cpp:150
void PrintStats(const char *, int k, const double &r2, const double &b2, const double &hq2)
Definition: solver.cpp:122
SolverParam & param
Definition: invert_quda.h:533
#define printfQuda(...)
Definition: util_quda.h:67
void PrintStats(const char *, int k, const double &r2, const double &b2, const double &hq2)
Definition: solver.cpp:179
bool convergenceHQ(const double &r2, const double &hq2, const double &r2_tol, const double &hq_tol)
Definition: solver.cpp:99