|
QUDA v0.3.2
A library for QCD on GPUs
|
00001 #ifndef _DIRAC_QUDA_H 00002 #define _DIRAC_QUDA_H 00003 00004 #include <quda_internal.h> 00005 #include <color_spinor_field.h> 00006 #include <dslash_quda.h> 00007 00008 // Params for Dirac operator 00009 class DiracParam { 00010 00011 public: 00012 QudaDiracType type; 00013 double kappa; 00014 double mass; 00015 double m5; // used by domain wall only 00016 MatPCType matpcType; 00017 DagType dagger; 00018 FullGauge *gauge; 00019 FullGauge *fatGauge; // used by staggered only 00020 FullGauge *longGauge; // used by staggered only 00021 FullClover *clover; 00022 FullClover *cloverInv; 00023 00024 double mu; // used by twisted mass only 00025 00026 cudaColorSpinorField *tmp1; 00027 cudaColorSpinorField *tmp2; // used only by Clover and TM 00028 00029 QudaVerbosity verbose; 00030 00031 DiracParam() 00032 : type(QUDA_INVALID_DIRAC), kappa(0.0), m5(0.0), matpcType(QUDA_MATPC_INVALID), 00033 dagger(QUDA_DAG_INVALID), gauge(0), clover(0), cloverInv(0), mu(0.0), 00034 tmp1(0), tmp2(0), verbose(QUDA_SILENT) 00035 { 00036 00037 } 00038 00039 }; 00040 00041 void setDiracParam(DiracParam &diracParam, QudaInvertParam *inv_param, bool pc); 00042 void setDiracSloppyParam(DiracParam &diracParam, QudaInvertParam *inv_param, bool pc); 00043 00044 // forward declarations 00045 class DiracM; 00046 class DiracMdagM; 00047 class DiracMdag; 00048 00049 // Abstract base class 00050 class Dirac { 00051 00052 friend class DiracM; 00053 friend class DiracMdagM; 00054 friend class DiracMdag; 00055 00056 protected: 00057 FullGauge &gauge; 00058 double kappa; 00059 double mass; 00060 MatPCType matpcType; 00061 mutable DagType dagger; // mutable to simplify implementation of Mdag 00062 mutable unsigned long long flops; 00063 mutable cudaColorSpinorField *tmp1; // temporary hack 00064 mutable cudaColorSpinorField *tmp2; // temporary hack 00065 00066 bool newTmp(cudaColorSpinorField **, const cudaColorSpinorField &) const; 00067 void deleteTmp(cudaColorSpinorField **, const bool &reset) const; 00068 00069 public: 00070 Dirac(const DiracParam ¶m); 00071 Dirac(const Dirac &dirac); 00072 virtual ~Dirac(); 00073 Dirac& operator=(const Dirac &dirac); 00074 00075 virtual void checkParitySpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const; 00076 virtual void checkFullSpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const; 00077 void checkSpinorAlias(const cudaColorSpinorField &, const cudaColorSpinorField &) const; 00078 00079 virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00080 const QudaParity parity) const = 0; 00081 virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00082 const QudaParity parity, const cudaColorSpinorField &x, 00083 const double &k) const = 0; 00084 virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const = 0; 00085 virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const = 0; 00086 void Mdag(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00087 00088 // required methods to use e-o preconditioning for solving full system 00089 virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00090 cudaColorSpinorField &x, cudaColorSpinorField &b, 00091 const QudaSolutionType) const = 0; 00092 virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00093 const QudaSolutionType) const = 0; 00094 00095 // Dirac operator factory 00096 static Dirac* create(const DiracParam ¶m); 00097 00098 unsigned long long Flops() const { unsigned long long rtn = flops; flops = 0; return rtn; } 00099 }; 00100 00101 // Full Wilson 00102 class DiracWilson : public Dirac { 00103 00104 protected: 00105 00106 public: 00107 DiracWilson(const DiracParam ¶m); 00108 DiracWilson(const DiracWilson &dirac); 00109 virtual ~DiracWilson(); 00110 DiracWilson& operator=(const DiracWilson &dirac); 00111 00112 virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00113 const QudaParity parity) const; 00114 virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00115 const QudaParity parity, const cudaColorSpinorField &x, const double &k) const; 00116 virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00117 virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00118 00119 virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00120 cudaColorSpinorField &x, cudaColorSpinorField &b, 00121 const QudaSolutionType) const; 00122 virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00123 const QudaSolutionType) const; 00124 }; 00125 00126 // Even-odd preconditioned Wilson 00127 class DiracWilsonPC : public DiracWilson { 00128 00129 private: 00130 00131 public: 00132 DiracWilsonPC(const DiracParam ¶m); 00133 DiracWilsonPC(const DiracWilsonPC &dirac); 00134 virtual ~DiracWilsonPC(); 00135 DiracWilsonPC& operator=(const DiracWilsonPC &dirac); 00136 00137 void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00138 void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00139 00140 void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00141 cudaColorSpinorField &x, cudaColorSpinorField &b, 00142 const QudaSolutionType) const; 00143 void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00144 const QudaSolutionType) const; 00145 }; 00146 00147 // Full clover 00148 class DiracClover : public DiracWilson { 00149 00150 protected: 00151 FullClover &clover; 00152 void checkParitySpinor(const cudaColorSpinorField &, const cudaColorSpinorField &, 00153 const FullClover &) const; 00154 void cloverApply(cudaColorSpinorField &out, const FullClover &clover, const cudaColorSpinorField &in, 00155 const QudaParity parity) const; 00156 00157 public: 00158 DiracClover(const DiracParam ¶m); 00159 DiracClover(const DiracClover &dirac); 00160 virtual ~DiracClover(); 00161 DiracClover& operator=(const DiracClover &dirac); 00162 00163 void Clover(cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaParity parity) const; 00164 virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00165 virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00166 00167 virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00168 cudaColorSpinorField &x, cudaColorSpinorField &b, 00169 const QudaSolutionType) const; 00170 virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00171 const QudaSolutionType) const; 00172 }; 00173 00174 // Even-odd preconditioned clover 00175 class DiracCloverPC : public DiracClover { 00176 00177 private: 00178 FullClover &cloverInv; 00179 00180 public: 00181 DiracCloverPC(const DiracParam ¶m); 00182 DiracCloverPC(const DiracCloverPC &dirac); 00183 virtual ~DiracCloverPC(); 00184 DiracCloverPC& operator=(const DiracCloverPC &dirac); 00185 00186 void CloverInv(cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaParity parity) const; 00187 void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00188 const QudaParity parity) const; 00189 void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00190 const QudaParity parity, const cudaColorSpinorField &x, const double &k) const; 00191 00192 void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00193 void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00194 00195 void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00196 cudaColorSpinorField &x, cudaColorSpinorField &b, 00197 const QudaSolutionType) const; 00198 void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00199 const QudaSolutionType) const; 00200 }; 00201 00202 00203 00204 // Full domain wall 00205 class DiracDomainWall : public DiracWilson { 00206 00207 protected: 00208 double m5; 00209 double kappa5; 00210 00211 public: 00212 DiracDomainWall(const DiracParam ¶m); 00213 DiracDomainWall(const DiracDomainWall &dirac); 00214 virtual ~DiracDomainWall(); 00215 DiracDomainWall& operator=(const DiracDomainWall &dirac); 00216 00217 void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00218 const QudaParity parity) const; 00219 void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00220 const QudaParity parity, const cudaColorSpinorField &x, const double &k) const; 00221 00222 virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00223 virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00224 00225 virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00226 cudaColorSpinorField &x, cudaColorSpinorField &b, 00227 const QudaSolutionType) const; 00228 virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00229 const QudaSolutionType) const; 00230 }; 00231 00232 // 5d Even-odd preconditioned domain wall 00233 class DiracDomainWallPC : public DiracDomainWall { 00234 00235 private: 00236 00237 public: 00238 DiracDomainWallPC(const DiracParam ¶m); 00239 DiracDomainWallPC(const DiracDomainWallPC &dirac); 00240 virtual ~DiracDomainWallPC(); 00241 DiracDomainWallPC& operator=(const DiracDomainWallPC &dirac); 00242 00243 void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00244 void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00245 00246 void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00247 cudaColorSpinorField &x, cudaColorSpinorField &b, 00248 const QudaSolutionType) const; 00249 void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00250 const QudaSolutionType) const; 00251 }; 00252 00253 // Full staggered 00254 class DiracStaggered : public Dirac { 00255 00256 protected: 00257 FullGauge *fatGauge; 00258 FullGauge *longGauge; 00259 00260 public: 00261 DiracStaggered(const DiracParam ¶m); 00262 DiracStaggered(const DiracStaggered &dirac); 00263 virtual ~DiracStaggered(); 00264 DiracStaggered& operator=(const DiracStaggered &dirac); 00265 00266 virtual void checkParitySpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const; 00267 00268 virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00269 const QudaParity parity) const; 00270 virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00271 const QudaParity parity, const cudaColorSpinorField &x, const double &k) const; 00272 virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00273 virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00274 00275 virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00276 cudaColorSpinorField &x, cudaColorSpinorField &b, 00277 const QudaSolutionType) const; 00278 virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00279 const QudaSolutionType) const; 00280 }; 00281 00282 // Even-odd preconditioned staggered 00283 class DiracStaggeredPC : public Dirac { 00284 00285 protected: 00286 FullGauge *fatGauge; 00287 FullGauge *longGauge; 00288 00289 public: 00290 DiracStaggeredPC(const DiracParam ¶m); 00291 DiracStaggeredPC(const DiracStaggeredPC &dirac); 00292 virtual ~DiracStaggeredPC(); 00293 DiracStaggeredPC& operator=(const DiracStaggeredPC &dirac); 00294 00295 virtual void checkParitySpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const; 00296 00297 virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00298 const QudaParity parity) const; 00299 virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00300 const QudaParity parity, const cudaColorSpinorField &x, const double &k) const; 00301 virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00302 virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00303 00304 virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00305 cudaColorSpinorField &x, cudaColorSpinorField &b, 00306 const QudaSolutionType) const; 00307 virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00308 const QudaSolutionType) const; 00309 }; 00310 00311 // Full twisted mass 00312 class DiracTwistedMass : public DiracWilson { 00313 00314 protected: 00315 double mu; 00316 void twistedApply(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00317 const QudaTwistGamma5Type twistType) const; 00318 00319 public: 00320 DiracTwistedMass(const DiracParam ¶m); 00321 DiracTwistedMass(const DiracTwistedMass &dirac); 00322 virtual ~DiracTwistedMass(); 00323 DiracTwistedMass& operator=(const DiracTwistedMass &dirac); 00324 00325 void Twist(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00326 00327 virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00328 virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00329 00330 virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00331 cudaColorSpinorField &x, cudaColorSpinorField &b, 00332 const QudaSolutionType) const; 00333 virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00334 const QudaSolutionType) const; 00335 }; 00336 00337 // Even-odd preconditioned twisted mass 00338 class DiracTwistedMassPC : public DiracTwistedMass { 00339 00340 private: 00341 00342 public: 00343 DiracTwistedMassPC(const DiracParam ¶m); 00344 DiracTwistedMassPC(const DiracTwistedMassPC &dirac); 00345 virtual ~DiracTwistedMassPC(); 00346 DiracTwistedMassPC& operator=(const DiracTwistedMassPC &dirac); 00347 00348 void TwistInv(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00349 00350 virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00351 const QudaParity parity) const; 00352 virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00353 const QudaParity parity, const cudaColorSpinorField &x, const double &k) const; 00354 void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00355 void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00356 00357 void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00358 cudaColorSpinorField &x, cudaColorSpinorField &b, 00359 const QudaSolutionType) const; 00360 void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00361 const QudaSolutionType) const; 00362 }; 00363 00364 // Functor base class for applying a given Dirac matrix (M, MdagM, etc.) 00365 class DiracMatrix { 00366 00367 protected: 00368 const Dirac *dirac; 00369 00370 public: 00371 DiracMatrix(const Dirac &d) : dirac(&d) { } 00372 DiracMatrix(const Dirac *d) : dirac(d) { } 00373 virtual ~DiracMatrix() = 0; 00374 00375 virtual void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in) const = 0; 00376 virtual void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00377 cudaColorSpinorField &tmp) const = 0; 00378 virtual void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00379 cudaColorSpinorField &Tmp1, cudaColorSpinorField &Tmp2) const = 0; 00380 00381 unsigned long long flops() const { return dirac->Flops(); } 00382 }; 00383 00384 inline DiracMatrix::~DiracMatrix() 00385 { 00386 00387 } 00388 00389 class DiracM : public DiracMatrix { 00390 00391 public: 00392 DiracM(const Dirac &d) : DiracMatrix(d) { } 00393 DiracM(const Dirac *d) : DiracMatrix(d) { } 00394 00395 void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00396 { 00397 dirac->M(out, in); 00398 } 00399 00400 void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, cudaColorSpinorField &tmp) const 00401 { 00402 dirac->tmp1 = &tmp; 00403 dirac->M(out, in); 00404 dirac->tmp1 = NULL; 00405 } 00406 00407 void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00408 cudaColorSpinorField &Tmp1, cudaColorSpinorField &Tmp2) const 00409 { 00410 dirac->tmp1 = &Tmp1; 00411 dirac->tmp2 = &Tmp2; 00412 dirac->M(out, in); 00413 dirac->tmp2 = NULL; 00414 dirac->tmp1 = NULL; 00415 } 00416 }; 00417 00418 class DiracMdagM : public DiracMatrix { 00419 00420 public: 00421 DiracMdagM(const Dirac &d) : DiracMatrix(d) { } 00422 DiracMdagM(const Dirac *d) : DiracMatrix(d) { } 00423 00424 void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00425 { 00426 dirac->MdagM(out, in); 00427 } 00428 00429 void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, cudaColorSpinorField &tmp) const 00430 { 00431 dirac->tmp1 = &tmp; 00432 dirac->MdagM(out, in); 00433 dirac->tmp1 = NULL; 00434 } 00435 00436 void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00437 cudaColorSpinorField &Tmp1, cudaColorSpinorField &Tmp2) const 00438 { 00439 dirac->tmp1 = &Tmp1; 00440 dirac->tmp2 = &Tmp2; 00441 dirac->MdagM(out, in); 00442 dirac->tmp2 = NULL; 00443 dirac->tmp1 = NULL; 00444 } 00445 }; 00446 00447 class DiracMdag : public DiracMatrix { 00448 00449 public: 00450 DiracMdag(const Dirac &d) : DiracMatrix(d) { } 00451 DiracMdag(const Dirac *d) : DiracMatrix(d) { } 00452 00453 void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00454 { 00455 dirac->Mdag(out, in); 00456 } 00457 00458 void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, cudaColorSpinorField &tmp) const 00459 { 00460 dirac->tmp1 = &tmp; 00461 dirac->Mdag(out, in); 00462 dirac->tmp1 = NULL; 00463 } 00464 00465 void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00466 cudaColorSpinorField &Tmp1, cudaColorSpinorField &Tmp2) const 00467 { 00468 dirac->tmp1 = &Tmp1; 00469 dirac->tmp2 = &Tmp2; 00470 dirac->Mdag(out, in); 00471 dirac->tmp2 = NULL; 00472 dirac->tmp1 = NULL; 00473 } 00474 }; 00475 00476 #endif // _DIRAC_QUDA_H
1.7.3