QUDA  0.9.0
dirac_staggered.cpp
Go to the documentation of this file.
1 #include <dirac_quda.h>
2 #include <blas_quda.h>
3 
4 namespace quda {
5 
7 
9 
11 
13  {
14  if (&dirac != this) {
16  }
17  return *this;
18  }
19 
21  {
22  if (in.Ndim() != 5 || out.Ndim() != 5) {
23  errorQuda("Staggered dslash requires 5-d fermion fields");
24  }
25 
26  if (in.Precision() != out.Precision()) {
27  errorQuda("Input and output spinor precisions don't match in dslash_quda");
28  }
29 
30  if (in.Stride() != out.Stride()) {
31  errorQuda("Input %d and output %d spinor strides don't match in dslash_quda", in.Stride(), out.Stride());
32  }
33 
34  if (in.SiteSubset() != QUDA_PARITY_SITE_SUBSET || out.SiteSubset() != QUDA_PARITY_SITE_SUBSET) {
35  errorQuda("ColorSpinorFields are not single parity, in = %d, out = %d",
36  in.SiteSubset(), out.SiteSubset());
37  }
38 
39  if ((out.Volume()/out.X(4) != 2*gauge->VolumeCB() && out.SiteSubset() == QUDA_FULL_SITE_SUBSET) ||
40  (out.Volume()/out.X(4) != gauge->VolumeCB() && out.SiteSubset() == QUDA_PARITY_SITE_SUBSET) ) {
41  errorQuda("Spinor volume %d doesn't match gauge volume %d", out.Volume(), gauge->VolumeCB());
42  }
43  }
44 
45 
47  const QudaParity parity) const
48  {
50 
52  staggeredDslashCuda(&static_cast<cudaColorSpinorField&>(out),
53  *gauge, &static_cast<const cudaColorSpinorField&>(in), parity,
54  dagger, 0, 0, commDim, profile);
55  } else {
56  errorQuda("Not supported");
57  }
58 
59  flops += 570ll*in.Volume();
60  }
61 
63  const QudaParity parity, const ColorSpinorField &x,
64  const double &k) const
65  {
67 
69  staggeredDslashCuda(&static_cast<cudaColorSpinorField&>(out), *gauge,
70  &static_cast<const cudaColorSpinorField&>(in), parity, dagger,
71  &static_cast<const cudaColorSpinorField&>(x), k, commDim, profile);
72  } else {
73  errorQuda("Not supported");
74  }
75 
76  flops += 582ll*in.Volume();
77  }
78 
79  // Full staggered operator
81  {
82  DslashXpay(out.Even(), in.Odd(), QUDA_EVEN_PARITY, in.Even(), 2*mass);
83  DslashXpay(out.Odd(), in.Even(), QUDA_ODD_PARITY, in.Odd(), 2*mass);
84  }
85 
87  {
88  bool reset = newTmp(&tmp1, in);
89 
90  //even
91  Dslash(tmp1->Even(), in.Even(), QUDA_ODD_PARITY);
92  DslashXpay(out.Even(), tmp1->Even(), QUDA_EVEN_PARITY, in.Even(), 4*mass*mass);
93 
94  //odd
95  Dslash(tmp1->Even(), in.Odd(), QUDA_EVEN_PARITY);
96  DslashXpay(out.Odd(), tmp1->Even(), QUDA_ODD_PARITY, in.Odd(), 4*mass*mass);
97 
98  deleteTmp(&tmp1, reset);
99  }
100 
103  const QudaSolutionType solType) const
104  {
105  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
106  errorQuda("Preconditioned solution requires a preconditioned solve_type");
107  }
108 
109  src = &b;
110  sol = &x;
111  }
112 
114  const QudaSolutionType solType) const
115  {
116  // do nothing
117  }
118 
119 
122  {
123 
124  }
125 
128  {
129 
130  }
131 
133  {
134 
135  }
136 
138  {
139  if (&dirac != this) {
141  }
142 
143  return *this;
144  }
145 
147  {
148  errorQuda("DiracStaggeredPC::M() is not implemented\n");
149  }
150 
152  {
153  bool reset = newTmp(&tmp1, in);
154 
156  QudaParity other_parity = QUDA_INVALID_PARITY;
159  other_parity = QUDA_ODD_PARITY;
160  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
162  other_parity = QUDA_EVEN_PARITY;
163  } else {
164  errorQuda("Invalid matpcType(%d) in function\n", matpcType);
165  }
166  Dslash(*tmp1, in, other_parity);
167  DslashXpay(out, *tmp1, parity, in, 4*mass*mass);
168 
169  deleteTmp(&tmp1, reset);
170  }
171 
174  const QudaSolutionType solType) const
175  {
176  src = &b;
177  sol = &x;
178  }
179 
181  const QudaSolutionType solType) const
182  {
183  // do nothing
184  }
185 
186 } // namespace quda
unsigned long long flops
Definition: dirac_quda.h:100
virtual void prepare(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType) const
const void * src
#define errorQuda(...)
Definition: util_quda.h:90
cudaGaugeField * gauge
Definition: dirac_quda.h:95
DiracStaggeredPC(const DiracParam &param)
const ColorSpinorField & Even() const
void deleteTmp(ColorSpinorField **, const bool &reset) const
Definition: dirac.cpp:64
TimeProfile profile
Definition: dirac_quda.h:112
virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const
virtual void reconstruct(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType) const
QudaGaugeParam param
Definition: pack_test.cpp:17
bool newTmp(ColorSpinorField **, const ColorSpinorField &) const
Definition: dirac.cpp:53
#define b
DiracStaggered & operator=(const DiracStaggered &dirac)
int commDim[QUDA_MAX_DIM]
Definition: dirac_quda.h:110
DiracStaggeredPC & operator=(const DiracStaggeredPC &dirac)
void staggeredDslashCuda(cudaColorSpinorField *out, const cudaGaugeField &gauge, const cudaColorSpinorField *in, const int parity, const int dagger, const cudaColorSpinorField *x, const double &k, const int *commDim, TimeProfile &profile)
virtual void prepare(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType) const
cpuColorSpinorField * in
double mass
Definition: dirac_quda.h:97
virtual void reconstruct(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType) const
#define checkLocation(...)
enum QudaSolutionType_s QudaSolutionType
virtual void checkParitySpinor(const ColorSpinorField &, const ColorSpinorField &) const
QudaDagType dagger
Definition: dirac_quda.h:99
enum QudaParity_s QudaParity
QudaMatPCType matpcType
Definition: dirac_quda.h:98
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const
Dirac & operator=(const Dirac &dirac)
Definition: dirac.cpp:32
GaugeCovDev * dirac
Definition: covdev_test.cpp:75
cpuColorSpinorField * out
virtual void Dslash(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const
int VolumeCB() const
virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const
DiracStaggered(const DiracParam &param)
QudaParity parity
Definition: covdev_test.cpp:53
virtual void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, const ColorSpinorField &x, const double &k) const
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const
ColorSpinorField * tmp1
Definition: dirac_quda.h:101