|
QUDA v0.3.2
A library for QCD on GPUs
|
00001 #include <iostream> 00002 #include <dirac_quda.h> 00003 #include <blas_quda.h> 00004 00005 DiracClover::DiracClover(const DiracParam ¶m) 00006 : DiracWilson(param), clover(*(param.clover)) 00007 { 00008 00009 } 00010 00011 DiracClover::DiracClover(const DiracClover &dirac) 00012 : DiracWilson(dirac), clover(dirac.clover) 00013 { 00014 00015 } 00016 00017 DiracClover::~DiracClover() 00018 { 00019 00020 } 00021 00022 DiracClover& DiracClover::operator=(const DiracClover &dirac) 00023 { 00024 00025 if (&dirac != this) { 00026 DiracWilson::operator=(dirac); 00027 clover = dirac.clover; 00028 } 00029 00030 return *this; 00031 } 00032 00033 void DiracClover::checkParitySpinor(const cudaColorSpinorField &out, const cudaColorSpinorField &in, 00034 const FullClover &clover) const 00035 { 00036 Dirac::checkParitySpinor(out, in); 00037 00038 if (out.Volume() != clover.even.volume) { 00039 errorQuda("Spinor volume %d doesn't match even clover volume %d", 00040 out.Volume(), clover.even.volume); 00041 } 00042 if (out.Volume() != clover.odd.volume) { 00043 errorQuda("Spinor volume %d doesn't match odd clover volume %d", 00044 out.Volume(), clover.odd.volume); 00045 } 00046 00047 } 00048 00049 // Protected method, also used for applying cloverInv 00050 void DiracClover::cloverApply(cudaColorSpinorField &out, const FullClover &clover, 00051 const cudaColorSpinorField &in, const QudaParity parity) const 00052 { 00053 if (!initDslash) initDslashConstants(gauge, in.stride, clover.even.stride); 00054 checkParitySpinor(in, out, clover); 00055 00056 cloverCuda(out.v, out.norm, gauge, clover, in.v, in.norm, parity, 00057 in.volume, in.length, in.Precision()); 00058 00059 flops += 504*in.volume; 00060 } 00061 00062 // Public method to apply the clover term only 00063 void DiracClover::Clover(cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaParity parity) const 00064 { 00065 cloverApply(out, clover, in, parity); 00066 } 00067 00068 // FIXME: create kernel to eliminate tmp 00069 void DiracClover::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00070 { 00071 checkFullSpinor(out, in); 00072 cudaColorSpinorField *tmp=0; // this hack allows for tmp2 to be full or parity field 00073 if (tmp2) { 00074 if (tmp2->SiteSubset() == QUDA_FULL_SITE_SUBSET) tmp = &(tmp2->Even()); 00075 else tmp = tmp2; 00076 } 00077 bool reset = newTmp(&tmp, in.Even()); 00078 00079 Clover(*tmp, in.Odd(), QUDA_ODD_PARITY); 00080 DslashXpay(out.Odd(), in.Even(), QUDA_ODD_PARITY, *tmp, -kappa); 00081 Clover(*tmp, in.Even(), QUDA_EVEN_PARITY); 00082 DslashXpay(out.Even(), in.Odd(), QUDA_EVEN_PARITY, *tmp, -kappa); 00083 00084 deleteTmp(&tmp, reset); 00085 } 00086 00087 void DiracClover::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00088 { 00089 checkFullSpinor(out, in); 00090 00091 bool reset = newTmp(&tmp1, in); 00092 checkFullSpinor(*tmp1, in); 00093 00094 M(*tmp1, in); 00095 Mdag(out, *tmp1); 00096 00097 deleteTmp(&tmp1, reset); 00098 } 00099 00100 void DiracClover::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00101 cudaColorSpinorField &x, cudaColorSpinorField &b, 00102 const QudaSolutionType solType) const 00103 { 00104 if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { 00105 errorQuda("Preconditioned solution requires a preconditioned solve_type"); 00106 } 00107 00108 src = &b; 00109 sol = &x; 00110 } 00111 00112 void DiracClover::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00113 const QudaSolutionType solType) const 00114 { 00115 // do nothing 00116 } 00117 00118 DiracCloverPC::DiracCloverPC(const DiracParam ¶m) 00119 : DiracClover(param), cloverInv(*(param.cloverInv)) 00120 { 00121 00122 } 00123 00124 DiracCloverPC::DiracCloverPC(const DiracCloverPC &dirac) 00125 : DiracClover(dirac), cloverInv(dirac.clover) 00126 { 00127 00128 } 00129 00130 DiracCloverPC::~DiracCloverPC() 00131 { 00132 00133 } 00134 00135 DiracCloverPC& DiracCloverPC::operator=(const DiracCloverPC &dirac) 00136 { 00137 if (&dirac != this) { 00138 DiracClover::operator=(dirac); 00139 cloverInv = dirac.cloverInv; 00140 } 00141 00142 return *this; 00143 } 00144 00145 // Public method 00146 void DiracCloverPC::CloverInv(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00147 const QudaParity parity) const 00148 { 00149 cloverApply(out, cloverInv, in, parity); 00150 } 00151 00152 // apply hopping term, then clover: (A_ee^-1 D_eo) or (A_oo^-1 D_oe), 00153 // and likewise for dagger: (A_ee^-1 D^dagger_eo) or (A_oo^-1 D^dagger_oe) 00154 // NOTE - this isn't Dslash dagger since order should be reversed! 00155 void DiracCloverPC::Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00156 const QudaParity parity) const 00157 { 00158 if (!initDslash) initDslashConstants(gauge, in.stride, cloverInv.even.stride); 00159 checkParitySpinor(in, out, cloverInv); 00160 checkSpinorAlias(in, out); 00161 00162 cloverDslashCuda(out.v, out.norm, gauge, cloverInv, in.v, in.norm, parity, dagger, 00163 0, 0, 0.0, out.volume, out.length, in.Precision()); 00164 00165 flops += (1320+504)*in.volume; 00166 } 00167 00168 // xpay version of the above 00169 void DiracCloverPC::DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00170 const QudaParity parity, const cudaColorSpinorField &x, 00171 const double &k) const 00172 { 00173 if (!initDslash) initDslashConstants(gauge, in.stride, cloverInv.even.stride); 00174 checkParitySpinor(in, out, cloverInv); 00175 checkSpinorAlias(in, out); 00176 00177 cloverDslashCuda(out.v, out.norm, gauge, cloverInv, in.v, in.norm, parity, dagger, 00178 x.v, x.norm, k, out.volume, out.length, in.Precision()); 00179 00180 flops += (1320+504+48)*in.volume; 00181 } 00182 00183 // Apply the even-odd preconditioned clover-improved Dirac operator 00184 void DiracCloverPC::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00185 { 00186 double kappa2 = -kappa*kappa; 00187 00188 // FIXME: For asymmetric, a "DslashCxpay" kernel would improve performance. 00189 bool reset = newTmp(&tmp1, in); 00190 00191 if (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) { 00192 Dslash(*tmp1, in, QUDA_ODD_PARITY); 00193 Clover(out, in, QUDA_EVEN_PARITY); 00194 DiracWilson::DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, out, kappa2); // safe since out is not read after writing 00195 } else if (matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) { 00196 Dslash(*tmp1, in, QUDA_EVEN_PARITY); 00197 Clover(out, in, QUDA_ODD_PARITY); 00198 DiracWilson::DslashXpay(out, *tmp1, QUDA_ODD_PARITY, out, kappa2); 00199 } else if (!dagger) { // symmetric preconditioning 00200 if (matpcType == QUDA_MATPC_EVEN_EVEN) { 00201 Dslash(*tmp1, in, QUDA_ODD_PARITY); 00202 DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, in, kappa2); 00203 } else if (matpcType == QUDA_MATPC_ODD_ODD) { 00204 Dslash(*tmp1, in, QUDA_EVEN_PARITY); 00205 DslashXpay(out, *tmp1, QUDA_ODD_PARITY, in, kappa2); 00206 } else { 00207 errorQuda("Invalid matpcType"); 00208 } 00209 } else { // symmetric preconditioning, dagger 00210 if (matpcType == QUDA_MATPC_EVEN_EVEN) { 00211 CloverInv(out, in, QUDA_EVEN_PARITY); 00212 Dslash(*tmp1, out, QUDA_ODD_PARITY); 00213 DiracWilson::DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, in, kappa2); 00214 } else if (matpcType == QUDA_MATPC_ODD_ODD) { 00215 CloverInv(out, in, QUDA_ODD_PARITY); 00216 Dslash(*tmp1, out, QUDA_EVEN_PARITY); 00217 DiracWilson::DslashXpay(out, *tmp1, QUDA_ODD_PARITY, in, kappa2); 00218 } else { 00219 errorQuda("MatPCType %d not valid for DiracCloverPC", matpcType); 00220 } 00221 } 00222 00223 deleteTmp(&tmp1, reset); 00224 } 00225 00226 void DiracCloverPC::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00227 { 00228 // need extra temporary because of symmetric preconditioning dagger 00229 bool reset = newTmp(&tmp2, in); 00230 00231 M(*tmp2, in); 00232 Mdag(out, *tmp2); 00233 00234 deleteTmp(&tmp2, reset); 00235 } 00236 00237 void DiracCloverPC::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00238 cudaColorSpinorField &x, cudaColorSpinorField &b, 00239 const QudaSolutionType solType) const 00240 { 00241 // we desire solution to preconditioned system 00242 if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { 00243 src = &b; 00244 sol = &x; 00245 return; 00246 } 00247 00248 bool reset = newTmp(&tmp1, b.Even()); 00249 00250 // we desire solution to full system 00251 if (matpcType == QUDA_MATPC_EVEN_EVEN) { 00252 // src = A_ee^-1 (b_e + k D_eo A_oo^-1 b_o) 00253 src = &(x.Odd()); 00254 CloverInv(*src, b.Odd(), QUDA_ODD_PARITY); 00255 DiracWilson::DslashXpay(*tmp1, *src, QUDA_EVEN_PARITY, b.Even(), kappa); 00256 CloverInv(*src, *tmp1, QUDA_EVEN_PARITY); 00257 sol = &(x.Even()); 00258 } else if (matpcType == QUDA_MATPC_ODD_ODD) { 00259 // src = A_oo^-1 (b_o + k D_oe A_ee^-1 b_e) 00260 src = &(x.Even()); 00261 CloverInv(*src, b.Even(), QUDA_EVEN_PARITY); 00262 DiracWilson::DslashXpay(*tmp1, *src, QUDA_ODD_PARITY, b.Odd(), kappa); 00263 CloverInv(*src, *tmp1, QUDA_ODD_PARITY); 00264 sol = &(x.Odd()); 00265 } else if (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) { 00266 // src = b_e + k D_eo A_oo^-1 b_o 00267 src = &(x.Odd()); 00268 CloverInv(*tmp1, b.Odd(), QUDA_ODD_PARITY); // safe even when *tmp1 = b.odd 00269 DiracWilson::DslashXpay(*src, *tmp1, QUDA_EVEN_PARITY, b.Even(), kappa); 00270 sol = &(x.Even()); 00271 } else if (matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) { 00272 // src = b_o + k D_oe A_ee^-1 b_e 00273 src = &(x.Even()); 00274 CloverInv(*tmp1, b.Even(), QUDA_EVEN_PARITY); // safe even when *tmp1 = b.even 00275 DiracWilson::DslashXpay(*src, *tmp1, QUDA_ODD_PARITY, b.Odd(), kappa); 00276 sol = &(x.Odd()); 00277 } else { 00278 errorQuda("MatPCType %d not valid for DiracCloverPC", matpcType); 00279 } 00280 00281 // here we use final solution to store parity solution and parity source 00282 // b is now up for grabs if we want 00283 00284 deleteTmp(&tmp1, reset); 00285 00286 } 00287 00288 void DiracCloverPC::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00289 const QudaSolutionType solType) const 00290 { 00291 if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { 00292 return; 00293 } 00294 00295 checkFullSpinor(x, b); 00296 00297 bool reset = newTmp(&tmp1, b.Even()); 00298 00299 // create full solution 00300 00301 if (matpcType == QUDA_MATPC_EVEN_EVEN || 00302 matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) { 00303 // x_o = A_oo^-1 (b_o + k D_oe x_e) 00304 DiracWilson::DslashXpay(*tmp1, x.Even(), QUDA_ODD_PARITY, b.Odd(), kappa); 00305 CloverInv(x.Odd(), *tmp1, QUDA_ODD_PARITY); 00306 } else if (matpcType == QUDA_MATPC_ODD_ODD || 00307 matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) { 00308 // x_e = A_ee^-1 (b_e + k D_eo x_o) 00309 DiracWilson::DslashXpay(*tmp1, x.Odd(), QUDA_EVEN_PARITY, b.Even(), kappa); 00310 CloverInv(x.Even(), *tmp1, QUDA_EVEN_PARITY); 00311 } else { 00312 errorQuda("MatPCType %d not valid for DiracCloverPC", matpcType); 00313 } 00314 00315 deleteTmp(&tmp1, reset); 00316 00317 } 00318
1.7.3