QUDA  v1.1.0
A library for QCD on GPUs
dirac_staggered.cpp
Go to the documentation of this file.
1 #include <dirac_quda.h>
2 #include <blas_quda.h>
3 #include <multigrid.h>
4 
5 namespace quda {
6 
8 
10 
12 
14  {
15  if (&dirac != this) {
17  }
18  return *this;
19  }
20 
22  {
23  if (in.Ndim() != 5 || out.Ndim() != 5) {
24  errorQuda("Staggered dslash requires 5-d fermion fields");
25  }
26 
27  if (in.Precision() != out.Precision()) {
28  errorQuda("Input and output spinor precisions don't match in dslash_quda");
29  }
30 
32  errorQuda("ColorSpinorFields are not single parity, in = %d, out = %d",
33  in.SiteSubset(), out.SiteSubset());
34  }
35 
36  if ((out.Volume()/out.X(4) != 2*gauge->VolumeCB() && out.SiteSubset() == QUDA_FULL_SITE_SUBSET) ||
37  (out.Volume()/out.X(4) != gauge->VolumeCB() && out.SiteSubset() == QUDA_PARITY_SITE_SUBSET) ) {
38  errorQuda("Spinor volume %lu doesn't match gauge volume %lu", out.Volume(), gauge->VolumeCB());
39  }
40  }
41 
42 
44  const QudaParity parity) const
45  {
46  checkParitySpinor(in, out);
47 
48  ApplyStaggered(out, in, *gauge, 0., in, parity, dagger, commDim, profile);
49  flops += 570ll*in.Volume();
50  }
51 
53  const QudaParity parity, const ColorSpinorField &x,
54  const double &k) const
55  {
56  checkParitySpinor(in, out);
57 
58  // Need to catch the zero mass case.
59  if (k == 0.0) {
60  // There's a sign convention difference for Dslash vs DslashXpay, which is
61  // triggered by looking for k == 0. We need to hack around this.
62  if (dagger == QUDA_DAG_YES) {
63  ApplyStaggered(out, in, *gauge, 0., x, parity, QUDA_DAG_NO, commDim, profile);
64  } else {
65  ApplyStaggered(out, in, *gauge, 0., x, parity, QUDA_DAG_YES, commDim, profile);
66  }
67  flops += 570ll * in.Volume();
68  } else {
69  ApplyStaggered(out, in, *gauge, k, x, parity, dagger, commDim, profile);
70  flops += 582ll * in.Volume();
71  }
72  }
73 
74  // Full staggered operator
76  {
77  // Due to the staggered convention, this is applying
78  // ( 2m -D_eo ) (x_e) = (b_e)
79  // ( -D_oe 2m ) (x_o) = (b_o)
80  // ... but under the hood we need to catch the zero mass case.
81 
82  checkFullSpinor(out, in);
83 
84  if (mass == 0.) {
85  if (dagger == QUDA_DAG_YES) {
87  } else {
89  }
90  flops += 570ll * in.Volume();
91  } else {
93  flops += 582ll * in.Volume();
94  }
95  }
96 
98  {
99  bool reset = newTmp(&tmp1, in);
100 
101  //even
102  Dslash(tmp1->Even(), in.Even(), QUDA_ODD_PARITY);
103  DslashXpay(out.Even(), tmp1->Even(), QUDA_EVEN_PARITY, in.Even(), 4*mass*mass);
104 
105  //odd
106  Dslash(tmp1->Even(), in.Odd(), QUDA_EVEN_PARITY);
107  DslashXpay(out.Odd(), tmp1->Even(), QUDA_ODD_PARITY, in.Odd(), 4*mass*mass);
108 
109  deleteTmp(&tmp1, reset);
110  }
111 
114  const QudaSolutionType solType) const
115  {
116  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
117  errorQuda("Preconditioned solution requires a preconditioned solve_type");
118  }
119 
120  src = &b;
121  sol = &x;
122  }
123 
125  const QudaSolutionType solType) const
126  {
127  // do nothing
128  }
129 
131  double kappa, double mass, double mu, double mu_factor) const {
133  errorQuda("The optimized Kahler-Dirac operator is not built through createCoarseOp");
134 
135  // nullptr == no Kahler-Dirac Xinv
136  const cudaGaugeField *XinvKD = nullptr;
138  }
139 
140 
143  {
144 
145  }
146 
149  {
150 
151  }
152 
154  {
155 
156  }
157 
159  {
160  if (&dirac != this) {
162  }
163 
164  return *this;
165  }
166 
167  // Unlike with clover, for ex, we don't need a custom Dslash or DslashXpay.
168  // That's because the convention for preconditioned staggered is to
169  // NOT divide out the factor of "2m", i.e., for the even system we invert
170  // (4m^2 - D_eo D_oe), not (1 - (1/(4m^2)) D_eo D_oe).
171 
173  {
174  bool reset = newTmp(&tmp1, in);
175 
177  QudaParity other_parity = QUDA_INVALID_PARITY;
180  other_parity = QUDA_ODD_PARITY;
181  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
183  other_parity = QUDA_EVEN_PARITY;
184  } else {
185  errorQuda("Invalid matpcType(%d) in function\n", matpcType);
186  }
187 
188  // Convention note: Dslash applies D_eo, DslashXpay applies 4m^2 - D_oe!
189  // Note the minus sign convention in the Xpay version.
190  // This applies equally for the e <-> o permutation.
191 
192  Dslash(*tmp1, in, other_parity);
193  DslashXpay(out, *tmp1, parity, in, 4*mass*mass);
194 
195  deleteTmp(&tmp1, reset);
196  }
197 
199  {
200  errorQuda("MdagM is no longer defined for DiracStaggeredPC. Use M instead.\n");
201  /*
202  // need extra temporary because for multi-gpu the input
203  // and output fields cannot alias
204  bool reset = newTmp(&tmp2, in);
205  M(*tmp2, in);
206  M(out, *tmp2); // doesn't need to be Mdag b/c M is normal!
207  deleteTmp(&tmp2, reset);
208  */
209  }
210 
213  const QudaSolutionType solType) const
214  {
215  // we desire solution to preconditioned system
216  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
217  src = &b;
218  sol = &x;
219  return;
220  }
221 
222  // we desire solution to full system.
223  // See sign convention comment in DiracStaggeredPC::M().
225  // With the convention given in DiracStaggered::M(),
226  // the source is src = 2m b_e + D_eo b_o
227  // But remember, DslashXpay actually applies
228  // -D_eo. Flip the sign on 2m to compensate, and
229  // then flip the overall sign.
230  src = &(x.Odd());
231  DslashXpay(*src, b.Odd(), QUDA_EVEN_PARITY, b.Even(), -2*mass);
232  blas::ax(-1.0, *src);
233  sol = &(x.Even());
234  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
235  // See above, permute e <-> o
236  src = &(x.Even());
237  DslashXpay(*src, b.Even(), QUDA_ODD_PARITY, b.Odd(), -2*mass);
238  blas::ax(-1.0, *src);
239  sol = &(x.Odd());
240  } else {
241  errorQuda("MatPCType %d not valid for DiracStaggeredPC", matpcType);
242  }
243 
244  // here we use final solution to store parity solution and parity source
245  // b is now up for grabs if we want
246 
247  }
248 
250  const QudaSolutionType solType) const
251  {
252 
253  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
254  return;
255  }
256 
257  checkFullSpinor(x, b);
258 
259  // create full solution
260  // See sign convention comment in DiracStaggeredPC::M()
262 
263  // With the convention given in DiracStaggered::M(),
264  // the reconstruct is x_o = 1/(2m) (b_o + D_oe x_e)
265  // But remember: DslashXpay actually applies -D_oe,
266  // so just like above we need to flip the sign
267  // on b_o. We then correct this by applying an additional
268  // minus sign when we rescale by 2m.
269  DslashXpay(x.Odd(), x.Even(), QUDA_ODD_PARITY, b.Odd(), -1.0);
270  blas::ax(-0.5/mass, x.Odd());
271  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
272  // See above, permute e <-> o
273  DslashXpay(x.Even(), x.Odd(), QUDA_EVEN_PARITY, b.Even(), -1.0);
274  blas::ax(-0.5/mass, x.Even());
275  } else {
276  errorQuda("MatPCType %d not valid for DiracStaggeredPC", matpcType);
277  }
278 
279  }
280 
281 
282 
283 } // namespace quda
const ColorSpinorField & Odd() const
QudaSiteSubset SiteSubset() const
const ColorSpinorField & Even() const
const int * X() const
unsigned long long flops
Definition: dirac_quda.h:150
bool newTmp(ColorSpinorField **, const ColorSpinorField &) const
Definition: dirac.cpp:72
QudaMatPCType matpcType
Definition: dirac_quda.h:148
cudaGaugeField * gauge
Definition: dirac_quda.h:144
double mass
Definition: dirac_quda.h:146
void deleteTmp(ColorSpinorField **, const bool &reset) const
Definition: dirac.cpp:83
ColorSpinorField * tmp1
Definition: dirac_quda.h:151
TimeProfile profile
Definition: dirac_quda.h:161
QudaDagType dagger
Definition: dirac_quda.h:149
Dirac & operator=(const Dirac &dirac)
Definition: dirac.cpp:51
virtual void checkFullSpinor(const ColorSpinorField &, const ColorSpinorField &) const
check full spinors are compatible (check geometry ?)
Definition: dirac.cpp:138
int commDim[QUDA_MAX_DIM]
Definition: dirac_quda.h:159
virtual void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, const ColorSpinorField &x, const double &k) const
Xpay version of Dslash.
virtual void reconstruct(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType) const
DiracStaggered & operator=(const DiracStaggered &dirac)
DiracStaggered(const DiracParam &param)
virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const
Apply MdagM operator which may be optimized.
virtual void checkParitySpinor(const ColorSpinorField &, const ColorSpinorField &) const
Check parity spinors are usable (check geometry ?)
void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu=0., double mu_factor=0.) const
Create the coarse staggered operator.
virtual void prepare(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType) const
virtual void Dslash(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const
apply 'dslash' operator for the DiracOp. This may be e.g. AD
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const
Apply M for the dirac op. E.g. the Schur Complement operator.
virtual void reconstruct(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType) const
DiracStaggeredPC & operator=(const DiracStaggeredPC &dirac)
virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const
Apply MdagM operator which may be optimized.
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const
Apply M for the dirac op. E.g. the Schur Complement operator.
DiracStaggeredPC(const DiracParam &param)
virtual void prepare(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType) const
QudaPrecision Precision() const
size_t VolumeCB() const
QudaTransferType getTransferType() const
Definition: transfer.h:240
double kappa
double mass
double mu
quda::mgarray< double > mu_factor
GaugeCovDev * dirac
Definition: covdev_test.cpp:42
QudaParity parity
Definition: covdev_test.cpp:40
@ QUDA_STAGGERED_DIRAC
Definition: enum_quda.h:305
@ QUDA_DAG_NO
Definition: enum_quda.h:223
@ QUDA_DAG_YES
Definition: enum_quda.h:223
@ QUDA_FULL_SITE_SUBSET
Definition: enum_quda.h:333
@ QUDA_PARITY_SITE_SUBSET
Definition: enum_quda.h:332
@ QUDA_EVEN_PARITY
Definition: enum_quda.h:284
@ QUDA_ODD_PARITY
Definition: enum_quda.h:284
@ QUDA_INVALID_PARITY
Definition: enum_quda.h:284
@ QUDA_TRANSFER_OPTIMIZED_KD
Definition: enum_quda.h:455
enum QudaSolutionType_s QudaSolutionType
@ QUDA_MATPC_ODD_ODD
Definition: enum_quda.h:217
@ QUDA_MATPC_EVEN_EVEN
Definition: enum_quda.h:216
@ QUDA_MATPC_INVALID
Definition: enum_quda.h:220
@ QUDA_MATPC_SOLUTION
Definition: enum_quda.h:159
@ QUDA_MATPCDAG_MATPC_SOLUTION
Definition: enum_quda.h:161
enum QudaParity_s QudaParity
void ax(double a, ColorSpinorField &x)
void ApplyStaggered(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
Apply the staggered dslash operator to a color-spinor field.
void StaggeredCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, const cudaGaugeField &gauge, const cudaGaugeField *XinvKD, double mass, QudaDiracType dirac, QudaMatPCType matpc)
Coarse operator construction from a fine-grid operator (Staggered)
QudaGaugeParam param
Definition: pack_test.cpp:18
#define errorQuda(...)
Definition: util_quda.h:120