QUDA v0.4.0
A library for QCD on GPUs
|
00001 #ifndef _DIRAC_QUDA_H 00002 #define _DIRAC_QUDA_H 00003 00004 #include <quda_internal.h> 00005 #include <color_spinor_field.h> 00006 #include <gauge_field.h> 00007 #include <clover_field.h> 00008 #include <dslash_quda.h> 00009 00010 #include <face_quda.h> 00011 00012 // Params for Dirac operator 00013 class DiracParam { 00014 00015 public: 00016 QudaDiracType type; 00017 double kappa; 00018 double mass; 00019 double m5; // used by domain wall only 00020 MatPCType matpcType; 00021 DagType dagger; 00022 cudaGaugeField *gauge; 00023 cudaGaugeField *fatGauge; // used by staggered only 00024 cudaGaugeField *longGauge; // used by staggered only 00025 cudaCloverField *clover; 00026 00027 double mu; // used by twisted mass only 00028 00029 cudaColorSpinorField *tmp1; 00030 cudaColorSpinorField *tmp2; // used only by Clover and TM 00031 00032 QudaVerbosity verbose; 00033 00034 int commDim[QUDA_MAX_DIM]; // whether to do comms or not 00035 00036 DiracParam() 00037 : type(QUDA_INVALID_DIRAC), kappa(0.0), m5(0.0), matpcType(QUDA_MATPC_INVALID), 00038 dagger(QUDA_DAG_INVALID), gauge(0), clover(0), mu(0.0), 00039 tmp1(0), tmp2(0), verbose(QUDA_SILENT) 00040 { 00041 00042 } 00043 00044 }; 00045 00046 void setDiracParam(DiracParam &diracParam, QudaInvertParam *inv_param, bool pc); 00047 void setDiracSloppyParam(DiracParam &diracParam, QudaInvertParam *inv_param, bool pc); 00048 00049 // forward declarations 00050 class DiracMatrix; 00051 class DiracM; 00052 class DiracMdagM; 00053 class DiracMdag; 00054 00055 // Abstract base class 00056 class Dirac { 00057 00058 friend class DiracMatrix; 00059 friend class DiracM; 00060 friend class DiracMdagM; 00061 friend class DiracMdag; 00062 00063 protected: 00064 cudaGaugeField &gauge; 00065 double kappa; 00066 double mass; 00067 MatPCType matpcType; 00068 mutable DagType dagger; // mutable to simplify implementation of Mdag 00069 mutable unsigned long long flops; 00070 mutable cudaColorSpinorField *tmp1; // temporary hack 00071 mutable cudaColorSpinorField *tmp2; // temporary hack 00072 00073 bool newTmp(cudaColorSpinorField **, const cudaColorSpinorField &) const; 00074 void deleteTmp(cudaColorSpinorField **, const bool &reset) const; 00075 00076 QudaTune tune; 00077 QudaVerbosity verbose; 00078 00079 int commDim[QUDA_MAX_DIM]; // whether do comms or not 00080 00081 public: 00082 Dirac(const DiracParam ¶m); 00083 Dirac(const Dirac &dirac); 00084 virtual ~Dirac(); 00085 Dirac& operator=(const Dirac &dirac); 00086 00087 virtual void checkParitySpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const; 00088 virtual void checkFullSpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const; 00089 void checkSpinorAlias(const cudaColorSpinorField &, const cudaColorSpinorField &) const; 00090 00091 virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00092 const QudaParity parity) const = 0; 00093 virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00094 const QudaParity parity, const cudaColorSpinorField &x, 00095 const double &k) const = 0; 00096 virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const = 0; 00097 virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const = 0; 00098 void Mdag(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00099 00100 // required methods to use e-o preconditioning for solving full system 00101 virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00102 cudaColorSpinorField &x, cudaColorSpinorField &b, 00103 const QudaSolutionType) const = 0; 00104 virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00105 const QudaSolutionType) const = 0; 00106 void setMass(double mass){ this->mass = mass;} 00107 // Dirac operator factory 00108 static Dirac* create(const DiracParam ¶m); 00109 00110 unsigned long long Flops() const { unsigned long long rtn = flops; flops = 0; return rtn; } 00111 QudaVerbosity Verbose() const { return verbose; } 00112 }; 00113 00114 // Full Wilson 00115 class DiracWilson : public Dirac { 00116 00117 protected: 00118 FaceBuffer face; // multi-gpu communication buffers 00119 00120 public: 00121 DiracWilson(const DiracParam ¶m); 00122 DiracWilson(const DiracWilson &dirac); 00123 virtual ~DiracWilson(); 00124 DiracWilson& operator=(const DiracWilson &dirac); 00125 00126 virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00127 const QudaParity parity) const; 00128 virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00129 const QudaParity parity, const cudaColorSpinorField &x, const double &k) const; 00130 virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00131 virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00132 00133 virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00134 cudaColorSpinorField &x, cudaColorSpinorField &b, 00135 const QudaSolutionType) const; 00136 virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00137 const QudaSolutionType) const; 00138 }; 00139 00140 // Even-odd preconditioned Wilson 00141 class DiracWilsonPC : public DiracWilson { 00142 00143 private: 00144 00145 public: 00146 DiracWilsonPC(const DiracParam ¶m); 00147 DiracWilsonPC(const DiracWilsonPC &dirac); 00148 virtual ~DiracWilsonPC(); 00149 DiracWilsonPC& operator=(const DiracWilsonPC &dirac); 00150 00151 void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00152 void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00153 00154 void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00155 cudaColorSpinorField &x, cudaColorSpinorField &b, 00156 const QudaSolutionType) const; 00157 void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00158 const QudaSolutionType) const; 00159 }; 00160 00161 // Full clover 00162 class DiracClover : public DiracWilson { 00163 00164 protected: 00165 cudaCloverField &clover; 00166 void checkParitySpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const; 00167 00168 public: 00169 DiracClover(const DiracParam ¶m); 00170 DiracClover(const DiracClover &dirac); 00171 virtual ~DiracClover(); 00172 DiracClover& operator=(const DiracClover &dirac); 00173 00174 void Clover(cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaParity parity) const; 00175 virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00176 virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00177 00178 virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00179 cudaColorSpinorField &x, cudaColorSpinorField &b, 00180 const QudaSolutionType) const; 00181 virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00182 const QudaSolutionType) const; 00183 }; 00184 00185 // Even-odd preconditioned clover 00186 class DiracCloverPC : public DiracClover { 00187 00188 public: 00189 DiracCloverPC(const DiracParam ¶m); 00190 DiracCloverPC(const DiracCloverPC &dirac); 00191 virtual ~DiracCloverPC(); 00192 DiracCloverPC& operator=(const DiracCloverPC &dirac); 00193 00194 void CloverInv(cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaParity parity) const; 00195 void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00196 const QudaParity parity) const; 00197 void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00198 const QudaParity parity, const cudaColorSpinorField &x, const double &k) const; 00199 00200 void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00201 void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00202 00203 void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00204 cudaColorSpinorField &x, cudaColorSpinorField &b, 00205 const QudaSolutionType) const; 00206 void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00207 const QudaSolutionType) const; 00208 }; 00209 00210 00211 00212 // Full domain wall 00213 class DiracDomainWall : public DiracWilson { 00214 00215 protected: 00216 double m5; 00217 double kappa5; 00218 00219 public: 00220 DiracDomainWall(const DiracParam ¶m); 00221 DiracDomainWall(const DiracDomainWall &dirac); 00222 virtual ~DiracDomainWall(); 00223 DiracDomainWall& operator=(const DiracDomainWall &dirac); 00224 00225 void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00226 const QudaParity parity) const; 00227 void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00228 const QudaParity parity, const cudaColorSpinorField &x, const double &k) const; 00229 00230 virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00231 virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00232 00233 virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00234 cudaColorSpinorField &x, cudaColorSpinorField &b, 00235 const QudaSolutionType) const; 00236 virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00237 const QudaSolutionType) const; 00238 }; 00239 00240 // 5d Even-odd preconditioned domain wall 00241 class DiracDomainWallPC : public DiracDomainWall { 00242 00243 private: 00244 00245 public: 00246 DiracDomainWallPC(const DiracParam ¶m); 00247 DiracDomainWallPC(const DiracDomainWallPC &dirac); 00248 virtual ~DiracDomainWallPC(); 00249 DiracDomainWallPC& operator=(const DiracDomainWallPC &dirac); 00250 00251 void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00252 void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00253 00254 void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00255 cudaColorSpinorField &x, cudaColorSpinorField &b, 00256 const QudaSolutionType) const; 00257 void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00258 const QudaSolutionType) const; 00259 }; 00260 00261 // Full twisted mass 00262 class DiracTwistedMass : public DiracWilson { 00263 00264 protected: 00265 double mu; 00266 void twistedApply(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00267 const QudaTwistGamma5Type twistType) const; 00268 00269 public: 00270 DiracTwistedMass(const DiracParam ¶m); 00271 DiracTwistedMass(const DiracTwistedMass &dirac); 00272 virtual ~DiracTwistedMass(); 00273 DiracTwistedMass& operator=(const DiracTwistedMass &dirac); 00274 00275 void Twist(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00276 00277 virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00278 virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00279 00280 virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00281 cudaColorSpinorField &x, cudaColorSpinorField &b, 00282 const QudaSolutionType) const; 00283 virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00284 const QudaSolutionType) const; 00285 }; 00286 00287 // Even-odd preconditioned twisted mass 00288 class DiracTwistedMassPC : public DiracTwistedMass { 00289 00290 public: 00291 DiracTwistedMassPC(const DiracParam ¶m); 00292 DiracTwistedMassPC(const DiracTwistedMassPC &dirac); 00293 virtual ~DiracTwistedMassPC(); 00294 DiracTwistedMassPC& operator=(const DiracTwistedMassPC &dirac); 00295 00296 void TwistInv(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00297 00298 virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00299 const QudaParity parity) const; 00300 virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00301 const QudaParity parity, const cudaColorSpinorField &x, const double &k) const; 00302 void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00303 void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00304 00305 void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00306 cudaColorSpinorField &x, cudaColorSpinorField &b, 00307 const QudaSolutionType) const; 00308 void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00309 const QudaSolutionType) const; 00310 }; 00311 00312 // Full staggered 00313 class DiracStaggered : public Dirac { 00314 00315 protected: 00316 cudaGaugeField *fatGauge; 00317 cudaGaugeField *longGauge; 00318 FaceBuffer face; // multi-gpu communication buffers 00319 00320 public: 00321 DiracStaggered(const DiracParam ¶m); 00322 DiracStaggered(const DiracStaggered &dirac); 00323 virtual ~DiracStaggered(); 00324 DiracStaggered& operator=(const DiracStaggered &dirac); 00325 00326 virtual void checkParitySpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const; 00327 00328 virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00329 const QudaParity parity) const; 00330 virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00331 const QudaParity parity, const cudaColorSpinorField &x, const double &k) const; 00332 virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00333 virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00334 00335 virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00336 cudaColorSpinorField &x, cudaColorSpinorField &b, 00337 const QudaSolutionType) const; 00338 virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00339 const QudaSolutionType) const; 00340 }; 00341 00342 // Even-odd preconditioned staggered 00343 class DiracStaggeredPC : public DiracStaggered { 00344 00345 protected: 00346 00347 public: 00348 DiracStaggeredPC(const DiracParam ¶m); 00349 DiracStaggeredPC(const DiracStaggeredPC &dirac); 00350 virtual ~DiracStaggeredPC(); 00351 DiracStaggeredPC& operator=(const DiracStaggeredPC &dirac); 00352 00353 virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00354 virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const; 00355 00356 virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00357 cudaColorSpinorField &x, cudaColorSpinorField &b, 00358 const QudaSolutionType) const; 00359 virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00360 const QudaSolutionType) const; 00361 }; 00362 00363 // Functor base class for applying a given Dirac matrix (M, MdagM, etc.) 00364 class DiracMatrix { 00365 00366 protected: 00367 const Dirac *dirac; 00368 00369 public: 00370 DiracMatrix(const Dirac &d) : dirac(&d) { } 00371 DiracMatrix(const Dirac *d) : dirac(d) { } 00372 virtual ~DiracMatrix() = 0; 00373 00374 virtual void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in) const = 0; 00375 virtual void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00376 cudaColorSpinorField &tmp) const = 0; 00377 virtual void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00378 cudaColorSpinorField &Tmp1, cudaColorSpinorField &Tmp2) const = 0; 00379 00380 unsigned long long flops() const { return dirac->Flops(); } 00381 }; 00382 00383 inline DiracMatrix::~DiracMatrix() 00384 { 00385 00386 } 00387 00388 class DiracM : public DiracMatrix { 00389 00390 public: 00391 DiracM(const Dirac &d) : DiracMatrix(d) { } 00392 DiracM(const Dirac *d) : DiracMatrix(d) { } 00393 00394 void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00395 { 00396 dirac->M(out, in); 00397 } 00398 00399 void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, cudaColorSpinorField &tmp) const 00400 { 00401 dirac->tmp1 = &tmp; 00402 dirac->M(out, in); 00403 dirac->tmp1 = NULL; 00404 } 00405 00406 void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00407 cudaColorSpinorField &Tmp1, cudaColorSpinorField &Tmp2) const 00408 { 00409 dirac->tmp1 = &Tmp1; 00410 dirac->tmp2 = &Tmp2; 00411 dirac->M(out, in); 00412 dirac->tmp2 = NULL; 00413 dirac->tmp1 = NULL; 00414 } 00415 }; 00416 00417 class DiracMdagM : public DiracMatrix { 00418 00419 public: 00420 DiracMdagM(const Dirac &d) : DiracMatrix(d) { } 00421 DiracMdagM(const Dirac *d) : DiracMatrix(d) { } 00422 00423 void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00424 { 00425 dirac->MdagM(out, in); 00426 } 00427 00428 void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, cudaColorSpinorField &tmp) const 00429 { 00430 dirac->tmp1 = &tmp; 00431 dirac->MdagM(out, in); 00432 dirac->tmp1 = NULL; 00433 } 00434 00435 void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00436 cudaColorSpinorField &Tmp1, cudaColorSpinorField &Tmp2) const 00437 { 00438 dirac->tmp1 = &Tmp1; 00439 dirac->tmp2 = &Tmp2; 00440 dirac->MdagM(out, in); 00441 dirac->tmp2 = NULL; 00442 dirac->tmp1 = NULL; 00443 } 00444 }; 00445 00446 class DiracMdag : public DiracMatrix { 00447 00448 public: 00449 DiracMdag(const Dirac &d) : DiracMatrix(d) { } 00450 DiracMdag(const Dirac *d) : DiracMatrix(d) { } 00451 00452 void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00453 { 00454 dirac->Mdag(out, in); 00455 } 00456 00457 void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, cudaColorSpinorField &tmp) const 00458 { 00459 dirac->tmp1 = &tmp; 00460 dirac->Mdag(out, in); 00461 dirac->tmp1 = NULL; 00462 } 00463 00464 void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00465 cudaColorSpinorField &Tmp1, cudaColorSpinorField &Tmp2) const 00466 { 00467 dirac->tmp1 = &Tmp1; 00468 dirac->tmp2 = &Tmp2; 00469 dirac->Mdag(out, in); 00470 dirac->tmp2 = NULL; 00471 dirac->tmp1 = NULL; 00472 } 00473 }; 00474 00475 #endif // _DIRAC_QUDA_H