QUDA v0.3.2
A library for QCD on GPUs

quda/include/dirac_quda.h

Go to the documentation of this file.
00001 #ifndef _DIRAC_QUDA_H
00002 #define _DIRAC_QUDA_H
00003 
00004 #include <quda_internal.h>
00005 #include <color_spinor_field.h>
00006 #include <dslash_quda.h>
00007 
00008 // Params for Dirac operator
00009 class DiracParam {
00010 
00011  public:
00012   QudaDiracType type;
00013   double kappa;
00014   double mass;
00015   double m5; // used by domain wall only
00016   MatPCType matpcType;
00017   DagType dagger;
00018   FullGauge *gauge;
00019   FullGauge *fatGauge;  // used by staggered only
00020   FullGauge *longGauge; // used by staggered only
00021   FullClover *clover;
00022   FullClover *cloverInv;
00023   
00024   double mu; // used by twisted mass only
00025 
00026   cudaColorSpinorField *tmp1;
00027   cudaColorSpinorField *tmp2; // used only by Clover and TM
00028 
00029   QudaVerbosity verbose;
00030 
00031   DiracParam() 
00032     : type(QUDA_INVALID_DIRAC), kappa(0.0), m5(0.0), matpcType(QUDA_MATPC_INVALID),
00033     dagger(QUDA_DAG_INVALID), gauge(0), clover(0), cloverInv(0), mu(0.0), 
00034     tmp1(0), tmp2(0), verbose(QUDA_SILENT)
00035   {
00036 
00037   }
00038 
00039 };
00040 
00041 void setDiracParam(DiracParam &diracParam, QudaInvertParam *inv_param, bool pc);
00042 void setDiracSloppyParam(DiracParam &diracParam, QudaInvertParam *inv_param, bool pc);
00043 
00044 // forward declarations
00045 class DiracM;
00046 class DiracMdagM;
00047 class DiracMdag;
00048 
00049 // Abstract base class
00050 class Dirac {
00051 
00052   friend class DiracM;
00053   friend class DiracMdagM;
00054   friend class DiracMdag;
00055 
00056  protected:
00057   FullGauge &gauge;
00058   double kappa;
00059   double mass;
00060   MatPCType matpcType;
00061   mutable DagType dagger; // mutable to simplify implementation of Mdag
00062   mutable unsigned long long flops;
00063   mutable cudaColorSpinorField *tmp1; // temporary hack
00064   mutable cudaColorSpinorField *tmp2; // temporary hack
00065 
00066   bool newTmp(cudaColorSpinorField **, const cudaColorSpinorField &) const;
00067   void deleteTmp(cudaColorSpinorField **, const bool &reset) const;
00068 
00069  public:
00070   Dirac(const DiracParam &param);
00071   Dirac(const Dirac &dirac);
00072   virtual ~Dirac();
00073   Dirac& operator=(const Dirac &dirac);
00074 
00075   virtual void checkParitySpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const;
00076   virtual void checkFullSpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const;
00077   void checkSpinorAlias(const cudaColorSpinorField &, const cudaColorSpinorField &) const;
00078 
00079   virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00080                       const QudaParity parity) const = 0;
00081   virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00082                           const QudaParity parity, const cudaColorSpinorField &x,
00083                           const double &k) const = 0;
00084   virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const = 0;
00085   virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const = 0;
00086   void Mdag(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00087 
00088   // required methods to use e-o preconditioning for solving full system
00089   virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00090                        cudaColorSpinorField &x, cudaColorSpinorField &b, 
00091                        const QudaSolutionType) const = 0;
00092   virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00093                            const QudaSolutionType) const = 0;
00094 
00095   // Dirac operator factory
00096   static Dirac* create(const DiracParam &param);
00097 
00098   unsigned long long Flops() const { unsigned long long rtn = flops; flops = 0; return rtn; }
00099 };
00100 
00101 // Full Wilson
00102 class DiracWilson : public Dirac {
00103 
00104  protected:
00105 
00106  public:
00107   DiracWilson(const DiracParam &param);
00108   DiracWilson(const DiracWilson &dirac);
00109   virtual ~DiracWilson();
00110   DiracWilson& operator=(const DiracWilson &dirac);
00111 
00112   virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00113                       const QudaParity parity) const;
00114   virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00115                           const QudaParity parity, const cudaColorSpinorField &x, const double &k) const;
00116   virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00117   virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00118 
00119   virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00120                        cudaColorSpinorField &x, cudaColorSpinorField &b, 
00121                        const QudaSolutionType) const;
00122   virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00123                            const QudaSolutionType) const;
00124 };
00125 
00126 // Even-odd preconditioned Wilson
00127 class DiracWilsonPC : public DiracWilson {
00128 
00129  private:
00130 
00131  public:
00132   DiracWilsonPC(const DiracParam &param);
00133   DiracWilsonPC(const DiracWilsonPC &dirac);
00134   virtual ~DiracWilsonPC();
00135   DiracWilsonPC& operator=(const DiracWilsonPC &dirac);
00136 
00137   void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00138   void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00139 
00140   void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00141                cudaColorSpinorField &x, cudaColorSpinorField &b, 
00142                const QudaSolutionType) const;
00143   void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00144                    const QudaSolutionType) const;
00145 };
00146 
00147 // Full clover
00148 class DiracClover : public DiracWilson {
00149 
00150  protected:
00151   FullClover &clover;
00152   void checkParitySpinor(const cudaColorSpinorField &, const cudaColorSpinorField &, 
00153                          const FullClover &) const;
00154   void cloverApply(cudaColorSpinorField &out, const FullClover &clover, const cudaColorSpinorField &in, 
00155                    const QudaParity parity) const;
00156 
00157  public:
00158   DiracClover(const DiracParam &param);
00159   DiracClover(const DiracClover &dirac);
00160   virtual ~DiracClover();
00161   DiracClover& operator=(const DiracClover &dirac);
00162 
00163   void Clover(cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaParity parity) const;
00164   virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00165   virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00166 
00167   virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00168                        cudaColorSpinorField &x, cudaColorSpinorField &b, 
00169                        const QudaSolutionType) const;
00170   virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00171                            const QudaSolutionType) const;
00172 };
00173 
00174 // Even-odd preconditioned clover
00175 class DiracCloverPC : public DiracClover {
00176 
00177  private:
00178   FullClover &cloverInv;
00179 
00180  public:
00181   DiracCloverPC(const DiracParam &param);
00182   DiracCloverPC(const DiracCloverPC &dirac);
00183   virtual ~DiracCloverPC();
00184   DiracCloverPC& operator=(const DiracCloverPC &dirac);
00185 
00186   void CloverInv(cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaParity parity) const;
00187   void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00188               const QudaParity parity) const;
00189   void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00190                   const QudaParity parity, const cudaColorSpinorField &x, const double &k) const;
00191 
00192   void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00193   void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00194 
00195   void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00196                cudaColorSpinorField &x, cudaColorSpinorField &b, 
00197                const QudaSolutionType) const;
00198   void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00199                    const QudaSolutionType) const;
00200 };
00201 
00202 
00203 
00204 // Full domain wall 
00205 class DiracDomainWall : public DiracWilson {
00206 
00207  protected:
00208   double m5;
00209   double kappa5;
00210 
00211  public:
00212   DiracDomainWall(const DiracParam &param);
00213   DiracDomainWall(const DiracDomainWall &dirac);
00214   virtual ~DiracDomainWall();
00215   DiracDomainWall& operator=(const DiracDomainWall &dirac);
00216 
00217   void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00218               const QudaParity parity) const;
00219   void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00220                   const QudaParity parity, const cudaColorSpinorField &x, const double &k) const;
00221 
00222   virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00223   virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00224 
00225   virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00226                        cudaColorSpinorField &x, cudaColorSpinorField &b, 
00227                        const QudaSolutionType) const;
00228   virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00229                            const QudaSolutionType) const;
00230 };
00231 
00232 // 5d Even-odd preconditioned domain wall
00233 class DiracDomainWallPC : public DiracDomainWall {
00234 
00235  private:
00236 
00237  public:
00238   DiracDomainWallPC(const DiracParam &param);
00239   DiracDomainWallPC(const DiracDomainWallPC &dirac);
00240   virtual ~DiracDomainWallPC();
00241   DiracDomainWallPC& operator=(const DiracDomainWallPC &dirac);
00242 
00243   void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00244   void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00245 
00246   void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00247                cudaColorSpinorField &x, cudaColorSpinorField &b, 
00248                const QudaSolutionType) const;
00249   void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00250                    const QudaSolutionType) const;
00251 };
00252 
00253 // Full staggered
00254 class DiracStaggered : public Dirac {
00255 
00256  protected:
00257     FullGauge *fatGauge;
00258     FullGauge *longGauge;
00259 
00260  public:
00261   DiracStaggered(const DiracParam &param);
00262   DiracStaggered(const DiracStaggered &dirac);
00263   virtual ~DiracStaggered();
00264   DiracStaggered& operator=(const DiracStaggered &dirac);
00265 
00266   virtual void checkParitySpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const;
00267   
00268   virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00269                       const QudaParity parity) const;
00270   virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00271                           const QudaParity parity, const cudaColorSpinorField &x, const double &k) const;
00272   virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00273   virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00274 
00275   virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00276                        cudaColorSpinorField &x, cudaColorSpinorField &b, 
00277                        const QudaSolutionType) const;
00278   virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00279                            const QudaSolutionType) const;
00280 };
00281 
00282 // Even-odd preconditioned staggered
00283 class DiracStaggeredPC : public Dirac {
00284 
00285  protected:
00286   FullGauge *fatGauge;
00287   FullGauge *longGauge;
00288 
00289  public:
00290   DiracStaggeredPC(const DiracParam &param);
00291   DiracStaggeredPC(const DiracStaggeredPC &dirac);
00292   virtual ~DiracStaggeredPC();
00293   DiracStaggeredPC& operator=(const DiracStaggeredPC &dirac);
00294 
00295   virtual void checkParitySpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const;
00296   
00297   virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00298                       const QudaParity parity) const;
00299   virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00300                           const QudaParity parity, const cudaColorSpinorField &x, const double &k) const;
00301   virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00302   virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00303 
00304   virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00305                        cudaColorSpinorField &x, cudaColorSpinorField &b, 
00306                        const QudaSolutionType) const;
00307   virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00308                            const QudaSolutionType) const;
00309 };
00310 
00311 // Full twisted mass
00312 class DiracTwistedMass : public DiracWilson {
00313 
00314  protected:
00315   double mu;
00316   void twistedApply(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00317                     const QudaTwistGamma5Type twistType) const;
00318 
00319  public:
00320   DiracTwistedMass(const DiracParam &param);
00321   DiracTwistedMass(const DiracTwistedMass &dirac);
00322   virtual ~DiracTwistedMass();
00323   DiracTwistedMass& operator=(const DiracTwistedMass &dirac);
00324 
00325   void Twist(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00326 
00327   virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00328   virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00329 
00330   virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00331                        cudaColorSpinorField &x, cudaColorSpinorField &b, 
00332                        const QudaSolutionType) const;
00333   virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00334                            const QudaSolutionType) const;
00335 };
00336 
00337 // Even-odd preconditioned twisted mass
00338 class DiracTwistedMassPC : public DiracTwistedMass {
00339 
00340  private:
00341 
00342  public:
00343   DiracTwistedMassPC(const DiracParam &param);
00344   DiracTwistedMassPC(const DiracTwistedMassPC &dirac);
00345   virtual ~DiracTwistedMassPC();
00346   DiracTwistedMassPC& operator=(const DiracTwistedMassPC &dirac);
00347 
00348   void TwistInv(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00349 
00350   virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00351                       const QudaParity parity) const;
00352   virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00353                           const QudaParity parity, const cudaColorSpinorField &x, const double &k) const;
00354   void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00355   void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00356 
00357   void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00358                cudaColorSpinorField &x, cudaColorSpinorField &b, 
00359                const QudaSolutionType) const;
00360   void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00361                    const QudaSolutionType) const;
00362 };
00363 
00364 // Functor base class for applying a given Dirac matrix (M, MdagM, etc.)
00365 class DiracMatrix {
00366 
00367  protected:
00368   const Dirac *dirac;
00369 
00370  public:
00371   DiracMatrix(const Dirac &d) : dirac(&d) { }
00372   DiracMatrix(const Dirac *d) : dirac(d) { }
00373   virtual ~DiracMatrix() = 0;
00374 
00375   virtual void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in) const = 0;
00376   virtual void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in,
00377                           cudaColorSpinorField &tmp) const = 0;
00378   virtual void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in,
00379                           cudaColorSpinorField &Tmp1, cudaColorSpinorField &Tmp2) const = 0;
00380 
00381   unsigned long long flops() const { return dirac->Flops(); }
00382 };
00383 
00384 inline DiracMatrix::~DiracMatrix()
00385 {
00386 
00387 }
00388 
00389 class DiracM : public DiracMatrix {
00390 
00391  public:
00392   DiracM(const Dirac &d) : DiracMatrix(d) { }
00393   DiracM(const Dirac *d) : DiracMatrix(d) { }
00394 
00395   void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00396   {
00397     dirac->M(out, in);
00398   }
00399 
00400   void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, cudaColorSpinorField &tmp) const
00401   {
00402     dirac->tmp1 = &tmp;
00403     dirac->M(out, in);
00404     dirac->tmp1 = NULL;
00405   }
00406 
00407   void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00408                   cudaColorSpinorField &Tmp1, cudaColorSpinorField &Tmp2) const
00409   {
00410     dirac->tmp1 = &Tmp1;
00411     dirac->tmp2 = &Tmp2;
00412     dirac->M(out, in);
00413     dirac->tmp2 = NULL;
00414     dirac->tmp1 = NULL;
00415   }
00416 };
00417 
00418 class DiracMdagM : public DiracMatrix {
00419 
00420  public:
00421   DiracMdagM(const Dirac &d) : DiracMatrix(d) { }
00422   DiracMdagM(const Dirac *d) : DiracMatrix(d) { }
00423 
00424   void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00425   {
00426     dirac->MdagM(out, in);
00427   }
00428 
00429   void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, cudaColorSpinorField &tmp) const
00430   {
00431     dirac->tmp1 = &tmp;
00432     dirac->MdagM(out, in);
00433     dirac->tmp1 = NULL;
00434   }
00435 
00436   void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00437                   cudaColorSpinorField &Tmp1, cudaColorSpinorField &Tmp2) const
00438   {
00439     dirac->tmp1 = &Tmp1;
00440     dirac->tmp2 = &Tmp2;
00441     dirac->MdagM(out, in);
00442     dirac->tmp2 = NULL;
00443     dirac->tmp1 = NULL;
00444   }
00445 };
00446 
00447 class DiracMdag : public DiracMatrix {
00448 
00449  public:
00450   DiracMdag(const Dirac &d) : DiracMatrix(d) { }
00451   DiracMdag(const Dirac *d) : DiracMatrix(d) { }
00452 
00453   void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00454   {
00455     dirac->Mdag(out, in);
00456   }
00457 
00458   void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, cudaColorSpinorField &tmp) const
00459   {
00460     dirac->tmp1 = &tmp;
00461     dirac->Mdag(out, in);
00462     dirac->tmp1 = NULL;
00463   }
00464 
00465   void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00466                   cudaColorSpinorField &Tmp1, cudaColorSpinorField &Tmp2) const
00467   {
00468     dirac->tmp1 = &Tmp1;
00469     dirac->tmp2 = &Tmp2;
00470     dirac->Mdag(out, in);
00471     dirac->tmp2 = NULL;
00472     dirac->tmp1 = NULL;
00473   }
00474 };
00475 
00476 #endif // _DIRAC_QUDA_H
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Defines