QUDA v0.3.2
A library for QCD on GPUs

quda/lib/inv_bicgstab_quda.cpp

Go to the documentation of this file.
00001 #include <stdio.h>
00002 #include <stdlib.h>
00003 #include <math.h>
00004 
00005 #include <complex>
00006 
00007 #include <quda_internal.h>
00008 #include <blas_quda.h>
00009 #include <dslash_quda.h>
00010 #include <invert_quda.h>
00011 #include <util_quda.h>
00012 
00013 #include <color_spinor_field.h>
00014 
00015 void invertBiCGstabCuda(const DiracMatrix &mat, const DiracMatrix &matSloppy, cudaColorSpinorField &x, 
00016                         cudaColorSpinorField &b, QudaInvertParam *invert_param)
00017 {
00018   typedef std::complex<double> Complex;
00019 
00020   ColorSpinorParam param;
00021   param.create = QUDA_ZERO_FIELD_CREATE;
00022   cudaColorSpinorField y(x, param);
00023   cudaColorSpinorField r(x, param); 
00024 
00025   param.precision = invert_param->cuda_prec_sloppy;
00026   cudaColorSpinorField p(x, param);
00027   cudaColorSpinorField v(x, param);
00028   cudaColorSpinorField tmp(x, param);
00029   cudaColorSpinorField t(x, param);
00030 
00031   cudaColorSpinorField *x_sloppy, *r_sloppy, *r_0;
00032   if (invert_param->cuda_prec_sloppy == x.Precision()) {
00033     param.create = QUDA_REFERENCE_FIELD_CREATE;
00034     x_sloppy = &x;
00035     r_sloppy = &r;
00036     r_0 = &b;
00037     zeroCuda(*x_sloppy);
00038     copyCuda(*r_sloppy, b);
00039   } else {
00040     x_sloppy = new cudaColorSpinorField(x, param);
00041     param.create = QUDA_COPY_FIELD_CREATE;
00042     r_sloppy = new cudaColorSpinorField(b, param);
00043     r_0 = new cudaColorSpinorField(b, param);
00044   }
00045 
00046   // Syntatic sugar
00047   cudaColorSpinorField &rSloppy = *r_sloppy;
00048   cudaColorSpinorField &xSloppy = *x_sloppy;
00049   cudaColorSpinorField &r0 = *r_0;
00050 
00051   double b2 = normCuda(b);
00052 
00053   double r2 = b2;
00054   double stop = b2*invert_param->tol*invert_param->tol; // stopping condition of solver
00055   double delta = invert_param->reliable_delta;
00056 
00057   int k = 0;
00058   int rUpdate = 0;
00059   
00060   Complex rho(1.0, 0.0);
00061   Complex rho0 = rho;
00062   Complex alpha(1.0, 0.0);
00063   Complex omega(1.0, 0.0);
00064   Complex beta;
00065 
00066   double3 rho_r2;
00067   double3 omega_t2;
00068   
00069   double rNorm = sqrt(r2);
00070   double r0Norm = rNorm;
00071   double maxrr = rNorm;
00072   double maxrx = rNorm;
00073 
00074   if (invert_param->verbosity >= QUDA_VERBOSE) printfQuda("BiCGstab: %d iterations, r2 = %e\n", k, r2);
00075 
00076   blas_quda_flops = 0;
00077 
00078   stopwatchStart();
00079 
00080   while (r2 > stop && k<invert_param->maxiter) {
00081     
00082     if (k==0) {
00083       rho = r2; // cDotProductCuda(r0, r_sloppy); // BiCRstab
00084       copyCuda(p, rSloppy);
00085     } else {
00086       if (abs(rho*alpha) == 0.0) beta = 0.0;
00087       else beta = (rho/rho0) * (alpha/omega);
00088 
00089       cxpaypbzCuda(rSloppy, -beta*omega, v, beta, p);
00090     }
00091     
00092     matSloppy(v, p, tmp);
00093 
00094     if (abs(rho) == 0.0) alpha = 0.0;
00095     else alpha = rho / cDotProductCuda(r0, v);
00096 
00097     // r -= alpha*v
00098     caxpyCuda(-alpha, v, rSloppy);
00099 
00100     matSloppy(t, rSloppy, tmp);
00101     
00102     // omega = (t, r) / (t, t)
00103     omega_t2 = cDotProductNormACuda(t, rSloppy);
00104     omega = Complex(omega_t2.x / omega_t2.z, omega_t2.y / omega_t2.z);
00105 
00106     //x += alpha*p + omega*r, r -= omega*t, r2 = (r,r), rho = (r0, r)
00107     rho_r2 = caxpbypzYmbwcDotProductWYNormYCuda(alpha, p, omega, rSloppy, xSloppy, t, r0);
00108 
00109     rho0 = rho;
00110     rho = Complex(rho_r2.x, rho_r2.y);
00111     r2 = rho_r2.z;
00112 
00113     // reliable updates
00114     rNorm = sqrt(r2);
00115     if (rNorm > maxrx) maxrx = rNorm;
00116     if (rNorm > maxrr) maxrr = rNorm;
00117     //int updateR = (rNorm < delta*maxrr && r0Norm <= maxrr) ? 1 : 0;
00118     //int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? 1 : 0;
00119     
00120     int updateR = (rNorm < delta*maxrr) ? 1 : 0;
00121 
00122     if (updateR) {
00123       if (x.Precision() != xSloppy.Precision()) copyCuda(x, xSloppy);
00124       
00125       xpyCuda(x, y); // swap these around?
00126       mat(r, y, x);
00127       r2 = xmyNormCuda(b, r);
00128 
00129       if (x.Precision() != rSloppy.Precision()) copyCuda(rSloppy, r);            
00130       zeroCuda(xSloppy);
00131 
00132       rNorm = sqrt(r2);
00133       maxrr = rNorm;
00134       maxrx = rNorm;
00135       r0Norm = rNorm;      
00136       rUpdate++;
00137     }
00138     
00139     k++;
00140     if (invert_param->verbosity >= QUDA_VERBOSE) 
00141       printfQuda("BiCGstab: %d iterations, r2 = %e\n", k, r2);
00142   }
00143   
00144   if (x.Precision() != xSloppy.Precision()) copyCuda(x, xSloppy);
00145   xpyCuda(y, x);
00146     
00147   if (k==invert_param->maxiter) warningQuda("Exceeded maximum iterations %d", invert_param->maxiter);
00148 
00149   if (invert_param->verbosity >= QUDA_VERBOSE) printfQuda("BiCGstab: Reliable updates = %d\n", rUpdate);
00150   
00151   invert_param->secs += stopwatchReadSeconds();
00152   
00153   float gflops = (blas_quda_flops + mat.flops() + matSloppy.flops())*1e-9;
00154   //  printfQuda("%f gflops\n", gflops / stopwatchReadSeconds());
00155   invert_param->gflops += gflops;
00156   invert_param->iter += k;
00157   
00158   //#if 0
00159   // Calculate the true residual
00160   mat(r, x);
00161   double true_res = xmyNormCuda(b, r);
00162     
00163   if (invert_param->verbosity >= QUDA_SUMMARIZE)
00164     printfQuda("BiCGstab: Converged after %d iterations, r2 = %e, true_r2 = %e\n", k, sqrt(r2/b2), sqrt(true_res / b2));    
00165   //#endif
00166 
00167   if (invert_param->cuda_prec_sloppy != x.Precision()) {
00168     delete r_0;
00169     delete r_sloppy;
00170     delete x_sloppy;
00171   }
00172 
00173   return;
00174 }
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Defines