QUDA  v1.1.0
A library for QCD on GPUs
dirac_mobius.cpp
Go to the documentation of this file.
1 #include <iostream>
2 #include <dirac_quda.h>
3 #include <blas_quda.h>
4 
5 namespace quda {
6 
8  {
9  memcpy(b_5, param.b_5, sizeof(Complex) * param.Ls);
10  memcpy(c_5, param.c_5, sizeof(Complex) * param.Ls);
11 
12  double b = b_5[0].real();
13  double c = c_5[0].real();
14  mobius_kappa_b = 0.5 / (b * (m5 + 4.) + 1.);
15  mobius_kappa_c = 0.5 / (c * (m5 + 4.) - 1.);
16 
18 
19  // check if doing zMobius
20  for (int i = 0; i < Ls; i++) {
21  if (b_5[i].imag() != 0.0 || c_5[i].imag() != 0.0 || (i < Ls - 1 && (b_5[i] != b_5[i + 1] || c_5[i] != c_5[i + 1]))) {
22  zMobius = true;
23  }
24  }
25 
26  if (getVerbosity() > QUDA_VERBOSE) {
27  if (zMobius) {
28  printfQuda("%s: Detected variable or complex cofficients: using zMobius\n", __func__);
29  } else {
30  printfQuda("%s: Detected fixed real cofficients: using regular Mobius\n", __func__);
31  }
32  }
33 
34  if (zMobius) { errorQuda("zMobius has NOT been fully tested in QUDA.\n"); }
35  }
36 
37  // Modification for the 4D preconditioned Mobius domain wall operator
39  {
40  checkDWF(in, out);
41  checkParitySpinor(in, out);
42  checkSpinorAlias(in, out);
43 
44  ApplyDomainWall4D(out, in, *gauge, 0.0, 0.0, nullptr, nullptr, in, parity, dagger, commDim, profile);
45 
46  flops += 1320LL * (long long)in.Volume();
47  }
48 
50  {
51  checkDWF(in, out);
52  checkParitySpinor(in, out);
53  checkSpinorAlias(in, out);
54 
55  ApplyDslash5(out, in, in, mass, m5, b_5, c_5, 0.0, dagger, DSLASH5_MOBIUS_PRE);
56 
57  long long Ls = in.X(4);
58  long long bulk = (Ls - 2) * (in.Volume() / Ls);
59  long long wall = 2 * in.Volume() / Ls;
60  flops += 72LL * (long long)in.Volume() + 96LL * bulk + 120LL * wall;
61  }
62 
63  // Unlike DWF-4d, the Mobius variant here applies the full M5 operator and not just D5
65  {
66  checkDWF(in, out);
67  checkParitySpinor(in, out);
68  checkSpinorAlias(in, out);
69 
70  ApplyDslash5(out, in, in, mass, m5, b_5, c_5, 0.0, dagger, DSLASH5_MOBIUS);
71 
72  long long Ls = in.X(4);
73  long long bulk = (Ls - 2) * (in.Volume() / Ls);
74  long long wall = 2 * in.Volume() / Ls;
75  flops += 48LL * (long long)in.Volume() + 96LL * bulk + 120LL * wall;
76  }
77 
78  // Modification for the 4D preconditioned Mobius domain wall operator
80  const ColorSpinorField &x, const double &k) const
81  {
82  checkDWF(in, out);
83  checkParitySpinor(in, out);
84  checkSpinorAlias(in, out);
85 
86  ApplyDomainWall4D(out, in, *gauge, k, m5, b_5, c_5, x, parity, dagger, commDim, profile);
87 
88  flops += 1320LL * (long long)in.Volume();
89  }
90 
92  const ColorSpinorField &x, const double &k) const
93  {
94  checkDWF(in, out);
95  checkParitySpinor(in, out);
96  checkSpinorAlias(in, out);
97 
98  ApplyDslash5(out, in, x, mass, m5, b_5, c_5, k, dagger, DSLASH5_MOBIUS_PRE);
99 
100  long long Ls = in.X(4);
101  long long bulk = (Ls - 2) * (in.Volume() / Ls);
102  long long wall = 2 * in.Volume() / Ls;
103 
104  flops += 72LL * (long long)in.Volume() + 96LL * bulk + 120LL * wall;
105  }
106 
107  // The xpay operator bakes in a factor of kappa_b^2
109  const ColorSpinorField &x, const double &k) const
110  {
111  checkDWF(in, out);
112  checkParitySpinor(in, out);
113  checkSpinorAlias(in, out);
114 
115  ApplyDslash5(out, in, x, mass, m5, b_5, c_5, k, dagger, DSLASH5_MOBIUS);
116 
117  long long Ls = in.X(4);
118  long long bulk = (Ls - 2) * (in.Volume() / Ls);
119  long long wall = 2 * in.Volume() / Ls;
120  flops += 96LL * (long long)in.Volume() + 96LL * bulk + 120LL * wall;
121  }
122 
124  {
125  checkFullSpinor(out, in);
126 
127  // zMobius breaks the following code. Refer to the zMobius check in DiracMobius::DiracMobius(param)
128  double mobius_kappa_b = 0.5 / (b_5[0].real() * (4.0 + m5) + 1.0);
129 
130  // cannot use Xpay variants since it will scale incorrectly for this operator
131 
132  ColorSpinorField *tmp = nullptr;
134  bool reset = newTmp(&tmp, in);
135 
136  if (dagger == QUDA_DAG_NO) {
137  ApplyDslash5(out, in, in, mass, m5, b_5, c_5, 0.0, dagger, DSLASH5_MOBIUS_PRE);
139  ApplyDslash5(out, in, in, mass, m5, b_5, c_5, 0.0, dagger, DSLASH5_MOBIUS);
140  } else {
141  // the third term is added, not multiplied, so we only need to swap the first two in the dagger
143  ApplyDslash5(*tmp, out, in, mass, m5, b_5, c_5, 0.0, dagger, DSLASH5_MOBIUS_PRE);
144  ApplyDslash5(out, in, in, mass, m5, b_5, c_5, 0.0, dagger, DSLASH5_MOBIUS);
145  }
146  blas::axpy(-mobius_kappa_b, *tmp, out);
147 
148  long long Ls = in.X(4);
149  long long bulk = (Ls - 2) * (in.Volume() / Ls);
150  long long wall = 2 * in.Volume() / Ls;
151  flops += 72LL * (long long)in.Volume() + 96LL * bulk + 120LL * wall; // pre
152  flops += 1320LL * (long long)in.Volume(); // dslash4
153  flops += 48LL * (long long)in.Volume() + 96LL * bulk + 120LL * wall; // dslash5
154 
155  deleteTmp(&tmp, reset);
156  }
157 
159  {
160  checkFullSpinor(out, in);
161 
162  bool reset = newTmp(&tmp1, in);
163 
164  M(*tmp1, in);
165  Mdag(out, *tmp1);
166 
167  deleteTmp(&tmp1, reset);
168  }
169 
171  const QudaSolutionType solType) const
172  {
173  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
174  errorQuda("Preconditioned solution requires a preconditioned solve_type");
175  }
176 
177  src = &b;
178  sol = &x;
179  }
180 
182  {
183  // do nothing
184  }
185 
186  DiracMobiusPC::DiracMobiusPC(const DiracParam &param) : DiracMobius(param), extended_gauge(nullptr)
187  {
188  // do nothing
189  }
190 
192  {
193  // do nothing
194  }
195 
197  {
198  if (extended_gauge) delete extended_gauge;
199  }
200 
202  {
203  if (&dirac != this) {
205  extended_gauge = nullptr;
206  }
207 
208  return *this;
209  }
210 
212  {
213  checkDWF(in, out);
214  checkParitySpinor(in, out);
215  checkSpinorAlias(in, out);
216 
217  ApplyDslash5(out, in, in, mass, m5, b_5, c_5, 0.0, dagger, zMobius ? M5_INV_ZMOBIUS : M5_INV_MOBIUS);
218 
219  if (0) {
220  // M5 = 1 + 0.5*kappa_b/kappa_c * D5
221  using namespace blas;
222  cudaColorSpinorField A(out);
223  Dslash5(A, out, parity);
224  printfQuda("Dslash5Xpay = %e M5inv = %e in = %e\n", norm2(A), norm2(out), norm2(in));
225  exit(0);
226  }
227 
228  long long Ls = in.X(4);
229  flops += 144LL * (long long)in.Volume() * Ls + 3LL * Ls * (Ls - 1LL);
230  }
231 
232  // The xpay operator bakes in a factor of kappa_b^2
234  const ColorSpinorField &x, const double &k) const
235  {
236  checkDWF(in, out);
237  checkParitySpinor(in, out);
238  checkSpinorAlias(in, out);
239 
241 
242  long long Ls = in.X(4);
243  flops += (144LL * Ls + 48LL) * (long long)in.Volume() + 3LL * Ls * (Ls - 1LL);
244  }
245 
246  // Apply the even-odd preconditioned mobius DWF operator
247  // Actually, Dslash5 will return M5 operation and M5 = 1 + 0.5*kappa_b/kappa_c * D5
249  {
250  bool reset1 = newTmp(&tmp1, in);
251 
252  int odd_bit = (matpcType == QUDA_MATPC_ODD_ODD || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) ? 1 : 0;
253  bool symmetric = (matpcType == QUDA_MATPC_EVEN_EVEN || matpcType == QUDA_MATPC_ODD_ODD) ? true : false;
254  QudaParity parity[2] = {static_cast<QudaParity>((1 + odd_bit) % 2), static_cast<QudaParity>((0 + odd_bit) % 2)};
255 
256  // QUDA_MATPC_EVEN_EVEN_ASYMMETRIC : M5 - kappa_b^2 * D4_{eo}D4pre_{oe}D5inv_{ee}D4_{eo}D4pre_{oe}
257  // QUDA_MATPC_ODD_ODD_ASYMMETRIC : M5 - kappa_b^2 * D4_{oe}D4pre_{eo}D5inv_{oo}D4_{oe}D4pre_{eo}
258  if (symmetric && !dagger) {
259  Dslash4pre(*tmp1, in, parity[1]);
260  Dslash4(out, *tmp1, parity[0]);
261  Dslash5inv(*tmp1, out, parity[0]);
262  Dslash4pre(out, *tmp1, parity[0]);
263  Dslash4(*tmp1, out, parity[1]);
264  Dslash5invXpay(out, *tmp1, parity[1], in, -1.0);
265  } else if (symmetric && dagger) {
266  Dslash5inv(*tmp1, in, parity[1]);
267  Dslash4(out, *tmp1, parity[0]);
268  Dslash4pre(*tmp1, out, parity[0]);
269  Dslash5inv(out, *tmp1, parity[0]);
270  Dslash4(*tmp1, out, parity[1]);
271  Dslash4preXpay(out, *tmp1, parity[1], in, -1.0);
272  } else if (!symmetric && !dagger) {
273  Dslash4pre(*tmp1, in, parity[1]);
274  Dslash4(out, *tmp1, parity[0]);
275  Dslash5inv(*tmp1, out, parity[0]);
276  Dslash4pre(out, *tmp1, parity[0]);
277  Dslash4(*tmp1, out, parity[1]);
278  Dslash5Xpay(out, in, parity[1], *tmp1, -1.0);
279  } else if (!symmetric && dagger) {
280  Dslash4(*tmp1, in, parity[0]);
281  Dslash4pre(out, *tmp1, parity[0]);
282  Dslash5inv(*tmp1, out, parity[0]);
283  Dslash4(out, *tmp1, parity[1]);
284  Dslash4pre(*tmp1, out, parity[1]);
285  Dslash5Xpay(out, in, parity[1], *tmp1, -1.0);
286  }
287 
288  deleteTmp(&tmp1, reset1);
289  }
290 
292  {
293  bool reset = newTmp(&tmp2, in);
294  M(*tmp2, in);
295  Mdag(out, *tmp2);
296  deleteTmp(&tmp2, reset);
297  }
298 
300  const QudaSolutionType solType) const
301  {
302  // we desire solution to preconditioned system
303  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
304  src = &b;
305  sol = &x;
306  } else { // we desire solution to full system
307  // prepare function in MDWF is not tested yet.
308  bool reset = newTmp(&tmp1, b.Even());
309 
311  // src = D5^-1 (b_e + k D4_eo * D4pre * D5^-1 b_o)
312  src = &(x.Odd());
315  Dslash4Xpay(*tmp1, *src, QUDA_EVEN_PARITY, b.Even(), 1.0);
317  sol = &(x.Even());
318  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
319  // src = b_o + k D4_oe * D4pre * D5inv b_e
320  src = &(x.Even());
323  Dslash4Xpay(*tmp1, *src, QUDA_ODD_PARITY, b.Odd(), 1.0);
325  sol = &(x.Odd());
327  // src = b_e + k D4_eo * D4pre * D5inv b_o
328  src = &(x.Odd());
329  Dslash5inv(*src, b.Odd(), QUDA_ODD_PARITY);
331  Dslash4Xpay(*src, *tmp1, QUDA_EVEN_PARITY, b.Even(), 1.0);
332  sol = &(x.Even());
334  // src = b_o + k D4_oe * D4pre * D5inv b_e
335  src = &(x.Even());
336  Dslash5inv(*src, b.Even(), QUDA_EVEN_PARITY);
338  Dslash4Xpay(*src, *tmp1, QUDA_ODD_PARITY, b.Odd(), 1.0);
339  sol = &(x.Odd());
340  } else {
341  errorQuda("MatPCType %d not valid for DiracMobiusPC", matpcType);
342  }
343  // here we use final solution to store parity solution and parity source
344  // b is now up for grabs if we want
345 
346  deleteTmp(&tmp1, reset);
347  }
348  }
349 
351  {
352  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { return; }
353 
354  bool reset1 = newTmp(&tmp1, x.Even());
355 
356  // create full solution
357  checkFullSpinor(x, b);
359  // psi_o = M5^-1 (b_o + k_b D4_oe D4pre x_e)
360  Dslash4pre(x.Odd(), x.Even(), QUDA_EVEN_PARITY);
361  Dslash4Xpay(*tmp1, x.Odd(), QUDA_ODD_PARITY, b.Odd(), 1.0);
364  // psi_e = M5^-1 (b_e + k_b D4_eo D4pre x_o)
365  Dslash4pre(x.Even(), x.Odd(), QUDA_ODD_PARITY);
366  Dslash4Xpay(*tmp1, x.Even(), QUDA_EVEN_PARITY, b.Even(), 1.0);
368  } else {
369  errorQuda("MatPCType %d not valid for DiracMobiusPC", matpcType);
370  }
371 
372  deleteTmp(&tmp1, reset1);
373  }
374 
376  {
377  if (zMobius) { errorQuda("DiracMobiusPC::MdagMLocal doesn't currently support zMobius.\n"); }
378 
379  int shift0[4] = {0, 0, 0, 0};
380  int shift1[4];
381  int shift2[4];
382 
383  for (int d = 0; d < 4; d++) {
384  shift1[d] = comm_dim_partitioned(d) ? 1 : 0;
385  shift2[d] = comm_dim_partitioned(d) ? 2 : 0;
386  }
387 
388  if (extended_gauge == nullptr) { extended_gauge = createExtendedGauge(*gauge, shift2, profile, true); }
389 
390  checkDWF(in, out);
391  // checkParitySpinor(in, out);
392  checkSpinorAlias(in, out);
393 
396 
399 
400  csParam.x[0] += shift2[0]; // x direction is checkerboarded
401  for (int d = 1; d < 4; ++d) { csParam.x[d] += shift2[d] * 2; }
404 
405  int odd_bit = (getMatPCType() == QUDA_MATPC_ODD_ODD) ? 1 : 0;
406  QudaParity parity[2] = {static_cast<QudaParity>((1 + odd_bit) % 2), static_cast<QudaParity>((0 + odd_bit) % 2)};
408  mobius_tensor_core::apply_fused_dslash(*unextended_tmp2, in, *extended_gauge, *unextended_tmp2, in, mass, m5, b_5,
409  c_5, dagger, parity[1], shift0, shift0, MdwfFusedDslashType::D5PRE);
410 
411  mobius_tensor_core::apply_fused_dslash(*extended_tmp2, *unextended_tmp2, *extended_gauge, *extended_tmp2,
412  *unextended_tmp2, mass, m5, b_5, c_5, dagger, parity[0], shift1, shift2,
414 
415  mobius_tensor_core::apply_fused_dslash(*extended_tmp1, *extended_tmp2, *extended_gauge, *unextended_tmp1, in,
416  mass, m5, b_5, c_5, dagger, parity[1], shift0, shift1,
418 
419  mobius_tensor_core::apply_fused_dslash(*extended_tmp2, *extended_tmp1, *extended_gauge, *extended_tmp2,
420  *extended_tmp1, mass, m5, b_5, c_5, dagger, parity[0], shift1, shift1,
422 
423  mobius_tensor_core::apply_fused_dslash(out, *extended_tmp2, *extended_gauge, out, *unextended_tmp1, mass, m5, b_5,
424  c_5, dagger, parity[1], shift2, shift2, MdwfFusedDslashType::D4DAG_D5PREDAG);
425 
426  const long long Ls = in.X(4);
427  const long long mat = 2ll * 4ll * Ls - 1ll; // (multiplicaiton-add) * (spin) * Ls - 1
428  const long long hop = 7ll * 8ll; // 8 for eight directions
429 
430  long long vol;
431  long long halo_vol;
432 
433  vol = (2 * in.X(0)) * in.X(1) * in.X(2) * in.X(3) * Ls / 2ll;
434  flops += vol * 24ll * mat;
435 
436  vol = (2 * in.X(0) + 2 * 1) * (in.X(1) + 2 * 1) * (in.X(2) + 2 * 1) * (in.X(3) + 2 * 1) * Ls / 2ll;
437  halo_vol = (2 * in.X(0)) * in.X(1) * in.X(2) * in.X(3) * Ls / 2ll;
438  flops += halo_vol * 24ll * hop + vol * 24ll * mat;
439 
440  vol = (2 * in.X(0) + 2 * 2) * (in.X(1) + 2 * 2) * (in.X(2) + 2 * 2) * (in.X(3) + 2 * 2) * Ls / 2ll;
441  halo_vol = (2 * in.X(0) + 2 * 1) * (in.X(1) + 2 * 1) * (in.X(2) + 2 * 1) * (in.X(3) + 2 * 1) * Ls / 2ll;
442  flops += halo_vol * 24ll * hop + vol * 24ll * mat * 2ll;
443 
444  vol = (2 * in.X(0) + 2 * 1) * (in.X(1) + 2 * 1) * (in.X(2) + 2 * 1) * (in.X(3) + 2 * 1) * Ls / 2ll;
445  flops += vol * 24ll * (hop + mat);
446 
447  vol = (2 * in.X(0)) * in.X(1) * in.X(2) * in.X(3) * Ls / 2ll;
448  flops += vol * 24ll * (hop + mat);
449 
450  delete extended_tmp2;
451  delete extended_tmp1;
452 
453  delete unextended_tmp1;
454  delete unextended_tmp2;
455 
456  } else {
457  errorQuda("DiracMobiusPC::MdagMLocal(...) only supports half and quarter precision");
458  }
459  }
460 
461  // Copy the EOFA specific parameters
466  mq1(param.mq1),
467  mq2(param.mq2),
468  mq3(param.mq3)
469  {
470  // Initiaize the EOFA parameters here: u, x, y
471 
472  if (zMobius) { errorQuda("DiracMobiusEofa doesn't currently support zMobius.\n"); }
473 
474  double b = b_5[0].real();
475  double c = c_5[0].real();
476 
477  double alpha = b + c;
478 
479  double eofa_norm = alpha * (mq3 - mq2) * std::pow(alpha + 1., 2. * Ls)
480  / (std::pow(alpha + 1., Ls) + mq2 * std::pow(alpha - 1., Ls))
481  / (std::pow(alpha + 1., Ls) + mq3 * std::pow(alpha - 1., Ls));
482 
483  double N = (eofa_pm ? +1. : -1.) * (2. * this->eofa_shift * eofa_norm)
484  * (std::pow(alpha + 1., Ls) + this->mq1 * std::pow(alpha - 1., Ls)) / (b * (m5 + 4.) + 1.);
485 
486  // Here the signs are somewhat mixed:
487  // There is one -1 from N for eofa_pm = minus, thus the u_- here is actually -u_- in the document
488  // It turns out this actually simplies things.
489  for (int s = 0; s < Ls; s++) {
490  eofa_u[eofa_pm ? s : Ls - 1 - s]
491  = N * std::pow(-1., s) * std::pow(alpha - 1., s) / std::pow(alpha + 1., Ls + s + 1);
492  }
493 
494  double factor = -mobius_kappa * mass;
495  if (eofa_pm) {
496  // eofa_pm = plus
497  // Computing x
498  eofa_x[0] = eofa_u[0];
499  for (int s = Ls - 1; s > 0; s--) {
500  eofa_x[0] -= factor * eofa_u[s];
501  factor *= -mobius_kappa;
502  }
503  eofa_x[0] /= 1. + factor;
504  for (int s = 1; s < Ls; s++) { eofa_x[s] = eofa_x[s - 1] * (-mobius_kappa) + eofa_u[s]; }
505  // Computing y
506  eofa_y[Ls - 1] = 1. / (1. + factor);
508  for (int s = Ls - 1; s > 0; s--) { eofa_y[s - 1] = eofa_y[s] * (-mobius_kappa); }
509  } else {
510  // eofa_pm = minus
511  // Computing x
512  eofa_x[Ls - 1] = eofa_u[Ls - 1];
513  for (int s = 0; s < Ls - 1; s++) {
514  eofa_x[Ls - 1] -= factor * eofa_u[s];
515  factor *= -mobius_kappa;
516  }
517  eofa_x[Ls - 1] /= 1. + factor;
518  for (int s = Ls - 1; s > 0; s--) { eofa_x[s - 1] = eofa_x[s] * (-mobius_kappa) + eofa_u[s - 1]; }
519  // Computing y
520  eofa_y[0] = 1. / (1. + factor);
522  for (int s = 1; s < Ls; s++) { eofa_y[s] = eofa_y[s - 1] * (-mobius_kappa); }
523  }
524  m5inv_fac = 0.5 / (1. + factor); // 0.5 for the spin project factor
525  sherman_morrison_fac = -0.5 / (1. + sherman_morrison_fac); // 0.5 for the spin project factor
526  }
527 
529  {
530  if (in.Ndim() != 5 || out.Ndim() != 5) errorQuda("Wrong number of dimensions\n");
531 
532  checkDWF(in, out);
533  checkSpinorAlias(in, out);
534 
537 
538  long long Ls = in.X(4);
539  long long bulk = (Ls - 2) * (in.Volume() / Ls);
540  long long wall = 2 * in.Volume() / Ls;
541 
542  // 96 = 48 + 48, the second 48 from EOFA
543  flops += 96LL * (long long)in.Volume() + 96LL * bulk + 120LL * wall;
544  }
545 
547  double a) const
548  {
549  if (in.Ndim() != 5 || out.Ndim() != 5) errorQuda("Wrong number of dimensions\n");
550 
551  checkDWF(in, out);
552  checkSpinorAlias(in, out);
553 
554  a *= mobius_kappa_b * mobius_kappa_b; // a = a * kappa_b^2
555  // The kernel will actually do (m5 * in - kappa_b^2 * x)
558 
559  long long Ls = in.X(4);
560  long long bulk = (Ls - 2) * (in.Volume() / Ls);
561  long long wall = 2 * in.Volume() / Ls;
562 
563  // 144 = 96 + 48, the 48 from EOFA
564  flops += 144LL * (long long)in.Volume() + 96LL * bulk + 120LL * wall;
565  }
566 
568  {
569  checkFullSpinor(out, in);
570 
571  // FIXME broken for variable coefficients
572  double mobius_kappa_b = 0.5 / (b_5[0].real() * (4.0 + m5) + 1.0);
573 
574  // cannot use Xpay variants since it will scale incorrectly for this operator
575 
576  ColorSpinorField *tmp = nullptr;
578  bool reset = newTmp(&tmp, in);
579 
580  if (dagger == QUDA_DAG_NO) {
581  ApplyDslash5(out, in, in, mass, m5, b_5, c_5, 0.0, dagger, DSLASH5_MOBIUS_PRE);
585  } else {
587  ApplyDslash5(*tmp, out, out, mass, m5, b_5, c_5, 0.0, dagger, DSLASH5_MOBIUS_PRE);
590  }
591  blas::axpy(-mobius_kappa_b, *tmp, out);
592 
593  long long Ls = in.X(4);
594  long long bulk = (Ls - 2) * (in.Volume() / Ls);
595  long long wall = 2 * in.Volume() / Ls;
596  flops += 72LL * (long long)in.Volume() + 96LL * bulk + 120LL * wall; // pre
597  flops += 1320LL * (long long)in.Volume(); // dslash4
598 
599  // 96 = 48 + 48, the second 48 from EOFA
600  flops += 96LL * (long long)in.Volume() + 96LL * bulk + 120LL * wall; // dslash5
601 
602  deleteTmp(&tmp, reset);
603  }
604 
606  {
607  checkFullSpinor(out, in);
608 
609  bool reset = newTmp(&tmp1, in);
610 
611  M(*tmp1, in);
612  Mdag(out, *tmp1);
613 
614  deleteTmp(&tmp1, reset);
615  }
616 
618  ColorSpinorField &b, const QudaSolutionType solType) const
619  {
620  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
621  errorQuda("Preconditioned solution requires a preconditioned solve_type");
622  }
623 
624  src = &b;
625  sol = &x;
626  }
627 
629  {
630  // do nothing
631  }
632 
634 
636  {
637  if (in.Ndim() != 5 || out.Ndim() != 5) errorQuda("Wrong number of dimensions\n");
638 
639  checkDWF(in, out);
640  // checkParitySpinor(in, out);
641  checkSpinorAlias(in, out);
642 
645 
646  long long Ls = in.X(4);
647  flops += (192LL * Ls + 96LL) * (long long)in.Volume() + 3LL * Ls * (Ls - 1LL);
648  }
649 
651  double a) const
652  {
653  if (in.Ndim() != 5 || out.Ndim() != 5) errorQuda("Wrong number of dimensions\n");
654 
655  checkDWF(in, out);
656  checkParitySpinor(in, out);
657  checkSpinorAlias(in, out);
658 
659  a *= mobius_kappa_b * mobius_kappa_b; // a = a * kappa_b^2
660  // The kernel will actually do (x - kappa_b^2 * m5inv * in)
663 
664  long long Ls = in.X(4);
665  flops += (192LL * Ls + 48LL + 96LL) * (long long)in.Volume() + 3LL * Ls * (Ls - 1LL);
666  }
667 
668  // Apply the even-odd preconditioned mobius DWF EOFA operator
670  {
671  bool reset1 = newTmp(&tmp1, in);
672 
673  int odd_bit = (matpcType == QUDA_MATPC_ODD_ODD || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) ? 1 : 0;
674  bool symmetric = (matpcType == QUDA_MATPC_EVEN_EVEN || matpcType == QUDA_MATPC_ODD_ODD) ? true : false;
675  QudaParity parity[2] = {static_cast<QudaParity>((1 + odd_bit) % 2), static_cast<QudaParity>((0 + odd_bit) % 2)};
676 
677  // QUDA_MATPC_EVEN_EVEN_ASYMMETRIC : M5 - kappa_b^2 * D4_{eo}D4pre_{oe}D5inv_{ee}D4_{eo}D4pre_{oe}
678  // QUDA_MATPC_ODD_ODD_ASYMMETRIC : M5 - kappa_b^2 * D4_{oe}D4pre_{eo}D5inv_{oo}D4_{oe}D4pre_{eo}
679  if (symmetric && !dagger) {
680  Dslash4pre(*tmp1, in, parity[1]);
681  Dslash4(out, *tmp1, parity[0]);
682  m5inv_eofa(*tmp1, out);
683  Dslash4pre(out, *tmp1, parity[0]);
684  Dslash4(*tmp1, out, parity[1]);
685  m5inv_eofa_xpay(out, *tmp1, in, -1.);
686  } else if (symmetric && dagger) {
687  m5inv_eofa(*tmp1, in);
688  Dslash4(out, *tmp1, parity[0]);
689  Dslash4pre(*tmp1, out, parity[0]);
690  m5inv_eofa(out, *tmp1);
691  Dslash4(*tmp1, out, parity[1]);
692  Dslash4preXpay(out, *tmp1, parity[1], in, -1.);
693  } else if (!symmetric && !dagger) {
694  Dslash4pre(*tmp1, in, parity[1]);
695  Dslash4(out, *tmp1, parity[0]);
696  m5inv_eofa(*tmp1, out);
697  Dslash4pre(out, *tmp1, parity[0]);
698  Dslash4(*tmp1, out, parity[1]);
699  m5_eofa_xpay(out, in, *tmp1, -1.);
700  } else if (!symmetric && dagger) {
701  Dslash4(*tmp1, in, parity[0]);
702  Dslash4pre(out, *tmp1, parity[0]);
703  m5inv_eofa(*tmp1, out);
704  Dslash4(out, *tmp1, parity[1]);
705  Dslash4pre(*tmp1, out, parity[1]);
706  m5_eofa_xpay(out, in, *tmp1, -1.);
707  }
708 
709  deleteTmp(&tmp1, reset1);
710  }
711 
713  ColorSpinorField &b, const QudaSolutionType solType) const
714  {
715  // we desire solution to preconditioned system
716  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
717  src = &b;
718  sol = &x;
719  } else {
720  // we desire solution to full system
721  bool reset = newTmp(&tmp1, b.Even());
723  // src = D5^-1 (b_e + k D4_eo * D4pre * D5^-1 b_o)
724  src = &(x.Odd());
725  m5inv_eofa(*tmp1, b.Odd());
727  Dslash4Xpay(*tmp1, *src, QUDA_EVEN_PARITY, b.Even(), 1.0);
728  m5inv_eofa(*src, *tmp1);
729  sol = &(x.Even());
730  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
731  // src = b_o + k D4_oe * D4pre * D5inv b_e
732  src = &(x.Even());
733  m5inv_eofa(*tmp1, b.Even());
735  Dslash4Xpay(*tmp1, *src, QUDA_ODD_PARITY, b.Odd(), 1.0);
736  m5inv_eofa(*src, *tmp1);
737  sol = &(x.Odd());
739  // src = b_e + k D4_eo * D4pre * D5inv b_o
740  src = &(x.Odd());
741  m5inv_eofa(*src, b.Odd());
743  Dslash4Xpay(*src, *tmp1, QUDA_EVEN_PARITY, b.Even(), 1.0);
744  sol = &(x.Even());
746  // src = b_o + k D4_oe * D4pre * D5inv b_e
747  src = &(x.Even());
748  m5inv_eofa(*src, b.Even());
750  Dslash4Xpay(*src, *tmp1, QUDA_ODD_PARITY, b.Odd(), 1.0);
751  sol = &(x.Odd());
752  } else {
753  errorQuda("MatPCType %d not valid for DiracMobiusEofaPC", matpcType);
754  }
755  // here we use final solution to store parity solution and parity source
756  // b is now up for grabs if we want
757  deleteTmp(&tmp1, reset);
758  }
759  }
760 
762  {
763  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { return; }
764 
765  bool reset1 = newTmp(&tmp1, x.Even());
766 
767  // create full solution
768  checkFullSpinor(x, b);
770  // psi_o = M5^-1 (b_o + k_b D4_oe D4pre x_e)
771  Dslash4pre(x.Odd(), x.Even(), QUDA_EVEN_PARITY);
772  Dslash4Xpay(*tmp1, x.Odd(), QUDA_ODD_PARITY, b.Odd(), 1.0);
773  m5inv_eofa(x.Odd(), *tmp1);
775  // psi_e = M5^-1 (b_e + k_b D4_eo D4pre x_o)
776  Dslash4pre(x.Even(), x.Odd(), QUDA_ODD_PARITY);
777  Dslash4Xpay(*tmp1, x.Even(), QUDA_EVEN_PARITY, b.Even(), 1.0);
778  m5inv_eofa(x.Even(), *tmp1);
779  } else {
780  errorQuda("MatPCType %d not valid for DiracMobiusPC", matpcType);
781  }
782 
783  deleteTmp(&tmp1, reset1);
784  }
785 
787  {
788  bool reset = newTmp(&tmp2, in);
789  M(*tmp2, in);
790  Mdag(out, *tmp2);
791  deleteTmp(&tmp2, reset);
792  }
793 
794  void
796  const ColorSpinorField &in) const // ye = Mee * xe + Meo * xo, yo = Moo * xo + Moe * xe
797  {
798  checkFullSpinor(out, in);
799  bool reset1 = newTmp(&tmp1, in.Odd());
800  bool reset2 = newTmp(&tmp2, in.Odd());
801  if (!dagger) {
802  // Even
803  m5_eofa(*tmp1, in.Even());
804  Dslash4pre(*tmp2, in.Odd(), QUDA_ODD_PARITY);
805  Dslash4Xpay(out.Even(), *tmp2, QUDA_EVEN_PARITY, *tmp1, -1.);
806  // Odd
807  m5_eofa(*tmp1, in.Odd());
808  Dslash4pre(*tmp2, in.Even(), QUDA_EVEN_PARITY);
809  Dslash4Xpay(out.Odd(), *tmp2, QUDA_ODD_PARITY, *tmp1, -1.);
810  } else {
811  printfQuda("Quda EOFA full dslash dagger=yes\n");
812  // Even
813  m5_eofa(*tmp1, in.Even());
814  Dslash4(*tmp2, in.Odd(), QUDA_EVEN_PARITY);
815  Dslash4preXpay(out.Even(), *tmp2, QUDA_EVEN_PARITY, *tmp1, -1. / mobius_kappa_b);
816  // Odd
817  m5_eofa(*tmp1, in.Odd());
818  Dslash4(*tmp2, in.Even(), QUDA_ODD_PARITY);
819  Dslash4preXpay(out.Odd(), *tmp2, QUDA_ODD_PARITY, *tmp1, -1. / mobius_kappa_b);
820  }
821  deleteTmp(&tmp1, reset1);
822  deleteTmp(&tmp2, reset2);
823  }
824 } // namespace quda
825 #include <iostream>
826 #include <dirac_quda.h>
827 #include <blas_quda.h>
const ColorSpinorField & Odd() const
QudaSiteSubset SiteSubset() const
static ColorSpinorField * Create(const ColorSpinorParam &param)
const ColorSpinorField & Even() const
const int * X() const
DiracDomainWall & operator=(const DiracDomainWall &dirac)
void checkDWF(const ColorSpinorField &out, const ColorSpinorField &in) const
unsigned long long flops
Definition: dirac_quda.h:150
bool newTmp(ColorSpinorField **, const ColorSpinorField &) const
Definition: dirac.cpp:72
QudaMatPCType matpcType
Definition: dirac_quda.h:148
cudaGaugeField * gauge
Definition: dirac_quda.h:144
double mass
Definition: dirac_quda.h:146
void deleteTmp(ColorSpinorField **, const bool &reset) const
Definition: dirac.cpp:83
virtual void checkParitySpinor(const ColorSpinorField &, const ColorSpinorField &) const
Check parity spinors are usable (check geometry ?)
Definition: dirac.cpp:108
ColorSpinorField * tmp1
Definition: dirac_quda.h:151
QudaMatPCType getMatPCType() const
returns preconditioning type
Definition: dirac_quda.h:323
TimeProfile profile
Definition: dirac_quda.h:161
QudaDagType dagger
Definition: dirac_quda.h:149
void checkSpinorAlias(const ColorSpinorField &, const ColorSpinorField &) const
check spinors do not alias
Definition: dirac.cpp:146
virtual void checkFullSpinor(const ColorSpinorField &, const ColorSpinorField &) const
check full spinors are compatible (check geometry ?)
Definition: dirac.cpp:138
ColorSpinorField * tmp2
Definition: dirac_quda.h:152
int commDim[QUDA_MAX_DIM]
Definition: dirac_quda.h:159
void Mdag(ColorSpinorField &out, const ColorSpinorField &in) const
Apply Mdag (daggered operator of M.
Definition: dirac.cpp:92
double eofa_y[QUDA_MAX_DWF_LS]
Definition: dirac_quda.h:903
void m5_eofa_xpay(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, double a=-1.) const
virtual void prepare(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType) const
void m5_eofa(ColorSpinorField &out, const ColorSpinorField &in) const
virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const
Apply MdagM operator which may be optimized.
double eofa_x[QUDA_MAX_DWF_LS]
Definition: dirac_quda.h:902
DiracMobiusEofa(const DiracParam &param)
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const
Apply M for the dirac op. E.g. the Schur Complement operator.
double eofa_u[QUDA_MAX_DWF_LS]
Definition: dirac_quda.h:901
virtual void reconstruct(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType) const
void full_dslash(ColorSpinorField &out, const ColorSpinorField &in) const
void M(ColorSpinorField &out, const ColorSpinorField &in) const
Apply M for the dirac op. E.g. the Schur Complement operator.
void prepare(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType) const
void reconstruct(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType) const
void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const
Apply MdagM operator which may be optimized.
void m5inv_eofa_xpay(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, double a=-1.) const
void m5inv_eofa(ColorSpinorField &out, const ColorSpinorField &in) const
DiracMobiusEofaPC(const DiracParam &param)
Complex c_5[QUDA_MAX_DWF_LS]
Definition: dirac_quda.h:820
void Dslash4(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const
Apply the local MdagM operator: equivalent to applying zero Dirichlet boundary condition to MdagM on ...
double mobius_kappa_b
Definition: dirac_quda.h:829
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const
Apply M for the dirac op. E.g. the Schur Complement operator.
virtual void reconstruct(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType) const
void Dslash5(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const
void Dslash5Xpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, const ColorSpinorField &x, const double &k) const
void Dslash4Xpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, const ColorSpinorField &x, const double &k) const
void Dslash4pre(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const
Complex b_5[QUDA_MAX_DWF_LS]
Definition: dirac_quda.h:819
void Dslash4preXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, const ColorSpinorField &x, const double &k) const
virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const
Apply MdagM operator which may be optimized.
virtual void prepare(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType) const
DiracMobius(const DiracParam &param)
Definition: dirac_mobius.cpp:7
double mobius_kappa_c
Definition: dirac_quda.h:830
cudaGaugeField * extended_gauge
Definition: dirac_quda.h:864
void Dslash5invXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, const ColorSpinorField &x, const double &k) const
void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const
Apply MdagM operator which may be optimized.
void Dslash5inv(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const
void prepare(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType) const
void M(ColorSpinorField &out, const ColorSpinorField &in) const
Apply M for the dirac op. E.g. the Schur Complement operator.
void reconstruct(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType) const
DiracMobiusPC(const DiracParam &param)
void MdagMLocal(ColorSpinorField &out, const ColorSpinorField &in) const
Apply the local MdagM operator: equivalent to applying zero Dirichlet boundary condition to MdagM on ...
DiracMobiusPC & operator=(const DiracMobiusPC &dirac)
QudaPrecision Precision() const
int comm_dim_partitioned(int dim)
double eofa_shift
int eofa_pm
bool dagger
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
GaugeCovDev * dirac
Definition: covdev_test.cpp:42
QudaParity parity
Definition: covdev_test.cpp:40
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:34
@ QUDA_DAG_NO
Definition: enum_quda.h:223
@ QUDA_VERBOSE
Definition: enum_quda.h:267
@ QUDA_FULL_SITE_SUBSET
Definition: enum_quda.h:333
@ QUDA_EVEN_PARITY
Definition: enum_quda.h:284
@ QUDA_ODD_PARITY
Definition: enum_quda.h:284
@ QUDA_INVALID_PARITY
Definition: enum_quda.h:284
enum QudaSolutionType_s QudaSolutionType
@ QUDA_MATPC_ODD_ODD_ASYMMETRIC
Definition: enum_quda.h:219
@ QUDA_MATPC_EVEN_EVEN_ASYMMETRIC
Definition: enum_quda.h:218
@ QUDA_MATPC_ODD_ODD
Definition: enum_quda.h:217
@ QUDA_MATPC_EVEN_EVEN
Definition: enum_quda.h:216
@ QUDA_MATPC_SOLUTION
Definition: enum_quda.h:159
@ QUDA_MATPCDAG_MATPC_SOLUTION
Definition: enum_quda.h:161
@ QUDA_QUARTER_PRECISION
Definition: enum_quda.h:62
@ QUDA_HALF_PRECISION
Definition: enum_quda.h:63
@ QUDA_NULL_FIELD_CREATE
Definition: enum_quda.h:360
enum QudaParity_s QudaParity
void axpy(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:43
void apply_dslash5(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, double m_f, double m_5, const Complex *b_5, const Complex *c_5, double a, int eofa_pm, double inv, double kappa, const double *eofa_u, const double *eofa_x, const double *eofa_y, double sherman_morrison, bool dagger, Dslash5Type type)
void apply_fused_dslash(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, ColorSpinorField &y, const ColorSpinorField &x, double m_f, double m_5, const Complex *b_5, const Complex *c_5, bool dagger, int parity, int shift[4], int halo_shift[4], MdwfFusedDslashType type)
double norm2(const CloverField &a, bool inverse=false)
std::complex< double > Complex
Definition: quda_internal.h:86
cudaGaugeField * createExtendedGauge(cudaGaugeField &in, const int *R, TimeProfile &profile, bool redundant_comms=false, QudaReconstructType recon=QUDA_RECONSTRUCT_INVALID)
__host__ __device__ ValueType pow(ValueType x, ExponentType e)
Definition: complex_quda.h:111
void ApplyDomainWall4D(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double m_5, const Complex *b_5, const Complex *c_5, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
Driver for applying the batched Wilson 4-d stencil to a 5-d vector with 4-d preconditioned data order...
@ DSLASH5_MOBIUS_PRE
Definition: dslash_quda.h:559
@ DSLASH5_MOBIUS
Definition: dslash_quda.h:560
@ M5_INV_ZMOBIUS
Definition: dslash_quda.h:563
@ M5INV_EOFA
Definition: dslash_quda.h:565
@ M5_EOFA
Definition: dslash_quda.h:564
@ M5_INV_MOBIUS
Definition: dslash_quda.h:562
void ApplyDslash5(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, double m_f, double m_5, const Complex *b_5, const Complex *c_5, double a, bool dagger, Dslash5Type type)
Apply either the domain-wall / mobius Dslash5 operator or the M5 inverse operator....
ColorSpinorParam csParam
Definition: pack_test.cpp:25
QudaGaugeParam param
Definition: pack_test.cpp:18
#define printfQuda(...)
Definition: util_quda.h:114
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define errorQuda(...)
Definition: util_quda.h:120