QUDA  v1.1.0
A library for QCD on GPUs
inv_sbicgstab_quda.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 #include <iostream>
5 
6 #include <quda_internal.h>
7 #include <color_spinor_field.h>
8 #include <blas_quda.h>
9 #include <dslash_quda.h>
10 #include <invert_quda.h>
11 #include <util_quda.h>
12 
13 namespace quda {
14 
16  Solver(param, profile), mat(mat)
17  {
18  }
19 
21  }
22 
24  {
25 
26  // Check to see that we're not trying to invert on a zero-field source
27  const double b2 = norm2(b);
28  if(b2 == 0){
29  profile.TPSTOP(QUDA_PROFILE_INIT);
30  printfQuda("Warning: inverting on zero-field source\n");
31  x=b;
32  param.true_res = 0.0;
33  param.true_res_hq = 0.0;
34  return;
35  }
36 
39 
41 
43 
44 
45 
46  mat(r, x, temp); // r = Ax
47  double r2 = xmyNormCuda(b,r); // r = b - Ax
48 
52  cudaColorSpinorField A2p(r);
54  cudaColorSpinorField r_new(r);
55  cudaColorSpinorField p_new(r);
56  Complex r0r;
57  Complex alpha;
58  Complex omega;
59  Complex beta;
60 
61 
62  double p2 = norm2(p);
63  double stop = stopping(param.tol, b2, param.residual_type);
64  int k=0;
65  while(!convergence(r2, 0.0, stop, 0.0) && k<param.maxiter){
66 
67  PrintStats("SimpleBiCGstab", k, r2, b2, 0.0);
68 
69  mat(Ap,p,temp);
70  mat(A2p,Ap,temp);
71  mat(Ar,r,temp);
72 
73 
74  r0r = cDotProductCuda(r0,r);
75  alpha = r0r/cDotProductCuda(r0,Ap);
76 
77 
78  Complex omega_num = cDotProductCuda(r,Ar)
79  - alpha*cDotProductCuda(r,A2p)
80  - conj(alpha)*cDotProductCuda(Ap,Ar)
81  + conj(alpha)*alpha*cDotProductCuda(Ap,A2p);
82 
83 
84  Complex omega_denom = cDotProductCuda(Ar,Ar)
85  - alpha*cDotProductCuda(Ar,A2p)
86  - conj(alpha)*cDotProductCuda(A2p,Ar)
87  + conj(alpha)*alpha*cDotProductCuda(A2p,A2p);
88 
89 
90  omega = omega_num/omega_denom;
91 
92 
93 
94  // x ---> x + alpha p + omega s
95  caxpyCuda(alpha,p,x);
96  caxpyCuda(omega,r,x);
97  caxpyCuda(-alpha*omega,Ap,x);
98 
99 
100  // r_new = r - omega*Ar - alpha*Ap + alpha*omega*A2p
101  r_new = r;
102  caxpyCuda(-omega,Ar,r_new);
103  caxpyCuda(-alpha,Ap,r_new);
104  caxpyCuda(alpha*omega,A2p,r_new);
105 
106  beta = (cDotProductCuda(r0,r_new)/r0r)*(alpha/omega);
107 
108 
109  // p = r_new + beta p - omega*beta Ap
110  p_new = r_new;
111  caxpyCuda(beta, p, p_new);
112  caxpyCuda(-beta*omega, Ap, p_new);
113 
114  p = p_new;
115  r = r_new;
116  r2 = norm2(r);
117  p2 = norm2(p);
118  k++;
119  }
120 
121 
122  if(k == param.maxiter)
123  warningQuda("Exceeded maximum iterations %d", param.maxiter);
124 
125  // compute the true residual
126  mat(r, x, temp);
127  param.true_res = sqrt(xmyNormCuda(b, r)/b2);
128 
129  PrintSummary("SimpleBiCGstab", k, r2, b2, stop, param.tol_hq);
130 
131  return;
132  }
133 
134 } // namespace quda
void operator()(ColorSpinorField &out, ColorSpinorField &in)
SimpleBiCGstab(const DiracMatrix &mat, SolverParam &param, TimeProfile &profile)
TimeProfile & profile
Definition: invert_quda.h:471
const DiracMatrix & mat
Definition: invert_quda.h:465
bool convergence(double r2, double hq2, double r2_tol, double hq_tol)
Definition: solver.cpp:328
void PrintSummary(const char *name, int k, double r2, double b2, double r2_tol, double hq_tol)
Prints out the summary of the solver convergence (requires a verbosity of QUDA_SUMMARIZE)....
Definition: solver.cpp:386
SolverParam & param
Definition: invert_quda.h:470
static double stopping(double tol, double b2, QudaResidualType residual_type)
Set the solver L2 stopping condition.
Definition: solver.cpp:311
void PrintStats(const char *name, int k, double r2, double b2, double hq2)
Prints out the running statistics of the solver (requires a verbosity of QUDA_VERBOSE)
Definition: solver.cpp:373
double omega
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
@ QUDA_ZERO_FIELD_CREATE
Definition: enum_quda.h:361
void stop()
Stop profiling.
Definition: device.cpp:228
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:130
double norm2(const CloverField &a, bool inverse=false)
std::complex< double > Complex
Definition: quda_internal.h:86
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:120
@ QUDA_PROFILE_INIT
Definition: timer.h:106
ColorSpinorParam csParam
Definition: pack_test.cpp:25
QudaGaugeParam param
Definition: pack_test.cpp:18
QudaResidualType residual_type
Definition: invert_quda.h:49
#define printfQuda(...)
Definition: util_quda.h:114
#define warningQuda(...)
Definition: util_quda.h:132