|
QUDA v0.3.2
A library for QCD on GPUs
|
00001 #include <dirac_quda.h> 00002 #include <blas_quda.h> 00003 #include <iostream> 00004 00005 DiracTwistedMass::DiracTwistedMass(const DiracParam ¶m) 00006 : DiracWilson(param), mu(param.mu) 00007 { 00008 00009 } 00010 00011 DiracTwistedMass::DiracTwistedMass(const DiracTwistedMass &dirac) 00012 : DiracWilson(dirac), mu(dirac.mu) 00013 { 00014 00015 } 00016 00017 DiracTwistedMass::~DiracTwistedMass() 00018 { 00019 00020 } 00021 00022 DiracTwistedMass& DiracTwistedMass::operator=(const DiracTwistedMass &dirac) 00023 { 00024 if (&dirac != this) { 00025 DiracWilson::operator=(dirac); 00026 } 00027 return *this; 00028 } 00029 00030 // Protected method for applying twist 00031 void DiracTwistedMass::twistedApply(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00032 const QudaTwistGamma5Type twistType) const 00033 { 00034 checkParitySpinor(out, in); 00035 00036 if (!initDslash) initDslashConstants(gauge, in.stride, 0); 00037 00038 if (in.twistFlavor == QUDA_TWIST_NO || in.twistFlavor == QUDA_TWIST_INVALID) 00039 errorQuda("Twist flavor not set %d\n", in.twistFlavor); 00040 00041 double flavor_mu = in.twistFlavor * mu; 00042 00043 twistGamma5Cuda(out.v, out.norm, in.v, in.norm, dagger, kappa, flavor_mu, 00044 in.volume, in.length, in.precision, twistType); 00045 } 00046 00047 00048 // Public method to apply the twist 00049 void DiracTwistedMass::Twist(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00050 { 00051 twistedApply(out, in, QUDA_TWIST_GAMMA5_DIRECT); 00052 } 00053 00054 void DiracTwistedMass::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00055 { 00056 checkFullSpinor(out, in); 00057 if (in.twistFlavor != out.twistFlavor) 00058 errorQuda("Twist flavors %d %d don't match", in.twistFlavor, out.twistFlavor); 00059 00060 if (in.twistFlavor == QUDA_TWIST_NO || in.twistFlavor == QUDA_TWIST_INVALID) { 00061 errorQuda("Twist flavor not set %d\n", in.twistFlavor); 00062 } 00063 00064 // We can eliminate this temporary at the expense of more kernels (like clover) 00065 cudaColorSpinorField *tmp=0; // this hack allows for tmp2 to be full or parity field 00066 if (tmp2) { 00067 if (tmp2->SiteSubset() == QUDA_FULL_SITE_SUBSET) tmp = &(tmp2->Even()); 00068 else tmp = tmp2; 00069 } 00070 bool reset = newTmp(&tmp, in.Even()); 00071 00072 Twist(*tmp, in.Odd()); 00073 DslashXpay(out.Odd(), in.Even(), QUDA_ODD_PARITY, *tmp, -kappa); 00074 Twist(*tmp, in.Even()); 00075 DslashXpay(out.Even(), in.Odd(), QUDA_EVEN_PARITY, *tmp, -kappa); 00076 00077 deleteTmp(&tmp, reset); 00078 00079 } 00080 00081 void DiracTwistedMass::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00082 { 00083 checkFullSpinor(out, in); 00084 bool reset = newTmp(&tmp1, in); 00085 00086 M(*tmp1, in); 00087 Mdag(out, *tmp1); 00088 00089 deleteTmp(&tmp1, reset); 00090 } 00091 00092 void DiracTwistedMass::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00093 cudaColorSpinorField &x, cudaColorSpinorField &b, 00094 const QudaSolutionType solType) const 00095 { 00096 if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { 00097 errorQuda("Preconditioned solution requires a preconditioned solve_type"); 00098 } 00099 00100 src = &b; 00101 sol = &x; 00102 } 00103 00104 void DiracTwistedMass::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00105 const QudaSolutionType solType) const 00106 { 00107 // do nothing 00108 } 00109 00110 DiracTwistedMassPC::DiracTwistedMassPC(const DiracParam ¶m) 00111 : DiracTwistedMass(param) 00112 { 00113 00114 } 00115 00116 DiracTwistedMassPC::DiracTwistedMassPC(const DiracTwistedMassPC &dirac) 00117 : DiracTwistedMass(dirac) 00118 { 00119 00120 } 00121 00122 DiracTwistedMassPC::~DiracTwistedMassPC() 00123 { 00124 00125 } 00126 00127 DiracTwistedMassPC& DiracTwistedMassPC::operator=(const DiracTwistedMassPC &dirac) 00128 { 00129 if (&dirac != this) { 00130 DiracTwistedMass::operator=(dirac); 00131 } 00132 return *this; 00133 } 00134 00135 // Public method to apply the inverse twist 00136 void DiracTwistedMassPC::TwistInv(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00137 { 00138 twistedApply(out, in, QUDA_TWIST_GAMMA5_INVERSE); 00139 } 00140 00141 // apply hopping term, then inverse twist: (A_ee^-1 D_eo) or (A_oo^-1 D_oe), 00142 // and likewise for dagger: (D^dagger_eo D_ee^-1) or (D^dagger_oe A_oo^-1) 00143 void DiracTwistedMassPC::Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00144 const QudaParity parity) const 00145 { 00146 if (!initDslash) initDslashConstants(gauge, in.stride, 0); 00147 checkParitySpinor(in, out); 00148 checkSpinorAlias(in, out); 00149 00150 if (in.twistFlavor != out.twistFlavor) 00151 errorQuda("Twist flavors %d %d don't match", in.twistFlavor, out.twistFlavor); 00152 if (in.twistFlavor == QUDA_TWIST_NO || in.twistFlavor == QUDA_TWIST_INVALID) 00153 errorQuda("Twist flavor not set %d\n", in.twistFlavor); 00154 00155 if (!dagger || matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) { 00156 double flavor_mu = in.twistFlavor * mu; 00157 twistedMassDslashCuda(out.v, out.norm, gauge, in.v, in.norm, parity, dagger, 00158 0, 0, kappa, flavor_mu, 0.0, out.volume, out.length, in.Precision()); 00159 flops += (1320+72)*in.volume; 00160 } else { // safe to use tmp2 here which may alias in 00161 bool reset = newTmp(&tmp2, in); 00162 00163 TwistInv(*tmp2, in); 00164 DiracWilson::Dslash(out, *tmp2, parity); 00165 00166 flops += 72*in.volume; 00167 00168 // if the pointers alias, undo the twist 00169 if (tmp2->v == in.v) Twist(*tmp2, *tmp2); 00170 00171 deleteTmp(&tmp2, reset); 00172 } 00173 00174 } 00175 00176 // xpay version of the above 00177 void DiracTwistedMassPC::DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00178 const QudaParity parity, const cudaColorSpinorField &x, 00179 const double &k) const 00180 { 00181 if (!initDslash) initDslashConstants(gauge, in.stride, 0); 00182 checkParitySpinor(in, out); 00183 checkSpinorAlias(in, out); 00184 00185 if (in.twistFlavor != out.twistFlavor) 00186 errorQuda("Twist flavors %d %d don't match", in.twistFlavor, out.twistFlavor); 00187 if (in.twistFlavor == QUDA_TWIST_NO || in.twistFlavor == QUDA_TWIST_INVALID) 00188 errorQuda("Twist flavor not set %d\n", in.twistFlavor); 00189 00190 if (!dagger) { 00191 double flavor_mu = in.twistFlavor * mu; 00192 twistedMassDslashCuda(out.v, out.norm, gauge, in.v, in.norm, parity, dagger, 00193 x.v, x.norm, kappa, flavor_mu, k, out.volume, out.length, in.Precision()); 00194 flops += (1320+96)*in.volume; 00195 } else { // tmp1 can alias in, but tmp2 can alias x so must not use this 00196 bool reset = newTmp(&tmp1, in); 00197 00198 TwistInv(*tmp1, in); 00199 DiracWilson::Dslash(out, *tmp1, parity); 00200 xpayCuda(x, k, out); 00201 flops += 96*in.volume; 00202 00203 // if the pointers alias, undo the twist 00204 if (tmp1->v == in.v) Twist(*tmp1, *tmp1); 00205 00206 deleteTmp(&tmp1, reset); 00207 } 00208 00209 } 00210 00211 void DiracTwistedMassPC::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00212 { 00213 double kappa2 = -kappa*kappa; 00214 00215 bool reset = newTmp(&tmp1, in); 00216 00217 if (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) { 00218 Dslash(*tmp1, in, QUDA_ODD_PARITY); // fused kernel 00219 Twist(out, in); 00220 DiracWilson::DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, out, kappa2); // safe since out is not read after writing 00221 } else if (matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) { 00222 Dslash(*tmp1, in, QUDA_EVEN_PARITY); // fused kernel 00223 Twist(out, in); 00224 DiracWilson::DslashXpay(out, *tmp1, QUDA_ODD_PARITY, out, kappa2); 00225 } else { // symmetric preconditioning 00226 if (matpcType == QUDA_MATPC_EVEN_EVEN) { 00227 Dslash(*tmp1, in, QUDA_ODD_PARITY); 00228 DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, in, kappa2); 00229 } else if (matpcType == QUDA_MATPC_ODD_ODD) { 00230 Dslash(*tmp1, in, QUDA_EVEN_PARITY); 00231 DslashXpay(out, *tmp1, QUDA_ODD_PARITY, in, kappa2); 00232 } else { 00233 errorQuda("Invalid matpcType"); 00234 } 00235 } 00236 00237 deleteTmp(&tmp1, reset); 00238 00239 } 00240 00241 void DiracTwistedMassPC::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00242 { 00243 // need extra temporary because of symmetric preconditioning dagger 00244 bool reset = newTmp(&tmp2, in); 00245 00246 M(*tmp2, in); 00247 Mdag(out, *tmp2); 00248 00249 deleteTmp(&tmp2, reset); 00250 } 00251 00252 void DiracTwistedMassPC::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00253 cudaColorSpinorField &x, cudaColorSpinorField &b, 00254 const QudaSolutionType solType) const 00255 { 00256 // we desire solution to preconditioned system 00257 if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { 00258 src = &b; 00259 sol = &x; 00260 return; 00261 } 00262 00263 bool reset = newTmp(&tmp1, b.Even()); 00264 00265 // we desire solution to full system 00266 if (matpcType == QUDA_MATPC_EVEN_EVEN) { 00267 // src = A_ee^-1 (b_e + k D_eo A_oo^-1 b_o) 00268 src = &(x.Odd()); 00269 TwistInv(*src, b.Odd()); 00270 DiracWilson::DslashXpay(*tmp1, *src, QUDA_EVEN_PARITY, b.Even(), kappa); 00271 TwistInv(*src, *tmp1); 00272 sol = &(x.Even()); 00273 } else if (matpcType == QUDA_MATPC_ODD_ODD) { 00274 // src = A_oo^-1 (b_o + k D_oe A_ee^-1 b_e) 00275 src = &(x.Even()); 00276 TwistInv(*src, b.Even()); 00277 DiracWilson::DslashXpay(*tmp1, *src, QUDA_ODD_PARITY, b.Odd(), kappa); 00278 TwistInv(*src, *tmp1); 00279 sol = &(x.Odd()); 00280 } else if (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) { 00281 // src = b_e + k D_eo A_oo^-1 b_o 00282 src = &(x.Odd()); 00283 TwistInv(*tmp1, b.Odd()); // safe even when *tmp1 = b.odd 00284 DiracWilson::DslashXpay(*src, *tmp1, QUDA_EVEN_PARITY, b.Even(), kappa); 00285 sol = &(x.Even()); 00286 } else if (matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) { 00287 // src = b_o + k D_oe A_ee^-1 b_e 00288 src = &(x.Even()); 00289 TwistInv(*tmp1, b.Even()); // safe even when *tmp1 = b.even 00290 DiracWilson::DslashXpay(*src, *tmp1, QUDA_ODD_PARITY, b.Odd(), kappa); 00291 sol = &(x.Odd()); 00292 } else { 00293 errorQuda("MatPCType %d not valid for DiracTwistedMassPC", matpcType); 00294 } 00295 00296 // here we use final solution to store parity solution and parity source 00297 // b is now up for grabs if we want 00298 00299 deleteTmp(&tmp1, reset); 00300 } 00301 00302 void DiracTwistedMassPC::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00303 const QudaSolutionType solType) const 00304 { 00305 if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { 00306 return; 00307 } 00308 00309 checkFullSpinor(x, b); 00310 bool reset = newTmp(&tmp1, b.Even()); 00311 00312 // create full solution 00313 00314 if (matpcType == QUDA_MATPC_EVEN_EVEN || 00315 matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) { 00316 // x_o = A_oo^-1 (b_o + k D_oe x_e) 00317 DiracWilson::DslashXpay(*tmp1, x.Even(), QUDA_ODD_PARITY, b.Odd(), kappa); 00318 TwistInv(x.Odd(), *tmp1); 00319 } else if (matpcType == QUDA_MATPC_ODD_ODD || 00320 matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) { 00321 // x_e = A_ee^-1 (b_e + k D_eo x_o) 00322 DiracWilson::DslashXpay(*tmp1, x.Odd(), QUDA_EVEN_PARITY, b.Even(), kappa); 00323 TwistInv(x.Even(), *tmp1); 00324 } else { 00325 errorQuda("MatPCType %d not valid for DiracTwistedMassPC", matpcType); 00326 } 00327 00328 deleteTmp(&tmp1, reset); 00329 }
1.7.3