QUDA v0.4.0
A library for QCD on GPUs
quda/lib/dirac_twisted_mass.cpp
Go to the documentation of this file.
00001 #include <dirac_quda.h>
00002 #include <blas_quda.h>
00003 #include <iostream>
00004 
00005 DiracTwistedMass::DiracTwistedMass(const DiracParam &param) : DiracWilson(param), mu(param.mu) { }
00006 
00007 DiracTwistedMass::DiracTwistedMass(const DiracTwistedMass &dirac) : DiracWilson(dirac), mu(dirac.mu) { }
00008 
00009 DiracTwistedMass::~DiracTwistedMass() { }
00010 
00011 DiracTwistedMass& DiracTwistedMass::operator=(const DiracTwistedMass &dirac)
00012 {
00013   if (&dirac != this) {
00014     DiracWilson::operator=(dirac);
00015   }
00016   return *this;
00017 }
00018 
00019 // Protected method for applying twist
00020 void DiracTwistedMass::twistedApply(cudaColorSpinorField &out, const cudaColorSpinorField &in,
00021                                     const QudaTwistGamma5Type twistType) const
00022 {
00023   checkParitySpinor(out, in);
00024   
00025   if (!initDslash) initDslashConstants(gauge, in.Stride());
00026 
00027   if (in.TwistFlavor() == QUDA_TWIST_NO || in.TwistFlavor() == QUDA_TWIST_INVALID)
00028     errorQuda("Twist flavor not set %d\n", in.TwistFlavor());
00029 
00030   double flavor_mu = in.TwistFlavor() * mu;
00031   
00032   twistGamma5Cuda(&out, &in, dagger, kappa, flavor_mu, twistType);
00033 
00034   flops += 24*in.Volume();
00035 }
00036 
00037 
00038 // Public method to apply the twist
00039 void DiracTwistedMass::Twist(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00040 {
00041   twistedApply(out, in, QUDA_TWIST_GAMMA5_DIRECT);
00042 }
00043 
00044 void DiracTwistedMass::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00045 {
00046   checkFullSpinor(out, in);
00047   if (in.TwistFlavor() != out.TwistFlavor()) 
00048     errorQuda("Twist flavors %d %d don't match", in.TwistFlavor(), out.TwistFlavor());
00049 
00050   if (in.TwistFlavor() == QUDA_TWIST_NO || in.TwistFlavor() == QUDA_TWIST_INVALID) {
00051     errorQuda("Twist flavor not set %d\n", in.TwistFlavor());
00052   }
00053 
00054   // We can eliminate this temporary at the expense of more kernels (like clover)
00055   cudaColorSpinorField *tmp=0; // this hack allows for tmp2 to be full or parity field
00056   if (tmp2) {
00057     if (tmp2->SiteSubset() == QUDA_FULL_SITE_SUBSET) tmp = &(tmp2->Even());
00058     else tmp = tmp2;
00059   }
00060   bool reset = newTmp(&tmp, in.Even());
00061 
00062   Twist(*tmp, in.Odd());
00063   DslashXpay(out.Odd(), in.Even(), QUDA_ODD_PARITY, *tmp, -kappa);
00064   Twist(*tmp, in.Even());
00065   DslashXpay(out.Even(), in.Odd(), QUDA_EVEN_PARITY, *tmp, -kappa);
00066 
00067   deleteTmp(&tmp, reset);
00068 
00069 }
00070 
00071 void DiracTwistedMass::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00072 {
00073   checkFullSpinor(out, in);
00074   bool reset = newTmp(&tmp1, in);
00075 
00076   M(*tmp1, in);
00077   Mdag(out, *tmp1);
00078 
00079   deleteTmp(&tmp1, reset);
00080 }
00081 
00082 void DiracTwistedMass::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00083                           cudaColorSpinorField &x, cudaColorSpinorField &b, 
00084                           const QudaSolutionType solType) const
00085 {
00086   if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
00087     errorQuda("Preconditioned solution requires a preconditioned solve_type");
00088   }
00089 
00090   src = &b;
00091   sol = &x;
00092 }
00093 
00094 void DiracTwistedMass::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00095                               const QudaSolutionType solType) const
00096 {
00097   // do nothing
00098 }
00099 
00100 DiracTwistedMassPC::DiracTwistedMassPC(const DiracParam &param) : DiracTwistedMass(param)
00101 {
00102 
00103 }
00104 
00105 DiracTwistedMassPC::DiracTwistedMassPC(const DiracTwistedMassPC &dirac) : DiracTwistedMass(dirac) { }
00106 
00107 DiracTwistedMassPC::~DiracTwistedMassPC()
00108 {
00109 
00110 }
00111 
00112 DiracTwistedMassPC& DiracTwistedMassPC::operator=(const DiracTwistedMassPC &dirac)
00113 {
00114   if (&dirac != this) {
00115     DiracTwistedMass::operator=(dirac);
00116   }
00117   return *this;
00118 }
00119 
00120 // Public method to apply the inverse twist
00121 void DiracTwistedMassPC::TwistInv(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00122 {
00123   twistedApply(out, in, QUDA_TWIST_GAMMA5_INVERSE);
00124 }
00125 
00126 // apply hopping term, then inverse twist: (A_ee^-1 D_eo) or (A_oo^-1 D_oe),
00127 // and likewise for dagger: (D^dagger_eo D_ee^-1) or (D^dagger_oe A_oo^-1)
00128 void DiracTwistedMassPC::Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00129                                 const QudaParity parity) const
00130 {
00131   if (!initDslash) initDslashConstants(gauge, in.Stride());
00132   checkParitySpinor(in, out);
00133   checkSpinorAlias(in, out);
00134 
00135   if (in.TwistFlavor() != out.TwistFlavor()) 
00136     errorQuda("Twist flavors %d %d don't match", in.TwistFlavor(), out.TwistFlavor());
00137   if (in.TwistFlavor() == QUDA_TWIST_NO || in.TwistFlavor() == QUDA_TWIST_INVALID)
00138     errorQuda("Twist flavor not set %d\n", in.TwistFlavor());
00139 
00140   if (!dagger || matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) {
00141     double flavor_mu = in.TwistFlavor() * mu;
00142     setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda
00143     twistedMassDslashCuda(&out, gauge, &in, parity, dagger, 0, kappa, flavor_mu, 0.0, commDim);
00144     flops += (1320+72)*in.Volume();
00145   } else { // safe to use tmp2 here which may alias in
00146     bool reset = newTmp(&tmp2, in);
00147 
00148     TwistInv(*tmp2, in);
00149     DiracWilson::Dslash(out, *tmp2, parity);
00150 
00151     flops += 72*in.Volume();
00152 
00153     // if the pointers alias, undo the twist
00154     if (tmp2->V() == in.V()) Twist(*tmp2, *tmp2); 
00155 
00156     deleteTmp(&tmp2, reset);
00157   }
00158 
00159 }
00160 
00161 // xpay version of the above
00162 void DiracTwistedMassPC::DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00163                                     const QudaParity parity, const cudaColorSpinorField &x,
00164                                     const double &k) const
00165 {
00166   if (!initDslash) initDslashConstants(gauge, in.Stride());
00167   checkParitySpinor(in, out);
00168   checkSpinorAlias(in, out);
00169 
00170   if (in.TwistFlavor() != out.TwistFlavor()) 
00171     errorQuda("Twist flavors %d %d don't match", in.TwistFlavor(), out.TwistFlavor());
00172   if (in.TwistFlavor() == QUDA_TWIST_NO || in.TwistFlavor() == QUDA_TWIST_INVALID)
00173     errorQuda("Twist flavor not set %d\n", in.TwistFlavor());  
00174 
00175   if (!dagger) {
00176     double flavor_mu = in.TwistFlavor() * mu;
00177     setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda
00178     twistedMassDslashCuda(&out, gauge, &in, parity, dagger, &x, kappa, 
00179                           flavor_mu, k, commDim);
00180     flops += (1320+96)*in.Volume();
00181   } else { // tmp1 can alias in, but tmp2 can alias x so must not use this
00182     bool reset = newTmp(&tmp1, in);
00183 
00184     TwistInv(*tmp1, in);
00185     DiracWilson::Dslash(out, *tmp1, parity);
00186     xpayCuda((cudaColorSpinorField&)x, k, out);
00187     flops += 96*in.Volume();
00188 
00189     // if the pointers alias, undo the twist
00190     if (tmp1->V() == in.V()) Twist(*tmp1, *tmp1); 
00191 
00192     deleteTmp(&tmp1, reset);
00193   }
00194 
00195 }
00196 
00197 void DiracTwistedMassPC::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00198 {
00199   double kappa2 = -kappa*kappa;
00200 
00201   bool reset = newTmp(&tmp1, in);
00202 
00203   if (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) {
00204     Dslash(*tmp1, in, QUDA_ODD_PARITY); // fused kernel
00205     Twist(out, in);
00206     DiracWilson::DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, out, kappa2); // safe since out is not read after writing
00207   } else if (matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) {
00208     Dslash(*tmp1, in, QUDA_EVEN_PARITY); // fused kernel
00209     Twist(out, in);
00210     DiracWilson::DslashXpay(out, *tmp1, QUDA_ODD_PARITY, out, kappa2);
00211   } else { // symmetric preconditioning
00212     if (matpcType == QUDA_MATPC_EVEN_EVEN) {
00213       Dslash(*tmp1, in, QUDA_ODD_PARITY);
00214       DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, in, kappa2); 
00215     } else if (matpcType == QUDA_MATPC_ODD_ODD) {
00216       Dslash(*tmp1, in, QUDA_EVEN_PARITY);
00217       DslashXpay(out, *tmp1, QUDA_ODD_PARITY, in, kappa2); 
00218     } else {
00219       errorQuda("Invalid matpcType");
00220     }
00221   }
00222 
00223   deleteTmp(&tmp1, reset);
00224 
00225 }
00226 
00227 void DiracTwistedMassPC::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00228 {
00229   // need extra temporary because of symmetric preconditioning dagger
00230   bool reset = newTmp(&tmp2, in);
00231 
00232   M(*tmp2, in);
00233   Mdag(out, *tmp2);
00234 
00235   deleteTmp(&tmp2, reset);
00236 }
00237 
00238 void DiracTwistedMassPC::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00239                             cudaColorSpinorField &x, cudaColorSpinorField &b, 
00240                             const QudaSolutionType solType) const
00241 {
00242   // we desire solution to preconditioned system
00243   if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
00244     src = &b;
00245     sol = &x;
00246     return;
00247   }
00248 
00249   bool reset = newTmp(&tmp1, b.Even());
00250   
00251   // we desire solution to full system
00252   if (matpcType == QUDA_MATPC_EVEN_EVEN) {
00253     // src = A_ee^-1 (b_e + k D_eo A_oo^-1 b_o)
00254     src = &(x.Odd());
00255     TwistInv(*src, b.Odd());
00256     DiracWilson::DslashXpay(*tmp1, *src, QUDA_EVEN_PARITY, b.Even(), kappa);
00257     TwistInv(*src, *tmp1);
00258     sol = &(x.Even());
00259   } else if (matpcType == QUDA_MATPC_ODD_ODD) {
00260     // src = A_oo^-1 (b_o + k D_oe A_ee^-1 b_e)
00261     src = &(x.Even());
00262     TwistInv(*src, b.Even());
00263     DiracWilson::DslashXpay(*tmp1, *src, QUDA_ODD_PARITY, b.Odd(), kappa);
00264     TwistInv(*src, *tmp1);
00265     sol = &(x.Odd());
00266   } else if (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) {
00267     // src = b_e + k D_eo A_oo^-1 b_o
00268     src = &(x.Odd());
00269     TwistInv(*tmp1, b.Odd()); // safe even when *tmp1 = b.odd
00270     DiracWilson::DslashXpay(*src, *tmp1, QUDA_EVEN_PARITY, b.Even(), kappa);
00271     sol = &(x.Even());
00272   } else if (matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) {
00273     // src = b_o + k D_oe A_ee^-1 b_e
00274     src = &(x.Even());
00275     TwistInv(*tmp1, b.Even()); // safe even when *tmp1 = b.even
00276     DiracWilson::DslashXpay(*src, *tmp1, QUDA_ODD_PARITY, b.Odd(), kappa);
00277     sol = &(x.Odd());
00278   } else {
00279     errorQuda("MatPCType %d not valid for DiracTwistedMassPC", matpcType);
00280   }
00281 
00282   // here we use final solution to store parity solution and parity source
00283   // b is now up for grabs if we want
00284 
00285   deleteTmp(&tmp1, reset);
00286 }
00287 
00288 void DiracTwistedMassPC::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00289                                 const QudaSolutionType solType) const
00290 {
00291   if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
00292     return;
00293   }                             
00294 
00295   checkFullSpinor(x, b);
00296   bool reset = newTmp(&tmp1, b.Even());
00297 
00298   // create full solution
00299   
00300   if (matpcType == QUDA_MATPC_EVEN_EVEN ||
00301       matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) {
00302     // x_o = A_oo^-1 (b_o + k D_oe x_e)
00303     DiracWilson::DslashXpay(*tmp1, x.Even(), QUDA_ODD_PARITY, b.Odd(), kappa);
00304     TwistInv(x.Odd(), *tmp1);
00305   } else if (matpcType == QUDA_MATPC_ODD_ODD ||
00306              matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) {
00307     // x_e = A_ee^-1 (b_e + k D_eo x_o)
00308     DiracWilson::DslashXpay(*tmp1, x.Odd(), QUDA_EVEN_PARITY, b.Even(), kappa);
00309     TwistInv(x.Even(), *tmp1);
00310   } else {
00311     errorQuda("MatPCType %d not valid for DiracTwistedMassPC", matpcType);
00312   }
00313   
00314   deleteTmp(&tmp1, reset);
00315 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines