QUDA v0.4.0
A library for QCD on GPUs
quda/lib/dirac_wilson.cpp
Go to the documentation of this file.
00001 #include <dirac_quda.h>
00002 #include <blas_quda.h>
00003 #include <iostream>
00004 
00005 DiracWilson::DiracWilson(const DiracParam &param) : 
00006   Dirac(param), face(param.gauge->X(), 4, 12, 1, param.gauge->Precision()) { }
00007 
00008 DiracWilson::DiracWilson(const DiracWilson &dirac) : 
00009   Dirac(dirac), face(dirac.face) { }
00010 
00011 DiracWilson::~DiracWilson() { }
00012 
00013 DiracWilson& DiracWilson::operator=(const DiracWilson &dirac)
00014 {
00015   if (&dirac != this) {
00016     Dirac::operator=(dirac);
00017     face = dirac.face;
00018   }
00019   return *this;
00020 }
00021 
00022 void DiracWilson::Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00023                          const QudaParity parity) const
00024 {
00025   if (!initDslash) initDslashConstants(gauge, in.Stride());
00026   checkParitySpinor(in, out);
00027   checkSpinorAlias(in, out);
00028 
00029   setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda
00030 
00031   wilsonDslashCuda(&out, gauge, &in, parity, dagger, 0, 0.0, commDim);
00032 
00033   flops += 1320ll*in.Volume();
00034 }
00035 
00036 void DiracWilson::DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00037                              const QudaParity parity, const cudaColorSpinorField &x,
00038                              const double &k) const
00039 {
00040   if (!initDslash) initDslashConstants(gauge, in.Stride());
00041   checkParitySpinor(in, out);
00042   checkSpinorAlias(in, out);
00043 
00044   setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda
00045 
00046   wilsonDslashCuda(&out, gauge, &in, parity, dagger, &x, k, commDim);
00047 
00048   flops += 1368ll*in.Volume();
00049 }
00050 
00051 void DiracWilson::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00052 {
00053   checkFullSpinor(out, in);
00054   DslashXpay(out.Odd(), in.Even(), QUDA_ODD_PARITY, in.Odd(), -kappa);
00055   DslashXpay(out.Even(), in.Odd(), QUDA_EVEN_PARITY, in.Even(), -kappa);
00056 }
00057 
00058 void DiracWilson::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00059 {
00060   checkFullSpinor(out, in);
00061 
00062   bool reset = newTmp(&tmp1, in);
00063   checkFullSpinor(*tmp1, in);
00064 
00065   M(*tmp1, in);
00066   Mdag(out, *tmp1);
00067 
00068   deleteTmp(&tmp1, reset);
00069 }
00070 
00071 void DiracWilson::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00072                           cudaColorSpinorField &x, cudaColorSpinorField &b, 
00073                           const QudaSolutionType solType) const
00074 {
00075   if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
00076     errorQuda("Preconditioned solution requires a preconditioned solve_type");
00077   }
00078 
00079   src = &b;
00080   sol = &x;
00081 }
00082 
00083 void DiracWilson::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00084                               const QudaSolutionType solType) const
00085 {
00086   // do nothing
00087 }
00088 
00089 DiracWilsonPC::DiracWilsonPC(const DiracParam &param)
00090   : DiracWilson(param)
00091 {
00092 
00093 }
00094 
00095 DiracWilsonPC::DiracWilsonPC(const DiracWilsonPC &dirac) 
00096   : DiracWilson(dirac)
00097 {
00098 
00099 }
00100 
00101 DiracWilsonPC::~DiracWilsonPC()
00102 {
00103 
00104 }
00105 
00106 DiracWilsonPC& DiracWilsonPC::operator=(const DiracWilsonPC &dirac)
00107 {
00108   if (&dirac != this) {
00109     DiracWilson::operator=(dirac);
00110   }
00111   return *this;
00112 }
00113 
00114 void DiracWilsonPC::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00115 {
00116   double kappa2 = -kappa*kappa;
00117 
00118   bool reset = newTmp(&tmp1, in);
00119 
00120   if (matpcType == QUDA_MATPC_EVEN_EVEN) {
00121     Dslash(*tmp1, in, QUDA_ODD_PARITY);
00122     DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, in, kappa2); 
00123   } else if (matpcType == QUDA_MATPC_ODD_ODD) {
00124     Dslash(*tmp1, in, QUDA_EVEN_PARITY);
00125     DslashXpay(out, *tmp1, QUDA_ODD_PARITY, in, kappa2); 
00126   } else {
00127     errorQuda("MatPCType %d not valid for DiracWilsonPC", matpcType);
00128   }
00129 
00130   deleteTmp(&tmp1, reset);
00131 }
00132 
00133 void DiracWilsonPC::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00134 {
00135   M(out, in);
00136   Mdag(out, out);
00137 }
00138 
00139 void DiracWilsonPC::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00140                             cudaColorSpinorField &x, cudaColorSpinorField &b, 
00141                             const QudaSolutionType solType) const
00142 {
00143   // we desire solution to preconditioned system
00144   if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
00145     src = &b;
00146     sol = &x;
00147   } else {
00148     // we desire solution to full system
00149     if (matpcType == QUDA_MATPC_EVEN_EVEN) {
00150       // src = b_e + k D_eo b_o
00151       DslashXpay(x.Odd(), b.Odd(), QUDA_EVEN_PARITY, b.Even(), kappa);
00152       src = &(x.Odd());
00153       sol = &(x.Even());
00154     } else if (matpcType == QUDA_MATPC_ODD_ODD) {
00155       // src = b_o + k D_oe b_e
00156       DslashXpay(x.Even(), b.Even(), QUDA_ODD_PARITY, b.Odd(), kappa);
00157       src = &(x.Even());
00158       sol = &(x.Odd());
00159     } else {
00160       errorQuda("MatPCType %d not valid for DiracWilsonPC", matpcType);
00161     }
00162     // here we use final solution to store parity solution and parity source
00163     // b is now up for grabs if we want
00164   }
00165 
00166 }
00167 
00168 void DiracWilsonPC::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00169                                 const QudaSolutionType solType) const
00170 {
00171   if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
00172     return;
00173   }                             
00174 
00175   // create full solution
00176 
00177   checkFullSpinor(x, b);
00178   if (matpcType == QUDA_MATPC_EVEN_EVEN) {
00179     // x_o = b_o + k D_oe x_e
00180     DslashXpay(x.Odd(), x.Even(), QUDA_ODD_PARITY, b.Odd(), kappa);
00181   } else if (matpcType == QUDA_MATPC_ODD_ODD) {
00182     // x_e = b_e + k D_eo x_o
00183     DslashXpay(x.Even(), x.Odd(), QUDA_EVEN_PARITY, b.Even(), kappa);
00184   } else {
00185     errorQuda("MatPCType %d not valid for DiracWilsonPC", matpcType);
00186   }
00187 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines