QUDA  v0.5.0
A library for QCD on GPUs
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
dirac_quda.h
Go to the documentation of this file.
1 #ifndef _DIRAC_QUDA_H
2 #define _DIRAC_QUDA_H
3 
4 #include <quda_internal.h>
5 #include <color_spinor_field.h>
6 #include <gauge_field.h>
7 #include <clover_field.h>
8 #include <dslash_quda.h>
9 
10 #include <face_quda.h>
11 #include <blas_quda.h>
12 
13 #include <typeinfo>
14 
15 namespace quda {
16 
17  // Params for Dirac operator
18  class DiracParam {
19 
20  public:
22  double kappa;
23  double mass;
24  double m5; // used by domain wall only
25  int Ls;
29  cudaGaugeField *fatGauge; // used by staggered only
30  cudaGaugeField *longGauge; // used by staggered only
32 
33  double mu; // used by twisted mass only
34  double epsilon; //2nd tm parameter (used by twisted mass only)
35 
37  cudaColorSpinorField *tmp2; // used by Wilson-like kernels only
38 
40 
41  int commDim[QUDA_MAX_DIM]; // whether to do comms or not
42 
45  dagger(QUDA_DAG_INVALID), gauge(0), clover(0), mu(0.0), epsilon(0.0),
46  tmp1(0), tmp2(0), verbose(QUDA_SILENT)
47  {
48 
49  }
50 
51  void print() {
52  printfQuda("Printing DslashParam\n");
53  printfQuda("type = %d\n", type);
54  printfQuda("kappa = %g\n", kappa);
55  printfQuda("mass = %g\n", mass);
56  printfQuda("m5 = %g\n", m5);
57  printfQuda("Ls = %d\n", Ls);
58  printfQuda("matpcType = %d\n", matpcType);
59  printfQuda("dagger = %d\n", dagger);
60  printfQuda("mu = %g\n", mu);
61  printfQuda("epsilon = %g\n", epsilon);
62  for (int i=0; i<QUDA_MAX_DIM; i++) printfQuda("commDim[%d] = %d\n", i, commDim[i]);
63  }
64  };
65 
66  void setDiracParam(DiracParam &diracParam, QudaInvertParam *inv_param, bool pc);
67  void setDiracSloppyParam(DiracParam &diracParam, QudaInvertParam *inv_param, bool pc);
68 
69  // forward declarations
70  class DiracMatrix;
71  class DiracM;
72  class DiracMdagM;
73  class DiracMdag;
74 
75  // Abstract base class
76  class Dirac {
77 
78  friend class DiracMatrix;
79  friend class DiracM;
80  friend class DiracMdagM;
81  friend class DiracMdag;
82 
83  protected:
85  double kappa;
86  double mass;
88  mutable DagType dagger; // mutable to simplify implementation of Mdag
89  mutable unsigned long long flops;
90  mutable cudaColorSpinorField *tmp1; // temporary hack
91  mutable cudaColorSpinorField *tmp2; // temporary hack
92 
93  bool newTmp(cudaColorSpinorField **, const cudaColorSpinorField &) const;
94  void deleteTmp(cudaColorSpinorField **, const bool &reset) const;
95 
98 
99  int commDim[QUDA_MAX_DIM]; // whether do comms or not
100 
101  public:
102  Dirac(const DiracParam &param);
103  Dirac(const Dirac &dirac);
104  virtual ~Dirac();
105  Dirac& operator=(const Dirac &dirac);
106 
107  virtual void checkParitySpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const;
108  virtual void checkFullSpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const;
109  void checkSpinorAlias(const cudaColorSpinorField &, const cudaColorSpinorField &) const;
110 
111  virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in,
112  const QudaParity parity) const = 0;
114  const QudaParity parity, const cudaColorSpinorField &x,
115  const double &k) const = 0;
116  virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const = 0;
117  virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const = 0;
118  void Mdag(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
119 
120  // required methods to use e-o preconditioning for solving full system
121  virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
123  const QudaSolutionType) const = 0;
124  virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
125  const QudaSolutionType) const = 0;
126  void setMass(double mass){ this->mass = mass;}
127  // Dirac operator factory
128  static Dirac* create(const DiracParam &param);
129 
130  unsigned long long Flops() const { unsigned long long rtn = flops; flops = 0; return rtn; }
131  QudaVerbosity Verbose() const { return verbose; }
132  };
133 
134  // Full Wilson
135  class DiracWilson : public Dirac {
136 
137  protected:
138  FaceBuffer face; // multi-gpu communication buffers
139 
140  public:
141  DiracWilson(const DiracParam &param);
142  DiracWilson(const DiracWilson &dirac);
143  DiracWilson(const DiracParam &param, const int nDims);//to correctly adjust face for DW and non-deg twisted mass
144 
145  virtual ~DiracWilson();
147 
148  virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in,
149  const QudaParity parity) const;
151  const QudaParity parity, const cudaColorSpinorField &x, const double &k) const;
152  virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
153  virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
154 
155  virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
157  const QudaSolutionType) const;
158  virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
159  const QudaSolutionType) const;
160  };
161 
162  // Even-odd preconditioned Wilson
163  class DiracWilsonPC : public DiracWilson {
164 
165  private:
166 
167  public:
170  virtual ~DiracWilsonPC();
171  DiracWilsonPC& operator=(const DiracWilsonPC &dirac);
172 
173  void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
174  void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
175 
178  const QudaSolutionType) const;
180  const QudaSolutionType) const;
181  };
182 
183  // Full clover
184  class DiracClover : public DiracWilson {
185 
186  protected:
188  void checkParitySpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const;
189 
190  public:
191  DiracClover(const DiracParam &param);
192  DiracClover(const DiracClover &dirac);
193  virtual ~DiracClover();
195 
198  const cudaColorSpinorField &x, const double &k) const;
199  virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
200  virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
201 
202  virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
204  const QudaSolutionType) const;
205  virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
206  const QudaSolutionType) const;
207  };
208 
209  // Even-odd preconditioned clover
210  class DiracCloverPC : public DiracClover {
211 
212  public:
215  virtual ~DiracCloverPC();
216  DiracCloverPC& operator=(const DiracCloverPC &dirac);
217 
219  void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in,
220  const QudaParity parity) const;
222  const QudaParity parity, const cudaColorSpinorField &x, const double &k) const;
223 
224  void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
225  void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
226 
229  const QudaSolutionType) const;
231  const QudaSolutionType) const;
232  };
233 
234 
235 
236  // Full domain wall
237  class DiracDomainWall : public DiracWilson {
238 
239  protected:
240  double m5;
241  double kappa5;
242 
243  public:
246  virtual ~DiracDomainWall();
248 
250  const QudaParity parity) const;
252  const QudaParity parity, const cudaColorSpinorField &x, const double &k) const;
253 
254  virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
255  virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
256 
257  virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
259  const QudaSolutionType) const;
260  virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
261  const QudaSolutionType) const;
262  };
263 
264  // 5d Even-odd preconditioned domain wall
266 
267  private:
268 
269  public:
272  virtual ~DiracDomainWallPC();
274 
275  void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
276  void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
277 
280  const QudaSolutionType) const;
282  const QudaSolutionType) const;
283  };
284 
285  // Full twisted mass
286  class DiracTwistedMass : public DiracWilson {
287 
288  protected:
289  double mu;
290  double epsilon;
292  const QudaTwistGamma5Type twistType) const;
293 
294  static int initTMFlag;
295 
296  public:
298  DiracTwistedMass(const DiracParam &param, const int nDim);
299  virtual ~DiracTwistedMass();
301 
303 
304  virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
305  virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
306 
307  virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
309  const QudaSolutionType) const;
310  virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
311  const QudaSolutionType) const;
312  };
313 
314  // Even-odd preconditioned twisted mass
316 
317  public:
319  DiracTwistedMassPC(const DiracParam &param, const int nDim);
320 
321  virtual ~DiracTwistedMassPC();
323 
325 
326  virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in,
327  const QudaParity parity) const;
328  virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in,
329  const QudaParity parity, const cudaColorSpinorField &x, const double &k) const;
330  void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
331  void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
332 
335  const QudaSolutionType) const;
337  const QudaSolutionType) const;
338  };
339 
340  // Full staggered
341  class DiracStaggered : public Dirac {
342 
343  protected:
346  FaceBuffer face; // multi-gpu communication buffers
347 
348  public:
351  virtual ~DiracStaggered();
353 
354  virtual void checkParitySpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const;
355 
356  virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in,
357  const QudaParity parity) const;
359  const QudaParity parity, const cudaColorSpinorField &x, const double &k) const;
360  virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
361  virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
362 
363  virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
365  const QudaSolutionType) const;
366  virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
367  const QudaSolutionType) const;
368  };
369 
370  // Even-odd preconditioned staggered
372 
373  protected:
374 
375  public:
378  virtual ~DiracStaggeredPC();
380 
381  virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
382  virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const;
383 
384  virtual void prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol,
386  const QudaSolutionType) const;
387  virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b,
388  const QudaSolutionType) const;
389  };
390 
391  // Functor base class for applying a given Dirac matrix (M, MdagM, etc.)
392  class DiracMatrix {
393 
394  protected:
395  const Dirac *dirac;
396 
397  public:
398  DiracMatrix(const Dirac &d) : dirac(&d) { }
399  DiracMatrix(const Dirac *d) : dirac(d) { }
400  virtual ~DiracMatrix() = 0;
401 
402  virtual void operator()(cudaColorSpinorField &out, const cudaColorSpinorField &in) const = 0;
404  cudaColorSpinorField &tmp) const = 0;
406  cudaColorSpinorField &Tmp1, cudaColorSpinorField &Tmp2) const = 0;
407 
408  unsigned long long flops() const { return dirac->Flops(); }
409 
410  std::string Type() const { return typeid(*dirac).name(); }
411  };
412 
414  {
415 
416  }
417 
418  class DiracM : public DiracMatrix {
419 
420  public:
421  DiracM(const Dirac &d) : DiracMatrix(d) { }
422  DiracM(const Dirac *d) : DiracMatrix(d) { }
423 
425  {
426  dirac->M(out, in);
427  }
428 
430  {
431  dirac->tmp1 = &tmp;
432  dirac->M(out, in);
433  dirac->tmp1 = NULL;
434  }
435 
437  cudaColorSpinorField &Tmp1, cudaColorSpinorField &Tmp2) const
438  {
439  dirac->tmp1 = &Tmp1;
440  dirac->tmp2 = &Tmp2;
441  dirac->M(out, in);
442  dirac->tmp2 = NULL;
443  dirac->tmp1 = NULL;
444  }
445  };
446 
447  class DiracMdagM : public DiracMatrix {
448 
449  public:
450  DiracMdagM(const Dirac &d) : DiracMatrix(d), shift(0.0) { }
451  DiracMdagM(const Dirac *d) : DiracMatrix(d), shift(0.0) { }
452 
454  double shift;
455 
457  {
458  dirac->MdagM(out, in);
459  if (shift != 0.0) axpyCuda(shift, const_cast<cudaColorSpinorField&>(in), out);
460  }
461 
463  {
464  dirac->tmp1 = &tmp;
465  dirac->MdagM(out, in);
466  if (shift != 0.0) axpyCuda(shift, const_cast<cudaColorSpinorField&>(in), out);
467  dirac->tmp1 = NULL;
468  }
469 
471  cudaColorSpinorField &Tmp1, cudaColorSpinorField &Tmp2) const
472  {
473  dirac->tmp1 = &Tmp1;
474  dirac->tmp2 = &Tmp2;
475  dirac->MdagM(out, in);
476  if (shift != 0.0) axpyCuda(shift, const_cast<cudaColorSpinorField&>(in), out);
477  dirac->tmp2 = NULL;
478  dirac->tmp1 = NULL;
479  }
480  };
481 
482  class DiracMdag : public DiracMatrix {
483 
484  public:
485  DiracMdag(const Dirac &d) : DiracMatrix(d) { }
486  DiracMdag(const Dirac *d) : DiracMatrix(d) { }
487 
489  {
490  dirac->Mdag(out, in);
491  }
492 
494  {
495  dirac->tmp1 = &tmp;
496  dirac->Mdag(out, in);
497  dirac->tmp1 = NULL;
498  }
499 
501  cudaColorSpinorField &Tmp1, cudaColorSpinorField &Tmp2) const
502  {
503  dirac->tmp1 = &Tmp1;
504  dirac->tmp2 = &Tmp2;
505  dirac->Mdag(out, in);
506  dirac->tmp2 = NULL;
507  dirac->tmp1 = NULL;
508  }
509  };
510 
511 } // namespace quda
512 
513 #endif // _DIRAC_QUDA_H