QUDA v0.3.2
A library for QCD on GPUs

quda/lib/inv_cg_quda.cpp

Go to the documentation of this file.
00001 #include <stdio.h>
00002 #include <stdlib.h>
00003 #include <math.h>
00004 
00005 #include <quda_internal.h>
00006 #include <color_spinor_field.h>
00007 #include <blas_quda.h>
00008 #include <dslash_quda.h>
00009 #include <invert_quda.h>
00010 #include <util_quda.h>
00011 #include <sys/time.h>
00012 
00013 #include <iostream>
00014 
00015 void invertCgCuda(const DiracMatrix &mat, const DiracMatrix &matSloppy, cudaColorSpinorField &x,
00016                   cudaColorSpinorField &b, QudaInvertParam *invert_param)
00017 {
00018   int k=0;
00019   int rUpdate = 0;
00020     
00021   cudaColorSpinorField r(b);
00022 
00023   ColorSpinorParam param;
00024   param.create = QUDA_ZERO_FIELD_CREATE;
00025   cudaColorSpinorField y(b, param); 
00026   
00027   mat(r, x, y);
00028   zeroCuda(y);
00029 
00030   double r2 = xmyNormCuda(b, r);
00031   rUpdate ++;
00032   
00033   param.precision = invert_param->cuda_prec_sloppy;
00034   cudaColorSpinorField Ap(x, param);
00035   cudaColorSpinorField tmp(x, param);
00036   cudaColorSpinorField tmp2(x, param); // only needed for clover and twisted mass
00037 
00038   cudaColorSpinorField *x_sloppy, *r_sloppy;
00039   if (invert_param->cuda_prec_sloppy == x.Precision()) {
00040     param.create = QUDA_REFERENCE_FIELD_CREATE;
00041     x_sloppy = &x;
00042     r_sloppy = &r;
00043   } else {
00044     param.create = QUDA_COPY_FIELD_CREATE;
00045     x_sloppy = new cudaColorSpinorField(x, param);
00046     r_sloppy = new cudaColorSpinorField(r, param);
00047   }
00048 
00049   cudaColorSpinorField &xSloppy = *x_sloppy;
00050   cudaColorSpinorField &rSloppy = *r_sloppy;
00051 
00052   cudaColorSpinorField p(rSloppy);
00053 
00054   double r2_old;
00055   double src_norm = norm2(b);
00056   double stop = src_norm*invert_param->tol*invert_param->tol; // stopping condition of solver
00057 
00058   double alpha, beta;
00059   double pAp;
00060 
00061   double rNorm = sqrt(r2);
00062   double r0Norm = rNorm;
00063   double maxrx = rNorm;
00064   double maxrr = rNorm;
00065   double delta = invert_param->reliable_delta;
00066 
00067   if (invert_param->verbosity >= QUDA_VERBOSE) printfQuda("CG: %d iterations, r2 = %e\n", k, r2);
00068 
00069   blas_quda_flops = 0;
00070 
00071   stopwatchStart();
00072   while (r2 > stop && k<invert_param->maxiter) {
00073 
00074     matSloppy(Ap, p, tmp, tmp2); // tmp as tmp
00075     
00076     pAp = reDotProductCuda(p, Ap);
00077     alpha = r2 / pAp;        
00078     r2_old = r2;
00079     r2 = axpyNormCuda(-alpha, Ap, rSloppy);
00080 
00081     // reliable update conditions
00082     rNorm = sqrt(r2);
00083     if (rNorm > maxrx) maxrx = rNorm;
00084     if (rNorm > maxrr) maxrr = rNorm;
00085     int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? 1 : 0;
00086     int updateR = ((rNorm < delta*maxrr && r0Norm <= maxrr) || updateX) ? 1 : 0;
00087     
00088     if (!(updateR || updateX)) {
00089       beta = r2 / r2_old;
00090       axpyZpbxCuda(alpha, p, xSloppy, rSloppy, beta);
00091     } else {
00092       axpyCuda(alpha, p, xSloppy);
00093       if (x.Precision() != xSloppy.Precision()) copyCuda(x, xSloppy);
00094       
00095       xpyCuda(x, y); // swap these around?
00096       mat(r, y, x); // here we can use x as tmp
00097       r2 = xmyNormCuda(b, r);
00098       if (x.Precision() != rSloppy.Precision()) copyCuda(rSloppy, r);            
00099       zeroCuda(xSloppy);
00100 
00101       rNorm = sqrt(r2);
00102       maxrr = rNorm;
00103       maxrx = rNorm;
00104       r0Norm = rNorm;      
00105       rUpdate++;
00106 
00107       beta = r2 / r2_old;
00108       xpayCuda(rSloppy, beta, p);
00109     }
00110 
00111     k++;
00112     if (invert_param->verbosity >= QUDA_VERBOSE)
00113       printfQuda("CG: %d iterations, r2 = %e\n", k, r2);
00114   }
00115 
00116   if (x.Precision() != xSloppy.Precision()) copyCuda(x, xSloppy);
00117   xpyCuda(y, x);
00118 
00119   invert_param->secs = stopwatchReadSeconds();
00120 
00121   
00122   if (k==invert_param->maxiter) 
00123     warningQuda("Exceeded maximum iterations %d", invert_param->maxiter);
00124 
00125   if (invert_param->verbosity >= QUDA_SUMMARIZE)
00126     printfQuda("CG: Reliable updates = %d\n", rUpdate);
00127 
00128   float gflops = (blas_quda_flops + mat.flops() + matSloppy.flops())*1e-9;
00129   //  printfQuda("%f gflops\n", gflops / stopwatchReadSeconds());
00130   invert_param->gflops = gflops;
00131   invert_param->iter = k;
00132 
00133   blas_quda_flops = 0;
00134 
00135   //#if 0
00136   // Calculate the true residual
00137   mat(r, x, y);
00138   double true_res = xmyNormCuda(b, r);
00139   if (invert_param->verbosity >= QUDA_SUMMARIZE){
00140     printfQuda("Converged after %d iterations, r2 = %e, relative true_r2 = %e\n", 
00141                k, r2, true_res / src_norm);
00142   }
00143   //#endif
00144 
00145   if (invert_param->cuda_prec_sloppy != x.Precision()) {
00146     delete r_sloppy;
00147     delete x_sloppy;
00148   }
00149 
00150   return;
00151 }
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Defines