QUDA v0.4.0
A library for QCD on GPUs
quda/lib/dirac_clover.cpp
Go to the documentation of this file.
00001 #include <iostream>
00002 #include <dirac_quda.h>
00003 #include <blas_quda.h>
00004 
00005 DiracClover::DiracClover(const DiracParam &param)
00006   : DiracWilson(param), clover(*(param.clover)) { }
00007 
00008 DiracClover::DiracClover(const DiracClover &dirac) 
00009   : DiracWilson(dirac), clover(dirac.clover) { }
00010 
00011 DiracClover::~DiracClover()
00012 {
00013 
00014 }
00015 
00016 DiracClover& DiracClover::operator=(const DiracClover &dirac)
00017 {
00018   if (&dirac != this) {
00019     DiracWilson::operator=(dirac);
00020     clover = dirac.clover;
00021   }
00022   return *this;
00023 }
00024 
00025 void DiracClover::checkParitySpinor(const cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00026 {
00027   Dirac::checkParitySpinor(out, in);
00028 
00029   if (out.Volume() != clover.VolumeCB()) {
00030     errorQuda("Parity spinor volume %d doesn't match clover checkboard volume %d",
00031               out.Volume(), clover.VolumeCB());
00032   }
00033 }
00034 
00035 // Public method to apply the clover term only
00036 void DiracClover::Clover(cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaParity parity) const
00037 {
00038   if (!initDslash) initDslashConstants(gauge, in.Stride());
00039   if (!initClover) initCloverConstants(clover.Stride());
00040   checkParitySpinor(in, out);
00041 
00042   // regular clover term
00043   FullClover cs;
00044   cs.even = clover.even; cs.odd = clover.odd; cs.evenNorm = clover.evenNorm; cs.oddNorm = clover.oddNorm;
00045   cs.precision = clover.precision; cs.bytes = clover.bytes, cs.norm_bytes = clover.norm_bytes;
00046   cloverCuda(&out, gauge, cs, &in, parity);
00047 
00048   flops += 504*in.Volume();
00049 }
00050 
00051 // FIXME: create kernel to eliminate tmp
00052 void DiracClover::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00053 {
00054   checkFullSpinor(out, in);
00055   cudaColorSpinorField *tmp=0; // this hack allows for tmp2 to be full or parity field
00056   if (tmp2) {
00057     if (tmp2->SiteSubset() == QUDA_FULL_SITE_SUBSET) tmp = &(tmp2->Even());
00058     else tmp = tmp2;
00059   }
00060   bool reset = newTmp(&tmp, in.Even());
00061 
00062   Clover(*tmp, in.Odd(), QUDA_ODD_PARITY);
00063   DslashXpay(out.Odd(), in.Even(), QUDA_ODD_PARITY, *tmp, -kappa);
00064   Clover(*tmp, in.Even(), QUDA_EVEN_PARITY);
00065   DslashXpay(out.Even(), in.Odd(), QUDA_EVEN_PARITY, *tmp, -kappa);
00066 
00067   deleteTmp(&tmp, reset);
00068 }
00069 
00070 void DiracClover::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00071 {
00072   checkFullSpinor(out, in);
00073 
00074   bool reset = newTmp(&tmp1, in);
00075   checkFullSpinor(*tmp1, in);
00076 
00077   M(*tmp1, in);
00078   Mdag(out, *tmp1);
00079 
00080   deleteTmp(&tmp1, reset);
00081 }
00082 
00083 void DiracClover::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00084                           cudaColorSpinorField &x, cudaColorSpinorField &b, 
00085                           const QudaSolutionType solType) const
00086 {
00087   if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
00088     errorQuda("Preconditioned solution requires a preconditioned solve_type");
00089   }
00090 
00091   src = &b;
00092   sol = &x;
00093 }
00094 
00095 void DiracClover::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00096                               const QudaSolutionType solType) const
00097 {
00098   // do nothing
00099 }
00100 
00101 DiracCloverPC::DiracCloverPC(const DiracParam &param) : 
00102   DiracClover(param)
00103 {
00104   // For the preconditioned operator, we need to check that the inverse of the clover term is present
00105   if (!clover.cloverInv) errorQuda("Clover inverse required for DiracCloverPC");
00106 }
00107 
00108 DiracCloverPC::DiracCloverPC(const DiracCloverPC &dirac) : DiracClover(dirac) { }
00109 
00110 DiracCloverPC::~DiracCloverPC() { }
00111 
00112 DiracCloverPC& DiracCloverPC::operator=(const DiracCloverPC &dirac)
00113 {
00114   if (&dirac != this) {
00115     DiracClover::operator=(dirac);
00116   }
00117   return *this;
00118 }
00119 
00120 // Public method
00121 void DiracCloverPC::CloverInv(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00122                               const QudaParity parity) const
00123 {
00124   if (!initDslash) initDslashConstants(gauge, in.Stride());
00125   if (!initClover) initCloverConstants(clover.Stride());
00126   checkParitySpinor(in, out);
00127 
00128   // needs to be cloverinv
00129   FullClover cs;
00130   cs.even = clover.evenInv; cs.odd = clover.oddInv; cs.evenNorm = clover.evenInvNorm; cs.oddNorm = clover.oddInvNorm;
00131   cs.precision = clover.precision; cs.bytes = clover.bytes, cs.norm_bytes = clover.norm_bytes;
00132   cloverCuda(&out, gauge, cs, &in, parity);
00133 
00134   flops += 504*in.Volume();
00135 }
00136 
00137 // apply hopping term, then clover: (A_ee^-1 D_eo) or (A_oo^-1 D_oe),
00138 // and likewise for dagger: (A_ee^-1 D^dagger_eo) or (A_oo^-1 D^dagger_oe)
00139 // NOTE - this isn't Dslash dagger since order should be reversed!
00140 void DiracCloverPC::Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00141                            const QudaParity parity) const
00142 {
00143   if (!initDslash) initDslashConstants(gauge, in.Stride());
00144   if (!initClover) initCloverConstants(clover.Stride());
00145   checkParitySpinor(in, out);
00146   checkSpinorAlias(in, out);
00147 
00148   setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda
00149 
00150   FullClover cs;
00151   cs.even = clover.evenInv; cs.odd = clover.oddInv; cs.evenNorm = clover.evenInvNorm; cs.oddNorm = clover.oddInvNorm;
00152   cs.precision = clover.precision; cs.bytes = clover.bytes, cs.norm_bytes = clover.norm_bytes;
00153   cloverDslashCuda(&out, gauge, cs, &in, parity, dagger, 0, 0.0, commDim);
00154 
00155   flops += (1320+504)*in.Volume();
00156 }
00157 
00158 // xpay version of the above
00159 void DiracCloverPC::DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00160                                const QudaParity parity, const cudaColorSpinorField &x,
00161                                const double &k) const
00162 {
00163   if (!initDslash) initDslashConstants(gauge, in.Stride());
00164   if (!initClover) initCloverConstants(clover.Stride());
00165   checkParitySpinor(in, out);
00166   checkSpinorAlias(in, out);
00167 
00168   setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda
00169 
00170   FullClover cs;
00171   cs.even = clover.evenInv; cs.odd = clover.oddInv; cs.evenNorm = clover.evenInvNorm; cs.oddNorm = clover.oddInvNorm;
00172   cs.precision = clover.precision; cs.bytes = clover.bytes, cs.norm_bytes = clover.norm_bytes;
00173   cloverDslashCuda(&out, gauge, cs, &in, parity, dagger, &x, k, commDim);
00174 
00175   flops += (1320+504+48)*in.Volume();
00176 }
00177 
00178 // Apply the even-odd preconditioned clover-improved Dirac operator
00179 void DiracCloverPC::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00180 {
00181   double kappa2 = -kappa*kappa;
00182 
00183   // FIXME: For asymmetric, a "DslashCxpay" kernel would improve performance.
00184   bool reset1 = newTmp(&tmp1, in);
00185 
00186   if (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) {
00187     bool reset2 = newTmp(&tmp2, in);
00188     // DiracCloverPC::Dslash applies A^{-1}Dslash
00189     Dslash(*tmp1, in, QUDA_ODD_PARITY);
00190     Clover(*tmp2, in, QUDA_EVEN_PARITY);
00191 
00192     // DiracWilson::Dslash applies only Dslash
00193     DiracWilson::DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, *tmp2, kappa2); 
00194 
00195     deleteTmp(&tmp2, reset2);
00196   } else if (matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) {
00197 
00198     // FIXME: It would be nice if I could do something like: cudaColorSpinorField tmp3( in.param() );
00199     // to save copying the data from 'in'
00200     bool reset2 = newTmp(&tmp2, in);
00201 
00202     // DiracCloverPC::Dslash applies A^{-1}Dslash
00203     Dslash(*tmp1, in, QUDA_EVEN_PARITY);
00204     Clover(*tmp2, in, QUDA_ODD_PARITY);
00205 
00206     // DiracWilson::Dslash applies only Dslash
00207     DiracWilson::DslashXpay(out, *tmp1, QUDA_ODD_PARITY, *tmp2, kappa2);
00208 
00209     deleteTmp(&tmp2, reset2);
00210   } else if (!dagger) { // symmetric preconditioning
00211     if (matpcType == QUDA_MATPC_EVEN_EVEN) {
00212       Dslash(*tmp1, in, QUDA_ODD_PARITY);
00213       DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, in, kappa2); 
00214     } else if (matpcType == QUDA_MATPC_ODD_ODD) {
00215       Dslash(*tmp1, in, QUDA_EVEN_PARITY);
00216       DslashXpay(out, *tmp1, QUDA_ODD_PARITY, in, kappa2); 
00217     } else {
00218       errorQuda("Invalid matpcType");
00219     }
00220   } else { // symmetric preconditioning, dagger
00221     if (matpcType == QUDA_MATPC_EVEN_EVEN) {
00222       CloverInv(out, in, QUDA_EVEN_PARITY); 
00223       Dslash(*tmp1, out, QUDA_ODD_PARITY);
00224       DiracWilson::DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, in, kappa2); 
00225     } else if (matpcType == QUDA_MATPC_ODD_ODD) {
00226       CloverInv(out, in, QUDA_ODD_PARITY); 
00227       Dslash(*tmp1, out, QUDA_EVEN_PARITY);
00228       DiracWilson::DslashXpay(out, *tmp1, QUDA_ODD_PARITY, in, kappa2); 
00229     } else {
00230       errorQuda("MatPCType %d not valid for DiracCloverPC", matpcType);
00231     }
00232   }
00233   
00234   deleteTmp(&tmp1, reset1);
00235 }
00236 
00237 void DiracCloverPC::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00238 {
00239   // need extra temporary because of symmetric preconditioning dagger
00240   bool reset = newTmp(&tmp2, in);
00241 
00242   M(*tmp2, in);
00243   Mdag(out, *tmp2);
00244 
00245   deleteTmp(&tmp2, reset);
00246 }
00247 
00248 void DiracCloverPC::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 
00249                             cudaColorSpinorField &x, cudaColorSpinorField &b, 
00250                             const QudaSolutionType solType) const
00251 {
00252   // we desire solution to preconditioned system
00253   if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
00254     src = &b;
00255     sol = &x;
00256     return;
00257   }
00258 
00259   bool reset = newTmp(&tmp1, b.Even());
00260   
00261   // we desire solution to full system
00262   if (matpcType == QUDA_MATPC_EVEN_EVEN) {
00263     // src = A_ee^-1 (b_e + k D_eo A_oo^-1 b_o)
00264     src = &(x.Odd());
00265     CloverInv(*src, b.Odd(), QUDA_ODD_PARITY);
00266     DiracWilson::DslashXpay(*tmp1, *src, QUDA_EVEN_PARITY, b.Even(), kappa);
00267     CloverInv(*src, *tmp1, QUDA_EVEN_PARITY);
00268     sol = &(x.Even());
00269   } else if (matpcType == QUDA_MATPC_ODD_ODD) {
00270     // src = A_oo^-1 (b_o + k D_oe A_ee^-1 b_e)
00271     src = &(x.Even());
00272     CloverInv(*src, b.Even(), QUDA_EVEN_PARITY);
00273     DiracWilson::DslashXpay(*tmp1, *src, QUDA_ODD_PARITY, b.Odd(), kappa);
00274     CloverInv(*src, *tmp1, QUDA_ODD_PARITY);
00275     sol = &(x.Odd());
00276   } else if (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) {
00277     // src = b_e + k D_eo A_oo^-1 b_o
00278     src = &(x.Odd());
00279     CloverInv(*tmp1, b.Odd(), QUDA_ODD_PARITY); // safe even when *tmp1 = b.odd
00280     DiracWilson::DslashXpay(*src, *tmp1, QUDA_EVEN_PARITY, b.Even(), kappa);
00281     sol = &(x.Even());
00282   } else if (matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) {
00283     // src = b_o + k D_oe A_ee^-1 b_e
00284     src = &(x.Even());
00285     CloverInv(*tmp1, b.Even(), QUDA_EVEN_PARITY); // safe even when *tmp1 = b.even
00286     DiracWilson::DslashXpay(*src, *tmp1, QUDA_ODD_PARITY, b.Odd(), kappa);
00287     sol = &(x.Odd());
00288   } else {
00289     errorQuda("MatPCType %d not valid for DiracCloverPC", matpcType);
00290   }
00291 
00292   // here we use final solution to store parity solution and parity source
00293   // b is now up for grabs if we want
00294 
00295   deleteTmp(&tmp1, reset);
00296 
00297 }
00298 
00299 void DiracCloverPC::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00300                                 const QudaSolutionType solType) const
00301 {
00302   if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
00303     return;
00304   }
00305 
00306   checkFullSpinor(x, b);
00307 
00308   bool reset = newTmp(&tmp1, b.Even());
00309 
00310   // create full solution
00311 
00312   if (matpcType == QUDA_MATPC_EVEN_EVEN ||
00313       matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) {
00314     // x_o = A_oo^-1 (b_o + k D_oe x_e)
00315     DiracWilson::DslashXpay(*tmp1, x.Even(), QUDA_ODD_PARITY, b.Odd(), kappa);
00316     CloverInv(x.Odd(), *tmp1, QUDA_ODD_PARITY);
00317   } else if (matpcType == QUDA_MATPC_ODD_ODD ||
00318       matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) {
00319     // x_e = A_ee^-1 (b_e + k D_eo x_o)
00320     DiracWilson::DslashXpay(*tmp1, x.Odd(), QUDA_EVEN_PARITY, b.Even(), kappa);
00321     CloverInv(x.Even(), *tmp1, QUDA_EVEN_PARITY);
00322   } else {
00323     errorQuda("MatPCType %d not valid for DiracCloverPC", matpcType);
00324   }
00325 
00326   deleteTmp(&tmp1, reset);
00327 
00328 }
00329 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines