17 Solver(param, profile), mat(mat), matSloppy(matSloppy),
init(false)
54 printfQuda(
"Warning: inverting on zero-field source\n");
104 const bool use_heavy_quark_res =
113 int resIncreaseTotal = 0;
116 const int hqmaxresIncrease = maxResIncrease + 1;
118 double heavy_quark_res = 0.0;
119 double heavy_quark_res_old = 0.0;
120 int hqresIncrease = 0;
121 bool L2breakdown =
false;
136 if (mixed_precision) {
144 if (mixed_precision) {
151 if (use_heavy_quark_res) {
153 heavy_quark_res_old = heavy_quark_res;
166 double rNorm =
sqrt(r2);
167 double r0Norm = rNorm;
168 double maxrx = rNorm;
169 double maxrr = rNorm;
171 bool restart =
false;
174 double rho = 1.0, gamma = 1.0;
178 double gamma_old = gamma;
183 if(k==0 || restart) {
195 rho = rho/(rho-(gamma/gamma_old)*(r2/r2_old));
217 if (use_heavy_quark_res && k%heavy_quark_check==0) {
218 heavy_quark_res_old = heavy_quark_res;
219 if (mixed_precision) {
228 if (mixed_precision) {
231 if (rNorm > maxrx) maxrx = rNorm;
232 if (rNorm > maxrr) maxrr = rNorm;
233 bool update = (rNorm < delta*r0Norm && r0Norm <= maxrx);
234 update = ( update || (rNorm < delta*maxrr && r0Norm <= maxrr));
251 if (use_heavy_quark_res) {
270 warningQuda(
"CG3: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
271 sqrt(r2),
sqrt(r2_old), resIncreaseTotal);
272 if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
273 if (use_heavy_quark_res) {
276 warningQuda(
"CG3: solver exiting due to too many true residual norm increases");
285 if (use_heavy_quark_res and L2breakdown) {
287 heavy_quark_check = 1;
288 warningQuda(
"CG3: Restarting without reliable updates for heavy-quark residual");
291 if (heavy_quark_res > heavy_quark_res_old) {
293 warningQuda(
"CG3: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", heavy_quark_res, heavy_quark_res_old);
295 if (hqresIncrease > hqmaxresIncrease) {
296 warningQuda(
"CG3: solver exiting due to too many heavy quark residual norm increases");
303 mat(r, x, tmp, tmp2S);
317 warningQuda(
"CG3: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
318 sqrt(r2),
sqrt(r2_old), resIncreaseTotal);
319 if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) {
320 warningQuda(
"CG3: solver exiting due to too many true residual norm increases");
326 PrintStats(
"CG3", k, r2, b2, heavy_quark_res);
void setPrecision(QudaPrecision precision, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION, bool force_native=false)
bool convergenceHQ(double r2, double hq2, double r2_tol, double hq_tol)
Test for HQ solver convergence – ignore L2 residual.
double caxpyNorm(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
double quadrupleCG3InitNorm(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v)
double norm2(const ColorSpinorField &a)
__host__ __device__ ValueType sqrt(ValueType x)
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
void PrintStats(const char *name, int k, double r2, double b2, double hq2)
Prints out the running statistics of the solver (requires a verbosity of QUDA_VERBOSE) ...
cudaColorSpinorField * tmp
double3 xpyHeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &r)
static ColorSpinorField * Create(const ColorSpinorParam ¶m)
bool convergence(double r2, double hq2, double r2_tol, double hq_tol)
double reDotProduct(ColorSpinorField &x, ColorSpinorField &y)
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
ColorSpinorField * tmp2Sp
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
QudaPreserveSource preserve_source
int max_res_increase_total
ColorSpinorField * rS_oldp
void operator()(ColorSpinorField &out, ColorSpinorField &in)
QudaComputeNullVector compute_null_vector
double Last(QudaProfileType idx)
CG3(DiracMatrix &mat, DiracMatrix &matSloppy, SolverParam ¶m, TimeProfile &profile)
QudaResidualType residual_type
static double stopping(double tol, double b2, QudaResidualType residual_type)
Set the solver L2 stopping condition.
void axpy(double a, ColorSpinorField &x, ColorSpinorField &y)
#define checkLocation(...)
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
std::complex< double > Complex
void init()
Create the CUBLAS context.
void zero(ColorSpinorField &a)
unsigned long long flops() const
ColorSpinorField * xS_oldp
void xpy(ColorSpinorField &x, ColorSpinorField &y)
void axpby(double a, ColorSpinorField &x, double b, ColorSpinorField &y)
double axpyNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
QudaUseInitGuess use_init_guess
QudaPrecision precision_sloppy
void PrintSummary(const char *name, int k, double r2, double b2, double r2_tol, double hq_tol)
Prints out the summary of the solver convergence (requires a verbosity of QUDA_SUMMARIZE). Assumes SolverParam.true_res and SolverParam.true_res_hq has been set.
const DiracMatrix & matSloppy
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
QudaPrecision Precision() const
double quadrupleCG3UpdateNorm(double a, double b, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v)