QUDA  v1.1.0
A library for QCD on GPUs
dirac_improved_staggered.cpp
Go to the documentation of this file.
1 #include <dirac_quda.h>
2 #include <blas_quda.h>
3 #include <multigrid.h>
4 
5 namespace quda {
6 
8  Dirac(param),
9  fatGauge(param.fatGauge),
10  longGauge(param.longGauge)
11  {
12  }
13 
15  : Dirac(dirac), fatGauge(dirac.fatGauge), longGauge(dirac.longGauge) { }
16 
18 
20  {
21  if (&dirac != this) {
23  fatGauge = dirac.fatGauge;
24  longGauge = dirac.longGauge;
25  }
26  return *this;
27  }
28 
30  {
31  if (in.Ndim() != 5 || out.Ndim() != 5) {
32  errorQuda("Staggered dslash requires 5-d fermion fields");
33  }
34 
35  if (in.Precision() != out.Precision()) {
36  errorQuda("Input and output spinor precisions don't match in dslash_quda");
37  }
38 
40  errorQuda("ColorSpinorFields are not single parity, in = %d, out = %d",
41  in.SiteSubset(), out.SiteSubset());
42  }
43 
44  if ((out.Volume() / out.X(4) != 2 * fatGauge->VolumeCB() && out.SiteSubset() == QUDA_FULL_SITE_SUBSET)
45  || (out.Volume() / out.X(4) != fatGauge->VolumeCB() && out.SiteSubset() == QUDA_PARITY_SITE_SUBSET)) {
46  errorQuda("Spinor volume %lu doesn't match gauge volume %lu", out.Volume(), fatGauge->VolumeCB());
47  }
48  }
49 
51  {
52  checkParitySpinor(in, out);
53 
55  flops += 1146ll*in.Volume();
56  }
57 
59  const ColorSpinorField &x, const double &k) const
60  {
61  checkParitySpinor(in, out);
62 
63  // Need to catch the zero mass case.
64  if (k == 0.0) {
65  // There's a sign convention difference for Dslash vs DslashXpay, which is
66  // triggered by looking for k == 0. We need to hack around this.
67  if (dagger == QUDA_DAG_YES) {
69  } else {
71  }
72  flops += 1146ll * in.Volume();
73  } else {
75  flops += 1158ll * in.Volume();
76  }
77  }
78 
79  // Full staggered operator
81  {
82  checkFullSpinor(out, in);
83  // Need to flip sign via dagger convention if mass == 0.
84  if (mass == 0.0) {
85  if (dagger == QUDA_DAG_YES) {
87  profile);
88  } else {
90  profile);
91  }
92  flops += 1146ll * in.Volume();
93  } else {
95  profile);
96  flops += 1158ll * in.Volume();
97  }
98  }
99 
101  {
102  bool reset = newTmp(&tmp1, in);
103 
104  //even
105  Dslash(tmp1->Even(), in.Even(), QUDA_ODD_PARITY);
106  DslashXpay(out.Even(), tmp1->Even(), QUDA_EVEN_PARITY, in.Even(), 4*mass*mass);
107 
108  //odd
109  Dslash(tmp1->Even(), in.Odd(), QUDA_EVEN_PARITY);
110  DslashXpay(out.Odd(), tmp1->Even(), QUDA_ODD_PARITY, in.Odd(), 4*mass*mass);
111 
112  deleteTmp(&tmp1, reset);
113  }
114 
117  const QudaSolutionType solType) const
118  {
119  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
120  errorQuda("Preconditioned solution requires a preconditioned solve_type");
121  }
122 
123  src = &b;
124  sol = &x;
125  }
126 
128  const QudaSolutionType solType) const
129  {
130  // do nothing
131  }
132 
134  double mass, double mu, double mu_factor) const
135  {
137  errorQuda("The optimized improved Kahler-Dirac operator is not built through createCoarseOp");
138 
139  // nullptr == no Kahler-Dirac Xinv
140  const cudaGaugeField *XinvKD = nullptr;
142  }
143 
145  {
146  Dirac::prefetch(mem_space, stream);
147  fatGauge->prefetch(mem_space, stream);
148  longGauge->prefetch(mem_space, stream);
149  }
150 
153  {
154 
155  }
156 
159  {
160 
161  }
162 
164  {
165 
166  }
167 
169  {
170  if (&dirac != this) {
172  }
173 
174  return *this;
175  }
176 
177  // Unlike with clover, for ex, we don't need a custom Dslash or DslashXpay.
178  // That's because the convention for preconditioned staggered is to
179  // NOT divide out the factor of "2m", i.e., for the even system we invert
180  // (4m^2 - D_eo D_oe), not (1 - (1/(4m^2)) D_eo D_oe).
181 
183  {
184  bool reset = newTmp(&tmp1, in);
185 
187  QudaParity other_parity = QUDA_INVALID_PARITY;
190  other_parity = QUDA_ODD_PARITY;
191  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
193  other_parity = QUDA_EVEN_PARITY;
194  } else {
195  errorQuda("Invalid matpcType(%d) in function\n", matpcType);
196  }
197 
198  // Convention note: Dslash applies D_eo, DslashXpay applies 4m^2 - D_oe!
199  // Note the minus sign convention in the Xpay version.
200  // This applies equally for the e <-> o permutation.
201 
202  Dslash(*tmp1, in, other_parity);
203  DslashXpay(out, *tmp1, parity, in, 4*mass*mass);
204 
205  deleteTmp(&tmp1, reset);
206  }
207 
209  {
210  errorQuda("MdagM is no longer defined for DiracImprovedStaggeredPC. Use M instead.\n");
211  /*
212  // need extra temporary because for multi-gpu the input
213  // and output fields cannot alias
214  bool reset = newTmp(&tmp2, in);
215  M(*tmp2, in);
216  M(out, *tmp2); // doesn't need to be Mdag b/c M is normal!
217  deleteTmp(&tmp2, reset);
218  */
219  }
220 
223  const QudaSolutionType solType) const
224  {
225 
226  // we desire solution to preconditioned system
227  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
228  src = &b;
229  sol = &x;
230  return;
231  }
232 
233  // we desire solution to full system.
234  // See sign convention comment in DiracStaggeredPC::M().
236 
237  // With the convention given in DiracStaggered::M(),
238  // the source is src = 2m b_e + D_eo b_o
239  // But remember, DslashXpay actually applies
240  // -D_eo. Flip the sign on 2m to compensate, and
241  // then flip the overall sign.
242  src = &(x.Odd());
243  DslashXpay(*src, b.Odd(), QUDA_EVEN_PARITY, b.Even(), -2*mass);
244  blas::ax(-1.0, *src);
245  sol = &(x.Even());
246  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
247  // See above, permute e <-> o
248  src = &(x.Even());
249  DslashXpay(*src, b.Even(), QUDA_ODD_PARITY, b.Odd(), -2*mass);
250  blas::ax(-1.0, *src);
251  sol = &(x.Odd());
252  } else {
253  errorQuda("MatPCType %d not valid for DiracStaggeredPC", matpcType);
254  }
255 
256  // here we use final solution to store parity solution and parity source
257  // b is now up for grabs if we want
258  }
259 
261  const QudaSolutionType solType) const
262  {
263 
264  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
265  return;
266  }
267 
268  checkFullSpinor(x, b);
269 
270  // create full solution
271  // See sign convention comment in DiracStaggeredPC::M()
273 
274  // With the convention given in DiracStaggered::M(),
275  // the reconstruct is x_o = 1/(2m) (b_o + D_oe x_e)
276  // But remember: DslashXpay actually applies -D_oe,
277  // so just like above we need to flip the sign
278  // on b_o. We then correct this by applying an additional
279  // minus sign when we rescale by 2m.
280  DslashXpay(x.Odd(), x.Even(), QUDA_ODD_PARITY, b.Odd(), -1.0);
281  blas::ax(-0.5/mass, x.Odd());
282  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
283  // See above, permute e <-> o
284  DslashXpay(x.Even(), x.Odd(), QUDA_EVEN_PARITY, b.Even(), -1.0);
285  blas::ax(-0.5/mass, x.Even());
286  } else {
287  errorQuda("MatPCType %d not valid for DiracStaggeredPC", matpcType);
288  }
289  }
290 
291 } // namespace quda
const ColorSpinorField & Odd() const
QudaSiteSubset SiteSubset() const
const ColorSpinorField & Even() const
const int * X() const
unsigned long long flops
Definition: dirac_quda.h:150
bool newTmp(ColorSpinorField **, const ColorSpinorField &) const
Definition: dirac.cpp:72
virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream=0) const
If managed memory and prefetch is enabled, prefetch the gauge field and temporary spinors to the CPU ...
Definition: dirac.cpp:305
QudaMatPCType matpcType
Definition: dirac_quda.h:148
double mass
Definition: dirac_quda.h:146
void deleteTmp(ColorSpinorField **, const bool &reset) const
Definition: dirac.cpp:83
ColorSpinorField * tmp1
Definition: dirac_quda.h:151
TimeProfile profile
Definition: dirac_quda.h:161
QudaDagType dagger
Definition: dirac_quda.h:149
Dirac & operator=(const Dirac &dirac)
Definition: dirac.cpp:51
virtual void checkFullSpinor(const ColorSpinorField &, const ColorSpinorField &) const
check full spinors are compatible (check geometry ?)
Definition: dirac.cpp:138
int commDim[QUDA_MAX_DIM]
Definition: dirac_quda.h:159
virtual void reconstruct(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType) const
DiracImprovedStaggered & operator=(const DiracImprovedStaggered &dirac)
cudaGaugeField * fatGauge
Definition: dirac_quda.h:1356
virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream=0) const
If managed memory and prefetch is enabled, prefetch all relevant memory fields (fat+long links,...
cudaGaugeField * longGauge
Definition: dirac_quda.h:1357
void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu=0., double mu_factor=0.) const
Create the coarse staggered operator.
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const
Apply M for the dirac op. E.g. the Schur Complement operator.
virtual void Dslash(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const
apply 'dslash' operator for the DiracOp. This may be e.g. AD
virtual void checkParitySpinor(const ColorSpinorField &, const ColorSpinorField &) const
Check parity spinors are usable (check geometry ?)
DiracImprovedStaggered(const DiracParam &param)
virtual void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, const ColorSpinorField &x, const double &k) const
Xpay version of Dslash.
virtual void prepare(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType) const
virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const
Apply MdagM operator which may be optimized.
virtual void reconstruct(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType) const
DiracImprovedStaggeredPC & operator=(const DiracImprovedStaggeredPC &dirac)
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const
Apply M for the dirac op. E.g. the Schur Complement operator.
virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const
Apply MdagM operator which may be optimized.
virtual void prepare(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType) const
DiracImprovedStaggeredPC(const DiracParam &param)
QudaPrecision Precision() const
size_t VolumeCB() const
QudaTransferType getTransferType() const
Definition: transfer.h:240
void prefetch(QudaFieldLocation mem_space, qudaStream_t stream=0) const
If managed memory and prefetch is enabled, prefetch the gauge field and buffers to the CPU or the GPU...
double kappa
double mass
double mu
quda::mgarray< double > mu_factor
GaugeCovDev * dirac
Definition: covdev_test.cpp:42
QudaParity parity
Definition: covdev_test.cpp:40
@ QUDA_ASQTAD_DIRAC
Definition: enum_quda.h:308
@ QUDA_DAG_NO
Definition: enum_quda.h:223
@ QUDA_DAG_YES
Definition: enum_quda.h:223
@ QUDA_FULL_SITE_SUBSET
Definition: enum_quda.h:333
@ QUDA_PARITY_SITE_SUBSET
Definition: enum_quda.h:332
@ QUDA_EVEN_PARITY
Definition: enum_quda.h:284
@ QUDA_ODD_PARITY
Definition: enum_quda.h:284
@ QUDA_INVALID_PARITY
Definition: enum_quda.h:284
@ QUDA_TRANSFER_OPTIMIZED_KD
Definition: enum_quda.h:455
enum QudaSolutionType_s QudaSolutionType
enum QudaFieldLocation_s QudaFieldLocation
@ QUDA_MATPC_ODD_ODD
Definition: enum_quda.h:217
@ QUDA_MATPC_EVEN_EVEN
Definition: enum_quda.h:216
@ QUDA_MATPC_INVALID
Definition: enum_quda.h:220
@ QUDA_MATPC_SOLUTION
Definition: enum_quda.h:159
@ QUDA_MATPCDAG_MATPC_SOLUTION
Definition: enum_quda.h:161
enum QudaParity_s QudaParity
void ax(double a, ColorSpinorField &x)
qudaStream_t * stream
void StaggeredCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, const cudaGaugeField &gauge, const cudaGaugeField *XinvKD, double mass, QudaDiracType dirac, QudaMatPCType matpc)
Coarse operator construction from a fine-grid operator (Staggered)
void ApplyImprovedStaggered(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const GaugeField &L, double a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
Apply the improved staggered dslash operator to a color-spinor field.
QudaGaugeParam param
Definition: pack_test.cpp:18
cudaStream_t qudaStream_t
Definition: quda_api.h:9
#define errorQuda(...)
Definition: util_quda.h:120