QUDA  v0.5.0
A library for QCD on GPUs
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
inv_mr_quda.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 
5 #include <complex>
6 
7 #include <quda_internal.h>
8 #include <blas_quda.h>
9 #include <dslash_quda.h>
10 #include <invert_quda.h>
11 #include <util_quda.h>
12 
13 #include<face_quda.h>
14 
15 #include <color_spinor_field.h>
16 
17 namespace quda {
18 
20  Solver(invParam, profile), mat(mat), init(false), allocate_r(false)
21  {
22 
23  }
24 
25  MR::~MR() {
27  if (init) {
28  if (allocate_r) delete rp;
29  delete Arp;
30  delete tmpp;
31  }
33  }
34 
36  {
37 
38  globalReduce = false; // use local reductions for DD solver
39 
40  if (!init) {
44  rp = new cudaColorSpinorField(x, param);
45  allocate_r = true;
46  }
47  Arp = new cudaColorSpinorField(x);
48  tmpp = new cudaColorSpinorField(x, param); //temporary for mat-vec
49 
50  init = true;
51  }
54  cudaColorSpinorField &Ar = *Arp;
55  cudaColorSpinorField &tmp = *tmpp;
56 
57  // set initial guess to zero and thus the residual is just the source
58  zeroCuda(x); // can get rid of this for a special first update kernel
59  double b2 = normCuda(b);
60  if (&r != &b) copyCuda(r, b);
61 
62  // domain-wise normalization of the initial residual to prevent underflow
63  double r2=0.0; // if zero source then we will exit immediately doing no work
64  if (b2 > 0.0) {
65  axCuda(1/sqrt(b2), r); // can merge this with the prior copy
66  r2 = 1.0; // by definition by this is now true
67  }
68  double stop = b2*invParam.tol*invParam.tol; // stopping condition of solver
69 
71  quda::blas_flops = 0;
73  }
74 
75  double omega = 1.0;
76 
77  int k = 0;
79  double x2 = norm2(x);
80  double3 Ar3 = cDotProductNormBCuda(Ar, r);
81  printfQuda("MR: %d iterations, r2 = %e, <r|A|r> = (%e, %e), x2 = %e\n",
82  k, Ar3.x, Ar3.y, Ar3.z, x2);
83  }
84 
85  while (r2 > stop && k < invParam.maxiter) {
86 
87  mat(Ar, r, tmp);
88 
89  double3 Ar3 = cDotProductNormACuda(Ar, r);
90  Complex alpha = Complex(Ar3.x, Ar3.y) / Ar3.z;
91 
92  // x += omega*alpha*r, r -= omega*alpha*Ar, r2 = norm2(r)
93  //r2 = caxpyXmazNormXCuda(omega*alpha, r, x, Ar);
94  caxpyXmazCuda(omega*alpha, r, x, Ar);
95 
97  double x2 = norm2(x);
98  double r2 = norm2(r);
99  printfQuda("MR: %d iterations, r2 = %e, <r|A|r> = (%e,%e) x2 = %e\n",
100  k+1, r2, Ar3.x, Ar3.y, x2);
101  } else if (invParam.verbosity >= QUDA_VERBOSE) {
102  printfQuda("MR: %d iterations, <r|A|r> = (%e, %e)\n", k, Ar3.x, Ar3.y);
103  }
104 
105  k++;
106  }
107 
109  mat(Ar, r, tmp);
110  Complex Ar2 = cDotProductCuda(Ar, r);
111  printfQuda("MR: %d iterations, <r|A|r> = (%e, %e)\n", k, real(Ar2), imag(Ar2));
112  }
113 
114  // Obtain global solution by rescaling
115  if (b2 > 0.0) axCuda(sqrt(b2), x);
116 
118  warningQuda("Exceeded maximum iterations %d", invParam.maxiter);
119 
124 
125  double gflops = (quda::blas_flops + mat.flops())*1e-9;
126  reduceDouble(gflops);
127 
128  invParam.gflops += gflops;
129  invParam.iter += k;
130 
131  // Calculate the true residual
132  r2 = norm2(r);
133  mat(r, x);
134  double true_res = xmyNormCuda(b, r);
135  invParam.true_res = sqrt(true_res / b2);
136 
138  printfQuda("MR: Converged after %d iterations, relative residua: iterated = %e, true = %e\n",
139  k, sqrt(r2/b2), invParam.true_res);
140  }
141 
142  // reset the flops counters
143  quda::blas_flops = 0;
144  mat.flops();
146  }
147 
148  globalReduce = true; // renable global reductions for outer solver
149 
150  return;
151  }
152 
153 } // namespace quda