QUDA v0.4.0
A library for QCD on GPUs
|
00001 #include <stdio.h> 00002 #include <stdlib.h> 00003 #include <math.h> 00004 00005 #include <complex> 00006 00007 #include <quda_internal.h> 00008 #include <blas_quda.h> 00009 #include <dslash_quda.h> 00010 #include <invert_quda.h> 00011 #include <util_quda.h> 00012 00013 #include<face_quda.h> 00014 00015 #include <color_spinor_field.h> 00016 00017 // set the required parameters for the inner solver 00018 void fillInnerInvertParam(QudaInvertParam &inner, const QudaInvertParam &outer); 00019 00020 double resNorm(const DiracMatrix &mat, cudaColorSpinorField &b, cudaColorSpinorField &x) { 00021 cudaColorSpinorField r(b); 00022 mat(r, x); 00023 return xmyNormCuda(b, r); 00024 } 00025 00026 00027 BiCGstab::BiCGstab(DiracMatrix &mat, DiracMatrix &matSloppy, DiracMatrix &matPrecon, QudaInvertParam &invParam) : 00028 Solver(invParam), mat(mat), matSloppy(matSloppy), matPrecon(matPrecon), init(false) { 00029 00030 } 00031 00032 BiCGstab::~BiCGstab() { 00033 if(init) { 00034 if (wp && wp != pp) delete wp; 00035 if (zp && zp != pp) delete zp; 00036 delete yp; 00037 delete rp; 00038 delete pp; 00039 delete vp; 00040 delete tmpp; 00041 delete tp; 00042 } 00043 } 00044 00045 void BiCGstab::operator()(cudaColorSpinorField &x, cudaColorSpinorField &b) 00046 { 00047 if (!init) { 00048 ColorSpinorParam csParam(x); 00049 csParam.create = QUDA_ZERO_FIELD_CREATE; 00050 yp = new cudaColorSpinorField(x, csParam); 00051 rp = new cudaColorSpinorField(x, csParam); 00052 csParam.precision = invParam.cuda_prec_sloppy; 00053 pp = new cudaColorSpinorField(x, csParam); 00054 vp = new cudaColorSpinorField(x, csParam); 00055 tmpp = new cudaColorSpinorField(x, csParam); 00056 tp = new cudaColorSpinorField(x, csParam); 00057 00058 // MR preconditioner - we need extra vectors 00059 if (invParam.inv_type_precondition == QUDA_MR_INVERTER) { 00060 wp = new cudaColorSpinorField(x, csParam); 00061 zp = new cudaColorSpinorField(x, csParam); 00062 } else { // dummy assignments 00063 wp = pp; 00064 zp = pp; 00065 } 00066 00067 init = true; 00068 } 00069 00070 cudaColorSpinorField &y = *yp; 00071 cudaColorSpinorField &r = *rp; 00072 cudaColorSpinorField &p = *pp; 00073 cudaColorSpinorField &v = *vp; 00074 cudaColorSpinorField &tmp = *tmpp; 00075 cudaColorSpinorField &t = *tp; 00076 00077 cudaColorSpinorField &w = *wp; 00078 cudaColorSpinorField &z = *zp; 00079 00080 cudaColorSpinorField *x_sloppy, *r_sloppy, *r_0; 00081 00082 if (invParam.cuda_prec_sloppy == x.Precision()) { 00083 x_sloppy = &x; 00084 r_sloppy = &r; 00085 r_0 = &b; 00086 zeroCuda(*x_sloppy); 00087 copyCuda(*r_sloppy, b); 00088 } else { 00089 ColorSpinorParam csParam(x); 00090 csParam.create = QUDA_ZERO_FIELD_CREATE; 00091 csParam.precision = invParam.cuda_prec_sloppy; 00092 x_sloppy = new cudaColorSpinorField(x, csParam); 00093 csParam.create = QUDA_COPY_FIELD_CREATE; 00094 r_sloppy = new cudaColorSpinorField(b, csParam); 00095 r_0 = new cudaColorSpinorField(b, csParam); 00096 } 00097 00098 // Syntatic sugar 00099 cudaColorSpinorField &rSloppy = *r_sloppy; 00100 cudaColorSpinorField &xSloppy = *x_sloppy; 00101 cudaColorSpinorField &r0 = *r_0; 00102 00103 QudaInvertParam invert_param_inner = newQudaInvertParam(); 00104 fillInnerInvertParam(invert_param_inner, invParam); 00105 00106 double b2 = normCuda(b); 00107 00108 double r2 = b2; 00109 double stop = b2*invParam.tol*invParam.tol; // stopping condition of solver 00110 double delta = invParam.reliable_delta; 00111 00112 int k = 0; 00113 int rUpdate = 0; 00114 00115 quda::Complex rho(1.0, 0.0); 00116 quda::Complex rho0 = rho; 00117 quda::Complex alpha(1.0, 0.0); 00118 quda::Complex omega(1.0, 0.0); 00119 quda::Complex beta; 00120 00121 double3 rho_r2; 00122 double3 omega_t2; 00123 00124 double rNorm = sqrt(r2); 00125 //double r0Norm = rNorm; 00126 double maxrr = rNorm; 00127 double maxrx = rNorm; 00128 00129 if (invParam.verbosity >= QUDA_VERBOSE) printfQuda("BiCGstab: %d iterations, r2 = %e\n", k, r2); 00130 00131 if (invParam.inv_type_precondition != QUDA_GCR_INVERTER) { // do not do the below if we this is an inner solver 00132 quda::blas_flops = 0; 00133 stopwatchStart(); 00134 } 00135 00136 while (r2 > stop && k<invParam.maxiter) { 00137 00138 if (k==0) { 00139 rho = r2; // cDotProductCuda(r0, r_sloppy); // BiCRstab 00140 copyCuda(p, rSloppy); 00141 } else { 00142 if (abs(rho*alpha) == 0.0) beta = 0.0; 00143 else beta = (rho/rho0) * (alpha/omega); 00144 00145 cxpaypbzCuda(rSloppy, -beta*omega, v, beta, p); 00146 } 00147 00148 if (invParam.inv_type_precondition == QUDA_MR_INVERTER) { 00149 errorQuda("Temporary disabled"); 00150 //invertMRCuda(*matPrecon, w, p, &invert_param_inner); 00151 matSloppy(v, w, tmp); 00152 } else { 00153 matSloppy(v, p, tmp); 00154 } 00155 00156 if (abs(rho) == 0.0) alpha = 0.0; 00157 else alpha = rho / cDotProductCuda(r0, v); 00158 00159 // r -= alpha*v 00160 caxpyCuda(-alpha, v, rSloppy); 00161 00162 if (invParam.inv_type_precondition == QUDA_MR_INVERTER) { 00163 errorQuda("Temporary disabled"); 00164 //invertMRCuda(*matPrecon, z, rSloppy, &invert_param_inner); 00165 matSloppy(t, z, tmp); 00166 } else { 00167 matSloppy(t, rSloppy, tmp); 00168 } 00169 00170 // omega = (t, r) / (t, t) 00171 omega_t2 = cDotProductNormACuda(t, rSloppy); 00172 omega = quda::Complex(omega_t2.x / omega_t2.z, omega_t2.y / omega_t2.z); 00173 00174 if (invParam.inv_type_precondition == QUDA_MR_INVERTER) { 00175 //x += alpha*w + omega*z, r -= omega*t, r2 = (r,r), rho = (r0, r) 00176 caxpyCuda(alpha, w, xSloppy); 00177 caxpyCuda(omega, z, xSloppy); 00178 caxpyCuda(-omega, t, rSloppy); 00179 rho_r2 = cDotProductNormBCuda(r0, rSloppy); 00180 } else { 00181 //x += alpha*p + omega*r, r -= omega*t, r2 = (r,r), rho = (r0, r) 00182 rho_r2 = caxpbypzYmbwcDotProductUYNormYCuda(alpha, p, omega, rSloppy, xSloppy, t, r0); 00183 } 00184 00185 rho0 = rho; 00186 rho = quda::Complex(rho_r2.x, rho_r2.y); 00187 r2 = rho_r2.z; 00188 00189 if (invParam.verbosity == QUDA_DEBUG_VERBOSE) 00190 printfQuda("DEBUG: %d iterated residual norm = %e, true residual norm = %e\n", 00191 k, norm2(rSloppy), resNorm(matSloppy, b, xSloppy)); 00192 00193 // reliable updates 00194 rNorm = sqrt(r2); 00195 if (rNorm > maxrx) maxrx = rNorm; 00196 if (rNorm > maxrr) maxrr = rNorm; 00197 //int updateR = (rNorm < delta*maxrr && r0Norm <= maxrr) ? 1 : 0; 00198 //int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? 1 : 0; 00199 00200 int updateR = (rNorm < delta*maxrr) ? 1 : 0; 00201 00202 if (updateR) { 00203 if (x.Precision() != xSloppy.Precision()) copyCuda(x, xSloppy); 00204 00205 xpyCuda(x, y); // swap these around? 00206 mat(r, y, x); 00207 r2 = xmyNormCuda(b, r); 00208 00209 if (x.Precision() != rSloppy.Precision()) copyCuda(rSloppy, r); 00210 zeroCuda(xSloppy); 00211 00212 rNorm = sqrt(r2); 00213 maxrr = rNorm; 00214 maxrx = rNorm; 00215 //r0Norm = rNorm; 00216 rUpdate++; 00217 } 00218 00219 k++; 00220 if (invParam.verbosity >= QUDA_VERBOSE) 00221 printfQuda("BiCGstab: %d iterations, r2 = %e\n", k, r2); 00222 } 00223 00224 if (x.Precision() != xSloppy.Precision()) copyCuda(x, xSloppy); 00225 xpyCuda(y, x); 00226 00227 if (k==invParam.maxiter) warningQuda("Exceeded maximum iterations %d", invParam.maxiter); 00228 00229 if (invParam.verbosity >= QUDA_VERBOSE) printfQuda("BiCGstab: Reliable updates = %d\n", rUpdate); 00230 00231 if (invParam.inv_type_precondition != QUDA_GCR_INVERTER) { // do not do the below if we this is an inner solver 00232 invParam.secs += stopwatchReadSeconds(); 00233 00234 double gflops = (quda::blas_flops + mat.flops() + matSloppy.flops() + matPrecon.flops())*1e-9; 00235 reduceDouble(gflops); 00236 00237 // printfQuda("%f gflops\n", gflops / stopwatchReadSeconds()); 00238 invParam.gflops += gflops; 00239 invParam.iter += k; 00240 00241 if (invParam.verbosity >= QUDA_SUMMARIZE) { 00242 // Calculate the true residual 00243 mat(r, x); 00244 double true_res = xmyNormCuda(b, r); 00245 00246 printfQuda("BiCGstab: Converged after %d iterations, relative residua: iterated = %e, true = %e\n", 00247 k, sqrt(r2/b2), sqrt(true_res / b2)); 00248 } 00249 } 00250 00251 if (invParam.cuda_prec_sloppy != x.Precision()) { 00252 delete r_0; 00253 delete r_sloppy; 00254 delete x_sloppy; 00255 } 00256 00257 return; 00258 }