15 namespace fermion_force {
31 FORCE_LEPAGE_MIDDLE_LINK,
33 FORCE_SIDE_LINK_SHORT,
40 __device__ __host__ constexpr
inline int opp_dir(
int dir) {
return 7-dir; }
41 __device__ __host__ constexpr
inline int goes_forward(
int dir) {
return dir<=3; }
42 __device__ __host__ constexpr
inline int goes_backward(
int dir) {
return dir>3; }
43 __device__ __host__ constexpr
inline int CoeffSign(
int pos_dir,
int odd_lattice) {
return 2*((pos_dir + odd_lattice + 1) & 1) - 1; }
44 __device__ __host__ constexpr
inline int Sign(
int parity) {
return parity ? -1 : 1; }
45 __device__ __host__ constexpr
inline int posDir(
int dir) {
return (dir >= 4) ? 7-dir : dir; }
47 template <
int dir,
typename Arg>
48 inline __device__ __host__
void updateCoords(
int x[],
int shift,
const Arg &
arg) {
49 x[dir] = (x[dir] + shift + arg.E[dir]) % arg.E[dir];
52 template <
typename Arg>
53 inline __device__ __host__
void updateCoords(
int x[],
int dir,
int shift,
const Arg &
arg) {
55 case 0: updateCoords<0>(x, shift,
arg);
break;
56 case 1: updateCoords<1>(x, shift,
arg);
break;
57 case 2: updateCoords<2>(x, shift,
arg);
break;
58 case 3: updateCoords<3>(x, shift,
arg);
break;
63 template <
typename real>
72 : one(path_coeff_array[0]), naik(path_coeff_array[1]),
73 three(path_coeff_array[2]), five(path_coeff_array[3]),
74 seven(path_coeff_array[4]), lepage(path_coeff_array[5]) { }
77 template <
typename real, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
79 typedef typename gauge_mapper<real,reconstruct>::type G;
97 BaseForceArg(
const GaugeField &link,
int overlap) : link(link), threads(1),
100 for (
int d=0; d<4; d++) {
102 border[d] = link.R()[d];
103 X[d] = E[d] - 2*border[d];
109 oddness_change = (base_idx[0] + base_idx[1] + base_idx[2] + base_idx[3])&1;
113 template <
typename real, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
114 struct FatLinkArg :
public BaseForceArg<real,reconstruct> {
116 typedef typename gauge_mapper<real,QUDA_RECONSTRUCT_NO>::type F;
127 const real accumu_coeff;
133 FatLinkArg(GaugeField &force,
const GaugeField &oProd,
const GaugeField &link, real coeff, HisqForceType type)
134 : BaseForceArg<real,reconstruct>(link, 0), outA(force), outB(force), pMu(oProd), p3(oProd), qMu(oProd),
135 oProd(oProd), qProd(oProd), qPrev(oProd), coeff(coeff), accumu_coeff(0),
136 p_mu(false), q_mu(false), q_prev(false)
137 {
if (type != FORCE_ONE_LINK)
errorQuda(
"This constructor is for FORCE_ONE_LINK"); }
139 FatLinkArg(GaugeField &newOprod, GaugeField &pMu, GaugeField &
P3, GaugeField &qMu,
140 const GaugeField &oProd,
const GaugeField &qPrev,
const GaugeField &link,
141 real coeff,
int overlap, HisqForceType type)
142 : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(newOprod), pMu(pMu), p3(P3), qMu(qMu),
143 oProd(oProd), qProd(oProd), qPrev(qPrev), coeff(coeff), accumu_coeff(0), p_mu(true), q_mu(true), q_prev(true)
144 {
if (type != FORCE_MIDDLE_LINK)
errorQuda(
"This constructor is for FORCE_MIDDLE_LINK"); }
146 FatLinkArg(GaugeField &newOprod, GaugeField &pMu, GaugeField &P3, GaugeField &qMu,
147 const GaugeField &oProd,
const GaugeField &link,
148 real coeff,
int overlap, HisqForceType type)
149 : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(newOprod), pMu(pMu), p3(P3), qMu(qMu),
150 oProd(oProd), qProd(oProd), qPrev(qMu), coeff(coeff), accumu_coeff(0), p_mu(true), q_mu(true), q_prev(false)
151 {
if (type != FORCE_MIDDLE_LINK)
errorQuda(
"This constructor is for FORCE_MIDDLE_LINK"); }
153 FatLinkArg(GaugeField &newOprod, GaugeField &P3,
const GaugeField &oProd,
154 const GaugeField &qPrev,
const GaugeField &link,
155 real coeff,
int overlap, HisqForceType type)
156 : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(newOprod), pMu(P3), p3(P3), qMu(qPrev),
157 oProd(oProd), qProd(oProd), qPrev(qPrev), coeff(coeff), accumu_coeff(0), p_mu(false), q_mu(false), q_prev(true)
158 {
if (type != FORCE_LEPAGE_MIDDLE_LINK)
errorQuda(
"This constructor is for FORCE_MIDDLE_LINK"); }
160 FatLinkArg(GaugeField &newOprod, GaugeField &shortP,
const GaugeField &P3,
161 const GaugeField &qProd,
const GaugeField &link, real coeff, real accumu_coeff,
int overlap, HisqForceType type)
162 : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(shortP), pMu(P3), p3(P3), qMu(qProd), oProd(qProd), qProd(qProd),
163 qPrev(qProd), coeff(coeff), accumu_coeff(accumu_coeff),
164 p_mu(false), q_mu(false), q_prev(false)
165 {
if (type != FORCE_SIDE_LINK)
errorQuda(
"This constructor is for FORCE_SIDE_LINK or FORCE_ALL_LINK"); }
167 FatLinkArg(GaugeField &newOprod, GaugeField &P3,
const GaugeField &link,
168 real coeff,
int overlap, HisqForceType type)
169 : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(newOprod),
170 pMu(P3), p3(P3), qMu(P3), oProd(P3), qProd(P3), qPrev(P3), coeff(coeff), accumu_coeff(0.0),
171 p_mu(false), q_mu(false), q_prev(false)
172 {
if (type != FORCE_SIDE_LINK_SHORT)
errorQuda(
"This constructor is for FORCE_SIDE_LINK_SHORT"); }
174 FatLinkArg(GaugeField &newOprod, GaugeField &shortP,
const GaugeField &oProd,
const GaugeField &qPrev,
175 const GaugeField &link, real coeff, real accumu_coeff,
int overlap, HisqForceType type,
bool dummy)
176 : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(shortP), oProd(oProd), qPrev(qPrev),
177 pMu(shortP), p3(shortP), qMu(qPrev), qProd(qPrev),
178 coeff(coeff), accumu_coeff(accumu_coeff), p_mu(false), q_mu(false), q_prev(false)
179 {
if (type != FORCE_ALL_LINK)
errorQuda(
"This constructor is for FORCE_ALL_LINK"); }
183 template <
typename real,
typename Arg>
184 __global__
void oneLinkTermKernel(Arg arg)
187 int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
188 if (x_cb >= arg.threads)
return;
189 int parity = blockIdx.y * blockDim.y + threadIdx.y;
190 int sig = blockIdx.z * blockDim.z + threadIdx.z;
191 if (sig >= 4)
return;
196 for (
int d=0; d<4; d++) x[d] += arg.border[d];
199 Link w = arg.oProd(sig, e_cb, parity);
200 Link force = arg.outA(sig, e_cb, parity);
201 force += arg.coeff * w;
202 arg.outA(sig, e_cb, parity) = force;
229 template<
typename real,
int sig_positive,
int mu_positive,
typename Arg>
230 __global__
void allLinkKernel(Arg arg)
232 typedef Matrix<complex<real>,3> Link;
234 int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
235 if (x_cb >= arg.threads)
return;
236 int parity = blockIdx.y * blockDim.y + threadIdx.y;
240 for (
int d=0; d<4; d++) x[d] += arg.base_idx[d];
242 parity = parity^arg.oddness_change;
244 real mycoeff = CoeffSign(sig_positive,parity)*arg.coeff;
246 int y[4] = {x[0], x[1], x[2], x[3]};
247 int mysig = posDir(arg.sig);
248 updateCoords(y, mysig, (sig_positive ? 1 : -1), arg);
250 int ab_link_nbr_idx = (sig_positive) ? e_cb : point_b;
252 for (
int d=0; d<4; d++) y[d] = x[d];
263 int mu = mu_positive ? arg.mu : opp_dir(arg.mu);
264 int dir = mu_positive ? -1 : 1;
266 updateCoords(y, mu, dir, arg);
268 updateCoords(y, mysig, (sig_positive ? 1 : -1), arg);
271 Link Uab = arg.link(posDir(arg.sig), ab_link_nbr_idx, sig_positive^(1-parity));
272 Link Uad = arg.link(mu, mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity);
273 Link Ubc = arg.link(mu, mu_positive ? point_c : point_b, mu_positive ? parity : 1-parity);
274 Link Ox = arg.qPrev(0, point_d, 1-parity);
275 Link Oy = arg.oProd(0, point_c, parity);
276 Link Oz = mu_positive ?
conj(Ubc)*Oy : Ubc*Oy;
279 Link force = arg.outA(arg.sig, e_cb, parity);
280 force +=
Sign(parity)*mycoeff*Oz*Ox* (mu_positive ? Uad :
conj(Uad));
281 arg.outA(arg.sig, e_cb, parity) = force;
287 Link force = arg.outA(mu, mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity);
288 force +=
Sign(mu_positive ? 1-parity : parity)*mycoeff* (mu_positive ? Oy*Ox :
conj(Ox)*
conj(Oy));
289 arg.outA(mu, mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity) = force;
291 Link shortP = arg.outB(0, point_d, 1-parity);
292 shortP += arg.accumu_coeff* (mu_positive ? Uad :
conj(Uad)) *Oy;
293 arg.outB(0, point_d, 1-parity) = shortP;
338 template <
typename real,
int sig_positive,
int mu_positive,
bool pMu,
bool qMu,
bool qPrev,
typename Arg>
339 __global__
void middleLinkKernel(Arg arg)
341 typedef Matrix<complex<real>,3> Link;
343 int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
344 if (x_cb >= arg.threads)
return;
345 int parity = blockIdx.y * blockDim.y + threadIdx.y;
358 for (
int d=0; d<4; d++) x[d] += arg.base_idx[d];
360 parity = parity ^ arg.oddness_change;
361 int y[4] = {x[0], x[1], x[2], x[3]};
363 int mymu = posDir(arg.mu);
364 updateCoords(y, mymu, (mu_positive ? -1 : 1), arg);
367 int ad_link_nbr_idx = mu_positive ? point_d : e_cb;
369 int mysig = posDir(arg.sig);
370 updateCoords(y, mysig, (sig_positive ? 1 : -1), arg);
373 for (
int d=0; d<4; d++) y[d] = x[d];
374 updateCoords(y, mysig, (sig_positive ? 1 : -1), arg);
377 int bc_link_nbr_idx = mu_positive ? point_c : point_b;
378 int ab_link_nbr_idx = sig_positive ? e_cb : point_b;
381 Link Uab = arg.link(mysig, ab_link_nbr_idx, sig_positive^(1-parity));
384 Link Ubc = arg.link(mymu, bc_link_nbr_idx, mu_positive^(1-parity));
388 Oy = arg.oProd(posDir(arg.sig), sig_positive ? point_d : point_c, sig_positive^parity);
389 if (!sig_positive) Oy =
conj(Oy);
391 Oy = arg.oProd(0, point_c, parity);
394 Link Ow = !mu_positive ? Ubc*Oy :
conj(Ubc)*Oy;
396 if (pMu) arg.pMu(0, point_b, 1-parity) = Ow;
398 arg.p3(0, e_cb, parity) = sig_positive ? Uab*Ow :
conj(Uab)*Ow;
400 Link Uad = arg.link(mymu, ad_link_nbr_idx, mu_positive^parity);
401 if (!mu_positive) Uad =
conj(Uad);
404 if (sig_positive) Oy = Ow*Uad;
405 if ( qMu ) arg.qMu(0, e_cb, parity) = Uad;
408 if ( qMu || sig_positive ) {
409 Oy = arg.qPrev(0, point_d, 1-parity);
412 if ( qMu ) arg.qMu(0, e_cb, parity) = Ox;
413 if (sig_positive) Oy = Ow*Ox;
417 Link oprod = arg.outA(arg.sig, e_cb, parity);
418 oprod += arg.coeff*Oy;
419 arg.outA(arg.sig, e_cb, parity) = oprod;
454 template <
typename real,
int mu_positive,
typename Arg>
455 __global__
void sideLinkKernel(Arg arg)
457 typedef Matrix<complex<real>, 3> Link;
458 int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
459 if (x_cb >= arg.threads)
return;
460 int parity = blockIdx.y * blockDim.y + threadIdx.y;
464 for (
int d=0; d<4; d++) x[d] = x[d] + arg.base_idx[d];
466 parity = parity ^ arg.oddness_change;
479 int mymu = posDir(arg.mu);
480 int y[4] = {x[0], x[1], x[2], x[3]};
481 updateCoords(y, mymu, (mu_positive ? -1 : 1), arg);
484 Link Oy = arg.p3(0, e_cb, parity);
487 int ad_link_nbr_idx = mu_positive ? point_d : e_cb;
489 Link Uad = arg.link(mymu, ad_link_nbr_idx, mu_positive^parity);
490 Link Ow = mu_positive ? Uad*Oy :
conj(Uad)*Oy;
492 Link shortP = arg.outB(0, point_d, 1-parity);
493 shortP += arg.accumu_coeff * Ow;
494 arg.outB(0, point_d, 1-parity) = shortP;
498 Link Ox = arg.qProd(0, point_d, 1-parity);
499 Link Ow = mu_positive ? Oy*Ox :
conj(Ox)*
conj(Oy);
501 real mycoeff = CoeffSign(goes_forward(arg.sig), parity)*CoeffSign(goes_forward(arg.mu),parity)*arg.coeff;
503 Link oprod = arg.outA(mu_positive ? arg.mu : opp_dir(arg.mu), mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity);
504 oprod += mycoeff * Ow;
505 arg.outA(mu_positive ? arg.mu : opp_dir(arg.mu), mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity) = oprod;
511 template<
typename real,
int mu_positive,
typename Arg>
512 __global__
void sideLinkShortKernel(Arg arg)
514 typedef Matrix<complex<real>,3> Link;
515 int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
516 if (x_cb >= arg.threads)
return;
517 int parity = blockIdx.y * blockDim.y + threadIdx.y;
521 for (
int d=0; d<4; d++) x[d] = x[d] + arg.base_idx[d];
523 parity = parity ^ arg.oddness_change;
535 int mymu = posDir(arg.mu);
536 int y[4] = {x[0], x[1], x[2], x[3]};
537 updateCoords(y, mymu, (mu_positive ? -1 : 1), arg);
538 int point_d = mu_positive ?
linkIndex(y,arg.E) : e_cb;
541 real mycoeff = CoeffSign(goes_forward(arg.sig),parity)*CoeffSign(goes_forward(arg.mu),parity)*arg.coeff;
543 Link Oy = arg.p3(0, e_cb, parity);
544 Link oprod = arg.outA(posDir(arg.mu), point_d, parity_);
545 oprod += mu_positive ? mycoeff * Oy : mycoeff *
conj(Oy);
546 arg.outA(posDir(arg.mu), point_d, parity_) = oprod;
549 template <
typename real,
typename Arg>
550 class FatLinkForce :
public TunableVectorYZ {
554 const GaugeField &meta;
555 const HisqForceType type;
557 unsigned int minThreads()
const {
return arg.threads; }
558 bool tuneGridDim()
const {
return false; }
561 FatLinkForce(Arg &arg,
const GaugeField &meta,
int sig,
int mu, HisqForceType type)
562 : TunableVectorYZ(2,type == FORCE_ONE_LINK ? 4 : 1), arg(arg), meta(meta), type(type) {
566 virtual ~FatLinkForce() { }
568 TuneKey tuneKey()
const {
569 std::stringstream aux;
570 if (type == FORCE_ONE_LINK) aux <<
"threads=" << arg.threads;
571 else if (type == FORCE_MIDDLE_LINK || type == FORCE_LEPAGE_MIDDLE_LINK)
572 aux <<
"threads=" << arg.threads <<
",sig=" << arg.sig <<
",mu=" << arg.mu <<
573 ",pMu=" << arg.p_mu <<
",q_muu=" << arg.q_mu <<
",q_prev=" << arg.q_prev;
575 aux <<
"threads=" << arg.threads <<
",mu=" << arg.mu;
578 case FORCE_ONE_LINK: aux <<
",ONE_LINK";
break;
579 case FORCE_ALL_LINK: aux <<
",ALL_LINK";
break;
580 case FORCE_MIDDLE_LINK: aux <<
",MIDDLE_LINK";
break;
581 case FORCE_LEPAGE_MIDDLE_LINK: aux <<
",LEPAGE_MIDDLE_LINK";
break;
582 case FORCE_SIDE_LINK: aux <<
",SIDE_LINK";
break;
583 case FORCE_SIDE_LINK_SHORT: aux <<
",SIDE_LINK_SHORT";
break;
584 default:
errorQuda(
"Undefined force type %d", type);
586 return TuneKey(meta.VolString(),
typeid(*this).name(), aux.str().c_str());
589 void apply(
const cudaStream_t &
stream) {
593 oneLinkTermKernel<real,Arg> <<<tp.grid,tp.block,tp.shared_bytes,stream>>>(
arg);
596 if (goes_forward(arg.sig) && goes_forward(arg.mu))
597 allLinkKernel<real,1,1,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
598 else if (goes_forward(arg.sig) && goes_backward(arg.mu))
599 allLinkKernel<real,1,0,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
600 else if (goes_backward(arg.sig) && goes_forward(arg.mu))
601 allLinkKernel<real,0,1,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
603 allLinkKernel<real,0,0,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(
arg);
605 case FORCE_MIDDLE_LINK:
606 if (!arg.p_mu || !arg.q_mu)
errorQuda(
"Expect p_mu=%d and q_mu=%d to both be true", arg.p_mu, arg.q_mu);
608 if (goes_forward(arg.sig) && goes_forward(arg.mu))
609 middleLinkKernel<real,1,1,true,true,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(
arg);
610 else if (goes_forward(arg.sig) && goes_backward(arg.mu))
611 middleLinkKernel<real,1,0,true,true,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
612 else if (goes_backward(arg.sig) && goes_forward(arg.mu))
613 middleLinkKernel<real,0,1,true,true,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
615 middleLinkKernel<real,0,0,true,true,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(
arg);
617 if (goes_forward(arg.sig) && goes_forward(arg.mu))
618 middleLinkKernel<real,1,1,true,true,false,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
619 else if (goes_forward(arg.sig) && goes_backward(arg.mu))
620 middleLinkKernel<real,1,0,true,true,false,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
621 else if (goes_backward(arg.sig) && goes_forward(arg.mu))
622 middleLinkKernel<real,0,1,true,true,false,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
624 middleLinkKernel<real,0,0,true,true,false,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(
arg);
627 case FORCE_LEPAGE_MIDDLE_LINK:
628 if (arg.p_mu || arg.q_mu || !arg.q_prev)
629 errorQuda(
"Expect p_mu=%d and q_mu=%d to both be false and q_prev=%d true", arg.p_mu, arg.q_mu, arg.q_prev);
630 if (goes_forward(arg.sig) && goes_forward(arg.mu))
631 middleLinkKernel<real,1,1,false,false,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(
arg);
632 else if (goes_forward(arg.sig) && goes_backward(arg.mu))
633 middleLinkKernel<real,1,0,false,false,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
634 else if (goes_backward(arg.sig) && goes_forward(arg.mu))
635 middleLinkKernel<real,0,1,false,false,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
637 middleLinkKernel<real,0,0,false,false,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(
arg);
639 case FORCE_SIDE_LINK:
640 if (goes_forward(arg.mu)) sideLinkKernel<real,1,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
641 else sideLinkKernel<real,0,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(
arg);
643 case FORCE_SIDE_LINK_SHORT:
644 if (goes_forward(arg.mu)) sideLinkShortKernel<real,1,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
645 else sideLinkShortKernel<real,0,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(
arg);
648 errorQuda(
"Undefined force type %d", type);
661 case FORCE_MIDDLE_LINK:
664 case FORCE_LEPAGE_MIDDLE_LINK:
668 case FORCE_SIDE_LINK:
670 case FORCE_SIDE_LINK_SHORT:
673 default:
errorQuda(
"Undefined force type %d", type);
686 case FORCE_MIDDLE_LINK:
689 case FORCE_LEPAGE_MIDDLE_LINK:
693 case FORCE_SIDE_LINK:
695 case FORCE_SIDE_LINK_SHORT:
698 default:
errorQuda(
"Undefined force type %d", type);
702 long long flops()
const {
705 return 2*4*arg.threads*36ll;
707 return 2*arg.threads*(goes_forward(arg.sig) ? 1242ll : 828ll);
708 case FORCE_MIDDLE_LINK:
709 case FORCE_LEPAGE_MIDDLE_LINK:
710 return 2*arg.threads*(2 * 198 +
711 (!arg.q_prev && goes_forward(arg.sig) ? 198 : 0) +
712 (arg.q_prev && (arg.q_mu || goes_forward(arg.sig) ) ? 198 : 0) +
713 ((arg.q_prev && goes_forward(arg.sig) ) ? 198 : 0) +
714 ( goes_forward(arg.sig) ? 216 : 0) );
715 case FORCE_SIDE_LINK:
return 2*arg.threads*2*234;
716 case FORCE_SIDE_LINK_SHORT:
return 2*arg.threads*36;
717 default:
errorQuda(
"Undefined force type %d", type);
722 long long bytes()
const {
725 return 2*4*arg.threads*( arg.oProd.Bytes() + 2*arg.outA.Bytes() );
727 return 2*arg.threads*( (goes_forward(arg.sig) ? 4 : 2)*arg.outA.Bytes() + 3*arg.link.Bytes()
728 + arg.oProd.Bytes() + arg.qPrev.Bytes() + 2*arg.outB.Bytes());
729 case FORCE_MIDDLE_LINK:
730 case FORCE_LEPAGE_MIDDLE_LINK:
731 return 2*arg.threads*( ( goes_forward(arg.sig) ? 2*arg.outA.Bytes() : 0 ) +
732 (arg.p_mu ? arg.pMu.Bytes() : 0) +
733 (arg.q_mu ? arg.qMu.Bytes() : 0) +
734 ( ( goes_forward(arg.sig) || arg.q_mu ) ? arg.qPrev.Bytes() : 0) +
735 arg.p3.Bytes() + 3*arg.link.Bytes() + arg.oProd.Bytes() );
736 case FORCE_SIDE_LINK:
737 return 2*arg.threads*( 2*arg.outA.Bytes() + 2*arg.outB.Bytes() +
738 arg.p3.Bytes() + arg.link.Bytes() + arg.qProd.Bytes() );
739 case FORCE_SIDE_LINK_SHORT:
740 return 2*arg.threads*( 2*arg.outA.Bytes() + arg.p3.Bytes() );
741 default:
errorQuda(
"Undefined force type %d", type);
747 template<
typename real>
749 GaugeField &
Qmu, GaugeField &
Qnumu, GaugeField &newOprod,
750 const GaugeField &oprod,
const GaugeField &link,
753 real OneLink = act_path_coeff.
one;
754 real ThreeSt = act_path_coeff.
three;
755 real mThreeSt = -ThreeSt;
756 real FiveSt = act_path_coeff.
five;
757 real mFiveSt = -FiveSt;
758 real SevenSt = act_path_coeff.
seven;
759 real Lepage = act_path_coeff.
lepage;
760 real mLepage = -Lepage;
762 FatLinkArg<real>
arg(newOprod, oprod, link, OneLink, FORCE_ONE_LINK);
763 FatLinkForce<real, FatLinkArg<real> > oneLink(arg, link, 0, 0, FORCE_ONE_LINK);
766 for (
int sig=0; sig<8; sig++) {
767 for (
int mu=0; mu<8; mu++) {
768 if ( (mu == sig) || (mu == opp_dir(sig)))
continue;
772 FatLinkArg<real> middleLinkArg( newOprod, Pmu, P3, Qmu, oprod, link, mThreeSt, 2, FORCE_MIDDLE_LINK);
773 FatLinkForce<real, FatLinkArg<real> > middleLink(middleLinkArg, link, sig, mu, FORCE_MIDDLE_LINK);
776 for (
int nu=0; nu < 8; nu++) {
777 if (nu == sig || nu == opp_dir(sig) || nu == mu || nu == opp_dir(mu))
continue;
781 FatLinkArg<real> middleLinkArg( newOprod, Pnumu, P5, Qnumu, Pmu, Qmu, link, FiveSt, 1, FORCE_MIDDLE_LINK);
782 FatLinkForce<real, FatLinkArg<real> > middleLink(middleLinkArg, link, sig, nu, FORCE_MIDDLE_LINK);
785 for (
int rho = 0; rho < 8; rho++) {
786 if (rho == sig || rho == opp_dir(sig) || rho == mu || rho == opp_dir(mu) || rho == nu || rho == opp_dir(nu))
continue;
789 FatLinkArg<real>
arg(newOprod, P5, Pnumu, Qnumu, link, SevenSt, FiveSt != 0 ? SevenSt/FiveSt : 0, 1, FORCE_ALL_LINK,
true);
790 FatLinkForce<real, FatLinkArg<real> > all(arg, link, sig, rho, FORCE_ALL_LINK);
796 FatLinkArg<real>
arg(newOprod, P3, P5, Qmu, link, mFiveSt, (ThreeSt != 0 ? FiveSt/ThreeSt : 0), 1, FORCE_SIDE_LINK);
797 FatLinkForce<real, FatLinkArg<real> > side(arg, link, sig, nu, FORCE_SIDE_LINK);
804 FatLinkArg<real> middleLinkArg( newOprod, P5, Pmu, Qmu, link, Lepage, 2, FORCE_LEPAGE_MIDDLE_LINK);
805 FatLinkForce<real, FatLinkArg<real> > middleLink(middleLinkArg, link, sig, mu, FORCE_LEPAGE_MIDDLE_LINK);
808 FatLinkArg<real>
arg(newOprod, P3, P5, Qmu, link, mLepage, (ThreeSt != 0 ? Lepage/ThreeSt : 0), 2, FORCE_SIDE_LINK);
809 FatLinkForce<real, FatLinkArg<real> > side(arg, link, sig, mu, FORCE_SIDE_LINK);
814 FatLinkArg<real>
arg(newOprod, P3, link, ThreeSt, 1, FORCE_SIDE_LINK_SHORT);
815 FatLinkForce<real, FatLinkArg<real> > side(arg, P3, sig, mu, FORCE_SIDE_LINK_SHORT);
822 void hisqStaplesForce(GaugeField &newOprod,
const GaugeField &oprod,
const GaugeField &link,
const double path_coeff_array[6])
824 if (!link.isNative())
errorQuda(
"Unsupported gauge order %d", link.Order());
825 if (!oprod.isNative())
errorQuda(
"Unsupported gauge order %d", oprod.Order());
826 if (!newOprod.isNative())
errorQuda(
"Unsupported gauge order %d", newOprod.Order());
845 hisqStaplesForce<double>(
Pmu,
P3,
P5,
Pnumu,
Qmu,
Qnumu, newOprod, oprod, link, act_path_coeff);
853 cudaDeviceSynchronize();
857 template <
typename real, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
858 struct CompleteForceArg :
public BaseForceArg<real,reconstruct> {
860 typedef typename gauge_mapper<real,QUDA_RECONSTRUCT_NO>::type F;
865 CompleteForceArg(GaugeField &force,
const GaugeField &link)
866 : BaseForceArg<real,reconstruct>(link, 0), outA(force), oProd(force), coeff(0.0)
872 template <
typename real,
typename Arg>
873 __global__
void completeForceKernel(Arg arg)
875 typedef Matrix<complex<real>,3> Link;
876 int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
877 if (x_cb >= arg.threads)
return;
878 int parity = blockIdx.y * blockDim.y + threadIdx.y;
883 for (
int d=0; d<4; d++) x[d] += arg.border[d];
884 int e_cb = linkIndex(x,arg.E);
887 for (
int sig=0; sig<4; ++sig) {
888 Link Uw = arg.link(sig, e_cb, parity);
889 Link Ox = arg.oProd(sig, e_cb, parity);
894 real coeff = (parity==1) ? -1.0 : 1.0;
895 arg.outA(sig, e_cb, parity) = coeff*Ow;
899 template <
typename real, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
900 struct LongLinkArg :
public BaseForceArg<real,reconstruct> {
902 typedef typename gauge::FloatNOrder<real,18,2,11> M;
903 typedef typename gauge_mapper<real,QUDA_RECONSTRUCT_NO>::type F;
908 LongLinkArg(GaugeField &newOprod,
const GaugeField &link,
const GaugeField &oprod, real coeff)
909 : BaseForceArg<real,reconstruct>(link,0), outA(newOprod), oProd(oprod), coeff(coeff)
917 template <
typename real,
typename Arg>
918 __global__
void longLinkKernel(Arg arg)
920 typedef Matrix<complex<real>,3> Link;
921 int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
922 if (x_cb >= arg.threads)
return;
923 int parity = blockIdx.y * blockDim.y + threadIdx.y;
926 int dx[4] = {0,0,0,0};
930 for (
int i=0; i<4; i++) x[i] += arg.border[i];
931 int e_cb = linkIndex(x,arg.E);
946 for (
int sig=0; sig<4; sig++) {
962 Link Uab = arg.link(sig, point_a, parity);
963 Link Ubc = arg.link(sig, point_b, 1-parity);
964 Link Ude = arg.link(sig, point_d, 1-parity);
965 Link Uef = arg.link(sig, point_e, parity);
967 Link Oz = arg.oProd(sig, point_c, parity);
968 Link Oy = arg.oProd(sig, point_b, 1-parity);
969 Link Ox = arg.oProd(sig, point_a, parity);
971 Link temp = Ude*Uef*Oz - Ude*Oy*Ubc + Ox*Uab*Ubc;
973 Link force = arg.outA(sig, e_cb, parity);
974 arg.outA(sig, e_cb, parity) = force + arg.coeff*temp;
979 template <
typename real,
typename Arg>
980 class HisqForce :
public TunableVectorY {
983 const GaugeField &meta;
984 const HisqForceType type;
986 unsigned int minThreads()
const {
return arg.threads; }
987 bool tuneGridDim()
const {
return false; }
990 HisqForce(Arg &arg,
const GaugeField &meta,
int sig,
int mu, HisqForceType type)
991 : TunableVectorY(2), arg(arg), meta(meta), type(type) {
995 virtual ~HisqForce() { }
997 void apply(
const cudaStream_t &stream) {
1000 case FORCE_LONG_LINK:
1001 longLinkKernel<real,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(
arg);
break;
1002 case FORCE_COMPLETE:
1003 completeForceKernel<real,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(
arg);
break;
1005 errorQuda(
"Undefined force type %d", type);
1009 TuneKey tuneKey()
const {
1010 std::stringstream aux;
1011 aux <<
"threads=" << arg.threads <<
",prec=" <<
sizeof(real);
1013 case FORCE_LONG_LINK: aux <<
",LONG_LINK";
break;
1014 case FORCE_COMPLETE: aux <<
",COMPLETE";
break;
1015 default:
errorQuda(
"Undefined force type %d", type);
1017 return TuneKey(meta.VolString(),
typeid(*this).name(), aux.str().c_str());
1022 case FORCE_LONG_LINK:
1023 case FORCE_COMPLETE:
1024 arg.outA.save();
break;
1025 default:
errorQuda(
"Undefined force type %d", type);
1031 case FORCE_LONG_LINK:
1032 case FORCE_COMPLETE:
1033 arg.outA.load();
break;
1034 default:
errorQuda(
"Undefined force type %d", type);
1038 long long flops()
const {
1040 case FORCE_LONG_LINK:
return 2*arg.threads*4968ll;
1041 case FORCE_COMPLETE:
return 2*arg.threads*792ll;
1042 default:
errorQuda(
"Undefined force type %d", type);
1047 long long bytes()
const {
1049 case FORCE_LONG_LINK:
return 4*2*arg.threads*(2*arg.outA.Bytes() + 4*arg.link.Bytes() + 3*arg.oProd.Bytes());
1050 case FORCE_COMPLETE:
return 4*2*arg.threads*(arg.outA.Bytes() + arg.link.Bytes() + arg.oProd.Bytes());
1051 default:
errorQuda(
"Undefined force type %d", type);
1057 void hisqLongLinkForce(GaugeField &newOprod,
const GaugeField &oldOprod,
const GaugeField &link,
double coeff)
1059 if (!link.isNative())
errorQuda(
"Unsupported gauge order %d", link.Order());
1060 if (!oldOprod.isNative())
errorQuda(
"Unsupported gauge order %d", oldOprod.Order());
1061 if (!newOprod.isNative())
errorQuda(
"Unsupported gauge order %d", newOprod.Order());
1067 typedef LongLinkArg<double,QUDA_RECONSTRUCT_NO> Arg;
1068 Arg
arg(newOprod, link, oldOprod, coeff);
1069 HisqForce<double,Arg> longLink(arg, link, 0, 0, FORCE_LONG_LINK);
1072 errorQuda(
"Reconstruct %d not supported", link.Reconstruct());
1076 typedef LongLinkArg<float,QUDA_RECONSTRUCT_NO> Arg;
1077 Arg
arg(newOprod, link, oldOprod, coeff);
1078 HisqForce<float, Arg> longLink(arg, link, 0, 0, FORCE_LONG_LINK);
1081 errorQuda(
"Reconstruct %d not supported", link.Reconstruct());
1084 errorQuda(
"Unsupported precision %d", precision);
1087 cudaDeviceSynchronize();
1092 if (!link.isNative())
errorQuda(
"Unsupported gauge order %d", link.Order());
1093 if (!force.isNative())
errorQuda(
"Unsupported gauge order %d", force.Order());
1099 typedef CompleteForceArg<double,QUDA_RECONSTRUCT_NO> Arg;
1100 Arg
arg(force, link);
1101 HisqForce<double,Arg> completeForce(arg, link, 0, 0, FORCE_COMPLETE);
1102 completeForce.apply(0);
1104 errorQuda(
"Reconstruct %d not supported", link.Reconstruct());
1108 typedef CompleteForceArg<float,QUDA_RECONSTRUCT_NO> Arg;
1109 Arg
arg(force, link);
1110 HisqForce<float, Arg> completeForce(arg, link, 0, 0, FORCE_COMPLETE);
1111 completeForce.apply(0);
1113 errorQuda(
"Reconstruct %d not supported", link.Reconstruct());
1116 errorQuda(
"Unsupported precision %d", precision);
1119 cudaDeviceSynchronize();
1124 #endif // GPU_HISQ_FORCE
enum QudaPrecision_s QudaPrecision
static __device__ __host__ int linkIndexShift(const I x[], const J dx[], const K X[4])
static __device__ __host__ int linkIndex(const int x[], const I X[4])
void hisqLongLinkForce(GaugeField &newOprod, const GaugeField &oprod, const GaugeField &link, double coeff)
Compute the long-link contribution to the fermion force.
QudaVerbosity getVerbosity()
#define checkPrecision(...)
void hisqCompleteForce(GaugeField &oprod, const GaugeField &link)
Multiply the computed the force matrix by the gauge field and perform traceless anti-hermitian projec...
QudaGaugeParam gauge_param
void hisqStaplesForce(GaugeField &newOprod, const GaugeField &oprod, const GaugeField &link, const double path_coeff[6])
Compute the fat-link contribution to the fermion force.
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
#define checkLocation(...)
Main header file for host and device accessors to GaugeFields.
QudaReconstructType reconstruct
static int commDim[QUDA_MAX_DIM]
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
__device__ __host__ void makeAntiHerm(Matrix< Complex, N > &m)
__host__ __device__ ValueType conj(ValueType x)
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
int comm_dim_partitioned(int dim)
__host__ __device__ int getCoords(int coord[], const Arg &arg, int &idx, int parity, int &dim)
Compute the space-time coordinates we are at.