QUDA v0.4.0
A library for QCD on GPUs
|
00001 #include <stdio.h> 00002 #include <stdlib.h> 00003 #include <math.h> 00004 00005 #include <complex> 00006 00007 #include <quda_internal.h> 00008 #include <blas_quda.h> 00009 #include <dslash_quda.h> 00010 #include <invert_quda.h> 00011 #include <util_quda.h> 00012 00013 #include<face_quda.h> 00014 00015 #include <color_spinor_field.h> 00016 00017 MR::MR(DiracMatrix &mat, QudaInvertParam &invParam) : 00018 Solver(invParam), mat(mat), init(false) 00019 { 00020 00021 } 00022 00023 MR::~MR() { 00024 if (init) { 00025 if (rp) delete rp; 00026 delete Arp; 00027 delete tmpp; 00028 } 00029 } 00030 00031 void MR::operator()(cudaColorSpinorField &x, cudaColorSpinorField &b) 00032 { 00033 00034 globalReduce = false; // use local reductions for DD solver 00035 00036 if (!init) { 00037 ColorSpinorParam param(x); 00038 param.create = QUDA_ZERO_FIELD_CREATE; 00039 if (invParam.preserve_source == QUDA_PRESERVE_SOURCE_YES) 00040 rp = new cudaColorSpinorField(x, param); 00041 Arp = new cudaColorSpinorField(x); 00042 tmpp = new cudaColorSpinorField(x, param); //temporary for mat-vec 00043 00044 init = true; 00045 } 00046 cudaColorSpinorField &r = 00047 (invParam.preserve_source == QUDA_PRESERVE_SOURCE_YES) ? *rp : b; 00048 cudaColorSpinorField &Ar = *Arp; 00049 cudaColorSpinorField &tmp = *tmpp; 00050 00051 // set initial guess to zero and thus the residual is just the source 00052 zeroCuda(x); // can get rid of this for a special first update kernel 00053 double b2 = normCuda(b); 00054 if (&r != &b) copyCuda(r, b); 00055 00056 // domain-wise normalization of the initial residual to prevent underflow 00057 double r2=0.0; // if zero source then we will exit immediately doing no work 00058 if (b2 > 0.0) { 00059 axCuda(1/sqrt(b2), r); // can merge this with the prior copy 00060 r2 = 1.0; // by definition by this is now true 00061 } 00062 double stop = b2*invParam.tol*invParam.tol; // stopping condition of solver 00063 00064 if (invParam.inv_type_precondition != QUDA_GCR_INVERTER) { 00065 quda::blas_flops = 0; 00066 stopwatchStart(); 00067 } 00068 00069 double omega = 1.0; 00070 00071 int k = 0; 00072 if (invParam.verbosity >= QUDA_DEBUG_VERBOSE) { 00073 double x2 = norm2(x); 00074 double3 Ar3 = cDotProductNormBCuda(Ar, r); 00075 printfQuda("MR: %d iterations, r2 = %e, <r|A|r> = (%e, %e), x2 = %e\n", 00076 k, Ar3.x, Ar3.y, Ar3.z, x2); 00077 } 00078 00079 while (r2 > stop && k < invParam.maxiter) { 00080 00081 mat(Ar, r, tmp); 00082 00083 double3 Ar3 = cDotProductNormACuda(Ar, r); 00084 quda::Complex alpha = quda::Complex(Ar3.x, Ar3.y) / Ar3.z; 00085 00086 // x += omega*alpha*r, r -= omega*alpha*Ar, r2 = norm2(r) 00087 //r2 = caxpyXmazNormXCuda(omega*alpha, r, x, Ar); 00088 caxpyXmazCuda(omega*alpha, r, x, Ar); 00089 00090 if (invParam.verbosity >= QUDA_DEBUG_VERBOSE) { 00091 double x2 = norm2(x); 00092 double r2 = norm2(r); 00093 printfQuda("MR: %d iterations, r2 = %e, <r|A|r> = (%e,%e) x2 = %e\n", 00094 k+1, r2, Ar3.x, Ar3.y, x2); 00095 } else if (invParam.verbosity >= QUDA_VERBOSE) { 00096 printfQuda("MR: %d iterations, <r|A|r> = (%e, %e)\n", k, Ar3.x, Ar3.y); 00097 } 00098 00099 k++; 00100 } 00101 00102 if (invParam.verbosity >= QUDA_VERBOSE) { 00103 mat(Ar, r, tmp); 00104 quda::Complex Ar2 = cDotProductCuda(Ar, r); 00105 printfQuda("MR: %d iterations, <r|A|r> = (%e, %e)\n", k, real(Ar2), imag(Ar2)); 00106 } 00107 00108 // Obtain global solution by rescaling 00109 if (b2 > 0.0) axCuda(sqrt(b2), x); 00110 00111 if (k>=invParam.maxiter && invParam.verbosity >= QUDA_SUMMARIZE) 00112 warningQuda("Exceeded maximum iterations %d", invParam.maxiter); 00113 00114 if (invParam.inv_type_precondition != QUDA_GCR_INVERTER) { 00115 invParam.secs += stopwatchReadSeconds(); 00116 00117 double gflops = (quda::blas_flops + mat.flops())*1e-9; 00118 reduceDouble(gflops); 00119 00120 invParam.gflops += gflops; 00121 invParam.iter += k; 00122 00123 if (invParam.verbosity >= QUDA_SUMMARIZE) { 00124 // Calculate the true residual 00125 r2 = norm2(r); 00126 mat(r, x); 00127 double true_res = xmyNormCuda(b, r); 00128 00129 printfQuda("MR: Converged after %d iterations, relative residua: iterated = %e, true = %e\n", 00130 k, sqrt(r2/b2), sqrt(true_res / b2)); 00131 } 00132 } 00133 00134 globalReduce = true; // renable global reductions for outer solver 00135 00136 return; 00137 }