QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
dirac_improved_staggered.cpp
Go to the documentation of this file.
1 #include <dirac_quda.h>
2 #include <blas_quda.h>
3 
4 namespace quda {
5 
7  : Dirac(param), fatGauge(*(param.fatGauge)), longGauge(*(param.longGauge)) { }
8 
10  : Dirac(dirac), fatGauge(dirac.fatGauge), longGauge(dirac.longGauge) { }
11 
13 
15  {
16  if (&dirac != this) {
17  Dirac::operator=(dirac);
18  fatGauge = dirac.fatGauge;
19  longGauge = dirac.longGauge;
20  }
21  return *this;
22  }
23 
25  {
26  if (in.Ndim() != 5 || out.Ndim() != 5) {
27  errorQuda("Staggered dslash requires 5-d fermion fields");
28  }
29 
30  if (in.Precision() != out.Precision()) {
31  errorQuda("Input and output spinor precisions don't match in dslash_quda");
32  }
33 
34  if (in.Stride() != out.Stride()) {
35  errorQuda("Input %d and output %d spinor strides don't match in dslash_quda", in.Stride(), out.Stride());
36  }
37 
39  errorQuda("ColorSpinorFields are not single parity, in = %d, out = %d",
40  in.SiteSubset(), out.SiteSubset());
41  }
42 
43  if ((out.Volume()/out.X(4) != 2*fatGauge.VolumeCB() && out.SiteSubset() == QUDA_FULL_SITE_SUBSET) ||
44  (out.Volume()/out.X(4) != fatGauge.VolumeCB() && out.SiteSubset() == QUDA_PARITY_SITE_SUBSET) ) {
45  errorQuda("Spinor volume %d doesn't match gauge volume %d", out.Volume(), fatGauge.VolumeCB());
46  }
47  }
48 
50  {
51  checkParitySpinor(in, out);
52 
53  ApplyImprovedStaggered(out, in, fatGauge, longGauge, 0., in, parity, dagger, commDim, profile);
54  flops += 1146ll*in.Volume();
55  }
56 
58  const ColorSpinorField &x, const double &k) const
59  {
60  checkParitySpinor(in, out);
61 
62  // Need to catch the zero mass case.
63  if (k == 0.0) {
64  // There's a sign convention difference for Dslash vs DslashXpay, which is
65  // triggered by looking for k == 0. We need to hack around this.
66  if (dagger == QUDA_DAG_YES) {
68  } else {
70  }
71  flops += 1146ll * in.Volume();
72  } else {
73  ApplyImprovedStaggered(out, in, fatGauge, longGauge, k, x, parity, dagger, commDim, profile);
74  flops += 1158ll * in.Volume();
75  }
76  }
77 
78  // Full staggered operator
80  {
81  checkFullSpinor(out, in);
82  // Need to flip sign via dagger convention if mass == 0.
83  if (mass == 0.0) {
84  if (dagger == QUDA_DAG_YES) {
86  } else {
88  }
89  flops += 1146ll * in.Volume();
90  } else {
92  flops += 1158ll * in.Volume();
93  }
94  }
95 
97  {
98  bool reset = newTmp(&tmp1, in);
99 
100  //even
101  Dslash(tmp1->Even(), in.Even(), QUDA_ODD_PARITY);
102  DslashXpay(out.Even(), tmp1->Even(), QUDA_EVEN_PARITY, in.Even(), 4*mass*mass);
103 
104  //odd
105  Dslash(tmp1->Even(), in.Odd(), QUDA_EVEN_PARITY);
106  DslashXpay(out.Odd(), tmp1->Even(), QUDA_ODD_PARITY, in.Odd(), 4*mass*mass);
107 
108  deleteTmp(&tmp1, reset);
109  }
110 
113  const QudaSolutionType solType) const
114  {
115  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
116  errorQuda("Preconditioned solution requires a preconditioned solve_type");
117  }
118 
119  src = &b;
120  sol = &x;
121  }
122 
124  const QudaSolutionType solType) const
125  {
126  // do nothing
127  }
128 
129 
131  : DiracImprovedStaggered(param)
132  {
133 
134  }
135 
137  : DiracImprovedStaggered(dirac)
138  {
139 
140  }
141 
143  {
144 
145  }
146 
148  {
149  if (&dirac != this) {
151  }
152 
153  return *this;
154  }
155 
156  // Unlike with clover, for ex, we don't need a custom Dslash or DslashXpay.
157  // That's because the convention for preconditioned staggered is to
158  // NOT divide out the factor of "2m", i.e., for the even system we invert
159  // (4m^2 - D_eo D_oe), not (1 - (1/(4m^2)) D_eo D_oe).
160 
162  {
163  bool reset = newTmp(&tmp1, in);
164 
166  QudaParity other_parity = QUDA_INVALID_PARITY;
168  parity = QUDA_EVEN_PARITY;
169  other_parity = QUDA_ODD_PARITY;
170  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
171  parity = QUDA_ODD_PARITY;
172  other_parity = QUDA_EVEN_PARITY;
173  } else {
174  errorQuda("Invalid matpcType(%d) in function\n", matpcType);
175  }
176 
177  // Convention note: Dslash applies D_eo, DslashXpay applies 4m^2 - D_oe!
178  // Note the minus sign convention in the Xpay version.
179  // This applies equally for the e <-> o permutation.
180 
181  Dslash(*tmp1, in, other_parity);
182  DslashXpay(out, *tmp1, parity, in, 4*mass*mass);
183 
184  deleteTmp(&tmp1, reset);
185  }
186 
188  {
189  errorQuda("MdagM is no longer defined for DiracImprovedStaggeredPC. Use M instead.\n");
190  /*
191  // need extra temporary because for multi-gpu the input
192  // and output fields cannot alias
193  bool reset = newTmp(&tmp2, in);
194  M(*tmp2, in);
195  M(out, *tmp2); // doesn't need to be Mdag b/c M is normal!
196  deleteTmp(&tmp2, reset);
197  */
198  }
199 
202  const QudaSolutionType solType) const
203  {
204  // we desire solution to preconditioned system
205  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
206  src = &b;
207  sol = &x;
208  return;
209  }
210 
211  // we desire solution to full system.
212  // See sign convention comment in DiracStaggeredPC::M().
214  printfQuda("Prepare for even-even.\n");
215  // With the convention given in DiracStaggered::M(),
216  // the source is src = 2m b_e + D_eo b_o
217  // But remember, DslashXpay actually applies
218  // -D_eo. Flip the sign on 2m to compensate, and
219  // then flip the overall sign.
220  src = &(x.Odd());
221  DslashXpay(*src, b.Odd(), QUDA_EVEN_PARITY, b.Even(), -2*mass);
222  blas::ax(-1.0, *src);
223  sol = &(x.Even());
224  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
225  // See above, permute e <-> o
226  src = &(x.Even());
227  DslashXpay(*src, b.Even(), QUDA_ODD_PARITY, b.Odd(), -2*mass);
228  blas::ax(-1.0, *src);
229  sol = &(x.Odd());
230  } else {
231  errorQuda("MatPCType %d not valid for DiracStaggeredPC", matpcType);
232  }
233 
234  // here we use final solution to store parity solution and parity source
235  // b is now up for grabs if we want
236  }
237 
239  const QudaSolutionType solType) const
240  {
241  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
242  return;
243  }
244 
245  checkFullSpinor(x, b);
246 
247  // create full solution
248  // See sign convention comment in DiracStaggeredPC::M()
250  printfQuda("Reconstruct even-even.\n");
251 
252  // With the convention given in DiracStaggered::M(),
253  // the reconstruct is x_o = 1/(2m) (b_o + D_oe x_e)
254  // But remember: DslashXpay actually applies -D_oe,
255  // so just like above we need to flip the sign
256  // on b_o. We then correct this by applying an additional
257  // minus sign when we rescale by 2m.
258  DslashXpay(x.Odd(), x.Even(), QUDA_ODD_PARITY, b.Odd(), -1.0);
259  blas::ax(-0.5/mass, x.Odd());
260  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
261  // See above, permute e <-> o
262  DslashXpay(x.Even(), x.Odd(), QUDA_EVEN_PARITY, b.Even(), -1.0);
263  blas::ax(-0.5/mass, x.Even());
264  } else {
265  errorQuda("MatPCType %d not valid for DiracStaggeredPC", matpcType);
266  }
267  }
268 
269 } // namespace quda
DiracImprovedStaggeredPC & operator=(const DiracImprovedStaggeredPC &dirac)
void ax(double a, ColorSpinorField &x)
Definition: blas_quda.cu:508
DiracImprovedStaggeredPC(const DiracParam &param)
unsigned long long flops
Definition: dirac_quda.h:121
virtual void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, const ColorSpinorField &x, const double &k) const
virtual void prepare(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType) const
#define errorQuda(...)
Definition: util_quda.h:121
virtual void checkFullSpinor(const ColorSpinorField &, const ColorSpinorField &) const
Definition: dirac.cpp:146
const ColorSpinorField & Even() const
void deleteTmp(ColorSpinorField **, const bool &reset) const
Definition: dirac.cpp:81
const ColorSpinorField & Odd() const
TimeProfile profile
Definition: dirac_quda.h:132
virtual void reconstruct(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType) const
cudaGaugeField & fatGauge
Definition: dirac_quda.h:759
QudaGaugeParam param
Definition: pack_test.cpp:17
bool newTmp(ColorSpinorField **, const ColorSpinorField &) const
Definition: dirac.cpp:70
int commDim[QUDA_MAX_DIM]
Definition: dirac_quda.h:130
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const
cpuColorSpinorField * in
QudaSiteSubset SiteSubset() const
double mass
Definition: dirac_quda.h:117
virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const
enum QudaSolutionType_s QudaSolutionType
QudaDagType dagger
Definition: dirac_quda.h:120
enum QudaParity_s QudaParity
QudaMatPCType matpcType
Definition: dirac_quda.h:119
virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const
Dirac & operator=(const Dirac &dirac)
Definition: dirac.cpp:49
GaugeCovDev * dirac
Definition: covdev_test.cpp:73
cpuColorSpinorField * out
virtual void prepare(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType) const
#define printfQuda(...)
Definition: util_quda.h:115
int VolumeCB() const
cudaGaugeField & longGauge
Definition: dirac_quda.h:760
void ApplyImprovedStaggered(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const GaugeField &L, double a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
Apply the improved staggered dslash operator to a color-spinor field.
const int * X() const
virtual void checkParitySpinor(const ColorSpinorField &, const ColorSpinorField &) const
virtual void reconstruct(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType) const
virtual void Dslash(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const
QudaPrecision Precision() const
QudaParity parity
Definition: covdev_test.cpp:54
DiracImprovedStaggered(const DiracParam &param)
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const
DiracImprovedStaggered & operator=(const DiracImprovedStaggered &dirac)
ColorSpinorField * tmp1
Definition: dirac_quda.h:122