QUDA v0.3.2
A library for QCD on GPUs

quda/lib/dirac_staggered.cpp

Go to the documentation of this file.
00001 #include <dirac_quda.h>
00002 #include <blas_quda.h>
00003 
00004 DiracStaggeredPC::DiracStaggeredPC(const DiracParam &param)
00005   : Dirac(param), fatGauge(param.fatGauge), longGauge(param.longGauge)
00006 {
00007 
00008 }
00009 
00010 DiracStaggeredPC::DiracStaggeredPC(const DiracStaggeredPC &dirac) 
00011   : Dirac(dirac)
00012 {
00013 
00014 }
00015 
00016 DiracStaggeredPC::~DiracStaggeredPC()
00017 {
00018 
00019 }
00020 
00021 DiracStaggeredPC& DiracStaggeredPC::operator=(const DiracStaggeredPC &dirac)
00022 {
00023   if (&dirac != this) {
00024     Dirac::operator=(dirac);
00025   }
00026  
00027   return *this;
00028 }
00029 
00030 
00031 
00032 void DiracStaggeredPC::checkParitySpinor(const cudaColorSpinorField &in, const cudaColorSpinorField &out) const
00033 {
00034 
00035   if (in.Precision() != out.Precision()) {
00036     errorQuda("Input and output spinor precisions don't match in dslash_quda");
00037   }
00038 
00039   if (in.Stride() != out.Stride()) {
00040     errorQuda("Input %d and output %d spinor strides don't match in dslash_quda", in.Stride(), out.Stride());
00041   }
00042 
00043   if (in.SiteSubset() != QUDA_PARITY_SITE_SUBSET || out.SiteSubset() != QUDA_PARITY_SITE_SUBSET) {
00044     errorQuda("ColorSpinorFields are not single parity, in = %d, out = %d", 
00045               in.SiteSubset(), out.SiteSubset());
00046   }
00047 
00048   if ((out.Volume() != 2*fatGauge->volume && out.SiteSubset() == QUDA_FULL_SITE_SUBSET) ||
00049       (out.Volume() != fatGauge->volume && out.SiteSubset() == QUDA_PARITY_SITE_SUBSET) ) {
00050       errorQuda("Spinor volume %d doesn't match gauge volume %d", out.Volume(), fatGauge->volume);
00051   }
00052 
00053 }
00054 
00055 
00056 void DiracStaggeredPC::Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00057                               const QudaParity parity) const
00058 {
00059   if (!initDslash) {
00060     initDslashConstants(*fatGauge, in.Stride(), 0);
00061   }
00062   checkParitySpinor(in, out);
00063     
00064   staggeredDslashCuda(out.v, out.norm, *fatGauge, *longGauge, in.v, in.norm, parity, dagger, 
00065                       0, 0, 0, out.volume, out.length, in.Precision());
00066     
00067   flops += 1146*in.volume;
00068 }
00069 
00070 void DiracStaggeredPC::DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00071                                   const QudaParity parity, const cudaColorSpinorField &x,
00072                                   const double &k) const
00073 {    
00074   if (!initDslash){
00075     initDslashConstants(*fatGauge, in.Stride(), 0);
00076   }
00077   checkParitySpinor(in, out);
00078   
00079   staggeredDslashCuda(out.v, out.norm, *fatGauge, *longGauge, in.v, in.norm, parity, dagger, x.v, x.norm, k, 
00080                       out.volume, out.length, in.Precision());
00081     
00082   flops += (1146+12)*in.volume;
00083 }
00084 
00085 void DiracStaggeredPC::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00086 {
00087   errorQuda("DiracStaggeredPC::M() is not implemented\n");
00088 }
00089 
00090 void DiracStaggeredPC::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00091 {
00092   if (!initDslash){
00093     initDslashConstants(*fatGauge, in.Stride(), 0);
00094   }
00095   
00096   bool reset = newTmp(&tmp1, in);
00097   
00098   QudaParity parity;
00099   QudaParity other_parity;
00100   if (matpcType == QUDA_MATPC_EVEN_EVEN) {
00101     parity = QUDA_EVEN_PARITY;
00102     other_parity = QUDA_ODD_PARITY;
00103   } else if (matpcType == QUDA_MATPC_ODD_ODD) {
00104     parity = QUDA_ODD_PARITY;
00105     other_parity = QUDA_EVEN_PARITY;
00106   } else {
00107     errorQuda("Invalid matpcType(%d) in function\n", matpcType);    
00108   }
00109   
00110   Dslash(*tmp1, in, other_parity);  
00111   DslashXpay(out, *tmp1, parity, in, 4*mass*mass);
00112 
00113   deleteTmp(&tmp1, reset);
00114 }
00115 
00116 void DiracStaggeredPC::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00117                                cudaColorSpinorField &x, cudaColorSpinorField &b, 
00118                                const QudaSolutionType solType) const
00119 {
00120   src = &b;
00121   sol = &x;  
00122 }
00123 
00124 void DiracStaggeredPC::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00125                                    const QudaSolutionType solType) const
00126 {
00127   // do nothing
00128 }
00129 
00130 
00131 
00132 
00133 DiracStaggered::DiracStaggered(const DiracParam &param)
00134   : Dirac(param), fatGauge(param.fatGauge), longGauge(param.longGauge)
00135 {
00136 
00137 }
00138 
00139 DiracStaggered::DiracStaggered(const DiracStaggered &dirac) 
00140   : Dirac(dirac)
00141 {
00142 
00143 }
00144 
00145 DiracStaggered::~DiracStaggered()
00146 {
00147 
00148 }
00149 
00150 DiracStaggered& DiracStaggered::operator=(const DiracStaggered &dirac)
00151 {
00152   if (&dirac != this) {
00153     Dirac::operator=(dirac);
00154   }
00155  
00156   return *this;
00157 }
00158 
00159 
00160 
00161 void DiracStaggered::checkParitySpinor(const cudaColorSpinorField &in, const cudaColorSpinorField &out) const
00162 {
00163   if (in.Precision() != out.Precision()) {
00164     errorQuda("Input and output spinor precisions don't match in dslash_quda");
00165   }
00166 
00167   if (in.Stride() != out.Stride()) {
00168     errorQuda("Input %d and output %d spinor strides don't match in dslash_quda", in.Stride(), out.Stride());
00169   }
00170 
00171   if (in.SiteSubset() != QUDA_PARITY_SITE_SUBSET || out.SiteSubset() != QUDA_PARITY_SITE_SUBSET) {
00172     errorQuda("ColorSpinorFields are not single parity, in = %d, out = %d", 
00173               in.SiteSubset(), out.SiteSubset());
00174   }
00175 
00176   if ((out.Volume() != 2*fatGauge->volume && out.SiteSubset() == QUDA_FULL_SITE_SUBSET) ||
00177       (out.Volume() != fatGauge->volume && out.SiteSubset() == QUDA_PARITY_SITE_SUBSET) ) {
00178       errorQuda("Spinor volume %d doesn't match gauge volume %d", out.Volume(), fatGauge->volume);
00179   }
00180 }
00181 
00182 
00183 void DiracStaggered::Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00184                          const QudaParity parity) const
00185 {
00186   if (!initDslash) {
00187     initDslashConstants(*fatGauge, in.Stride(), 0);
00188   }
00189   checkParitySpinor(in, out);
00190     
00191   staggeredDslashCuda(out.v, out.norm, *fatGauge, *longGauge, in.v, in.norm, parity, dagger, 
00192                       0, 0, 0, out.volume, out.length, in.Precision());
00193     
00194   flops += 1187*in.volume;
00195 }
00196 
00197 void DiracStaggered::DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00198                                 const QudaParity parity, const cudaColorSpinorField &x,
00199                                 const double &k) const
00200 {    
00201   if (!initDslash){
00202     initDslashConstants(*fatGauge, in.Stride(), 0);
00203   }
00204   checkParitySpinor(in, out);
00205   
00206   staggeredDslashCuda(out.v, out.norm, *fatGauge, *longGauge, in.v, in.norm, parity, dagger, x.v, x.norm, k, 
00207                       out.volume, out.length, in.Precision());
00208   
00209   flops += (1187+12)*in.volume;
00210 }
00211 
00212 void DiracStaggered::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00213 {
00214   errorQuda("DiracStaggered::M() is not implemented");  
00215 }
00216 
00217 void DiracStaggered::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00218 {
00219   
00220   if (!initDslash){
00221     initDslashConstants(*fatGauge, in.Stride(), 0);
00222   }
00223   
00224   bool reset = newTmp(&tmp1, in);
00225   
00226   cudaColorSpinorField* mytmp = dynamic_cast<cudaColorSpinorField*>(tmp1->even);
00227   cudaColorSpinorField* ineven = dynamic_cast<cudaColorSpinorField*>(in.even);
00228   cudaColorSpinorField* inodd = dynamic_cast<cudaColorSpinorField*>(in.odd);
00229   cudaColorSpinorField* outeven = dynamic_cast<cudaColorSpinorField*>(out.even);
00230   cudaColorSpinorField* outodd = dynamic_cast<cudaColorSpinorField*>(out.odd);
00231   
00232   //even
00233   Dslash(*mytmp, *ineven, QUDA_ODD_PARITY);  
00234   DslashXpay(*outeven, *mytmp, QUDA_EVEN_PARITY, *ineven, 4*mass*mass);
00235   
00236   //odd
00237   Dslash(*mytmp, *inodd, QUDA_EVEN_PARITY);  
00238   DslashXpay(*outodd, *mytmp, QUDA_ODD_PARITY, *inodd, 4*mass*mass);    
00239 
00240   deleteTmp(&tmp1, reset);
00241 }
00242 
00243 void DiracStaggered::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00244                              cudaColorSpinorField &x, cudaColorSpinorField &b, 
00245                              const QudaSolutionType solType) const
00246 {
00247   if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
00248     errorQuda("Preconditioned solution requires a preconditioned solve_type");
00249   }
00250 
00251   src = &b;
00252   sol = &x;  
00253 }
00254 
00255 void DiracStaggered::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00256                                  const QudaSolutionType solType) const
00257 {
00258   // do nothing
00259 }
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Defines