QUDA v0.4.0
A library for QCD on GPUs
quda/lib/dirac_staggered.cpp
Go to the documentation of this file.
00001 #include <dirac_quda.h>
00002 #include <blas_quda.h>
00003 
00004 DiracStaggered::DiracStaggered(const DiracParam &param) : 
00005   Dirac(param), fatGauge(param.fatGauge), longGauge(param.longGauge), 
00006   face(param.fatGauge->X(), 4, 6, 3, param.fatGauge->Precision()) 
00007   //FIXME: this may break mixed precision multishift solver since may not have fatGauge initializeed yet
00008 {
00009 
00010 }
00011 
00012 DiracStaggered::DiracStaggered(const DiracStaggered &dirac) : Dirac(dirac),
00013   fatGauge(dirac.fatGauge), longGauge(dirac.longGauge), face(dirac.face) { }
00014 
00015 DiracStaggered::~DiracStaggered()
00016 {
00017 
00018 }
00019 
00020 DiracStaggered& DiracStaggered::operator=(const DiracStaggered &dirac)
00021 {
00022   if (&dirac != this) {
00023     Dirac::operator=(dirac);
00024     fatGauge = dirac.fatGauge;
00025     longGauge = dirac.longGauge;
00026     face = dirac.face;
00027   }
00028  
00029   return *this;
00030 }
00031 
00032 void DiracStaggered::checkParitySpinor(const cudaColorSpinorField &in, const cudaColorSpinorField &out) const
00033 {
00034   if (in.Precision() != out.Precision()) {
00035     errorQuda("Input and output spinor precisions don't match in dslash_quda");
00036   }
00037 
00038   if (in.Stride() != out.Stride()) {
00039     errorQuda("Input %d and output %d spinor strides don't match in dslash_quda", in.Stride(), out.Stride());
00040   }
00041 
00042   if (in.SiteSubset() != QUDA_PARITY_SITE_SUBSET || out.SiteSubset() != QUDA_PARITY_SITE_SUBSET) {
00043     errorQuda("ColorSpinorFields are not single parity, in = %d, out = %d", 
00044               in.SiteSubset(), out.SiteSubset());
00045   }
00046 
00047   if ((out.Volume() != 2*fatGauge->VolumeCB() && out.SiteSubset() == QUDA_FULL_SITE_SUBSET) ||
00048       (out.Volume() != fatGauge->VolumeCB() && out.SiteSubset() == QUDA_PARITY_SITE_SUBSET) ) {
00049     errorQuda("Spinor volume %d doesn't match gauge volume %d", out.Volume(), fatGauge->VolumeCB());
00050   }
00051 }
00052 
00053 
00054 void DiracStaggered::Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00055                          const QudaParity parity) const
00056 {
00057   if (!initDslash) {
00058     initDslashConstants(*fatGauge, in.Stride());
00059     initStaggeredConstants(*fatGauge, *longGauge);
00060   }
00061   checkParitySpinor(in, out);
00062 
00063   setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda
00064   staggeredDslashCuda(&out, *fatGauge, *longGauge, &in, parity, dagger, 0, 0, commDim);
00065   
00066   flops += 1146*in.Volume();
00067 }
00068 
00069 void DiracStaggered::DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00070                                 const QudaParity parity, const cudaColorSpinorField &x,
00071                                 const double &k) const
00072 {    
00073   if (!initDslash){
00074     initDslashConstants(*fatGauge, in.Stride());
00075     initStaggeredConstants(*fatGauge, *longGauge);
00076   }
00077   checkParitySpinor(in, out);
00078 
00079   setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda
00080   staggeredDslashCuda(&out, *fatGauge, *longGauge, &in, parity, dagger, &x, k, commDim);
00081   
00082   flops += (1146+12)*in.Volume();
00083 }
00084 
00085 // Full staggered operator
00086 void DiracStaggered::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00087 {
00088   if (!initDslash){
00089     initDslashConstants(*fatGauge, in.Stride());
00090     initStaggeredConstants(*fatGauge, *longGauge);
00091   }
00092 
00093   bool reset = newTmp(&tmp1, in.Even());
00094 
00095   DslashXpay(out.Even(), in.Odd(), QUDA_EVEN_PARITY, *tmp1, 2*mass);  
00096   DslashXpay(out.Odd(), in.Even(), QUDA_ODD_PARITY, *tmp1, 2*mass);
00097   
00098   deleteTmp(&tmp1, reset);
00099 }
00100 
00101 void DiracStaggered::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00102 {
00103 
00104   if (!initDslash){
00105     initDslashConstants(*fatGauge, in.Stride());
00106     initStaggeredConstants(*fatGauge, *longGauge);
00107   }
00108   
00109   bool reset = newTmp(&tmp1, in);
00110   
00111   cudaColorSpinorField* mytmp = dynamic_cast<cudaColorSpinorField*>(&(tmp1->Even()));
00112   cudaColorSpinorField* ineven = dynamic_cast<cudaColorSpinorField*>(&(in.Even()));
00113   cudaColorSpinorField* inodd = dynamic_cast<cudaColorSpinorField*>(&(in.Odd()));
00114   cudaColorSpinorField* outeven = dynamic_cast<cudaColorSpinorField*>(&(out.Even()));
00115   cudaColorSpinorField* outodd = dynamic_cast<cudaColorSpinorField*>(&(out.Odd()));
00116   
00117   //even
00118   Dslash(*mytmp, *ineven, QUDA_ODD_PARITY);  
00119   DslashXpay(*outeven, *mytmp, QUDA_EVEN_PARITY, *ineven, 4*mass*mass);
00120   
00121   //odd
00122   Dslash(*mytmp, *inodd, QUDA_EVEN_PARITY);  
00123   DslashXpay(*outodd, *mytmp, QUDA_ODD_PARITY, *inodd, 4*mass*mass);    
00124 
00125   deleteTmp(&tmp1, reset);
00126 }
00127 
00128 void DiracStaggered::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00129                              cudaColorSpinorField &x, cudaColorSpinorField &b, 
00130                              const QudaSolutionType solType) const
00131 {
00132   if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
00133     errorQuda("Preconditioned solution requires a preconditioned solve_type");
00134   }
00135 
00136   src = &b;
00137   sol = &x;  
00138 }
00139 
00140 void DiracStaggered::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00141                                  const QudaSolutionType solType) const
00142 {
00143   // do nothing
00144 }
00145 
00146 
00147 DiracStaggeredPC::DiracStaggeredPC(const DiracParam &param)
00148   : DiracStaggered(param)
00149 {
00150 
00151 }
00152 
00153 DiracStaggeredPC::DiracStaggeredPC(const DiracStaggeredPC &dirac) 
00154   : DiracStaggered(dirac)
00155 {
00156 
00157 }
00158 
00159 DiracStaggeredPC::~DiracStaggeredPC()
00160 {
00161 
00162 }
00163 
00164 DiracStaggeredPC& DiracStaggeredPC::operator=(const DiracStaggeredPC &dirac)
00165 {
00166   if (&dirac != this) {
00167     DiracStaggered::operator=(dirac);
00168   }
00169  
00170   return *this;
00171 }
00172 
00173 void DiracStaggeredPC::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00174 {
00175   errorQuda("DiracStaggeredPC::M() is not implemented\n");
00176 }
00177 
00178 void DiracStaggeredPC::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00179 {
00180   if (!initDslash){
00181     initDslashConstants(*fatGauge, in.Stride());
00182     initStaggeredConstants(*fatGauge, *longGauge);
00183   }
00184   
00185   bool reset = newTmp(&tmp1, in);
00186   
00187   QudaParity parity = QUDA_INVALID_PARITY;
00188   QudaParity other_parity = QUDA_INVALID_PARITY;
00189   if (matpcType == QUDA_MATPC_EVEN_EVEN) {
00190     parity = QUDA_EVEN_PARITY;
00191     other_parity = QUDA_ODD_PARITY;
00192   } else if (matpcType == QUDA_MATPC_ODD_ODD) {
00193     parity = QUDA_ODD_PARITY;
00194     other_parity = QUDA_EVEN_PARITY;
00195   } else {
00196     errorQuda("Invalid matpcType(%d) in function\n", matpcType);    
00197   }
00198   Dslash(*tmp1, in, other_parity);  
00199   DslashXpay(out, *tmp1, parity, in, 4*mass*mass);
00200 
00201   deleteTmp(&tmp1, reset);
00202 }
00203 
00204 void DiracStaggeredPC::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00205                                cudaColorSpinorField &x, cudaColorSpinorField &b, 
00206                                const QudaSolutionType solType) const
00207 {
00208   src = &b;
00209   sol = &x;  
00210 }
00211 
00212 void DiracStaggeredPC::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00213                                    const QudaSolutionType solType) const
00214 {
00215   // do nothing
00216 }
00217 
00218 
00219 
00220 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines