QUDA  0.9.0
inv_cg3_quda.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 #include <iostream>
5 
6 #include <quda_internal.h>
7 #include <color_spinor_field.h>
8 #include <blas_quda.h>
9 #include <dslash_quda.h>
10 #include <invert_quda.h>
11 #include <util_quda.h>
12 
13 namespace quda {
14 
15  CG3::CG3(DiracMatrix &mat, SolverParam &param, TimeProfile &profile) :
16  Solver(param, profile), mat(mat)
17  {
18  }
19 
20  CG3::~CG3() {
21  }
22 
23  void CG3::operator()(cudaColorSpinorField &x, cudaColorSpinorField &b)
24  {
25 
26  // Check to see that we're not trying to invert on a zero-field source
27  const double b2 = norm2(b);
28  if(b2 == 0){
29  profile.TPSTOP(QUDA_PROFILE_INIT);
30  printfQuda("Warning: inverting on zero-field source\n");
31  x=b;
32  param.true_res = 0.0;
33  param.true_res_hq = 0.0;
34  return;
35  }
36 
37  ColorSpinorParam csParam(x);
39 
40 
41  cudaColorSpinorField x_prev(b, csParam);
42  cudaColorSpinorField r_prev(b, csParam);
43  cudaColorSpinorField temp(b, csParam);
44 
45  cudaColorSpinorField r(b);
46  cudaColorSpinorField w(b);
47 
48 
49  mat(r, x, temp); // r = Mx
50  double r2 = xmyNormCuda(b,r); // r = b - Mx
51  PrintStats("CG3", 0, r2, b2, 0.0);
52 
53 
54  double stop = stopping(param.tol, b2, param.residual_type);
55  if(convergence(r2, 0.0, stop, 0.0)) return;
56  // First iteration
57  mat(w, r, temp);
58  double rAr = reDotProductCuda(r,w);
59  double rho = 1.0;
60  double gamma_prev = 0.0;
61  double gamma = r2/rAr;
62 
63 
64  cudaColorSpinorField x_new(x);
65  cudaColorSpinorField r_new(r);
66  axpyCuda(gamma, r, x_new); // x_new += gamma*r
67  axpyCuda(-gamma, w, r_new); // r_new -= gamma*w
68  // end of first iteration
69 
70  // axpbyCuda(a,b,x,y) => y = a*x + b*y
71 
72  int k = 1; // First iteration performed above
73 
74  double r2_prev;
75  while(!convergence(r2, 0.0, stop, 0.0) && k<param.maxiter){
76  x_prev = x; x = x_new;
77  r_prev = r; r = r_new;
78  mat(w, r, temp);
79  rAr = reDotProductCuda(r,w);
80  r2_prev = r2;
81  r2 = norm2(r);
82 
83  // Need to rearrange this!
84  PrintStats("CG3", k, r2, b2, 0.0);
85 
86  gamma_prev = gamma;
87  gamma = r2/rAr;
88  rho = 1.0/(1. - (gamma/gamma_prev)*(r2/r2_prev)*(1.0/rho));
89 
90  x_new = x;
91  axCuda(rho,x_new);
92  axpyCuda(rho*gamma,r,x_new);
93  axpyCuda((1. - rho),x_prev,x_new);
94 
95  r_new = r;
96  axCuda(rho,r_new);
97  axpyCuda(-rho*gamma,w,r_new);
98  axpyCuda((1.-rho),r_prev,r_new);
99 
100 
101  double rr_old = reDotProductCuda(r_new,r);
102  printfQuda("rr_old = %1.14lf\n", rr_old);
103 
104 
105 
106  k++;
107  }
108 
109 
110  if(k == param.maxiter)
111  warningQuda("Exceeded maximum iterations %d", param.maxiter);
112 
113  // compute the true residual
114  mat(r, x, temp);
115  param.true_res = sqrt(xmyNormCuda(b, r)/b2);
116 
117  PrintSummary("CG3", k, r2, b2);
118 
119  return;
120  }
121 
122 } // namespace quda
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:105
double norm2(const CloverField &a, bool inverse=false)
QudaGaugeParam param
Definition: pack_test.cpp:17
#define b
ColorSpinorParam csParam
Definition: pack_test.cpp:24
int int int w
#define warningQuda(...)
Definition: util_quda.h:101
double gamma(double) __attribute__((availability(macosx
#define printfQuda(...)
Definition: util_quda.h:84
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)