QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
dirac_staggered.cpp
Go to the documentation of this file.
1 #include <dirac_quda.h>
2 #include <blas_quda.h>
3 
4 namespace quda {
5 
7 
9 
11 
13  {
14  if (&dirac != this) {
15  Dirac::operator=(dirac);
16  }
17  return *this;
18  }
19 
21  {
22  if (in.Ndim() != 5 || out.Ndim() != 5) {
23  errorQuda("Staggered dslash requires 5-d fermion fields");
24  }
25 
26  if (in.Precision() != out.Precision()) {
27  errorQuda("Input and output spinor precisions don't match in dslash_quda");
28  }
29 
30  if (in.Stride() != out.Stride()) {
31  errorQuda("Input %d and output %d spinor strides don't match in dslash_quda", in.Stride(), out.Stride());
32  }
33 
35  errorQuda("ColorSpinorFields are not single parity, in = %d, out = %d",
36  in.SiteSubset(), out.SiteSubset());
37  }
38 
39  if ((out.Volume()/out.X(4) != 2*gauge->VolumeCB() && out.SiteSubset() == QUDA_FULL_SITE_SUBSET) ||
40  (out.Volume()/out.X(4) != gauge->VolumeCB() && out.SiteSubset() == QUDA_PARITY_SITE_SUBSET) ) {
41  errorQuda("Spinor volume %d doesn't match gauge volume %d", out.Volume(), gauge->VolumeCB());
42  }
43  }
44 
45 
47  const QudaParity parity) const
48  {
49  checkParitySpinor(in, out);
50 
51  ApplyStaggered(out, in, *gauge, 0., in, parity, dagger, commDim, profile);
52  flops += 570ll*in.Volume();
53  }
54 
56  const QudaParity parity, const ColorSpinorField &x,
57  const double &k) const
58  {
59  checkParitySpinor(in, out);
60 
61  // Need to catch the zero mass case.
62  if (k == 0.0) {
63  // There's a sign convention difference for Dslash vs DslashXpay, which is
64  // triggered by looking for k == 0. We need to hack around this.
65  if (dagger == QUDA_DAG_YES) {
66  ApplyStaggered(out, in, *gauge, 0., x, parity, QUDA_DAG_NO, commDim, profile);
67  } else {
68  ApplyStaggered(out, in, *gauge, 0., x, parity, QUDA_DAG_YES, commDim, profile);
69  }
70  flops += 570ll * in.Volume();
71  } else {
72  ApplyStaggered(out, in, *gauge, k, x, parity, dagger, commDim, profile);
73  flops += 582ll * in.Volume();
74  }
75  }
76 
77  // Full staggered operator
79  {
80  // Due to the staggered convention, this is applying
81  // ( 2m -D_eo ) (x_e) = (b_e)
82  // ( -D_oe 2m ) (x_o) = (b_o)
83  // ... but under the hood we need to catch the zero mass case.
84 
85  checkFullSpinor(out, in);
86 
87  if (mass == 0.) {
88  if (dagger == QUDA_DAG_YES) {
90  } else {
92  }
93  flops += 570ll * in.Volume();
94  } else {
96  flops += 582ll * in.Volume();
97  }
98  }
99 
101  {
102  bool reset = newTmp(&tmp1, in);
103 
104  //even
105  Dslash(tmp1->Even(), in.Even(), QUDA_ODD_PARITY);
106  DslashXpay(out.Even(), tmp1->Even(), QUDA_EVEN_PARITY, in.Even(), 4*mass*mass);
107 
108  //odd
109  Dslash(tmp1->Even(), in.Odd(), QUDA_EVEN_PARITY);
110  DslashXpay(out.Odd(), tmp1->Even(), QUDA_ODD_PARITY, in.Odd(), 4*mass*mass);
111 
112  deleteTmp(&tmp1, reset);
113  }
114 
117  const QudaSolutionType solType) const
118  {
119  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
120  errorQuda("Preconditioned solution requires a preconditioned solve_type");
121  }
122 
123  src = &b;
124  sol = &x;
125  }
126 
128  const QudaSolutionType solType) const
129  {
130  // do nothing
131  }
132 
134  double kappa, double mass, double mu, double mu_factor) const {
135  errorQuda("Cannot coarsen a staggered operator (yet!), we're just getting the functions in place.");
136  //CoarseStaggeredOp(Y, X, T, *gauge, mass, QUDA_STAGGERED_DIRAC, QUDA_MATPC_INVALID);
137  }
138 
139 
141  : DiracStaggered(param)
142  {
143 
144  }
145 
147  : DiracStaggered(dirac)
148  {
149 
150  }
151 
153  {
154 
155  }
156 
158  {
159  if (&dirac != this) {
161  }
162 
163  return *this;
164  }
165 
166  // Unlike with clover, for ex, we don't need a custom Dslash or DslashXpay.
167  // That's because the convention for preconditioned staggered is to
168  // NOT divide out the factor of "2m", i.e., for the even system we invert
169  // (4m^2 - D_eo D_oe), not (1 - (1/(4m^2)) D_eo D_oe).
170 
172  {
173  bool reset = newTmp(&tmp1, in);
174 
176  QudaParity other_parity = QUDA_INVALID_PARITY;
178  parity = QUDA_EVEN_PARITY;
179  other_parity = QUDA_ODD_PARITY;
180  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
181  parity = QUDA_ODD_PARITY;
182  other_parity = QUDA_EVEN_PARITY;
183  } else {
184  errorQuda("Invalid matpcType(%d) in function\n", matpcType);
185  }
186 
187  // Convention note: Dslash applies D_eo, DslashXpay applies 4m^2 - D_oe!
188  // Note the minus sign convention in the Xpay version.
189  // This applies equally for the e <-> o permutation.
190 
191  Dslash(*tmp1, in, other_parity);
192  DslashXpay(out, *tmp1, parity, in, 4*mass*mass);
193 
194  deleteTmp(&tmp1, reset);
195  }
196 
198  {
199  errorQuda("MdagM is no longer defined for DiracStaggeredPC. Use M instead.\n");
200  /*
201  // need extra temporary because for multi-gpu the input
202  // and output fields cannot alias
203  bool reset = newTmp(&tmp2, in);
204  M(*tmp2, in);
205  M(out, *tmp2); // doesn't need to be Mdag b/c M is normal!
206  deleteTmp(&tmp2, reset);
207  */
208  }
209 
212  const QudaSolutionType solType) const
213  {
214  // we desire solution to preconditioned system
215  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
216  src = &b;
217  sol = &x;
218  return;
219  }
220 
221  // we desire solution to full system.
222  // See sign convention comment in DiracStaggeredPC::M().
224  // With the convention given in DiracStaggered::M(),
225  // the source is src = 2m b_e + D_eo b_o
226  // But remember, DslashXpay actually applies
227  // -D_eo. Flip the sign on 2m to compensate, and
228  // then flip the overall sign.
229  src = &(x.Odd());
230  DslashXpay(*src, b.Odd(), QUDA_EVEN_PARITY, b.Even(), -2*mass);
231  blas::ax(-1.0, *src);
232  sol = &(x.Even());
233  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
234  // See above, permute e <-> o
235  src = &(x.Even());
236  DslashXpay(*src, b.Even(), QUDA_ODD_PARITY, b.Odd(), -2*mass);
237  blas::ax(-1.0, *src);
238  sol = &(x.Odd());
239  } else {
240  errorQuda("MatPCType %d not valid for DiracStaggeredPC", matpcType);
241  }
242 
243  // here we use final solution to store parity solution and parity source
244  // b is now up for grabs if we want
245 
246  }
247 
249  const QudaSolutionType solType) const
250  {
251 
252  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
253  return;
254  }
255 
256  checkFullSpinor(x, b);
257 
258  // create full solution
259  // See sign convention comment in DiracStaggeredPC::M()
261 
262  // With the convention given in DiracStaggered::M(),
263  // the reconstruct is x_o = 1/(2m) (b_o + D_oe x_e)
264  // But remember: DslashXpay actually applies -D_oe,
265  // so just like above we need to flip the sign
266  // on b_o. We then correct this by applying an additional
267  // minus sign when we rescale by 2m.
268  DslashXpay(x.Odd(), x.Even(), QUDA_ODD_PARITY, b.Odd(), -1.0);
269  blas::ax(-0.5/mass, x.Odd());
270  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
271  // See above, permute e <-> o
272  DslashXpay(x.Even(), x.Odd(), QUDA_EVEN_PARITY, b.Even(), -1.0);
273  blas::ax(-0.5/mass, x.Even());
274  } else {
275  errorQuda("MatPCType %d not valid for DiracStaggeredPC", matpcType);
276  }
277 
278  }
279 
280 
281 
282 } // namespace quda
void ax(double a, ColorSpinorField &x)
Definition: blas_quda.cu:508
unsigned long long flops
Definition: dirac_quda.h:121
double mu
Definition: test_util.cpp:1648
virtual void prepare(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType) const
#define errorQuda(...)
Definition: util_quda.h:121
cudaGaugeField * gauge
Definition: dirac_quda.h:115
DiracStaggeredPC(const DiracParam &param)
virtual void checkFullSpinor(const ColorSpinorField &, const ColorSpinorField &) const
Definition: dirac.cpp:146
const ColorSpinorField & Even() const
void deleteTmp(ColorSpinorField **, const bool &reset) const
Definition: dirac.cpp:81
const ColorSpinorField & Odd() const
TimeProfile profile
Definition: dirac_quda.h:132
virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const
virtual void reconstruct(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType) const
QudaGaugeParam param
Definition: pack_test.cpp:17
bool newTmp(ColorSpinorField **, const ColorSpinorField &) const
Definition: dirac.cpp:70
DiracStaggered & operator=(const DiracStaggered &dirac)
int commDim[QUDA_MAX_DIM]
Definition: dirac_quda.h:130
DiracStaggeredPC & operator=(const DiracStaggeredPC &dirac)
virtual void prepare(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType) const
cpuColorSpinorField * in
QudaSiteSubset SiteSubset() const
double mass
Definition: dirac_quda.h:117
virtual void reconstruct(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType) const
enum QudaSolutionType_s QudaSolutionType
virtual void checkParitySpinor(const ColorSpinorField &, const ColorSpinorField &) const
QudaDagType dagger
Definition: dirac_quda.h:120
int X[4]
Definition: covdev_test.cpp:70
enum QudaParity_s QudaParity
double kappa
Definition: dirac_quda.h:116
QudaMatPCType matpcType
Definition: dirac_quda.h:119
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const
Dirac & operator=(const Dirac &dirac)
Definition: dirac.cpp:49
GaugeCovDev * dirac
Definition: covdev_test.cpp:73
cpuColorSpinorField * out
void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu=0., double mu_factor=0.) const
Create the coarse staggered operator. Unlike the Wilson operator, we assume a mass normalization...
virtual void Dslash(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const
int VolumeCB() const
virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const
void ApplyStaggered(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
Apply the staggered dslash operator to a color-spinor field.
double mu_factor[QUDA_MAX_MG_LEVEL]
Definition: test_util.cpp:1674
const int * X() const
DiracStaggered(const DiracParam &param)
QudaPrecision Precision() const
QudaParity parity
Definition: covdev_test.cpp:54
virtual void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, const ColorSpinorField &x, const double &k) const
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const
ColorSpinorField * tmp1
Definition: dirac_quda.h:122