2 #include <quda_internal.h>
3 #include <gauge_field.h>
4 #include <ks_improved_force.h>
5 #include <quda_matrix.h>
7 #include <index_helper.cuh>
8 #include <gauge_field_order.h>
9 #include <instantiate.h>
15 namespace fermion_force {
31 FORCE_LEPAGE_MIDDLE_LINK,
33 FORCE_SIDE_LINK_SHORT,
40 constexpr int opp_dir(int dir) { return 7-dir; }
41 constexpr int goes_forward(int dir) { return dir<=3; }
42 constexpr int goes_backward(int dir) { return dir>3; }
43 constexpr int CoeffSign(int pos_dir, int odd_lattice) { return 2*((pos_dir + odd_lattice + 1) & 1) - 1; }
44 constexpr int Sign(int parity) { return parity ? -1 : 1; }
45 constexpr int posDir(int dir) { return (dir >= 4) ? 7-dir : dir; }
47 template <int dir, typename Arg>
48 constexpr void updateCoords(int x[], int shift, const Arg &arg) {
49 x[dir] = (x[dir] + shift + arg.E[dir]) % arg.E[dir];
52 template <typename Arg>
53 constexpr void updateCoords(int x[], int dir, int shift, const Arg &arg) {
55 case 0: updateCoords<0>(x, shift, arg); break;
56 case 1: updateCoords<1>(x, shift, arg); break;
57 case 2: updateCoords<2>(x, shift, arg); break;
58 case 3: updateCoords<3>(x, shift, arg); break;
62 //struct for holding the fattening path coefficients
63 template <typename real>
64 struct PathCoefficients {
71 PathCoefficients(const double *path_coeff_array)
72 : one(path_coeff_array[0]), naik(path_coeff_array[1]),
73 three(path_coeff_array[2]), five(path_coeff_array[3]),
74 seven(path_coeff_array[4]), lepage(path_coeff_array[5]) { }
77 template <typename real_, int nColor_, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
80 static constexpr int nColor = nColor_;
81 typedef typename gauge_mapper<real,reconstruct>::type G;
84 int X[4]; // regular grid dims
85 int D[4]; // working set grid dims
86 int E[4]; // extended grid dims
90 int base_idx[4]; // the offset into the extended field
96 @param[in] link Gauge field
97 @param[in] overlap Radius of additional redundant computation to do
99 BaseForceArg(const GaugeField &link, int overlap) : link(link), threads(1),
100 commDim{ comm_dim_partitioned(0), comm_dim_partitioned(1), comm_dim_partitioned(2), comm_dim_partitioned(3) }
102 for (int d=0; d<4; d++) {
104 border[d] = link.R()[d];
105 X[d] = E[d] - 2*border[d];
106 D[d] = comm_dim_partitioned(d) ? X[d]+overlap*2 : X[d];
107 base_idx[d] = comm_dim_partitioned(d) ? border[d]-overlap : 0;
111 oddness_change = (base_idx[0] + base_idx[1] + base_idx[2] + base_idx[3])&1;
115 template <typename real, int nColor, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
116 struct FatLinkArg : public BaseForceArg<real, nColor, reconstruct> {
117 using BaseForceArg = BaseForceArg<real, nColor, reconstruct>;
118 typedef typename gauge_mapper<real,QUDA_RECONSTRUCT_NO>::type F;
129 const real accumu_coeff;
135 FatLinkArg(GaugeField &force, const GaugeField &oProd, const GaugeField &link, real coeff, HisqForceType type)
136 : BaseForceArg(link, 0), outA(force), outB(force), pMu(oProd), p3(oProd), qMu(oProd),
137 oProd(oProd), qProd(oProd), qPrev(oProd), coeff(coeff), accumu_coeff(0),
138 p_mu(false), q_mu(false), q_prev(false)
139 { if (type != FORCE_ONE_LINK) errorQuda("This constructor is for FORCE_ONE_LINK"); }
141 FatLinkArg(GaugeField &newOprod, GaugeField &pMu, GaugeField &P3, GaugeField &qMu,
142 const GaugeField &oProd, const GaugeField &qPrev, const GaugeField &link,
143 real coeff, int overlap, HisqForceType type)
144 : BaseForceArg(link, overlap), outA(newOprod), outB(newOprod), pMu(pMu), p3(P3), qMu(qMu),
145 oProd(oProd), qProd(oProd), qPrev(qPrev), coeff(coeff), accumu_coeff(0), p_mu(true), q_mu(true), q_prev(true)
146 { if (type != FORCE_MIDDLE_LINK) errorQuda("This constructor is for FORCE_MIDDLE_LINK"); }
148 FatLinkArg(GaugeField &newOprod, GaugeField &pMu, GaugeField &P3, GaugeField &qMu,
149 const GaugeField &oProd, const GaugeField &link,
150 real coeff, int overlap, HisqForceType type)
151 : BaseForceArg(link, overlap), outA(newOprod), outB(newOprod), pMu(pMu), p3(P3), qMu(qMu),
152 oProd(oProd), qProd(oProd), qPrev(qMu), coeff(coeff), accumu_coeff(0), p_mu(true), q_mu(true), q_prev(false)
153 { if (type != FORCE_MIDDLE_LINK) errorQuda("This constructor is for FORCE_MIDDLE_LINK"); }
155 FatLinkArg(GaugeField &newOprod, GaugeField &P3, const GaugeField &oProd,
156 const GaugeField &qPrev, const GaugeField &link,
157 real coeff, int overlap, HisqForceType type)
158 : BaseForceArg(link, overlap), outA(newOprod), outB(newOprod), pMu(P3), p3(P3), qMu(qPrev),
159 oProd(oProd), qProd(oProd), qPrev(qPrev), coeff(coeff), accumu_coeff(0), p_mu(false), q_mu(false), q_prev(true)
160 { if (type != FORCE_LEPAGE_MIDDLE_LINK) errorQuda("This constructor is for FORCE_MIDDLE_LINK"); }
162 FatLinkArg(GaugeField &newOprod, GaugeField &shortP, const GaugeField &P3,
163 const GaugeField &qProd, const GaugeField &link, real coeff, real accumu_coeff, int overlap, HisqForceType type)
164 : BaseForceArg(link, overlap), outA(newOprod), outB(shortP), pMu(P3), p3(P3), qMu(qProd), oProd(qProd), qProd(qProd),
165 qPrev(qProd), coeff(coeff), accumu_coeff(accumu_coeff),
166 p_mu(false), q_mu(false), q_prev(false)
167 { if (type != FORCE_SIDE_LINK) errorQuda("This constructor is for FORCE_SIDE_LINK or FORCE_ALL_LINK"); }
169 FatLinkArg(GaugeField &newOprod, GaugeField &P3, const GaugeField &link,
170 real coeff, int overlap, HisqForceType type)
171 : BaseForceArg(link, overlap), outA(newOprod), outB(newOprod),
172 pMu(P3), p3(P3), qMu(P3), oProd(P3), qProd(P3), qPrev(P3), coeff(coeff), accumu_coeff(0.0),
173 p_mu(false), q_mu(false), q_prev(false)
174 { if (type != FORCE_SIDE_LINK_SHORT) errorQuda("This constructor is for FORCE_SIDE_LINK_SHORT"); }
176 FatLinkArg(GaugeField &newOprod, GaugeField &shortP, const GaugeField &oProd, const GaugeField &qPrev,
177 const GaugeField &link, real coeff, real accumu_coeff, int overlap, HisqForceType type, bool dummy)
178 : BaseForceArg(link, overlap), outA(newOprod), outB(shortP), oProd(oProd), qPrev(qPrev),
179 pMu(shortP), p3(shortP), qMu(qPrev), qProd(qPrev), // dummy
180 coeff(coeff), accumu_coeff(accumu_coeff), p_mu(false), q_mu(false), q_prev(false)
181 { if (type != FORCE_ALL_LINK) errorQuda("This constructor is for FORCE_ALL_LINK"); }
185 template <typename Arg>
186 __global__ void oneLinkTermKernel(Arg arg)
188 typedef Matrix<complex<typename Arg::real>, Arg::nColor> Link;
189 int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
190 if (x_cb >= arg.threads) return;
191 int parity = blockIdx.y * blockDim.y + threadIdx.y;
192 int sig = blockIdx.z * blockDim.z + threadIdx.z;
193 if (sig >= 4) return;
196 getCoords(x, x_cb, arg.X, parity);
198 for (int d=0; d<4; d++) x[d] += arg.border[d];
199 int e_cb = linkIndex(x,arg.E);
201 Link w = arg.oProd(sig, e_cb, parity);
202 Link force = arg.outA(sig, e_cb, parity);
203 force += arg.coeff * w;
204 arg.outA(sig, e_cb, parity) = force;
208 /********************************allLinkKernel*********************************************
210 * In this function we need
212 * 3 LINKS: ad_link, ab_link, bc_link
213 * 5 COLOR MATRIX: Qprev_at_D, oprod_at_C, newOprod_at_A(sig), newOprod_at_D/newOprod_at_A(mu), shortP_at_D
215 * 3 COLOR MATRIX: newOprod_at_A(sig), newOprod_at_D/newOprod_at_A(mu), shortP_at_D,
217 * If sig is negative, then we don't need to read/write the color matrix newOprod_at_A(sig)
219 * Therefore the data traffic, in two-number pair (num_of_link, num_of_color_matrix)
221 * if (sig is positive): (3, 8)
224 * This function is called 384 times, half positive sig, half negative sig
226 * Flop count, in two-number pair (matrix_multi, matrix_add)
227 * if(sig is positive) (6,3)
230 ************************************************************************************************/
231 template <int sig_positive, int mu_positive, typename Arg>
232 __global__ void allLinkKernel(Arg arg)
234 typedef Matrix<complex<typename Arg::real>, Arg::nColor> Link;
236 int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
237 if (x_cb >= arg.threads) return;
238 int parity = blockIdx.y * blockDim.y + threadIdx.y;
241 getCoords(x, x_cb, arg.D, parity);
242 for (int d=0; d<4; d++) x[d] += arg.base_idx[d];
243 int e_cb = linkIndex(x,arg.E);
244 parity = parity^arg.oddness_change;
246 auto mycoeff = CoeffSign(sig_positive,parity)*arg.coeff;
248 int y[4] = {x[0], x[1], x[2], x[3]};
249 int mysig = posDir(arg.sig);
250 updateCoords(y, mysig, (sig_positive ? 1 : -1), arg);
251 int point_b = linkIndex(y,arg.E);
252 int ab_link_nbr_idx = (sig_positive) ? e_cb : point_b;
254 for (int d=0; d<4; d++) y[d] = x[d];
261 * A is the current point (sid)
265 int mu = mu_positive ? arg.mu : opp_dir(arg.mu);
266 int dir = mu_positive ? -1 : 1;
268 updateCoords(y, mu, dir, arg);
269 int point_d = linkIndex(y,arg.E);
270 updateCoords(y, mysig, (sig_positive ? 1 : -1), arg);
271 int point_c = linkIndex(y,arg.E);
273 Link Uab = arg.link(posDir(arg.sig), ab_link_nbr_idx, sig_positive^(1-parity));
274 Link Uad = arg.link(mu, mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity);
275 Link Ubc = arg.link(mu, mu_positive ? point_c : point_b, mu_positive ? parity : 1-parity);
276 Link Ox = arg.qPrev(0, point_d, 1-parity);
277 Link Oy = arg.oProd(0, point_c, parity);
278 Link Oz = mu_positive ? conj(Ubc)*Oy : Ubc*Oy;
281 Link force = arg.outA(arg.sig, e_cb, parity);
282 force += Sign(parity)*mycoeff*Oz*Ox* (mu_positive ? Uad : conj(Uad));
283 arg.outA(arg.sig, e_cb, parity) = force;
289 Link force = arg.outA(mu, mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity);
290 force += Sign(mu_positive ? 1-parity : parity)*mycoeff* (mu_positive ? Oy*Ox : conj(Ox)*conj(Oy));
291 arg.outA(mu, mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity) = force;
293 Link shortP = arg.outB(0, point_d, 1-parity);
294 shortP += arg.accumu_coeff* (mu_positive ? Uad : conj(Uad)) *Oy;
295 arg.outB(0, point_d, 1-parity) = shortP;
299 /**************************middleLinkKernel*****************************
304 * 3 LINKS: ab_link, bc_link, ad_link
305 * 3 COLOR MATRIX: newOprod_at_A, oprod_at_C, Qprod_at_D
307 * 4 COLOR MATRIX: newOprod_at_A, P3_at_A, Pmu_at_B, Qmu_at_A
309 * Three call variations:
310 * 1. when Qprev == NULL: Qprod_at_D does not exist and is not read in
312 * 3. when Pmu/Qmu == NULL, Pmu_at_B and Qmu_at_A are not written out
314 * In all three above case, if the direction sig is negative, newOprod_at_A is
315 * not read in or written out.
317 * Therefore the data traffic, in two-number pair (num_of_link, num_of_color_matrix)
318 * Call 1: (called 48 times, half positive sig, half negative sig)
319 * if (sig is positive): (3, 6)
321 * Call 2: (called 192 time, half positive sig, half negative sig)
322 * if (sig is positive): (3, 7)
324 * Call 3: (called 48 times, half positive sig, half negative sig)
325 * if (sig is positive): (3, 5)
326 * else : (3, 2) no need to loadQprod_at_D in this case
328 * note: oprod_at_C could actually be read in from D when it is the fresh outer product
329 * and we call it oprod_at_C to simply naming. This does not affect our data traffic analysis
331 * Flop count, in two-number pair (matrix_multi, matrix_add)
332 * call 1: if (sig is positive) (3, 1)
334 * call 2: if (sig is positive) (4, 1)
336 * call 3: if (sig is positive) (4, 1)
337 * (Lepage) else (2, 0)
339 ****************************************************************************/
340 template <int sig_positive, int mu_positive, bool pMu, bool qMu, bool qPrev, typename Arg>
341 __global__ void middleLinkKernel(Arg arg)
343 typedef Matrix<complex<typename Arg::real>, Arg::nColor> Link;
345 int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
346 if (x_cb >= arg.threads) return;
347 int parity = blockIdx.y * blockDim.y + threadIdx.y;
350 getCoords(x, x_cb, arg.D, parity);
356 * A is the current point (sid)
360 for (int d=0; d<4; d++) x[d] += arg.base_idx[d];
361 int e_cb = linkIndex(x,arg.E);
362 parity = parity ^ arg.oddness_change;
363 int y[4] = {x[0], x[1], x[2], x[3]};
365 int mymu = posDir(arg.mu);
366 updateCoords(y, mymu, (mu_positive ? -1 : 1), arg);
368 int point_d = linkIndex(y, arg.E);
369 int ad_link_nbr_idx = mu_positive ? point_d : e_cb;
371 int mysig = posDir(arg.sig);
372 updateCoords(y, mysig, (sig_positive ? 1 : -1), arg);
373 int point_c = linkIndex(y, arg.E);
375 for (int d=0; d<4; d++) y[d] = x[d];
376 updateCoords(y, mysig, (sig_positive ? 1 : -1), arg);
377 int point_b = linkIndex(y, arg.E);
379 int bc_link_nbr_idx = mu_positive ? point_c : point_b;
380 int ab_link_nbr_idx = sig_positive ? e_cb : point_b;
382 // load the link variable connecting a and b
383 Link Uab = arg.link(mysig, ab_link_nbr_idx, sig_positive^(1-parity));
385 // load the link variable connecting b and c
386 Link Ubc = arg.link(mymu, bc_link_nbr_idx, mu_positive^(1-parity));
390 Oy = arg.oProd(posDir(arg.sig), sig_positive ? point_d : point_c, sig_positive^parity);
391 if (!sig_positive) Oy = conj(Oy);
392 } else { // QprevOdd != NULL
393 Oy = arg.oProd(0, point_c, parity);
396 Link Ow = !mu_positive ? Ubc*Oy : conj(Ubc)*Oy;
398 if (pMu) arg.pMu(0, point_b, 1-parity) = Ow;
400 arg.p3(0, e_cb, parity) = sig_positive ? Uab*Ow : conj(Uab)*Ow;
402 Link Uad = arg.link(mymu, ad_link_nbr_idx, mu_positive^parity);
403 if (!mu_positive) Uad = conj(Uad);
406 if (sig_positive) Oy = Ow*Uad;
407 if ( qMu ) arg.qMu(0, e_cb, parity) = Uad;
410 if ( qMu || sig_positive ) {
411 Oy = arg.qPrev(0, point_d, 1-parity);
414 if ( qMu ) arg.qMu(0, e_cb, parity) = Ox;
415 if (sig_positive) Oy = Ow*Ox;
419 Link oprod = arg.outA(arg.sig, e_cb, parity);
420 oprod += arg.coeff*Oy;
421 arg.outA(arg.sig, e_cb, parity) = oprod;
426 /***********************************sideLinkKernel***************************
431 * 4 COLOR MATRIX: shortP_at_D, newOprod, P3_at_A, Qprod_at_D,
433 * 2 COLOR MATRIX: shortP_at_D, newOprod,
435 * Two call variations:
437 * 2. when shortP == NULL && Qprod == NULL:
438 * no need to read ad_link/shortP_at_D or write shortP_at_D
439 * Qprod_at_D does not exit and is not read in
442 * Therefore the data traffic, in two-number pair (num_of_links, num_of_color_matrix)
443 * Call 1: (called 192 times)
446 * Call 2: (called 48 times)
449 * note: newOprod can be at point D or A, depending on if mu is postive or negative
451 * Flop count, in two-number pair (matrix_multi, matrix_add)
455 *********************************************************************************/
456 template <int mu_positive, typename Arg>
457 __global__ void sideLinkKernel(Arg arg)
459 typedef Matrix<complex<typename Arg::real>, Arg::nColor> Link;
460 int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
461 if (x_cb >= arg.threads) return;
462 int parity = blockIdx.y * blockDim.y + threadIdx.y;
465 getCoords(x, x_cb ,arg.D, parity);
466 for (int d=0; d<4; d++) x[d] = x[d] + arg.base_idx[d];
467 int e_cb = linkIndex(x,arg.E);
468 parity = parity ^ arg.oddness_change;
470 /* compute the side link contribution to the momentum
477 * A is the current point (x_cb)
481 int mymu = posDir(arg.mu);
482 int y[4] = {x[0], x[1], x[2], x[3]};
483 updateCoords(y, mymu, (mu_positive ? -1 : 1), arg);
484 int point_d = linkIndex(y,arg.E);
486 Link Oy = arg.p3(0, e_cb, parity);
489 int ad_link_nbr_idx = mu_positive ? point_d : e_cb;
491 Link Uad = arg.link(mymu, ad_link_nbr_idx, mu_positive^parity);
492 Link Ow = mu_positive ? Uad*Oy : conj(Uad)*Oy;
494 Link shortP = arg.outB(0, point_d, 1-parity);
495 shortP += arg.accumu_coeff * Ow;
496 arg.outB(0, point_d, 1-parity) = shortP;
500 Link Ox = arg.qProd(0, point_d, 1-parity);
501 Link Ow = mu_positive ? Oy*Ox : conj(Ox)*conj(Oy);
503 auto mycoeff = CoeffSign(goes_forward(arg.sig), parity)*CoeffSign(goes_forward(arg.mu),parity)*arg.coeff;
505 Link oprod = arg.outA(mu_positive ? arg.mu : opp_dir(arg.mu), mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity);
506 oprod += mycoeff * Ow;
507 arg.outA(mu_positive ? arg.mu : opp_dir(arg.mu), mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity) = oprod;
511 // Flop count, in two-number pair (matrix_mult, matrix_add)
513 template <int mu_positive, typename Arg>
514 __global__ void sideLinkShortKernel(Arg arg)
516 typedef Matrix<complex<typename Arg::real>, Arg::nColor> Link;
517 int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
518 if (x_cb >= arg.threads) return;
519 int parity = blockIdx.y * blockDim.y + threadIdx.y;
522 getCoords(x, x_cb, arg.D, parity);
523 for (int d=0; d<4; d++) x[d] = x[d] + arg.base_idx[d];
524 int e_cb = linkIndex(x,arg.E);
525 parity = parity ^ arg.oddness_change;
527 /* compute the side link contribution to the momentum
534 * A is the current point (x_cb)
537 int mymu = posDir(arg.mu);
538 int y[4] = {x[0], x[1], x[2], x[3]};
539 updateCoords(y, mymu, (mu_positive ? -1 : 1), arg);
540 int point_d = mu_positive ? linkIndex(y,arg.E) : e_cb;
542 int parity_ = mu_positive ? 1-parity : parity;
543 auto mycoeff = CoeffSign(goes_forward(arg.sig),parity)*CoeffSign(goes_forward(arg.mu),parity)*arg.coeff;
545 Link Oy = arg.p3(0, e_cb, parity);
546 Link oprod = arg.outA(posDir(arg.mu), point_d, parity_);
547 oprod += mu_positive ? mycoeff * Oy : mycoeff * conj(Oy);
548 arg.outA(posDir(arg.mu), point_d, parity_) = oprod;
551 template <typename Arg>
552 class FatLinkForce : public TunableVectorYZ {
555 const GaugeField &meta;
556 const HisqForceType type;
558 unsigned int minThreads() const { return arg.threads; }
559 bool tuneGridDim() const { return false; }
562 FatLinkForce(Arg &arg, const GaugeField &meta, int sig, int mu, HisqForceType type)
563 : TunableVectorYZ(2,type == FORCE_ONE_LINK ? 4 : 1), arg(arg), meta(meta), type(type) {
568 TuneKey tuneKey() const {
569 std::stringstream aux;
570 aux << meta.AuxString() << comm_dim_partitioned_string() << ",threads=" << arg.threads;
571 if (type == FORCE_MIDDLE_LINK || type == FORCE_LEPAGE_MIDDLE_LINK)
572 aux << ",sig=" << arg.sig << ",mu=" << arg.mu << ",pMu=" << arg.p_mu << ",q_muu=" << arg.q_mu << ",q_prev=" << arg.q_prev;
573 else if (type != FORCE_ONE_LINK)
574 aux << ",mu=" << arg.mu; // no sig dependence needed for side link
577 case FORCE_ONE_LINK: aux << ",ONE_LINK"; break;
578 case FORCE_ALL_LINK: aux << ",ALL_LINK"; break;
579 case FORCE_MIDDLE_LINK: aux << ",MIDDLE_LINK"; break;
580 case FORCE_LEPAGE_MIDDLE_LINK: aux << ",LEPAGE_MIDDLE_LINK"; break;
581 case FORCE_SIDE_LINK: aux << ",SIDE_LINK"; break;
582 case FORCE_SIDE_LINK_SHORT: aux << ",SIDE_LINK_SHORT"; break;
583 default: errorQuda("Undefined force type %d", type);
585 return TuneKey(meta.VolString(), typeid(*this).name(), aux.str().c_str());
588 void apply(const qudaStream_t &stream)
590 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
593 qudaLaunchKernel(oneLinkTermKernel<Arg>, tp, stream, arg);
596 if (goes_forward(arg.sig) && goes_forward(arg.mu))
597 qudaLaunchKernel(allLinkKernel<1,1,Arg>, tp, stream, arg);
598 else if (goes_forward(arg.sig) && goes_backward(arg.mu))
599 qudaLaunchKernel(allLinkKernel<1,0,Arg>, tp, stream, arg);
600 else if (goes_backward(arg.sig) && goes_forward(arg.mu))
601 qudaLaunchKernel(allLinkKernel<0,1,Arg>, tp, stream, arg);
603 qudaLaunchKernel(allLinkKernel<0,0,Arg>, tp, stream, arg);
605 case FORCE_MIDDLE_LINK:
606 if (!arg.p_mu || !arg.q_mu) errorQuda("Expect p_mu=%d and q_mu=%d to both be true", arg.p_mu, arg.q_mu);
608 if (goes_forward(arg.sig) && goes_forward(arg.mu))
609 qudaLaunchKernel(middleLinkKernel<1,1,true,true,true,Arg>, tp, stream, arg);
610 else if (goes_forward(arg.sig) && goes_backward(arg.mu))
611 qudaLaunchKernel(middleLinkKernel<1,0,true,true,true,Arg>, tp, stream, arg);
612 else if (goes_backward(arg.sig) && goes_forward(arg.mu))
613 qudaLaunchKernel(middleLinkKernel<0,1,true,true,true,Arg>, tp, stream, arg);
615 qudaLaunchKernel(middleLinkKernel<0,0,true,true,true,Arg>, tp, stream, arg);
617 if (goes_forward(arg.sig) && goes_forward(arg.mu))
618 qudaLaunchKernel(middleLinkKernel<1,1,true,true,false,Arg>, tp, stream, arg);
619 else if (goes_forward(arg.sig) && goes_backward(arg.mu))
620 qudaLaunchKernel(middleLinkKernel<1,0,true,true,false,Arg>, tp, stream, arg);
621 else if (goes_backward(arg.sig) && goes_forward(arg.mu))
622 qudaLaunchKernel(middleLinkKernel<0,1,true,true,false,Arg>, tp, stream, arg);
624 qudaLaunchKernel(middleLinkKernel<0,0,true,true,false,Arg>, tp, stream, arg);
627 case FORCE_LEPAGE_MIDDLE_LINK:
628 if (arg.p_mu || arg.q_mu || !arg.q_prev)
629 errorQuda("Expect p_mu=%d and q_mu=%d to both be false and q_prev=%d true", arg.p_mu, arg.q_mu, arg.q_prev);
630 if (goes_forward(arg.sig) && goes_forward(arg.mu))
631 qudaLaunchKernel(middleLinkKernel<1,1,false,false,true,Arg>, tp, stream, arg);
632 else if (goes_forward(arg.sig) && goes_backward(arg.mu))
633 qudaLaunchKernel(middleLinkKernel<1,0,false,false,true,Arg>, tp, stream, arg);
634 else if (goes_backward(arg.sig) && goes_forward(arg.mu))
635 qudaLaunchKernel(middleLinkKernel<0,1,false,false,true,Arg>, tp, stream, arg);
637 qudaLaunchKernel(middleLinkKernel<0,0,false,false,true,Arg>, tp, stream, arg);
639 case FORCE_SIDE_LINK:
640 if (goes_forward(arg.mu)) qudaLaunchKernel(sideLinkKernel<1,Arg>, tp, stream, arg);
641 else qudaLaunchKernel(sideLinkKernel<0,Arg>, tp, stream, arg);
643 case FORCE_SIDE_LINK_SHORT:
644 if (goes_forward(arg.mu)) qudaLaunchKernel(sideLinkShortKernel<1,Arg>, tp, stream, arg);
645 else qudaLaunchKernel(sideLinkShortKernel<0,Arg>, tp, stream, arg);
648 errorQuda("Undefined force type %d", type);
661 case FORCE_MIDDLE_LINK:
664 case FORCE_LEPAGE_MIDDLE_LINK:
668 case FORCE_SIDE_LINK:
670 case FORCE_SIDE_LINK_SHORT:
673 default: errorQuda("Undefined force type %d", type);
686 case FORCE_MIDDLE_LINK:
689 case FORCE_LEPAGE_MIDDLE_LINK:
693 case FORCE_SIDE_LINK:
695 case FORCE_SIDE_LINK_SHORT:
698 default: errorQuda("Undefined force type %d", type);
702 long long flops() const {
705 return 2*4*arg.threads*36ll;
707 return 2*arg.threads*(goes_forward(arg.sig) ? 1242ll : 828ll);
708 case FORCE_MIDDLE_LINK:
709 case FORCE_LEPAGE_MIDDLE_LINK:
710 return 2*arg.threads*(2 * 198 +
711 (!arg.q_prev && goes_forward(arg.sig) ? 198 : 0) +
712 (arg.q_prev && (arg.q_mu || goes_forward(arg.sig) ) ? 198 : 0) +
713 ((arg.q_prev && goes_forward(arg.sig) ) ? 198 : 0) +
714 ( goes_forward(arg.sig) ? 216 : 0) );
715 case FORCE_SIDE_LINK: return 2*arg.threads*2*234;
716 case FORCE_SIDE_LINK_SHORT: return 2*arg.threads*36;
717 default: errorQuda("Undefined force type %d", type);
722 long long bytes() const {
725 return 2*4*arg.threads*( arg.oProd.Bytes() + 2*arg.outA.Bytes() );
727 return 2*arg.threads*( (goes_forward(arg.sig) ? 4 : 2)*arg.outA.Bytes() + 3*arg.link.Bytes()
728 + arg.oProd.Bytes() + arg.qPrev.Bytes() + 2*arg.outB.Bytes());
729 case FORCE_MIDDLE_LINK:
730 case FORCE_LEPAGE_MIDDLE_LINK:
731 return 2*arg.threads*( ( goes_forward(arg.sig) ? 2*arg.outA.Bytes() : 0 ) +
732 (arg.p_mu ? arg.pMu.Bytes() : 0) +
733 (arg.q_mu ? arg.qMu.Bytes() : 0) +
734 ( ( goes_forward(arg.sig) || arg.q_mu ) ? arg.qPrev.Bytes() : 0) +
735 arg.p3.Bytes() + 3*arg.link.Bytes() + arg.oProd.Bytes() );
736 case FORCE_SIDE_LINK:
737 return 2*arg.threads*( 2*arg.outA.Bytes() + 2*arg.outB.Bytes() +
738 arg.p3.Bytes() + arg.link.Bytes() + arg.qProd.Bytes() );
739 case FORCE_SIDE_LINK_SHORT:
740 return 2*arg.threads*( 2*arg.outA.Bytes() + arg.p3.Bytes() );
741 default: errorQuda("Undefined force type %d", type);
747 template <typename real, int nColor, QudaReconstructType recon>
748 struct HisqStaplesForce {
749 HisqStaplesForce(GaugeField &Pmu, GaugeField &P3, GaugeField &P5, GaugeField &Pnumu,
750 GaugeField &Qmu, GaugeField &Qnumu, GaugeField &newOprod,
751 const GaugeField &oprod, const GaugeField &link,
752 const double *path_coeff_array)
754 PathCoefficients<real> act_path_coeff(path_coeff_array);
755 real OneLink = act_path_coeff.one;
756 real ThreeSt = act_path_coeff.three;
757 real mThreeSt = -ThreeSt;
758 real FiveSt = act_path_coeff.five;
759 real mFiveSt = -FiveSt;
760 real SevenSt = act_path_coeff.seven;
761 real Lepage = act_path_coeff.lepage;
762 real mLepage = -Lepage;
764 FatLinkArg<real, nColor> arg(newOprod, oprod, link, OneLink, FORCE_ONE_LINK);
765 FatLinkForce<decltype(arg)> oneLink(arg, link, 0, 0, FORCE_ONE_LINK);
768 for (int sig=0; sig<8; sig++) {
769 for (int mu=0; mu<8; mu++) {
770 if ( (mu == sig) || (mu == opp_dir(sig))) continue;
773 //Kernel A: middle link
774 FatLinkArg<real, nColor> middleLinkArg( newOprod, Pmu, P3, Qmu, oprod, link, mThreeSt, 2, FORCE_MIDDLE_LINK);
775 FatLinkForce<decltype(arg)> middleLink(middleLinkArg, link, sig, mu, FORCE_MIDDLE_LINK);
778 for (int nu=0; nu < 8; nu++) {
779 if (nu == sig || nu == opp_dir(sig) || nu == mu || nu == opp_dir(mu)) continue;
781 //5-link: middle link
783 FatLinkArg<real, nColor> middleLinkArg( newOprod, Pnumu, P5, Qnumu, Pmu, Qmu, link, FiveSt, 1, FORCE_MIDDLE_LINK);
784 FatLinkForce<decltype(arg)> middleLink(middleLinkArg, link, sig, nu, FORCE_MIDDLE_LINK);
787 for (int rho = 0; rho < 8; rho++) {
788 if (rho == sig || rho == opp_dir(sig) || rho == mu || rho == opp_dir(mu) || rho == nu || rho == opp_dir(nu)) continue;
790 //7-link: middle link and side link
791 FatLinkArg<real, nColor> arg(newOprod, P5, Pnumu, Qnumu, link, SevenSt, FiveSt != 0 ? SevenSt/FiveSt : 0, 1, FORCE_ALL_LINK, true);
792 FatLinkForce<decltype(arg)> all(arg, link, sig, rho, FORCE_ALL_LINK);
798 FatLinkArg<real, nColor> arg(newOprod, P3, P5, Qmu, link, mFiveSt, (ThreeSt != 0 ? FiveSt/ThreeSt : 0), 1, FORCE_SIDE_LINK);
799 FatLinkForce<decltype(arg)> side(arg, link, sig, nu, FORCE_SIDE_LINK);
806 FatLinkArg<real, nColor> middleLinkArg( newOprod, P5, Pmu, Qmu, link, Lepage, 2, FORCE_LEPAGE_MIDDLE_LINK);
807 FatLinkForce<decltype(arg)> middleLink(middleLinkArg, link, sig, mu, FORCE_LEPAGE_MIDDLE_LINK);
810 FatLinkArg<real, nColor> arg(newOprod, P3, P5, Qmu, link, mLepage, (ThreeSt != 0 ? Lepage/ThreeSt : 0), 2, FORCE_SIDE_LINK);
811 FatLinkForce<decltype(arg)> side(arg, link, sig, mu, FORCE_SIDE_LINK);
816 FatLinkArg<real, nColor> arg(newOprod, P3, link, ThreeSt, 1, FORCE_SIDE_LINK_SHORT);
817 FatLinkForce<decltype(arg)> side(arg, P3, sig, mu, FORCE_SIDE_LINK_SHORT);
824 void hisqStaplesForce(GaugeField &newOprod, const GaugeField &oprod, const GaugeField &link, const double path_coeff_array[6])
826 if (!link.isNative()) errorQuda("Unsupported gauge order %d", link.Order());
827 if (!oprod.isNative()) errorQuda("Unsupported gauge order %d", oprod.Order());
828 if (!newOprod.isNative()) errorQuda("Unsupported gauge order %d", newOprod.Order());
829 if (checkLocation(newOprod,oprod,link) == QUDA_CPU_FIELD_LOCATION) errorQuda("CPU not implemented");
831 // create color matrix fields with zero padding
832 GaugeFieldParam gauge_param(link);
833 gauge_param.reconstruct = QUDA_RECONSTRUCT_NO;
834 gauge_param.order = QUDA_FLOAT2_GAUGE_ORDER;
835 gauge_param.geometry = QUDA_SCALAR_GEOMETRY;
837 cudaGaugeField Pmu(gauge_param);
838 cudaGaugeField P3(gauge_param);
839 cudaGaugeField P5(gauge_param);
840 cudaGaugeField Pnumu(gauge_param);
841 cudaGaugeField Qmu(gauge_param);
842 cudaGaugeField Qnumu(gauge_param);
844 QudaPrecision precision = checkPrecision(oprod, link, newOprod);
845 instantiate<HisqStaplesForce, ReconstructNone>(Pmu, P3, P5, Pnumu, Qmu, Qnumu, newOprod, oprod, link, path_coeff_array);
847 qudaDeviceSynchronize();
850 template <typename real, int nColor, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
851 struct CompleteForceArg : public BaseForceArg<real, nColor, reconstruct> {
853 typedef typename gauge_mapper<real,QUDA_RECONSTRUCT_NO>::type F;
854 F outA; // force output accessor
855 const F oProd; // force input accessor
858 CompleteForceArg(GaugeField &force, const GaugeField &link)
859 : BaseForceArg<real, nColor, reconstruct>(link, 0), outA(force), oProd(force), coeff(0.0)
864 // Flops count: 4 matrix multiplications per lattice site = 792 Flops per site
865 template <typename Arg>
866 __global__ void completeForceKernel(Arg arg)
868 typedef Matrix<complex<typename Arg::real>, Arg::nColor> Link;
869 int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
870 if (x_cb >= arg.threads) return;
871 int parity = blockIdx.y * blockDim.y + threadIdx.y;
874 getCoords(x, x_cb, arg.X, parity);
876 for (int d=0; d<4; d++) x[d] += arg.border[d];
877 int e_cb = linkIndex(x,arg.E);
880 for (int sig=0; sig<4; ++sig) {
881 Link Uw = arg.link(sig, e_cb, parity);
882 Link Ox = arg.oProd(sig, e_cb, parity);
887 typename Arg::real coeff = (parity==1) ? -1.0 : 1.0;
888 arg.outA(sig, e_cb, parity) = coeff*Ow;
892 template <typename real, int nColor, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
893 struct LongLinkArg : public BaseForceArg<real, nColor, reconstruct> {
895 typedef typename gauge::FloatNOrder<real,18,2,11> M;
896 typedef typename gauge_mapper<real,QUDA_RECONSTRUCT_NO>::type F;
901 LongLinkArg(GaugeField &newOprod, const GaugeField &link, const GaugeField &oprod, real coeff)
902 : BaseForceArg<real, nColor, reconstruct>(link,0), outA(newOprod), oProd(oprod), coeff(coeff)
907 // Flops count, in two-number pair (matrix_mult, matrix_add)
909 // 4968 Flops per site in total
910 template <typename Arg>
911 __global__ void longLinkKernel(Arg arg)
913 typedef Matrix<complex<typename Arg::real>, Arg::nColor> Link;
914 int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
915 if (x_cb >= arg.threads) return;
916 int parity = blockIdx.y * blockDim.y + threadIdx.y;
919 int dx[4] = {0,0,0,0};
921 getCoords(x, x_cb, arg.X, parity);
923 for (int i=0; i<4; i++) x[i] += arg.border[i];
924 int e_cb = linkIndex(x,arg.E);
929 * ---- ---- ---- ----
933 * C is the current point (sid)
937 // compute the force for forward long links
939 for (int sig=0; sig<4; sig++) {
943 int point_d = linkIndexShift(x,dx,arg.E);
946 int point_e = linkIndexShift(x,dx,arg.E);
949 int point_b = linkIndexShift(x,dx,arg.E);
952 int point_a = linkIndexShift(x,dx,arg.E);
955 Link Uab = arg.link(sig, point_a, parity);
956 Link Ubc = arg.link(sig, point_b, 1-parity);
957 Link Ude = arg.link(sig, point_d, 1-parity);
958 Link Uef = arg.link(sig, point_e, parity);
960 Link Oz = arg.oProd(sig, point_c, parity);
961 Link Oy = arg.oProd(sig, point_b, 1-parity);
962 Link Ox = arg.oProd(sig, point_a, parity);
964 Link temp = Ude*Uef*Oz - Ude*Oy*Ubc + Ox*Uab*Ubc;
966 Link force = arg.outA(sig, e_cb, parity);
967 arg.outA(sig, e_cb, parity) = force + arg.coeff*temp;
972 template <typename Arg>
973 class HisqForce : public TunableVectorY {
976 const GaugeField &meta;
977 const HisqForceType type;
979 unsigned int minThreads() const { return arg.threads; }
980 bool tuneGridDim() const { return false; }
983 HisqForce(Arg &arg, const GaugeField &meta, int sig, int mu, HisqForceType type)
984 : TunableVectorY(2), arg(arg), meta(meta), type(type) {
989 void apply(const qudaStream_t &stream) {
990 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
992 case FORCE_LONG_LINK: qudaLaunchKernel(longLinkKernel<Arg>, tp, stream, arg); break;
993 case FORCE_COMPLETE: qudaLaunchKernel(completeForceKernel<Arg>, tp, stream, arg); break;
995 errorQuda("Undefined force type %d", type);
999 TuneKey tuneKey() const {
1000 std::stringstream aux;
1001 aux << meta.AuxString() << comm_dim_partitioned_string() << ",threads=" << arg.threads;
1003 case FORCE_LONG_LINK: aux << ",LONG_LINK"; break;
1004 case FORCE_COMPLETE: aux << ",COMPLETE"; break;
1005 default: errorQuda("Undefined force type %d", type);
1007 return TuneKey(meta.VolString(), typeid(*this).name(), aux.str().c_str());
1012 case FORCE_LONG_LINK:
1013 case FORCE_COMPLETE:
1014 arg.outA.save(); break;
1015 default: errorQuda("Undefined force type %d", type);
1021 case FORCE_LONG_LINK:
1022 case FORCE_COMPLETE:
1023 arg.outA.load(); break;
1024 default: errorQuda("Undefined force type %d", type);
1028 long long flops() const {
1030 case FORCE_LONG_LINK: return 2*arg.threads*4968ll;
1031 case FORCE_COMPLETE: return 2*arg.threads*792ll;
1032 default: errorQuda("Undefined force type %d", type);
1037 long long bytes() const {
1039 case FORCE_LONG_LINK: return 4*2*arg.threads*(2*arg.outA.Bytes() + 4*arg.link.Bytes() + 3*arg.oProd.Bytes());
1040 case FORCE_COMPLETE: return 4*2*arg.threads*(arg.outA.Bytes() + arg.link.Bytes() + arg.oProd.Bytes());
1041 default: errorQuda("Undefined force type %d", type);
1047 template <typename real, int nColor, QudaReconstructType recon>
1048 struct HisqLongLinkForce {
1049 HisqLongLinkForce(GaugeField &newOprod, const GaugeField &oldOprod, const GaugeField &link, double coeff)
1051 LongLinkArg<real, nColor, recon> arg(newOprod, link, oldOprod, coeff);
1052 HisqForce<decltype(arg)> longLink(arg, link, 0, 0, FORCE_LONG_LINK);
1054 qudaDeviceSynchronize();
1058 void hisqLongLinkForce(GaugeField &newOprod, const GaugeField &oldOprod, const GaugeField &link, double coeff)
1060 if (!link.isNative()) errorQuda("Unsupported gauge order %d", link.Order());
1061 if (!oldOprod.isNative()) errorQuda("Unsupported gauge order %d", oldOprod.Order());
1062 if (!newOprod.isNative()) errorQuda("Unsupported gauge order %d", newOprod.Order());
1063 if (checkLocation(newOprod,oldOprod,link) == QUDA_CPU_FIELD_LOCATION) errorQuda("CPU not implemented");
1064 checkPrecision(newOprod, link, oldOprod);
1065 instantiate<HisqLongLinkForce, ReconstructNone>(newOprod, oldOprod, link, coeff);
1068 template <typename real, int nColor, QudaReconstructType recon>
1069 struct HisqCompleteForce {
1070 HisqCompleteForce(GaugeField &force, const GaugeField &link)
1072 CompleteForceArg<real, nColor, recon> arg(force, link);
1073 HisqForce<decltype(arg)> completeForce(arg, link, 0, 0, FORCE_COMPLETE);
1074 completeForce.apply(0);
1075 qudaDeviceSynchronize();
1079 void hisqCompleteForce(GaugeField &force, const GaugeField &link)
1081 if (!link.isNative()) errorQuda("Unsupported gauge order %d", link.Order());
1082 if (!force.isNative()) errorQuda("Unsupported gauge order %d", force.Order());
1083 if (checkLocation(force,link) == QUDA_CPU_FIELD_LOCATION) errorQuda("CPU not implemented");
1084 checkPrecision(link, force);
1085 instantiate<HisqCompleteForce, ReconstructNone>(force, link);
1088 } // namespace fermion_force
1092 #endif // GPU_HISQ_FORCE