13 namespace fermion_force {
29 FORCE_LEPAGE_MIDDLE_LINK,
31 FORCE_SIDE_LINK_SHORT,
38 __device__ __host__ constexpr
inline int opp_dir(
int dir) {
return 7-dir; }
39 __device__ __host__ constexpr
inline int goes_forward(
int dir) {
return dir<=3; }
40 __device__ __host__ constexpr
inline int goes_backward(
int dir) {
return dir>3; }
41 __device__ __host__ constexpr
inline int CoeffSign(
int pos_dir,
int odd_lattice) {
return 2*((pos_dir + odd_lattice + 1) & 1) - 1; }
42 __device__ __host__ constexpr
inline int Sign(
int parity) {
return parity ? -1 : 1; }
43 __device__ __host__ constexpr
inline int posDir(
int dir) {
return (dir >= 4) ? 7-dir : dir; }
45 template <
int dir,
typename Arg>
46 inline __device__ __host__
void updateCoords(
int x[],
int shift,
const Arg &
arg) {
50 template <
typename Arg>
51 inline __device__ __host__
void updateCoords(
int x[],
int dir,
int shift,
const Arg &
arg) {
53 case 0: updateCoords<0>(
x,
shift,
arg);
break;
54 case 1: updateCoords<1>(
x,
shift,
arg);
break;
55 case 2: updateCoords<2>(
x,
shift,
arg);
break;
56 case 3: updateCoords<3>(
x,
shift,
arg);
break;
61 template <
typename real>
70 : one(path_coeff_array[0]), naik(path_coeff_array[1]),
71 three(path_coeff_array[2]), five(path_coeff_array[3]),
72 seven(path_coeff_array[4]), lepage(path_coeff_array[5]) { }
75 template <
typename real, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
77 typedef typename gauge_mapper<real,reconstruct>::type G;
95 BaseForceArg(
const GaugeField &link,
int overlap) : link(link), threads(1),
98 for (
int d=0;
d<4;
d++) {
100 border[
d] = link.R()[
d];
101 X[
d] =
E[
d] - 2*border[
d];
107 oddness_change = (base_idx[0] + base_idx[1] + base_idx[2] + base_idx[3])&1;
111 template <
typename real, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
112 struct FatLinkArg :
public BaseForceArg<real,reconstruct> {
114 typedef typename gauge_mapper<real,QUDA_RECONSTRUCT_NO>::type F;
125 const real accumu_coeff;
131 FatLinkArg(GaugeField &force,
const GaugeField &oProd,
const GaugeField &link, real
coeff, HisqForceType type)
132 : BaseForceArg<real,reconstruct>(link, 0), outA(force), outB(force), pMu(oProd), p3(oProd), qMu(oProd),
133 oProd(oProd), qProd(oProd), qPrev(oProd),
coeff(
coeff), accumu_coeff(0),
134 p_mu(false), q_mu(false), q_prev(false)
135 {
if (type != FORCE_ONE_LINK)
errorQuda(
"This constructor is for FORCE_ONE_LINK"); }
137 FatLinkArg(GaugeField &newOprod, GaugeField &pMu, GaugeField &
P3, GaugeField &qMu,
138 const GaugeField &oProd,
const GaugeField &qPrev,
const GaugeField &link,
139 real
coeff,
int overlap, HisqForceType type)
140 : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(newOprod), pMu(pMu), p3(
P3), qMu(qMu),
141 oProd(oProd), qProd(oProd), qPrev(qPrev),
coeff(
coeff), accumu_coeff(0), p_mu(true), q_mu(true), q_prev(true)
142 {
if (type != FORCE_MIDDLE_LINK)
errorQuda(
"This constructor is for FORCE_MIDDLE_LINK"); }
144 FatLinkArg(GaugeField &newOprod, GaugeField &pMu, GaugeField &
P3, GaugeField &qMu,
145 const GaugeField &oProd,
const GaugeField &link,
146 real
coeff,
int overlap, HisqForceType type)
147 : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(newOprod), pMu(pMu), p3(
P3), qMu(qMu),
148 oProd(oProd), qProd(oProd), qPrev(qMu),
coeff(
coeff), accumu_coeff(0), p_mu(true), q_mu(true), q_prev(false)
149 {
if (type != FORCE_MIDDLE_LINK)
errorQuda(
"This constructor is for FORCE_MIDDLE_LINK"); }
151 FatLinkArg(GaugeField &newOprod, GaugeField &
P3,
const GaugeField &oProd,
152 const GaugeField &qPrev,
const GaugeField &link,
153 real
coeff,
int overlap, HisqForceType type)
154 : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(newOprod), pMu(
P3), p3(
P3), qMu(qPrev),
155 oProd(oProd), qProd(oProd), qPrev(qPrev),
coeff(
coeff), accumu_coeff(0), p_mu(false), q_mu(false), q_prev(true)
156 {
if (type != FORCE_LEPAGE_MIDDLE_LINK)
errorQuda(
"This constructor is for FORCE_MIDDLE_LINK"); }
158 FatLinkArg(GaugeField &newOprod, GaugeField &shortP,
const GaugeField &
P3,
159 const GaugeField &qProd,
const GaugeField &link, real
coeff, real accumu_coeff,
int overlap, HisqForceType type)
160 : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(shortP), pMu(
P3), p3(
P3), qMu(qProd), oProd(qProd), qProd(qProd),
161 qPrev(qProd),
coeff(
coeff), accumu_coeff(accumu_coeff),
162 p_mu(false), q_mu(false), q_prev(false)
163 {
if (type != FORCE_SIDE_LINK)
errorQuda(
"This constructor is for FORCE_SIDE_LINK or FORCE_ALL_LINK"); }
165 FatLinkArg(GaugeField &newOprod, GaugeField &
P3,
const GaugeField &link,
166 real
coeff,
int overlap, HisqForceType type)
167 : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(newOprod),
169 p_mu(false), q_mu(false), q_prev(false)
170 {
if (type != FORCE_SIDE_LINK_SHORT)
errorQuda(
"This constructor is for FORCE_SIDE_LINK_SHORT"); }
172 FatLinkArg(GaugeField &newOprod, GaugeField &shortP,
const GaugeField &oProd,
const GaugeField &qPrev,
173 const GaugeField &link, real
coeff, real accumu_coeff,
int overlap, HisqForceType type,
bool dummy)
174 : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(shortP), oProd(oProd), qPrev(qPrev),
175 pMu(shortP), p3(shortP), qMu(qPrev), qProd(qPrev),
176 coeff(
coeff), accumu_coeff(accumu_coeff), p_mu(false), q_mu(false), q_prev(false)
177 {
if (type != FORCE_ALL_LINK)
errorQuda(
"This constructor is for FORCE_ALL_LINK"); }
181 template <
typename real,
typename Arg>
182 __global__
void oneLinkTermKernel(Arg
arg)
185 int x_cb = blockIdx.x *
blockDim.x + threadIdx.x;
186 if (x_cb >=
arg.threads)
return;
188 int sig = blockIdx.z *
blockDim.z + threadIdx.z;
189 if (sig >= 4)
return;
194 for (
int d=0;
d<4;
d++)
x[
d] +=
arg.border[
d];
198 Link force =
arg.outA(sig, e_cb,
parity);
199 force +=
arg.coeff *
w;
227 template<
typename real,
int sig_positive,
int mu_positive,
typename Arg>
228 __global__
void allLinkKernel(Arg
arg)
232 int x_cb = blockIdx.x *
blockDim.x + threadIdx.x;
233 if (x_cb >=
arg.threads)
return;
238 for (
int d=0;
d<4;
d++)
x[
d] +=
arg.base_idx[
d];
242 real mycoeff = CoeffSign(sig_positive,
parity)*
arg.coeff;
244 int y[4] = {
x[0],
x[1],
x[2],
x[3]};
245 int mysig = posDir(
arg.sig);
246 updateCoords(
y, mysig, (sig_positive ? 1 : -1),
arg);
248 int ab_link_nbr_idx = (sig_positive) ? e_cb : point_b;
250 for (
int d=0;
d<4;
d++)
y[
d] =
x[
d];
261 int mu = mu_positive ?
arg.mu : opp_dir(
arg.mu);
262 int dir = mu_positive ? -1 : 1;
264 updateCoords(
y,
mu, dir,
arg);
266 updateCoords(
y, mysig, (sig_positive ? 1 : -1),
arg);
269 Link Uab =
arg.link(posDir(
arg.sig), ab_link_nbr_idx, sig_positive^(1-
parity));
270 Link Uad =
arg.link(
mu, mu_positive ? point_d : e_cb, mu_positive ? 1-
parity :
parity);
271 Link Ubc =
arg.link(
mu, mu_positive ? point_c : point_b, mu_positive ?
parity : 1-
parity);
272 Link Ox =
arg.qPrev(0, point_d, 1-
parity);
274 Link Oz = mu_positive ?
conj(Ubc)*Oy : Ubc*Oy;
278 force +=
Sign(
parity)*mycoeff*Oz*Ox* (mu_positive ? Uad :
conj(Uad));
285 Link force =
arg.outA(
mu, mu_positive ? point_d : e_cb, mu_positive ? 1-
parity :
parity);
287 arg.outA(
mu, mu_positive ? point_d : e_cb, mu_positive ? 1-
parity :
parity) = force;
289 Link shortP =
arg.outB(0, point_d, 1-
parity);
290 shortP +=
arg.accumu_coeff* (mu_positive ? Uad :
conj(Uad)) *Oy;
336 template <
typename real,
int sig_positive,
int mu_positive,
bool pMu,
bool qMu,
bool qPrev,
typename Arg>
337 __global__
void middleLinkKernel(Arg
arg)
341 int x_cb = blockIdx.x *
blockDim.x + threadIdx.x;
342 if (x_cb >=
arg.threads)
return;
356 for (
int d=0;
d<4;
d++)
x[
d] +=
arg.base_idx[
d];
359 int y[4] = {
x[0],
x[1],
x[2],
x[3]};
361 int mymu = posDir(
arg.mu);
362 updateCoords(
y, mymu, (mu_positive ? -1 : 1),
arg);
365 int ad_link_nbr_idx = mu_positive ? point_d : e_cb;
367 int mysig = posDir(
arg.sig);
368 updateCoords(
y, mysig, (sig_positive ? 1 : -1),
arg);
371 for (
int d=0;
d<4;
d++)
y[
d] =
x[
d];
372 updateCoords(
y, mysig, (sig_positive ? 1 : -1),
arg);
375 int bc_link_nbr_idx = mu_positive ? point_c : point_b;
376 int ab_link_nbr_idx = sig_positive ? e_cb : point_b;
379 Link Uab =
arg.link(mysig, ab_link_nbr_idx, sig_positive^(1-
parity));
382 Link Ubc =
arg.link(mymu, bc_link_nbr_idx, mu_positive^(1-
parity));
386 Oy =
arg.oProd(posDir(
arg.sig), sig_positive ? point_d : point_c, sig_positive^
parity);
387 if (!sig_positive) Oy =
conj(Oy);
392 Link Ow = !mu_positive ? Ubc*Oy :
conj(Ubc)*Oy;
394 if (pMu)
arg.pMu(0, point_b, 1-
parity) = Ow;
396 arg.p3(0, e_cb,
parity) = sig_positive ? Uab*Ow :
conj(Uab)*Ow;
398 Link Uad =
arg.link(mymu, ad_link_nbr_idx, mu_positive^
parity);
399 if (!mu_positive) Uad =
conj(Uad);
402 if (sig_positive) Oy = Ow*Uad;
403 if ( qMu )
arg.qMu(0, e_cb,
parity) = Uad;
406 if ( qMu || sig_positive ) {
410 if ( qMu )
arg.qMu(0, e_cb,
parity) = Ox;
411 if (sig_positive) Oy = Ow*Ox;
416 oprod +=
arg.coeff*Oy;
452 template <
typename real,
int mu_positive,
typename Arg>
453 __global__
void sideLinkKernel(Arg
arg)
456 int x_cb = blockIdx.x *
blockDim.x + threadIdx.x;
457 if (x_cb >=
arg.threads)
return;
462 for (
int d=0;
d<4;
d++)
x[
d] =
x[
d] +
arg.base_idx[
d];
477 int mymu = posDir(
arg.mu);
478 int y[4] = {
x[0],
x[1],
x[2],
x[3]};
479 updateCoords(
y, mymu, (mu_positive ? -1 : 1),
arg);
485 int ad_link_nbr_idx = mu_positive ? point_d : e_cb;
487 Link Uad =
arg.link(mymu, ad_link_nbr_idx, mu_positive^
parity);
488 Link Ow = mu_positive ? Uad*Oy :
conj(Uad)*Oy;
490 Link shortP =
arg.outB(0, point_d, 1-
parity);
491 shortP +=
arg.accumu_coeff * Ow;
496 Link Ox =
arg.qProd(0, point_d, 1-
parity);
497 Link Ow = mu_positive ? Oy*Ox :
conj(Ox)*
conj(Oy);
499 real mycoeff = CoeffSign(goes_forward(
arg.sig),
parity)*CoeffSign(goes_forward(
arg.mu),
parity)*
arg.coeff;
501 Link oprod =
arg.outA(mu_positive ?
arg.mu : opp_dir(
arg.mu), mu_positive ? point_d : e_cb, mu_positive ? 1-
parity :
parity);
502 oprod += mycoeff * Ow;
503 arg.outA(mu_positive ?
arg.mu : opp_dir(
arg.mu), mu_positive ? point_d : e_cb, mu_positive ? 1-
parity :
parity) = oprod;
509 template<
typename real,
int mu_positive,
typename Arg>
510 __global__
void sideLinkShortKernel(Arg
arg)
513 int x_cb = blockIdx.x *
blockDim.x + threadIdx.x;
514 if (x_cb >=
arg.threads)
return;
519 for (
int d=0;
d<4;
d++)
x[
d] =
x[
d] +
arg.base_idx[
d];
533 int mymu = posDir(
arg.mu);
534 int y[4] = {
x[0],
x[1],
x[2],
x[3]};
535 updateCoords(
y, mymu, (mu_positive ? -1 : 1),
arg);
539 real mycoeff = CoeffSign(goes_forward(
arg.sig),
parity)*CoeffSign(goes_forward(
arg.mu),
parity)*
arg.coeff;
542 Link oprod =
arg.outA(posDir(
arg.mu), point_d, parity_);
543 oprod += mu_positive ? mycoeff * Oy : mycoeff *
conj(Oy);
544 arg.outA(posDir(
arg.mu), point_d, parity_) = oprod;
547 template <
typename real,
typename Arg>
548 class FatLinkForce :
public TunableVectorYZ {
552 const GaugeField &meta;
553 const HisqForceType type;
555 unsigned int minThreads()
const {
return arg.threads; }
556 bool tuneGridDim()
const {
return false; }
559 FatLinkForce(Arg &
arg,
const GaugeField &meta,
int sig,
int mu, HisqForceType type)
560 : TunableVectorYZ(2,type == FORCE_ONE_LINK ? 4 : 1),
arg(
arg), meta(meta), type(type) {
564 virtual ~FatLinkForce() { }
566 TuneKey tuneKey()
const {
567 std::stringstream aux;
568 if (type == FORCE_ONE_LINK) aux <<
"threads=" <<
arg.threads;
569 else if (type == FORCE_MIDDLE_LINK || type == FORCE_LEPAGE_MIDDLE_LINK)
570 aux <<
"threads=" <<
arg.threads <<
",sig=" <<
arg.sig <<
",mu=" <<
arg.mu <<
571 ",pMu=" <<
arg.p_mu <<
",q_muu=" <<
arg.q_mu <<
",q_prev=" <<
arg.q_prev;
573 aux <<
"threads=" <<
arg.threads <<
",mu=" <<
arg.mu;
576 case FORCE_ONE_LINK: aux <<
",ONE_LINK";
break;
577 case FORCE_ALL_LINK: aux <<
",ALL_LINK";
break;
578 case FORCE_MIDDLE_LINK: aux <<
",MIDDLE_LINK";
break;
579 case FORCE_LEPAGE_MIDDLE_LINK: aux <<
",LEPAGE_MIDDLE_LINK";
break;
580 case FORCE_SIDE_LINK: aux <<
",SIDE_LINK";
break;
581 case FORCE_SIDE_LINK_SHORT: aux <<
",SIDE_LINK_SHORT";
break;
582 default:
errorQuda(
"Undefined force type %d", type);
584 return TuneKey(meta.VolString(),
typeid(*this).name(), aux.str().c_str());
587 void apply(
const cudaStream_t &
stream) {
591 oneLinkTermKernel<real,Arg> <<<tp.grid,tp.block,tp.shared_bytes,
stream>>>(
arg);
594 if (goes_forward(
arg.sig) && goes_forward(
arg.mu))
595 allLinkKernel<real,1,1,Arg><<<tp.grid,tp.block,tp.shared_bytes,
stream>>>(
arg);
596 else if (goes_forward(
arg.sig) && goes_backward(
arg.mu))
597 allLinkKernel<real,1,0,Arg><<<tp.grid,tp.block,tp.shared_bytes,
stream>>>(
arg);
598 else if (goes_backward(
arg.sig) && goes_forward(
arg.mu))
599 allLinkKernel<real,0,1,Arg><<<tp.grid,tp.block,tp.shared_bytes,
stream>>>(
arg);
601 allLinkKernel<real,0,0,Arg><<<tp.grid,tp.block,tp.shared_bytes,
stream>>>(
arg);
603 case FORCE_MIDDLE_LINK:
604 if (!
arg.p_mu || !
arg.q_mu)
errorQuda(
"Expect p_mu=%d and q_mu=%d to both be true",
arg.p_mu,
arg.q_mu);
606 if (goes_forward(
arg.sig) && goes_forward(
arg.mu))
607 middleLinkKernel<real,1,1,true,true,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,
stream>>>(
arg);
608 else if (goes_forward(
arg.sig) && goes_backward(
arg.mu))
609 middleLinkKernel<real,1,0,true,true,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,
stream>>>(
arg);
610 else if (goes_backward(
arg.sig) && goes_forward(
arg.mu))
611 middleLinkKernel<real,0,1,true,true,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,
stream>>>(
arg);
613 middleLinkKernel<real,0,0,true,true,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,
stream>>>(
arg);
615 if (goes_forward(
arg.sig) && goes_forward(
arg.mu))
616 middleLinkKernel<real,1,1,true,true,false,Arg><<<tp.grid,tp.block,tp.shared_bytes,
stream>>>(
arg);
617 else if (goes_forward(
arg.sig) && goes_backward(
arg.mu))
618 middleLinkKernel<real,1,0,true,true,false,Arg><<<tp.grid,tp.block,tp.shared_bytes,
stream>>>(
arg);
619 else if (goes_backward(
arg.sig) && goes_forward(
arg.mu))
620 middleLinkKernel<real,0,1,true,true,false,Arg><<<tp.grid,tp.block,tp.shared_bytes,
stream>>>(
arg);
622 middleLinkKernel<real,0,0,true,true,false,Arg><<<tp.grid,tp.block,tp.shared_bytes,
stream>>>(
arg);
625 case FORCE_LEPAGE_MIDDLE_LINK:
627 errorQuda(
"Expect p_mu=%d and q_mu=%d to both be false and q_prev=%d true",
arg.p_mu,
arg.q_mu,
arg.q_prev);
628 if (goes_forward(
arg.sig) && goes_forward(
arg.mu))
629 middleLinkKernel<real,1,1,false,false,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,
stream>>>(
arg);
630 else if (goes_forward(
arg.sig) && goes_backward(
arg.mu))
631 middleLinkKernel<real,1,0,false,false,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,
stream>>>(
arg);
632 else if (goes_backward(
arg.sig) && goes_forward(
arg.mu))
633 middleLinkKernel<real,0,1,false,false,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,
stream>>>(
arg);
635 middleLinkKernel<real,0,0,false,false,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,
stream>>>(
arg);
637 case FORCE_SIDE_LINK:
638 if (goes_forward(
arg.mu)) sideLinkKernel<real,1,Arg><<<tp.grid,tp.block,tp.shared_bytes,
stream>>>(
arg);
639 else sideLinkKernel<real,0,Arg><<<tp.grid,tp.block,tp.shared_bytes,
stream>>>(
arg);
641 case FORCE_SIDE_LINK_SHORT:
642 if (goes_forward(
arg.mu)) sideLinkShortKernel<real,1,Arg><<<tp.grid,tp.block,tp.shared_bytes,
stream>>>(
arg);
643 else sideLinkShortKernel<real,0,Arg><<<tp.grid,tp.block,tp.shared_bytes,
stream>>>(
arg);
646 errorQuda(
"Undefined force type %d", type);
659 case FORCE_MIDDLE_LINK:
662 case FORCE_LEPAGE_MIDDLE_LINK:
666 case FORCE_SIDE_LINK:
668 case FORCE_SIDE_LINK_SHORT:
671 default:
errorQuda(
"Undefined force type %d", type);
684 case FORCE_MIDDLE_LINK:
687 case FORCE_LEPAGE_MIDDLE_LINK:
691 case FORCE_SIDE_LINK:
693 case FORCE_SIDE_LINK_SHORT:
696 default:
errorQuda(
"Undefined force type %d", type);
700 long long flops()
const {
703 return 2*4*
arg.threads*36ll;
705 return 2*
arg.threads*(goes_forward(
arg.sig) ? 1242ll : 828ll);
706 case FORCE_MIDDLE_LINK:
707 case FORCE_LEPAGE_MIDDLE_LINK:
708 return 2*
arg.threads*(2 * 198 +
709 (!
arg.q_prev && goes_forward(
arg.sig) ? 198 : 0) +
710 (
arg.q_prev && (
arg.q_mu || goes_forward(
arg.sig) ) ? 198 : 0) +
711 ((
arg.q_prev && goes_forward(
arg.sig) ) ? 198 : 0) +
712 ( goes_forward(
arg.sig) ? 216 : 0) );
713 case FORCE_SIDE_LINK:
return 2*
arg.threads*2*234;
714 case FORCE_SIDE_LINK_SHORT:
return 2*
arg.threads*36;
715 default:
errorQuda(
"Undefined force type %d", type);
720 long long bytes()
const {
723 return 2*4*
arg.threads*(
arg.oProd.Bytes() + 2*
arg.outA.Bytes() );
725 return 2*
arg.threads*( (goes_forward(
arg.sig) ? 4 : 2)*
arg.outA.Bytes() + 3*
arg.link.Bytes()
726 +
arg.oProd.Bytes() +
arg.qPrev.Bytes() + 2*
arg.outB.Bytes());
727 case FORCE_MIDDLE_LINK:
728 case FORCE_LEPAGE_MIDDLE_LINK:
729 return 2*
arg.threads*( ( goes_forward(
arg.sig) ? 2*
arg.outA.Bytes() : 0 ) +
730 (
arg.p_mu ?
arg.pMu.Bytes() : 0) +
731 (
arg.q_mu ?
arg.qMu.Bytes() : 0) +
732 ( ( goes_forward(
arg.sig) ||
arg.q_mu ) ?
arg.qPrev.Bytes() : 0) +
733 arg.p3.Bytes() + 3*
arg.link.Bytes() +
arg.oProd.Bytes() );
734 case FORCE_SIDE_LINK:
735 return 2*
arg.threads*( 2*
arg.outA.Bytes() + 2*
arg.outB.Bytes() +
736 arg.p3.Bytes() +
arg.link.Bytes() +
arg.qProd.Bytes() );
737 case FORCE_SIDE_LINK_SHORT:
738 return 2*
arg.threads*( 2*
arg.outA.Bytes() +
arg.p3.Bytes() );
739 default:
errorQuda(
"Undefined force type %d", type);
745 template<
typename real>
747 GaugeField &
Qmu, GaugeField &
Qnumu, GaugeField &newOprod,
748 const GaugeField &oprod,
const GaugeField &link,
751 real OneLink = act_path_coeff.
one;
752 real ThreeSt = act_path_coeff.
three;
753 real mThreeSt = -ThreeSt;
754 real FiveSt = act_path_coeff.
five;
755 real mFiveSt = -FiveSt;
756 real SevenSt = act_path_coeff.
seven;
757 real Lepage = act_path_coeff.
lepage;
758 real mLepage = -Lepage;
760 FatLinkArg<real>
arg(newOprod, oprod, link, OneLink, FORCE_ONE_LINK);
761 FatLinkForce<real, FatLinkArg<real> > oneLink(
arg, link, 0, 0, FORCE_ONE_LINK);
764 for (
int sig=0; sig<8; sig++) {
765 for (
int mu=0;
mu<8;
mu++) {
766 if ( (
mu == sig) || (
mu == opp_dir(sig)))
continue;
770 FatLinkArg<real> middleLinkArg( newOprod,
Pmu,
P3,
Qmu, oprod, link, mThreeSt, 2, FORCE_MIDDLE_LINK);
771 FatLinkForce<real, FatLinkArg<real> > middleLink(middleLinkArg, link, sig,
mu, FORCE_MIDDLE_LINK);
774 for (
int nu=0; nu < 8; nu++) {
775 if (nu == sig || nu == opp_dir(sig) || nu ==
mu || nu == opp_dir(
mu))
continue;
779 FatLinkArg<real> middleLinkArg( newOprod,
Pnumu,
P5,
Qnumu,
Pmu,
Qmu, link, FiveSt, 1, FORCE_MIDDLE_LINK);
780 FatLinkForce<real, FatLinkArg<real> > middleLink(middleLinkArg, link, sig, nu, FORCE_MIDDLE_LINK);
783 for (
int rho = 0; rho < 8; rho++) {
784 if (rho == sig || rho == opp_dir(sig) || rho ==
mu || rho == opp_dir(
mu) || rho == nu || rho == opp_dir(nu))
continue;
787 FatLinkArg<real>
arg(newOprod,
P5,
Pnumu,
Qnumu, link, SevenSt, FiveSt != 0 ? SevenSt/FiveSt : 0, 1, FORCE_ALL_LINK,
true);
788 FatLinkForce<real, FatLinkArg<real> > all(
arg, link, sig, rho, FORCE_ALL_LINK);
794 FatLinkArg<real>
arg(newOprod,
P3,
P5,
Qmu, link, mFiveSt, (ThreeSt != 0 ? FiveSt/ThreeSt : 0), 1, FORCE_SIDE_LINK);
795 FatLinkForce<real, FatLinkArg<real> > side(
arg, link, sig, nu, FORCE_SIDE_LINK);
802 FatLinkArg<real> middleLinkArg( newOprod,
P5,
Pmu,
Qmu, link, Lepage, 2, FORCE_LEPAGE_MIDDLE_LINK);
803 FatLinkForce<real, FatLinkArg<real> > middleLink(middleLinkArg, link, sig,
mu, FORCE_LEPAGE_MIDDLE_LINK);
806 FatLinkArg<real>
arg(newOprod,
P3,
P5,
Qmu, link, mLepage, (ThreeSt != 0 ? Lepage/ThreeSt : 0), 2, FORCE_SIDE_LINK);
807 FatLinkForce<real, FatLinkArg<real> > side(
arg, link, sig,
mu, FORCE_SIDE_LINK);
812 FatLinkArg<real>
arg(newOprod,
P3, link, ThreeSt, 1, FORCE_SIDE_LINK_SHORT);
813 FatLinkForce<real, FatLinkArg<real> > side(
arg,
P3, sig,
mu, FORCE_SIDE_LINK_SHORT);
820 void hisqStaplesForce(GaugeField &newOprod,
const GaugeField &oprod,
const GaugeField &link,
const double path_coeff_array[6],
long long*
flops)
822 if (!link.isNative())
errorQuda(
"Unsupported gauge order %d", link.Order());
823 if (!oprod.isNative())
errorQuda(
"Unsupported gauge order %d", oprod.Order());
824 if (!newOprod.isNative())
errorQuda(
"Unsupported gauge order %d", newOprod.Order());
843 hisqStaplesForce<double>(
Pmu,
P3,
P5,
Pnumu,
Qmu,
Qnumu, newOprod, oprod, link, act_path_coeff);
851 cudaDeviceSynchronize();
856 for (
int d=0;
d<4;
d++) volume += link.X()[
d]-2*link.R()[
d];
858 *
flops += (
long long)volume*(134784 + 24192 + 103680 + 864 + 397440 + 72 + (path_coeff_array[5] != 0 ? 28944 : 0));
863 template <
typename real, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
864 struct CompleteForceArg :
public BaseForceArg<real,reconstruct> {
866 typedef typename gauge::FloatNOrder<real,18,2,11> M;
867 typedef typename gauge_mapper<real,QUDA_RECONSTRUCT_NO>::type F;
872 CompleteForceArg(GaugeField &force,
const GaugeField &link,
const GaugeField &oprod)
873 : BaseForceArg<real,reconstruct>(link, 0), outA(force), oProd(oprod),
coeff(0.0)
879 template <
typename real,
typename Arg>
880 __global__
void completeForceKernel(Arg
arg)
883 int x_cb = blockIdx.x *
blockDim.x + threadIdx.x;
884 if (x_cb >=
arg.threads)
return;
890 for (
int d=0;
d<4;
d++)
x[
d] +=
arg.border[
d];
894 for (
int sig=0; sig<4; ++sig) {
906 template <
typename real, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
907 struct LongLinkArg :
public BaseForceArg<real,reconstruct> {
909 typedef typename gauge::FloatNOrder<real,18,2,11> M;
910 typedef typename gauge_mapper<real,QUDA_RECONSTRUCT_NO>::type F;
915 LongLinkArg(GaugeField &newOprod,
const GaugeField &link,
const GaugeField &oprod, real
coeff)
916 : BaseForceArg<real,reconstruct>(link,0), outA(newOprod), oProd(oprod),
coeff(
coeff)
924 template <
typename real,
typename Arg>
925 __global__
void longLinkKernel(Arg
arg)
928 int x_cb = blockIdx.x *
blockDim.x + threadIdx.x;
929 if (x_cb >=
arg.threads)
return;
933 int dx[4] = {0,0,0,0};
937 for (
int i=0;
i<4;
i++)
x[
i] +=
arg.border[
i];
953 for (
int sig=0; sig<4; sig++) {
969 Link Uab =
arg.link(sig, point_a,
parity);
970 Link Ubc =
arg.link(sig, point_b, 1-
parity);
971 Link Ude =
arg.link(sig, point_d, 1-
parity);
972 Link Uef =
arg.link(sig, point_e,
parity);
974 Link Oz =
arg.oProd(sig, point_c,
parity);
975 Link Oy =
arg.oProd(sig, point_b, 1-
parity);
976 Link Ox =
arg.oProd(sig, point_a,
parity);
978 Link temp = Ude*Uef*Oz - Ude*Oy*Ubc + Ox*Uab*Ubc;
980 Link force =
arg.outA(sig, e_cb,
parity);
986 template <
typename real,
typename Arg>
987 class HisqForce :
public TunableVectorY {
990 const GaugeField &meta;
991 const HisqForceType type;
993 unsigned int minThreads()
const {
return arg.threads; }
994 bool tuneGridDim()
const {
return false; }
997 HisqForce(Arg &
arg,
const GaugeField &meta,
int sig,
int mu, HisqForceType type)
998 : TunableVectorY(2),
arg(
arg), meta(meta), type(type) {
1002 virtual ~HisqForce() { }
1004 void apply(
const cudaStream_t &
stream) {
1007 case FORCE_LONG_LINK:
1008 longLinkKernel<real,Arg><<<tp.grid,tp.block,tp.shared_bytes,
stream>>>(
arg);
break;
1009 case FORCE_COMPLETE:
1010 completeForceKernel<real,Arg><<<tp.grid,tp.block,tp.shared_bytes,
stream>>>(
arg);
break;
1012 errorQuda(
"Undefined force type %d", type);
1016 TuneKey tuneKey()
const {
1017 std::stringstream aux;
1018 aux <<
"threads=" <<
arg.threads <<
",prec=" <<
sizeof(real);
1020 case FORCE_LONG_LINK: aux <<
",LONG_LINK";
break;
1021 case FORCE_COMPLETE: aux <<
",COMPLETE";
break;
1022 default:
errorQuda(
"Undefined force type %d", type);
1024 return TuneKey(meta.VolString(),
typeid(*this).name(), aux.str().c_str());
1029 case FORCE_LONG_LINK:
arg.outA.save();
break;
1030 case FORCE_COMPLETE:
break;
1031 default:
errorQuda(
"Undefined force type %d", type);
1037 case FORCE_LONG_LINK:
arg.outA.load();
break;
1038 case FORCE_COMPLETE:
break;
1039 default:
errorQuda(
"Undefined force type %d", type);
1043 long long flops()
const {
1045 case FORCE_LONG_LINK:
return 2*
arg.threads*4968ll;
1046 case FORCE_COMPLETE:
return 2*
arg.threads*792ll;
1047 default:
errorQuda(
"Undefined force type %d", type);
1052 long long bytes()
const {
1054 case FORCE_LONG_LINK:
return 4*2*
arg.threads*(2*
arg.outA.Bytes() + 4*
arg.link.Bytes() + 3*
arg.oProd.Bytes());
1055 case FORCE_COMPLETE:
return 4*2*
arg.threads*(
arg.outA.Bytes() +
arg.link.Bytes() +
arg.oProd.Bytes());
1056 default:
errorQuda(
"Undefined force type %d", type);
1064 if (!link.isNative())
errorQuda(
"Unsupported gauge order %d", link.Order());
1065 if (!oldOprod.isNative())
errorQuda(
"Unsupported gauge order %d", oldOprod.Order());
1066 if (!newOprod.isNative())
errorQuda(
"Unsupported gauge order %d", newOprod.Order());
1072 typedef LongLinkArg<double,QUDA_RECONSTRUCT_NO> Arg;
1073 Arg
arg(newOprod, link, oldOprod,
coeff);
1074 HisqForce<double,Arg> longLink(
arg, link, 0, 0, FORCE_LONG_LINK);
1076 if (
flops) (*flops) += longLink.flops();
1078 errorQuda(
"Reconstruct %d not supported", link.Reconstruct());
1082 typedef LongLinkArg<float,QUDA_RECONSTRUCT_NO> Arg;
1083 Arg
arg(newOprod, link, oldOprod,
coeff);
1084 HisqForce<float, Arg> longLink(
arg, link, 0, 0, FORCE_LONG_LINK);
1086 if (
flops) (*flops) += longLink.flops();
1088 errorQuda(
"Reconstruct %d not supported", link.Reconstruct());
1091 errorQuda(
"Unsupported precision %d", precision);
1094 cudaDeviceSynchronize();
1097 void hisqCompleteForce(GaugeField &force,
const GaugeField &oprod,
const GaugeField &link,
long long*
flops)
1099 if (!link.isNative())
errorQuda(
"Unsupported gauge order %d", link.Order());
1100 if (!oprod.isNative())
errorQuda(
"Unsupported gauge order %d", oprod.Order());
1101 if (!force.isNative())
errorQuda(
"Unsupported gauge order %d", force.Order());
1107 typedef CompleteForceArg<double,QUDA_RECONSTRUCT_NO> Arg;
1108 Arg
arg(force, link, oprod);
1109 HisqForce<double,Arg> completeForce(
arg, link, 0, 0, FORCE_COMPLETE);
1110 completeForce.apply(0);
1113 errorQuda(
"Reconstruct %d not supported", link.Reconstruct());
1117 typedef CompleteForceArg<float,QUDA_RECONSTRUCT_NO> Arg;
1118 Arg
arg(force, link, oprod);
1119 HisqForce<float, Arg> completeForce(
arg, link, 0, 0, FORCE_COMPLETE);
1120 completeForce.apply(0);
1123 errorQuda(
"Reconstruct %d not supported", link.Reconstruct());
1126 errorQuda(
"Unsupported precision %d", precision);
1129 cudaDeviceSynchronize();
1134 #endif // GPU_HISQ_FORCE
enum QudaPrecision_s QudaPrecision
static __device__ __host__ int linkIndexShift(const I x[], const J dx[], const K X[4])
static __device__ __host__ int linkIndex(const int x[], const I X[4])
QudaVerbosity getVerbosity()
#define checkPrecision(...)
QudaGaugeParam gauge_param
void hisqLongLinkForce(GaugeField &newOprod, const GaugeField &oprod, const GaugeField &link, double coeff, long long *flops=nullptr)
Compute the long-link contribution to the fermion force.
void hisqStaplesForce(GaugeField &newOprod, const GaugeField &oprod, const GaugeField &link, const double path_coeff[6], long long *flops=nullptr)
Compute the fat-link contribution to the fermion force.
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
#define checkLocation(...)
static unsigned int unsigned int shift
Main header file for host and device accessors to GaugeFields.
QudaReconstructType reconstruct
void hisqCompleteForce(GaugeField &momentum, const GaugeField &oprod, const GaugeField &link, long long *flops=nullptr)
Multiply the computed the force matrix by the gauge field and perform traceless anti-hermitian projec...
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
__device__ __host__ void makeAntiHerm(Matrix< Complex, N > &m)
__host__ __device__ ValueType conj(ValueType x)
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
static __inline__ size_t size_t d
int comm_dim_partitioned(int dim)
static __device__ __host__ void getCoords(int x[], int cb_index, const I X[], int parity)