QUDA v0.4.0
A library for QCD on GPUs
quda/lib/inv_bicgstab_quda.cpp
Go to the documentation of this file.
00001 #include <stdio.h>
00002 #include <stdlib.h>
00003 #include <math.h>
00004 
00005 #include <complex>
00006 
00007 #include <quda_internal.h>
00008 #include <blas_quda.h>
00009 #include <dslash_quda.h>
00010 #include <invert_quda.h>
00011 #include <util_quda.h>
00012 
00013 #include<face_quda.h>
00014 
00015 #include <color_spinor_field.h>
00016 
00017 // set the required parameters for the inner solver
00018 void fillInnerInvertParam(QudaInvertParam &inner, const QudaInvertParam &outer);
00019 
00020 double resNorm(const DiracMatrix &mat, cudaColorSpinorField &b, cudaColorSpinorField &x) {  
00021   cudaColorSpinorField r(b);
00022   mat(r, x);
00023   return xmyNormCuda(b, r);
00024 }
00025 
00026 
00027 BiCGstab::BiCGstab(DiracMatrix &mat, DiracMatrix &matSloppy, DiracMatrix &matPrecon, QudaInvertParam &invParam) :
00028   Solver(invParam), mat(mat), matSloppy(matSloppy), matPrecon(matPrecon), init(false) {
00029 
00030 }
00031 
00032 BiCGstab::~BiCGstab() {
00033   if(init) {
00034     if (wp && wp != pp) delete wp;
00035     if (zp && zp != pp) delete zp;
00036     delete yp;
00037     delete rp;
00038     delete pp;
00039     delete vp;
00040     delete tmpp;
00041     delete tp;
00042   }
00043 }
00044 
00045 void BiCGstab::operator()(cudaColorSpinorField &x, cudaColorSpinorField &b) 
00046 {
00047   if (!init) {
00048     ColorSpinorParam csParam(x);
00049     csParam.create = QUDA_ZERO_FIELD_CREATE;
00050     yp = new cudaColorSpinorField(x, csParam);
00051     rp = new cudaColorSpinorField(x, csParam); 
00052     csParam.precision = invParam.cuda_prec_sloppy;
00053     pp = new cudaColorSpinorField(x, csParam);
00054     vp = new cudaColorSpinorField(x, csParam);
00055     tmpp = new cudaColorSpinorField(x, csParam);
00056     tp = new cudaColorSpinorField(x, csParam);
00057 
00058     // MR preconditioner - we need extra vectors
00059     if (invParam.inv_type_precondition == QUDA_MR_INVERTER) {
00060       wp = new cudaColorSpinorField(x, csParam);
00061       zp = new cudaColorSpinorField(x, csParam);
00062     } else { // dummy assignments
00063       wp = pp;
00064       zp = pp;
00065     }
00066 
00067     init = true;
00068   }
00069 
00070   cudaColorSpinorField &y = *yp;
00071   cudaColorSpinorField &r = *rp; 
00072   cudaColorSpinorField &p = *pp;
00073   cudaColorSpinorField &v = *vp;
00074   cudaColorSpinorField &tmp = *tmpp;
00075   cudaColorSpinorField &t = *tp;
00076 
00077   cudaColorSpinorField &w = *wp;
00078   cudaColorSpinorField &z = *zp;
00079 
00080   cudaColorSpinorField *x_sloppy, *r_sloppy, *r_0;
00081 
00082   if (invParam.cuda_prec_sloppy == x.Precision()) {
00083     x_sloppy = &x;
00084     r_sloppy = &r;
00085     r_0 = &b;
00086     zeroCuda(*x_sloppy);
00087     copyCuda(*r_sloppy, b);
00088   } else {
00089     ColorSpinorParam csParam(x);
00090     csParam.create = QUDA_ZERO_FIELD_CREATE;
00091     csParam.precision = invParam.cuda_prec_sloppy;
00092     x_sloppy = new cudaColorSpinorField(x, csParam);
00093     csParam.create = QUDA_COPY_FIELD_CREATE;
00094     r_sloppy = new cudaColorSpinorField(b, csParam);
00095     r_0 = new cudaColorSpinorField(b, csParam);
00096   }
00097 
00098   // Syntatic sugar
00099   cudaColorSpinorField &rSloppy = *r_sloppy;
00100   cudaColorSpinorField &xSloppy = *x_sloppy;
00101   cudaColorSpinorField &r0 = *r_0;
00102 
00103   QudaInvertParam invert_param_inner = newQudaInvertParam();
00104   fillInnerInvertParam(invert_param_inner, invParam);
00105 
00106   double b2 = normCuda(b);
00107 
00108   double r2 = b2;
00109   double stop = b2*invParam.tol*invParam.tol; // stopping condition of solver
00110   double delta = invParam.reliable_delta;
00111 
00112   int k = 0;
00113   int rUpdate = 0;
00114   
00115   quda::Complex rho(1.0, 0.0);
00116   quda::Complex rho0 = rho;
00117   quda::Complex alpha(1.0, 0.0);
00118   quda::Complex omega(1.0, 0.0);
00119   quda::Complex beta;
00120 
00121   double3 rho_r2;
00122   double3 omega_t2;
00123   
00124   double rNorm = sqrt(r2);
00125   //double r0Norm = rNorm;
00126   double maxrr = rNorm;
00127   double maxrx = rNorm;
00128 
00129   if (invParam.verbosity >= QUDA_VERBOSE) printfQuda("BiCGstab: %d iterations, r2 = %e\n", k, r2);
00130 
00131   if (invParam.inv_type_precondition != QUDA_GCR_INVERTER) { // do not do the below if we this is an inner solver
00132     quda::blas_flops = 0;    
00133     stopwatchStart();
00134   }
00135 
00136   while (r2 > stop && k<invParam.maxiter) {
00137     
00138     if (k==0) {
00139       rho = r2; // cDotProductCuda(r0, r_sloppy); // BiCRstab
00140       copyCuda(p, rSloppy);
00141     } else {
00142       if (abs(rho*alpha) == 0.0) beta = 0.0;
00143       else beta = (rho/rho0) * (alpha/omega);
00144 
00145       cxpaypbzCuda(rSloppy, -beta*omega, v, beta, p);
00146     }
00147     
00148     if (invParam.inv_type_precondition == QUDA_MR_INVERTER) {
00149       errorQuda("Temporary disabled");
00150       //invertMRCuda(*matPrecon, w, p, &invert_param_inner);
00151       matSloppy(v, w, tmp);
00152     } else {
00153       matSloppy(v, p, tmp);
00154     }
00155 
00156     if (abs(rho) == 0.0) alpha = 0.0;
00157     else alpha = rho / cDotProductCuda(r0, v);
00158 
00159     // r -= alpha*v
00160     caxpyCuda(-alpha, v, rSloppy);
00161 
00162     if (invParam.inv_type_precondition == QUDA_MR_INVERTER) {
00163       errorQuda("Temporary disabled");
00164       //invertMRCuda(*matPrecon, z, rSloppy, &invert_param_inner);
00165       matSloppy(t, z, tmp);
00166     } else {
00167       matSloppy(t, rSloppy, tmp);
00168     }
00169     
00170     // omega = (t, r) / (t, t)
00171     omega_t2 = cDotProductNormACuda(t, rSloppy);
00172     omega = quda::Complex(omega_t2.x / omega_t2.z, omega_t2.y / omega_t2.z);
00173 
00174     if (invParam.inv_type_precondition == QUDA_MR_INVERTER) {
00175       //x += alpha*w + omega*z, r -= omega*t, r2 = (r,r), rho = (r0, r)
00176       caxpyCuda(alpha, w, xSloppy);
00177       caxpyCuda(omega, z, xSloppy);
00178       caxpyCuda(-omega, t, rSloppy);
00179       rho_r2 = cDotProductNormBCuda(r0, rSloppy);
00180     } else {
00181       //x += alpha*p + omega*r, r -= omega*t, r2 = (r,r), rho = (r0, r)
00182       rho_r2 = caxpbypzYmbwcDotProductUYNormYCuda(alpha, p, omega, rSloppy, xSloppy, t, r0);
00183     }
00184 
00185     rho0 = rho;
00186     rho = quda::Complex(rho_r2.x, rho_r2.y);
00187     r2 = rho_r2.z;
00188 
00189     if (invParam.verbosity == QUDA_DEBUG_VERBOSE)
00190       printfQuda("DEBUG: %d iterated residual norm = %e, true residual norm = %e\n",
00191                  k, norm2(rSloppy), resNorm(matSloppy, b, xSloppy));
00192 
00193     // reliable updates
00194     rNorm = sqrt(r2);
00195     if (rNorm > maxrx) maxrx = rNorm;
00196     if (rNorm > maxrr) maxrr = rNorm;
00197     //int updateR = (rNorm < delta*maxrr && r0Norm <= maxrr) ? 1 : 0;
00198     //int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? 1 : 0;
00199     
00200     int updateR = (rNorm < delta*maxrr) ? 1 : 0;
00201 
00202     if (updateR) {
00203       if (x.Precision() != xSloppy.Precision()) copyCuda(x, xSloppy);
00204       
00205       xpyCuda(x, y); // swap these around?
00206       mat(r, y, x);
00207       r2 = xmyNormCuda(b, r);
00208 
00209       if (x.Precision() != rSloppy.Precision()) copyCuda(rSloppy, r);            
00210       zeroCuda(xSloppy);
00211 
00212       rNorm = sqrt(r2);
00213       maxrr = rNorm;
00214       maxrx = rNorm;
00215       //r0Norm = rNorm;      
00216       rUpdate++;
00217     }
00218     
00219     k++;
00220     if (invParam.verbosity >= QUDA_VERBOSE) 
00221       printfQuda("BiCGstab: %d iterations, r2 = %e\n", k, r2);
00222   }
00223   
00224   if (x.Precision() != xSloppy.Precision()) copyCuda(x, xSloppy);
00225   xpyCuda(y, x);
00226     
00227   if (k==invParam.maxiter) warningQuda("Exceeded maximum iterations %d", invParam.maxiter);
00228 
00229   if (invParam.verbosity >= QUDA_VERBOSE) printfQuda("BiCGstab: Reliable updates = %d\n", rUpdate);
00230   
00231   if (invParam.inv_type_precondition != QUDA_GCR_INVERTER) { // do not do the below if we this is an inner solver
00232     invParam.secs += stopwatchReadSeconds();
00233 
00234     double gflops = (quda::blas_flops + mat.flops() + matSloppy.flops() + matPrecon.flops())*1e-9;
00235     reduceDouble(gflops);
00236 
00237     //  printfQuda("%f gflops\n", gflops / stopwatchReadSeconds());
00238     invParam.gflops += gflops;
00239     invParam.iter += k;
00240     
00241     if (invParam.verbosity >= QUDA_SUMMARIZE) {
00242       // Calculate the true residual
00243       mat(r, x);
00244       double true_res = xmyNormCuda(b, r);
00245       
00246       printfQuda("BiCGstab: Converged after %d iterations, relative residua: iterated = %e, true = %e\n", 
00247                  k, sqrt(r2/b2), sqrt(true_res / b2));    
00248     }
00249   }
00250 
00251   if (invParam.cuda_prec_sloppy != x.Precision()) {
00252     delete r_0;
00253     delete r_sloppy;
00254     delete x_sloppy;
00255   }
00256 
00257   return;
00258 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines