QUDA v0.4.0
A library for QCD on GPUs
quda/lib/inv_mr_quda.cpp
Go to the documentation of this file.
00001 #include <stdio.h>
00002 #include <stdlib.h>
00003 #include <math.h>
00004 
00005 #include <complex>
00006 
00007 #include <quda_internal.h>
00008 #include <blas_quda.h>
00009 #include <dslash_quda.h>
00010 #include <invert_quda.h>
00011 #include <util_quda.h>
00012 
00013 #include<face_quda.h>
00014 
00015 #include <color_spinor_field.h>
00016 
00017 MR::MR(DiracMatrix &mat, QudaInvertParam &invParam) :
00018   Solver(invParam), mat(mat), init(false)
00019 {
00020  
00021 }
00022 
00023 MR::~MR() {
00024   if (init) {
00025     if (rp) delete rp;
00026     delete Arp;
00027     delete tmpp;
00028   }
00029 }
00030 
00031 void MR::operator()(cudaColorSpinorField &x, cudaColorSpinorField &b)
00032 {
00033 
00034   globalReduce = false; // use local reductions for DD solver
00035 
00036   if (!init) {
00037     ColorSpinorParam param(x);
00038     param.create = QUDA_ZERO_FIELD_CREATE;
00039     if (invParam.preserve_source == QUDA_PRESERVE_SOURCE_YES)
00040       rp = new cudaColorSpinorField(x, param); 
00041     Arp = new cudaColorSpinorField(x);
00042     tmpp = new cudaColorSpinorField(x, param); //temporary for mat-vec
00043 
00044     init = true;
00045   }
00046   cudaColorSpinorField &r = 
00047     (invParam.preserve_source == QUDA_PRESERVE_SOURCE_YES) ? *rp : b;
00048   cudaColorSpinorField &Ar = *Arp;
00049   cudaColorSpinorField &tmp = *tmpp;
00050 
00051   // set initial guess to zero and thus the residual is just the source
00052   zeroCuda(x);  // can get rid of this for a special first update kernel  
00053   double b2 = normCuda(b);
00054   if (&r != &b) copyCuda(r, b);
00055 
00056   // domain-wise normalization of the initial residual to prevent underflow
00057   double r2=0.0; // if zero source then we will exit immediately doing no work
00058   if (b2 > 0.0) {
00059     axCuda(1/sqrt(b2), r); // can merge this with the prior copy
00060     r2 = 1.0; // by definition by this is now true
00061   }
00062   double stop = b2*invParam.tol*invParam.tol; // stopping condition of solver
00063 
00064   if (invParam.inv_type_precondition != QUDA_GCR_INVERTER) {
00065     quda::blas_flops = 0;
00066     stopwatchStart();
00067   }
00068 
00069   double omega = 1.0;
00070 
00071   int k = 0;
00072   if (invParam.verbosity >= QUDA_DEBUG_VERBOSE) {
00073     double x2 = norm2(x);
00074     double3 Ar3 = cDotProductNormBCuda(Ar, r);
00075     printfQuda("MR: %d iterations, r2 = %e, <r|A|r> = (%e, %e), x2 = %e\n", 
00076                k, Ar3.x, Ar3.y, Ar3.z, x2);
00077   }
00078 
00079   while (r2 > stop && k < invParam.maxiter) {
00080     
00081     mat(Ar, r, tmp);
00082     
00083     double3 Ar3 = cDotProductNormACuda(Ar, r);
00084     quda::Complex alpha = quda::Complex(Ar3.x, Ar3.y) / Ar3.z;
00085 
00086     // x += omega*alpha*r, r -= omega*alpha*Ar, r2 = norm2(r)
00087     //r2 = caxpyXmazNormXCuda(omega*alpha, r, x, Ar);
00088     caxpyXmazCuda(omega*alpha, r, x, Ar);
00089 
00090     if (invParam.verbosity >= QUDA_DEBUG_VERBOSE) {
00091       double x2 = norm2(x);
00092       double r2 = norm2(r);
00093       printfQuda("MR: %d iterations, r2 = %e, <r|A|r> = (%e,%e) x2 = %e\n", 
00094                  k+1, r2, Ar3.x, Ar3.y, x2);
00095     } else if (invParam.verbosity >= QUDA_VERBOSE) {
00096       printfQuda("MR: %d iterations, <r|A|r> = (%e, %e)\n", k, Ar3.x, Ar3.y);
00097     }
00098 
00099     k++;
00100   }
00101   
00102   if (invParam.verbosity >= QUDA_VERBOSE) {
00103     mat(Ar, r, tmp);    
00104     quda::Complex Ar2 = cDotProductCuda(Ar, r);
00105     printfQuda("MR: %d iterations, <r|A|r> = (%e, %e)\n", k, real(Ar2), imag(Ar2));
00106   }
00107 
00108   // Obtain global solution by rescaling
00109   if (b2 > 0.0) axCuda(sqrt(b2), x);
00110 
00111   if (k>=invParam.maxiter && invParam.verbosity >= QUDA_SUMMARIZE) 
00112     warningQuda("Exceeded maximum iterations %d", invParam.maxiter);
00113   
00114   if (invParam.inv_type_precondition != QUDA_GCR_INVERTER) {
00115     invParam.secs += stopwatchReadSeconds();
00116   
00117     double gflops = (quda::blas_flops + mat.flops())*1e-9;
00118     reduceDouble(gflops);
00119 
00120     invParam.gflops += gflops;
00121     invParam.iter += k;
00122     
00123     if (invParam.verbosity >= QUDA_SUMMARIZE) {
00124       // Calculate the true residual
00125       r2 = norm2(r);
00126       mat(r, x);
00127       double true_res = xmyNormCuda(b, r);
00128       
00129       printfQuda("MR: Converged after %d iterations, relative residua: iterated = %e, true = %e\n", 
00130                  k, sqrt(r2/b2), sqrt(true_res / b2));    
00131     }
00132   }
00133 
00134   globalReduce = true; // renable global reductions for outer solver
00135 
00136   return;
00137 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines