QUDA v0.3.2
A library for QCD on GPUs

quda/lib/dirac_wilson.cpp

Go to the documentation of this file.
00001 #include <dirac_quda.h>
00002 #include <blas_quda.h>
00003 #include <iostream>
00004 
00005 DiracWilson::DiracWilson(const DiracParam &param)
00006   : Dirac(param)
00007 {
00008 
00009 }
00010 
00011 DiracWilson::DiracWilson(const DiracWilson &dirac) 
00012   : Dirac(dirac)
00013 {
00014 
00015 }
00016 
00017 DiracWilson::~DiracWilson()
00018 {
00019 
00020 }
00021 
00022 DiracWilson& DiracWilson::operator=(const DiracWilson &dirac)
00023 {
00024   if (&dirac != this) {
00025     Dirac::operator=(dirac);
00026   }
00027   return *this;
00028 }
00029 
00030 void DiracWilson::Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00031                          const QudaParity parity) const
00032 {
00033   if (!initDslash) initDslashConstants(gauge, in.Stride(), 0);
00034   checkParitySpinor(in, out);
00035   checkSpinorAlias(in, out);
00036 
00037   dslashCuda(out.v, out.norm, gauge, in.v, in.norm, parity, dagger, 
00038              0, 0, 0, out.volume, out.length, in.Precision());
00039 
00040   flops += 1320*in.volume;
00041 }
00042 
00043 void DiracWilson::DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00044                              const QudaParity parity, const cudaColorSpinorField &x,
00045                              const double &k) const
00046 {
00047   if (!initDslash) initDslashConstants(gauge, in.Stride(), 0);
00048   checkParitySpinor(in, out);
00049   checkSpinorAlias(in, out);
00050 
00051   dslashCuda(out.v, out.norm, gauge, in.v, in.norm, parity, dagger, x.v, x.norm, k, 
00052              out.volume, out.length, in.Precision());
00053 
00054   flops += (1320+48)*in.volume;
00055 }
00056 
00057 void DiracWilson::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00058 {
00059   checkFullSpinor(out, in);
00060   DslashXpay(out.Odd(), in.Even(), QUDA_ODD_PARITY, in.Odd(), -kappa);
00061   DslashXpay(out.Even(), in.Odd(), QUDA_EVEN_PARITY, in.Even(), -kappa);
00062 }
00063 
00064 void DiracWilson::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00065 {
00066   checkFullSpinor(out, in);
00067 
00068   bool reset = newTmp(&tmp1, in);
00069   checkFullSpinor(*tmp1, in);
00070 
00071   M(*tmp1, in);
00072   Mdag(out, *tmp1);
00073 
00074   deleteTmp(&tmp1, reset);
00075 }
00076 
00077 void DiracWilson::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00078                           cudaColorSpinorField &x, cudaColorSpinorField &b, 
00079                           const QudaSolutionType solType) const
00080 {
00081   if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
00082     errorQuda("Preconditioned solution requires a preconditioned solve_type");
00083   }
00084 
00085   src = &b;
00086   sol = &x;
00087 }
00088 
00089 void DiracWilson::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00090                               const QudaSolutionType solType) const
00091 {
00092   // do nothing
00093 }
00094 
00095 DiracWilsonPC::DiracWilsonPC(const DiracParam &param)
00096   : DiracWilson(param)
00097 {
00098 
00099 }
00100 
00101 DiracWilsonPC::DiracWilsonPC(const DiracWilsonPC &dirac) 
00102   : DiracWilson(dirac)
00103 {
00104 
00105 }
00106 
00107 DiracWilsonPC::~DiracWilsonPC()
00108 {
00109 
00110 }
00111 
00112 DiracWilsonPC& DiracWilsonPC::operator=(const DiracWilsonPC &dirac)
00113 {
00114   if (&dirac != this) {
00115     DiracWilson::operator=(dirac);
00116   }
00117   return *this;
00118 }
00119 
00120 void DiracWilsonPC::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00121 {
00122   double kappa2 = -kappa*kappa;
00123 
00124   bool reset = newTmp(&tmp1, in);
00125 
00126   if (matpcType == QUDA_MATPC_EVEN_EVEN) {
00127     Dslash(*tmp1, in, QUDA_ODD_PARITY);
00128     DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, in, kappa2); 
00129   } else if (matpcType == QUDA_MATPC_ODD_ODD) {
00130     Dslash(*tmp1, in, QUDA_EVEN_PARITY);
00131     DslashXpay(out, *tmp1, QUDA_ODD_PARITY, in, kappa2); 
00132   } else {
00133     errorQuda("MatPCType %d not valid for DiracWilsonPC", matpcType);
00134   }
00135 
00136   deleteTmp(&tmp1, reset);
00137 }
00138 
00139 void DiracWilsonPC::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00140 {
00141   M(out, in);
00142   Mdag(out, out);
00143 }
00144 
00145 void DiracWilsonPC::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00146                             cudaColorSpinorField &x, cudaColorSpinorField &b, 
00147                             const QudaSolutionType solType) const
00148 {
00149   // we desire solution to preconditioned system
00150   if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
00151     src = &b;
00152     sol = &x;
00153   } else {
00154     // we desire solution to full system
00155     if (matpcType == QUDA_MATPC_EVEN_EVEN) {
00156       // src = b_e + k D_eo b_o
00157       DslashXpay(x.Odd(), b.Odd(), QUDA_EVEN_PARITY, b.Even(), kappa);
00158       src = &(x.Odd());
00159       sol = &(x.Even());
00160     } else if (matpcType == QUDA_MATPC_ODD_ODD) {
00161       // src = b_o + k D_oe b_e
00162       DslashXpay(x.Even(), b.Even(), QUDA_ODD_PARITY, b.Odd(), kappa);
00163       src = &(x.Even());
00164       sol = &(x.Odd());
00165     } else {
00166       errorQuda("MatPCType %d not valid for DiracWilsonPC", matpcType);
00167     }
00168     // here we use final solution to store parity solution and parity source
00169     // b is now up for grabs if we want
00170   }
00171 
00172 }
00173 
00174 void DiracWilsonPC::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00175                                 const QudaSolutionType solType) const
00176 {
00177   if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
00178     return;
00179   }                             
00180 
00181   // create full solution
00182 
00183   checkFullSpinor(x, b);
00184   if (matpcType == QUDA_MATPC_EVEN_EVEN) {
00185     // x_o = b_o + k D_oe x_e
00186     DslashXpay(x.Odd(), x.Even(), QUDA_ODD_PARITY, b.Odd(), kappa);
00187   } else if (matpcType == QUDA_MATPC_ODD_ODD) {
00188     // x_e = b_e + k D_eo x_o
00189     DslashXpay(x.Even(), x.Odd(), QUDA_EVEN_PARITY, b.Even(), kappa);
00190   } else {
00191     errorQuda("MatPCType %d not valid for DiracWilsonPC", matpcType);
00192   }
00193 }
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Defines