QUDA v0.4.0
A library for QCD on GPUs
|
00001 #include <dirac_quda.h> 00002 #include <blas_quda.h> 00003 #include <iostream> 00004 00005 DiracTwistedMass::DiracTwistedMass(const DiracParam ¶m) : DiracWilson(param), mu(param.mu) { } 00006 00007 DiracTwistedMass::DiracTwistedMass(const DiracTwistedMass &dirac) : DiracWilson(dirac), mu(dirac.mu) { } 00008 00009 DiracTwistedMass::~DiracTwistedMass() { } 00010 00011 DiracTwistedMass& DiracTwistedMass::operator=(const DiracTwistedMass &dirac) 00012 { 00013 if (&dirac != this) { 00014 DiracWilson::operator=(dirac); 00015 } 00016 return *this; 00017 } 00018 00019 // Protected method for applying twist 00020 void DiracTwistedMass::twistedApply(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00021 const QudaTwistGamma5Type twistType) const 00022 { 00023 checkParitySpinor(out, in); 00024 00025 if (!initDslash) initDslashConstants(gauge, in.Stride()); 00026 00027 if (in.TwistFlavor() == QUDA_TWIST_NO || in.TwistFlavor() == QUDA_TWIST_INVALID) 00028 errorQuda("Twist flavor not set %d\n", in.TwistFlavor()); 00029 00030 double flavor_mu = in.TwistFlavor() * mu; 00031 00032 twistGamma5Cuda(&out, &in, dagger, kappa, flavor_mu, twistType); 00033 00034 flops += 24*in.Volume(); 00035 } 00036 00037 00038 // Public method to apply the twist 00039 void DiracTwistedMass::Twist(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00040 { 00041 twistedApply(out, in, QUDA_TWIST_GAMMA5_DIRECT); 00042 } 00043 00044 void DiracTwistedMass::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00045 { 00046 checkFullSpinor(out, in); 00047 if (in.TwistFlavor() != out.TwistFlavor()) 00048 errorQuda("Twist flavors %d %d don't match", in.TwistFlavor(), out.TwistFlavor()); 00049 00050 if (in.TwistFlavor() == QUDA_TWIST_NO || in.TwistFlavor() == QUDA_TWIST_INVALID) { 00051 errorQuda("Twist flavor not set %d\n", in.TwistFlavor()); 00052 } 00053 00054 // We can eliminate this temporary at the expense of more kernels (like clover) 00055 cudaColorSpinorField *tmp=0; // this hack allows for tmp2 to be full or parity field 00056 if (tmp2) { 00057 if (tmp2->SiteSubset() == QUDA_FULL_SITE_SUBSET) tmp = &(tmp2->Even()); 00058 else tmp = tmp2; 00059 } 00060 bool reset = newTmp(&tmp, in.Even()); 00061 00062 Twist(*tmp, in.Odd()); 00063 DslashXpay(out.Odd(), in.Even(), QUDA_ODD_PARITY, *tmp, -kappa); 00064 Twist(*tmp, in.Even()); 00065 DslashXpay(out.Even(), in.Odd(), QUDA_EVEN_PARITY, *tmp, -kappa); 00066 00067 deleteTmp(&tmp, reset); 00068 00069 } 00070 00071 void DiracTwistedMass::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00072 { 00073 checkFullSpinor(out, in); 00074 bool reset = newTmp(&tmp1, in); 00075 00076 M(*tmp1, in); 00077 Mdag(out, *tmp1); 00078 00079 deleteTmp(&tmp1, reset); 00080 } 00081 00082 void DiracTwistedMass::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00083 cudaColorSpinorField &x, cudaColorSpinorField &b, 00084 const QudaSolutionType solType) const 00085 { 00086 if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { 00087 errorQuda("Preconditioned solution requires a preconditioned solve_type"); 00088 } 00089 00090 src = &b; 00091 sol = &x; 00092 } 00093 00094 void DiracTwistedMass::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00095 const QudaSolutionType solType) const 00096 { 00097 // do nothing 00098 } 00099 00100 DiracTwistedMassPC::DiracTwistedMassPC(const DiracParam ¶m) : DiracTwistedMass(param) 00101 { 00102 00103 } 00104 00105 DiracTwistedMassPC::DiracTwistedMassPC(const DiracTwistedMassPC &dirac) : DiracTwistedMass(dirac) { } 00106 00107 DiracTwistedMassPC::~DiracTwistedMassPC() 00108 { 00109 00110 } 00111 00112 DiracTwistedMassPC& DiracTwistedMassPC::operator=(const DiracTwistedMassPC &dirac) 00113 { 00114 if (&dirac != this) { 00115 DiracTwistedMass::operator=(dirac); 00116 } 00117 return *this; 00118 } 00119 00120 // Public method to apply the inverse twist 00121 void DiracTwistedMassPC::TwistInv(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00122 { 00123 twistedApply(out, in, QUDA_TWIST_GAMMA5_INVERSE); 00124 } 00125 00126 // apply hopping term, then inverse twist: (A_ee^-1 D_eo) or (A_oo^-1 D_oe), 00127 // and likewise for dagger: (D^dagger_eo D_ee^-1) or (D^dagger_oe A_oo^-1) 00128 void DiracTwistedMassPC::Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00129 const QudaParity parity) const 00130 { 00131 if (!initDslash) initDslashConstants(gauge, in.Stride()); 00132 checkParitySpinor(in, out); 00133 checkSpinorAlias(in, out); 00134 00135 if (in.TwistFlavor() != out.TwistFlavor()) 00136 errorQuda("Twist flavors %d %d don't match", in.TwistFlavor(), out.TwistFlavor()); 00137 if (in.TwistFlavor() == QUDA_TWIST_NO || in.TwistFlavor() == QUDA_TWIST_INVALID) 00138 errorQuda("Twist flavor not set %d\n", in.TwistFlavor()); 00139 00140 if (!dagger || matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) { 00141 double flavor_mu = in.TwistFlavor() * mu; 00142 setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda 00143 twistedMassDslashCuda(&out, gauge, &in, parity, dagger, 0, kappa, flavor_mu, 0.0, commDim); 00144 flops += (1320+72)*in.Volume(); 00145 } else { // safe to use tmp2 here which may alias in 00146 bool reset = newTmp(&tmp2, in); 00147 00148 TwistInv(*tmp2, in); 00149 DiracWilson::Dslash(out, *tmp2, parity); 00150 00151 flops += 72*in.Volume(); 00152 00153 // if the pointers alias, undo the twist 00154 if (tmp2->V() == in.V()) Twist(*tmp2, *tmp2); 00155 00156 deleteTmp(&tmp2, reset); 00157 } 00158 00159 } 00160 00161 // xpay version of the above 00162 void DiracTwistedMassPC::DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00163 const QudaParity parity, const cudaColorSpinorField &x, 00164 const double &k) const 00165 { 00166 if (!initDslash) initDslashConstants(gauge, in.Stride()); 00167 checkParitySpinor(in, out); 00168 checkSpinorAlias(in, out); 00169 00170 if (in.TwistFlavor() != out.TwistFlavor()) 00171 errorQuda("Twist flavors %d %d don't match", in.TwistFlavor(), out.TwistFlavor()); 00172 if (in.TwistFlavor() == QUDA_TWIST_NO || in.TwistFlavor() == QUDA_TWIST_INVALID) 00173 errorQuda("Twist flavor not set %d\n", in.TwistFlavor()); 00174 00175 if (!dagger) { 00176 double flavor_mu = in.TwistFlavor() * mu; 00177 setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda 00178 twistedMassDslashCuda(&out, gauge, &in, parity, dagger, &x, kappa, 00179 flavor_mu, k, commDim); 00180 flops += (1320+96)*in.Volume(); 00181 } else { // tmp1 can alias in, but tmp2 can alias x so must not use this 00182 bool reset = newTmp(&tmp1, in); 00183 00184 TwistInv(*tmp1, in); 00185 DiracWilson::Dslash(out, *tmp1, parity); 00186 xpayCuda((cudaColorSpinorField&)x, k, out); 00187 flops += 96*in.Volume(); 00188 00189 // if the pointers alias, undo the twist 00190 if (tmp1->V() == in.V()) Twist(*tmp1, *tmp1); 00191 00192 deleteTmp(&tmp1, reset); 00193 } 00194 00195 } 00196 00197 void DiracTwistedMassPC::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00198 { 00199 double kappa2 = -kappa*kappa; 00200 00201 bool reset = newTmp(&tmp1, in); 00202 00203 if (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) { 00204 Dslash(*tmp1, in, QUDA_ODD_PARITY); // fused kernel 00205 Twist(out, in); 00206 DiracWilson::DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, out, kappa2); // safe since out is not read after writing 00207 } else if (matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) { 00208 Dslash(*tmp1, in, QUDA_EVEN_PARITY); // fused kernel 00209 Twist(out, in); 00210 DiracWilson::DslashXpay(out, *tmp1, QUDA_ODD_PARITY, out, kappa2); 00211 } else { // symmetric preconditioning 00212 if (matpcType == QUDA_MATPC_EVEN_EVEN) { 00213 Dslash(*tmp1, in, QUDA_ODD_PARITY); 00214 DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, in, kappa2); 00215 } else if (matpcType == QUDA_MATPC_ODD_ODD) { 00216 Dslash(*tmp1, in, QUDA_EVEN_PARITY); 00217 DslashXpay(out, *tmp1, QUDA_ODD_PARITY, in, kappa2); 00218 } else { 00219 errorQuda("Invalid matpcType"); 00220 } 00221 } 00222 00223 deleteTmp(&tmp1, reset); 00224 00225 } 00226 00227 void DiracTwistedMassPC::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00228 { 00229 // need extra temporary because of symmetric preconditioning dagger 00230 bool reset = newTmp(&tmp2, in); 00231 00232 M(*tmp2, in); 00233 Mdag(out, *tmp2); 00234 00235 deleteTmp(&tmp2, reset); 00236 } 00237 00238 void DiracTwistedMassPC::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00239 cudaColorSpinorField &x, cudaColorSpinorField &b, 00240 const QudaSolutionType solType) const 00241 { 00242 // we desire solution to preconditioned system 00243 if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { 00244 src = &b; 00245 sol = &x; 00246 return; 00247 } 00248 00249 bool reset = newTmp(&tmp1, b.Even()); 00250 00251 // we desire solution to full system 00252 if (matpcType == QUDA_MATPC_EVEN_EVEN) { 00253 // src = A_ee^-1 (b_e + k D_eo A_oo^-1 b_o) 00254 src = &(x.Odd()); 00255 TwistInv(*src, b.Odd()); 00256 DiracWilson::DslashXpay(*tmp1, *src, QUDA_EVEN_PARITY, b.Even(), kappa); 00257 TwistInv(*src, *tmp1); 00258 sol = &(x.Even()); 00259 } else if (matpcType == QUDA_MATPC_ODD_ODD) { 00260 // src = A_oo^-1 (b_o + k D_oe A_ee^-1 b_e) 00261 src = &(x.Even()); 00262 TwistInv(*src, b.Even()); 00263 DiracWilson::DslashXpay(*tmp1, *src, QUDA_ODD_PARITY, b.Odd(), kappa); 00264 TwistInv(*src, *tmp1); 00265 sol = &(x.Odd()); 00266 } else if (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) { 00267 // src = b_e + k D_eo A_oo^-1 b_o 00268 src = &(x.Odd()); 00269 TwistInv(*tmp1, b.Odd()); // safe even when *tmp1 = b.odd 00270 DiracWilson::DslashXpay(*src, *tmp1, QUDA_EVEN_PARITY, b.Even(), kappa); 00271 sol = &(x.Even()); 00272 } else if (matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) { 00273 // src = b_o + k D_oe A_ee^-1 b_e 00274 src = &(x.Even()); 00275 TwistInv(*tmp1, b.Even()); // safe even when *tmp1 = b.even 00276 DiracWilson::DslashXpay(*src, *tmp1, QUDA_ODD_PARITY, b.Odd(), kappa); 00277 sol = &(x.Odd()); 00278 } else { 00279 errorQuda("MatPCType %d not valid for DiracTwistedMassPC", matpcType); 00280 } 00281 00282 // here we use final solution to store parity solution and parity source 00283 // b is now up for grabs if we want 00284 00285 deleteTmp(&tmp1, reset); 00286 } 00287 00288 void DiracTwistedMassPC::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00289 const QudaSolutionType solType) const 00290 { 00291 if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { 00292 return; 00293 } 00294 00295 checkFullSpinor(x, b); 00296 bool reset = newTmp(&tmp1, b.Even()); 00297 00298 // create full solution 00299 00300 if (matpcType == QUDA_MATPC_EVEN_EVEN || 00301 matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) { 00302 // x_o = A_oo^-1 (b_o + k D_oe x_e) 00303 DiracWilson::DslashXpay(*tmp1, x.Even(), QUDA_ODD_PARITY, b.Odd(), kappa); 00304 TwistInv(x.Odd(), *tmp1); 00305 } else if (matpcType == QUDA_MATPC_ODD_ODD || 00306 matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) { 00307 // x_e = A_ee^-1 (b_e + k D_eo x_o) 00308 DiracWilson::DslashXpay(*tmp1, x.Odd(), QUDA_EVEN_PARITY, b.Even(), kappa); 00309 TwistInv(x.Even(), *tmp1); 00310 } else { 00311 errorQuda("MatPCType %d not valid for DiracTwistedMassPC", matpcType); 00312 } 00313 00314 deleteTmp(&tmp1, reset); 00315 }