30 if (i < 0 || i >= (
Z[0]*
Z[1]*
Z[2]*
Z[3]/2))
31 { printf(
"i out of range in neighborIndex_4d\n"); exit(-1); }
36 int x4 = X/(
Z[2]*
Z[1]*
Z[0]);
37 int x3 = (X/(
Z[1]*
Z[0])) %
Z[2];
38 int x2 = (X/
Z[0]) %
Z[1];
41 x4 = (x4+dx4+
Z[3]) %
Z[3];
42 x3 = (x3+dx3+
Z[2]) %
Z[2];
43 x2 = (x2+dx2+
Z[1]) %
Z[1];
44 x1 = (x1+dx1+
Z[0]) %
Z[0];
46 return (x4*(
Z[2]*
Z[1]*
Z[0]) + x3*(
Z[1]*
Z[0]) + x2*(
Z[0]) + x1) / 2;
56 template <
typename Float>
66 gaugeField = (oddBit ? gaugeOdd : gaugeEven);
75 default: j = -1;
break;
77 gaugeField = (oddBit ? gaugeEven : gaugeOdd);
80 return &gaugeField[dir/2][j*(3*3*2)];
87 template <
typename Float>
94 gaugeField = (oddBit ? gaugeOdd : gaugeEven);
99 int x4 = Y/(
Z[2]*
Z[1]*
Z[0]);
100 int x3 = (Y/(
Z[1]*
Z[0])) %
Z[2];
101 int x2 = (Y/
Z[0]) %
Z[1];
107 Float* ghostGaugeField;
112 int new_x1 = (x1 - d + X1 )% X1;
114 ghostGaugeField = (oddBit?ghostGaugeEven[0]: ghostGaugeOdd[0]);
115 int offset = (n_ghost_faces + x1 -d)*X4*X3*X2/2 + (x4*X3*X2 + x3*X2+x2)/2;
116 return &ghostGaugeField[offset*(3*3*2)];
118 j = (x4*X3*X2*X1 + x3*X2*X1 + x2*X1 + new_x1) / 2;
123 int new_x2 = (x2 - d + X2 )% X2;
125 ghostGaugeField = (oddBit?ghostGaugeEven[1]: ghostGaugeOdd[1]);
126 int offset = (n_ghost_faces + x2 -d)*X4*X3*X1/2 + (x4*X3*X1 + x3*X1+x1)/2;
127 return &ghostGaugeField[offset*(3*3*2)];
129 j = (x4*X3*X2*X1 + x3*X2*X1 + new_x2*X1 + x1) / 2;
135 int new_x3 = (x3 - d + X3 )% X3;
137 ghostGaugeField = (oddBit?ghostGaugeEven[2]: ghostGaugeOdd[2]);
138 int offset = (n_ghost_faces + x3 -d)*X4*X2*X1/2 + (x4*X2*X1 + x2*X1+x1)/2;
139 return &ghostGaugeField[offset*(3*3*2)];
141 j = (x4*X3*X2*X1 + new_x3*X2*X1 + x2*X1 + x1) / 2;
146 int new_x4 = (x4 - d + X4)% X4;
148 ghostGaugeField = (oddBit?ghostGaugeEven[3]: ghostGaugeOdd[3]);
149 int offset = (n_ghost_faces + x4 -d)*X1*X2*X3/2 + (x3*X2*X1 + x2*X1+x1)/2;
150 return &ghostGaugeField[offset*(3*3*2)];
152 j = (new_x4*(X3*X2*X1) + x3*(X2*X1) + x2*(X1) + x1) / 2;
156 default: j = -1; printf(
"ERROR: wrong dir \n"); exit(1);
158 gaugeField = (oddBit ? gaugeEven : gaugeOdd);
162 return &gaugeField[dir/2][j*(3*3*2)];
171 {{1,0}, {0,0}, {0,0}, {0,-1}},
172 {{0,0}, {1,0}, {0,-1}, {0,0}},
173 {{0,0}, {0,1}, {1,0}, {0,0}},
174 {{0,1}, {0,0}, {0,0}, {1,0}}
177 {{1,0}, {0,0}, {0,0}, {0,1}},
178 {{0,0}, {1,0}, {0,1}, {0,0}},
179 {{0,0}, {0,-1}, {1,0}, {0,0}},
180 {{0,-1}, {0,0}, {0,0}, {1,0}}
183 {{1,0}, {0,0}, {0,0}, {1,0}},
184 {{0,0}, {1,0}, {-1,0}, {0,0}},
185 {{0,0}, {-1,0}, {1,0}, {0,0}},
186 {{1,0}, {0,0}, {0,0}, {1,0}}
189 {{1,0}, {0,0}, {0,0}, {-1,0}},
190 {{0,0}, {1,0}, {1,0}, {0,0}},
191 {{0,0}, {1,0}, {1,0}, {0,0}},
192 {{-1,0}, {0,0}, {0,0}, {1,0}}
195 {{1,0}, {0,0}, {0,-1}, {0,0}},
196 {{0,0}, {1,0}, {0,0}, {0,1}},
197 {{0,1}, {0,0}, {1,0}, {0,0}},
198 {{0,0}, {0,-1}, {0,0}, {1,0}}
201 {{1,0}, {0,0}, {0,1}, {0,0}},
202 {{0,0}, {1,0}, {0,0}, {0,-1}},
203 {{0,-1}, {0,0}, {1,0}, {0,0}},
204 {{0,0}, {0,1}, {0,0}, {1,0}}
207 {{1,0}, {0,0}, {-1,0}, {0,0}},
208 {{0,0}, {1,0}, {0,0}, {-1,0}},
209 {{-1,0}, {0,0}, {1,0}, {0,0}},
210 {{0,0}, {-1,0}, {0,0}, {1,0}}
213 {{1,0}, {0,0}, {1,0}, {0,0}},
214 {{0,0}, {1,0}, {0,0}, {1,0}},
215 {{1,0}, {0,0}, {1,0}, {0,0}},
216 {{0,0}, {1,0}, {0,0}, {1,0}}
220 {{0,0}, {0,0}, {0,0}, {0,0}},
221 {{0,0}, {0,0}, {0,0}, {0,0}},
222 {{0,0}, {0,0}, {2,0}, {0,0}},
223 {{0,0}, {0,0}, {0,0}, {2,0}}
227 {{2,0}, {0,0}, {0,0}, {0,0}},
228 {{0,0}, {2,0}, {0,0}, {0,0}},
229 {{0,0}, {0,0}, {0,0}, {0,0}},
230 {{0,0}, {0,0}, {0,0}, {0,0}}
236 template <
typename Float>
238 for (
int i=0; i<4*3*2; i++) res[i] = 0.0;
240 for (
int s = 0; s < 4; s++) {
241 for (
int t = 0; t < 4; t++) {
245 for (
int m = 0; m < 3; m++) {
246 Float spinorRe = spinorIn[t*(3*2) + m*(2) + 0];
247 Float spinorIm = spinorIn[t*(3*2) + m*(2) + 1];
248 res[s*(3*2) + m*(2) + 0] += projRe*spinorRe - projIm*spinorIm;
249 res[s*(3*2) + m*(2) + 1] += projRe*spinorIm + projIm*spinorRe;
270 template <QudaPCType type,
typename sFloat,
typename gFloat>
276 for (
int i=0; i<
V5h*4*3*2; i++) res[i] = 0.0;
279 gFloat *gaugeEven[4], *gaugeOdd[4];
282 for (
int dir = 0; dir < 4; dir++) {
283 gaugeEven[dir] = gaugeFull[dir];
288 int sp_idx,gaugeOddBit;
289 for (
int xs=0;xs<
Ls;xs++) {
290 for (
int gge_idx = 0; gge_idx <
Vh; gge_idx++) {
291 for (
int dir = 0; dir < 8; dir++) {
292 sp_idx=gge_idx+
Vh*xs;
298 gaugeOddBit = (xs%2 == 0 || type ==
QUDA_4D_PC) ? oddBit : (oddBit+1) % 2;
299 gFloat *gauge =
gaugeLink_sgpu(gge_idx, dir, gaugeOddBit, gaugeEven, gaugeOdd);
303 sFloat *
spinor = spinorNeighbor_5d<type>(sp_idx, dir, oddBit, spinorField);
304 sFloat projectedSpinor[4*3*2], gaugedSpinor[4*3*2];
305 int projIdx = 2*(dir/2)+(dir+daggerBit)%2;
308 for (
int s = 0; s < 4; s++) {
310 su3Mul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]);
312 std::cout <<
"spinor:" << std::endl;
314 std::cout <<
"gauge:" << std::endl;
317 su3Tmul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]);
321 sum(&res[sp_idx*(4*3*2)], &res[sp_idx*(4*3*2)], gaugedSpinor, 4*3*2);
328 template <QudaPCType type,
typename sFloat,
typename gFloat>
329 void dslashReference_4d_mgpu(sFloat *res, gFloat **gaugeFull, gFloat **ghostGauge, sFloat *spinorField,
330 sFloat **fwdSpinor, sFloat **backSpinor,
int oddBit,
int daggerBit)
334 gFloat *gaugeEven[4], *gaugeOdd[4];
335 gFloat *ghostGaugeEven[4], *ghostGaugeOdd[4];
337 for (
int dir = 0; dir < 4; dir++)
339 gaugeEven[dir] = gaugeFull[dir];
342 ghostGaugeEven[dir] = ghostGauge[dir];
345 for (
int xs=0;xs<
Ls;xs++)
348 for (
int i = 0; i <
Vh; i++)
351 for (
int dir = 0; dir < 8; dir++)
353 int gaugeOddBit = (xs%2 == 0 || type ==
QUDA_4D_PC) ? oddBit : (oddBit + 1) % 2;
355 gFloat *gauge =
gaugeLink_mgpu(i, dir, gaugeOddBit, gaugeEven, gaugeOdd, ghostGaugeEven, ghostGaugeOdd, 1, 1);
356 sFloat *
spinor = spinorNeighbor_5d_mgpu<type>(sp_idx, dir, oddBit, spinorField, fwdSpinor, backSpinor, 1, 1);
359 int projIdx = 2 * (dir / 2) + (dir + daggerBit) % 2;
362 for (
int s = 0; s < 4; s++) {
364 su3Mul(&gaugedSpinor[s * (3 * 2)], gauge, &projectedSpinor[s * (3 * 2)]);
366 su3Tmul(&gaugedSpinor[s * (3 * 2)], gauge, &projectedSpinor[s * (3 * 2)]);
368 sum(&res[sp_idx * (4 * 3 * 2)], &res[sp_idx * (4 * 3 * 2)], gaugedSpinor, 4 * 3 * 2);
375 template <
bool plus,
class sFloat>
376 void axpby_ssp_project(sFloat *z, sFloat a, sFloat *x, sFloat b, sFloat *y,
int idx_cb_4d,
int s,
int sp)
382 for (
int spin = (
plus ? 0 : 2); spin < (
plus ? 2 : 4); spin++) {
383 for (
int color_comp = 0; color_comp < 6; color_comp++) {
384 z[(s *
Vh + idx_cb_4d) * 24 + spin * 6 + color_comp] = a * x[(s *
Vh + idx_cb_4d) * 24 + spin * 6 + color_comp]
385 + b * y[(sp *
Vh + idx_cb_4d) * 24 + spin * 6 + color_comp];
388 for (
int spin = (
plus ? 2 : 0); spin < (
plus ? 4 : 2); spin++) {
389 for (
int color_comp = 0; color_comp < 6; color_comp++) {
390 z[(s *
Vh + idx_cb_4d) * 24 + spin * 6 + color_comp] = a * x[(s *
Vh + idx_cb_4d) * 24 + spin * 6 + color_comp];
395 template <
typename sFloat>
396 void mdw_eofa_m5_ref(sFloat *res, sFloat *spinorField,
int oddBit,
int daggerBit, sFloat mferm, sFloat
m5, sFloat b,
405 sFloat alpha = b + c;
406 sFloat eofa_norm = alpha * (mq3 - mq2) *
std::pow(alpha + 1., 2 *
Ls)
410 sFloat
kappa = 0.5 * (c * (4. +
m5) - 1.) / (b * (4. +
m5) + 1.);
412 constexpr
int spinor_size = 4 * 3 * 2;
413 for (
int i = 0; i <
V5h; i++) {
414 for (
int one_site = 0; one_site < 24; one_site++) { res[i * spinor_size + one_site] = 0.; }
415 for (
int dir = 8; dir < 10; dir++) {
419 sFloat *
spinor = spinorNeighbor_5d<QUDA_4D_PC>(i, dir, oddBit, spinorField);
420 sFloat projectedSpinor[spinor_size];
421 int projIdx = 2 * (dir / 2) + (dir + daggerBit) % 2;
425 int xs = X / (
Z[3] *
Z[2] *
Z[1] *
Z[0]);
427 if ((xs == 0 && dir == 9) || (xs ==
Ls - 1 && dir == 8)) {
428 ax(projectedSpinor, -mferm, projectedSpinor, spinor_size);
430 sum(&res[i * spinor_size], &res[i * spinor_size], projectedSpinor, spinor_size);
433 axpby((sFloat)1., &spinorField[i * spinor_size],
kappa, &res[i * spinor_size], spinor_size);
437 std::vector<sFloat> shift_coeffs(
Ls);
445 N *= 1. / (b * (
m5 + 4.) + 1.);
446 for (
int s = 0; s <
Ls; s++) {
452 for (
int idx_cb_4d = 0; idx_cb_4d <
Vh; idx_cb_4d++) {
453 for (
int s = 0; s <
Ls; s++) {
454 if (daggerBit == 0) {
456 axpby_ssp_project<true>(res, (sFloat)1., res, shift_coeffs[s], spinorField, idx_cb_4d, s,
Ls - 1);
458 axpby_ssp_project<false>(res, (sFloat)1., res, shift_coeffs[s], spinorField, idx_cb_4d, s, 0);
462 axpby_ssp_project<true>(res, (sFloat)1., res, shift_coeffs[s], spinorField, idx_cb_4d,
Ls - 1, s);
464 axpby_ssp_project<false>(res, (sFloat)1., res, shift_coeffs[s], spinorField, idx_cb_4d, 0, s);
471 void mdw_eofa_m5(
void *res,
void *spinorField,
int oddBit,
int daggerBit,
double mferm,
double m5,
double b,
double c,
475 mdw_eofa_m5_ref<double>((
double *)res, (
double *)spinorField, oddBit, daggerBit, mferm,
m5, b, c, mq1, mq2, mq3,
478 mdw_eofa_m5_ref<float>((
float *)res, (
float *)spinorField, oddBit, daggerBit, mferm,
m5, b, c, mq1, mq2, mq3,
485 template <QudaPCType type,
bool zero_initialize = false,
typename sFloat>
488 for (
int i = 0; i <
V5h; i++) {
489 if (zero_initialize)
for(
int one_site = 0 ; one_site < 24 ; one_site++)
490 res[i*(4*3*2)+one_site] = 0.0;
491 for (
int dir = 8; dir < 10; dir++) {
495 sFloat *
spinor = spinorNeighbor_5d<type>(i, dir, oddBit, spinorField);
496 sFloat projectedSpinor[4*3*2];
497 int projIdx = 2*(dir/2)+(dir+daggerBit)%2;
501 int xs = X/(
Z[3]*
Z[2]*
Z[1]*
Z[0]);
503 if ( (xs == 0 && dir == 9) || (xs ==
Ls-1 && dir == 8) ) {
504 ax(projectedSpinor,(sFloat)(-mferm),projectedSpinor,4*3*2);
506 sum(&res[i*(4*3*2)], &res[i*(4*3*2)], projectedSpinor, 4*3*2);
512 template <
typename sFloat>
515 double *inv_Ftr = (
double*)malloc(
Ls*
sizeof(sFloat));
516 double *Ftr = (
double*)malloc(
Ls*
sizeof(sFloat));
517 for(
int xs = 0 ; xs <
Ls ; xs++)
519 inv_Ftr[xs] = 1.0/(1.0+
pow(2.0*
kappa[xs],
Ls)*mferm);
520 Ftr[xs] = -2.0*
kappa[xs]*mferm*inv_Ftr[xs];
521 for (
int i = 0; i <
Vh; i++) {
522 memcpy(&res[24*(i+
Vh*xs)], &spinorField[24*(i+
Vh*xs)], 24*
sizeof(sFloat));
528 for (
int i = 0; i <
Vh; i++) {
529 ax(&res[12+24*(i+
Vh*(
Ls-1))],(sFloat)(inv_Ftr[0]), &spinorField[12+24*(i+
Vh*(
Ls-1))], 12);
533 for(
int xs = 0 ; xs <=
Ls-2 ; ++xs)
535 for (
int i = 0; i <
Vh; i++) {
536 axpy((sFloat)(2.0*
kappa[xs]), &res[24*(i+
Vh*xs)], &res[24*(i+
Vh*(xs+1))], 12);
537 axpy((sFloat)Ftr[xs], &res[12+24*(i+
Vh*xs)], &res[12+24*(i+
Vh*(
Ls-1))], 12);
539 for (
int tmp_s = 0 ; tmp_s <
Ls ; tmp_s++)
540 Ftr[tmp_s] *= 2.0*
kappa[tmp_s];
542 for(
int xs = 0 ; xs <
Ls ; xs++)
544 Ftr[xs] = -
pow(2.0*
kappa[xs],
Ls-1)*mferm*inv_Ftr[xs];
547 for(
int xs =
Ls-2 ; xs >=0 ; --xs)
549 for (
int i = 0; i <
Vh; i++) {
550 axpy((sFloat)Ftr[xs], &res[24*(i+
Vh*(
Ls-1))], &res[24*(i+
Vh*xs)], 12);
551 axpy((sFloat)(2.0*
kappa[xs]), &res[12+24*(i+
Vh*(xs+1))], &res[12+24*(i+
Vh*xs)], 12);
553 for (
int tmp_s = 0 ; tmp_s <
Ls ; tmp_s++)
554 Ftr[tmp_s] /= 2.0*
kappa[tmp_s];
557 for (
int i = 0; i <
Vh; i++) {
558 ax(&res[24*(i+
Vh*(
Ls-1))], (sFloat)(inv_Ftr[
Ls-1]), &res[24*(i+
Vh*(
Ls-1))], 12);
564 for (
int i = 0; i <
Vh; i++) {
565 ax(&res[24*(i+
Vh*(
Ls-1))],(sFloat)(inv_Ftr[0]), &spinorField[24*(i+
Vh*(
Ls-1))], 12);
569 for(
int xs = 0 ; xs <=
Ls-2 ; ++xs)
571 for (
int i = 0; i <
Vh; i++) {
572 axpy((sFloat)Ftr[xs], &res[24*(i+
Vh*xs)], &res[24*(i+
Vh*(
Ls-1))], 12);
573 axpy((sFloat)(2.0*
kappa[xs]), &res[12+24*(i+
Vh*xs)], &res[12+24*(i+
Vh*(xs+1))], 12);
575 for (
int tmp_s = 0 ; tmp_s <
Ls ; tmp_s++)
576 Ftr[tmp_s] *= 2.0*
kappa[tmp_s];
578 for(
int xs = 0 ; xs <
Ls ; xs++)
580 Ftr[xs] = -
pow(2.0*
kappa[xs],
Ls-1)*mferm*inv_Ftr[xs];
583 for(
int xs =
Ls-2 ; xs >=0 ; --xs)
585 for (
int i = 0; i <
Vh; i++) {
586 axpy((sFloat)(2.0*
kappa[xs]), &res[24*(i+
Vh*(xs+1))], &res[24*(i+
Vh*xs)], 12);
587 axpy((sFloat)Ftr[xs], &res[12+24*(i+
Vh*(
Ls-1))], &res[12+24*(i+
Vh*xs)], 12);
589 for (
int tmp_s = 0 ; tmp_s <
Ls ; tmp_s++)
590 Ftr[tmp_s] /= 2.0*
kappa[tmp_s];
593 for (
int i = 0; i <
Vh; i++) {
594 ax(&res[12+24*(i+
Vh*(
Ls-1))], (sFloat)(inv_Ftr[
Ls-1]), &res[12+24*(i+
Vh*(
Ls-1))], 12);
601 template <
typename sComplex> sComplex
cpow(
const sComplex &x,
int y)
603 static_assert(
sizeof(sComplex) ==
sizeof(
Complex),
"C and C++ complex type sizes do not match");
607 sComplex z =
reinterpret_cast<sComplex &
>(z_);
612 template <
typename sFloat,
typename sComplex>
615 sComplex *inv_Ftr = (sComplex *)malloc(
Ls *
sizeof(sComplex));
616 sComplex *Ftr = (sComplex *)malloc(
Ls *
sizeof(sComplex));
617 for (
int xs = 0; xs <
Ls; xs++) {
618 inv_Ftr[xs] = 1.0 / (1.0 +
cpow(2.0 *
kappa[xs],
Ls) * mferm);
619 Ftr[xs] = -2.0 *
kappa[xs] * mferm * inv_Ftr[xs];
620 for (
int i = 0; i <
Vh; i++) {
621 memcpy(&res[24 * (i +
Vh * xs)], &spinorField[24 * (i +
Vh * xs)], 24 *
sizeof(sFloat));
624 if (daggerBit == 0) {
626 for (
int i = 0; i <
Vh; i++) {
627 ax((sComplex *)&res[12 + 24 * (i +
Vh * (
Ls - 1))], inv_Ftr[0],
628 (sComplex *)&spinorField[12 + 24 * (i +
Vh * (
Ls - 1))], 6);
632 for (
int xs = 0; xs <=
Ls - 2; ++xs) {
633 for (
int i = 0; i <
Vh; i++) {
634 axpy((2.0 *
kappa[xs]), (sComplex *)&res[24 * (i +
Vh * xs)], (sComplex *)&res[24 * (i +
Vh * (xs + 1))], 6);
635 axpy(Ftr[xs], (sComplex *)&res[12 + 24 * (i +
Vh * xs)], (sComplex *)&res[12 + 24 * (i +
Vh * (
Ls - 1))], 6);
637 for (
int tmp_s = 0; tmp_s <
Ls; tmp_s++) Ftr[tmp_s] *= 2.0 *
kappa[tmp_s];
639 for (
int xs = 0; xs <
Ls; xs++) Ftr[xs] = -
cpow(2.0 *
kappa[xs],
Ls - 1) * mferm * inv_Ftr[xs];
642 for (
int xs =
Ls - 2; xs >= 0; --xs) {
643 for (
int i = 0; i <
Vh; i++) {
644 axpy(Ftr[xs], (sComplex *)&res[24 * (i +
Vh * (
Ls - 1))], (sComplex *)&res[24 * (i +
Vh * xs)], 6);
645 axpy((2.0 *
kappa[xs]), (sComplex *)&res[12 + 24 * (i +
Vh * (xs + 1))],
646 (sComplex *)&res[12 + 24 * (i +
Vh * xs)], 6);
648 for (
int tmp_s = 0; tmp_s <
Ls; tmp_s++) Ftr[tmp_s] /= 2.0 *
kappa[tmp_s];
651 for (
int i = 0; i <
Vh; i++) {
652 ax((sComplex *)&res[24 * (i +
Vh * (
Ls - 1))], inv_Ftr[
Ls - 1], (sComplex *)&res[24 * (i +
Vh * (
Ls - 1))], 6);
656 for (
int i = 0; i <
Vh; i++) {
657 ax((sComplex *)&res[24 * (i +
Vh * (
Ls - 1))], inv_Ftr[0], (sComplex *)&spinorField[24 * (i +
Vh * (
Ls - 1))], 6);
661 for (
int xs = 0; xs <=
Ls - 2; ++xs) {
662 for (
int i = 0; i <
Vh; i++) {
663 axpy(Ftr[xs], (sComplex *)&res[24 * (i +
Vh * xs)], (sComplex *)&res[24 * (i +
Vh * (
Ls - 1))], 6);
664 axpy((2.0 *
kappa[xs]), (sComplex *)&res[12 + 24 * (i +
Vh * xs)],
665 (sComplex *)&res[12 + 24 * (i +
Vh * (xs + 1))], 6);
667 for (
int tmp_s = 0; tmp_s <
Ls; tmp_s++) Ftr[tmp_s] *= 2.0 *
kappa[tmp_s];
669 for (
int xs = 0; xs <
Ls; xs++) Ftr[xs] = -
cpow(2.0 *
kappa[xs],
Ls - 1) * mferm * inv_Ftr[xs];
672 for (
int xs =
Ls - 2; xs >= 0; --xs) {
673 for (
int i = 0; i <
Vh; i++) {
674 axpy((2.0 *
kappa[xs]), (sComplex *)&res[24 * (i +
Vh * (xs + 1))], (sComplex *)&res[24 * (i +
Vh * xs)], 6);
675 axpy(Ftr[xs], (sComplex *)&res[12 + 24 * (i +
Vh * (
Ls - 1))], (sComplex *)&res[12 + 24 * (i +
Vh * xs)], 6);
677 for (
int tmp_s = 0; tmp_s <
Ls; tmp_s++) Ftr[tmp_s] /= 2.0 *
kappa[tmp_s];
680 for (
int i = 0; i <
Vh; i++) {
681 ax((sComplex *)&res[12 + 24 * (i +
Vh * (
Ls - 1))], inv_Ftr[
Ls - 1],
682 (sComplex *)&res[12 + 24 * (i +
Vh * (
Ls - 1))], 6);
689 template <
typename sFloat>
690 void mdw_eofa_m5inv_ref(sFloat *res, sFloat *spinorField,
int oddBit,
int daggerBit, sFloat mferm, sFloat
m5, sFloat b,
699 sFloat alpha = b + c;
700 sFloat eofa_norm = alpha * (mq3 - mq2) *
std::pow(alpha + 1., 2 *
Ls)
703 sFloat
kappa5 = (c * (4. +
m5) - 1.) / (b * (4. +
m5) + 1.);
705 using sComplex =
double _Complex;
707 std::vector<sComplex> kappa_array(
Ls, -0.5 *
kappa5);
708 std::vector<sFloat> eofa_u(
Ls);
709 std::vector<sFloat> eofa_x(
Ls);
710 std::vector<sFloat> eofa_y(
Ls);
720 for (
int s = 0; s <
Ls; s++) {
724 sFloat sherman_morrison_fac;
726 sFloat factor = -
kappa5 * mferm;
730 eofa_x[0] = eofa_u[0];
731 for (
int s =
Ls - 1; s > 0; s--) {
732 eofa_x[0] -= factor * eofa_u[s];
735 eofa_x[0] /= 1. + factor;
736 for (
int s = 1; s <
Ls; s++) { eofa_x[s] = eofa_x[s - 1] * (-
kappa5) + eofa_u[s]; }
738 eofa_y[
Ls - 1] = 1. / (1. + factor);
739 sherman_morrison_fac = eofa_x[
Ls - 1];
740 for (
int s =
Ls - 1; s > 0; s--) { eofa_y[s - 1] = eofa_y[s] * (-
kappa5); }
744 eofa_x[
Ls - 1] = eofa_u[
Ls - 1];
745 for (
int s = 0; s <
Ls - 1; s++) {
746 eofa_x[
Ls - 1] -= factor * eofa_u[s];
749 eofa_x[
Ls - 1] /= 1. + factor;
750 for (
int s =
Ls - 1; s > 0; s--) { eofa_x[s - 1] = eofa_x[s] * (-
kappa5) + eofa_u[s - 1]; }
752 eofa_y[0] = 1. / (1. + factor);
753 sherman_morrison_fac = eofa_x[0];
754 for (
int s = 1; s <
Ls; s++) { eofa_y[s] = eofa_y[s - 1] * (-
kappa5); }
756 sherman_morrison_fac = -0.5 / (1. + sherman_morrison_fac);
759 for (
int idx_cb_4d = 0; idx_cb_4d <
Vh; idx_cb_4d++) {
760 for (
int s = 0; s <
Ls; s++) {
761 for (
int sp = 0; sp <
Ls; sp++) {
762 sFloat t = 2.0 * sherman_morrison_fac;
763 if (daggerBit == 0) {
764 t *= eofa_x[s] * eofa_y[sp];
766 axpby_ssp_project<true>(res, (sFloat)1., res, t, spinorField, idx_cb_4d, s, sp);
768 axpby_ssp_project<false>(res, (sFloat)1., res, t, spinorField, idx_cb_4d, s, sp);
771 t *= eofa_y[s] * eofa_x[sp];
773 axpby_ssp_project<true>(res, (sFloat)1., res, t, spinorField, idx_cb_4d, s, sp);
775 axpby_ssp_project<false>(res, (sFloat)1., res, t, spinorField, idx_cb_4d, s, sp);
783 void mdw_eofa_m5inv(
void *res,
void *spinorField,
int oddBit,
int daggerBit,
double mferm,
double m5,
double b,
double c,
787 mdw_eofa_m5inv_ref<double>((
double *)res, (
double *)spinorField, oddBit, daggerBit, mferm,
m5, b, c, mq1, mq2, mq3,
790 mdw_eofa_m5inv_ref<float>((
float *)res, (
float *)spinorField, oddBit, daggerBit, mferm,
m5, b, c, mq1, mq2, mq3,
802 dslashReference_4d_sgpu<QUDA_5D_PC>((
double*)out, (
double**)gauge, (
double*)in, oddBit, daggerBit);
803 dslashReference_5th<QUDA_5D_PC>((
double*)out, (
double*)in, oddBit, daggerBit, mferm);
805 dslashReference_4d_sgpu<QUDA_5D_PC>((
float*)out, (
float**)gauge, (
float*)in, oddBit, daggerBit);
806 dslashReference_5th<QUDA_5D_PC>((
float*)out, (
float*)in, oddBit, daggerBit, (
float)mferm);
813 void **ghostGauge = (
void**)cpu.
Ghost();
822 for (
int d=0; d<4; d++)
csParam.
x[d] =
Z[d];
840 else errorQuda(
"ERROR: full parity not supported in function %s", __FUNCTION__);
849 dslashReference_4d_mgpu<QUDA_5D_PC>((
double*)out, (
double**)gauge, (
double**)ghostGauge, (
double*)in,(
double**)fwd_nbr_spinor, (
double**)back_nbr_spinor, oddBit, daggerBit);
851 dslashReference_5th<QUDA_5D_PC>((
double*)out, (
double*)in, oddBit, daggerBit, mferm);
853 dslashReference_4d_mgpu<QUDA_5D_PC>((
float*)out, (
float**)gauge, (
float**)ghostGauge, (
float*)in,
854 (
float**)fwd_nbr_spinor, (
float**)back_nbr_spinor, oddBit, daggerBit);
855 dslashReference_5th<QUDA_5D_PC>((
float*)out, (
float*)in, oddBit, daggerBit, (
float)mferm);
866 dslashReference_4d_sgpu<QUDA_4D_PC>((
double*)out, (
double**)gauge, (
double*)in, oddBit, daggerBit);
868 dslashReference_4d_sgpu<QUDA_4D_PC>((
float*)out, (
float**)gauge, (
float*)in, oddBit, daggerBit);
875 void **ghostGauge = (
void**)cpu.
Ghost();
884 for (
int d=0; d<4; d++)
csParam.
x[d] =
Z[d];
902 else errorQuda(
"ERROR: full parity not supported in function %s", __FUNCTION__);
910 dslashReference_4d_mgpu<QUDA_4D_PC>((
double*)out, (
double**)gauge, (
double**)ghostGauge, (
double*)in,(
double**)fwd_nbr_spinor, (
double**)back_nbr_spinor, oddBit, daggerBit);
912 dslashReference_4d_mgpu<QUDA_4D_PC>((
float*)out, (
float**)gauge, (
float**)ghostGauge, (
float*)in,
913 (
float**)fwd_nbr_spinor, (
float**)back_nbr_spinor, oddBit, daggerBit);
923 if (zero_initialize) dslashReference_5th<QUDA_4D_PC, true>((
double*)out, (
double*)in, oddBit, daggerBit, mferm);
924 else dslashReference_5th<QUDA_4D_PC, false>((
double*)out, (
double*)in, oddBit, daggerBit, mferm);
926 if (zero_initialize) dslashReference_5th<QUDA_4D_PC, true>((
float*)out, (
float*)in, oddBit, daggerBit, (
float)mferm);
927 else dslashReference_5th<QUDA_4D_PC, false>((
float*)out, (
float*)in, oddBit, daggerBit, (
float)mferm);
954 if (zero_initialize) dslashReference_5th<QUDA_4D_PC,true>((
double*)out, (
double*)in, oddBit, daggerBit, mferm);
955 else dslashReference_5th<QUDA_4D_PC,false>((
double*)out, (
double*)in, oddBit, daggerBit, mferm);
957 if (zero_initialize) dslashReference_5th<QUDA_4D_PC,true>((
float*)out, (
float*)in, oddBit, daggerBit, (
float)mferm);
958 else dslashReference_5th<QUDA_4D_PC,false>((
float*)out, (
float*)in, oddBit, daggerBit, (
float)mferm);
960 for(
int xs = 0 ; xs <
Ls ; xs++) {
970 if (zero_initialize) dslashReference_5th<QUDA_4D_PC, true>((
double*)out, (
double*)in, oddBit, daggerBit, mferm);
971 else dslashReference_5th<QUDA_4D_PC, false>((
double*)out, (
double*)in, oddBit, daggerBit, mferm);
972 for(
int xs = 0 ; xs <
Ls ; xs++)
979 dslashReference_5th<QUDA_4D_PC, true>((
float *)out, (
float *)in, oddBit, daggerBit, (
float)mferm);
980 else dslashReference_5th<QUDA_4D_PC,false>((
float*)out, (
float*)in, oddBit, daggerBit, (
float)mferm);
981 for(
int xs = 0 ; xs <
Ls ; xs++)
1009 void *outEven = out;
1022 void mdw_mat(
void *out,
void **gauge,
void *in,
double _Complex *kappa_b,
double _Complex *kappa_c,
int dagger,
1026 double _Complex *
kappa5 = (
double _Complex *)malloc(
Ls *
sizeof(
double _Complex));
1028 for(
int xs = 0; xs <
Ls ; xs++)
kappa5[xs] = 0.5*kappa_b[xs]/kappa_c[xs];
1032 void *outEven = out;
1036 mdw_dslash_4_pre(
tmp, gauge, inEven, 0,
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1041 mdw_dslash_4_pre(outOdd, gauge,
tmp, 0,
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1045 for(
int xs = 0 ; xs <
Ls ; xs++) {
1051 mdw_dslash_4_pre(
tmp, gauge, inOdd, 1,
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1056 mdw_dslash_4_pre(outEven, gauge,
tmp, 1,
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1060 for(
int xs = 0 ; xs <
Ls ; xs++) {
1070 double mferm,
double m5,
double b,
double c,
double mq1,
double mq2,
double mq3,
int eofa_pm,
1075 using sComplex =
double _Complex;
1077 std::vector<sComplex> b_array(
Ls, b);
1078 std::vector<sComplex> c_array(
Ls, c);
1080 auto b5 = b_array.data();
1081 auto c5 = c_array.data();
1083 auto kappa_b = 0.5 / (b * (4. +
m5) + 1.);
1087 void *outEven = out;
1091 mdw_dslash_4_pre(
tmp, gauge, inEven, 0,
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1093 mdw_eofa_m5(
tmp, inOdd, 1,
dagger, mferm,
m5, b, c, mq1, mq2, mq3,
eofa_pm,
eofa_shift, precision);
1096 mdw_dslash_4_pre(outOdd, gauge,
tmp, 0,
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1097 mdw_eofa_m5(
tmp, inOdd, 1,
dagger, mferm,
m5, b, c, mq1, mq2, mq3,
eofa_pm,
eofa_shift, precision);
1100 for (
int xs = 0; xs <
Ls; xs++) {
1106 mdw_dslash_4_pre(
tmp, gauge, inOdd, 1,
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1108 mdw_eofa_m5(
tmp, inEven, 0,
dagger, mferm,
m5, b, c, mq1, mq2, mq3,
eofa_pm,
eofa_shift, precision);
1111 mdw_dslash_4_pre(outEven, gauge,
tmp, 1,
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1112 mdw_eofa_m5(
tmp, inEven, 0,
dagger, mferm,
m5, b, c, mq1, mq2, mq3,
eofa_pm,
eofa_shift, precision);
1115 for (
int xs = 0; xs <
Ls; xs++) {
1128 dagger_bit = (dagger_bit == 1) ? 0 : 1;
1157 double *
kappa5 = (
double*)malloc(
Ls*
sizeof(
double));
1158 for(
int xs = 0; xs <
Ls ; xs++)
1162 double *output = (
double*)out;
1170 if (symmetric && !dagger_bit) {
1176 }
else if (symmetric && dagger_bit) {
1194 void mdw_matpc(
void *out,
void **gauge,
void *in,
double _Complex *kappa_b,
double _Complex *kappa_c,
1196 double _Complex *
b5,
double _Complex *
c5)
1199 double _Complex *
kappa5 = (
double _Complex *)malloc(
Ls *
sizeof(
double _Complex));
1200 double _Complex *kappa2 = (
double _Complex *)malloc(
Ls *
sizeof(
double _Complex));
1201 double _Complex *kappa_mdwf = (
double _Complex *)malloc(
Ls *
sizeof(
double _Complex));
1202 for(
int xs = 0; xs <
Ls ; xs++)
1204 kappa5[xs] = 0.5*kappa_b[xs]/kappa_c[xs];
1205 kappa2[xs] = -kappa_b[xs]*kappa_b[xs];
1206 kappa_mdwf[xs] = -
kappa5[xs];
1213 if (symmetric && !
dagger) {
1214 mdw_dslash_4_pre(
tmp, gauge, in,
parity[1],
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1217 mdw_dslash_4_pre(out, gauge,
tmp,
parity[0],
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1220 for(
int xs = 0 ; xs <
Ls ; xs++) {
1224 }
else if (symmetric &&
dagger) {
1227 mdw_dslash_4_pre(
tmp, gauge, out,
parity[0],
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1230 mdw_dslash_4_pre(out, gauge,
tmp,
parity[1],
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1231 for(
int xs = 0 ; xs <
Ls ; xs++) {
1235 }
else if (!symmetric && !
dagger) {
1236 mdw_dslash_4_pre(out, gauge, in,
parity[1],
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1239 mdw_dslash_4_pre(
tmp, gauge, out,
parity[0],
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1242 for(
int xs = 0 ; xs <
Ls ; xs++) {
1246 }
else if (!symmetric &&
dagger) {
1248 mdw_dslash_4_pre(
tmp, gauge, out,
parity[1],
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1251 mdw_dslash_4_pre(out, gauge,
tmp,
parity[0],
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1253 for(
int xs = 0 ; xs <
Ls ; xs++) {
1273 using sComplex =
double _Complex;
1275 std::vector<sComplex> kappa2_array(
Ls, -0.25 / (b * (4. +
m5) + 1.) / (b * (4. +
m5) + 1.));
1276 std::vector<sComplex> b_array(
Ls, b);
1277 std::vector<sComplex> c_array(
Ls, c);
1279 auto kappa2 = kappa2_array.data();
1280 auto b5 = b_array.data();
1281 auto c5 = c_array.data();
1287 if (symmetric && !
dagger) {
1288 mdw_dslash_4_pre(
tmp, gauge, in,
parity[1],
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1290 mdw_eofa_m5inv(
tmp, out,
parity[1],
dagger, mferm,
m5, b, c, mq1, mq2, mq3,
eofa_pm,
eofa_shift, precision);
1291 mdw_dslash_4_pre(out, gauge,
tmp,
parity[0],
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1293 mdw_eofa_m5inv(out,
tmp,
parity[0],
dagger, mferm,
m5, b, c, mq1, mq2, mq3,
eofa_pm,
eofa_shift, precision);
1294 for (
int xs = 0; xs <
Ls; xs++) {
1298 }
else if (symmetric &&
dagger) {
1299 mdw_eofa_m5inv(
tmp, in,
parity[1],
dagger, mferm,
m5, b, c, mq1, mq2, mq3,
eofa_pm,
eofa_shift, precision);
1301 mdw_dslash_4_pre(
tmp, gauge, out,
parity[0],
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1302 mdw_eofa_m5inv(out,
tmp,
parity[0],
dagger, mferm,
m5, b, c, mq1, mq2, mq3,
eofa_pm,
eofa_shift, precision);
1304 mdw_dslash_4_pre(out, gauge,
tmp,
parity[1],
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1305 for (
int xs = 0; xs <
Ls; xs++) {
1309 }
else if (!symmetric && !
dagger) {
1310 mdw_dslash_4_pre(out, gauge, in,
parity[1],
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1312 mdw_eofa_m5inv(out,
tmp,
parity[1],
dagger, mferm,
m5, b, c, mq1, mq2, mq3,
eofa_pm,
eofa_shift, precision);
1313 mdw_dslash_4_pre(
tmp, gauge, out,
parity[0],
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1315 mdw_eofa_m5(
tmp, in,
parity[0],
dagger, mferm,
m5, b, c, mq1, mq2, mq3,
eofa_pm,
eofa_shift, precision);
1316 for (
int xs = 0; xs <
Ls; xs++) {
1320 }
else if (!symmetric &&
dagger) {
1322 mdw_dslash_4_pre(
tmp, gauge, out,
parity[1],
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1323 mdw_eofa_m5inv(out,
tmp,
parity[0],
dagger, mferm,
m5, b, c, mq1, mq2, mq3,
eofa_pm,
eofa_shift, precision);
1325 mdw_dslash_4_pre(out, gauge,
tmp,
parity[0],
dagger, precision,
gauge_param, mferm,
b5,
c5,
true);
1326 mdw_eofa_m5(
tmp, in,
parity[0],
dagger, mferm,
m5, b, c, mq1, mq2, mq3,
eofa_pm,
eofa_shift, precision);
1327 for (
int xs = 0; xs <
Ls; xs++) {
1338 void mdw_mdagm_local(
void *out,
void **gauge,
void *in,
double _Complex *kappa_b,
double _Complex *kappa_c,
1340 double _Complex *
b5,
double _Complex *
c5)
1351 for (
int d = 0; d < 4; d++) {
1352 W[d] =
Z[d] + 2 * R[d];
1353 padded_V *=
Z[d] + 2 * R[d];
1355 int padded_V5 = padded_V *
Ls;
1356 int padded_Vh = padded_V / 2;
1357 int padded_V5h = padded_Vh *
Ls;
1359 static_assert(
sizeof(
char) == 1,
"This code assumes sizeof(char) == 1.");
1361 char *padded_in = (
char *)malloc(padded_V5h *
spinor_site_size * precision);
1363 char *padded_out = (
char *)malloc(padded_V5h *
spinor_site_size * precision);
1365 char *padded_tmp = (
char *)malloc(padded_V5h *
spinor_site_size * precision);
1368 char *in_alias = (
char *)in;
1369 char *out_alias = (
char *)out;
1371 for (
int s = 0; s <
Ls; s++) {
1372 for (
int index_cb_4d = 0; index_cb_4d <
Vh; index_cb_4d++) {
1378 memcpy(&padded_in[
spinor_site_size * precision * (s * padded_Vh + padded_index_cb_4d)],
1384 for (
int d = 0; d < 4; d++) { padded_gauge_param.
X[d] += 2 * R[d]; }
1386 void **padded_gauge_p = (
void **)(padded_gauge->
Gauge_p());
1396 for (
int d = 0; d < 4; d++) {
1402 mdw_matpc(padded_tmp, padded_gauge_p, padded_in, kappa_b, kappa_c,
matpc_type, 0, precision, padded_gauge_param,
1405 mdw_matpc(padded_out, padded_gauge_p, padded_tmp, kappa_b, kappa_c,
matpc_type, 1, precision, padded_gauge_param,
1412 for (
int d = 0; d < 4; d++) {
Z[d] = Z_old[d]; }
1414 for (
int s = 0; s <
Ls; s++) {
1415 for (
int index_cb_4d = 0; index_cb_4d <
Vh; index_cb_4d++) {
1422 &padded_out[
spinor_site_size * precision * (s * padded_Vh + padded_index_cb_4d)],
1431 delete padded_gauge;
1487 void matpc(
void *outEven,
void **gauge,
void *inEven,
double kappa,
QudaGammaBasis gammaBasis
void setPrecision(QudaPrecision precision, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION, bool force_native=false)
QudaFieldOrder fieldOrder
const void ** Ghost() const
void exchangeGhost(QudaParity parity, int nFace, int dagger, const MemoryLocation *pack_destination=nullptr, const MemoryLocation *halo_location=nullptr, bool gdr_send=false, bool gdr_recv=false, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION) const
This is a unified ghost exchange function for doing a complete halo exchange regardless of the type o...
static void * fwdGhostFaceBuffer[QUDA_MAX_DIM]
static void * backGhostFaceBuffer[QUDA_MAX_DIM]
int comm_dim_partitioned(int dim)
void * memset(void *s, int c, size_t n)
cudaColorSpinorField * tmp
cpuColorSpinorField * spinor
QudaGaugeParam gauge_param
void mdw_eofa_m5inv_ref(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm, sFloat m5, sFloat b, sFloat c, sFloat mq1, sFloat mq2, sFloat mq3, int eofa_pm, sFloat eofa_shift)
void dw_mat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
void dw_4d_matpc(void *out, void **gauge, void *in, double kappa, QudaMatPCType matpc_type, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
void mdw_matpc(void *out, void **gauge, void *in, double _Complex *kappa_b, double _Complex *kappa_c, QudaMatPCType matpc_type, int dagger, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *b5, double _Complex *c5)
void dslashReference_5th(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm)
void dw_4d_mat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
void dslash_4_4d(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
void mdw_mat(void *out, void **gauge, void *in, double _Complex *kappa_b, double _Complex *kappa_c, int dagger, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *b5, double _Complex *c5)
void mdw_dslash_4_pre(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *b5, double _Complex *c5, bool zero_initialize)
void dw_dslash(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
void dslashReference_4d_sgpu(sFloat *res, gFloat **gaugeFull, sFloat *spinorField, int oddBit, int daggerBit)
Float * gaugeLink_mgpu(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd, Float **ghostGaugeEven, Float **ghostGaugeOdd, int n_ghost_faces, int nbr_distance)
void axpby_ssp_project(sFloat *z, sFloat a, sFloat *x, sFloat b, sFloat *y, int idx_cb_4d, int s, int sp)
void dw_dslash_5_4d(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, bool zero_initialize)
void mdw_dslash_5_inv(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *kappa)
void dw_matdagmat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
void mdw_eofa_m5_ref(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm, sFloat m5, sFloat b, sFloat c, sFloat mq1, sFloat mq2, sFloat mq3, int eofa_pm, sFloat eofa_shift)
void mdw_eofa_m5(void *res, void *spinorField, int oddBit, int daggerBit, double mferm, double m5, double b, double c, double mq1, double mq2, double mq3, int eofa_pm, double eofa_shift, QudaPrecision precision)
Float * gaugeLink_sgpu(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd)
void mdslashReference_5th_inv(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm, sComplex *kappa)
void dslash_5_inv(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double *kappa)
void multiplySpinorByDiracProjector5(Float *res, int projIdx, Float *spinorIn)
void mdw_mdagm_local(void *out, void **gauge, void *in, double _Complex *kappa_b, double _Complex *kappa_c, QudaMatPCType matpc_type, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *b5, double _Complex *c5)
void mdw_eofa_matpc(void *out, void **gauge, void *in, QudaMatPCType matpc_type, int dagger, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double m5, double b, double c, double mq1, double mq2, double mq3, int eofa_pm, double eofa_shift)
void matpcdagmatpc(void *out, void **gauge, void *in, double kappa, QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm, QudaMatPCType matpc_type)
void mdw_eofa_m5inv(void *res, void *spinorField, int oddBit, int daggerBit, double mferm, double m5, double b, double c, double mq1, double mq2, double mq3, int eofa_pm, double eofa_shift, QudaPrecision precision)
void dw_matpc(void *out, void **gauge, void *in, double kappa, QudaMatPCType matpc_type, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
const double projector[10][4][4][2]
void dslashReference_5th_inv(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm, double *kappa)
void mdw_dslash_5(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *kappa, bool zero_initialize)
void matpc(void *outEven, void **gauge, void *inEven, double kappa, QudaMatPCType matpc_type, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm)
sComplex cpow(const sComplex &x, int y)
void matdagmat(void *out, void **gauge, void *in, double kappa, QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm)
int neighborIndex_4d(int i, int oddBit, int dx4, int dx3, int dx2, int dx1)
void mdw_eofa_mat(void *out, void **gauge, void *in, int dagger, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double m5, double b, double c, double mq1, double mq2, double mq3, int eofa_pm, double eofa_shift)
enum QudaPrecision_s QudaPrecision
@ QUDA_PARITY_SITE_SUBSET
@ QUDA_DEGRAND_ROSSI_GAMMA_BASIS
@ QUDA_GHOST_EXCHANGE_PAD
@ QUDA_MATPC_ODD_ODD_ASYMMETRIC
@ QUDA_MATPC_EVEN_EVEN_ASYMMETRIC
enum QudaMatPCType_s QudaMatPCType
@ QUDA_EVEN_ODD_SITE_ORDER
@ QUDA_SPACE_SPIN_COLOR_FIELD_ORDER
@ QUDA_REFERENCE_FIELD_CREATE
enum QudaParity_s QudaParity
void cxpay(void *x, double _Complex a, void *y, int length, QudaPrecision precision)
int fullLatticeIndex_5d(int i, int oddBit)
int fullLatticeIndex_4d(int i, int oddBit)
void printSpinorElement(void *spinor, int X, QudaPrecision precision)
void coordinate_from_shrinked_index(int coordinate[4], int shrinked_index, const int shrinked_dim[4], const int shift[4], int parity)
int fullLatticeIndex(int dim[4], int index, int oddBit)
int index_4d_cb_from_coordinate_4d(const int coordinate[4], const int dim[4])
int fullLatticeIndex_5d_4dpc(int i, int oddBit)
void xpay(ColorSpinorField &x, double a, ColorSpinorField &y)
void axpy(double a, ColorSpinorField &x, ColorSpinorField &y)
void axpby(double a, ColorSpinorField &x, double b, ColorSpinorField &y)
void ax(const double &a, GaugeField &u)
Scale the gauge field by the scalar a.
std::complex< double > Complex
cudaGaugeField * createExtendedGauge(cudaGaugeField &in, const int *R, TimeProfile &profile, bool redundant_comms=false, QudaReconstructType recon=QUDA_RECONSTRUCT_INVALID)
__host__ __device__ ValueType pow(ValueType x, ExponentType e)
FloatingPoint< float > Float
__host__ __device__ T sum(const array< T, s > &a)
Main header file for the QUDA library.
QudaGhostExchange ghostExchange
QudaSiteSubset siteSubset