QUDA v0.4.0
A library for QCD on GPUs
|
00001 #include <iostream> 00002 #include <dirac_quda.h> 00003 #include <blas_quda.h> 00004 00005 DiracClover::DiracClover(const DiracParam ¶m) 00006 : DiracWilson(param), clover(*(param.clover)) { } 00007 00008 DiracClover::DiracClover(const DiracClover &dirac) 00009 : DiracWilson(dirac), clover(dirac.clover) { } 00010 00011 DiracClover::~DiracClover() 00012 { 00013 00014 } 00015 00016 DiracClover& DiracClover::operator=(const DiracClover &dirac) 00017 { 00018 if (&dirac != this) { 00019 DiracWilson::operator=(dirac); 00020 clover = dirac.clover; 00021 } 00022 return *this; 00023 } 00024 00025 void DiracClover::checkParitySpinor(const cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00026 { 00027 Dirac::checkParitySpinor(out, in); 00028 00029 if (out.Volume() != clover.VolumeCB()) { 00030 errorQuda("Parity spinor volume %d doesn't match clover checkboard volume %d", 00031 out.Volume(), clover.VolumeCB()); 00032 } 00033 } 00034 00035 // Public method to apply the clover term only 00036 void DiracClover::Clover(cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaParity parity) const 00037 { 00038 if (!initDslash) initDslashConstants(gauge, in.Stride()); 00039 if (!initClover) initCloverConstants(clover.Stride()); 00040 checkParitySpinor(in, out); 00041 00042 // regular clover term 00043 FullClover cs; 00044 cs.even = clover.even; cs.odd = clover.odd; cs.evenNorm = clover.evenNorm; cs.oddNorm = clover.oddNorm; 00045 cs.precision = clover.precision; cs.bytes = clover.bytes, cs.norm_bytes = clover.norm_bytes; 00046 cloverCuda(&out, gauge, cs, &in, parity); 00047 00048 flops += 504*in.Volume(); 00049 } 00050 00051 // FIXME: create kernel to eliminate tmp 00052 void DiracClover::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00053 { 00054 checkFullSpinor(out, in); 00055 cudaColorSpinorField *tmp=0; // this hack allows for tmp2 to be full or parity field 00056 if (tmp2) { 00057 if (tmp2->SiteSubset() == QUDA_FULL_SITE_SUBSET) tmp = &(tmp2->Even()); 00058 else tmp = tmp2; 00059 } 00060 bool reset = newTmp(&tmp, in.Even()); 00061 00062 Clover(*tmp, in.Odd(), QUDA_ODD_PARITY); 00063 DslashXpay(out.Odd(), in.Even(), QUDA_ODD_PARITY, *tmp, -kappa); 00064 Clover(*tmp, in.Even(), QUDA_EVEN_PARITY); 00065 DslashXpay(out.Even(), in.Odd(), QUDA_EVEN_PARITY, *tmp, -kappa); 00066 00067 deleteTmp(&tmp, reset); 00068 } 00069 00070 void DiracClover::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00071 { 00072 checkFullSpinor(out, in); 00073 00074 bool reset = newTmp(&tmp1, in); 00075 checkFullSpinor(*tmp1, in); 00076 00077 M(*tmp1, in); 00078 Mdag(out, *tmp1); 00079 00080 deleteTmp(&tmp1, reset); 00081 } 00082 00083 void DiracClover::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00084 cudaColorSpinorField &x, cudaColorSpinorField &b, 00085 const QudaSolutionType solType) const 00086 { 00087 if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { 00088 errorQuda("Preconditioned solution requires a preconditioned solve_type"); 00089 } 00090 00091 src = &b; 00092 sol = &x; 00093 } 00094 00095 void DiracClover::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00096 const QudaSolutionType solType) const 00097 { 00098 // do nothing 00099 } 00100 00101 DiracCloverPC::DiracCloverPC(const DiracParam ¶m) : 00102 DiracClover(param) 00103 { 00104 // For the preconditioned operator, we need to check that the inverse of the clover term is present 00105 if (!clover.cloverInv) errorQuda("Clover inverse required for DiracCloverPC"); 00106 } 00107 00108 DiracCloverPC::DiracCloverPC(const DiracCloverPC &dirac) : DiracClover(dirac) { } 00109 00110 DiracCloverPC::~DiracCloverPC() { } 00111 00112 DiracCloverPC& DiracCloverPC::operator=(const DiracCloverPC &dirac) 00113 { 00114 if (&dirac != this) { 00115 DiracClover::operator=(dirac); 00116 } 00117 return *this; 00118 } 00119 00120 // Public method 00121 void DiracCloverPC::CloverInv(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00122 const QudaParity parity) const 00123 { 00124 if (!initDslash) initDslashConstants(gauge, in.Stride()); 00125 if (!initClover) initCloverConstants(clover.Stride()); 00126 checkParitySpinor(in, out); 00127 00128 // needs to be cloverinv 00129 FullClover cs; 00130 cs.even = clover.evenInv; cs.odd = clover.oddInv; cs.evenNorm = clover.evenInvNorm; cs.oddNorm = clover.oddInvNorm; 00131 cs.precision = clover.precision; cs.bytes = clover.bytes, cs.norm_bytes = clover.norm_bytes; 00132 cloverCuda(&out, gauge, cs, &in, parity); 00133 00134 flops += 504*in.Volume(); 00135 } 00136 00137 // apply hopping term, then clover: (A_ee^-1 D_eo) or (A_oo^-1 D_oe), 00138 // and likewise for dagger: (A_ee^-1 D^dagger_eo) or (A_oo^-1 D^dagger_oe) 00139 // NOTE - this isn't Dslash dagger since order should be reversed! 00140 void DiracCloverPC::Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00141 const QudaParity parity) const 00142 { 00143 if (!initDslash) initDslashConstants(gauge, in.Stride()); 00144 if (!initClover) initCloverConstants(clover.Stride()); 00145 checkParitySpinor(in, out); 00146 checkSpinorAlias(in, out); 00147 00148 setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda 00149 00150 FullClover cs; 00151 cs.even = clover.evenInv; cs.odd = clover.oddInv; cs.evenNorm = clover.evenInvNorm; cs.oddNorm = clover.oddInvNorm; 00152 cs.precision = clover.precision; cs.bytes = clover.bytes, cs.norm_bytes = clover.norm_bytes; 00153 cloverDslashCuda(&out, gauge, cs, &in, parity, dagger, 0, 0.0, commDim); 00154 00155 flops += (1320+504)*in.Volume(); 00156 } 00157 00158 // xpay version of the above 00159 void DiracCloverPC::DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00160 const QudaParity parity, const cudaColorSpinorField &x, 00161 const double &k) const 00162 { 00163 if (!initDslash) initDslashConstants(gauge, in.Stride()); 00164 if (!initClover) initCloverConstants(clover.Stride()); 00165 checkParitySpinor(in, out); 00166 checkSpinorAlias(in, out); 00167 00168 setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda 00169 00170 FullClover cs; 00171 cs.even = clover.evenInv; cs.odd = clover.oddInv; cs.evenNorm = clover.evenInvNorm; cs.oddNorm = clover.oddInvNorm; 00172 cs.precision = clover.precision; cs.bytes = clover.bytes, cs.norm_bytes = clover.norm_bytes; 00173 cloverDslashCuda(&out, gauge, cs, &in, parity, dagger, &x, k, commDim); 00174 00175 flops += (1320+504+48)*in.Volume(); 00176 } 00177 00178 // Apply the even-odd preconditioned clover-improved Dirac operator 00179 void DiracCloverPC::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00180 { 00181 double kappa2 = -kappa*kappa; 00182 00183 // FIXME: For asymmetric, a "DslashCxpay" kernel would improve performance. 00184 bool reset1 = newTmp(&tmp1, in); 00185 00186 if (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) { 00187 bool reset2 = newTmp(&tmp2, in); 00188 // DiracCloverPC::Dslash applies A^{-1}Dslash 00189 Dslash(*tmp1, in, QUDA_ODD_PARITY); 00190 Clover(*tmp2, in, QUDA_EVEN_PARITY); 00191 00192 // DiracWilson::Dslash applies only Dslash 00193 DiracWilson::DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, *tmp2, kappa2); 00194 00195 deleteTmp(&tmp2, reset2); 00196 } else if (matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) { 00197 00198 // FIXME: It would be nice if I could do something like: cudaColorSpinorField tmp3( in.param() ); 00199 // to save copying the data from 'in' 00200 bool reset2 = newTmp(&tmp2, in); 00201 00202 // DiracCloverPC::Dslash applies A^{-1}Dslash 00203 Dslash(*tmp1, in, QUDA_EVEN_PARITY); 00204 Clover(*tmp2, in, QUDA_ODD_PARITY); 00205 00206 // DiracWilson::Dslash applies only Dslash 00207 DiracWilson::DslashXpay(out, *tmp1, QUDA_ODD_PARITY, *tmp2, kappa2); 00208 00209 deleteTmp(&tmp2, reset2); 00210 } else if (!dagger) { // symmetric preconditioning 00211 if (matpcType == QUDA_MATPC_EVEN_EVEN) { 00212 Dslash(*tmp1, in, QUDA_ODD_PARITY); 00213 DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, in, kappa2); 00214 } else if (matpcType == QUDA_MATPC_ODD_ODD) { 00215 Dslash(*tmp1, in, QUDA_EVEN_PARITY); 00216 DslashXpay(out, *tmp1, QUDA_ODD_PARITY, in, kappa2); 00217 } else { 00218 errorQuda("Invalid matpcType"); 00219 } 00220 } else { // symmetric preconditioning, dagger 00221 if (matpcType == QUDA_MATPC_EVEN_EVEN) { 00222 CloverInv(out, in, QUDA_EVEN_PARITY); 00223 Dslash(*tmp1, out, QUDA_ODD_PARITY); 00224 DiracWilson::DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, in, kappa2); 00225 } else if (matpcType == QUDA_MATPC_ODD_ODD) { 00226 CloverInv(out, in, QUDA_ODD_PARITY); 00227 Dslash(*tmp1, out, QUDA_EVEN_PARITY); 00228 DiracWilson::DslashXpay(out, *tmp1, QUDA_ODD_PARITY, in, kappa2); 00229 } else { 00230 errorQuda("MatPCType %d not valid for DiracCloverPC", matpcType); 00231 } 00232 } 00233 00234 deleteTmp(&tmp1, reset1); 00235 } 00236 00237 void DiracCloverPC::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00238 { 00239 // need extra temporary because of symmetric preconditioning dagger 00240 bool reset = newTmp(&tmp2, in); 00241 00242 M(*tmp2, in); 00243 Mdag(out, *tmp2); 00244 00245 deleteTmp(&tmp2, reset); 00246 } 00247 00248 void DiracCloverPC::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00249 cudaColorSpinorField &x, cudaColorSpinorField &b, 00250 const QudaSolutionType solType) const 00251 { 00252 // we desire solution to preconditioned system 00253 if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { 00254 src = &b; 00255 sol = &x; 00256 return; 00257 } 00258 00259 bool reset = newTmp(&tmp1, b.Even()); 00260 00261 // we desire solution to full system 00262 if (matpcType == QUDA_MATPC_EVEN_EVEN) { 00263 // src = A_ee^-1 (b_e + k D_eo A_oo^-1 b_o) 00264 src = &(x.Odd()); 00265 CloverInv(*src, b.Odd(), QUDA_ODD_PARITY); 00266 DiracWilson::DslashXpay(*tmp1, *src, QUDA_EVEN_PARITY, b.Even(), kappa); 00267 CloverInv(*src, *tmp1, QUDA_EVEN_PARITY); 00268 sol = &(x.Even()); 00269 } else if (matpcType == QUDA_MATPC_ODD_ODD) { 00270 // src = A_oo^-1 (b_o + k D_oe A_ee^-1 b_e) 00271 src = &(x.Even()); 00272 CloverInv(*src, b.Even(), QUDA_EVEN_PARITY); 00273 DiracWilson::DslashXpay(*tmp1, *src, QUDA_ODD_PARITY, b.Odd(), kappa); 00274 CloverInv(*src, *tmp1, QUDA_ODD_PARITY); 00275 sol = &(x.Odd()); 00276 } else if (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) { 00277 // src = b_e + k D_eo A_oo^-1 b_o 00278 src = &(x.Odd()); 00279 CloverInv(*tmp1, b.Odd(), QUDA_ODD_PARITY); // safe even when *tmp1 = b.odd 00280 DiracWilson::DslashXpay(*src, *tmp1, QUDA_EVEN_PARITY, b.Even(), kappa); 00281 sol = &(x.Even()); 00282 } else if (matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) { 00283 // src = b_o + k D_oe A_ee^-1 b_e 00284 src = &(x.Even()); 00285 CloverInv(*tmp1, b.Even(), QUDA_EVEN_PARITY); // safe even when *tmp1 = b.even 00286 DiracWilson::DslashXpay(*src, *tmp1, QUDA_ODD_PARITY, b.Odd(), kappa); 00287 sol = &(x.Odd()); 00288 } else { 00289 errorQuda("MatPCType %d not valid for DiracCloverPC", matpcType); 00290 } 00291 00292 // here we use final solution to store parity solution and parity source 00293 // b is now up for grabs if we want 00294 00295 deleteTmp(&tmp1, reset); 00296 00297 } 00298 00299 void DiracCloverPC::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00300 const QudaSolutionType solType) const 00301 { 00302 if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { 00303 return; 00304 } 00305 00306 checkFullSpinor(x, b); 00307 00308 bool reset = newTmp(&tmp1, b.Even()); 00309 00310 // create full solution 00311 00312 if (matpcType == QUDA_MATPC_EVEN_EVEN || 00313 matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) { 00314 // x_o = A_oo^-1 (b_o + k D_oe x_e) 00315 DiracWilson::DslashXpay(*tmp1, x.Even(), QUDA_ODD_PARITY, b.Odd(), kappa); 00316 CloverInv(x.Odd(), *tmp1, QUDA_ODD_PARITY); 00317 } else if (matpcType == QUDA_MATPC_ODD_ODD || 00318 matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) { 00319 // x_e = A_ee^-1 (b_e + k D_eo x_o) 00320 DiracWilson::DslashXpay(*tmp1, x.Odd(), QUDA_EVEN_PARITY, b.Even(), kappa); 00321 CloverInv(x.Even(), *tmp1, QUDA_EVEN_PARITY); 00322 } else { 00323 errorQuda("MatPCType %d not valid for DiracCloverPC", matpcType); 00324 } 00325 00326 deleteTmp(&tmp1, reset); 00327 00328 } 00329