|
QUDA v0.3.2
A library for QCD on GPUs
|
00001 #include <stdio.h> 00002 #include <stdlib.h> 00003 #include <math.h> 00004 00005 #include <complex> 00006 00007 #include <quda_internal.h> 00008 #include <blas_quda.h> 00009 #include <dslash_quda.h> 00010 #include <invert_quda.h> 00011 #include <util_quda.h> 00012 00013 #include <color_spinor_field.h> 00014 00015 void invertBiCGstabCuda(const DiracMatrix &mat, const DiracMatrix &matSloppy, cudaColorSpinorField &x, 00016 cudaColorSpinorField &b, QudaInvertParam *invert_param) 00017 { 00018 typedef std::complex<double> Complex; 00019 00020 ColorSpinorParam param; 00021 param.create = QUDA_ZERO_FIELD_CREATE; 00022 cudaColorSpinorField y(x, param); 00023 cudaColorSpinorField r(x, param); 00024 00025 param.precision = invert_param->cuda_prec_sloppy; 00026 cudaColorSpinorField p(x, param); 00027 cudaColorSpinorField v(x, param); 00028 cudaColorSpinorField tmp(x, param); 00029 cudaColorSpinorField t(x, param); 00030 00031 cudaColorSpinorField *x_sloppy, *r_sloppy, *r_0; 00032 if (invert_param->cuda_prec_sloppy == x.Precision()) { 00033 param.create = QUDA_REFERENCE_FIELD_CREATE; 00034 x_sloppy = &x; 00035 r_sloppy = &r; 00036 r_0 = &b; 00037 zeroCuda(*x_sloppy); 00038 copyCuda(*r_sloppy, b); 00039 } else { 00040 x_sloppy = new cudaColorSpinorField(x, param); 00041 param.create = QUDA_COPY_FIELD_CREATE; 00042 r_sloppy = new cudaColorSpinorField(b, param); 00043 r_0 = new cudaColorSpinorField(b, param); 00044 } 00045 00046 // Syntatic sugar 00047 cudaColorSpinorField &rSloppy = *r_sloppy; 00048 cudaColorSpinorField &xSloppy = *x_sloppy; 00049 cudaColorSpinorField &r0 = *r_0; 00050 00051 double b2 = normCuda(b); 00052 00053 double r2 = b2; 00054 double stop = b2*invert_param->tol*invert_param->tol; // stopping condition of solver 00055 double delta = invert_param->reliable_delta; 00056 00057 int k = 0; 00058 int rUpdate = 0; 00059 00060 Complex rho(1.0, 0.0); 00061 Complex rho0 = rho; 00062 Complex alpha(1.0, 0.0); 00063 Complex omega(1.0, 0.0); 00064 Complex beta; 00065 00066 double3 rho_r2; 00067 double3 omega_t2; 00068 00069 double rNorm = sqrt(r2); 00070 double r0Norm = rNorm; 00071 double maxrr = rNorm; 00072 double maxrx = rNorm; 00073 00074 if (invert_param->verbosity >= QUDA_VERBOSE) printfQuda("BiCGstab: %d iterations, r2 = %e\n", k, r2); 00075 00076 blas_quda_flops = 0; 00077 00078 stopwatchStart(); 00079 00080 while (r2 > stop && k<invert_param->maxiter) { 00081 00082 if (k==0) { 00083 rho = r2; // cDotProductCuda(r0, r_sloppy); // BiCRstab 00084 copyCuda(p, rSloppy); 00085 } else { 00086 if (abs(rho*alpha) == 0.0) beta = 0.0; 00087 else beta = (rho/rho0) * (alpha/omega); 00088 00089 cxpaypbzCuda(rSloppy, -beta*omega, v, beta, p); 00090 } 00091 00092 matSloppy(v, p, tmp); 00093 00094 if (abs(rho) == 0.0) alpha = 0.0; 00095 else alpha = rho / cDotProductCuda(r0, v); 00096 00097 // r -= alpha*v 00098 caxpyCuda(-alpha, v, rSloppy); 00099 00100 matSloppy(t, rSloppy, tmp); 00101 00102 // omega = (t, r) / (t, t) 00103 omega_t2 = cDotProductNormACuda(t, rSloppy); 00104 omega = Complex(omega_t2.x / omega_t2.z, omega_t2.y / omega_t2.z); 00105 00106 //x += alpha*p + omega*r, r -= omega*t, r2 = (r,r), rho = (r0, r) 00107 rho_r2 = caxpbypzYmbwcDotProductWYNormYCuda(alpha, p, omega, rSloppy, xSloppy, t, r0); 00108 00109 rho0 = rho; 00110 rho = Complex(rho_r2.x, rho_r2.y); 00111 r2 = rho_r2.z; 00112 00113 // reliable updates 00114 rNorm = sqrt(r2); 00115 if (rNorm > maxrx) maxrx = rNorm; 00116 if (rNorm > maxrr) maxrr = rNorm; 00117 //int updateR = (rNorm < delta*maxrr && r0Norm <= maxrr) ? 1 : 0; 00118 //int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? 1 : 0; 00119 00120 int updateR = (rNorm < delta*maxrr) ? 1 : 0; 00121 00122 if (updateR) { 00123 if (x.Precision() != xSloppy.Precision()) copyCuda(x, xSloppy); 00124 00125 xpyCuda(x, y); // swap these around? 00126 mat(r, y, x); 00127 r2 = xmyNormCuda(b, r); 00128 00129 if (x.Precision() != rSloppy.Precision()) copyCuda(rSloppy, r); 00130 zeroCuda(xSloppy); 00131 00132 rNorm = sqrt(r2); 00133 maxrr = rNorm; 00134 maxrx = rNorm; 00135 r0Norm = rNorm; 00136 rUpdate++; 00137 } 00138 00139 k++; 00140 if (invert_param->verbosity >= QUDA_VERBOSE) 00141 printfQuda("BiCGstab: %d iterations, r2 = %e\n", k, r2); 00142 } 00143 00144 if (x.Precision() != xSloppy.Precision()) copyCuda(x, xSloppy); 00145 xpyCuda(y, x); 00146 00147 if (k==invert_param->maxiter) warningQuda("Exceeded maximum iterations %d", invert_param->maxiter); 00148 00149 if (invert_param->verbosity >= QUDA_VERBOSE) printfQuda("BiCGstab: Reliable updates = %d\n", rUpdate); 00150 00151 invert_param->secs += stopwatchReadSeconds(); 00152 00153 float gflops = (blas_quda_flops + mat.flops() + matSloppy.flops())*1e-9; 00154 // printfQuda("%f gflops\n", gflops / stopwatchReadSeconds()); 00155 invert_param->gflops += gflops; 00156 invert_param->iter += k; 00157 00158 //#if 0 00159 // Calculate the true residual 00160 mat(r, x); 00161 double true_res = xmyNormCuda(b, r); 00162 00163 if (invert_param->verbosity >= QUDA_SUMMARIZE) 00164 printfQuda("BiCGstab: Converged after %d iterations, r2 = %e, true_r2 = %e\n", k, sqrt(r2/b2), sqrt(true_res / b2)); 00165 //#endif 00166 00167 if (invert_param->cuda_prec_sloppy != x.Precision()) { 00168 delete r_0; 00169 delete r_sloppy; 00170 delete x_sloppy; 00171 } 00172 00173 return; 00174 }
1.7.3