QUDA  0.9.0
inv_mr_quda.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 
5 #include <complex>
6 
7 #include <quda_internal.h>
8 #include <blas_quda.h>
9 #include <dslash_quda.h>
10 #include <invert_quda.h>
11 #include <util_quda.h>
12 #include <color_spinor_field.h>
13 
14 namespace quda {
15 
17  Solver(param, profile), mat(mat), matSloppy(matSloppy), init(false), allocate_r(false), allocate_y(false)
18  {
19 
20  }
21 
22  MR::~MR() {
24  if (init) {
25  if (allocate_r) delete rp;
26  delete Arp;
27  delete tmpp;
28  if (allocate_y) delete yp;
29 
30  }
32  }
33 
35  {
36  commGlobalReductionSet(param.global_reduction); // use local reductions for DD solver
37 
38  if (!init) {
41  csParam.precision = param.precision_sloppy;
43  tmpp = ColorSpinorField::Create(csParam); //temporary for mat-vec
44  init = true;
45  }
46 
47  //Source needs to be preserved if initial guess is used or if different precision is requested
48  if(!allocate_r &&
52  csParam.precision = param.precision_sloppy;
54  allocate_r = true;
55  }
56 
57  // y is the (sloppy) iterated solution vector
58  if (!allocate_y) {
61  csParam.precision = param.precision_sloppy;
63  allocate_y = true;
64  }
65 
66  ColorSpinorField &r = allocate_r ? *rp : b;
67  ColorSpinorField &Ar = *Arp;
69  ColorSpinorField &y = *yp;
70 
71  double r2=0.0; // if zero source then we will exit immediately doing no work
73  blas::copy(tmp, x);
74  matSloppy(r, tmp, Ar);
75  blas::copy(y, b);
76  r2 = blas::xmyNorm(y, r); //r = b - Ax0
77  } else {
78  if (&r != &b) blas::copy(r, b);
79  r2 = blas::norm2(r);
80  blas::zero(x);
81  }
82 
83  // set initial guess to zero and thus the residual is just the source
84  blas::zero(y); // can get rid of this for a special first update kernel
85  double b2 = blas::norm2(b); //Save norm of b
86  double c2 = r2; //c2 holds the initial r2 after (possible) subtraction of initial guess
87 
88  // domain-wise normalization of the initial residual to prevent underflow
89  if (c2 > 0.0) {
90  blas::ax(1/sqrt(c2), r); // can merge this with the prior copy
91  r2 = 1.0; // by definition by this is now true
92  }
93 
94  if (!param.is_preconditioner) {
95  blas::flops = 0;
97  }
98 
99  double omega = param.omega;
100 
101  int k = 0;
102  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) {
103  double x2 = blas::norm2(y);
104  double3 Ar3 = blas::cDotProductNormB(Ar, r);
105  printfQuda("MR: %d iterations, r2 = %e, <r|A|r> = (%e, %e), x2 = %e\n",
106  k, Ar3.z, Ar3.x, Ar3.y, x2);
107  } else if (getVerbosity() >= QUDA_VERBOSE) {
108  printfQuda("MR: %d iterations, r2 = %e\n", k, r2);
109  }
110 
111  double3 Ar3;
112  while (k < param.maxiter && r2 > 0.0) {
113 
114  matSloppy(Ar, r, tmp);
115 
116  if (param.global_reduction) {
117  Ar3 = blas::cDotProductNormA(Ar, r);
118  Complex alpha = Complex(Ar3.x, Ar3.y) / Ar3.z;
119 
120  // x += omega*alpha*r, r -= omega*alpha*Ar, r2 = blas::norm2(r)
121  //r2 = blas::caxpyXmazNormX(omega*alpha, r, x, Ar);
122  blas::caxpyXmaz(omega*alpha, r, y, Ar);
123  } else {
124  // doing local reductions so can make it asynchronous
125  commAsyncReductionSet(true);
126  Ar3 = blas::cDotProductNormA(Ar, r);
127 
128  // omega*alpha is done in the kernel
129  blas::caxpyXmazMR(omega, r, y, Ar);
130  commAsyncReductionSet(false);
131  }
132  k++;
133 
134  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) {
135  double x2 = blas::norm2(y);
136  double r2 = blas::norm2(r);
137  printfQuda("MR: %d iterations, r2 = %e, <r|A|r> = (%e,%e) x2 = %e\n",
138  k+1, r2, Ar3.x, Ar3.y, x2);
139  } else if (getVerbosity() >= QUDA_VERBOSE) {
140  printfQuda("MR: %d iterations, <r|A|r> = (%e, %e)\n", k, Ar3.x, Ar3.y);
141  }
142  }
143 
144  //Add back initial guess (if appropriate) and scale if necessary
146  double scale = c2 > 0.0 ? sqrt(c2) : 1.0;
147  blas::axpy(scale,y,x);
148  } else {
149  if (c2 > 0.0) blas::axpby(sqrt(c2), y, 0.0, x); // FIXME: if x contains a Nan then this will fail: hence zero of x above
150  }
151  // if not preserving source then overide source with residual
152  if (param.preserve_source == QUDA_PRESERVE_SOURCE_NO && &r != &b) {
153  if (c2 > 0.0) blas::axpby(sqrt(c2), r, 0.0, b);
154  } else {
155  if (c2 > 0.0) blas::ax(sqrt(c2), r);
156  }
157 
158  if (!param.is_preconditioner) {
162 
163  double gflops = (blas::flops + mat.flops() + matSloppy.flops())*1e-9;
164 
165  param.gflops += gflops;
166  param.iter += k;
167 
168  // compute the iterated relative residual
169  if (getVerbosity() >= QUDA_SUMMARIZE) r2 = blas::norm2(r) / b2;
170 
171  // calculate the true sloppy residual
173  mat(r, x, tmp);
174  double true_res = blas::xmyNorm(b, r);
175  param.true_res = sqrt(true_res / b2);
176 
177  if (getVerbosity() >= QUDA_SUMMARIZE) {
178  printfQuda("MR: Converged after %d iterations, relative residual: iterated = %e, true = %e\n",
179  k, sqrt(r2), param.true_res);
180  }
181  } else {
182  if (getVerbosity() >= QUDA_SUMMARIZE) {
183  printfQuda("MR: Converged after %d iterations, relative residual: iterated = %e\n", k, sqrt(r2));
184  }
185  }
186 
187  // reset the flops counters
188  blas::flops = 0;
189  mat.flops();
191  }
192 
193  commGlobalReductionSet(true); // renable global reductions for outer solver
194  return;
195  }
196 
197 } // namespace quda
bool global_reduction
whether the solver acting as a preconditioner for another solver
Definition: invert_quda.h:201
void caxpyXmazMR(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Definition: blas_quda.cu:583
double3 cDotProductNormA(ColorSpinorField &a, ColorSpinorField &b)
Definition: reduce_quda.cu:572
MR(DiracMatrix &mat, DiracMatrix &matSloppy, SolverParam &param, TimeProfile &profile)
Definition: inv_mr_quda.cpp:16
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:20
ColorSpinorField * yp
Definition: invert_quda.h:653
double norm2(const ColorSpinorField &a)
Definition: reduce_quda.cu:241
void init()
Definition: blas_quda.cu:64
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:105
std::complex< double > Complex
Definition: eig_variables.h:13
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:44
static ColorSpinorField * Create(const ColorSpinorParam &param)
TimeProfile & profile
Definition: invert_quda.h:329
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
Definition: copy_quda.cu:263
void ax(const double &a, ColorSpinorField &x)
Definition: blas_quda.cu:209
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:364
QudaPreserveSource preserve_source
Definition: invert_quda.h:121
bool allocate_y
Definition: invert_quda.h:656
QudaGaugeParam param
Definition: pack_test.cpp:17
#define b
double Last(QudaProfileType idx)
const DiracMatrix & matSloppy
Definition: invert_quda.h:649
ColorSpinorParam csParam
Definition: pack_test.cpp:24
ColorSpinorField * Arp
Definition: invert_quda.h:651
bool is_preconditioner
verbosity to use for preconditioner
Definition: invert_quda.h:199
ColorSpinorField * rp
Definition: invert_quda.h:650
bool init
Definition: invert_quda.h:654
void zero(ColorSpinorField &a)
Definition: blas_quda.cu:45
void commAsyncReductionSet(bool global_reduce)
void axpy(const double &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:150
void axpby(const double &a, ColorSpinorField &x, const double &b, ColorSpinorField &y)
Definition: blas_quda.cu:106
SolverParam & param
Definition: invert_quda.h:328
bool allocate_r
Definition: invert_quda.h:655
void operator()(ColorSpinorField &out, ColorSpinorField &in)
Definition: inv_mr_quda.cpp:34
void caxpyXmaz(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Definition: blas_quda.cu:549
unsigned long long flops() const
Definition: dirac_quda.h:995
#define printfQuda(...)
Definition: util_quda.h:84
ColorSpinorField * tmpp
Definition: invert_quda.h:652
unsigned long long flops
Definition: blas_quda.cu:42
virtual ~MR()
Definition: inv_mr_quda.cpp:22
const DiracMatrix & mat
Definition: invert_quda.h:648
QudaUseInitGuess use_init_guess
Definition: invert_quda.h:50
QudaPrecision precision_sloppy
Definition: invert_quda.h:115
double omega
Definition: test_util.cpp:1663
double3 cDotProductNormB(ColorSpinorField &a, ColorSpinorField &b)
Definition: reduce_quda.cu:599
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
void commGlobalReductionSet(bool global_reduce)