QUDA  v1.1.0
A library for QCD on GPUs
inv_mr_quda.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 
5 #include <complex>
6 
7 #include <quda_internal.h>
8 #include <blas_quda.h>
9 #include <dslash_quda.h>
10 #include <invert_quda.h>
11 #include <util_quda.h>
12 #include <color_spinor_field.h>
13 
14 namespace quda {
15 
16  MR::MR(const DiracMatrix &mat, const DiracMatrix &matSloppy, SolverParam &param, TimeProfile &profile) :
17  Solver(mat, matSloppy, matSloppy, matSloppy, param, profile),
18  rp(nullptr),
19  r_sloppy(nullptr),
20  Arp(nullptr),
21  tmpp(nullptr),
22  tmp_sloppy(nullptr),
23  x_sloppy(nullptr),
24  init(false)
25  {
27  errorQuda("For multiplicative Schwarz, number of solver steps %d must be even", param.Nsteps);
28  }
29  }
30 
31  MR::~MR() {
33  if (init) {
34  if (x_sloppy) delete x_sloppy;
35  if (tmp_sloppy) delete tmp_sloppy;
36  if (tmpp) delete tmpp;
37  if (Arp) delete Arp;
38  if (r_sloppy) delete r_sloppy;
39  if (rp) delete rp;
40  }
42  }
43 
45  {
46  if (checkPrecision(x,b) != param.precision) errorQuda("Precision mismatch %d %d", checkPrecision(x,b), param.precision);
47 
48  if (param.maxiter == 0 || param.Nsteps == 0) {
50  return;
51  }
52 
53  if (!init) {
54  bool mixed = param.precision != param.precision_sloppy;
55 
58 
59  // Source needs to be preserved if we're computing the true residual
61  || param.Nsteps > 1 || param.compute_true_res == 1) ?
63 
66 
67  // now allocate sloppy fields
68  csParam.setPrecision(param.precision_sloppy);
69 
70  r_sloppy = mixed ? ColorSpinorField::Create(csParam) : nullptr; // we need a separate sloppy residual vector
72 
73  //sloppy temporary for mat-vec
74  tmp_sloppy = (!tmpp || mixed) ? ColorSpinorField::Create(csParam) : nullptr;
75 
76  // iterated sloppy solution vector
78 
79  init = true;
80  } // init
81 
82  ColorSpinorField &r = rp ? *rp : b;
83  ColorSpinorField &rSloppy = r_sloppy ? *r_sloppy : r;
84  ColorSpinorField &Ar = *Arp;
85  ColorSpinorField &tmp = tmpp ? *tmpp : b;
86  ColorSpinorField &tmpSloppy = tmp_sloppy ? *tmp_sloppy : tmp;
87  ColorSpinorField &xSloppy = *x_sloppy;
88 
89  if (!param.is_preconditioner) {
90  blas::flops = 0;
92  }
93 
94  double b2 = blas::norm2(b); //Save norm of b
95  double r2 = 0.0; // if zero source then we will exit immediately doing no work
97  mat(r, x, tmp);
98  r2 = blas::xmyNorm(b, r); //r = b - Ax0
99  } else {
100  r2 = b2;
101  blas::copy(r, b);
102  blas::zero(x); // needed?
103  }
104  blas::copy(rSloppy, r);
105 
106  // if invalid residual then convergence is set by iteration count only
108  int step = 0;
109 
110  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("MR: Initial residual = %e\n", sqrt(r2));
111 
112  bool converged = false;
113  while (!converged) {
114 
115  double scale = 1.0;
116  if ((node_parity+step)%2 == 0 && param.schwarz_type == QUDA_MULTIPLICATIVE_SCHWARZ) {
117  // for multiplicative Schwarz we alternate updates depending on node parity
118  } else {
119 
120  commGlobalReductionSet(param.global_reduction); // use local reductions for DD solver
121 
122  blas::zero(xSloppy); // can get rid of this for a special first update kernel
123  double c2 = param.global_reduction == QUDA_BOOLEAN_TRUE ? r2 : blas::norm2(r); // c2 holds the initial r2
124  scale = c2 > 0.0 ? sqrt(c2) : 1.0;
125 
126  // domain-wise normalization of the initial residual to prevent underflow
127  if (c2 > 0.0) {
128  blas::ax(1/scale, rSloppy); // can merge this with the prior copy
129  r2 = 1.0; // by definition by this is now true
130  }
131 
132  int k = 0;
133  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("MR: %d cycle, %d iterations, r2 = %e\n", step, k, r2);
134 
135  double3 Ar3;
136  while (k < param.maxiter && r2 > 0.0) {
137 
138  matSloppy(Ar, rSloppy, tmpSloppy);
139 
140  if (param.global_reduction) {
141  Ar3 = blas::cDotProductNormA(Ar, rSloppy);
142  Complex alpha = Complex(Ar3.x, Ar3.y) / Ar3.z;
143 
144  // x += omega*alpha*r, r -= omega*alpha*Ar, r2 = blas::norm2(r)
145  //r2 = blas::caxpyXmazNormX(omega*alpha, r, x, Ar);
146  blas::caxpyXmaz(param.omega*alpha, rSloppy, xSloppy, Ar);
147 
148  if (getVerbosity() >= QUDA_VERBOSE)
149  printfQuda("MR: %d cycle, %d iterations, <r|A|r> = (%e, %e)\n", step, k+1, Ar3.x, Ar3.y);
150  } else {
151  // doing local reductions so can make it asynchronous
152  commAsyncReductionSet(true);
153  Ar3 = blas::cDotProductNormA(Ar, rSloppy);
154 
155  // omega*alpha is done in the kernel
156  blas::caxpyXmazMR(param.omega, rSloppy, xSloppy, Ar);
157  commAsyncReductionSet(false);
158  }
159  k++;
160 
161  }
162 
163  // Scale and sum to accumulator
164  blas::axpy(scale,xSloppy,x);
165 
166  commGlobalReductionSet(true); // renable global reductions for outer solver
167 
168  }
169  step++;
170 
171  // FIXME - add over/under relaxation in outer loop
172  if (param.compute_true_res || param.Nsteps > 1) {
173  mat(r, x, tmp);
174  r2 = blas::xmyNorm(b, r);
175  param.true_res = sqrt(r2 / b2);
176 
177  converged = (step < param.Nsteps && r2 > stop) ? false : true;
178 
179  // if not preserving source and finished then overide source with residual
180  if (param.preserve_source == QUDA_PRESERVE_SOURCE_NO && converged) blas::copy(b, r);
181  else blas::copy(rSloppy, r);
182 
183  if (getVerbosity() >= QUDA_SUMMARIZE) {
184  printfQuda("MR: %d cycle, Converged after %d iterations, relative residual: true = %e\n",
185  step, param.maxiter, sqrt(r2));
186  }
187  } else {
188 
189  blas::ax(scale, rSloppy);
190  r2 = blas::norm2(rSloppy);
191 
192  converged = (step < param.Nsteps) ? false : true;
193 
194  // if not preserving source and finished then overide source with residual
195  if (param.preserve_source == QUDA_PRESERVE_SOURCE_NO && converged) blas::copy(b, rSloppy);
196  else blas::copy(r, rSloppy);
197 
198  if (getVerbosity() >= QUDA_SUMMARIZE) {
199  printfQuda("MR: %d cycle, Converged after %d iterations, relative residual: iterated = %e\n",
200  step, param.maxiter, sqrt(r2));
201  }
202  }
203 
204  }
205 
206  if (!param.is_preconditioner) {
210 
211  // store flops and reset counters
212  double gflops = (blas::flops + mat.flops() + matSloppy.flops())*1e-9;
213 
214  param.gflops += gflops;
216  blas::flops = 0;
217 
219  }
220 
221  return;
222  }
223 
224 } // namespace quda
static ColorSpinorField * Create(const ColorSpinorParam &param)
unsigned long long flops() const
Definition: dirac_quda.h:1909
MR(const DiracMatrix &mat, const DiracMatrix &matSloppy, SolverParam &param, TimeProfile &profile)
Definition: inv_mr_quda.cpp:16
virtual ~MR()
Definition: inv_mr_quda.cpp:31
void operator()(ColorSpinorField &out, ColorSpinorField &in)
Definition: inv_mr_quda.cpp:44
TimeProfile & profile
Definition: invert_quda.h:471
const DiracMatrix & mat
Definition: invert_quda.h:465
SolverParam & param
Definition: invert_quda.h:470
const DiracMatrix & matSloppy
Definition: invert_quda.h:466
double Last(QudaProfileType idx)
Definition: timer.h:254
void commAsyncReductionSet(bool global_reduce)
void commGlobalReductionSet(bool global_reduce)
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:34
@ QUDA_USE_INIT_GUESS_NO
Definition: enum_quda.h:429
@ QUDA_USE_INIT_GUESS_YES
Definition: enum_quda.h:430
@ QUDA_SUMMARIZE
Definition: enum_quda.h:266
@ QUDA_VERBOSE
Definition: enum_quda.h:267
@ QUDA_BOOLEAN_TRUE
Definition: enum_quda.h:461
@ QUDA_INVALID_RESIDUAL
Definition: enum_quda.h:196
@ QUDA_PRESERVE_SOURCE_NO
Definition: enum_quda.h:238
@ QUDA_PRESERVE_SOURCE_YES
Definition: enum_quda.h:239
@ QUDA_MULTIPLICATIVE_SCHWARZ
Definition: enum_quda.h:188
@ QUDA_NULL_FIELD_CREATE
Definition: enum_quda.h:360
#define checkPrecision(...)
void init()
Create the BLAS context.
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:79
unsigned long long flops
void caxpyXmaz(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
void ax(double a, ColorSpinorField &x)
void caxpyXmazMR(const double &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
void zero(ColorSpinorField &a)
double norm2(const ColorSpinorField &a)
void axpy(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:43
double3 cDotProductNormA(ColorSpinorField &a, ColorSpinorField &b)
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
Definition: blas_quda.h:24
void stop()
Stop profiling.
Definition: device.cpp:228
std::complex< double > Complex
Definition: quda_internal.h:86
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:120
@ QUDA_PROFILE_EPILOGUE
Definition: timer.h:110
@ QUDA_PROFILE_COMPUTE
Definition: timer.h:108
@ QUDA_PROFILE_FREE
Definition: timer.h:111
ColorSpinorParam csParam
Definition: pack_test.cpp:25
QudaGaugeParam param
Definition: pack_test.cpp:18
QudaPreserveSource preserve_source
Definition: invert_quda.h:151
QudaPrecision precision
Definition: invert_quda.h:136
bool is_preconditioner
verbosity to use for preconditioner
Definition: invert_quda.h:238
QudaSchwarzType schwarz_type
Definition: invert_quda.h:214
QudaResidualType residual_type
Definition: invert_quda.h:49
QudaPrecision precision_sloppy
Definition: invert_quda.h:139
QudaUseInitGuess use_init_guess
Definition: invert_quda.h:58
bool global_reduction
whether the solver acting as a preconditioner for another solver
Definition: invert_quda.h:240
#define printfQuda(...)
Definition: util_quda.h:114
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define errorQuda(...)
Definition: util_quda.h:120