QUDA v0.4.0
A library for QCD on GPUs
quda/lib/inv_cg_quda.cpp
Go to the documentation of this file.
00001 #include <stdio.h>
00002 #include <stdlib.h>
00003 #include <math.h>
00004 
00005 #include <quda_internal.h>
00006 #include <color_spinor_field.h>
00007 #include <blas_quda.h>
00008 #include <dslash_quda.h>
00009 #include <invert_quda.h>
00010 #include <util_quda.h>
00011 #include <sys/time.h>
00012 
00013 #include <face_quda.h>
00014 
00015 #include <iostream>
00016 
00017 CG::CG(DiracMatrix &mat, DiracMatrix &matSloppy, QudaInvertParam &invParam) :
00018   Solver(invParam), mat(mat), matSloppy(matSloppy)
00019 {
00020 
00021 }
00022 
00023 CG::~CG() {
00024 
00025 }
00026 
00027 void CG::operator()(cudaColorSpinorField &x, cudaColorSpinorField &b) 
00028 {
00029   int k=0;
00030   int rUpdate = 0;
00031     
00032   cudaColorSpinorField r(b);
00033 
00034   ColorSpinorParam param(x);
00035   param.create = QUDA_ZERO_FIELD_CREATE;
00036   cudaColorSpinorField y(b, param); 
00037   
00038   mat(r, x, y);
00039   zeroCuda(y);
00040 
00041   double r2 = xmyNormCuda(b, r);
00042   rUpdate ++;
00043   
00044   param.precision = invParam.cuda_prec_sloppy;
00045   cudaColorSpinorField Ap(x, param);
00046   cudaColorSpinorField tmp(x, param);
00047   cudaColorSpinorField tmp2(x, param); // only needed for clover and twisted mass
00048 
00049   cudaColorSpinorField *x_sloppy, *r_sloppy;
00050   if (invParam.cuda_prec_sloppy == x.Precision()) {
00051     param.create = QUDA_REFERENCE_FIELD_CREATE;
00052     x_sloppy = &x;
00053     r_sloppy = &r;
00054   } else {
00055     param.create = QUDA_COPY_FIELD_CREATE;
00056     x_sloppy = new cudaColorSpinorField(x, param);
00057     r_sloppy = new cudaColorSpinorField(r, param);
00058   }
00059 
00060   cudaColorSpinorField &xSloppy = *x_sloppy;
00061   cudaColorSpinorField &rSloppy = *r_sloppy;
00062 
00063   cudaColorSpinorField p(rSloppy);
00064 
00065   double r2_old;
00066   double src_norm = norm2(b);
00067   double stop = src_norm*invParam.tol*invParam.tol; // stopping condition of solver
00068 
00069   double alpha=0.0, beta=0.0;
00070   double pAp;
00071 
00072   double rNorm = sqrt(r2);
00073   double r0Norm = rNorm;
00074   double maxrx = rNorm;
00075   double maxrr = rNorm;
00076   double delta = invParam.reliable_delta;
00077 
00078   if (invParam.verbosity == QUDA_DEBUG_VERBOSE) {
00079     double x2 = norm2(x);
00080     double p2 = norm2(p);
00081     printf("CG: %d iterations, r2 = %e, x2 = %e, p2 = %e, alpha = %e, beta = %e\n", 
00082            k, r2, x2, p2, alpha, beta);
00083   } else if (invParam.verbosity >= QUDA_VERBOSE) {
00084     printfQuda("CG: %d iterations, r2 = %e\n", k, r2);
00085   }
00086 
00087   quda::blas_flops = 0;
00088 
00089   stopwatchStart();
00090   while (r2 > stop && k<invParam.maxiter) {
00091 
00092     matSloppy(Ap, p, tmp, tmp2); // tmp as tmp
00093     
00094     pAp = reDotProductCuda(p, Ap);
00095     alpha = r2 / pAp;        
00096     r2_old = r2;
00097     r2 = axpyNormCuda(-alpha, Ap, rSloppy);
00098 
00099     // reliable update conditions
00100     rNorm = sqrt(r2);
00101     if (rNorm > maxrx) maxrx = rNorm;
00102     if (rNorm > maxrr) maxrr = rNorm;
00103     int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? 1 : 0;
00104     int updateR = ((rNorm < delta*maxrr && r0Norm <= maxrr) || updateX) ? 1 : 0;
00105     
00106     if ( !(updateR || updateX)) {
00107       beta = r2 / r2_old;
00108       axpyZpbxCuda(alpha, p, xSloppy, rSloppy, beta);
00109     } else {
00110       axpyCuda(alpha, p, xSloppy);
00111       if (x.Precision() != xSloppy.Precision()) copyCuda(x, xSloppy);
00112       
00113       xpyCuda(x, y); // swap these around?
00114       mat(r, y, x); // here we can use x as tmp
00115       r2 = xmyNormCuda(b, r);
00116       if (x.Precision() != rSloppy.Precision()) copyCuda(rSloppy, r);            
00117       zeroCuda(xSloppy);
00118 
00119       rNorm = sqrt(r2);
00120       maxrr = rNorm;
00121       maxrx = rNorm;
00122       r0Norm = rNorm;      
00123       rUpdate++;
00124 
00125       beta = r2 / r2_old; 
00126       xpayCuda(rSloppy, beta, p);
00127     }
00128 
00129     k++;
00130     if (invParam.verbosity == QUDA_DEBUG_VERBOSE) {
00131       double x2 = norm2(x);
00132       double p2 = norm2(p);
00133       printf("CG: %d iterations, r2 = %e, x2 = %e, p2 = %e, alpha = %e, beta = %e\n", 
00134              k, r2, x2, p2, alpha, beta);
00135     } else if (invParam.verbosity >= QUDA_VERBOSE) {
00136       printfQuda("CG: %d iterations, r2 = %e\n", k, r2);
00137     }
00138   }
00139 
00140   if (x.Precision() != xSloppy.Precision()) copyCuda(x, xSloppy);
00141   xpyCuda(y, x);
00142 
00143   invParam.secs = stopwatchReadSeconds();
00144 
00145   
00146   if (k==invParam.maxiter) 
00147     warningQuda("Exceeded maximum iterations %d", invParam.maxiter);
00148 
00149   if (invParam.verbosity >= QUDA_SUMMARIZE)
00150     printfQuda("CG: Reliable updates = %d\n", rUpdate);
00151 
00152   double gflops = (quda::blas_flops + mat.flops() + matSloppy.flops())*1e-9;
00153   reduceDouble(gflops);
00154 
00155   //  printfQuda("%f gflops\n", gflops / stopwatchReadSeconds());
00156   invParam.gflops = gflops;
00157   invParam.iter = k;
00158 
00159   quda::blas_flops = 0;
00160 
00161   if (invParam.verbosity >= QUDA_SUMMARIZE){
00162     mat(r, x, y);
00163     double true_res = xmyNormCuda(b, r);
00164     printfQuda("CG: Converged after %d iterations, relative residua: iterated = %e, true = %e\n", 
00165                k, sqrt(r2/src_norm), sqrt(true_res / src_norm));    
00166   }
00167 
00168   if (invParam.cuda_prec_sloppy != x.Precision()) {
00169     delete r_sloppy;
00170     delete x_sloppy;
00171   }
00172 
00173   return;
00174 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines