QUDA  v0.7.0
A library for QCD on GPUs
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
inv_sbicgstab_quda.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 
5 #include <quda_internal.h>
6 #include <color_spinor_field.h>
7 #include <blas_quda.h>
8 #include <dslash_quda.h>
9 #include <invert_quda.h>
10 #include <util_quda.h>
11 #include <sys/time.h>
12 
13 #include <face_quda.h>
14 
15 #include <iostream>
16 
17 namespace quda {
18 
20  Solver(param, profile), mat(mat)
21  {
22  }
23 
25  }
26 
28  {
29 
30  // Check to see that we're not trying to invert on a zero-field source
31  const double b2 = norm2(b);
32  if(b2 == 0){
34  printfQuda("Warning: inverting on zero-field source\n");
35  x=b;
36  param.true_res = 0.0;
37  param.true_res_hq = 0.0;
38  return;
39  }
40 
43 
44  cudaColorSpinorField temp(b, csParam);
45 
47 
48 
49 
50  mat(r, x, temp); // r = Ax
51  double r2 = xmyNormCuda(b,r); // r = b - Ax
52 
56  cudaColorSpinorField A2p(r);
58  cudaColorSpinorField r_new(r);
59  cudaColorSpinorField p_new(r);
60  Complex r0r;
61  Complex alpha;
62  Complex omega;
63  Complex beta;
64 
65 
66  double p2 = norm2(p);
67  double stop = stopping(param.tol, b2, param.residual_type);
68  int k=0;
69  while(!convergence(r2, 0.0, stop, 0.0) && k<param.maxiter){
70 
71  PrintStats("SimpleBiCGstab", k, r2, b2, 0.0);
72 
73  mat(Ap,p,temp);
74  mat(A2p,Ap,temp);
75  mat(Ar,r,temp);
76 
77 
78  r0r = cDotProductCuda(r0,r);
79  alpha = r0r/cDotProductCuda(r0,Ap);
80 
81 
82  Complex omega_num = cDotProductCuda(r,Ar)
83  - alpha*cDotProductCuda(r,A2p)
84  - conj(alpha)*cDotProductCuda(Ap,Ar)
85  + conj(alpha)*alpha*cDotProductCuda(Ap,A2p);
86 
87 
88  Complex omega_denom = cDotProductCuda(Ar,Ar)
89  - alpha*cDotProductCuda(Ar,A2p)
90  - conj(alpha)*cDotProductCuda(A2p,Ar)
91  + conj(alpha)*alpha*cDotProductCuda(A2p,A2p);
92 
93 
94  omega = omega_num/omega_denom;
95 
96 
97 
98  // x ---> x + alpha p + omega s
99  caxpyCuda(alpha,p,x);
100  caxpyCuda(omega,r,x);
101  caxpyCuda(-alpha*omega,Ap,x);
102 
103 
104  // r_new = r - omega*Ar - alpha*Ap + alpha*omega*A2p
105  r_new = r;
106  caxpyCuda(-omega,Ar,r_new);
107  caxpyCuda(-alpha,Ap,r_new);
108  caxpyCuda(alpha*omega,A2p,r_new);
109 
110  beta = (cDotProductCuda(r0,r_new)/r0r)*(alpha/omega);
111 
112 
113  // p = r_new + beta p - omega*beta Ap
114  p_new = r_new;
115  caxpyCuda(beta, p, p_new);
116  caxpyCuda(-beta*omega, Ap, p_new);
117 
118  p = p_new;
119  r = r_new;
120  r2 = norm2(r);
121  p2 = norm2(p);
122  k++;
123  }
124 
125 
126  if(k == param.maxiter)
127  warningQuda("Exceeded maximum iterations %d", param.maxiter);
128 
129  // compute the true residual
130  mat(r, x, temp);
131  param.true_res = sqrt(xmyNormCuda(b, r)/b2);
132 
133  PrintSummary("SimpleBiCGstab", k, r2, b2);
134 
135  return;
136  }
137 
138 } // namespace quda
bool convergence(const double &r2, const double &hq2, const double &r2_tol, const double &hq_tol)
Definition: solver.cpp:82
void caxpyCuda(const Complex &a, cudaColorSpinorField &x, cudaColorSpinorField &y)
Definition: blas_quda.cu:207
static double stopping(const double &tol, const double &b2, QudaResidualType residual_type)
Definition: solver.cpp:65
void operator()(cudaColorSpinorField &out, cudaColorSpinorField &in)
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:105
std::complex< double > Complex
Definition: eig_variables.h:13
void mat(void *out, void **fatlink, void **longlink, void *in, double kappa, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision)
TimeProfile & profile
Definition: invert_quda.h:224
QudaGaugeParam param
Definition: pack_test.cpp:17
SimpleBiCGstab(DiracMatrix &mat, SolverParam &param, TimeProfile &profile)
void PrintSummary(const char *name, int k, const double &r2, const double &b2)
Definition: solver.cpp:137
QudaResidualType residual_type
Definition: invert_quda.h:35
Complex cDotProductCuda(cudaColorSpinorField &, cudaColorSpinorField &)
Definition: reduce_quda.cu:468
ColorSpinorParam csParam
Definition: pack_test.cpp:24
#define warningQuda(...)
Definition: util_quda.h:84
int x[4]
SolverParam & param
Definition: invert_quda.h:223
void Stop(QudaProfileType idx)
void PrintStats(const char *, int k, const double &r2, const double &b2, const double &hq2)
Definition: solver.cpp:122
#define printfQuda(...)
Definition: util_quda.h:67
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:115
double norm2(const ColorSpinorField &)
double xmyNormCuda(cudaColorSpinorField &a, cudaColorSpinorField &b)
Definition: reduce_quda.cu:343