QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
dirac_mobius.cpp
Go to the documentation of this file.
1 #include <iostream>
2 #include <dirac_quda.h>
3 #include <blas_quda.h>
4 
5 namespace quda {
6 
7  DiracMobius::DiracMobius(const DiracParam &param) : DiracDomainWall(param), zMobius(false)
8  {
9  memcpy(b_5, param.b_5, sizeof(Complex) * param.Ls);
10  memcpy(c_5, param.c_5, sizeof(Complex) * param.Ls);
11 
12  // check if doing zMobius
13  for (int i = 0; i < Ls; i++) {
14  if (b_5[i].imag() != 0.0 || c_5[i].imag() != 0.0 || (i < Ls - 1 && (b_5[i] != b_5[i + 1] || c_5[i] != c_5[i + 1]))) {
15  zMobius = true;
16  }
17  }
18 
19  if (getVerbosity() > QUDA_VERBOSE) {
20  if (zMobius)
21  printfQuda("%s: Detected variable or complex cofficients: using zMobius\n", __func__);
22  else
23  printfQuda("%s: Detected fixed real cofficients: using regular Mobius\n", __func__);
24  }
25  }
26 
28  {
29  memcpy(b_5, dirac.b_5, sizeof(Complex) * Ls);
30  memcpy(c_5, dirac.c_5, sizeof(Complex) * Ls);
31  }
32 
34 
36  {
37  if (&dirac != this) {
39  memcpy(b_5, dirac.b_5, sizeof(Complex) * Ls);
40  memcpy(c_5, dirac.c_5, sizeof(Complex) * Ls);
41  zMobius = dirac.zMobius;
42  }
43 
44  return *this;
45  }
46 
47 // Modification for the 4D preconditioned Mobius domain wall operator
49  const QudaParity parity) const
50  {
51  checkDWF(in, out);
52  checkParitySpinor(in, out);
53  checkSpinorAlias(in, out);
54 
55  ApplyDomainWall4D(out, in, *gauge, 0.0, 0.0, nullptr, nullptr, in, parity, dagger, commDim, profile);
56 
57  flops += 1320LL*(long long)in.Volume();
58  }
59 
61  {
62  checkDWF(in, out);
63  checkParitySpinor(in, out);
64  checkSpinorAlias(in, out);
65 
66  ApplyDslash5(out, in, in, mass, m5, b_5, c_5, 0.0, dagger, DSLASH5_MOBIUS_PRE);
67 
68  long long Ls = in.X(4);
69  long long bulk = (Ls-2)*(in.Volume()/Ls);
70  long long wall = 2*in.Volume()/Ls;
71  flops += 72LL*(long long)in.Volume() + 96LL*bulk + 120LL*wall;
72  }
73 
74  // Unlike DWF-4d, the Mobius variant here applies the full M5 operator and not just D5
76  {
77  checkDWF(in, out);
78  checkParitySpinor(in, out);
79  checkSpinorAlias(in, out);
80 
81  ApplyDslash5(out, in, in, mass, m5, b_5, c_5, 0.0, dagger, DSLASH5_MOBIUS);
82 
83  long long Ls = in.X(4);
84  long long bulk = (Ls-2)*(in.Volume()/Ls);
85  long long wall = 2*in.Volume()/Ls;
86  flops += 48LL*(long long)in.Volume() + 96LL*bulk + 120LL*wall;
87  }
88 
89  // Modification for the 4D preconditioned Mobius domain wall operator
91  const QudaParity parity, const ColorSpinorField &x, const double &k) const
92  {
93  checkDWF(in, out);
94  checkParitySpinor(in, out);
95  checkSpinorAlias(in, out);
96 
97  ApplyDomainWall4D(out, in, *gauge, k, m5, b_5, c_5, x, parity, dagger, commDim, profile);
98 
99  flops += (1320LL+48LL)*(long long)in.Volume();
100  }
101 
103  const QudaParity parity, const ColorSpinorField &x, const double &k) const
104  {
105  checkDWF(in, out);
106  checkParitySpinor(in, out);
107  checkSpinorAlias(in, out);
108 
109  ApplyDslash5(out, in, x, mass, m5, b_5, c_5, k, dagger, DSLASH5_MOBIUS_PRE);
110 
111  long long Ls = in.X(4);
112  long long bulk = (Ls-2)*(in.Volume()/Ls);
113  long long wall = 2*in.Volume()/Ls;
114  flops += (72LL+48LL)*(long long)in.Volume() + 96LL*bulk + 120LL*wall;
115  }
116 
117  // The xpay operator bakes in a factor of kappa_b^2
119  const QudaParity parity, const ColorSpinorField &x, const double &k) const
120  {
121  checkDWF(in, out);
122  checkParitySpinor(in, out);
123  checkSpinorAlias(in, out);
124 
125  ApplyDslash5(out, in, x, mass, m5, b_5, c_5, k, dagger, DSLASH5_MOBIUS);
126 
127  long long Ls = in.X(4);
128  long long bulk = (Ls-2)*(in.Volume()/Ls);
129  long long wall = 2*in.Volume()/Ls;
130  flops += (96LL)*(long long)in.Volume() + 96LL*bulk + 120LL*wall;
131  }
132 
134  {
135  checkFullSpinor(out, in);
136 
137  // FIXME broken for variable coefficients
138  double kappa_b = 0.5 / (b_5[0].real() * (4.0 + m5) + 1.0);
139 
140  // cannot use Xpay variants since it will scale incorrectly for this operator
141 
142  ColorSpinorField *tmp = nullptr;
143  if (tmp2 && tmp2->SiteSubset() == QUDA_FULL_SITE_SUBSET) tmp = tmp2;
144  bool reset = newTmp(&tmp, in);
145 
146  ApplyDslash5(out, in, in, mass, m5, b_5, c_5, 0.0, dagger, DSLASH5_MOBIUS_PRE);
148  ApplyDslash5(out, in, in, mass, m5, b_5, c_5, 0.0, dagger, DSLASH5_MOBIUS);
149  blas::axpy(-kappa_b, *tmp, out);
150 
151  long long Ls = in.X(4);
152  long long bulk = (Ls - 2) * (in.Volume() / Ls);
153  long long wall = 2 * in.Volume() / Ls;
154  flops += 72LL * (long long)in.Volume() + 96LL * bulk + 120LL * wall; // pre
155  flops += 1320LL * (long long)in.Volume(); // dslash4
156  flops += 48LL * (long long)in.Volume() + 96LL * bulk + 120LL * wall; // dslash5
157 
158  deleteTmp(&tmp, reset);
159  }
160 
162  {
163  checkFullSpinor(out, in);
164 
165  bool reset = newTmp(&tmp1, in);
166 
167  M(*tmp1, in);
168  Mdag(out, *tmp1);
169 
170  deleteTmp(&tmp1, reset);
171  }
172 
174  const QudaSolutionType solType) const
175  {
176  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
177  errorQuda("Preconditioned solution requires a preconditioned solve_type");
178  }
179 
180  src = &b;
181  sol = &x;
182  }
183 
185  {
186  // do nothing
187  }
188 
189 
191 
193 
195 
197  {
198  if (&dirac != this) {
199  DiracMobius::operator=(dirac);
200  }
201 
202  return *this;
203  }
204 
206  {
207  checkDWF(in, out);
208  checkParitySpinor(in, out);
209  checkSpinorAlias(in, out);
210 
211  ApplyDslash5(out, in, in, mass, m5, b_5, c_5, 0.0, dagger, zMobius ? M5_INV_ZMOBIUS : M5_INV_MOBIUS);
212 
213  if (0) {
214  // M5 = 1 + 0.5*kappa_b/kappa_c * D5
215  using namespace blas;
216  cudaColorSpinorField A(out);
217  Dslash5(A, out, parity);
218  printfQuda("Dslash5Xpay = %e M5inv = %e in = %e\n", norm2(A), norm2(out), norm2(in));
219  exit(0);
220  }
221 
222  long long Ls = in.X(4);
223  flops += 144LL*(long long)in.Volume()*Ls + 3LL*Ls*(Ls-1LL);
224  }
225 
226  // The xpay operator bakes in a factor of kappa_b^2
228  const ColorSpinorField &x, const double &k) const
229  {
230  checkDWF(in, out);
231  checkParitySpinor(in, out);
232  checkSpinorAlias(in, out);
233 
235 
236  long long Ls = in.X(4);
237  flops += (144LL*Ls + 48LL)*(long long)in.Volume() + 3LL*Ls*(Ls-1LL);
238  }
239 
240  // Apply the even-odd preconditioned mobius DWF operator
241  //Actually, Dslash5 will return M5 operation and M5 = 1 + 0.5*kappa_b/kappa_c * D5
243  {
244  bool reset1 = newTmp(&tmp1, in);
245 
246  int odd_bit = (matpcType == QUDA_MATPC_ODD_ODD || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) ? 1 : 0;
247  bool symmetric =(matpcType == QUDA_MATPC_EVEN_EVEN || matpcType == QUDA_MATPC_ODD_ODD) ? true : false;
248  QudaParity parity[2] = {static_cast<QudaParity>((1 + odd_bit) % 2), static_cast<QudaParity>((0 + odd_bit) % 2)};
249 
250  //QUDA_MATPC_EVEN_EVEN_ASYMMETRIC : M5 - kappa_b^2 * D4_{eo}D4pre_{oe}D5inv_{ee}D4_{eo}D4pre_{oe}
251  //QUDA_MATPC_ODD_ODD_ASYMMETRIC : M5 - kappa_b^2 * D4_{oe}D4pre_{eo}D5inv_{oo}D4_{oe}D4pre_{eo}
252  if (symmetric && !dagger) {
253  Dslash4pre(*tmp1, in, parity[1]);
254  Dslash4(out, *tmp1, parity[0]);
255  Dslash5inv(*tmp1, out, parity[0]);
256  Dslash4pre(out, *tmp1, parity[0]);
257  Dslash4(*tmp1, out, parity[1]);
258  Dslash5invXpay(out, *tmp1, parity[1], in, -1.0);
259  } else if (symmetric && dagger) {
260  Dslash5inv(*tmp1, in, parity[1]);
261  Dslash4(out, *tmp1, parity[0]);
262  Dslash4pre(*tmp1, out, parity[0]);
263  Dslash5inv(out, *tmp1, parity[0]);
264  Dslash4(*tmp1, out, parity[1]);
265  Dslash4preXpay(out, *tmp1, parity[1], in, -1.0);
266  } else if (!symmetric && !dagger) {
267  Dslash4pre(*tmp1, in, parity[1]);
268  Dslash4(out, *tmp1, parity[0]);
269  Dslash5inv(*tmp1, out, parity[0]);
270  Dslash4pre(out, *tmp1, parity[0]);
271  Dslash4(*tmp1, out, parity[1]);
272  Dslash5Xpay(out, in, parity[1], *tmp1, -1.0);
273  } else if (!symmetric && dagger) {
274  Dslash4(*tmp1, in, parity[0]);
275  Dslash4pre(out, *tmp1, parity[0]);
276  Dslash5inv(*tmp1, out, parity[0]);
277  Dslash4(out, *tmp1, parity[1]);
278  Dslash4pre(*tmp1, out, parity[1]);
279  Dslash5Xpay(out, in, parity[1], *tmp1, -1.0);
280  }
281 
282  deleteTmp(&tmp1, reset1);
283  }
284 
286  {
287  bool reset = newTmp(&tmp2, in);
288  M(*tmp2, in);
289  Mdag(out, *tmp2);
290  deleteTmp(&tmp2, reset);
291  }
292 
295  const QudaSolutionType solType) const
296  {
297  // we desire solution to preconditioned system
298  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
299  src = &b;
300  sol = &x;
301  } else { // we desire solution to full system
302  // prepare function in MDWF is not tested yet.
303  bool reset = newTmp(&tmp1, b.Even());
304 
306  // src = D5^-1 (b_e + k D4_eo * D4pre * D5^-1 b_o)
307  src = &(x.Odd());
310  Dslash4Xpay(*tmp1, *src, QUDA_EVEN_PARITY, b.Even(), 1.0);
312  sol = &(x.Even());
313  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
314  // src = b_o + k D4_oe * D4pre * D5inv b_e
315  src = &(x.Even());
318  Dslash4Xpay(*tmp1, *src, QUDA_ODD_PARITY, b.Odd(), 1.0);
320  sol = &(x.Odd());
322  // src = b_e + k D4_eo * D4pre * D5inv b_o
323  src = &(x.Odd());
324  Dslash5inv(*src, b.Odd(), QUDA_ODD_PARITY);
326  Dslash4Xpay(*src, *tmp1, QUDA_EVEN_PARITY, b.Even(), 1.0);
327  sol = &(x.Even());
329  // src = b_o + k D4_oe * D4pre * D5inv b_e
330  src = &(x.Even());
331  Dslash5inv(*src, b.Even(), QUDA_EVEN_PARITY);
333  Dslash4Xpay(*src, *tmp1, QUDA_ODD_PARITY, b.Odd(), 1.0);
334  sol = &(x.Odd());
335  } else {
336  errorQuda("MatPCType %d not valid for DiracMobiusPC", matpcType);
337  }
338  // here we use final solution to store parity solution and parity source
339  // b is now up for grabs if we want
340 
341  deleteTmp(&tmp1, reset);
342  }
343  }
344 
346  const QudaSolutionType solType) const
347  {
348  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
349  return;
350  }
351 
352  bool reset1 = newTmp(&tmp1, x.Even());
353 
354  // create full solution
355  checkFullSpinor(x, b);
358  // psi_o = M5^-1 (b_o + k_b D4_oe D4pre x_e)
359  Dslash4pre(x.Odd(), x.Even(), QUDA_EVEN_PARITY);
360  Dslash4Xpay(*tmp1, x.Odd(), QUDA_ODD_PARITY, b.Odd(), 1.0);
362  } else if (matpcType == QUDA_MATPC_ODD_ODD ||
364  // psi_e = M5^-1 (b_e + k_b D4_eo D4pre x_o)
365  Dslash4pre(x.Even(), x.Odd(), QUDA_ODD_PARITY);
366  Dslash4Xpay(*tmp1, x.Even(), QUDA_EVEN_PARITY, b.Even(), 1.0);
368  } else {
369  errorQuda("MatPCType %d not valid for DiracMobiusPC", matpcType);
370  }
371 
372  deleteTmp(&tmp1, reset1);
373  }
374 
375 } // namespace quda
unsigned long long flops
Definition: dirac_quda.h:121
DiracMobiusPC(const DiracParam &param)
void Dslash5(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define errorQuda(...)
Definition: util_quda.h:121
cudaGaugeField * gauge
Definition: dirac_quda.h:115
virtual void checkFullSpinor(const ColorSpinorField &, const ColorSpinorField &) const
Definition: dirac.cpp:146
Complex b_5[QUDA_MAX_DWF_LS]
Definition: dirac_quda.h:453
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:44
const ColorSpinorField & Even() const
void deleteTmp(ColorSpinorField **, const bool &reset) const
Definition: dirac.cpp:81
const ColorSpinorField & Odd() const
TimeProfile profile
Definition: dirac_quda.h:132
void reconstruct(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType) const
Complex c_5[QUDA_MAX_DWF_LS]
Definition: dirac_quda.h:454
Complex c_5[QUDA_MAX_DWF_LS]
Definition: dirac_quda.h:28
void ApplyDslash5(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, double m_f, double m_5, const Complex *b_5, const Complex *c_5, double a, bool dagger, Dslash5Type type)
Apply either the domain-wall / mobius Dslash5 operator or the M5 inverse operator. In the current implementation, it is expected that the color-spinor fields are 4-d preconditioned.
void M(ColorSpinorField &out, const ColorSpinorField &in) const
void Dslash5invXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, const ColorSpinorField &x, const double &k) const
double norm2(const CloverField &a, bool inverse=false)
QudaGaugeParam param
Definition: pack_test.cpp:17
bool newTmp(ColorSpinorField **, const ColorSpinorField &) const
Definition: dirac.cpp:70
virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const
int commDim[QUDA_MAX_DIM]
Definition: dirac_quda.h:130
void Dslash4(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const
Complex b_5[QUDA_MAX_DWF_LS]
Definition: dirac_quda.h:27
DiracMobius(const DiracParam &param)
Definition: dirac_mobius.cpp:7
void checkDWF(const ColorSpinorField &out, const ColorSpinorField &in) const
DiracDomainWall & operator=(const DiracDomainWall &dirac)
void checkSpinorAlias(const ColorSpinorField &, const ColorSpinorField &) const
Definition: dirac.cpp:154
void axpy(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:35
cpuColorSpinorField * in
QudaSiteSubset SiteSubset() const
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const
double mass
Definition: dirac_quda.h:117
virtual void reconstruct(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType) const
enum QudaSolutionType_s QudaSolutionType
void Dslash4pre(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const
QudaDagType dagger
Definition: dirac_quda.h:120
std::complex< double > Complex
Definition: quda_internal.h:46
void Dslash4preXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, const ColorSpinorField &x, const double &k) const
void ApplyDomainWall4D(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double m_5, const Complex *b_5, const Complex *c_5, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
Driver for applying the batched Wilson 4-d stencil to a 5-d vector with 4-d preconditioned data order...
enum QudaParity_s QudaParity
void Dslash4Xpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, const ColorSpinorField &x, const double &k) const
QudaMatPCType matpcType
Definition: dirac_quda.h:119
void Mdag(ColorSpinorField &out, const ColorSpinorField &in) const
Definition: dirac.cpp:90
void prepare(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType) const
GaugeCovDev * dirac
Definition: covdev_test.cpp:73
cpuColorSpinorField * out
#define printfQuda(...)
Definition: util_quda.h:115
void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const
const int * X() const
void Dslash5Xpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, const ColorSpinorField &x, const double &k) const
virtual ~DiracMobius()
virtual void checkParitySpinor(const ColorSpinorField &, const ColorSpinorField &) const
Definition: dirac.cpp:106
void Dslash5inv(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const
DiracMobiusPC & operator=(const DiracMobiusPC &dirac)
QudaParity parity
Definition: covdev_test.cpp:54
ColorSpinorField * tmp2
Definition: dirac_quda.h:123
DiracMobius & operator=(const DiracMobius &dirac)
ColorSpinorField * tmp1
Definition: dirac_quda.h:122
virtual void prepare(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType) const