QUDA  0.9.0
dirac_mobius.cpp
Go to the documentation of this file.
1 #include <iostream>
2 #include <dirac_quda.h>
3 #include <blas_quda.h>
4 
5 namespace quda {
6 
8  memcpy(b_5, param.b_5, sizeof(double)*param.Ls);
9  memcpy(c_5, param.c_5, sizeof(double)*param.Ls);
10  }
11 
13  memcpy(b_5, dirac.b_5, Ls);
14  memcpy(c_5, dirac.c_5, Ls);
15  }
16 
18 
20  {
21  if (&dirac != this) {
23  memcpy(b_5, dirac.b_5, Ls);
24  memcpy(c_5, dirac.c_5, Ls);
25  }
26 
27  return *this;
28  }
29 
30 // Modification for the 4D preconditioned Mobius domain wall operator
32  const QudaParity parity) const
33  {
34  if ( in.Ndim() != 5 || out.Ndim() != 5) errorQuda("Wrong number of dimensions\n");
37 
38  MDWFDslashCuda(&static_cast<cudaColorSpinorField&>(out), *gauge,
39  &static_cast<const cudaColorSpinorField&>(in),
40  parity, dagger, 0, mass, 0, b_5, c_5, m5, commDim, 0, profile);
41 
42  flops += 1320LL*(long long)in.Volume();
43  }
44 
46  {
47  if ( in.Ndim() != 5 || out.Ndim() != 5) errorQuda("Wrong number of dimensions\n");
50 
51  MDWFDslashCuda(&static_cast<cudaColorSpinorField&>(out), *gauge,
52  &static_cast<const cudaColorSpinorField&>(in),
53  parity, dagger, 0, mass, 0, b_5, c_5, m5, commDim, 1, profile);
54 
55  long long Ls = in.X(4);
56  long long bulk = (Ls-2)*(in.Volume()/Ls);
57  long long wall = 2*in.Volume()/Ls;
58  flops += 72LL*(long long)in.Volume() + 96LL*bulk + 120LL*wall;
59  }
60 
62  {
63  if ( in.Ndim() != 5 || out.Ndim() != 5) errorQuda("Wrong number of dimensions\n");
66 
67  MDWFDslashCuda(&static_cast<cudaColorSpinorField&>(out), *gauge,
68  &static_cast<const cudaColorSpinorField&>(in),
69  parity, dagger, 0, mass, 0, b_5, c_5, m5, commDim, 2, profile);
70 
71  long long Ls = in.X(4);
72  long long bulk = (Ls-2)*(in.Volume()/Ls);
73  long long wall = 2*in.Volume()/Ls;
74  flops += 48LL*(long long)in.Volume() + 96LL*bulk + 120LL*wall;
75  }
76 
77  // Modification for the 4D preconditioned Mobius domain wall operator
79  const QudaParity parity, const ColorSpinorField &x, const double &k) const
80  {
81  if ( in.Ndim() != 5 || out.Ndim() != 5) errorQuda("Wrong number of dimensions\n");
82 
85 
86  MDWFDslashCuda(&static_cast<cudaColorSpinorField&>(out), *gauge,
87  &static_cast<const cudaColorSpinorField&>(in),
88  parity, dagger, &static_cast<const cudaColorSpinorField&>(x),
89  mass, k, b_5, c_5, m5, commDim, 0, profile);
90 
91  flops += (1320LL+48LL)*(long long)in.Volume();
92  }
93 
95  const QudaParity parity, const ColorSpinorField &x, const double &k) const
96  {
97  if ( in.Ndim() != 5 || out.Ndim() != 5) errorQuda("Wrong number of dimensions\n");
98 
101 
102  MDWFDslashCuda(&static_cast<cudaColorSpinorField&>(out), *gauge,
103  &static_cast<const cudaColorSpinorField&>(in),
104  parity, dagger, &static_cast<const cudaColorSpinorField&>(x),
105  mass, k, b_5, c_5, m5, commDim, 1, profile);
106 
107  long long Ls = in.X(4);
108  long long bulk = (Ls-2)*(in.Volume()/Ls);
109  long long wall = 2*in.Volume()/Ls;
110  flops += (72LL+48LL)*(long long)in.Volume() + 96LL*bulk + 120LL*wall;
111  }
112 
113  // The xpay operator bakes in a factor of kappa_b^2
115  const QudaParity parity, const ColorSpinorField &x, const double &k) const
116  {
117  if ( in.Ndim() != 5 || out.Ndim() != 5) errorQuda("Wrong number of dimensions\n");
120 
121  MDWFDslashCuda(&static_cast<cudaColorSpinorField&>(out), *gauge,
122  &static_cast<const cudaColorSpinorField&>(in),
123  parity, dagger, &static_cast<const cudaColorSpinorField&>(x),
124  mass, k, b_5, c_5, m5, commDim, 2, profile);
125 
126  long long Ls = in.X(4);
127  long long bulk = (Ls-2)*(in.Volume()/Ls);
128  long long wall = 2*in.Volume()/Ls;
129  flops += (96LL)*(long long)in.Volume() + 96LL*bulk + 120LL*wall;
130  }
131 
133  {
134  if ( in.Ndim() != 5 || out.Ndim() != 5) errorQuda("Wrong number of dimensions\n");
135 
136  bool reset = newTmp(&tmp1, in);
138 
139  // FIXME broken for variable coefficients
140  double kappa_b = 0.5 / (b_5[0]*(4.0+m5)+1.0);
141 
142  // cannot use Xpay variants since it will scale incorrectly for this operator
143 
144  Dslash4pre(out.Odd(), in.Even(), QUDA_EVEN_PARITY);
145  Dslash4(tmp1->Even(), out.Odd(), QUDA_ODD_PARITY);
146  Dslash5(out.Odd(), in.Odd(), QUDA_ODD_PARITY);
147  blas::axpy(-kappa_b, tmp1->Even(), out.Odd());
148 
149  Dslash4pre(out.Even(), in.Odd(), QUDA_ODD_PARITY);
150  Dslash4(tmp1->Odd(), out.Even(), QUDA_EVEN_PARITY);
151  Dslash5(out.Even(), in.Even(), QUDA_EVEN_PARITY);
152  blas::axpy(-kappa_b, tmp1->Odd(), out.Even());
153 
154  deleteTmp(&tmp1, reset);
155  }
156 
158  {
160 
161  bool reset = newTmp(&tmp2, in);
162 
163  M(*tmp2, in);
164  Mdag(out, *tmp2);
165 
166  deleteTmp(&tmp2, reset);
167  }
168 
170  const QudaSolutionType solType) const
171  {
172  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
173  errorQuda("Preconditioned solution requires a preconditioned solve_type");
174  }
175 
176  src = &b;
177  sol = &x;
178  }
179 
181  {
182  // do nothing
183  }
184 
185 
187 
189 
191 
193  {
194  if (&dirac != this) {
196  }
197 
198  return *this;
199  }
200 
202  {
203  if ( in.Ndim() != 5 || out.Ndim() != 5) errorQuda("Wrong number of dimensions\n");
204 
207 
208  MDWFDslashCuda(&static_cast<cudaColorSpinorField&>(out), *gauge,
209  &static_cast<const cudaColorSpinorField&>(in),
210  parity, dagger, 0, mass, 0, b_5, c_5, m5, commDim, 3, profile);
211 
212  long long Ls = in.X(4);
213  flops += 144LL*(long long)in.Volume()*Ls + 3LL*Ls*(Ls-1LL);
214  }
215 
216  // The xpay operator bakes in a factor of kappa_b^2
218  const QudaParity parity, const ColorSpinorField &x, const double &k) const
219  {
220  if ( in.Ndim() != 5 || out.Ndim() != 5) errorQuda("Wrong number of dimensions\n");
221 
224 
225  MDWFDslashCuda(&static_cast<cudaColorSpinorField&>(out), *gauge,
226  &static_cast<const cudaColorSpinorField&>(in),
227  parity, dagger, &static_cast<const cudaColorSpinorField&>(x),
228  mass, k, b_5, c_5, m5, commDim, 3, profile);
229 
230  long long Ls = in.X(4);
231  flops += (144LL*Ls + 48LL)*(long long)in.Volume() + 3LL*Ls*(Ls-1LL);
232  }
233 
234  // Apply the even-odd preconditioned mobius DWF operator
235  //Actually, Dslash5 will return M5 operation and M5 = 1 + 0.5*kappa_b/kappa_c * D5
237  {
238  if ( in.Ndim() != 5 || out.Ndim() != 5) errorQuda("Wrong number of dimensions\n");
239 
240  bool reset1 = newTmp(&tmp1, in);
241 
242  int odd_bit = (matpcType == QUDA_MATPC_ODD_ODD || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) ? 1 : 0;
243  bool symmetric =(matpcType == QUDA_MATPC_EVEN_EVEN || matpcType == QUDA_MATPC_ODD_ODD) ? true : false;
244  QudaParity parity[2] = {static_cast<QudaParity>((1 + odd_bit) % 2), static_cast<QudaParity>((0 + odd_bit) % 2)};
245 
246  //QUDA_MATPC_EVEN_EVEN_ASYMMETRIC : M5 - kappa_b^2 * D4_{eo}D4pre_{oe}D5inv_{ee}D4_{eo}D4pre_{oe}
247  //QUDA_MATPC_ODD_ODD_ASYMMETRIC : M5 - kappa_b^2 * D4_{oe}D4pre_{eo}D5inv_{oo}D4_{oe}D4pre_{eo}
248  if (symmetric && !dagger) {
249  Dslash4pre(*tmp1, in, parity[1]);
250  Dslash4(out, *tmp1, parity[0]);
251  Dslash5inv(*tmp1, out, parity[0]);
252  Dslash4pre(out, *tmp1, parity[0]);
253  Dslash4(*tmp1, out, parity[1]);
254  Dslash5invXpay(out, *tmp1, parity[1], in, -1.0);
255  } else if (symmetric && dagger) {
256  Dslash5inv(*tmp1, in, parity[1]);
257  Dslash4(out, *tmp1, parity[0]);
258  Dslash4pre(*tmp1, out, parity[0]);
259  Dslash5inv(out, *tmp1, parity[0]);
260  Dslash4(*tmp1, out, parity[1]);
261  Dslash4preXpay(out, *tmp1, parity[1], in, -1.0);
262  } else if (!symmetric && !dagger) {
263  Dslash4pre(*tmp1, in, parity[1]);
264  Dslash4(out, *tmp1, parity[0]);
265  Dslash5inv(*tmp1, out, parity[0]);
266  Dslash4pre(out, *tmp1, parity[0]);
267  Dslash4(*tmp1, out, parity[1]);
268  Dslash5Xpay(out, in, parity[1], *tmp1, -1.0);
269  } else if (!symmetric && dagger) {
270  Dslash4(*tmp1, in, parity[0]);
271  Dslash4pre(out, *tmp1, parity[0]);
272  Dslash5inv(*tmp1, out, parity[0]);
273  Dslash4(out, *tmp1, parity[1]);
274  Dslash4pre(*tmp1, out, parity[1]);
275  Dslash5Xpay(out, in, parity[1], *tmp1, -1.0);
276  }
277 
278  deleteTmp(&tmp1, reset1);
279  }
280 
282  {
283  bool reset = newTmp(&tmp2, in);
284  M(*tmp2, in);
285  Mdag(out, *tmp2);
286  deleteTmp(&tmp2, reset);
287  }
288 
291  const QudaSolutionType solType) const
292  {
293  // we desire solution to preconditioned system
294  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
295  src = &b;
296  sol = &x;
297  } else { // we desire solution to full system
298  // prepare function in MDWF is not tested yet.
299  bool reset = newTmp(&tmp1, b.Even());
300 
302  // src = D5^-1 (b_e + k D4_eo * D4pre * D5^-1 b_o)
303  src = &(x.Odd());
304  Dslash5inv(*tmp1, b.Odd(), QUDA_ODD_PARITY);
306  Dslash4Xpay(*tmp1, *src, QUDA_EVEN_PARITY, b.Even(), 1.0);
308  sol = &(x.Even());
309  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
310  // src = b_o + k D4_oe * D4pre * D5inv b_e
311  src = &(x.Even());
312  Dslash5inv(*tmp1, b.Even(), QUDA_EVEN_PARITY);
314  Dslash4Xpay(*tmp1, *src, QUDA_ODD_PARITY, b.Odd(), 1.0);
316  sol = &(x.Odd());
318  // src = b_e + k D4_eo * D4pre * D5inv b_o
319  src = &(x.Odd());
320  Dslash5inv(*src, b.Odd(), QUDA_ODD_PARITY);
322  Dslash4Xpay(*src, *tmp1, QUDA_EVEN_PARITY, b.Even(), 1.0);
323  sol = &(x.Even());
325  // src = b_o + k D4_oe * D4pre * D5inv b_e
326  src = &(x.Even());
327  Dslash5inv(*src, b.Even(), QUDA_EVEN_PARITY);
329  Dslash4Xpay(*src, *tmp1, QUDA_ODD_PARITY, b.Odd(), 1.0);
330  sol = &(x.Odd());
331  } else {
332  errorQuda("MatPCType %d not valid for DiracMobiusPC", matpcType);
333  }
334  // here we use final solution to store parity solution and parity source
335  // b is now up for grabs if we want
336 
337  deleteTmp(&tmp1, reset);
338  }
339  }
340 
342  const QudaSolutionType solType) const
343  {
344  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
345  return;
346  }
347 
348  bool reset1 = newTmp(&tmp1, x.Even());
349 
350  // create full solution
351  checkFullSpinor(x, b);
354  // psi_o = M5^-1 (b_o + k_b D4_oe D4pre x_e)
355  Dslash4pre(x.Odd(), x.Even(), QUDA_EVEN_PARITY);
356  Dslash4Xpay(*tmp1, x.Odd(), QUDA_ODD_PARITY, b.Odd(), 1.0);
357  Dslash5inv(x.Odd(), *tmp1, QUDA_ODD_PARITY);
358  } else if (matpcType == QUDA_MATPC_ODD_ODD ||
360  // psi_e = M5^-1 (b_e + k_b D4_eo D4pre x_o)
361  Dslash4pre(x.Even(), x.Odd(), QUDA_ODD_PARITY);
362  Dslash4Xpay(*tmp1, x.Even(), QUDA_EVEN_PARITY, b.Even(), 1.0);
363  Dslash5inv(x.Even(), *tmp1, QUDA_EVEN_PARITY);
364  } else {
365  errorQuda("MatPCType %d not valid for DiracMobiusPC", matpcType);
366  }
367 
368  deleteTmp(&tmp1, reset1);
369  }
370 
371 } // namespace quda
unsigned long long flops
Definition: dirac_quda.h:100
DiracMobiusPC(const DiracParam &param)
void Dslash5(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const
const void * src
#define errorQuda(...)
Definition: util_quda.h:90
cudaGaugeField * gauge
Definition: dirac_quda.h:95
virtual void checkFullSpinor(const ColorSpinorField &, const ColorSpinorField &) const
Definition: dirac.cpp:129
double b_5[QUDA_MAX_DWF_LS]
Definition: dirac_quda.h:399
void MDWFDslashCuda(cudaColorSpinorField *out, const cudaGaugeField &gauge, const cudaColorSpinorField *in, const int parity, const int dagger, const cudaColorSpinorField *x, const double &m_f, const double &k, const double *b5, const double *c_5, const double &m5, const int *commDim, const int DS_type, TimeProfile &profile)
const ColorSpinorField & Even() const
void deleteTmp(ColorSpinorField **, const bool &reset) const
Definition: dirac.cpp:64
const ColorSpinorField & Odd() const
TimeProfile profile
Definition: dirac_quda.h:112
void reconstruct(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType) const
void M(ColorSpinorField &out, const ColorSpinorField &in) const
void Dslash5invXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, const ColorSpinorField &x, const double &k) const
QudaGaugeParam param
Definition: pack_test.cpp:17
bool newTmp(ColorSpinorField **, const ColorSpinorField &) const
Definition: dirac.cpp:53
#define b
virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const
int commDim[QUDA_MAX_DIM]
Definition: dirac_quda.h:110
void Dslash4(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const
DiracMobius(const DiracParam &param)
Definition: dirac_mobius.cpp:7
DiracDomainWall & operator=(const DiracDomainWall &dirac)
void checkSpinorAlias(const ColorSpinorField &, const ColorSpinorField &) const
Definition: dirac.cpp:137
cpuColorSpinorField * in
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const
double mass
Definition: dirac_quda.h:97
virtual void reconstruct(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType) const
enum QudaSolutionType_s QudaSolutionType
void Dslash4pre(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const
QudaDagType dagger
Definition: dirac_quda.h:99
void Dslash4preXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, const ColorSpinorField &x, const double &k) const
enum QudaParity_s QudaParity
void Dslash4Xpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, const ColorSpinorField &x, const double &k) const
void * memcpy(void *__dst, const void *__src, size_t __n)
QudaMatPCType matpcType
Definition: dirac_quda.h:98
void axpy(const double &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:150
void Mdag(ColorSpinorField &out, const ColorSpinorField &in) const
Definition: dirac.cpp:73
void prepare(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType) const
double c_5[QUDA_MAX_DWF_LS]
Definition: dirac_quda.h:400
GaugeCovDev * dirac
Definition: covdev_test.cpp:75
cpuColorSpinorField * out
void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const
void Dslash5Xpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, const ColorSpinorField &x, const double &k) const
virtual ~DiracMobius()
virtual void checkParitySpinor(const ColorSpinorField &, const ColorSpinorField &) const
Definition: dirac.cpp:89
void Dslash5inv(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const
DiracMobiusPC & operator=(const DiracMobiusPC &dirac)
QudaParity parity
Definition: covdev_test.cpp:53
ColorSpinorField * tmp2
Definition: dirac_quda.h:102
DiracMobius & operator=(const DiracMobius &dirac)
ColorSpinorField * tmp1
Definition: dirac_quda.h:101
virtual void prepare(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType) const