QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
inv_mr_quda.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 
5 #include <complex>
6 
7 #include <quda_internal.h>
8 #include <blas_quda.h>
9 #include <dslash_quda.h>
10 #include <invert_quda.h>
11 #include <util_quda.h>
12 #include <color_spinor_field.h>
13 
14 namespace quda {
15 
17  Solver(param, profile), mat(mat), matSloppy(matSloppy), rp(nullptr), r_sloppy(nullptr),
18  Arp(nullptr), tmpp(nullptr), tmp_sloppy(nullptr), x_sloppy(nullptr), init(false)
19  {
20  if (param.schwarz_type == QUDA_MULTIPLICATIVE_SCHWARZ && param.Nsteps % 2 == 1) {
21  errorQuda("For multiplicative Schwarz, number of solver steps %d must be even", param.Nsteps);
22  }
23  }
24 
25  MR::~MR() {
27  if (init) {
28  if (x_sloppy) delete x_sloppy;
29  if (tmp_sloppy) delete tmp_sloppy;
30  if (tmpp) delete tmpp;
31  if (Arp) delete Arp;
32  if (r_sloppy) delete r_sloppy;
33  if (rp) delete rp;
34  }
36  }
37 
39  {
40  if (checkPrecision(x,b) != param.precision) errorQuda("Precision mismatch %d %d", checkPrecision(x,b), param.precision);
41 
42  if (param.maxiter == 0 || param.Nsteps == 0) {
44  return;
45  }
46 
47  if (!init) {
48  bool mixed = param.precision != param.precision_sloppy;
49 
52 
53  // Source needs to be preserved if we're computing the true residual
55  || param.Nsteps > 1 || param.compute_true_res == 1) ?
56  ColorSpinorField::Create(csParam) : nullptr;
57 
59  ColorSpinorField::Create(csParam) : nullptr;
60 
61  // now allocate sloppy fields
63 
64  r_sloppy = mixed ? ColorSpinorField::Create(csParam) : nullptr; // we need a separate sloppy residual vector
65  Arp = ColorSpinorField::Create(csParam);
66 
67  //sloppy temporary for mat-vec
68  tmp_sloppy = (!tmpp || mixed) ? ColorSpinorField::Create(csParam) : nullptr;
69 
70  // iterated sloppy solution vector
72 
73  init = true;
74  } // init
75 
76  ColorSpinorField &r = rp ? *rp : b;
77  ColorSpinorField &rSloppy = r_sloppy ? *r_sloppy : r;
78  ColorSpinorField &Ar = *Arp;
79  ColorSpinorField &tmp = tmpp ? *tmpp : b;
80  ColorSpinorField &tmpSloppy = tmp_sloppy ? *tmp_sloppy : tmp;
81  ColorSpinorField &xSloppy = *x_sloppy;
82 
83  if (!param.is_preconditioner) {
84  blas::flops = 0;
86  }
87 
88  double b2 = blas::norm2(b); //Save norm of b
89  double r2 = 0.0; // if zero source then we will exit immediately doing no work
91  mat(r, x, tmp);
92  r2 = blas::xmyNorm(b, r); //r = b - Ax0
93  } else {
94  r2 = b2;
95  blas::copy(r, b);
96  blas::zero(x); // needed?
97  }
98  blas::copy(rSloppy, r);
99 
100  // if invalid residual then convergence is set by iteration count only
101  double stop = param.residual_type == QUDA_INVALID_RESIDUAL ? 0.0 : b2*param.tol*param.tol;
102  int step = 0;
103 
104  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("MR: Initial residual = %e\n", sqrt(r2));
105 
106  bool converged = false;
107  while (!converged) {
108 
109  double scale = 1.0;
110  if ((node_parity+step)%2 == 0 && param.schwarz_type == QUDA_MULTIPLICATIVE_SCHWARZ) {
111  // for multiplicative Schwarz we alternate updates depending on node parity
112  } else {
113 
114  commGlobalReductionSet(param.global_reduction); // use local reductions for DD solver
115 
116  blas::zero(xSloppy); // can get rid of this for a special first update kernel
117  double c2 = param.global_reduction == QUDA_BOOLEAN_TRUE ? r2 : blas::norm2(r); // c2 holds the initial r2
118  scale = c2 > 0.0 ? sqrt(c2) : 1.0;
119 
120  // domain-wise normalization of the initial residual to prevent underflow
121  if (c2 > 0.0) {
122  blas::ax(1/scale, rSloppy); // can merge this with the prior copy
123  r2 = 1.0; // by definition by this is now true
124  }
125 
126  int k = 0;
127  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("MR: %d cycle, %d iterations, r2 = %e\n", step, k, r2);
128 
129  double3 Ar3;
130  while (k < param.maxiter && r2 > 0.0) {
131 
132  matSloppy(Ar, rSloppy, tmpSloppy);
133 
134  if (param.global_reduction) {
135  Ar3 = blas::cDotProductNormA(Ar, rSloppy);
136  Complex alpha = Complex(Ar3.x, Ar3.y) / Ar3.z;
137 
138  // x += omega*alpha*r, r -= omega*alpha*Ar, r2 = blas::norm2(r)
139  //r2 = blas::caxpyXmazNormX(omega*alpha, r, x, Ar);
140  blas::caxpyXmaz(param.omega*alpha, rSloppy, xSloppy, Ar);
141 
142  if (getVerbosity() >= QUDA_VERBOSE)
143  printfQuda("MR: %d cycle, %d iterations, <r|A|r> = (%e, %e)\n", step, k+1, Ar3.x, Ar3.y);
144  } else {
145  // doing local reductions so can make it asynchronous
146  commAsyncReductionSet(true);
147  Ar3 = blas::cDotProductNormA(Ar, rSloppy);
148 
149  // omega*alpha is done in the kernel
150  blas::caxpyXmazMR(param.omega, rSloppy, xSloppy, Ar);
151  commAsyncReductionSet(false);
152  }
153  k++;
154 
155  }
156 
157  // Scale and sum to accumulator
158  blas::axpy(scale,xSloppy,x);
159 
160  commGlobalReductionSet(true); // renable global reductions for outer solver
161 
162  }
163  step++;
164 
165  // FIXME - add over/under relaxation in outer loop
166  if (param.compute_true_res || param.Nsteps > 1) {
167  mat(r, x, tmp);
168  r2 = blas::xmyNorm(b, r);
169  param.true_res = sqrt(r2 / b2);
170 
171  converged = (step < param.Nsteps && r2 > stop) ? false : true;
172 
173  // if not preserving source and finished then overide source with residual
174  if (param.preserve_source == QUDA_PRESERVE_SOURCE_NO && converged) blas::copy(b, r);
175  else blas::copy(rSloppy, r);
176 
177  if (getVerbosity() >= QUDA_SUMMARIZE) {
178  printfQuda("MR: %d cycle, Converged after %d iterations, relative residual: true = %e\n",
179  step, param.maxiter, sqrt(r2));
180  }
181  } else {
182 
183  blas::ax(scale, rSloppy);
184  r2 = blas::norm2(rSloppy);
185 
186  converged = (step < param.Nsteps) ? false : true;
187 
188  // if not preserving source and finished then overide source with residual
189  if (param.preserve_source == QUDA_PRESERVE_SOURCE_NO && converged) blas::copy(b, rSloppy);
190  else blas::copy(r, rSloppy);
191 
192  if (getVerbosity() >= QUDA_SUMMARIZE) {
193  printfQuda("MR: %d cycle, Converged after %d iterations, relative residual: iterated = %e\n",
194  step, param.maxiter, sqrt(r2));
195  }
196  }
197 
198  }
199 
200  if (!param.is_preconditioner) {
204 
205  // store flops and reset counters
206  double gflops = (blas::flops + mat.flops() + matSloppy.flops())*1e-9;
207 
208  param.gflops += gflops;
210  blas::flops = 0;
211 
213  }
214 
215  return;
216  }
217 
218 } // namespace quda
void ax(double a, ColorSpinorField &x)
Definition: blas_quda.cu:508
bool global_reduction
whether the solver acting as a preconditioner for another solver
Definition: invert_quda.h:243
QudaSchwarzType schwarz_type
Definition: invert_quda.h:217
void setPrecision(QudaPrecision precision, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION, bool force_native=false)
void caxpyXmazMR(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Definition: blas_quda.cu:603
double3 cDotProductNormA(ColorSpinorField &a, ColorSpinorField &b)
Definition: reduce_quda.cu:778
MR(DiracMatrix &mat, DiracMatrix &matSloppy, SolverParam &param, TimeProfile &profile)
Definition: inv_mr_quda.cpp:16
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define checkPrecision(...)
#define errorQuda(...)
Definition: util_quda.h:121
double norm2(const ColorSpinorField &a)
Definition: reduce_quda.cu:721
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:120
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:44
static ColorSpinorField * Create(const ColorSpinorParam &param)
TimeProfile & profile
Definition: invert_quda.h:464
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
Definition: copy_quda.cu:355
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:75
QudaPreserveSource preserve_source
Definition: invert_quda.h:154
QudaGaugeParam param
Definition: pack_test.cpp:17
double Last(QudaProfileType idx)
Definition: timer.h:251
const DiracMatrix & matSloppy
Definition: invert_quda.h:867
QudaResidualType residual_type
Definition: invert_quda.h:49
ColorSpinorParam csParam
Definition: pack_test.cpp:24
void axpy(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:35
ColorSpinorField * Arp
Definition: invert_quda.h:870
bool is_preconditioner
verbosity to use for preconditioner
Definition: invert_quda.h:241
std::complex< double > Complex
Definition: quda_internal.h:46
ColorSpinorField * rp
Definition: invert_quda.h:868
bool init
Definition: invert_quda.h:874
void init()
Create the CUBLAS context.
Definition: blas_cublas.cu:31
void zero(ColorSpinorField &a)
Definition: blas_quda.cu:472
void commAsyncReductionSet(bool global_reduce)
QudaPrecision precision
Definition: invert_quda.h:142
ColorSpinorField * tmp_sloppy
Definition: invert_quda.h:872
SolverParam & param
Definition: invert_quda.h:463
void operator()(ColorSpinorField &out, ColorSpinorField &in)
Definition: inv_mr_quda.cpp:38
void caxpyXmaz(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Definition: blas_quda.cu:597
unsigned long long flops() const
Definition: dirac_quda.h:1119
ColorSpinorField * r_sloppy
Definition: invert_quda.h:869
#define printfQuda(...)
Definition: util_quda.h:115
ColorSpinorField * tmpp
Definition: invert_quda.h:871
unsigned long long flops
Definition: blas_quda.cu:22
virtual ~MR()
Definition: inv_mr_quda.cpp:25
const DiracMatrix & mat
Definition: invert_quda.h:866
ColorSpinorField * x_sloppy
Definition: invert_quda.h:873
QudaUseInitGuess use_init_guess
Definition: invert_quda.h:64
QudaPrecision precision_sloppy
Definition: invert_quda.h:145
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
void commGlobalReductionSet(bool global_reduce)