QUDA v0.3.2
A library for QCD on GPUs

quda/lib/dirac_twisted_mass.cpp

Go to the documentation of this file.
00001 #include <dirac_quda.h>
00002 #include <blas_quda.h>
00003 #include <iostream>
00004 
00005 DiracTwistedMass::DiracTwistedMass(const DiracParam &param)
00006   : DiracWilson(param), mu(param.mu)
00007 {
00008 
00009 }
00010 
00011 DiracTwistedMass::DiracTwistedMass(const DiracTwistedMass &dirac) 
00012   : DiracWilson(dirac), mu(dirac.mu)
00013 {
00014 
00015 }
00016 
00017 DiracTwistedMass::~DiracTwistedMass()
00018 {
00019 
00020 }
00021 
00022 DiracTwistedMass& DiracTwistedMass::operator=(const DiracTwistedMass &dirac)
00023 {
00024   if (&dirac != this) {
00025     DiracWilson::operator=(dirac);
00026   }
00027   return *this;
00028 }
00029 
00030 // Protected method for applying twist
00031 void DiracTwistedMass::twistedApply(cudaColorSpinorField &out, const cudaColorSpinorField &in,
00032                                     const QudaTwistGamma5Type twistType) const
00033 {
00034   checkParitySpinor(out, in);
00035   
00036   if (!initDslash) initDslashConstants(gauge, in.stride, 0);
00037 
00038   if (in.twistFlavor == QUDA_TWIST_NO || in.twistFlavor == QUDA_TWIST_INVALID)
00039     errorQuda("Twist flavor not set %d\n", in.twistFlavor);
00040 
00041   double flavor_mu = in.twistFlavor * mu;
00042   
00043   twistGamma5Cuda(out.v, out.norm, in.v, in.norm, dagger, kappa, flavor_mu, 
00044                   in.volume, in.length, in.precision, twistType);
00045 }
00046 
00047 
00048 // Public method to apply the twist
00049 void DiracTwistedMass::Twist(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00050 {
00051   twistedApply(out, in, QUDA_TWIST_GAMMA5_DIRECT);
00052 }
00053 
00054 void DiracTwistedMass::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00055 {
00056   checkFullSpinor(out, in);
00057   if (in.twistFlavor != out.twistFlavor) 
00058     errorQuda("Twist flavors %d %d don't match", in.twistFlavor, out.twistFlavor);
00059 
00060   if (in.twistFlavor == QUDA_TWIST_NO || in.twistFlavor == QUDA_TWIST_INVALID) {
00061     errorQuda("Twist flavor not set %d\n", in.twistFlavor);
00062   }
00063 
00064   // We can eliminate this temporary at the expense of more kernels (like clover)
00065   cudaColorSpinorField *tmp=0; // this hack allows for tmp2 to be full or parity field
00066   if (tmp2) {
00067     if (tmp2->SiteSubset() == QUDA_FULL_SITE_SUBSET) tmp = &(tmp2->Even());
00068     else tmp = tmp2;
00069   }
00070   bool reset = newTmp(&tmp, in.Even());
00071 
00072   Twist(*tmp, in.Odd());
00073   DslashXpay(out.Odd(), in.Even(), QUDA_ODD_PARITY, *tmp, -kappa);
00074   Twist(*tmp, in.Even());
00075   DslashXpay(out.Even(), in.Odd(), QUDA_EVEN_PARITY, *tmp, -kappa);
00076 
00077   deleteTmp(&tmp, reset);
00078 
00079 }
00080 
00081 void DiracTwistedMass::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00082 {
00083   checkFullSpinor(out, in);
00084   bool reset = newTmp(&tmp1, in);
00085 
00086   M(*tmp1, in);
00087   Mdag(out, *tmp1);
00088 
00089   deleteTmp(&tmp1, reset);
00090 }
00091 
00092 void DiracTwistedMass::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00093                           cudaColorSpinorField &x, cudaColorSpinorField &b, 
00094                           const QudaSolutionType solType) const
00095 {
00096   if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
00097     errorQuda("Preconditioned solution requires a preconditioned solve_type");
00098   }
00099 
00100   src = &b;
00101   sol = &x;
00102 }
00103 
00104 void DiracTwistedMass::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00105                               const QudaSolutionType solType) const
00106 {
00107   // do nothing
00108 }
00109 
00110 DiracTwistedMassPC::DiracTwistedMassPC(const DiracParam &param)
00111   : DiracTwistedMass(param)
00112 {
00113 
00114 }
00115 
00116 DiracTwistedMassPC::DiracTwistedMassPC(const DiracTwistedMassPC &dirac) 
00117   : DiracTwistedMass(dirac)
00118 {
00119 
00120 }
00121 
00122 DiracTwistedMassPC::~DiracTwistedMassPC()
00123 {
00124 
00125 }
00126 
00127 DiracTwistedMassPC& DiracTwistedMassPC::operator=(const DiracTwistedMassPC &dirac)
00128 {
00129   if (&dirac != this) {
00130     DiracTwistedMass::operator=(dirac);
00131   }
00132   return *this;
00133 }
00134 
00135 // Public method to apply the inverse twist
00136 void DiracTwistedMassPC::TwistInv(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00137 {
00138   twistedApply(out, in, QUDA_TWIST_GAMMA5_INVERSE);
00139 }
00140 
00141 // apply hopping term, then inverse twist: (A_ee^-1 D_eo) or (A_oo^-1 D_oe),
00142 // and likewise for dagger: (D^dagger_eo D_ee^-1) or (D^dagger_oe A_oo^-1)
00143 void DiracTwistedMassPC::Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00144                                 const QudaParity parity) const
00145 {
00146   if (!initDslash) initDslashConstants(gauge, in.stride, 0);
00147   checkParitySpinor(in, out);
00148   checkSpinorAlias(in, out);
00149 
00150   if (in.twistFlavor != out.twistFlavor) 
00151     errorQuda("Twist flavors %d %d don't match", in.twistFlavor, out.twistFlavor);
00152   if (in.twistFlavor == QUDA_TWIST_NO || in.twistFlavor == QUDA_TWIST_INVALID)
00153     errorQuda("Twist flavor not set %d\n", in.twistFlavor);
00154 
00155   if (!dagger || matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) {
00156     double flavor_mu = in.twistFlavor * mu;
00157     twistedMassDslashCuda(out.v, out.norm, gauge, in.v, in.norm, parity, dagger, 
00158                           0, 0, kappa, flavor_mu, 0.0, out.volume, out.length, in.Precision());
00159     flops += (1320+72)*in.volume;
00160   } else { // safe to use tmp2 here which may alias in
00161     bool reset = newTmp(&tmp2, in);
00162 
00163     TwistInv(*tmp2, in);
00164     DiracWilson::Dslash(out, *tmp2, parity);
00165 
00166     flops += 72*in.volume;
00167 
00168     // if the pointers alias, undo the twist
00169     if (tmp2->v == in.v) Twist(*tmp2, *tmp2); 
00170 
00171     deleteTmp(&tmp2, reset);
00172   }
00173 
00174 }
00175 
00176 // xpay version of the above
00177 void DiracTwistedMassPC::DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00178                                     const QudaParity parity, const cudaColorSpinorField &x,
00179                                     const double &k) const
00180 {
00181   if (!initDslash) initDslashConstants(gauge, in.stride, 0);
00182   checkParitySpinor(in, out);
00183   checkSpinorAlias(in, out);
00184 
00185   if (in.twistFlavor != out.twistFlavor) 
00186     errorQuda("Twist flavors %d %d don't match", in.twistFlavor, out.twistFlavor);
00187   if (in.twistFlavor == QUDA_TWIST_NO || in.twistFlavor == QUDA_TWIST_INVALID)
00188     errorQuda("Twist flavor not set %d\n", in.twistFlavor);  
00189 
00190   if (!dagger) {
00191     double flavor_mu = in.twistFlavor * mu;
00192     twistedMassDslashCuda(out.v, out.norm, gauge, in.v, in.norm, parity, dagger, 
00193                           x.v, x.norm, kappa, flavor_mu, k, out.volume, out.length, in.Precision());
00194     flops += (1320+96)*in.volume;
00195   } else { // tmp1 can alias in, but tmp2 can alias x so must not use this
00196     bool reset = newTmp(&tmp1, in);
00197 
00198     TwistInv(*tmp1, in);
00199     DiracWilson::Dslash(out, *tmp1, parity);
00200     xpayCuda(x, k, out);
00201     flops += 96*in.volume;
00202 
00203     // if the pointers alias, undo the twist
00204     if (tmp1->v == in.v) Twist(*tmp1, *tmp1); 
00205 
00206     deleteTmp(&tmp1, reset);
00207   }
00208 
00209 }
00210 
00211 void DiracTwistedMassPC::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00212 {
00213   double kappa2 = -kappa*kappa;
00214 
00215   bool reset = newTmp(&tmp1, in);
00216 
00217   if (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) {
00218     Dslash(*tmp1, in, QUDA_ODD_PARITY); // fused kernel
00219     Twist(out, in);
00220     DiracWilson::DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, out, kappa2); // safe since out is not read after writing
00221   } else if (matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) {
00222     Dslash(*tmp1, in, QUDA_EVEN_PARITY); // fused kernel
00223     Twist(out, in);
00224     DiracWilson::DslashXpay(out, *tmp1, QUDA_ODD_PARITY, out, kappa2);
00225   } else { // symmetric preconditioning
00226     if (matpcType == QUDA_MATPC_EVEN_EVEN) {
00227       Dslash(*tmp1, in, QUDA_ODD_PARITY);
00228       DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, in, kappa2); 
00229     } else if (matpcType == QUDA_MATPC_ODD_ODD) {
00230       Dslash(*tmp1, in, QUDA_EVEN_PARITY);
00231       DslashXpay(out, *tmp1, QUDA_ODD_PARITY, in, kappa2); 
00232     } else {
00233       errorQuda("Invalid matpcType");
00234     }
00235   }
00236 
00237   deleteTmp(&tmp1, reset);
00238 
00239 }
00240 
00241 void DiracTwistedMassPC::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00242 {
00243   // need extra temporary because of symmetric preconditioning dagger
00244   bool reset = newTmp(&tmp2, in);
00245 
00246   M(*tmp2, in);
00247   Mdag(out, *tmp2);
00248 
00249   deleteTmp(&tmp2, reset);
00250 }
00251 
00252 void DiracTwistedMassPC::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00253                             cudaColorSpinorField &x, cudaColorSpinorField &b, 
00254                             const QudaSolutionType solType) const
00255 {
00256   // we desire solution to preconditioned system
00257   if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
00258     src = &b;
00259     sol = &x;
00260     return;
00261   }
00262 
00263   bool reset = newTmp(&tmp1, b.Even());
00264   
00265   // we desire solution to full system
00266   if (matpcType == QUDA_MATPC_EVEN_EVEN) {
00267     // src = A_ee^-1 (b_e + k D_eo A_oo^-1 b_o)
00268     src = &(x.Odd());
00269     TwistInv(*src, b.Odd());
00270     DiracWilson::DslashXpay(*tmp1, *src, QUDA_EVEN_PARITY, b.Even(), kappa);
00271     TwistInv(*src, *tmp1);
00272     sol = &(x.Even());
00273   } else if (matpcType == QUDA_MATPC_ODD_ODD) {
00274     // src = A_oo^-1 (b_o + k D_oe A_ee^-1 b_e)
00275     src = &(x.Even());
00276     TwistInv(*src, b.Even());
00277     DiracWilson::DslashXpay(*tmp1, *src, QUDA_ODD_PARITY, b.Odd(), kappa);
00278     TwistInv(*src, *tmp1);
00279     sol = &(x.Odd());
00280   } else if (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) {
00281     // src = b_e + k D_eo A_oo^-1 b_o
00282     src = &(x.Odd());
00283     TwistInv(*tmp1, b.Odd()); // safe even when *tmp1 = b.odd
00284     DiracWilson::DslashXpay(*src, *tmp1, QUDA_EVEN_PARITY, b.Even(), kappa);
00285     sol = &(x.Even());
00286   } else if (matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) {
00287     // src = b_o + k D_oe A_ee^-1 b_e
00288     src = &(x.Even());
00289     TwistInv(*tmp1, b.Even()); // safe even when *tmp1 = b.even
00290     DiracWilson::DslashXpay(*src, *tmp1, QUDA_ODD_PARITY, b.Odd(), kappa);
00291     sol = &(x.Odd());
00292   } else {
00293     errorQuda("MatPCType %d not valid for DiracTwistedMassPC", matpcType);
00294   }
00295 
00296   // here we use final solution to store parity solution and parity source
00297   // b is now up for grabs if we want
00298 
00299   deleteTmp(&tmp1, reset);
00300 }
00301 
00302 void DiracTwistedMassPC::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00303                                 const QudaSolutionType solType) const
00304 {
00305   if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
00306     return;
00307   }                             
00308 
00309   checkFullSpinor(x, b);
00310   bool reset = newTmp(&tmp1, b.Even());
00311 
00312   // create full solution
00313   
00314   if (matpcType == QUDA_MATPC_EVEN_EVEN ||
00315       matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) {
00316     // x_o = A_oo^-1 (b_o + k D_oe x_e)
00317     DiracWilson::DslashXpay(*tmp1, x.Even(), QUDA_ODD_PARITY, b.Odd(), kappa);
00318     TwistInv(x.Odd(), *tmp1);
00319   } else if (matpcType == QUDA_MATPC_ODD_ODD ||
00320              matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) {
00321     // x_e = A_ee^-1 (b_e + k D_eo x_o)
00322     DiracWilson::DslashXpay(*tmp1, x.Odd(), QUDA_EVEN_PARITY, b.Even(), kappa);
00323     TwistInv(x.Even(), *tmp1);
00324   } else {
00325     errorQuda("MatPCType %d not valid for DiracTwistedMassPC", matpcType);
00326   }
00327   
00328   deleteTmp(&tmp1, reset);
00329 }
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Defines