31 if (i < 0 || i >= (
Z[0]*
Z[1]*
Z[2]*
Z[3]/2))
32 { printf(
"i out of range in neighborIndex_4d\n"); exit(-1); }
37 int x4 = X/(
Z[2]*
Z[1]*
Z[0]);
38 int x3 = (X/(
Z[1]*
Z[0])) %
Z[2];
39 int x2 = (X/
Z[0]) %
Z[1];
42 x4 = (x4+dx4+
Z[3]) %
Z[3];
43 x3 = (x3+dx3+
Z[2]) %
Z[2];
44 x2 = (x2+dx2+
Z[1]) %
Z[1];
45 x1 = (x1+dx1+
Z[0]) %
Z[0];
47 return (x4*(
Z[2]*
Z[1]*
Z[0]) + x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2;
57 template <
typename Float>
67 gaugeField = (oddBit ? gaugeOdd : gaugeEven);
76 default: j = -1;
break;
78 gaugeField = (oddBit ? gaugeEven : gaugeOdd);
81 return &gaugeField[dir/2][j*(3*3*2)];
88 template <
typename Float>
89 Float *
gaugeLink_mgpu(
int i,
int dir,
int oddBit, Float **gaugeEven, Float **gaugeOdd, Float** ghostGaugeEven, Float** ghostGaugeOdd,
int n_ghost_faces,
int nbr_distance) {
95 gaugeField = (oddBit ? gaugeOdd : gaugeEven);
100 int x4 = Y/(
Z[2]*
Z[1]*
Z[0]);
101 int x3 = (Y/(
Z[1]*
Z[0])) %
Z[2];
102 int x2 = (Y/
Z[0]) %
Z[1];
108 Float* ghostGaugeField;
113 int new_x1 = (x1 - d +
X1 )% X1;
115 ghostGaugeField = (oddBit?ghostGaugeEven[0]: ghostGaugeOdd[0]);
116 int offset = (n_ghost_faces + x1 -d)*X4*X3*X2/2 + (x4*X3*X2 + x3*X2+x2)/2;
117 return &ghostGaugeField[offset*(3*3*2)];
119 j = (x4*X3*X2*X1 + x3*X2*X1 + x2*X1 + new_x1) / 2;
124 int new_x2 = (x2 - d +
X2 )% X2;
126 ghostGaugeField = (oddBit?ghostGaugeEven[1]: ghostGaugeOdd[1]);
127 int offset = (n_ghost_faces + x2 -d)*X4*X3*X1/2 + (x4*X3*X1 + x3*X1+x1)/2;
128 return &ghostGaugeField[offset*(3*3*2)];
130 j = (x4*X3*X2*X1 + x3*X2*X1 + new_x2*X1 + x1) / 2;
136 int new_x3 = (x3 - d +
X3 )% X3;
138 ghostGaugeField = (oddBit?ghostGaugeEven[2]: ghostGaugeOdd[2]);
139 int offset = (n_ghost_faces + x3 -d)*X4*X2*X1/2 + (x4*X2*X1 + x2*X1+x1)/2;
140 return &ghostGaugeField[offset*(3*3*2)];
142 j = (x4*X3*X2*X1 + new_x3*X2*X1 + x2*X1 + x1) / 2;
147 int new_x4 = (x4 - d +
X4)% X4;
149 ghostGaugeField = (oddBit?ghostGaugeEven[3]: ghostGaugeOdd[3]);
150 int offset = (n_ghost_faces + x4 -d)*X1*X2*X3/2 + (x3*X2*X1 + x2*X1+x1)/2;
151 return &ghostGaugeField[offset*(3*3*2)];
153 j = (new_x4*(X3*X2*
X1) + x3*(X2*X1) + x2*(
X1) + x1) / 2;
157 default: j = -1; printf(
"ERROR: wrong dir \n"); exit(1);
159 gaugeField = (oddBit ? gaugeEven : gaugeOdd);
163 return &gaugeField[dir/2][j*(3*3*2)];
172 {{1,0}, {0,0}, {0,0}, {0,-1}},
173 {{0,0}, {1,0}, {0,-1}, {0,0}},
174 {{0,0}, {0,1}, {1,0}, {0,0}},
175 {{0,1}, {0,0}, {0,0}, {1,0}}
178 {{1,0}, {0,0}, {0,0}, {0,1}},
179 {{0,0}, {1,0}, {0,1}, {0,0}},
180 {{0,0}, {0,-1}, {1,0}, {0,0}},
181 {{0,-1}, {0,0}, {0,0}, {1,0}}
184 {{1,0}, {0,0}, {0,0}, {1,0}},
185 {{0,0}, {1,0}, {-1,0}, {0,0}},
186 {{0,0}, {-1,0}, {1,0}, {0,0}},
187 {{1,0}, {0,0}, {0,0}, {1,0}}
190 {{1,0}, {0,0}, {0,0}, {-1,0}},
191 {{0,0}, {1,0}, {1,0}, {0,0}},
192 {{0,0}, {1,0}, {1,0}, {0,0}},
193 {{-1,0}, {0,0}, {0,0}, {1,0}}
196 {{1,0}, {0,0}, {0,-1}, {0,0}},
197 {{0,0}, {1,0}, {0,0}, {0,1}},
198 {{0,1}, {0,0}, {1,0}, {0,0}},
199 {{0,0}, {0,-1}, {0,0}, {1,0}}
202 {{1,0}, {0,0}, {0,1}, {0,0}},
203 {{0,0}, {1,0}, {0,0}, {0,-1}},
204 {{0,-1}, {0,0}, {1,0}, {0,0}},
205 {{0,0}, {0,1}, {0,0}, {1,0}}
208 {{1,0}, {0,0}, {-1,0}, {0,0}},
209 {{0,0}, {1,0}, {0,0}, {-1,0}},
210 {{-1,0}, {0,0}, {1,0}, {0,0}},
211 {{0,0}, {-1,0}, {0,0}, {1,0}}
214 {{1,0}, {0,0}, {1,0}, {0,0}},
215 {{0,0}, {1,0}, {0,0}, {1,0}},
216 {{1,0}, {0,0}, {1,0}, {0,0}},
217 {{0,0}, {1,0}, {0,0}, {1,0}}
221 {{0,0}, {0,0}, {0,0}, {0,0}},
222 {{0,0}, {0,0}, {0,0}, {0,0}},
223 {{0,0}, {0,0}, {2,0}, {0,0}},
224 {{0,0}, {0,0}, {0,0}, {2,0}}
228 {{2,0}, {0,0}, {0,0}, {0,0}},
229 {{0,0}, {2,0}, {0,0}, {0,0}},
230 {{0,0}, {0,0}, {0,0}, {0,0}},
231 {{0,0}, {0,0}, {0,0}, {0,0}}
237 template <
typename Float>
239 for (
int i=0; i<4*3*2; i++) res[i] = 0.0;
241 for (
int s = 0;
s < 4;
s++) {
242 for (
int t = 0; t < 4; t++) {
246 for (
int m = 0; m < 3; m++) {
247 Float spinorRe = spinorIn[t*(3*2) + m*(2) + 0];
248 Float spinorIm = spinorIn[t*(3*2) + m*(2) + 1];
249 res[
s*(3*2) + m*(2) + 0] += projRe*spinorRe - projIm*spinorIm;
250 res[
s*(3*2) + m*(2) + 1] += projRe*spinorIm + projIm*spinorRe;
271 template <QudaPCType type,
typename sFloat,
typename gFloat>
277 for (
int i=0; i<
V5h*4*3*2; i++) res[i] = 0.0;
280 gFloat *gaugeEven[4], *gaugeOdd[4];
283 for (
int dir = 0; dir < 4; dir++) {
284 gaugeEven[dir] = gaugeFull[dir];
289 int sp_idx,gaugeOddBit;
290 for (
int xs=0;xs<
Ls;xs++) {
291 for (
int gge_idx = 0; gge_idx <
Vh; gge_idx++) {
292 for (
int dir = 0; dir < 8; dir++) {
293 sp_idx=gge_idx+Vh*xs;
299 gaugeOddBit = (xs%2 == 0 || type ==
QUDA_4D_PC) ? oddBit : (oddBit+1) % 2;
300 gFloat *gauge =
gaugeLink_sgpu(gge_idx, dir, gaugeOddBit, gaugeEven, gaugeOdd);
304 sFloat *
spinor = spinorNeighbor_5d<type>(sp_idx, dir, oddBit, spinorField);
305 sFloat projectedSpinor[4*3*2], gaugedSpinor[4*3*2];
306 int projIdx = 2*(dir/2)+(dir+daggerBit)%2;
309 for (
int s = 0;
s < 4;
s++) {
311 su3Mul(&gaugedSpinor[
s*(3*2)], gauge, &projectedSpinor[
s*(3*2)]);
313 std::cout <<
"spinor:" << std::endl;
315 std::cout <<
"gauge:" << std::endl;
318 su3Tmul(&gaugedSpinor[
s*(3*2)], gauge, &projectedSpinor[
s*(3*2)]);
322 sum(&res[sp_idx*(4*3*2)], &res[sp_idx*(4*3*2)], gaugedSpinor, 4*3*2);
329 template <QudaPCType type,
typename sFloat,
typename gFloat>
330 void dslashReference_4d_mgpu(sFloat *res, gFloat **gaugeFull, gFloat **ghostGauge, sFloat *spinorField,
331 sFloat **fwdSpinor, sFloat **backSpinor,
int oddBit,
int daggerBit)
336 gFloat *gaugeEven[4], *gaugeOdd[4];
337 gFloat *ghostGaugeEven[4], *ghostGaugeOdd[4];
339 for (
int dir = 0; dir < 4; dir++)
341 gaugeEven[dir] = gaugeFull[dir];
344 ghostGaugeEven[dir] = ghostGauge[dir];
345 ghostGaugeOdd[dir] = ghostGauge[dir] + (
faceVolume[dir]/2)*gaugeSiteSize;
347 for (
int xs=0;xs<
Ls;xs++)
350 for (
int i = 0; i <
Vh; i++)
353 for (
int dir = 0; dir < 8; dir++)
355 int gaugeOddBit = (xs%2 == 0 || type ==
QUDA_4D_PC) ? oddBit : (oddBit + 1) % 2;
357 gFloat *gauge =
gaugeLink_mgpu(i, dir, gaugeOddBit, gaugeEven, gaugeOdd, ghostGaugeEven, ghostGaugeOdd, 1, 1);
358 sFloat *
spinor = spinorNeighbor_5d_mgpu<type>(sp_idx, dir, oddBit, spinorField, fwdSpinor, backSpinor, 1, 1);
361 int projIdx = 2*(dir/2)+(dir+daggerBit)%2;
364 for (
int s = 0;
s < 4;
s++)
366 if (dir % 2 == 0)
su3Mul(&gaugedSpinor[
s*(3*2)], gauge, &projectedSpinor[
s*(3*2)]);
367 else su3Tmul(&gaugedSpinor[
s*(3*2)], gauge, &projectedSpinor[
s*(3*2)]);
369 sum(&res[sp_idx*(4*3*2)], &res[sp_idx*(4*3*2)], gaugedSpinor, 4*3*2);
377 template <QudaPCType type,
bool zero_initialize = false,
typename sFloat>
380 for (
int i = 0; i <
V5h; i++) {
381 if (zero_initialize)
for(
int one_site = 0 ; one_site < 24 ; one_site++)
382 res[i*(4*3*2)+one_site] = 0.0;
383 for (
int dir = 8; dir < 10; dir++) {
387 sFloat *
spinor = spinorNeighbor_5d<type>(i, dir, oddBit, spinorField);
388 sFloat projectedSpinor[4*3*2];
389 int projIdx = 2*(dir/2)+(dir+daggerBit)%2;
393 int xs = X/(
Z[3]*
Z[2]*
Z[1]*
Z[0]);
395 if ( (xs == 0 && dir == 9) || (xs ==
Ls-1 && dir == 8) ) {
396 ax(projectedSpinor,(sFloat)(-mferm),projectedSpinor,4*3*2);
398 sum(&res[i*(4*3*2)], &res[i*(4*3*2)], projectedSpinor, 4*3*2);
404 template <
typename sFloat>
407 double *inv_Ftr = (
double*)malloc(
Ls*
sizeof(sFloat));
408 double *Ftr = (
double*)malloc(
Ls*
sizeof(sFloat));
409 for(
int xs = 0 ; xs <
Ls ; xs++)
411 inv_Ftr[xs] = 1.0/(1.0+
pow(2.0*kappa[xs], Ls)*mferm);
412 Ftr[xs] = -2.0*kappa[xs]*mferm*inv_Ftr[xs];
413 for (
int i = 0; i <
Vh; i++) {
414 memcpy(&res[24*(i+Vh*xs)], &spinorField[24*(i+Vh*xs)], 24*
sizeof(sFloat));
420 for (
int i = 0; i <
Vh; i++) {
421 ax(&res[12+24*(i+Vh*(Ls-1))],(sFloat)(inv_Ftr[0]), &spinorField[12+24*(i+Vh*(Ls-1))], 12);
425 for(
int xs = 0 ; xs <= Ls-2 ; ++xs)
427 for (
int i = 0; i <
Vh; i++) {
428 axpy((sFloat)(2.0*kappa[xs]), &res[24*(i+Vh*xs)], &res[24*(i+Vh*(xs+1))], 12);
429 axpy((sFloat)Ftr[xs], &res[12+24*(i+Vh*xs)], &res[12+24*(i+Vh*(Ls-1))], 12);
431 for (
int tmp_s = 0 ; tmp_s <
Ls ; tmp_s++)
432 Ftr[tmp_s] *= 2.0*kappa[tmp_s];
434 for(
int xs = 0 ; xs <
Ls ; xs++)
436 Ftr[xs] = -
pow(2.0*kappa[xs],Ls-1)*mferm*inv_Ftr[xs];
439 for(
int xs = Ls-2 ; xs >=0 ; --xs)
441 for (
int i = 0; i <
Vh; i++) {
442 axpy((sFloat)Ftr[xs], &res[24*(i+Vh*(Ls-1))], &res[24*(i+Vh*xs)], 12);
443 axpy((sFloat)(2.0*kappa[xs]), &res[12+24*(i+Vh*(xs+1))], &res[12+24*(i+Vh*xs)], 12);
445 for (
int tmp_s = 0 ; tmp_s <
Ls ; tmp_s++)
446 Ftr[tmp_s] /= 2.0*kappa[tmp_s];
449 for (
int i = 0; i <
Vh; i++) {
450 ax(&res[24*(i+Vh*(Ls-1))], (sFloat)(inv_Ftr[Ls-1]), &res[24*(i+Vh*(Ls-1))], 12);
456 for (
int i = 0; i <
Vh; i++) {
457 ax(&res[24*(i+Vh*(Ls-1))],(sFloat)(inv_Ftr[0]), &spinorField[24*(i+Vh*(Ls-1))], 12);
461 for(
int xs = 0 ; xs <= Ls-2 ; ++xs)
463 for (
int i = 0; i <
Vh; i++) {
464 axpy((sFloat)Ftr[xs], &res[24*(i+Vh*xs)], &res[24*(i+Vh*(Ls-1))], 12);
465 axpy((sFloat)(2.0*kappa[xs]), &res[12+24*(i+Vh*xs)], &res[12+24*(i+Vh*(xs+1))], 12);
467 for (
int tmp_s = 0 ; tmp_s <
Ls ; tmp_s++)
468 Ftr[tmp_s] *= 2.0*kappa[tmp_s];
470 for(
int xs = 0 ; xs <
Ls ; xs++)
472 Ftr[xs] = -
pow(2.0*kappa[xs],Ls-1)*mferm*inv_Ftr[xs];
475 for(
int xs = Ls-2 ; xs >=0 ; --xs)
477 for (
int i = 0; i <
Vh; i++) {
478 axpy((sFloat)(2.0*kappa[xs]), &res[24*(i+Vh*(xs+1))], &res[24*(i+Vh*xs)], 12);
479 axpy((sFloat)Ftr[xs], &res[12+24*(i+Vh*(Ls-1))], &res[12+24*(i+Vh*xs)], 12);
481 for (
int tmp_s = 0 ; tmp_s <
Ls ; tmp_s++)
482 Ftr[tmp_s] /= 2.0*kappa[tmp_s];
485 for (
int i = 0; i <
Vh; i++) {
486 ax(&res[12+24*(i+Vh*(Ls-1))], (sFloat)(inv_Ftr[Ls-1]), &res[12+24*(i+Vh*(Ls-1))], 12);
494 template <
typename sFloat,
typename sComplex>
497 sComplex *inv_Ftr = (sComplex *)malloc(
Ls *
sizeof(sComplex));
498 sComplex *Ftr = (sComplex *)malloc(
Ls *
sizeof(sComplex));
499 for (
int xs = 0; xs <
Ls; xs++) {
500 inv_Ftr[xs] = 1.0 / (1.0 + cpow(2.0 * kappa[xs], Ls) * mferm);
501 Ftr[xs] = -2.0 * kappa[xs] * mferm * inv_Ftr[xs];
502 for (
int i = 0; i <
Vh; i++) {
503 memcpy(&res[24 * (i + Vh * xs)], &spinorField[24 * (i + Vh * xs)], 24 *
sizeof(sFloat));
506 if (daggerBit == 0) {
508 for (
int i = 0; i <
Vh; i++) {
509 ax((sComplex *)&res[12 + 24 * (i + Vh * (Ls - 1))], inv_Ftr[0],
510 (sComplex *)&spinorField[12 + 24 * (i + Vh * (Ls - 1))], 6);
514 for (
int xs = 0; xs <= Ls - 2; ++xs) {
515 for (
int i = 0; i <
Vh; i++) {
516 axpy((2.0 * kappa[xs]), (sComplex *)&res[24 * (i + Vh * xs)], (sComplex *)&res[24 * (i + Vh * (xs + 1))], 6);
517 axpy(Ftr[xs], (sComplex *)&res[12 + 24 * (i + Vh * xs)], (sComplex *)&res[12 + 24 * (i + Vh * (Ls - 1))], 6);
519 for (
int tmp_s = 0; tmp_s <
Ls; tmp_s++) Ftr[tmp_s] *= 2.0 * kappa[tmp_s];
521 for (
int xs = 0; xs <
Ls; xs++) Ftr[xs] = -cpow(2.0 * kappa[xs], Ls - 1) * mferm * inv_Ftr[xs];
524 for (
int xs = Ls - 2; xs >= 0; --xs) {
525 for (
int i = 0; i <
Vh; i++) {
526 axpy(Ftr[xs], (sComplex *)&res[24 * (i + Vh * (Ls - 1))], (sComplex *)&res[24 * (i + Vh * xs)], 6);
527 axpy((2.0 * kappa[xs]), (sComplex *)&res[12 + 24 * (i + Vh * (xs + 1))],
528 (sComplex *)&res[12 + 24 * (i + Vh * xs)], 6);
530 for (
int tmp_s = 0; tmp_s <
Ls; tmp_s++) Ftr[tmp_s] /= 2.0 * kappa[tmp_s];
533 for (
int i = 0; i <
Vh; i++) {
534 ax((sComplex *)&res[24 * (i + Vh * (Ls - 1))], inv_Ftr[Ls - 1], (sComplex *)&res[24 * (i + Vh * (Ls - 1))], 6);
538 for (
int i = 0; i <
Vh; i++) {
539 ax((sComplex *)&res[24 * (i + Vh * (Ls - 1))], inv_Ftr[0], (sComplex *)&spinorField[24 * (i + Vh * (Ls - 1))], 6);
543 for (
int xs = 0; xs <= Ls - 2; ++xs) {
544 for (
int i = 0; i <
Vh; i++) {
545 axpy(Ftr[xs], (sComplex *)&res[24 * (i + Vh * xs)], (sComplex *)&res[24 * (i + Vh * (Ls - 1))], 6);
546 axpy((2.0 * kappa[xs]), (sComplex *)&res[12 + 24 * (i + Vh * xs)],
547 (sComplex *)&res[12 + 24 * (i + Vh * (xs + 1))], 6);
549 for (
int tmp_s = 0; tmp_s <
Ls; tmp_s++) Ftr[tmp_s] *= 2.0 * kappa[tmp_s];
551 for (
int xs = 0; xs <
Ls; xs++) Ftr[xs] = -cpow(2.0 * kappa[xs], Ls - 1) * mferm * inv_Ftr[xs];
554 for (
int xs = Ls - 2; xs >= 0; --xs) {
555 for (
int i = 0; i <
Vh; i++) {
556 axpy((2.0 * kappa[xs]), (sComplex *)&res[24 * (i + Vh * (xs + 1))], (sComplex *)&res[24 * (i + Vh * xs)], 6);
557 axpy(Ftr[xs], (sComplex *)&res[12 + 24 * (i + Vh * (Ls - 1))], (sComplex *)&res[12 + 24 * (i + Vh * xs)], 6);
559 for (
int tmp_s = 0; tmp_s <
Ls; tmp_s++) Ftr[tmp_s] /= 2.0 * kappa[tmp_s];
562 for (
int i = 0; i <
Vh; i++) {
563 ax((sComplex *)&res[12 + 24 * (i + Vh * (Ls - 1))], inv_Ftr[Ls - 1],
564 (sComplex *)&res[12 + 24 * (i + Vh * (Ls - 1))], 6);
577 dslashReference_4d_sgpu<QUDA_5D_PC>((
double*)out, (
double**)gauge, (
double*)in, oddBit, daggerBit);
578 dslashReference_5th<QUDA_5D_PC>((
double*)out, (
double*)
in, oddBit, daggerBit, mferm);
580 dslashReference_4d_sgpu<QUDA_5D_PC>((
float*)out, (
float**)gauge, (
float*)in, oddBit, daggerBit);
581 dslashReference_5th<QUDA_5D_PC>((
float*)out, (
float*)
in, oddBit, daggerBit, (float)mferm);
588 void **ghostGauge = (
void**)cpu.
Ghost();
597 for (
int d=0; d<4; d++) csParam.
x[d] =
Z[d];
615 else errorQuda(
"ERROR: full parity not supported in function %s", __FUNCTION__);
624 dslashReference_4d_mgpu<QUDA_5D_PC>((
double*)out, (
double**)gauge, (
double**)ghostGauge, (
double*)
in,(
double**)fwd_nbr_spinor, (
double**)back_nbr_spinor, oddBit, daggerBit);
626 dslashReference_5th<QUDA_5D_PC>((
double*)out, (
double*)
in, oddBit, daggerBit, mferm);
628 dslashReference_4d_mgpu<QUDA_5D_PC>((
float*)out, (
float**)gauge, (
float**)ghostGauge, (
float*)
in,
629 (
float**)fwd_nbr_spinor, (
float**)back_nbr_spinor, oddBit, daggerBit);
630 dslashReference_5th<QUDA_5D_PC>((
float*)out, (
float*)
in, oddBit, daggerBit, (float)mferm);
641 dslashReference_4d_sgpu<QUDA_4D_PC>((
double*)out, (
double**)gauge, (
double*)in, oddBit, daggerBit);
643 dslashReference_4d_sgpu<QUDA_4D_PC>((
float*)out, (
float**)gauge, (
float*)in, oddBit, daggerBit);
650 void **ghostGauge = (
void**)cpu.
Ghost();
659 for (
int d=0; d<4; d++) csParam.
x[d] =
Z[d];
677 else errorQuda(
"ERROR: full parity not supported in function %s", __FUNCTION__);
685 dslashReference_4d_mgpu<QUDA_4D_PC>((
double*)out, (
double**)gauge, (
double**)ghostGauge, (
double*)
in,(
double**)fwd_nbr_spinor, (
double**)back_nbr_spinor, oddBit, daggerBit);
687 dslashReference_4d_mgpu<QUDA_4D_PC>((
float*)out, (
float**)gauge, (
float**)ghostGauge, (
float*)
in,
688 (
float**)fwd_nbr_spinor, (
float**)back_nbr_spinor, oddBit, daggerBit);
698 if (zero_initialize) dslashReference_5th<QUDA_4D_PC, true>((
double*)out, (
double*)
in, oddBit, daggerBit, mferm);
699 else dslashReference_5th<QUDA_4D_PC, false>((
double*)out, (
double*)
in, oddBit, daggerBit, mferm);
701 if (zero_initialize) dslashReference_5th<QUDA_4D_PC, true>((
float*)out, (
float*)
in, oddBit, daggerBit, (float)mferm);
702 else dslashReference_5th<QUDA_4D_PC, false>((
float*)out, (
float*)
in, oddBit, daggerBit, (float)mferm);
729 if (zero_initialize) dslashReference_5th<QUDA_4D_PC,true>((
double*)out, (
double*)
in, oddBit, daggerBit, mferm);
730 else dslashReference_5th<QUDA_4D_PC,false>((
double*)out, (
double*)
in, oddBit, daggerBit, mferm);
732 if (zero_initialize) dslashReference_5th<QUDA_4D_PC,true>((
float*)out, (
float*)
in, oddBit, daggerBit, (float)mferm);
733 else dslashReference_5th<QUDA_4D_PC,false>((
float*)out, (
float*)
in, oddBit, daggerBit, (float)mferm);
735 for(
int xs = 0 ; xs <
Ls ; xs++) {
745 if (zero_initialize) dslashReference_5th<QUDA_4D_PC, true>((
double*)out, (
double*)
in, oddBit, daggerBit, mferm);
746 else dslashReference_5th<QUDA_4D_PC, false>((
double*)out, (
double*)
in, oddBit, daggerBit, mferm);
747 for(
int xs = 0 ; xs <
Ls ; xs++)
754 dslashReference_5th<QUDA_4D_PC, true>((
float *)out, (
float *)
in, oddBit, daggerBit, (float)mferm);
755 else dslashReference_5th<QUDA_4D_PC,false>((
float*)out, (
float*)
in, oddBit, daggerBit, (float)mferm);
756 for(
int xs = 0 ; xs <
Ls ; xs++)
759 (
float _Complex)(0.5 * c5[xs]), (
float _Complex *)out +
Vh * (
spinorSiteSize / 2) * xs,
773 dw_dslash(outOdd, gauge, inEven, 1, dagger_bit, precision, gauge_param, mferm);
774 dw_dslash(outEven, gauge, inOdd, 0, dagger_bit, precision, gauge_param, mferm);
787 dslash_4_4d(outOdd, gauge, inEven, 1, dagger_bit, precision, gauge_param, mferm);
788 dw_dslash_5_4d(outOdd, gauge, inOdd, 1, dagger_bit, precision, gauge_param, mferm,
false);
790 dslash_4_4d(outEven, gauge, inOdd, 0, dagger_bit, precision, gauge_param, mferm);
791 dw_dslash_5_4d(outEven, gauge, inEven, 0, dagger_bit, precision, gauge_param, mferm,
false);
797 void mdw_mat(
void *
out,
void **gauge,
void *
in,
double _Complex *kappa_b,
double _Complex *kappa_c,
int dagger,
802 double _Complex *
kappa5 = (
double _Complex *)malloc(
Ls *
sizeof(
double _Complex));
804 for(
int xs = 0; xs <
Ls ; xs++) kappa5[xs] = 0.5*kappa_b[xs]/kappa_c[xs];
811 mdw_dslash_4_pre(tmp, gauge, inEven, 0, dagger, precision, gauge_param, mferm, b5, c5,
true);
812 dslash_4_4d(outOdd, gauge, tmp, 1, dagger, precision, gauge_param, mferm);
813 mdw_dslash_5(tmp, gauge, inOdd, 1, dagger, precision, gauge_param, mferm, kappa5,
true);
815 for(
int xs = 0 ; xs <
Ls ; xs++) {
820 mdw_dslash_4_pre(tmp, gauge, inOdd, 1, dagger, precision, gauge_param, mferm, b5, c5,
true);
821 dslash_4_4d(outEven, gauge, tmp, 0, dagger, precision, gauge_param, mferm);
822 mdw_dslash_5(tmp, gauge, inEven, 0, dagger, precision, gauge_param, mferm, kappa5,
true);
824 for(
int xs = 0 ; xs <
Ls ; xs++) {
838 dw_mat(tmp, gauge, in, kappa, dagger_bit, precision, gauge_param, mferm);
839 dagger_bit = (dagger_bit == 1) ? 0 : 1;
840 dw_mat(out, gauge, tmp, kappa, dagger_bit, precision, gauge_param, mferm);
850 dw_dslash(tmp, gauge, in, 1, dagger_bit, precision, gauge_param, mferm);
851 dw_dslash(out, gauge, tmp, 0, dagger_bit, precision, gauge_param, mferm);
853 dw_dslash(tmp, gauge, in, 0, dagger_bit, precision, gauge_param, mferm);
854 dw_dslash(out, gauge, tmp, 1, dagger_bit, precision, gauge_param, mferm);
858 double kappa2 = -kappa*
kappa;
867 double kappa2 = -kappa*
kappa;
868 double *
kappa5 = (
double*)malloc(
Ls*
sizeof(
double));
869 for(
int xs = 0; xs <
Ls ; xs++)
873 double *output = (
double*)out;
882 if (symmetric && !dagger_bit) {
883 dslash_4_4d(tmp, gauge, in, parity[0], dagger_bit, precision, gauge_param, mferm);
884 dslash_5_inv(out, gauge, tmp, parity[0], dagger_bit, precision, gauge_param, mferm, kappa5);
885 dslash_4_4d(tmp, gauge, out, parity[1], dagger_bit, precision, gauge_param, mferm);
886 dslash_5_inv(out, gauge, tmp, parity[1], dagger_bit, precision, gauge_param, mferm, kappa5);
887 xpay(in, kappa2, out,
V5h*spinorSiteSize, precision);
888 }
else if (symmetric && dagger_bit) {
889 dslash_5_inv(tmp, gauge, in, parity[1], dagger_bit, precision, gauge_param, mferm, kappa5);
890 dslash_4_4d(out, gauge, tmp, parity[0], dagger_bit, precision, gauge_param, mferm);
891 dslash_5_inv(tmp, gauge, out, parity[0], dagger_bit, precision, gauge_param, mferm, kappa5);
892 dslash_4_4d(out, gauge, tmp, parity[1], dagger_bit, precision, gauge_param, mferm);
893 xpay(in, kappa2, out,
V5h*spinorSiteSize, precision);
895 dslash_4_4d(tmp, gauge, in, parity[0], dagger_bit, precision, gauge_param, mferm);
896 dslash_5_inv(out, gauge, tmp, parity[0], dagger_bit, precision, gauge_param, mferm, kappa5);
897 dslash_4_4d(tmp, gauge, out, parity[1], dagger_bit, precision, gauge_param, mferm);
898 xpay(in, kappa2, tmp,
V5h*spinorSiteSize, precision);
899 dw_dslash_5_4d(out, gauge, in, parity[1], dagger_bit, precision, gauge_param, mferm,
true);
900 xpay(tmp, -kappa, out,
V5h*spinorSiteSize, precision);
906 void mdw_matpc(
void *
out,
void **gauge,
void *
in,
double _Complex *kappa_b,
double _Complex *kappa_c,
908 double _Complex *b5,
double _Complex *c5)
911 double _Complex *
kappa5 = (
double _Complex *)malloc(
Ls *
sizeof(
double _Complex));
912 double _Complex *kappa2 = (
double _Complex *)malloc(
Ls *
sizeof(
double _Complex));
913 double _Complex *kappa_mdwf = (
double _Complex *)malloc(
Ls *
sizeof(
double _Complex));
914 for(
int xs = 0; xs <
Ls ; xs++)
916 kappa5[xs] = 0.5*kappa_b[xs]/kappa_c[xs];
917 kappa2[xs] = -kappa_b[xs]*kappa_b[xs];
918 kappa_mdwf[xs] = -kappa5[xs];
925 if (symmetric && !dagger) {
926 mdw_dslash_4_pre(tmp, gauge, in, parity[1], dagger, precision, gauge_param, mferm, b5, c5,
true);
927 dslash_4_4d(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm);
928 mdw_dslash_5_inv(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm, kappa_mdwf);
929 mdw_dslash_4_pre(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm, b5, c5,
true);
930 dslash_4_4d(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm);
931 mdw_dslash_5_inv(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm, kappa_mdwf);
932 for(
int xs = 0 ; xs <
Ls ; xs++) {
936 }
else if (symmetric && dagger) {
937 mdw_dslash_5_inv(tmp, gauge, in, parity[1], dagger, precision, gauge_param, mferm, kappa_mdwf);
938 dslash_4_4d(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm);
939 mdw_dslash_4_pre(tmp, gauge, out, parity[0], dagger, precision, gauge_param, mferm, b5, c5,
true);
940 mdw_dslash_5_inv(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm, kappa_mdwf);
941 dslash_4_4d(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm);
942 mdw_dslash_4_pre(out, gauge, tmp, parity[1], dagger, precision, gauge_param, mferm, b5, c5,
true);
943 for(
int xs = 0 ; xs <
Ls ; xs++) {
947 }
else if (!symmetric && !dagger) {
948 mdw_dslash_4_pre(out, gauge, in, parity[1], dagger, precision, gauge_param, mferm, b5, c5,
true);
949 dslash_4_4d(tmp, gauge, out, parity[0], dagger, precision, gauge_param, mferm);
950 mdw_dslash_5_inv(out, gauge, tmp, parity[1], dagger, precision, gauge_param, mferm, kappa_mdwf);
951 mdw_dslash_4_pre(tmp, gauge, out, parity[0], dagger, precision, gauge_param, mferm, b5, c5,
true);
952 dslash_4_4d(out, gauge, tmp, parity[1], dagger, precision, gauge_param, mferm);
953 mdw_dslash_5(tmp, gauge, in, parity[0], dagger, precision, gauge_param, mferm, kappa5,
true);
954 for(
int xs = 0 ; xs <
Ls ; xs++) {
958 }
else if (!symmetric && dagger) {
959 dslash_4_4d(out, gauge, in, parity[0], dagger, precision, gauge_param, mferm);
960 mdw_dslash_4_pre(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm, b5, c5,
true);
961 mdw_dslash_5_inv(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm, kappa_mdwf);
962 dslash_4_4d(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm);
963 mdw_dslash_4_pre(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm, b5, c5,
true);
964 mdw_dslash_5(tmp, gauge, in, parity[0], dagger, precision, gauge_param, mferm, kappa5,
true);
965 for(
int xs = 0 ; xs <
Ls ; xs++) {
970 errorQuda(
"Unsupported matpc_type=%d dagger=%d", matpc_type, dagger);
1032 void matpc(
void *outEven,
void **gauge,
void *inEven,
double kappa,
void mdslashReference_5th_inv(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm, sComplex *kappa)
QudaGhostExchange ghostExchange
void mdw_matpc(void *out, void **gauge, void *in, double _Complex *kappa_b, double _Complex *kappa_c, QudaMatPCType matpc_type, int dagger, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *b5, double _Complex *c5)
void setPrecision(QudaPrecision precision, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION, bool force_native=false)
void dw_4d_matpc(void *out, void **gauge, void *in, double kappa, QudaMatPCType matpc_type, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
void dw_dslash_5_4d(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, bool zero_initialize)
enum QudaPrecision_s QudaPrecision
void mdw_mat(void *out, void **gauge, void *in, double _Complex *kappa_b, double _Complex *kappa_c, int dagger, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *b5, double _Complex *c5)
void dslashReference_5th(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm)
void printSpinorElement(void *spinor, int X, QudaPrecision precision)
cudaColorSpinorField * tmp
int fullLatticeIndex_5d_4dpc(int i, int oddBit)
void dw_dslash(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
QudaGaugeParam gauge_param
void dw_mat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
__host__ __device__ void sum(double &a, double &b)
QudaSiteSubset siteSubset
static void axpby(Float a, Float *x, Float b, Float *y, int len)
void xpay(ColorSpinorField &x, double a, ColorSpinorField &y)
int fullLatticeIndex(int dim[4], int index, int oddBit)
void matpc(void *outEven, void **gauge, void *inEven, double kappa, QudaMatPCType matpc_type, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm)
void mdw_dslash_4_pre(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *b5, double _Complex *c5, bool zero_initialize)
void dw_matdagmat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
void dw_matpc(void *out, void **gauge, void *in, double kappa, QudaMatPCType matpc_type, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
void ax(const double &a, GaugeField &u)
Scale the gauge field by the scalar a.
QudaFieldOrder fieldOrder
void mdw_dslash_5_inv(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *kappa)
void dw_4d_mat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
Float * gaugeLink_sgpu(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd)
enum QudaMatPCType_s QudaMatPCType
QudaGammaBasis gammaBasis
void cxpay(void *x, double _Complex a, void *y, int length, QudaPrecision precision)
void dslashReference_4d_sgpu(sFloat *res, gFloat **gaugeFull, sFloat *spinorField, int oddBit, int daggerBit)
__host__ __device__ ValueType pow(ValueType x, ExponentType e)
static void * backGhostFaceBuffer[QUDA_MAX_DIM]
const void ** Ghost() const
enum QudaParity_s QudaParity
void exchangeGhost(QudaParity parity, int nFace, int dagger, const MemoryLocation *pack_destination=nullptr, const MemoryLocation *halo_location=nullptr, bool gdr_send=false, bool gdr_recv=false, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION) const
This is a unified ghost exchange function for doing a complete halo exchange regardless of the type o...
void matpcdagmatpc(void *out, void **gauge, void *in, double kappa, QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm, QudaMatPCType matpc_type)
static void * fwdGhostFaceBuffer[QUDA_MAX_DIM]
void dslashReference_5th_inv(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm, double *kappa)
int fullLatticeIndex_5d(int i, int oddBit)
cpuColorSpinorField * out
void mdw_dslash_5(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *kappa, bool zero_initialize)
Main header file for the QUDA library.
int fullLatticeIndex_4d(int i, int oddBit)
const double projector[10][4][4][2]
void dslash_4_4d(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
void matdagmat(void *out, void **gauge, void *in, double kappa, QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm)
static void su3Mul(sFloat *res, gFloat *mat, sFloat *vec)
__device__ void axpy(real a, const real *x, Link &y)
static void su3Tmul(sFloat *res, gFloat *mat, sFloat *vec)
Float * gaugeLink_mgpu(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd, Float **ghostGaugeEven, Float **ghostGaugeOdd, int n_ghost_faces, int nbr_distance)
int neighborIndex_4d(int i, int oddBit, int dx4, int dx3, int dx2, int dx1)
void dslash_5_inv(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double *kappa)
void multiplySpinorByDiracProjector5(Float *res, int projIdx, Float *spinorIn)
cpuColorSpinorField * spinor
int comm_dim_partitioned(int dim)