QUDA  v1.1.0
A library for QCD on GPUs
domain_wall_dslash_reference.cpp
Go to the documentation of this file.
1 #include <iostream>
2 #include <stdio.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <math.h>
6 #include <complex.h>
7 
8 #include <quda.h>
9 #include <host_utils.h>
10 #include <dslash_reference.h>
12 
13 #include <gauge_field.h>
14 #include <color_spinor_field.h>
15 
16 using namespace quda;
17 
18 // i represents a "half index" into an even or odd "half lattice".
19 // when oddBit={0,1} the half lattice is {even,odd}.
20 //
21 // the displacements, such as dx, refer to the full lattice coordinates.
22 //
23 // neighborIndex() takes a "half index", displaces it, and returns the
24 // new "half index", which can be an index into either the even or odd lattices.
25 // displacements of magnitude one always interchange odd and even lattices.
26 //
27 //
28 int neighborIndex_4d(int i, int oddBit, int dx4, int dx3, int dx2, int dx1) {
29  // On input i should be in the range [0 , ... , Z[0]*Z[1]*Z[2]*Z[3]/2-1].
30  if (i < 0 || i >= (Z[0]*Z[1]*Z[2]*Z[3]/2))
31  { printf("i out of range in neighborIndex_4d\n"); exit(-1); }
32  // Compute the linear index. Then dissect.
33  // fullLatticeIndex_4d is in util_quda.cpp.
34  // The gauge fields live on a 4d sublattice.
35  int X = fullLatticeIndex_4d(i, oddBit);
36  int x4 = X/(Z[2]*Z[1]*Z[0]);
37  int x3 = (X/(Z[1]*Z[0])) % Z[2];
38  int x2 = (X/Z[0]) % Z[1];
39  int x1 = X % Z[0];
40 
41  x4 = (x4+dx4+Z[3]) % Z[3];
42  x3 = (x3+dx3+Z[2]) % Z[2];
43  x2 = (x2+dx2+Z[1]) % Z[1];
44  x1 = (x1+dx1+Z[0]) % Z[0];
45 
46  return (x4*(Z[2]*Z[1]*Z[0]) + x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2;
47 }
48 
49 
50 
51 //#ifndef MULTI_GPU
52 // This is just a copy of gaugeLink() from the quda code, except
53 // that neighborIndex() is replaced by the renamed version
54 // neighborIndex_4d().
55 //ok
56 template <typename Float>
57 Float *gaugeLink_sgpu(int i, int dir, int oddBit, Float **gaugeEven,
58  Float **gaugeOdd) {
59  Float **gaugeField;
60  int j;
61 
62  // If going forward, just grab link at site, U_\mu(x).
63  if (dir % 2 == 0) {
64  j = i;
65  // j will get used in the return statement below.
66  gaugeField = (oddBit ? gaugeOdd : gaugeEven);
67  } else {
68  // If going backward, a shift must occur, U_\mu(x-\muhat)^\dagger;
69  // dagger happens elsewhere, here we're just doing index gymnastics.
70  switch (dir) {
71  case 1: j = neighborIndex_4d(i, oddBit, 0, 0, 0, -1); break;
72  case 3: j = neighborIndex_4d(i, oddBit, 0, 0, -1, 0); break;
73  case 5: j = neighborIndex_4d(i, oddBit, 0, -1, 0, 0); break;
74  case 7: j = neighborIndex_4d(i, oddBit, -1, 0, 0, 0); break;
75  default: j = -1; break;
76  }
77  gaugeField = (oddBit ? gaugeEven : gaugeOdd);
78  }
79 
80  return &gaugeField[dir/2][j*(3*3*2)];
81 }
82 
83 
84 //#else
85 
86 //Standard 4d version (nothing to change)
87 template <typename Float>
88 Float *gaugeLink_mgpu(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd, Float** ghostGaugeEven, Float** ghostGaugeOdd, int n_ghost_faces, int nbr_distance) {
89  Float **gaugeField;
90  int j;
91  int d = nbr_distance;
92  if (dir % 2 == 0) {
93  j = i;
94  gaugeField = (oddBit ? gaugeOdd : gaugeEven);
95  }
96  else {
97 
98  int Y = fullLatticeIndex(i, oddBit);
99  int x4 = Y/(Z[2]*Z[1]*Z[0]);
100  int x3 = (Y/(Z[1]*Z[0])) % Z[2];
101  int x2 = (Y/Z[0]) % Z[1];
102  int x1 = Y % Z[0];
103  int X1= Z[0];
104  int X2= Z[1];
105  int X3= Z[2];
106  int X4= Z[3];
107  Float* ghostGaugeField;
108 
109  switch (dir) {
110  case 1:
111  { //-X direction
112  int new_x1 = (x1 - d + X1 )% X1;
113  if (x1 -d < 0 && comm_dim_partitioned(0)){
114  ghostGaugeField = (oddBit?ghostGaugeEven[0]: ghostGaugeOdd[0]);
115  int offset = (n_ghost_faces + x1 -d)*X4*X3*X2/2 + (x4*X3*X2 + x3*X2+x2)/2;
116  return &ghostGaugeField[offset*(3*3*2)];
117  }
118  j = (x4*X3*X2*X1 + x3*X2*X1 + x2*X1 + new_x1) / 2;
119  break;
120  }
121  case 3:
122  { //-Y direction
123  int new_x2 = (x2 - d + X2 )% X2;
124  if (x2 -d < 0 && comm_dim_partitioned(1)){
125  ghostGaugeField = (oddBit?ghostGaugeEven[1]: ghostGaugeOdd[1]);
126  int offset = (n_ghost_faces + x2 -d)*X4*X3*X1/2 + (x4*X3*X1 + x3*X1+x1)/2;
127  return &ghostGaugeField[offset*(3*3*2)];
128  }
129  j = (x4*X3*X2*X1 + x3*X2*X1 + new_x2*X1 + x1) / 2;
130  break;
131 
132  }
133  case 5:
134  { //-Z direction
135  int new_x3 = (x3 - d + X3 )% X3;
136  if (x3 -d < 0 && comm_dim_partitioned(2)){
137  ghostGaugeField = (oddBit?ghostGaugeEven[2]: ghostGaugeOdd[2]);
138  int offset = (n_ghost_faces + x3 -d)*X4*X2*X1/2 + (x4*X2*X1 + x2*X1+x1)/2;
139  return &ghostGaugeField[offset*(3*3*2)];
140  }
141  j = (x4*X3*X2*X1 + new_x3*X2*X1 + x2*X1 + x1) / 2;
142  break;
143  }
144  case 7:
145  { //-T direction
146  int new_x4 = (x4 - d + X4)% X4;
147  if (x4 -d < 0 && comm_dim_partitioned(3)){
148  ghostGaugeField = (oddBit?ghostGaugeEven[3]: ghostGaugeOdd[3]);
149  int offset = (n_ghost_faces + x4 -d)*X1*X2*X3/2 + (x3*X2*X1 + x2*X1+x1)/2;
150  return &ghostGaugeField[offset*(3*3*2)];
151  }
152  j = (new_x4*(X3*X2*X1) + x3*(X2*X1) + x2*(X1) + x1) / 2;
153  break;
154  }//7
155 
156  default: j = -1; printf("ERROR: wrong dir \n"); exit(1);
157  }
158  gaugeField = (oddBit ? gaugeEven : gaugeOdd);
159 
160  }
161 
162  return &gaugeField[dir/2][j*(3*3*2)];
163 }
164 
165 
166 //J Directions 0..7 were used in the 4d code.
167 //J Directions 8,9 will be for P_- and P_+, chiral
168 //J projectors.
169 const double projector[10][4][4][2] = {
170  {
171  {{1,0}, {0,0}, {0,0}, {0,-1}},
172  {{0,0}, {1,0}, {0,-1}, {0,0}},
173  {{0,0}, {0,1}, {1,0}, {0,0}},
174  {{0,1}, {0,0}, {0,0}, {1,0}}
175  },
176  {
177  {{1,0}, {0,0}, {0,0}, {0,1}},
178  {{0,0}, {1,0}, {0,1}, {0,0}},
179  {{0,0}, {0,-1}, {1,0}, {0,0}},
180  {{0,-1}, {0,0}, {0,0}, {1,0}}
181  },
182  {
183  {{1,0}, {0,0}, {0,0}, {1,0}},
184  {{0,0}, {1,0}, {-1,0}, {0,0}},
185  {{0,0}, {-1,0}, {1,0}, {0,0}},
186  {{1,0}, {0,0}, {0,0}, {1,0}}
187  },
188  {
189  {{1,0}, {0,0}, {0,0}, {-1,0}},
190  {{0,0}, {1,0}, {1,0}, {0,0}},
191  {{0,0}, {1,0}, {1,0}, {0,0}},
192  {{-1,0}, {0,0}, {0,0}, {1,0}}
193  },
194  {
195  {{1,0}, {0,0}, {0,-1}, {0,0}},
196  {{0,0}, {1,0}, {0,0}, {0,1}},
197  {{0,1}, {0,0}, {1,0}, {0,0}},
198  {{0,0}, {0,-1}, {0,0}, {1,0}}
199  },
200  {
201  {{1,0}, {0,0}, {0,1}, {0,0}},
202  {{0,0}, {1,0}, {0,0}, {0,-1}},
203  {{0,-1}, {0,0}, {1,0}, {0,0}},
204  {{0,0}, {0,1}, {0,0}, {1,0}}
205  },
206  {
207  {{1,0}, {0,0}, {-1,0}, {0,0}},
208  {{0,0}, {1,0}, {0,0}, {-1,0}},
209  {{-1,0}, {0,0}, {1,0}, {0,0}},
210  {{0,0}, {-1,0}, {0,0}, {1,0}}
211  },
212  {
213  {{1,0}, {0,0}, {1,0}, {0,0}},
214  {{0,0}, {1,0}, {0,0}, {1,0}},
215  {{1,0}, {0,0}, {1,0}, {0,0}},
216  {{0,0}, {1,0}, {0,0}, {1,0}}
217  },
218  // P_+ = P_R
219  {
220  {{0,0}, {0,0}, {0,0}, {0,0}},
221  {{0,0}, {0,0}, {0,0}, {0,0}},
222  {{0,0}, {0,0}, {2,0}, {0,0}},
223  {{0,0}, {0,0}, {0,0}, {2,0}}
224  },
225  // P_- = P_L
226  {
227  {{2,0}, {0,0}, {0,0}, {0,0}},
228  {{0,0}, {2,0}, {0,0}, {0,0}},
229  {{0,0}, {0,0}, {0,0}, {0,0}},
230  {{0,0}, {0,0}, {0,0}, {0,0}}
231  }
232 };
233 
234 
235 // todo pass projector
236 template <typename Float>
237 void multiplySpinorByDiracProjector5(Float *res, int projIdx, Float *spinorIn) {
238  for (int i=0; i<4*3*2; i++) res[i] = 0.0;
239 
240  for (int s = 0; s < 4; s++) {
241  for (int t = 0; t < 4; t++) {
242  Float projRe = projector[projIdx][s][t][0];
243  Float projIm = projector[projIdx][s][t][1];
244 
245  for (int m = 0; m < 3; m++) {
246  Float spinorRe = spinorIn[t*(3*2) + m*(2) + 0];
247  Float spinorIm = spinorIn[t*(3*2) + m*(2) + 1];
248  res[s*(3*2) + m*(2) + 0] += projRe*spinorRe - projIm*spinorIm;
249  res[s*(3*2) + m*(2) + 1] += projRe*spinorIm + projIm*spinorRe;
250  }
251  }
252  }
253 }
254 
255 
256 //#ifndef MULTI_GPU
257 // dslashReference_4d()
258 //J This is just the 4d wilson dslash of quda code, with a
259 //J few small changes to take into account that the spinors
260 //J are 5d and the gauge fields are 4d.
261 //
262 // if oddBit is zero: calculate odd parity spinor elements (using even parity spinor)
263 // if oddBit is one: calculate even parity spinor elements
264 //
265 // if daggerBit is zero: perform ordinary dslash operator
266 // if daggerBit is one: perform hermitian conjugate of dslash
267 //
268 //An "ok" will only be granted once check2.tex is deemed complete,
269 //since the logic in this function is important and nontrivial.
270 template <QudaPCType type, typename sFloat, typename gFloat>
271 void dslashReference_4d_sgpu(sFloat *res, gFloat **gaugeFull, sFloat *spinorField, int oddBit, int daggerBit)
272 {
273 
274  // Initialize the return half-spinor to zero. Note that it is a
275  // 5d spinor, hence the use of V5h.
276  for (int i=0; i<V5h*4*3*2; i++) res[i] = 0.0;
277 
278  // Some pointers that we use to march through arrays.
279  gFloat *gaugeEven[4], *gaugeOdd[4];
280  // Initialize to beginning of even and odd parts of
281  // gauge array.
282  for (int dir = 0; dir < 4; dir++) {
283  gaugeEven[dir] = gaugeFull[dir];
284  // Note the use of Vh here, since the gauge fields
285  // are 4-dim'l.
286  gaugeOdd[dir] = gaugeFull[dir] + Vh * gauge_site_size;
287  }
288  int sp_idx,gaugeOddBit;
289  for (int xs=0;xs<Ls;xs++) {
290  for (int gge_idx = 0; gge_idx < Vh; gge_idx++) {
291  for (int dir = 0; dir < 8; dir++) {
292  sp_idx=gge_idx+Vh*xs;
293  // Here is a function call to study. It is defined near
294  // Line 90 of this file.
295  // Here we have to switch oddBit depending on the value of xs. E.g., suppose
296  // xs=1. Then the odd spinor site x1=x2=x3=x4=0 wants the even gauge array
297  // element 0, so that we get U_\mu(0).
298  gaugeOddBit = (xs%2 == 0 || type == QUDA_4D_PC) ? oddBit : (oddBit+1) % 2;
299  gFloat *gauge = gaugeLink_sgpu(gge_idx, dir, gaugeOddBit, gaugeEven, gaugeOdd);
300 
301  // Even though we're doing the 4d part of the dslash, we need
302  // to use a 5d neighbor function, to get the offsets right.
303  sFloat *spinor = spinorNeighbor_5d<type>(sp_idx, dir, oddBit, spinorField);
304  sFloat projectedSpinor[4*3*2], gaugedSpinor[4*3*2];
305  int projIdx = 2*(dir/2)+(dir+daggerBit)%2;
306  multiplySpinorByDiracProjector5(projectedSpinor, projIdx, spinor);
307 
308  for (int s = 0; s < 4; s++) {
309  if (dir % 2 == 0) {
310  su3Mul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]);
311 #ifdef DBUG_VERBOSE
312  std::cout << "spinor:" << std::endl;
313  printSpinorElement(&projectedSpinor[s*(3*2)],0,QUDA_DOUBLE_PRECISION);
314  std::cout << "gauge:" << std::endl;
315 #endif
316  } else {
317  su3Tmul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]);
318  }
319  }
320 
321  sum(&res[sp_idx*(4*3*2)], &res[sp_idx*(4*3*2)], gaugedSpinor, 4*3*2);
322  }
323  }
324  }
325 }
326 
327 #ifdef MULTI_GPU
328 template <QudaPCType type, typename sFloat, typename gFloat>
329 void dslashReference_4d_mgpu(sFloat *res, gFloat **gaugeFull, gFloat **ghostGauge, sFloat *spinorField,
330  sFloat **fwdSpinor, sFloat **backSpinor, int oddBit, int daggerBit)
331 {
332  for (int i = 0; i < V5h * spinor_site_size; i++) res[i] = 0.0;
333 
334  gFloat *gaugeEven[4], *gaugeOdd[4];
335  gFloat *ghostGaugeEven[4], *ghostGaugeOdd[4];
336 
337  for (int dir = 0; dir < 4; dir++)
338  {
339  gaugeEven[dir] = gaugeFull[dir];
340  gaugeOdd[dir] = gaugeFull[dir] + Vh * gauge_site_size;
341 
342  ghostGaugeEven[dir] = ghostGauge[dir];
343  ghostGaugeOdd[dir] = ghostGauge[dir] + (faceVolume[dir] / 2) * gauge_site_size;
344  }
345  for (int xs=0;xs<Ls;xs++)
346  {
347  int sp_idx;
348  for (int i = 0; i < Vh; i++)
349  {
350  sp_idx = i + Vh*xs;
351  for (int dir = 0; dir < 8; dir++)
352  {
353  int gaugeOddBit = (xs%2 == 0 || type == QUDA_4D_PC) ? oddBit : (oddBit + 1) % 2;
354 
355  gFloat *gauge = gaugeLink_mgpu(i, dir, gaugeOddBit, gaugeEven, gaugeOdd, ghostGaugeEven, ghostGaugeOdd, 1, 1);//this is unchanged from MPi version
356  sFloat *spinor = spinorNeighbor_5d_mgpu<type>(sp_idx, dir, oddBit, spinorField, fwdSpinor, backSpinor, 1, 1);
357 
358  sFloat projectedSpinor[spinor_site_size], gaugedSpinor[spinor_site_size];
359  int projIdx = 2 * (dir / 2) + (dir + daggerBit) % 2;
360  multiplySpinorByDiracProjector5(projectedSpinor, projIdx, spinor);
361 
362  for (int s = 0; s < 4; s++) {
363  if (dir % 2 == 0)
364  su3Mul(&gaugedSpinor[s * (3 * 2)], gauge, &projectedSpinor[s * (3 * 2)]);
365  else
366  su3Tmul(&gaugedSpinor[s * (3 * 2)], gauge, &projectedSpinor[s * (3 * 2)]);
367  }
368  sum(&res[sp_idx * (4 * 3 * 2)], &res[sp_idx * (4 * 3 * 2)], gaugedSpinor, 4 * 3 * 2);
369  }
370  }
371  }
372 }
373 #endif
374 
375 template <bool plus, class sFloat> // plus = true -> gamma_+; plus = false -> gamma_-
376 void axpby_ssp_project(sFloat *z, sFloat a, sFloat *x, sFloat b, sFloat *y, int idx_cb_4d, int s, int sp)
377 {
378  // z_s = a*x_s + b*\gamma_+/-*y_sp
379  // Will use the DeGrand-Rossi/CPS basis, where gamma5 is diagonal:
380  // +1 0
381  // 0 -1
382  for (int spin = (plus ? 0 : 2); spin < (plus ? 2 : 4); spin++) {
383  for (int color_comp = 0; color_comp < 6; color_comp++) {
384  z[(s * Vh + idx_cb_4d) * 24 + spin * 6 + color_comp] = a * x[(s * Vh + idx_cb_4d) * 24 + spin * 6 + color_comp]
385  + b * y[(sp * Vh + idx_cb_4d) * 24 + spin * 6 + color_comp];
386  }
387  }
388  for (int spin = (plus ? 2 : 0); spin < (plus ? 4 : 2); spin++) {
389  for (int color_comp = 0; color_comp < 6; color_comp++) {
390  z[(s * Vh + idx_cb_4d) * 24 + spin * 6 + color_comp] = a * x[(s * Vh + idx_cb_4d) * 24 + spin * 6 + color_comp];
391  }
392  }
393 }
394 
395 template <typename sFloat>
396 void mdw_eofa_m5_ref(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm, sFloat m5, sFloat b,
397  sFloat c, sFloat mq1, sFloat mq2, sFloat mq3, int eofa_pm, sFloat eofa_shift)
398 {
399  // res: the output spinor field
400  // spinorField: the input spinor field
401  // oddBit: even-odd bit
402  // daggerBit: dagger or not
403  // mferm: m_f
404 
405  sFloat alpha = b + c;
406  sFloat eofa_norm = alpha * (mq3 - mq2) * std::pow(alpha + 1., 2 * Ls)
407  / (std::pow(alpha + 1., Ls) + mq2 * std::pow(alpha - 1., Ls))
408  / (std::pow(alpha + 1., Ls) + mq3 * std::pow(alpha - 1., Ls));
409 
410  sFloat kappa = 0.5 * (c * (4. + m5) - 1.) / (b * (4. + m5) + 1.);
411 
412  constexpr int spinor_size = 4 * 3 * 2;
413  for (int i = 0; i < V5h; i++) {
414  for (int one_site = 0; one_site < 24; one_site++) { res[i * spinor_size + one_site] = 0.; }
415  for (int dir = 8; dir < 10; dir++) {
416  // Calls for an extension of the original function.
417  // 8 is forward hop, which wants P_+, 9 is backward hop,
418  // which wants P_-. Dagger reverses these.
419  sFloat *spinor = spinorNeighbor_5d<QUDA_4D_PC>(i, dir, oddBit, spinorField);
420  sFloat projectedSpinor[spinor_size];
421  int projIdx = 2 * (dir / 2) + (dir + daggerBit) % 2;
422  multiplySpinorByDiracProjector5(projectedSpinor, projIdx, spinor);
423  // J Need a conditional here for s=0 and s=Ls-1.
424  int X = fullLatticeIndex_5d_4dpc(i, oddBit);
425  int xs = X / (Z[3] * Z[2] * Z[1] * Z[0]);
426 
427  if ((xs == 0 && dir == 9) || (xs == Ls - 1 && dir == 8)) {
428  ax(projectedSpinor, -mferm, projectedSpinor, spinor_size);
429  }
430  sum(&res[i * spinor_size], &res[i * spinor_size], projectedSpinor, spinor_size);
431  }
432  // 1 + kappa*D5
433  axpby((sFloat)1., &spinorField[i * spinor_size], kappa, &res[i * spinor_size], spinor_size);
434  }
435 
436  // Initialize
437  std::vector<sFloat> shift_coeffs(Ls);
438 
439  // Construct Mooee_shift
440  sFloat N = (eofa_pm ? 1.0 : -1.0) * (2.0 * eofa_shift * eofa_norm)
441  * (std::pow(alpha + 1.0, Ls) + mq1 * std::pow(alpha - 1.0, Ls));
442 
443  // For the kappa preconditioning
444  int idx = 0;
445  N *= 1. / (b * (m5 + 4.) + 1.);
446  for (int s = 0; s < Ls; s++) {
447  idx = eofa_pm ? (s) : (Ls - 1 - s);
448  shift_coeffs[idx] = N * std::pow(-1.0, s) * std::pow(alpha - 1.0, s) / std::pow(alpha + 1.0, Ls + s + 1);
449  }
450 
451  // The eofa part.
452  for (int idx_cb_4d = 0; idx_cb_4d < Vh; idx_cb_4d++) {
453  for (int s = 0; s < Ls; s++) {
454  if (daggerBit == 0) {
455  if (eofa_pm) {
456  axpby_ssp_project<true>(res, (sFloat)1., res, shift_coeffs[s], spinorField, idx_cb_4d, s, Ls - 1);
457  } else {
458  axpby_ssp_project<false>(res, (sFloat)1., res, shift_coeffs[s], spinorField, idx_cb_4d, s, 0);
459  }
460  } else {
461  if (eofa_pm) {
462  axpby_ssp_project<true>(res, (sFloat)1., res, shift_coeffs[s], spinorField, idx_cb_4d, Ls - 1, s);
463  } else {
464  axpby_ssp_project<false>(res, (sFloat)1., res, shift_coeffs[s], spinorField, idx_cb_4d, 0, s);
465  }
466  }
467  }
468  }
469 }
470 
471 void mdw_eofa_m5(void *res, void *spinorField, int oddBit, int daggerBit, double mferm, double m5, double b, double c,
472  double mq1, double mq2, double mq3, int eofa_pm, double eofa_shift, QudaPrecision precision)
473 {
474  if (precision == QUDA_DOUBLE_PRECISION) {
475  mdw_eofa_m5_ref<double>((double *)res, (double *)spinorField, oddBit, daggerBit, mferm, m5, b, c, mq1, mq2, mq3,
477  } else {
478  mdw_eofa_m5_ref<float>((float *)res, (float *)spinorField, oddBit, daggerBit, mferm, m5, b, c, mq1, mq2, mq3,
480  }
481  return;
482 }
483 
484 //Currently we consider only spacetime decomposition (not in 5th dim), so this operator is local
485 template <QudaPCType type, bool zero_initialize = false, typename sFloat>
486 void dslashReference_5th(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm)
487 {
488  for (int i = 0; i < V5h; i++) {
489  if (zero_initialize) for(int one_site = 0 ; one_site < 24 ; one_site++)
490  res[i*(4*3*2)+one_site] = 0.0;
491  for (int dir = 8; dir < 10; dir++) {
492  // Calls for an extension of the original function.
493  // 8 is forward hop, which wants P_+, 9 is backward hop,
494  // which wants P_-. Dagger reverses these.
495  sFloat *spinor = spinorNeighbor_5d<type>(i, dir, oddBit, spinorField);
496  sFloat projectedSpinor[4*3*2];
497  int projIdx = 2*(dir/2)+(dir+daggerBit)%2;
498  multiplySpinorByDiracProjector5(projectedSpinor, projIdx, spinor);
499  //J Need a conditional here for s=0 and s=Ls-1.
500  int X = (type == QUDA_5D_PC) ? fullLatticeIndex_5d(i, oddBit) : fullLatticeIndex_5d_4dpc(i, oddBit);
501  int xs = X/(Z[3]*Z[2]*Z[1]*Z[0]);
502 
503  if ( (xs == 0 && dir == 9) || (xs == Ls-1 && dir == 8) ) {
504  ax(projectedSpinor,(sFloat)(-mferm),projectedSpinor,4*3*2);
505  }
506  sum(&res[i*(4*3*2)], &res[i*(4*3*2)], projectedSpinor, 4*3*2);
507  }
508  }
509 }
510 
511 //Currently we consider only spacetime decomposition (not in 5th dim), so this operator is local
512 template <typename sFloat>
513 void dslashReference_5th_inv(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm, double *kappa)
514 {
515  double *inv_Ftr = (double*)malloc(Ls*sizeof(sFloat));
516  double *Ftr = (double*)malloc(Ls*sizeof(sFloat));
517  for(int xs = 0 ; xs < Ls ; xs++)
518  {
519  inv_Ftr[xs] = 1.0/(1.0+pow(2.0*kappa[xs], Ls)*mferm);
520  Ftr[xs] = -2.0*kappa[xs]*mferm*inv_Ftr[xs];
521  for (int i = 0; i < Vh; i++) {
522  memcpy(&res[24*(i+Vh*xs)], &spinorField[24*(i+Vh*xs)], 24*sizeof(sFloat));
523  }
524  }
525  if(daggerBit == 0)
526  {
527  // s = 0
528  for (int i = 0; i < Vh; i++) {
529  ax(&res[12+24*(i+Vh*(Ls-1))],(sFloat)(inv_Ftr[0]), &spinorField[12+24*(i+Vh*(Ls-1))], 12);
530  }
531 
532  // s = 1 ... ls-2
533  for(int xs = 0 ; xs <= Ls-2 ; ++xs)
534  {
535  for (int i = 0; i < Vh; i++) {
536  axpy((sFloat)(2.0*kappa[xs]), &res[24*(i+Vh*xs)], &res[24*(i+Vh*(xs+1))], 12);
537  axpy((sFloat)Ftr[xs], &res[12+24*(i+Vh*xs)], &res[12+24*(i+Vh*(Ls-1))], 12);
538  }
539  for (int tmp_s = 0 ; tmp_s < Ls ; tmp_s++)
540  Ftr[tmp_s] *= 2.0*kappa[tmp_s];
541  }
542  for(int xs = 0 ; xs < Ls ; xs++)
543  {
544  Ftr[xs] = -pow(2.0*kappa[xs],Ls-1)*mferm*inv_Ftr[xs];
545  }
546  // s = ls-2 ... 0
547  for(int xs = Ls-2 ; xs >=0 ; --xs)
548  {
549  for (int i = 0; i < Vh; i++) {
550  axpy((sFloat)Ftr[xs], &res[24*(i+Vh*(Ls-1))], &res[24*(i+Vh*xs)], 12);
551  axpy((sFloat)(2.0*kappa[xs]), &res[12+24*(i+Vh*(xs+1))], &res[12+24*(i+Vh*xs)], 12);
552  }
553  for (int tmp_s = 0 ; tmp_s < Ls ; tmp_s++)
554  Ftr[tmp_s] /= 2.0*kappa[tmp_s];
555  }
556  // s = ls -1
557  for (int i = 0; i < Vh; i++) {
558  ax(&res[24*(i+Vh*(Ls-1))], (sFloat)(inv_Ftr[Ls-1]), &res[24*(i+Vh*(Ls-1))], 12);
559  }
560  }
561  else
562  {
563  // s = 0
564  for (int i = 0; i < Vh; i++) {
565  ax(&res[24*(i+Vh*(Ls-1))],(sFloat)(inv_Ftr[0]), &spinorField[24*(i+Vh*(Ls-1))], 12);
566  }
567 
568  // s = 1 ... ls-2
569  for(int xs = 0 ; xs <= Ls-2 ; ++xs)
570  {
571  for (int i = 0; i < Vh; i++) {
572  axpy((sFloat)Ftr[xs], &res[24*(i+Vh*xs)], &res[24*(i+Vh*(Ls-1))], 12);
573  axpy((sFloat)(2.0*kappa[xs]), &res[12+24*(i+Vh*xs)], &res[12+24*(i+Vh*(xs+1))], 12);
574  }
575  for (int tmp_s = 0 ; tmp_s < Ls ; tmp_s++)
576  Ftr[tmp_s] *= 2.0*kappa[tmp_s];
577  }
578  for(int xs = 0 ; xs < Ls ; xs++)
579  {
580  Ftr[xs] = -pow(2.0*kappa[xs],Ls-1)*mferm*inv_Ftr[xs];
581  }
582  // s = ls-2 ... 0
583  for(int xs = Ls-2 ; xs >=0 ; --xs)
584  {
585  for (int i = 0; i < Vh; i++) {
586  axpy((sFloat)(2.0*kappa[xs]), &res[24*(i+Vh*(xs+1))], &res[24*(i+Vh*xs)], 12);
587  axpy((sFloat)Ftr[xs], &res[12+24*(i+Vh*(Ls-1))], &res[12+24*(i+Vh*xs)], 12);
588  }
589  for (int tmp_s = 0 ; tmp_s < Ls ; tmp_s++)
590  Ftr[tmp_s] /= 2.0*kappa[tmp_s];
591  }
592  // s = ls -1
593  for (int i = 0; i < Vh; i++) {
594  ax(&res[12+24*(i+Vh*(Ls-1))], (sFloat)(inv_Ftr[Ls-1]), &res[12+24*(i+Vh*(Ls-1))], 12);
595  }
596  }
597  free(inv_Ftr);
598  free(Ftr);
599 }
600 
601 template <typename sComplex> sComplex cpow(const sComplex &x, int y)
602 {
603  static_assert(sizeof(sComplex) == sizeof(Complex), "C and C++ complex type sizes do not match");
604  // note that C++ standard explicitly calls out that casting between C and C++ complex is legal
605  const Complex x_ = reinterpret_cast<const Complex &>(x);
606  Complex z_ = std::pow(x_, y);
607  sComplex z = reinterpret_cast<sComplex &>(z_);
608  return z;
609 }
610 
611 // Currently we consider only spacetime decomposition (not in 5th dim), so this operator is local
612 template <typename sFloat, typename sComplex>
613 void mdslashReference_5th_inv(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm, sComplex *kappa)
614 {
615  sComplex *inv_Ftr = (sComplex *)malloc(Ls * sizeof(sComplex));
616  sComplex *Ftr = (sComplex *)malloc(Ls * sizeof(sComplex));
617  for (int xs = 0; xs < Ls; xs++) {
618  inv_Ftr[xs] = 1.0 / (1.0 + cpow(2.0 * kappa[xs], Ls) * mferm);
619  Ftr[xs] = -2.0 * kappa[xs] * mferm * inv_Ftr[xs];
620  for (int i = 0; i < Vh; i++) {
621  memcpy(&res[24 * (i + Vh * xs)], &spinorField[24 * (i + Vh * xs)], 24 * sizeof(sFloat));
622  }
623  }
624  if (daggerBit == 0) {
625  // s = 0
626  for (int i = 0; i < Vh; i++) {
627  ax((sComplex *)&res[12 + 24 * (i + Vh * (Ls - 1))], inv_Ftr[0],
628  (sComplex *)&spinorField[12 + 24 * (i + Vh * (Ls - 1))], 6);
629  }
630 
631  // s = 1 ... ls-2
632  for (int xs = 0; xs <= Ls - 2; ++xs) {
633  for (int i = 0; i < Vh; i++) {
634  axpy((2.0 * kappa[xs]), (sComplex *)&res[24 * (i + Vh * xs)], (sComplex *)&res[24 * (i + Vh * (xs + 1))], 6);
635  axpy(Ftr[xs], (sComplex *)&res[12 + 24 * (i + Vh * xs)], (sComplex *)&res[12 + 24 * (i + Vh * (Ls - 1))], 6);
636  }
637  for (int tmp_s = 0; tmp_s < Ls; tmp_s++) Ftr[tmp_s] *= 2.0 * kappa[tmp_s];
638  }
639  for (int xs = 0; xs < Ls; xs++) Ftr[xs] = -cpow(2.0 * kappa[xs], Ls - 1) * mferm * inv_Ftr[xs];
640 
641  // s = ls-2 ... 0
642  for (int xs = Ls - 2; xs >= 0; --xs) {
643  for (int i = 0; i < Vh; i++) {
644  axpy(Ftr[xs], (sComplex *)&res[24 * (i + Vh * (Ls - 1))], (sComplex *)&res[24 * (i + Vh * xs)], 6);
645  axpy((2.0 * kappa[xs]), (sComplex *)&res[12 + 24 * (i + Vh * (xs + 1))],
646  (sComplex *)&res[12 + 24 * (i + Vh * xs)], 6);
647  }
648  for (int tmp_s = 0; tmp_s < Ls; tmp_s++) Ftr[tmp_s] /= 2.0 * kappa[tmp_s];
649  }
650  // s = ls -1
651  for (int i = 0; i < Vh; i++) {
652  ax((sComplex *)&res[24 * (i + Vh * (Ls - 1))], inv_Ftr[Ls - 1], (sComplex *)&res[24 * (i + Vh * (Ls - 1))], 6);
653  }
654  } else {
655  // s = 0
656  for (int i = 0; i < Vh; i++) {
657  ax((sComplex *)&res[24 * (i + Vh * (Ls - 1))], inv_Ftr[0], (sComplex *)&spinorField[24 * (i + Vh * (Ls - 1))], 6);
658  }
659 
660  // s = 1 ... ls-2
661  for (int xs = 0; xs <= Ls - 2; ++xs) {
662  for (int i = 0; i < Vh; i++) {
663  axpy(Ftr[xs], (sComplex *)&res[24 * (i + Vh * xs)], (sComplex *)&res[24 * (i + Vh * (Ls - 1))], 6);
664  axpy((2.0 * kappa[xs]), (sComplex *)&res[12 + 24 * (i + Vh * xs)],
665  (sComplex *)&res[12 + 24 * (i + Vh * (xs + 1))], 6);
666  }
667  for (int tmp_s = 0; tmp_s < Ls; tmp_s++) Ftr[tmp_s] *= 2.0 * kappa[tmp_s];
668  }
669  for (int xs = 0; xs < Ls; xs++) Ftr[xs] = -cpow(2.0 * kappa[xs], Ls - 1) * mferm * inv_Ftr[xs];
670 
671  // s = ls-2 ... 0
672  for (int xs = Ls - 2; xs >= 0; --xs) {
673  for (int i = 0; i < Vh; i++) {
674  axpy((2.0 * kappa[xs]), (sComplex *)&res[24 * (i + Vh * (xs + 1))], (sComplex *)&res[24 * (i + Vh * xs)], 6);
675  axpy(Ftr[xs], (sComplex *)&res[12 + 24 * (i + Vh * (Ls - 1))], (sComplex *)&res[12 + 24 * (i + Vh * xs)], 6);
676  }
677  for (int tmp_s = 0; tmp_s < Ls; tmp_s++) Ftr[tmp_s] /= 2.0 * kappa[tmp_s];
678  }
679  // s = ls -1
680  for (int i = 0; i < Vh; i++) {
681  ax((sComplex *)&res[12 + 24 * (i + Vh * (Ls - 1))], inv_Ftr[Ls - 1],
682  (sComplex *)&res[12 + 24 * (i + Vh * (Ls - 1))], 6);
683  }
684  }
685  free(inv_Ftr);
686  free(Ftr);
687 }
688 
689 template <typename sFloat>
690 void mdw_eofa_m5inv_ref(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm, sFloat m5, sFloat b,
691  sFloat c, sFloat mq1, sFloat mq2, sFloat mq3, int eofa_pm, sFloat eofa_shift)
692 {
693  // res: the output spinor field
694  // spinorField: the input spinor field
695  // oddBit: even-odd bit
696  // daggerBit: dagger or not
697  // mferm: m_f
698 
699  sFloat alpha = b + c;
700  sFloat eofa_norm = alpha * (mq3 - mq2) * std::pow(alpha + 1., 2 * Ls)
701  / (std::pow(alpha + 1., Ls) + mq2 * std::pow(alpha - 1., Ls))
702  / (std::pow(alpha + 1., Ls) + mq3 * std::pow(alpha - 1., Ls));
703  sFloat kappa5 = (c * (4. + m5) - 1.) / (b * (4. + m5) + 1.); // alpha = b+c
704 
705  using sComplex = double _Complex;
706 
707  std::vector<sComplex> kappa_array(Ls, -0.5 * kappa5);
708  std::vector<sFloat> eofa_u(Ls);
709  std::vector<sFloat> eofa_x(Ls);
710  std::vector<sFloat> eofa_y(Ls);
711 
712  mdslashReference_5th_inv(res, spinorField, oddBit, daggerBit, mferm, kappa_array.data());
713 
714  sFloat N = (eofa_pm ? +1. : -1.) * (2. * eofa_shift * eofa_norm)
715  * (std::pow(alpha + 1., Ls) + mq1 * std::pow(alpha - 1., Ls)) / (b * (m5 + 4.) + 1.);
716 
717  // Here the signs are somewhat mixed:
718  // There is one -1 from N for eofa_pm = minus, thus the u_- here is actually -u_- in the document
719  // It turns out this actually simplies things.
720  for (int s = 0; s < Ls; s++) {
721  eofa_u[eofa_pm ? s : Ls - 1 - s] = N * std::pow(-1., s) * std::pow(alpha - 1., s) / std::pow(alpha + 1., Ls + s + 1);
722  }
723 
724  sFloat sherman_morrison_fac;
725 
726  sFloat factor = -kappa5 * mferm;
727  if (eofa_pm) {
728  // eofa_pm = plus
729  // Computing x
730  eofa_x[0] = eofa_u[0];
731  for (int s = Ls - 1; s > 0; s--) {
732  eofa_x[0] -= factor * eofa_u[s];
733  factor *= -kappa5;
734  }
735  eofa_x[0] /= 1. + factor;
736  for (int s = 1; s < Ls; s++) { eofa_x[s] = eofa_x[s - 1] * (-kappa5) + eofa_u[s]; }
737  // Computing y
738  eofa_y[Ls - 1] = 1. / (1. + factor);
739  sherman_morrison_fac = eofa_x[Ls - 1];
740  for (int s = Ls - 1; s > 0; s--) { eofa_y[s - 1] = eofa_y[s] * (-kappa5); }
741  } else {
742  // eofa_pm = minus
743  // Computing x
744  eofa_x[Ls - 1] = eofa_u[Ls - 1];
745  for (int s = 0; s < Ls - 1; s++) {
746  eofa_x[Ls - 1] -= factor * eofa_u[s];
747  factor *= -kappa5;
748  }
749  eofa_x[Ls - 1] /= 1. + factor;
750  for (int s = Ls - 1; s > 0; s--) { eofa_x[s - 1] = eofa_x[s] * (-kappa5) + eofa_u[s - 1]; }
751  // Computing y
752  eofa_y[0] = 1. / (1. + factor);
753  sherman_morrison_fac = eofa_x[0];
754  for (int s = 1; s < Ls; s++) { eofa_y[s] = eofa_y[s - 1] * (-kappa5); }
755  }
756  sherman_morrison_fac = -0.5 / (1. + sherman_morrison_fac); // 0.5 for the spin project factor
757 
758  // The EOFA stuff
759  for (int idx_cb_4d = 0; idx_cb_4d < Vh; idx_cb_4d++) {
760  for (int s = 0; s < Ls; s++) {
761  for (int sp = 0; sp < Ls; sp++) {
762  sFloat t = 2.0 * sherman_morrison_fac;
763  if (daggerBit == 0) {
764  t *= eofa_x[s] * eofa_y[sp];
765  if (eofa_pm) {
766  axpby_ssp_project<true>(res, (sFloat)1., res, t, spinorField, idx_cb_4d, s, sp);
767  } else {
768  axpby_ssp_project<false>(res, (sFloat)1., res, t, spinorField, idx_cb_4d, s, sp);
769  }
770  } else {
771  t *= eofa_y[s] * eofa_x[sp];
772  if (eofa_pm) {
773  axpby_ssp_project<true>(res, (sFloat)1., res, t, spinorField, idx_cb_4d, s, sp);
774  } else {
775  axpby_ssp_project<false>(res, (sFloat)1., res, t, spinorField, idx_cb_4d, s, sp);
776  }
777  }
778  }
779  }
780  }
781 }
782 
783 void mdw_eofa_m5inv(void *res, void *spinorField, int oddBit, int daggerBit, double mferm, double m5, double b, double c,
784  double mq1, double mq2, double mq3, int eofa_pm, double eofa_shift, QudaPrecision precision)
785 {
786  if (precision == QUDA_DOUBLE_PRECISION) {
787  mdw_eofa_m5inv_ref<double>((double *)res, (double *)spinorField, oddBit, daggerBit, mferm, m5, b, c, mq1, mq2, mq3,
789  } else {
790  mdw_eofa_m5inv_ref<float>((float *)res, (float *)spinorField, oddBit, daggerBit, mferm, m5, b, c, mq1, mq2, mq3,
792  }
793  return;
794 }
795 
796 // this actually applies the preconditioned dslash, e.g., D_ee^{-1} D_eo or D_oo^{-1} D_oe
797 void dw_dslash(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision,
798  QudaGaugeParam &gauge_param, double mferm)
799 {
800 #ifndef MULTI_GPU
801  if (precision == QUDA_DOUBLE_PRECISION) {
802  dslashReference_4d_sgpu<QUDA_5D_PC>((double*)out, (double**)gauge, (double*)in, oddBit, daggerBit);
803  dslashReference_5th<QUDA_5D_PC>((double*)out, (double*)in, oddBit, daggerBit, mferm);
804  } else {
805  dslashReference_4d_sgpu<QUDA_5D_PC>((float*)out, (float**)gauge, (float*)in, oddBit, daggerBit);
806  dslashReference_5th<QUDA_5D_PC>((float*)out, (float*)in, oddBit, daggerBit, (float)mferm);
807  }
808 #else
809 
810  GaugeFieldParam gauge_field_param(gauge, gauge_param);
811  gauge_field_param.ghostExchange = QUDA_GHOST_EXCHANGE_PAD;
812  cpuGaugeField cpu(gauge_field_param);
813  void **ghostGauge = (void**)cpu.Ghost();
814 
815  // Get spinor ghost fields
816  // First wrap the input spinor into a ColorSpinorField
818  csParam.v = in;
819  csParam.nColor = 3;
820  csParam.nSpin = 4;
821  csParam.nDim = 5; //for DW dslash
822  for (int d=0; d<4; d++) csParam.x[d] = Z[d];
823  csParam.x[4] = Ls;//5th dimention
824  csParam.setPrecision(precision);
825  csParam.pad = 0;
827  csParam.x[0] /= 2;
833 
834  cpuColorSpinorField inField(csParam);
835 
836  { // Now do the exchange
837  QudaParity otherParity = QUDA_INVALID_PARITY;
838  if (oddBit == QUDA_EVEN_PARITY) otherParity = QUDA_ODD_PARITY;
839  else if (oddBit == QUDA_ODD_PARITY) otherParity = QUDA_EVEN_PARITY;
840  else errorQuda("ERROR: full parity not supported in function %s", __FUNCTION__);
841  const int nFace = 1;
842 
843  inField.exchangeGhost(otherParity, nFace, daggerBit);
844  }
845  void** fwd_nbr_spinor = inField.fwdGhostFaceBuffer;
846  void** back_nbr_spinor = inField.backGhostFaceBuffer;
847  //NOTE: hopping in 5th dimension does not use MPI.
848  if (precision == QUDA_DOUBLE_PRECISION) {
849  dslashReference_4d_mgpu<QUDA_5D_PC>((double*)out, (double**)gauge, (double**)ghostGauge, (double*)in,(double**)fwd_nbr_spinor, (double**)back_nbr_spinor, oddBit, daggerBit);
850  //dslashReference_4d_sgpu<QUDA_5D_PC>((double*)out, (double**)gauge, (double*)in, oddBit, daggerBit);
851  dslashReference_5th<QUDA_5D_PC>((double*)out, (double*)in, oddBit, daggerBit, mferm);
852  } else {
853  dslashReference_4d_mgpu<QUDA_5D_PC>((float*)out, (float**)gauge, (float**)ghostGauge, (float*)in,
854  (float**)fwd_nbr_spinor, (float**)back_nbr_spinor, oddBit, daggerBit);
855  dslashReference_5th<QUDA_5D_PC>((float*)out, (float*)in, oddBit, daggerBit, (float)mferm);
856  }
857 
858 #endif
859 
860 }
861 
862 void dslash_4_4d(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
863 {
864 #ifndef MULTI_GPU
865  if (precision == QUDA_DOUBLE_PRECISION) {
866  dslashReference_4d_sgpu<QUDA_4D_PC>((double*)out, (double**)gauge, (double*)in, oddBit, daggerBit);
867  } else {
868  dslashReference_4d_sgpu<QUDA_4D_PC>((float*)out, (float**)gauge, (float*)in, oddBit, daggerBit);
869  }
870 #else
871 
872  GaugeFieldParam gauge_field_param(gauge, gauge_param);
873  gauge_field_param.ghostExchange = QUDA_GHOST_EXCHANGE_PAD;
874  cpuGaugeField cpu(gauge_field_param);
875  void **ghostGauge = (void**)cpu.Ghost();
876 
877  // Get spinor ghost fields
878  // First wrap the input spinor into a ColorSpinorField
880  csParam.v = in;
881  csParam.nColor = 3;
882  csParam.nSpin = 4;
883  csParam.nDim = 5; //for DW dslash
884  for (int d=0; d<4; d++) csParam.x[d] = Z[d];
885  csParam.x[4] = Ls;//5th dimention
886  csParam.setPrecision(precision);
887  csParam.pad = 0;
889  csParam.x[0] /= 2;
895 
896  cpuColorSpinorField inField(csParam);
897 
898  { // Now do the exchange
899  QudaParity otherParity = QUDA_INVALID_PARITY;
900  if (oddBit == QUDA_EVEN_PARITY) otherParity = QUDA_ODD_PARITY;
901  else if (oddBit == QUDA_ODD_PARITY) otherParity = QUDA_EVEN_PARITY;
902  else errorQuda("ERROR: full parity not supported in function %s", __FUNCTION__);
903  const int nFace = 1;
904 
905  inField.exchangeGhost(otherParity, nFace, daggerBit);
906  }
907  void** fwd_nbr_spinor = inField.fwdGhostFaceBuffer;
908  void** back_nbr_spinor = inField.backGhostFaceBuffer;
909  if (precision == QUDA_DOUBLE_PRECISION) {
910  dslashReference_4d_mgpu<QUDA_4D_PC>((double*)out, (double**)gauge, (double**)ghostGauge, (double*)in,(double**)fwd_nbr_spinor, (double**)back_nbr_spinor, oddBit, daggerBit);
911  } else {
912  dslashReference_4d_mgpu<QUDA_4D_PC>((float*)out, (float**)gauge, (float**)ghostGauge, (float*)in,
913  (float**)fwd_nbr_spinor, (float**)back_nbr_spinor, oddBit, daggerBit);
914  }
915 
916 #endif
917 
918 }
919 
920 void dw_dslash_5_4d(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, bool zero_initialize)
921 {
922  if (precision == QUDA_DOUBLE_PRECISION) {
923  if (zero_initialize) dslashReference_5th<QUDA_4D_PC, true>((double*)out, (double*)in, oddBit, daggerBit, mferm);
924  else dslashReference_5th<QUDA_4D_PC, false>((double*)out, (double*)in, oddBit, daggerBit, mferm);
925  } else {
926  if (zero_initialize) dslashReference_5th<QUDA_4D_PC, true>((float*)out, (float*)in, oddBit, daggerBit, (float)mferm);
927  else dslashReference_5th<QUDA_4D_PC, false>((float*)out, (float*)in, oddBit, daggerBit, (float)mferm);
928  }
929 }
930 
931 void dslash_5_inv(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double *kappa)
932 {
933  if (precision == QUDA_DOUBLE_PRECISION) {
934  dslashReference_5th_inv((double*)out, (double*)in, oddBit, daggerBit, mferm, kappa);
935  } else {
936  dslashReference_5th_inv((float*)out, (float*)in, oddBit, daggerBit, (float)mferm, kappa);
937  }
938 }
939 
940 void mdw_dslash_5_inv(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision,
941  QudaGaugeParam &gauge_param, double mferm, double _Complex *kappa)
942 {
943  if (precision == QUDA_DOUBLE_PRECISION) {
944  mdslashReference_5th_inv((double *)out, (double *)in, oddBit, daggerBit, mferm, kappa);
945  } else {
946  mdslashReference_5th_inv((float *)out, (float *)in, oddBit, daggerBit, (float)mferm, kappa);
947  }
948 }
949 
950 void mdw_dslash_5(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision,
951  QudaGaugeParam &gauge_param, double mferm, double _Complex *kappa, bool zero_initialize)
952 {
953  if (precision == QUDA_DOUBLE_PRECISION) {
954  if (zero_initialize) dslashReference_5th<QUDA_4D_PC,true>((double*)out, (double*)in, oddBit, daggerBit, mferm);
955  else dslashReference_5th<QUDA_4D_PC,false>((double*)out, (double*)in, oddBit, daggerBit, mferm);
956  } else {
957  if (zero_initialize) dslashReference_5th<QUDA_4D_PC,true>((float*)out, (float*)in, oddBit, daggerBit, (float)mferm);
958  else dslashReference_5th<QUDA_4D_PC,false>((float*)out, (float*)in, oddBit, daggerBit, (float)mferm);
959  }
960  for(int xs = 0 ; xs < Ls ; xs++) {
961  cxpay((char *)in + precision * Vh * spinor_site_size * xs, kappa[xs],
962  (char *)out + precision * Vh * spinor_site_size * xs, Vh * spinor_site_size, precision);
963  }
964 }
965 
966 void mdw_dslash_4_pre(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision,
967  QudaGaugeParam &gauge_param, double mferm, double _Complex *b5, double _Complex *c5, bool zero_initialize)
968 {
969  if (precision == QUDA_DOUBLE_PRECISION) {
970  if (zero_initialize) dslashReference_5th<QUDA_4D_PC, true>((double*)out, (double*)in, oddBit, daggerBit, mferm);
971  else dslashReference_5th<QUDA_4D_PC, false>((double*)out, (double*)in, oddBit, daggerBit, mferm);
972  for(int xs = 0 ; xs < Ls ; xs++)
973  {
974  axpby(b5[xs], (double _Complex *)in + Vh * spinor_site_size / 2 * xs, 0.5 * c5[xs],
975  (double _Complex *)out + Vh * spinor_site_size / 2 * xs, Vh * spinor_site_size / 2);
976  }
977  } else {
978  if (zero_initialize)
979  dslashReference_5th<QUDA_4D_PC, true>((float *)out, (float *)in, oddBit, daggerBit, (float)mferm);
980  else dslashReference_5th<QUDA_4D_PC,false>((float*)out, (float*)in, oddBit, daggerBit, (float)mferm);
981  for(int xs = 0 ; xs < Ls ; xs++)
982  {
983  axpby((float _Complex)(b5[xs]), (float _Complex *)in + Vh * (spinor_site_size / 2) * xs,
984  (float _Complex)(0.5 * c5[xs]), (float _Complex *)out + Vh * (spinor_site_size / 2) * xs,
985  Vh * spinor_site_size / 2);
986  }
987  }
988 
989 }
990 
991 void dw_mat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm) {
992 
993  void *inEven = in;
994  void *inOdd = (char *)in + V5h * spinor_site_size * precision;
995  void *outEven = out;
996  void *outOdd = (char *)out + V5h * spinor_site_size * precision;
997 
998  dw_dslash(outOdd, gauge, inEven, 1, dagger_bit, precision, gauge_param, mferm);
999  dw_dslash(outEven, gauge, inOdd, 0, dagger_bit, precision, gauge_param, mferm);
1000 
1001  // lastly apply the kappa term
1002  xpay(in, -kappa, out, V5 * spinor_site_size, precision);
1003 }
1004 
1005 void dw_4d_mat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm) {
1006 
1007  void *inEven = in;
1008  void *inOdd = (char *)in + V5h * spinor_site_size * precision;
1009  void *outEven = out;
1010  void *outOdd = (char *)out + V5h * spinor_site_size * precision;
1011 
1012  dslash_4_4d(outOdd, gauge, inEven, 1, dagger_bit, precision, gauge_param, mferm);
1013  dw_dslash_5_4d(outOdd, gauge, inOdd, 1, dagger_bit, precision, gauge_param, mferm, false);
1014 
1015  dslash_4_4d(outEven, gauge, inOdd, 0, dagger_bit, precision, gauge_param, mferm);
1016  dw_dslash_5_4d(outEven, gauge, inEven, 0, dagger_bit, precision, gauge_param, mferm, false);
1017 
1018  // lastly apply the kappa term
1019  xpay(in, -kappa, out, V5 * spinor_site_size, precision);
1020 }
1021 
1022 void mdw_mat(void *out, void **gauge, void *in, double _Complex *kappa_b, double _Complex *kappa_c, int dagger,
1023  QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *b5, double _Complex *c5)
1024 {
1025  void *tmp = malloc(V5h * spinor_site_size * precision);
1026  double _Complex *kappa5 = (double _Complex *)malloc(Ls * sizeof(double _Complex));
1027 
1028  for(int xs = 0; xs < Ls ; xs++) kappa5[xs] = 0.5*kappa_b[xs]/kappa_c[xs];
1029 
1030  void *inEven = in;
1031  void *inOdd = (char *)in + V5h * spinor_site_size * precision;
1032  void *outEven = out;
1033  void *outOdd = (char *)out + V5h * spinor_site_size * precision;
1034 
1035  if (!dagger) {
1036  mdw_dslash_4_pre(tmp, gauge, inEven, 0, dagger, precision, gauge_param, mferm, b5, c5, true);
1037  dslash_4_4d(outOdd, gauge, tmp, 1, dagger, precision, gauge_param, mferm);
1038  mdw_dslash_5(tmp, gauge, inOdd, 1, dagger, precision, gauge_param, mferm, kappa5, true);
1039  } else {
1040  dslash_4_4d(tmp, gauge, inEven, 1, dagger, precision, gauge_param, mferm);
1041  mdw_dslash_4_pre(outOdd, gauge, tmp, 0, dagger, precision, gauge_param, mferm, b5, c5, true);
1042  mdw_dslash_5(tmp, gauge, inOdd, 1, dagger, precision, gauge_param, mferm, kappa5, true);
1043  }
1044 
1045  for(int xs = 0 ; xs < Ls ; xs++) {
1046  cxpay((char *)tmp + precision * Vh * spinor_site_size * xs, -kappa_b[xs],
1047  (char *)outOdd + precision * Vh * spinor_site_size * xs, Vh * spinor_site_size, precision);
1048  }
1049 
1050  if (!dagger) {
1051  mdw_dslash_4_pre(tmp, gauge, inOdd, 1, dagger, precision, gauge_param, mferm, b5, c5, true);
1052  dslash_4_4d(outEven, gauge, tmp, 0, dagger, precision, gauge_param, mferm);
1053  mdw_dslash_5(tmp, gauge, inEven, 0, dagger, precision, gauge_param, mferm, kappa5, true);
1054  } else {
1055  dslash_4_4d(tmp, gauge, inOdd, 0, dagger, precision, gauge_param, mferm);
1056  mdw_dslash_4_pre(outEven, gauge, tmp, 1, dagger, precision, gauge_param, mferm, b5, c5, true);
1057  mdw_dslash_5(tmp, gauge, inEven, 0, dagger, precision, gauge_param, mferm, kappa5, true);
1058  }
1059 
1060  for(int xs = 0 ; xs < Ls ; xs++) {
1061  cxpay((char *)tmp + precision * Vh * spinor_site_size * xs, -kappa_b[xs],
1062  (char *)outEven + precision * Vh * spinor_site_size * xs, Vh * spinor_site_size, precision);
1063  }
1064 
1065  free(kappa5);
1066  free(tmp);
1067 }
1068 
1069 void mdw_eofa_mat(void *out, void **gauge, void *in, int dagger, QudaPrecision precision, QudaGaugeParam &gauge_param,
1070  double mferm, double m5, double b, double c, double mq1, double mq2, double mq3, int eofa_pm,
1071  double eofa_shift)
1072 {
1073  void *tmp = malloc(V5h * spinor_site_size * precision);
1074 
1075  using sComplex = double _Complex;
1076 
1077  std::vector<sComplex> b_array(Ls, b);
1078  std::vector<sComplex> c_array(Ls, c);
1079 
1080  auto b5 = b_array.data();
1081  auto c5 = c_array.data();
1082 
1083  auto kappa_b = 0.5 / (b * (4. + m5) + 1.);
1084 
1085  void *inEven = in;
1086  void *inOdd = (char *)in + V5h * spinor_site_size * precision;
1087  void *outEven = out;
1088  void *outOdd = (char *)out + V5h * spinor_site_size * precision;
1089 
1090  if (!dagger) {
1091  mdw_dslash_4_pre(tmp, gauge, inEven, 0, dagger, precision, gauge_param, mferm, b5, c5, true);
1092  dslash_4_4d(outOdd, gauge, tmp, 1, dagger, precision, gauge_param, mferm);
1093  mdw_eofa_m5(tmp, inOdd, 1, dagger, mferm, m5, b, c, mq1, mq2, mq3, eofa_pm, eofa_shift, precision);
1094  } else {
1095  dslash_4_4d(tmp, gauge, inEven, 1, dagger, precision, gauge_param, mferm);
1096  mdw_dslash_4_pre(outOdd, gauge, tmp, 0, dagger, precision, gauge_param, mferm, b5, c5, true);
1097  mdw_eofa_m5(tmp, inOdd, 1, dagger, mferm, m5, b, c, mq1, mq2, mq3, eofa_pm, eofa_shift, precision);
1098  }
1099 
1100  for (int xs = 0; xs < Ls; xs++) {
1101  cxpay((char *)tmp + precision * Vh * spinor_site_size * xs, -kappa_b,
1102  (char *)outOdd + precision * Vh * spinor_site_size * xs, Vh * spinor_site_size, precision);
1103  }
1104 
1105  if (!dagger) {
1106  mdw_dslash_4_pre(tmp, gauge, inOdd, 1, dagger, precision, gauge_param, mferm, b5, c5, true);
1107  dslash_4_4d(outEven, gauge, tmp, 0, dagger, precision, gauge_param, mferm);
1108  mdw_eofa_m5(tmp, inEven, 0, dagger, mferm, m5, b, c, mq1, mq2, mq3, eofa_pm, eofa_shift, precision);
1109  } else {
1110  dslash_4_4d(tmp, gauge, inOdd, 0, dagger, precision, gauge_param, mferm);
1111  mdw_dslash_4_pre(outEven, gauge, tmp, 1, dagger, precision, gauge_param, mferm, b5, c5, true);
1112  mdw_eofa_m5(tmp, inEven, 0, dagger, mferm, m5, b, c, mq1, mq2, mq3, eofa_pm, eofa_shift, precision);
1113  }
1114 
1115  for (int xs = 0; xs < Ls; xs++) {
1116  cxpay((char *)tmp + precision * Vh * spinor_site_size * xs, -kappa_b,
1117  (char *)outEven + precision * Vh * spinor_site_size * xs, Vh * spinor_site_size, precision);
1118  }
1119 
1120  free(tmp);
1121 }
1122 //
1123 void dw_matdagmat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
1124 {
1125 
1126  void *tmp = malloc(V5 * spinor_site_size * precision);
1127  dw_mat(tmp, gauge, in, kappa, dagger_bit, precision, gauge_param, mferm);
1128  dagger_bit = (dagger_bit == 1) ? 0 : 1;
1129  dw_mat(out, gauge, tmp, kappa, dagger_bit, precision, gauge_param, mferm);
1130 
1131  free(tmp);
1132 }
1133 
1134 void dw_matpc(void *out, void **gauge, void *in, double kappa, QudaMatPCType matpc_type, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
1135 {
1136  void *tmp = malloc(V5h * spinor_site_size * precision);
1137 
1139  dw_dslash(tmp, gauge, in, 1, dagger_bit, precision, gauge_param, mferm);
1140  dw_dslash(out, gauge, tmp, 0, dagger_bit, precision, gauge_param, mferm);
1141  } else {
1142  dw_dslash(tmp, gauge, in, 0, dagger_bit, precision, gauge_param, mferm);
1143  dw_dslash(out, gauge, tmp, 1, dagger_bit, precision, gauge_param, mferm);
1144  }
1145 
1146  // lastly apply the kappa term
1147  double kappa2 = -kappa*kappa;
1148  xpay(in, kappa2, out, V5h * spinor_site_size, precision);
1149 
1150  free(tmp);
1151 }
1152 
1153 
1154 void dw_4d_matpc(void *out, void **gauge, void *in, double kappa, QudaMatPCType matpc_type, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
1155 {
1156  double kappa2 = -kappa*kappa;
1157  double *kappa5 = (double*)malloc(Ls*sizeof(double));
1158  for(int xs = 0; xs < Ls ; xs++)
1159  kappa5[xs] = kappa;
1160  void *tmp = malloc(V5h * spinor_site_size * precision);
1161  //------------------------------------------
1162  double *output = (double*)out;
1163  for (int k = 0; k < V5h * spinor_site_size; k++) output[k] = 0.0;
1164  //------------------------------------------
1165 
1166  int odd_bit = (matpc_type == QUDA_MATPC_ODD_ODD || matpc_type == QUDA_MATPC_ODD_ODD_ASYMMETRIC) ? 1 : 0;
1167  bool symmetric =(matpc_type == QUDA_MATPC_EVEN_EVEN || matpc_type == QUDA_MATPC_ODD_ODD) ? true : false;
1168  QudaParity parity[2] = {static_cast<QudaParity>((1 + odd_bit) % 2), static_cast<QudaParity>((0 + odd_bit) % 2)};
1169 
1170  if (symmetric && !dagger_bit) {
1171  dslash_4_4d(tmp, gauge, in, parity[0], dagger_bit, precision, gauge_param, mferm);
1172  dslash_5_inv(out, gauge, tmp, parity[0], dagger_bit, precision, gauge_param, mferm, kappa5);
1173  dslash_4_4d(tmp, gauge, out, parity[1], dagger_bit, precision, gauge_param, mferm);
1174  dslash_5_inv(out, gauge, tmp, parity[1], dagger_bit, precision, gauge_param, mferm, kappa5);
1175  xpay(in, kappa2, out, V5h * spinor_site_size, precision);
1176  } else if (symmetric && dagger_bit) {
1177  dslash_5_inv(tmp, gauge, in, parity[1], dagger_bit, precision, gauge_param, mferm, kappa5);
1178  dslash_4_4d(out, gauge, tmp, parity[0], dagger_bit, precision, gauge_param, mferm);
1179  dslash_5_inv(tmp, gauge, out, parity[0], dagger_bit, precision, gauge_param, mferm, kappa5);
1180  dslash_4_4d(out, gauge, tmp, parity[1], dagger_bit, precision, gauge_param, mferm);
1181  xpay(in, kappa2, out, V5h * spinor_site_size, precision);
1182  } else {
1183  dslash_4_4d(tmp, gauge, in, parity[0], dagger_bit, precision, gauge_param, mferm);
1184  dslash_5_inv(out, gauge, tmp, parity[0], dagger_bit, precision, gauge_param, mferm, kappa5);
1185  dslash_4_4d(tmp, gauge, out, parity[1], dagger_bit, precision, gauge_param, mferm);
1186  xpay(in, kappa2, tmp, V5h * spinor_site_size, precision);
1187  dw_dslash_5_4d(out, gauge, in, parity[1], dagger_bit, precision, gauge_param, mferm, true);
1188  xpay(tmp, -kappa, out, V5h * spinor_site_size, precision);
1189  }
1190  free(tmp);
1191  free(kappa5);
1192 }
1193 
1194 void mdw_matpc(void *out, void **gauge, void *in, double _Complex *kappa_b, double _Complex *kappa_c,
1195  QudaMatPCType matpc_type, int dagger, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm,
1196  double _Complex *b5, double _Complex *c5)
1197 {
1198  void *tmp = malloc(V5h * spinor_site_size * precision);
1199  double _Complex *kappa5 = (double _Complex *)malloc(Ls * sizeof(double _Complex));
1200  double _Complex *kappa2 = (double _Complex *)malloc(Ls * sizeof(double _Complex));
1201  double _Complex *kappa_mdwf = (double _Complex *)malloc(Ls * sizeof(double _Complex));
1202  for(int xs = 0; xs < Ls ; xs++)
1203  {
1204  kappa5[xs] = 0.5*kappa_b[xs]/kappa_c[xs];
1205  kappa2[xs] = -kappa_b[xs]*kappa_b[xs];
1206  kappa_mdwf[xs] = -kappa5[xs];
1207  }
1208 
1209  int odd_bit = (matpc_type == QUDA_MATPC_ODD_ODD || matpc_type == QUDA_MATPC_ODD_ODD_ASYMMETRIC) ? 1 : 0;
1210  bool symmetric =(matpc_type == QUDA_MATPC_EVEN_EVEN || matpc_type == QUDA_MATPC_ODD_ODD) ? true : false;
1211  QudaParity parity[2] = {static_cast<QudaParity>((1 + odd_bit) % 2), static_cast<QudaParity>((0 + odd_bit) % 2)};
1212 
1213  if (symmetric && !dagger) {
1214  mdw_dslash_4_pre(tmp, gauge, in, parity[1], dagger, precision, gauge_param, mferm, b5, c5, true);
1215  dslash_4_4d(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm);
1216  mdw_dslash_5_inv(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm, kappa_mdwf);
1217  mdw_dslash_4_pre(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm, b5, c5, true);
1218  dslash_4_4d(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm);
1219  mdw_dslash_5_inv(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm, kappa_mdwf);
1220  for(int xs = 0 ; xs < Ls ; xs++) {
1221  cxpay((char *)in + precision * Vh * spinor_site_size * xs, kappa2[xs],
1222  (char *)out + precision * Vh * spinor_site_size * xs, Vh * spinor_site_size, precision);
1223  }
1224  } else if (symmetric && dagger) {
1225  mdw_dslash_5_inv(tmp, gauge, in, parity[1], dagger, precision, gauge_param, mferm, kappa_mdwf);
1226  dslash_4_4d(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm);
1227  mdw_dslash_4_pre(tmp, gauge, out, parity[0], dagger, precision, gauge_param, mferm, b5, c5, true);
1228  mdw_dslash_5_inv(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm, kappa_mdwf);
1229  dslash_4_4d(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm);
1230  mdw_dslash_4_pre(out, gauge, tmp, parity[1], dagger, precision, gauge_param, mferm, b5, c5, true);
1231  for(int xs = 0 ; xs < Ls ; xs++) {
1232  cxpay((char *)in + precision * Vh * spinor_site_size * xs, kappa2[xs],
1233  (char *)out + precision * Vh * spinor_site_size * xs, Vh * spinor_site_size, precision);
1234  }
1235  } else if (!symmetric && !dagger) {
1236  mdw_dslash_4_pre(out, gauge, in, parity[1], dagger, precision, gauge_param, mferm, b5, c5, true);
1237  dslash_4_4d(tmp, gauge, out, parity[0], dagger, precision, gauge_param, mferm);
1238  mdw_dslash_5_inv(out, gauge, tmp, parity[1], dagger, precision, gauge_param, mferm, kappa_mdwf);
1239  mdw_dslash_4_pre(tmp, gauge, out, parity[0], dagger, precision, gauge_param, mferm, b5, c5, true);
1240  dslash_4_4d(out, gauge, tmp, parity[1], dagger, precision, gauge_param, mferm);
1241  mdw_dslash_5(tmp, gauge, in, parity[0], dagger, precision, gauge_param, mferm, kappa5, true);
1242  for(int xs = 0 ; xs < Ls ; xs++) {
1243  cxpay((char *)tmp + precision * Vh * spinor_site_size * xs, kappa2[xs],
1244  (char *)out + precision * Vh * spinor_site_size * xs, Vh * spinor_site_size, precision);
1245  }
1246  } else if (!symmetric && dagger) {
1247  dslash_4_4d(out, gauge, in, parity[0], dagger, precision, gauge_param, mferm);
1248  mdw_dslash_4_pre(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm, b5, c5, true);
1249  mdw_dslash_5_inv(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm, kappa_mdwf);
1250  dslash_4_4d(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm);
1251  mdw_dslash_4_pre(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm, b5, c5, true);
1252  mdw_dslash_5(tmp, gauge, in, parity[0], dagger, precision, gauge_param, mferm, kappa5, true);
1253  for(int xs = 0 ; xs < Ls ; xs++) {
1254  cxpay((char *)tmp + precision * Vh * spinor_site_size * xs, kappa2[xs],
1255  (char *)out + precision * Vh * spinor_site_size * xs, Vh * spinor_site_size, precision);
1256  }
1257  } else {
1258  errorQuda("Unsupported matpc_type=%d dagger=%d", matpc_type, dagger);
1259  }
1260 
1261  free(tmp);
1262  free(kappa5);
1263  free(kappa2);
1264  free(kappa_mdwf);
1265 }
1266 
1267 void mdw_eofa_matpc(void *out, void **gauge, void *in, QudaMatPCType matpc_type, int dagger, QudaPrecision precision,
1268  QudaGaugeParam &gauge_param, double mferm, double m5, double b, double c, double mq1, double mq2,
1269  double mq3, int eofa_pm, double eofa_shift)
1270 {
1271  void *tmp = malloc(V5h * spinor_site_size * precision);
1272 
1273  using sComplex = double _Complex;
1274 
1275  std::vector<sComplex> kappa2_array(Ls, -0.25 / (b * (4. + m5) + 1.) / (b * (4. + m5) + 1.));
1276  std::vector<sComplex> b_array(Ls, b);
1277  std::vector<sComplex> c_array(Ls, c);
1278 
1279  auto kappa2 = kappa2_array.data();
1280  auto b5 = b_array.data();
1281  auto c5 = c_array.data();
1282 
1283  int odd_bit = (matpc_type == QUDA_MATPC_ODD_ODD || matpc_type == QUDA_MATPC_ODD_ODD_ASYMMETRIC) ? 1 : 0;
1284  bool symmetric = (matpc_type == QUDA_MATPC_EVEN_EVEN || matpc_type == QUDA_MATPC_ODD_ODD) ? true : false;
1285  QudaParity parity[2] = {static_cast<QudaParity>((1 + odd_bit) % 2), static_cast<QudaParity>((0 + odd_bit) % 2)};
1286 
1287  if (symmetric && !dagger) {
1288  mdw_dslash_4_pre(tmp, gauge, in, parity[1], dagger, precision, gauge_param, mferm, b5, c5, true);
1289  dslash_4_4d(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm);
1290  mdw_eofa_m5inv(tmp, out, parity[1], dagger, mferm, m5, b, c, mq1, mq2, mq3, eofa_pm, eofa_shift, precision);
1291  mdw_dslash_4_pre(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm, b5, c5, true);
1292  dslash_4_4d(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm);
1293  mdw_eofa_m5inv(out, tmp, parity[0], dagger, mferm, m5, b, c, mq1, mq2, mq3, eofa_pm, eofa_shift, precision);
1294  for (int xs = 0; xs < Ls; xs++) {
1295  cxpay((char *)in + precision * Vh * spinor_site_size * xs, kappa2[xs],
1296  (char *)out + precision * Vh * spinor_site_size * xs, Vh * spinor_site_size, precision);
1297  }
1298  } else if (symmetric && dagger) {
1299  mdw_eofa_m5inv(tmp, in, parity[1], dagger, mferm, m5, b, c, mq1, mq2, mq3, eofa_pm, eofa_shift, precision);
1300  dslash_4_4d(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm);
1301  mdw_dslash_4_pre(tmp, gauge, out, parity[0], dagger, precision, gauge_param, mferm, b5, c5, true);
1302  mdw_eofa_m5inv(out, tmp, parity[0], dagger, mferm, m5, b, c, mq1, mq2, mq3, eofa_pm, eofa_shift, precision);
1303  dslash_4_4d(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm);
1304  mdw_dslash_4_pre(out, gauge, tmp, parity[1], dagger, precision, gauge_param, mferm, b5, c5, true);
1305  for (int xs = 0; xs < Ls; xs++) {
1306  cxpay((char *)in + precision * Vh * spinor_site_size * xs, kappa2[xs],
1307  (char *)out + precision * Vh * spinor_site_size * xs, Vh * spinor_site_size, precision);
1308  }
1309  } else if (!symmetric && !dagger) {
1310  mdw_dslash_4_pre(out, gauge, in, parity[1], dagger, precision, gauge_param, mferm, b5, c5, true);
1311  dslash_4_4d(tmp, gauge, out, parity[0], dagger, precision, gauge_param, mferm);
1312  mdw_eofa_m5inv(out, tmp, parity[1], dagger, mferm, m5, b, c, mq1, mq2, mq3, eofa_pm, eofa_shift, precision);
1313  mdw_dslash_4_pre(tmp, gauge, out, parity[0], dagger, precision, gauge_param, mferm, b5, c5, true);
1314  dslash_4_4d(out, gauge, tmp, parity[1], dagger, precision, gauge_param, mferm);
1315  mdw_eofa_m5(tmp, in, parity[0], dagger, mferm, m5, b, c, mq1, mq2, mq3, eofa_pm, eofa_shift, precision);
1316  for (int xs = 0; xs < Ls; xs++) {
1317  cxpay((char *)tmp + precision * Vh * spinor_site_size * xs, kappa2[xs],
1318  (char *)out + precision * Vh * spinor_site_size * xs, Vh * spinor_site_size, precision);
1319  }
1320  } else if (!symmetric && dagger) {
1321  dslash_4_4d(out, gauge, in, parity[0], dagger, precision, gauge_param, mferm);
1322  mdw_dslash_4_pre(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm, b5, c5, true);
1323  mdw_eofa_m5inv(out, tmp, parity[0], dagger, mferm, m5, b, c, mq1, mq2, mq3, eofa_pm, eofa_shift, precision);
1324  dslash_4_4d(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm);
1325  mdw_dslash_4_pre(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm, b5, c5, true);
1326  mdw_eofa_m5(tmp, in, parity[0], dagger, mferm, m5, b, c, mq1, mq2, mq3, eofa_pm, eofa_shift, precision);
1327  for (int xs = 0; xs < Ls; xs++) {
1328  cxpay((char *)tmp + precision * Vh * spinor_site_size * xs, kappa2[xs],
1329  (char *)out + precision * Vh * spinor_site_size * xs, Vh * spinor_site_size, precision);
1330  }
1331  } else {
1332  errorQuda("Unsupported matpc_type=%d dagger=%d", matpc_type, dagger);
1333  }
1334 
1335  free(tmp);
1336 }
1337 
1338 void mdw_mdagm_local(void *out, void **gauge, void *in, double _Complex *kappa_b, double _Complex *kappa_c,
1340  double _Complex *b5, double _Complex *c5)
1341 {
1342 
1343  int R[4];
1344 
1345  for (int d = 0; d < 4; d++) { R[d] = comm_dim_partitioned(d) ? 2 : 0; }
1346 
1347  cpuGaugeField *padded_gauge = createExtendedGauge(gauge, gauge_param, R);
1348 
1349  int padded_V = 1;
1350  int W[4];
1351  for (int d = 0; d < 4; d++) {
1352  W[d] = Z[d] + 2 * R[d];
1353  padded_V *= Z[d] + 2 * R[d];
1354  }
1355  int padded_V5 = padded_V * Ls;
1356  int padded_Vh = padded_V / 2;
1357  int padded_V5h = padded_Vh * Ls;
1358 
1359  static_assert(sizeof(char) == 1, "This code assumes sizeof(char) == 1.");
1360 
1361  char *padded_in = (char *)malloc(padded_V5h * spinor_site_size * precision);
1362  memset(padded_in, 0, padded_V5h * spinor_site_size * precision);
1363  char *padded_out = (char *)malloc(padded_V5h * spinor_site_size * precision);
1364  memset(padded_out, 0, padded_V5h * spinor_site_size * precision);
1365  char *padded_tmp = (char *)malloc(padded_V5h * spinor_site_size * precision);
1366  memset(padded_tmp, 0, padded_V5h * spinor_site_size * precision);
1367 
1368  char *in_alias = (char *)in;
1369  char *out_alias = (char *)out;
1370 
1371  for (int s = 0; s < Ls; s++) {
1372  for (int index_cb_4d = 0; index_cb_4d < Vh; index_cb_4d++) {
1373  // calculate padded_index_cb_4d
1374  int x[4];
1375  coordinate_from_shrinked_index(x, index_cb_4d, Z, R, 0); // parity = 0
1376  int padded_index_cb_4d = index_4d_cb_from_coordinate_4d(x, W);
1377  // copy data
1378  memcpy(&padded_in[spinor_site_size * precision * (s * padded_Vh + padded_index_cb_4d)],
1379  &in_alias[spinor_site_size * precision * (s * Vh + index_cb_4d)], spinor_site_size * precision);
1380  }
1381  }
1382 
1383  QudaGaugeParam padded_gauge_param(gauge_param);
1384  for (int d = 0; d < 4; d++) { padded_gauge_param.X[d] += 2 * R[d]; }
1385 
1386  void **padded_gauge_p = (void **)(padded_gauge->Gauge_p());
1387 
1388  // Extend these global variables then restore them
1389  int V5_old = V5;
1390  V5 = padded_V5;
1391  int Vh_old = Vh;
1392  Vh = padded_Vh;
1393  int V5h_old = V5h;
1394  V5h = padded_V5h;
1395  int Z_old[4];
1396  for (int d = 0; d < 4; d++) {
1397  Z_old[d] = Z[d];
1398  Z[d] = W[d];
1399  }
1400 
1401  // dagger = 0
1402  mdw_matpc(padded_tmp, padded_gauge_p, padded_in, kappa_b, kappa_c, matpc_type, 0, precision, padded_gauge_param,
1403  mferm, b5, c5);
1404  // dagger = 1
1405  mdw_matpc(padded_out, padded_gauge_p, padded_tmp, kappa_b, kappa_c, matpc_type, 1, precision, padded_gauge_param,
1406  mferm, b5, c5);
1407 
1408  // Restore them
1409  V5 = V5_old;
1410  Vh = Vh_old;
1411  V5h = V5h_old;
1412  for (int d = 0; d < 4; d++) { Z[d] = Z_old[d]; }
1413 
1414  for (int s = 0; s < Ls; s++) {
1415  for (int index_cb_4d = 0; index_cb_4d < Vh; index_cb_4d++) {
1416  // calculate padded_index_cb_4d
1417  int x[4];
1418  coordinate_from_shrinked_index(x, index_cb_4d, Z, R, 0); // parity = 0
1419  int padded_index_cb_4d = index_4d_cb_from_coordinate_4d(x, W);
1420  // copy data
1421  memcpy(&out_alias[spinor_site_size * precision * (s * Vh + index_cb_4d)],
1422  &padded_out[spinor_site_size * precision * (s * padded_Vh + padded_index_cb_4d)],
1423  spinor_site_size * precision);
1424  }
1425  }
1426 
1427  free(padded_in);
1428  free(padded_out);
1429  free(padded_tmp);
1430 
1431  delete padded_gauge;
1432 }
1433 
1434 /*
1435 // Apply the even-odd preconditioned Dirac operator
1436 template <typename sFloat, typename gFloat>
1437 void MatPC(sFloat *outEven, gFloat **gauge, sFloat *inEven, sFloat kappa,
1438  QudaMatPCType matpc_type, sFloat mferm) {
1439 
1440  sFloat *tmp = (sFloat*)malloc(V5h*spinor_site_size*sizeof(sFloat));
1441 
1442  // full dslash operator
1443  if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
1444  dslashReference_4d(tmp, gauge, inEven, 1, 0);
1445  dslashReference_5th(tmp, inEven, 1, 0, mferm);
1446  dslashReference_4d(outEven, gauge, tmp, 0, 0);
1447  dslashReference_5th(outEven, tmp, 0, 0, mferm);
1448  } else {
1449  dslashReference_4d(tmp, gauge, inEven, 0, 0);
1450  dslashReference_5th(tmp, inEven, 0, 0, mferm);
1451  dslashReference_4d(outEven, gauge, tmp, 1, 0);
1452  dslashReference_5th(outEven, tmp, 1, 0, mferm);
1453  }
1454 
1455  // lastly apply the kappa term
1456  sFloat kappa2 = -kappa*kappa;
1457  xpay(inEven, kappa2, outEven, V5h*spinor_site_size);
1458  free(tmp);
1459 }
1460 
1461 // Apply the even-odd preconditioned Dirac operator
1462 template <typename sFloat, typename gFloat>
1463 void MatPCDag(sFloat *outEven, gFloat **gauge, sFloat *inEven, sFloat kappa,
1464  QudaMatPCType matpc_type, sFloat mferm) {
1465 
1466  sFloat *tmp = (sFloat*)malloc(V5h*spinor_site_size*sizeof(sFloat));
1467 
1468  // full dslash operator
1469  if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
1470  dslashReference_4d(tmp, gauge, inEven, 1, 1);
1471  dslashReference_5th(tmp, inEven, 1, 1, mferm);
1472  dslashReference_4d(outEven, gauge, tmp, 0, 1);
1473  dslashReference_5th(outEven, tmp, 0, 1, mferm);
1474  } else {
1475  dslashReference_4d(tmp, gauge, inEven, 0, 1);
1476  dslashReference_5th(tmp, inEven, 0, 1, mferm);
1477  dslashReference_4d(outEven, gauge, tmp, 1, 1);
1478  dslashReference_5th(outEven, tmp, 1, 1, mferm);
1479  }
1480 
1481  sFloat kappa2 = -kappa*kappa;
1482  xpay(inEven, kappa2, outEven, V5h*spinor_site_size);
1483  free(tmp);
1484 }
1485 */
1486 
1487 void matpc(void *outEven, void **gauge, void *inEven, double kappa,
1488  QudaMatPCType matpc_type, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision,
1489  double mferm) {
1490 /*
1491  if (!dagger_bit) {
1492  if (sPrecision == QUDA_DOUBLE_PRECISION)
1493  if (gPrecision == QUDA_DOUBLE_PRECISION)
1494  MatPC((double*)outEven, (double**)gauge, (double*)inEven, (double)kappa, matpc_type, (double)mferm);
1495  else
1496  MatPC((double*)outEven, (float**)gauge, (double*)inEven, (double)kappa, matpc_type, (double)mferm);
1497  else
1498  if (gPrecision == QUDA_DOUBLE_PRECISION)
1499  MatPC((float*)outEven, (double**)gauge, (float*)inEven, (float)kappa, matpc_type, (float)mferm);
1500  else
1501  MatPC((float*)outEven, (float**)gauge, (float*)inEven, (float)kappa, matpc_type, (float)mferm);
1502  } else {
1503  if (sPrecision == QUDA_DOUBLE_PRECISION)
1504  if (gPrecision == QUDA_DOUBLE_PRECISION)
1505  MatPCDag((double*)outEven, (double**)gauge, (double*)inEven, (double)kappa, matpc_type, (double)mferm);
1506  else
1507  MatPCDag((double*)outEven, (float**)gauge, (double*)inEven, (double)kappa, matpc_type, (double)mferm);
1508  else
1509  if (gPrecision == QUDA_DOUBLE_PRECISION)
1510  MatPCDag((float*)outEven, (double**)gauge, (float*)inEven, (float)kappa, matpc_type, (float)mferm);
1511  else
1512  MatPCDag((float*)outEven, (float**)gauge, (float*)inEven, (float)kappa, matpc_type, (float)mferm);
1513  }
1514 */
1515 }
1516 
1517 /*
1518 template <typename sFloat, typename gFloat>
1519 void MatDagMat(sFloat *out, gFloat **gauge, sFloat *in, sFloat kappa, sFloat mferm)
1520 {
1521  // Allocate a full spinor.
1522  sFloat *tmp = (sFloat*)malloc(V5*spinor_site_size*sizeof(sFloat));
1523  // Call templates above.
1524  Mat(tmp, gauge, in, kappa, mferm);
1525  MatDag(out, gauge, tmp, kappa, mferm);
1526  free(tmp);
1527 }
1528 
1529 template <typename sFloat, typename gFloat>
1530 void MatPCDagMatPC(sFloat *out, gFloat **gauge, sFloat *in, sFloat kappa,
1531  QudaMatPCType matpc_type, sFloat mferm)
1532 {
1533 
1534  // Allocate half spinor
1535  sFloat *tmp = (sFloat*)malloc(V5h*spinor_site_size*sizeof(sFloat));
1536  // Apply the PC templates above
1537  MatPC(tmp, gauge, in, kappa, matpc_type, mferm);
1538  MatPCDag(out, gauge, tmp, kappa, matpc_type, mferm);
1539  free(tmp);
1540 }
1541 */
1542 // Wrapper to templates that handles different precisions.
1543 void matdagmat(void *out, void **gauge, void *in, double kappa,
1544  QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm)
1545 {
1546 /*
1547  if (sPrecision == QUDA_DOUBLE_PRECISION) {
1548  if (gPrecision == QUDA_DOUBLE_PRECISION)
1549  MatDagMat((double*)out, (double**)gauge, (double*)in, (double)kappa,
1550  (double)mferm);
1551  else
1552  MatDagMat((double*)out, (float**)gauge, (double*)in, (double)kappa, (double)mferm);
1553  } else {
1554  if (gPrecision == QUDA_DOUBLE_PRECISION)
1555  MatDagMat((float*)out, (double**)gauge, (float*)in, (float)kappa,
1556  (float)mferm);
1557  else
1558  MatDagMat((float*)out, (float**)gauge, (float*)in, (float)kappa, (float)mferm);
1559  }
1560 */
1561 }
1562 
1563 // Wrapper to templates that handles different precisions.
1564 void matpcdagmatpc(void *out, void **gauge, void *in, double kappa,
1565  QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm, QudaMatPCType matpc_type)
1566 {
1567 /*
1568  if (sPrecision == QUDA_DOUBLE_PRECISION) {
1569  if (gPrecision == QUDA_DOUBLE_PRECISION)
1570  MatPCDagMatPC((double*)out, (double**)gauge, (double*)in, (double)kappa,
1571  matpc_type, (double)mferm);
1572  else
1573  MatPCDagMatPC((double*)out, (float**)gauge, (double*)in, (double)kappa,
1574  matpc_type, (double)mferm);
1575  } else {
1576  if (gPrecision == QUDA_DOUBLE_PRECISION)
1577  MatPCDagMatPC((float*)out, (double**)gauge, (float*)in, (float)kappa,
1578  matpc_type, (float)mferm);
1579  else
1580  MatPCDagMatPC((float*)out, (float**)gauge, (float*)in, (float)kappa,
1581  matpc_type, (float)mferm);
1582  }
1583 */
1584 }
1585 
1586 
void setPrecision(QudaPrecision precision, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION, bool force_native=false)
const void ** Ghost() const
Definition: gauge_field.h:368
void exchangeGhost(QudaParity parity, int nFace, int dagger, const MemoryLocation *pack_destination=nullptr, const MemoryLocation *halo_location=nullptr, bool gdr_send=false, bool gdr_recv=false, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION) const
This is a unified ghost exchange function for doing a complete halo exchange regardless of the type o...
static void * fwdGhostFaceBuffer[QUDA_MAX_DIM]
static void * backGhostFaceBuffer[QUDA_MAX_DIM]
int comm_dim_partitioned(int dim)
double kappa
double c5
double eofa_shift
double m5
QudaMatPCType matpc_type
double b5
int eofa_pm
bool dagger
int Vh
Definition: host_utils.cpp:38
int Z[4]
Definition: host_utils.cpp:36
void * memset(void *s, int c, size_t n)
QudaParity parity
Definition: covdev_test.cpp:40
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:34
cpuColorSpinorField * spinor
Definition: covdev_test.cpp:31
QudaGaugeParam gauge_param
Definition: covdev_test.cpp:26
void mdw_eofa_m5inv_ref(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm, sFloat m5, sFloat b, sFloat c, sFloat mq1, sFloat mq2, sFloat mq3, int eofa_pm, sFloat eofa_shift)
void dw_mat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
void dw_4d_matpc(void *out, void **gauge, void *in, double kappa, QudaMatPCType matpc_type, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
void mdw_matpc(void *out, void **gauge, void *in, double _Complex *kappa_b, double _Complex *kappa_c, QudaMatPCType matpc_type, int dagger, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *b5, double _Complex *c5)
void dslashReference_5th(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm)
void dw_4d_mat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
void dslash_4_4d(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
void mdw_mat(void *out, void **gauge, void *in, double _Complex *kappa_b, double _Complex *kappa_c, int dagger, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *b5, double _Complex *c5)
void mdw_dslash_4_pre(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *b5, double _Complex *c5, bool zero_initialize)
void dw_dslash(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
void dslashReference_4d_sgpu(sFloat *res, gFloat **gaugeFull, sFloat *spinorField, int oddBit, int daggerBit)
Float * gaugeLink_mgpu(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd, Float **ghostGaugeEven, Float **ghostGaugeOdd, int n_ghost_faces, int nbr_distance)
void axpby_ssp_project(sFloat *z, sFloat a, sFloat *x, sFloat b, sFloat *y, int idx_cb_4d, int s, int sp)
void dw_dslash_5_4d(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, bool zero_initialize)
void mdw_dslash_5_inv(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *kappa)
void dw_matdagmat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
void mdw_eofa_m5_ref(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm, sFloat m5, sFloat b, sFloat c, sFloat mq1, sFloat mq2, sFloat mq3, int eofa_pm, sFloat eofa_shift)
void mdw_eofa_m5(void *res, void *spinorField, int oddBit, int daggerBit, double mferm, double m5, double b, double c, double mq1, double mq2, double mq3, int eofa_pm, double eofa_shift, QudaPrecision precision)
Float * gaugeLink_sgpu(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd)
void mdslashReference_5th_inv(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm, sComplex *kappa)
void dslash_5_inv(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double *kappa)
void multiplySpinorByDiracProjector5(Float *res, int projIdx, Float *spinorIn)
void mdw_mdagm_local(void *out, void **gauge, void *in, double _Complex *kappa_b, double _Complex *kappa_c, QudaMatPCType matpc_type, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *b5, double _Complex *c5)
void mdw_eofa_matpc(void *out, void **gauge, void *in, QudaMatPCType matpc_type, int dagger, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double m5, double b, double c, double mq1, double mq2, double mq3, int eofa_pm, double eofa_shift)
void matpcdagmatpc(void *out, void **gauge, void *in, double kappa, QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm, QudaMatPCType matpc_type)
void mdw_eofa_m5inv(void *res, void *spinorField, int oddBit, int daggerBit, double mferm, double m5, double b, double c, double mq1, double mq2, double mq3, int eofa_pm, double eofa_shift, QudaPrecision precision)
void dw_matpc(void *out, void **gauge, void *in, double kappa, QudaMatPCType matpc_type, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
const double projector[10][4][4][2]
void dslashReference_5th_inv(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm, double *kappa)
void mdw_dslash_5(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *kappa, bool zero_initialize)
void matpc(void *outEven, void **gauge, void *inEven, double kappa, QudaMatPCType matpc_type, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm)
sComplex cpow(const sComplex &x, int y)
void matdagmat(void *out, void **gauge, void *in, double kappa, QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm)
int neighborIndex_4d(int i, int oddBit, int dx4, int dx3, int dx2, int dx1)
void mdw_eofa_mat(void *out, void **gauge, void *in, int dagger, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double m5, double b, double c, double mq1, double mq2, double mq3, int eofa_pm, double eofa_shift)
enum QudaPrecision_s QudaPrecision
@ QUDA_PARITY_SITE_SUBSET
Definition: enum_quda.h:332
@ QUDA_DEGRAND_ROSSI_GAMMA_BASIS
Definition: enum_quda.h:368
@ QUDA_EVEN_PARITY
Definition: enum_quda.h:284
@ QUDA_ODD_PARITY
Definition: enum_quda.h:284
@ QUDA_INVALID_PARITY
Definition: enum_quda.h:284
@ QUDA_GHOST_EXCHANGE_PAD
Definition: enum_quda.h:509
@ QUDA_MATPC_ODD_ODD_ASYMMETRIC
Definition: enum_quda.h:219
@ QUDA_MATPC_EVEN_EVEN_ASYMMETRIC
Definition: enum_quda.h:218
@ QUDA_MATPC_ODD_ODD
Definition: enum_quda.h:217
@ QUDA_MATPC_EVEN_EVEN
Definition: enum_quda.h:216
enum QudaMatPCType_s QudaMatPCType
@ QUDA_EVEN_ODD_SITE_ORDER
Definition: enum_quda.h:340
@ QUDA_DOUBLE_PRECISION
Definition: enum_quda.h:65
@ QUDA_5D_PC
Definition: enum_quda.h:397
@ QUDA_4D_PC
Definition: enum_quda.h:397
@ QUDA_SPACE_SPIN_COLOR_FIELD_ORDER
Definition: enum_quda.h:351
@ QUDA_REFERENCE_FIELD_CREATE
Definition: enum_quda.h:363
enum QudaParity_s QudaParity
#define gauge_site_size
Definition: face_gauge.cpp:34
void cxpay(void *x, double _Complex a, void *y, int length, QudaPrecision precision)
Definition: host_blas.cpp:64
int V5h
Definition: host_utils.cpp:50
int V5
Definition: host_utils.cpp:49
int fullLatticeIndex_5d(int i, int oddBit)
Definition: host_utils.cpp:940
int fullLatticeIndex_4d(int i, int oddBit)
Definition: host_utils.cpp:902
void printSpinorElement(void *spinor, int X, QudaPrecision precision)
Definition: host_utils.cpp:566
int Ls
Definition: host_utils.cpp:48
void coordinate_from_shrinked_index(int coordinate[4], int shrinked_index, const int shrinked_dim[4], const int shift[4], int parity)
Definition: host_utils.cpp:393
int faceVolume[4]
Definition: host_utils.cpp:41
int fullLatticeIndex(int dim[4], int index, int oddBit)
Definition: host_utils.cpp:591
double kappa5
Definition: host_utils.cpp:51
int index_4d_cb_from_coordinate_4d(const int coordinate[4], const int dim[4])
Definition: host_utils.cpp:388
int fullLatticeIndex_5d_4dpc(int i, int oddBit)
Definition: host_utils.cpp:947
#define spinor_site_size
Definition: host_utils.h:9
void xpay(ColorSpinorField &x, double a, ColorSpinorField &y)
Definition: blas_quda.h:45
void axpy(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:43
void axpby(double a, ColorSpinorField &x, double b, ColorSpinorField &y)
Definition: blas_quda.h:44
void ax(const double &a, GaugeField &u)
Scale the gauge field by the scalar a.
std::complex< double > Complex
Definition: quda_internal.h:86
cudaGaugeField * createExtendedGauge(cudaGaugeField &in, const int *R, TimeProfile &profile, bool redundant_comms=false, QudaReconstructType recon=QUDA_RECONSTRUCT_INVALID)
__host__ __device__ ValueType pow(ValueType x, ExponentType e)
Definition: complex_quda.h:111
FloatingPoint< float > Float
__host__ __device__ T sum(const array< T, s > &a)
Definition: utility.h:76
ColorSpinorParam csParam
Definition: pack_test.cpp:25
Main header file for the QUDA library.
int X[4]
Definition: quda.h:35
QudaGhostExchange ghostExchange
Definition: lattice_field.h:77
int x[QUDA_MAX_DIM]
Definition: lattice_field.h:68
QudaSiteSubset siteSubset
Definition: lattice_field.h:72
#define errorQuda(...)
Definition: util_quda.h:120