QUDA v0.4.0
A library for QCD on GPUs
quda/include/dirac_quda.h
Go to the documentation of this file.
00001 #ifndef _DIRAC_QUDA_H
00002 #define _DIRAC_QUDA_H
00003 
00004 #include <quda_internal.h>
00005 #include <color_spinor_field.h>
00006 #include <gauge_field.h>
00007 #include <clover_field.h>
00008 #include <dslash_quda.h>
00009 
00010 #include <face_quda.h>
00011 
00012 // Params for Dirac operator
00013 class DiracParam {
00014 
00015  public:
00016   QudaDiracType type;
00017   double kappa;
00018   double mass;
00019   double m5; // used by domain wall only
00020   MatPCType matpcType;
00021   DagType dagger;
00022   cudaGaugeField *gauge;
00023   cudaGaugeField *fatGauge;  // used by staggered only
00024   cudaGaugeField *longGauge; // used by staggered only
00025   cudaCloverField *clover;
00026   
00027   double mu; // used by twisted mass only
00028 
00029   cudaColorSpinorField *tmp1;
00030   cudaColorSpinorField *tmp2; // used only by Clover and TM
00031 
00032   QudaVerbosity verbose;
00033 
00034   int commDim[QUDA_MAX_DIM]; // whether to do comms or not
00035 
00036   DiracParam() 
00037     : type(QUDA_INVALID_DIRAC), kappa(0.0), m5(0.0), matpcType(QUDA_MATPC_INVALID),
00038     dagger(QUDA_DAG_INVALID), gauge(0), clover(0), mu(0.0), 
00039     tmp1(0), tmp2(0), verbose(QUDA_SILENT)
00040   {
00041 
00042   }
00043 
00044 };
00045 
00046 void setDiracParam(DiracParam &diracParam, QudaInvertParam *inv_param, bool pc);
00047 void setDiracSloppyParam(DiracParam &diracParam, QudaInvertParam *inv_param, bool pc);
00048 
00049 // forward declarations
00050 class DiracMatrix;
00051 class DiracM;
00052 class DiracMdagM;
00053 class DiracMdag;
00054 
00055 // Abstract base class
00056 class Dirac {
00057 
00058   friend class DiracMatrix;
00059   friend class DiracM;
00060   friend class DiracMdagM;
00061   friend class DiracMdag;
00062 
00063   protected:
00064   cudaGaugeField &gauge;
00065   double kappa;
00066   double mass;
00067   MatPCType matpcType;
00068   mutable DagType dagger; // mutable to simplify implementation of Mdag
00069   mutable unsigned long long flops;
00070   mutable cudaColorSpinorField *tmp1; // temporary hack
00071   mutable cudaColorSpinorField *tmp2; // temporary hack
00072 
00073   bool newTmp(cudaColorSpinorField **, const cudaColorSpinorField &) const;
00074   void deleteTmp(cudaColorSpinorField **, const bool &reset) const;
00075 
00076   QudaTune tune;
00077   QudaVerbosity verbose;  
00078 
00079   int commDim[QUDA_MAX_DIM]; // whether do comms or not
00080 
00081  public:
00082   Dirac(const DiracParam &param);
00083   Dirac(const Dirac &dirac);
00084   virtual ~Dirac();
00085   Dirac& operator=(const Dirac &dirac);
00086 
00087   virtual void checkParitySpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const;
00088   virtual void checkFullSpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const;
00089   void checkSpinorAlias(const cudaColorSpinorField &, const cudaColorSpinorField &) const;
00090 
00091   virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00092                       const QudaParity parity) const = 0;
00093   virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00094                           const QudaParity parity, const cudaColorSpinorField &x,
00095                           const double &k) const = 0;
00096   virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const = 0;
00097   virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const = 0;
00098   void Mdag(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00099 
00100   // required methods to use e-o preconditioning for solving full system
00101   virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00102                        cudaColorSpinorField &x, cudaColorSpinorField &b, 
00103                        const QudaSolutionType) const = 0;
00104   virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00105                            const QudaSolutionType) const = 0;
00106   void setMass(double mass){ this->mass = mass;}
00107   // Dirac operator factory
00108   static Dirac* create(const DiracParam &param);
00109 
00110   unsigned long long Flops() const { unsigned long long rtn = flops; flops = 0; return rtn; }
00111   QudaVerbosity Verbose() const { return verbose; }
00112 };
00113 
00114 // Full Wilson
00115 class DiracWilson : public Dirac {
00116 
00117  protected:
00118   FaceBuffer face; // multi-gpu communication buffers
00119 
00120  public:
00121   DiracWilson(const DiracParam &param);
00122   DiracWilson(const DiracWilson &dirac);
00123   virtual ~DiracWilson();
00124   DiracWilson& operator=(const DiracWilson &dirac);
00125 
00126   virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00127                       const QudaParity parity) const;
00128   virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00129                           const QudaParity parity, const cudaColorSpinorField &x, const double &k) const;
00130   virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00131   virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00132 
00133   virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00134                        cudaColorSpinorField &x, cudaColorSpinorField &b, 
00135                        const QudaSolutionType) const;
00136   virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00137                            const QudaSolutionType) const;
00138 };
00139 
00140 // Even-odd preconditioned Wilson
00141 class DiracWilsonPC : public DiracWilson {
00142 
00143  private:
00144 
00145  public:
00146   DiracWilsonPC(const DiracParam &param);
00147   DiracWilsonPC(const DiracWilsonPC &dirac);
00148   virtual ~DiracWilsonPC();
00149   DiracWilsonPC& operator=(const DiracWilsonPC &dirac);
00150 
00151   void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00152   void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00153 
00154   void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00155                cudaColorSpinorField &x, cudaColorSpinorField &b, 
00156                const QudaSolutionType) const;
00157   void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00158                    const QudaSolutionType) const;
00159 };
00160 
00161 // Full clover
00162 class DiracClover : public DiracWilson {
00163 
00164  protected:
00165   cudaCloverField &clover;
00166   void checkParitySpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const;
00167 
00168  public:
00169   DiracClover(const DiracParam &param);
00170   DiracClover(const DiracClover &dirac);
00171   virtual ~DiracClover();
00172   DiracClover& operator=(const DiracClover &dirac);
00173 
00174   void Clover(cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaParity parity) const;
00175   virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00176   virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00177 
00178   virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00179                        cudaColorSpinorField &x, cudaColorSpinorField &b, 
00180                        const QudaSolutionType) const;
00181   virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00182                            const QudaSolutionType) const;
00183 };
00184 
00185 // Even-odd preconditioned clover
00186 class DiracCloverPC : public DiracClover {
00187 
00188  public:
00189   DiracCloverPC(const DiracParam &param);
00190   DiracCloverPC(const DiracCloverPC &dirac);
00191   virtual ~DiracCloverPC();
00192   DiracCloverPC& operator=(const DiracCloverPC &dirac);
00193 
00194   void CloverInv(cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaParity parity) const;
00195   void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00196               const QudaParity parity) const;
00197   void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00198                   const QudaParity parity, const cudaColorSpinorField &x, const double &k) const;
00199 
00200   void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00201   void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00202 
00203   void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00204                cudaColorSpinorField &x, cudaColorSpinorField &b, 
00205                const QudaSolutionType) const;
00206   void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00207                    const QudaSolutionType) const;
00208 };
00209 
00210 
00211 
00212 // Full domain wall 
00213 class DiracDomainWall : public DiracWilson {
00214 
00215  protected:
00216   double m5;
00217   double kappa5;
00218 
00219  public:
00220   DiracDomainWall(const DiracParam &param);
00221   DiracDomainWall(const DiracDomainWall &dirac);
00222   virtual ~DiracDomainWall();
00223   DiracDomainWall& operator=(const DiracDomainWall &dirac);
00224 
00225   void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00226               const QudaParity parity) const;
00227   void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00228                   const QudaParity parity, const cudaColorSpinorField &x, const double &k) const;
00229 
00230   virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00231   virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00232 
00233   virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00234                        cudaColorSpinorField &x, cudaColorSpinorField &b, 
00235                        const QudaSolutionType) const;
00236   virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00237                            const QudaSolutionType) const;
00238 };
00239 
00240 // 5d Even-odd preconditioned domain wall
00241 class DiracDomainWallPC : public DiracDomainWall {
00242 
00243  private:
00244 
00245  public:
00246   DiracDomainWallPC(const DiracParam &param);
00247   DiracDomainWallPC(const DiracDomainWallPC &dirac);
00248   virtual ~DiracDomainWallPC();
00249   DiracDomainWallPC& operator=(const DiracDomainWallPC &dirac);
00250 
00251   void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00252   void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00253 
00254   void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00255                cudaColorSpinorField &x, cudaColorSpinorField &b, 
00256                const QudaSolutionType) const;
00257   void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00258                    const QudaSolutionType) const;
00259 };
00260 
00261 // Full twisted mass
00262 class DiracTwistedMass : public DiracWilson {
00263 
00264  protected:
00265   double mu;
00266   void twistedApply(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00267                     const QudaTwistGamma5Type twistType) const;
00268 
00269  public:
00270   DiracTwistedMass(const DiracParam &param);
00271   DiracTwistedMass(const DiracTwistedMass &dirac);
00272   virtual ~DiracTwistedMass();
00273   DiracTwistedMass& operator=(const DiracTwistedMass &dirac);
00274 
00275   void Twist(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00276 
00277   virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00278   virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00279 
00280   virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00281                        cudaColorSpinorField &x, cudaColorSpinorField &b, 
00282                        const QudaSolutionType) const;
00283   virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00284                            const QudaSolutionType) const;
00285 };
00286 
00287 // Even-odd preconditioned twisted mass
00288 class DiracTwistedMassPC : public DiracTwistedMass {
00289 
00290  public:
00291   DiracTwistedMassPC(const DiracParam &param);
00292   DiracTwistedMassPC(const DiracTwistedMassPC &dirac);
00293   virtual ~DiracTwistedMassPC();
00294   DiracTwistedMassPC& operator=(const DiracTwistedMassPC &dirac);
00295 
00296   void TwistInv(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00297 
00298   virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00299                       const QudaParity parity) const;
00300   virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00301                           const QudaParity parity, const cudaColorSpinorField &x, const double &k) const;
00302   void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00303   void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00304 
00305   void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00306                cudaColorSpinorField &x, cudaColorSpinorField &b, 
00307                const QudaSolutionType) const;
00308   void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00309                    const QudaSolutionType) const;
00310 };
00311 
00312 // Full staggered
00313 class DiracStaggered : public Dirac {
00314 
00315  protected:
00316   cudaGaugeField *fatGauge;
00317   cudaGaugeField *longGauge;
00318   FaceBuffer face; // multi-gpu communication buffers
00319 
00320  public:
00321   DiracStaggered(const DiracParam &param);
00322   DiracStaggered(const DiracStaggered &dirac);
00323   virtual ~DiracStaggered();
00324   DiracStaggered& operator=(const DiracStaggered &dirac);
00325 
00326   virtual void checkParitySpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const;
00327   
00328   virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00329                       const QudaParity parity) const;
00330   virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00331                           const QudaParity parity, const cudaColorSpinorField &x, const double &k) const;
00332   virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00333   virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00334 
00335   virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00336                        cudaColorSpinorField &x, cudaColorSpinorField &b, 
00337                        const QudaSolutionType) const;
00338   virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00339                            const QudaSolutionType) const;
00340 };
00341 
00342 // Even-odd preconditioned staggered
00343 class DiracStaggeredPC : public DiracStaggered {
00344 
00345  protected:
00346 
00347  public:
00348   DiracStaggeredPC(const DiracParam &param);
00349   DiracStaggeredPC(const DiracStaggeredPC &dirac);
00350   virtual ~DiracStaggeredPC();
00351   DiracStaggeredPC& operator=(const DiracStaggeredPC &dirac);
00352 
00353   virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00354   virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
00355 
00356   virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
00357                        cudaColorSpinorField &x, cudaColorSpinorField &b, 
00358                        const QudaSolutionType) const;
00359   virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
00360                            const QudaSolutionType) const;
00361 };
00362 
00363 // Functor base class for applying a given Dirac matrix (M, MdagM, etc.)
00364 class DiracMatrix {
00365 
00366  protected:
00367   const Dirac *dirac;
00368 
00369  public:
00370   DiracMatrix(const Dirac &d) : dirac(&d) { }
00371   DiracMatrix(const Dirac *d) : dirac(d) { }
00372   virtual ~DiracMatrix() = 0;
00373 
00374   virtual void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in) const = 0;
00375   virtual void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in,
00376                           cudaColorSpinorField &tmp) const = 0;
00377   virtual void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in,
00378                           cudaColorSpinorField &Tmp1, cudaColorSpinorField &Tmp2) const = 0;
00379 
00380   unsigned long long flops() const { return dirac->Flops(); }
00381 };
00382 
00383 inline DiracMatrix::~DiracMatrix()
00384 {
00385 
00386 }
00387 
00388 class DiracM : public DiracMatrix {
00389 
00390  public:
00391  DiracM(const Dirac &d) : DiracMatrix(d) { }
00392  DiracM(const Dirac *d) : DiracMatrix(d) { }
00393 
00394   void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00395   {
00396     dirac->M(out, in);
00397   }
00398 
00399   void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, cudaColorSpinorField &tmp) const
00400   {
00401     dirac->tmp1 = &tmp;
00402     dirac->M(out, in);
00403     dirac->tmp1 = NULL;
00404   }
00405 
00406   void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00407                   cudaColorSpinorField &Tmp1, cudaColorSpinorField &Tmp2) const
00408   {
00409     dirac->tmp1 = &Tmp1;
00410     dirac->tmp2 = &Tmp2;
00411     dirac->M(out, in);
00412     dirac->tmp2 = NULL;
00413     dirac->tmp1 = NULL;
00414   }
00415 };
00416 
00417 class DiracMdagM : public DiracMatrix {
00418 
00419  public:
00420   DiracMdagM(const Dirac &d) : DiracMatrix(d) { }
00421   DiracMdagM(const Dirac *d) : DiracMatrix(d) { }
00422 
00423   void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00424   {
00425     dirac->MdagM(out, in);
00426   }
00427 
00428   void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, cudaColorSpinorField &tmp) const
00429   {
00430     dirac->tmp1 = &tmp;
00431     dirac->MdagM(out, in);
00432     dirac->tmp1 = NULL;
00433   }
00434 
00435   void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00436                   cudaColorSpinorField &Tmp1, cudaColorSpinorField &Tmp2) const
00437   {
00438     dirac->tmp1 = &Tmp1;
00439     dirac->tmp2 = &Tmp2;
00440     dirac->MdagM(out, in);
00441     dirac->tmp2 = NULL;
00442     dirac->tmp1 = NULL;
00443   }
00444 };
00445 
00446 class DiracMdag : public DiracMatrix {
00447 
00448  public:
00449   DiracMdag(const Dirac &d) : DiracMatrix(d) { }
00450   DiracMdag(const Dirac *d) : DiracMatrix(d) { }
00451 
00452   void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
00453   {
00454     dirac->Mdag(out, in);
00455   }
00456 
00457   void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, cudaColorSpinorField &tmp) const
00458   {
00459     dirac->tmp1 = &tmp;
00460     dirac->Mdag(out, in);
00461     dirac->tmp1 = NULL;
00462   }
00463 
00464   void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in, 
00465                   cudaColorSpinorField &Tmp1, cudaColorSpinorField &Tmp2) const
00466   {
00467     dirac->tmp1 = &Tmp1;
00468     dirac->tmp2 = &Tmp2;
00469     dirac->Mdag(out, in);
00470     dirac->tmp2 = NULL;
00471     dirac->tmp1 = NULL;
00472   }
00473 };
00474 
00475 #endif // _DIRAC_QUDA_H
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines