QUDA v0.3.2
A library for QCD on GPUs

quda/lib/dirac_clover.cpp

Go to the documentation of this file.
00001 #include <iostream>
00002 #include <dirac_quda.h>
00003 #include <blas_quda.h>
00004 
00005 DiracClover::DiracClover(const DiracParam &param)
00006   : DiracWilson(param), clover(*(param.clover))
00007 {
00008 
00009 }
00010 
00011 DiracClover::DiracClover(const DiracClover &dirac) 
00012   : DiracWilson(dirac), clover(dirac.clover)
00013 {
00014 
00015 }
00016 
00017 DiracClover::~DiracClover()
00018 {
00019 
00020 }
00021 
00022 DiracClover& DiracClover::operator=(const DiracClover &dirac)
00023 {
00024 
00025   if (&dirac != this) {
00026     DiracWilson::operator=(dirac);
00027     clover = dirac.clover;
00028   }
00029 
00030   return *this;
00031 }
00032 
00033 void DiracClover::checkParitySpinor(const cudaColorSpinorField &out, const cudaColorSpinorField &in,
00034                                     const FullClover &clover) const
00035 {
00036   Dirac::checkParitySpinor(out, in);
00037 
00038   if (out.Volume() != clover.even.volume) {
00039     errorQuda("Spinor volume %d doesn't match even clover volume %d",
00040               out.Volume(), clover.even.volume);
00041   }
00042   if (out.Volume() != clover.odd.volume) {
00043     errorQuda("Spinor volume %d doesn't match odd clover volume %d",
00044               out.Volume(), clover.odd.volume);
00045   }
00046 
00047 }
00048 
00049 // Protected method, also used for applying cloverInv
00050 void DiracClover::cloverApply(cudaColorSpinorField &out, const FullClover &clover, 
00051                               const cudaColorSpinorField &in, const QudaParity parity) const
00052 {
00053   if (!initDslash) initDslashConstants(gauge, in.stride, clover.even.stride);
00054   checkParitySpinor(in, out, clover);
00055 
00056   cloverCuda(out.v, out.norm, gauge, clover, in.v, in.norm, parity, 
00057              in.volume, in.length, in.Precision());
00058 
00059   flops += 504*in.volume;
00060 }
00061 
00062 // Public method to apply the clover term only
00063 void DiracClover::Clover(cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaParity parity) const
00064 {
00065   cloverApply(out, clover, in, parity);
00066 }
00067 
00068 // FIXME: create kernel to eliminate tmp
00069 void DiracClover::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00070 {
00071   checkFullSpinor(out, in);
00072   cudaColorSpinorField *tmp=0; // this hack allows for tmp2 to be full or parity field
00073   if (tmp2) {
00074     if (tmp2->SiteSubset() == QUDA_FULL_SITE_SUBSET) tmp = &(tmp2->Even());
00075     else tmp = tmp2;
00076   }
00077   bool reset = newTmp(&tmp, in.Even());
00078 
00079   Clover(*tmp, in.Odd(), QUDA_ODD_PARITY);
00080   DslashXpay(out.Odd(), in.Even(), QUDA_ODD_PARITY, *tmp, -kappa);
00081   Clover(*tmp, in.Even(), QUDA_EVEN_PARITY);
00082   DslashXpay(out.Even(), in.Odd(), QUDA_EVEN_PARITY, *tmp, -kappa);
00083 
00084   deleteTmp(&tmp, reset);
00085 }
00086 
00087 void DiracClover::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00088 {
00089   checkFullSpinor(out, in);
00090 
00091   bool reset = newTmp(&tmp1, in);
00092   checkFullSpinor(*tmp1, in);
00093 
00094   M(*tmp1, in);
00095   Mdag(out, *tmp1);
00096 
00097   deleteTmp(&tmp1, reset);
00098 }
00099 
00100 void DiracClover::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00101                           cudaColorSpinorField &x, cudaColorSpinorField &b, 
00102                           const QudaSolutionType solType) const
00103 {
00104   if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
00105     errorQuda("Preconditioned solution requires a preconditioned solve_type");
00106   }
00107 
00108   src = &b;
00109   sol = &x;
00110 }
00111 
00112 void DiracClover::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00113                               const QudaSolutionType solType) const
00114 {
00115   // do nothing
00116 }
00117 
00118 DiracCloverPC::DiracCloverPC(const DiracParam &param)
00119   : DiracClover(param), cloverInv(*(param.cloverInv))
00120 {
00121 
00122 }
00123 
00124 DiracCloverPC::DiracCloverPC(const DiracCloverPC &dirac) 
00125   : DiracClover(dirac), cloverInv(dirac.clover)
00126 {
00127 
00128 }
00129 
00130 DiracCloverPC::~DiracCloverPC()
00131 {
00132 
00133 }
00134 
00135 DiracCloverPC& DiracCloverPC::operator=(const DiracCloverPC &dirac)
00136 {
00137   if (&dirac != this) {
00138     DiracClover::operator=(dirac);
00139     cloverInv = dirac.cloverInv;
00140   }
00141 
00142   return *this;
00143 }
00144 
00145 // Public method
00146 void DiracCloverPC::CloverInv(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00147                               const QudaParity parity) const
00148 {
00149   cloverApply(out, cloverInv, in, parity);
00150 }
00151 
00152 // apply hopping term, then clover: (A_ee^-1 D_eo) or (A_oo^-1 D_oe),
00153 // and likewise for dagger: (A_ee^-1 D^dagger_eo) or (A_oo^-1 D^dagger_oe)
00154 // NOTE - this isn't Dslash dagger since order should be reversed!
00155 void DiracCloverPC::Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00156                            const QudaParity parity) const
00157 {
00158   if (!initDslash) initDslashConstants(gauge, in.stride, cloverInv.even.stride);
00159   checkParitySpinor(in, out, cloverInv);
00160   checkSpinorAlias(in, out);
00161 
00162   cloverDslashCuda(out.v, out.norm, gauge, cloverInv, in.v, in.norm, parity, dagger, 
00163                    0, 0, 0.0, out.volume, out.length, in.Precision());
00164 
00165   flops += (1320+504)*in.volume;
00166 }
00167 
00168 // xpay version of the above
00169 void DiracCloverPC::DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00170                                const QudaParity parity, const cudaColorSpinorField &x,
00171                                const double &k) const
00172 {
00173   if (!initDslash) initDslashConstants(gauge, in.stride, cloverInv.even.stride);
00174   checkParitySpinor(in, out, cloverInv);
00175   checkSpinorAlias(in, out);
00176 
00177   cloverDslashCuda(out.v, out.norm, gauge, cloverInv, in.v, in.norm, parity, dagger, 
00178                    x.v, x.norm, k, out.volume, out.length, in.Precision());
00179 
00180   flops += (1320+504+48)*in.volume;
00181 }
00182 
00183 // Apply the even-odd preconditioned clover-improved Dirac operator
00184 void DiracCloverPC::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00185 {
00186   double kappa2 = -kappa*kappa;
00187 
00188   // FIXME: For asymmetric, a "DslashCxpay" kernel would improve performance.
00189   bool reset = newTmp(&tmp1, in);
00190 
00191   if (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) {
00192     Dslash(*tmp1, in, QUDA_ODD_PARITY);
00193     Clover(out, in, QUDA_EVEN_PARITY);
00194     DiracWilson::DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, out, kappa2); // safe since out is not read after writing
00195   } else if (matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) {
00196     Dslash(*tmp1, in, QUDA_EVEN_PARITY);
00197     Clover(out, in, QUDA_ODD_PARITY);
00198     DiracWilson::DslashXpay(out, *tmp1, QUDA_ODD_PARITY, out, kappa2);
00199   } else if (!dagger) { // symmetric preconditioning
00200     if (matpcType == QUDA_MATPC_EVEN_EVEN) {
00201       Dslash(*tmp1, in, QUDA_ODD_PARITY);
00202       DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, in, kappa2); 
00203     } else if (matpcType == QUDA_MATPC_ODD_ODD) {
00204       Dslash(*tmp1, in, QUDA_EVEN_PARITY);
00205       DslashXpay(out, *tmp1, QUDA_ODD_PARITY, in, kappa2); 
00206     } else {
00207       errorQuda("Invalid matpcType");
00208     }
00209   } else { // symmetric preconditioning, dagger
00210     if (matpcType == QUDA_MATPC_EVEN_EVEN) {
00211       CloverInv(out, in, QUDA_EVEN_PARITY); 
00212       Dslash(*tmp1, out, QUDA_ODD_PARITY);
00213       DiracWilson::DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, in, kappa2); 
00214     } else if (matpcType == QUDA_MATPC_ODD_ODD) {
00215       CloverInv(out, in, QUDA_ODD_PARITY); 
00216       Dslash(*tmp1, out, QUDA_EVEN_PARITY);
00217       DiracWilson::DslashXpay(out, *tmp1, QUDA_ODD_PARITY, in, kappa2); 
00218     } else {
00219       errorQuda("MatPCType %d not valid for DiracCloverPC", matpcType);
00220     }
00221   }
00222   
00223   deleteTmp(&tmp1, reset);
00224 }
00225 
00226 void DiracCloverPC::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00227 {
00228   // need extra temporary because of symmetric preconditioning dagger
00229   bool reset = newTmp(&tmp2, in);
00230 
00231   M(*tmp2, in);
00232   Mdag(out, *tmp2);
00233 
00234   deleteTmp(&tmp2, reset);
00235 }
00236 
00237 void DiracCloverPC::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 
00238                             cudaColorSpinorField &x, cudaColorSpinorField &b, 
00239                             const QudaSolutionType solType) const
00240 {
00241   // we desire solution to preconditioned system
00242   if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
00243     src = &b;
00244     sol = &x;
00245     return;
00246   }
00247 
00248   bool reset = newTmp(&tmp1, b.Even());
00249   
00250   // we desire solution to full system
00251   if (matpcType == QUDA_MATPC_EVEN_EVEN) {
00252     // src = A_ee^-1 (b_e + k D_eo A_oo^-1 b_o)
00253     src = &(x.Odd());
00254     CloverInv(*src, b.Odd(), QUDA_ODD_PARITY);
00255     DiracWilson::DslashXpay(*tmp1, *src, QUDA_EVEN_PARITY, b.Even(), kappa);
00256     CloverInv(*src, *tmp1, QUDA_EVEN_PARITY);
00257     sol = &(x.Even());
00258   } else if (matpcType == QUDA_MATPC_ODD_ODD) {
00259     // src = A_oo^-1 (b_o + k D_oe A_ee^-1 b_e)
00260     src = &(x.Even());
00261     CloverInv(*src, b.Even(), QUDA_EVEN_PARITY);
00262     DiracWilson::DslashXpay(*tmp1, *src, QUDA_ODD_PARITY, b.Odd(), kappa);
00263     CloverInv(*src, *tmp1, QUDA_ODD_PARITY);
00264     sol = &(x.Odd());
00265   } else if (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) {
00266     // src = b_e + k D_eo A_oo^-1 b_o
00267     src = &(x.Odd());
00268     CloverInv(*tmp1, b.Odd(), QUDA_ODD_PARITY); // safe even when *tmp1 = b.odd
00269     DiracWilson::DslashXpay(*src, *tmp1, QUDA_EVEN_PARITY, b.Even(), kappa);
00270     sol = &(x.Even());
00271   } else if (matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) {
00272     // src = b_o + k D_oe A_ee^-1 b_e
00273     src = &(x.Even());
00274     CloverInv(*tmp1, b.Even(), QUDA_EVEN_PARITY); // safe even when *tmp1 = b.even
00275     DiracWilson::DslashXpay(*src, *tmp1, QUDA_ODD_PARITY, b.Odd(), kappa);
00276     sol = &(x.Odd());
00277   } else {
00278     errorQuda("MatPCType %d not valid for DiracCloverPC", matpcType);
00279   }
00280 
00281   // here we use final solution to store parity solution and parity source
00282   // b is now up for grabs if we want
00283 
00284   deleteTmp(&tmp1, reset);
00285 
00286 }
00287 
00288 void DiracCloverPC::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00289                                 const QudaSolutionType solType) const
00290 {
00291   if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
00292     return;
00293   }
00294 
00295   checkFullSpinor(x, b);
00296 
00297   bool reset = newTmp(&tmp1, b.Even());
00298 
00299   // create full solution
00300 
00301   if (matpcType == QUDA_MATPC_EVEN_EVEN ||
00302       matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) {
00303     // x_o = A_oo^-1 (b_o + k D_oe x_e)
00304     DiracWilson::DslashXpay(*tmp1, x.Even(), QUDA_ODD_PARITY, b.Odd(), kappa);
00305     CloverInv(x.Odd(), *tmp1, QUDA_ODD_PARITY);
00306   } else if (matpcType == QUDA_MATPC_ODD_ODD ||
00307       matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) {
00308     // x_e = A_ee^-1 (b_e + k D_eo x_o)
00309     DiracWilson::DslashXpay(*tmp1, x.Odd(), QUDA_EVEN_PARITY, b.Even(), kappa);
00310     CloverInv(x.Even(), *tmp1, QUDA_EVEN_PARITY);
00311   } else {
00312     errorQuda("MatPCType %d not valid for DiracCloverPC", matpcType);
00313   }
00314 
00315   deleteTmp(&tmp1, reset);
00316 
00317 }
00318 
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Defines