QUDA  v1.1.0
A library for QCD on GPUs
hisq_paths_force_quda.cu
Go to the documentation of this file.
1 #include <utility>
2 #include <quda_internal.h>
3 #include <gauge_field.h>
4 #include <ks_improved_force.h>
5 #include <quda_matrix.h>
6 #include <tune_quda.h>
7 #include <index_helper.cuh>
8 #include <gauge_field_order.h>
9 #include <instantiate.h>
10 
11 #ifdef GPU_HISQ_FORCE
12 
13 namespace quda {
14 
15  namespace fermion_force {
16 
17  enum {
18  XUP = 0,
19  YUP = 1,
20  ZUP = 2,
21  TUP = 3,
22  TDOWN = 4,
23  ZDOWN = 5,
24  YDOWN = 6,
25  XDOWN = 7
26  };
27 
28  enum HisqForceType {
29  FORCE_ALL_LINK,
30  FORCE_MIDDLE_LINK,
31  FORCE_LEPAGE_MIDDLE_LINK,
32  FORCE_SIDE_LINK,
33  FORCE_SIDE_LINK_SHORT,
34  FORCE_LONG_LINK,
35  FORCE_COMPLETE,
36  FORCE_ONE_LINK,
37  FORCE_INVALID
38  };
39 
40  constexpr int opp_dir(int dir) { return 7-dir; }
41  constexpr int goes_forward(int dir) { return dir<=3; }
42  constexpr int goes_backward(int dir) { return dir>3; }
43  constexpr int CoeffSign(int pos_dir, int odd_lattice) { return 2*((pos_dir + odd_lattice + 1) & 1) - 1; }
44  constexpr int Sign(int parity) { return parity ? -1 : 1; }
45  constexpr int posDir(int dir) { return (dir >= 4) ? 7-dir : dir; }
46 
47  template <int dir, typename Arg>
48  constexpr void updateCoords(int x[], int shift, const Arg &arg) {
49  x[dir] = (x[dir] + shift + arg.E[dir]) % arg.E[dir];
50  }
51 
52  template <typename Arg>
53  constexpr void updateCoords(int x[], int dir, int shift, const Arg &arg) {
54  switch (dir) {
55  case 0: updateCoords<0>(x, shift, arg); break;
56  case 1: updateCoords<1>(x, shift, arg); break;
57  case 2: updateCoords<2>(x, shift, arg); break;
58  case 3: updateCoords<3>(x, shift, arg); break;
59  }
60  }
61 
62  //struct for holding the fattening path coefficients
63  template <typename real>
64  struct PathCoefficients {
65  const real one;
66  const real three;
67  const real five;
68  const real seven;
69  const real naik;
70  const real lepage;
71  PathCoefficients(const double *path_coeff_array)
72  : one(path_coeff_array[0]), naik(path_coeff_array[1]),
73  three(path_coeff_array[2]), five(path_coeff_array[3]),
74  seven(path_coeff_array[4]), lepage(path_coeff_array[5]) { }
75  };
76 
77  template <typename real_, int nColor_, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
78  struct BaseForceArg {
79  using real = real_;
80  static constexpr int nColor = nColor_;
81  typedef typename gauge_mapper<real,reconstruct>::type G;
82  const G link;
83  int threads;
84  int X[4]; // regular grid dims
85  int D[4]; // working set grid dims
86  int E[4]; // extended grid dims
87 
88  int commDim[4];
89  int border[4];
90  int base_idx[4]; // the offset into the extended field
91  int oddness_change;
92  int mu;
93  int sig;
94 
95  /**
96  @param[in] link Gauge field
97  @param[in] overlap Radius of additional redundant computation to do
98  */
99  BaseForceArg(const GaugeField &link, int overlap) : link(link), threads(1),
100  commDim{ comm_dim_partitioned(0), comm_dim_partitioned(1), comm_dim_partitioned(2), comm_dim_partitioned(3) }
101  {
102  for (int d=0; d<4; d++) {
103  E[d] = link.X()[d];
104  border[d] = link.R()[d];
105  X[d] = E[d] - 2*border[d];
106  D[d] = comm_dim_partitioned(d) ? X[d]+overlap*2 : X[d];
107  base_idx[d] = comm_dim_partitioned(d) ? border[d]-overlap : 0;
108  threads *= D[d];
109  }
110  threads /= 2;
111  oddness_change = (base_idx[0] + base_idx[1] + base_idx[2] + base_idx[3])&1;
112  }
113  };
114 
115  template <typename real, int nColor, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
116  struct FatLinkArg : public BaseForceArg<real, nColor, reconstruct> {
117  using BaseForceArg = BaseForceArg<real, nColor, reconstruct>;
118  typedef typename gauge_mapper<real,QUDA_RECONSTRUCT_NO>::type F;
119  F outA;
120  F outB;
121  F pMu;
122  F p3;
123  F qMu;
124 
125  const F oProd;
126  const F qProd;
127  const F qPrev;
128  const real coeff;
129  const real accumu_coeff;
130 
131  const bool p_mu;
132  const bool q_mu;
133  const bool q_prev;
134 
135  FatLinkArg(GaugeField &force, const GaugeField &oProd, const GaugeField &link, real coeff, HisqForceType type)
136  : BaseForceArg(link, 0), outA(force), outB(force), pMu(oProd), p3(oProd), qMu(oProd),
137  oProd(oProd), qProd(oProd), qPrev(oProd), coeff(coeff), accumu_coeff(0),
138  p_mu(false), q_mu(false), q_prev(false)
139  { if (type != FORCE_ONE_LINK) errorQuda("This constructor is for FORCE_ONE_LINK"); }
140 
141  FatLinkArg(GaugeField &newOprod, GaugeField &pMu, GaugeField &P3, GaugeField &qMu,
142  const GaugeField &oProd, const GaugeField &qPrev, const GaugeField &link,
143  real coeff, int overlap, HisqForceType type)
144  : BaseForceArg(link, overlap), outA(newOprod), outB(newOprod), pMu(pMu), p3(P3), qMu(qMu),
145  oProd(oProd), qProd(oProd), qPrev(qPrev), coeff(coeff), accumu_coeff(0), p_mu(true), q_mu(true), q_prev(true)
146  { if (type != FORCE_MIDDLE_LINK) errorQuda("This constructor is for FORCE_MIDDLE_LINK"); }
147 
148  FatLinkArg(GaugeField &newOprod, GaugeField &pMu, GaugeField &P3, GaugeField &qMu,
149  const GaugeField &oProd, const GaugeField &link,
150  real coeff, int overlap, HisqForceType type)
151  : BaseForceArg(link, overlap), outA(newOprod), outB(newOprod), pMu(pMu), p3(P3), qMu(qMu),
152  oProd(oProd), qProd(oProd), qPrev(qMu), coeff(coeff), accumu_coeff(0), p_mu(true), q_mu(true), q_prev(false)
153  { if (type != FORCE_MIDDLE_LINK) errorQuda("This constructor is for FORCE_MIDDLE_LINK"); }
154 
155  FatLinkArg(GaugeField &newOprod, GaugeField &P3, const GaugeField &oProd,
156  const GaugeField &qPrev, const GaugeField &link,
157  real coeff, int overlap, HisqForceType type)
158  : BaseForceArg(link, overlap), outA(newOprod), outB(newOprod), pMu(P3), p3(P3), qMu(qPrev),
159  oProd(oProd), qProd(oProd), qPrev(qPrev), coeff(coeff), accumu_coeff(0), p_mu(false), q_mu(false), q_prev(true)
160  { if (type != FORCE_LEPAGE_MIDDLE_LINK) errorQuda("This constructor is for FORCE_MIDDLE_LINK"); }
161 
162  FatLinkArg(GaugeField &newOprod, GaugeField &shortP, const GaugeField &P3,
163  const GaugeField &qProd, const GaugeField &link, real coeff, real accumu_coeff, int overlap, HisqForceType type)
164  : BaseForceArg(link, overlap), outA(newOprod), outB(shortP), pMu(P3), p3(P3), qMu(qProd), oProd(qProd), qProd(qProd),
165  qPrev(qProd), coeff(coeff), accumu_coeff(accumu_coeff),
166  p_mu(false), q_mu(false), q_prev(false)
167  { if (type != FORCE_SIDE_LINK) errorQuda("This constructor is for FORCE_SIDE_LINK or FORCE_ALL_LINK"); }
168 
169  FatLinkArg(GaugeField &newOprod, GaugeField &P3, const GaugeField &link,
170  real coeff, int overlap, HisqForceType type)
171  : BaseForceArg(link, overlap), outA(newOprod), outB(newOprod),
172  pMu(P3), p3(P3), qMu(P3), oProd(P3), qProd(P3), qPrev(P3), coeff(coeff), accumu_coeff(0.0),
173  p_mu(false), q_mu(false), q_prev(false)
174  { if (type != FORCE_SIDE_LINK_SHORT) errorQuda("This constructor is for FORCE_SIDE_LINK_SHORT"); }
175 
176  FatLinkArg(GaugeField &newOprod, GaugeField &shortP, const GaugeField &oProd, const GaugeField &qPrev,
177  const GaugeField &link, real coeff, real accumu_coeff, int overlap, HisqForceType type, bool dummy)
178  : BaseForceArg(link, overlap), outA(newOprod), outB(shortP), oProd(oProd), qPrev(qPrev),
179  pMu(shortP), p3(shortP), qMu(qPrev), qProd(qPrev), // dummy
180  coeff(coeff), accumu_coeff(accumu_coeff), p_mu(false), q_mu(false), q_prev(false)
181  { if (type != FORCE_ALL_LINK) errorQuda("This constructor is for FORCE_ALL_LINK"); }
182 
183  };
184 
185  template <typename Arg>
186  __global__ void oneLinkTermKernel(Arg arg)
187  {
188  typedef Matrix<complex<typename Arg::real>, Arg::nColor> Link;
189  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
190  if (x_cb >= arg.threads) return;
191  int parity = blockIdx.y * blockDim.y + threadIdx.y;
192  int sig = blockIdx.z * blockDim.z + threadIdx.z;
193  if (sig >= 4) return;
194 
195  int x[4];
196  getCoords(x, x_cb, arg.X, parity);
197 #pragma unroll
198  for (int d=0; d<4; d++) x[d] += arg.border[d];
199  int e_cb = linkIndex(x,arg.E);
200 
201  Link w = arg.oProd(sig, e_cb, parity);
202  Link force = arg.outA(sig, e_cb, parity);
203  force += arg.coeff * w;
204  arg.outA(sig, e_cb, parity) = force;
205  }
206 
207 
208  /********************************allLinkKernel*********************************************
209  *
210  * In this function we need
211  * READ
212  * 3 LINKS: ad_link, ab_link, bc_link
213  * 5 COLOR MATRIX: Qprev_at_D, oprod_at_C, newOprod_at_A(sig), newOprod_at_D/newOprod_at_A(mu), shortP_at_D
214  * WRITE:
215  * 3 COLOR MATRIX: newOprod_at_A(sig), newOprod_at_D/newOprod_at_A(mu), shortP_at_D,
216  *
217  * If sig is negative, then we don't need to read/write the color matrix newOprod_at_A(sig)
218  *
219  * Therefore the data traffic, in two-number pair (num_of_link, num_of_color_matrix)
220  *
221  * if (sig is positive): (3, 8)
222  * else : (3, 6)
223  *
224  * This function is called 384 times, half positive sig, half negative sig
225  *
226  * Flop count, in two-number pair (matrix_multi, matrix_add)
227  * if(sig is positive) (6,3)
228  * else (4,2)
229  *
230  ************************************************************************************************/
231  template <int sig_positive, int mu_positive, typename Arg>
232  __global__ void allLinkKernel(Arg arg)
233  {
234  typedef Matrix<complex<typename Arg::real>, Arg::nColor> Link;
235 
236  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
237  if (x_cb >= arg.threads) return;
238  int parity = blockIdx.y * blockDim.y + threadIdx.y;
239 
240  int x[4];
241  getCoords(x, x_cb, arg.D, parity);
242  for (int d=0; d<4; d++) x[d] += arg.base_idx[d];
243  int e_cb = linkIndex(x,arg.E);
244  parity = parity^arg.oddness_change;
245 
246  auto mycoeff = CoeffSign(sig_positive,parity)*arg.coeff;
247 
248  int y[4] = {x[0], x[1], x[2], x[3]};
249  int mysig = posDir(arg.sig);
250  updateCoords(y, mysig, (sig_positive ? 1 : -1), arg);
251  int point_b = linkIndex(y,arg.E);
252  int ab_link_nbr_idx = (sig_positive) ? e_cb : point_b;
253 
254  for (int d=0; d<4; d++) y[d] = x[d];
255 
256  /* sig
257  * A________B
258  * mu | |
259  * D | |C
260  *
261  * A is the current point (sid)
262  *
263  */
264 
265  int mu = mu_positive ? arg.mu : opp_dir(arg.mu);
266  int dir = mu_positive ? -1 : 1;
267 
268  updateCoords(y, mu, dir, arg);
269  int point_d = linkIndex(y,arg.E);
270  updateCoords(y, mysig, (sig_positive ? 1 : -1), arg);
271  int point_c = linkIndex(y,arg.E);
272 
273  Link Uab = arg.link(posDir(arg.sig), ab_link_nbr_idx, sig_positive^(1-parity));
274  Link Uad = arg.link(mu, mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity);
275  Link Ubc = arg.link(mu, mu_positive ? point_c : point_b, mu_positive ? parity : 1-parity);
276  Link Ox = arg.qPrev(0, point_d, 1-parity);
277  Link Oy = arg.oProd(0, point_c, parity);
278  Link Oz = mu_positive ? conj(Ubc)*Oy : Ubc*Oy;
279 
280  if (sig_positive) {
281  Link force = arg.outA(arg.sig, e_cb, parity);
282  force += Sign(parity)*mycoeff*Oz*Ox* (mu_positive ? Uad : conj(Uad));
283  arg.outA(arg.sig, e_cb, parity) = force;
284  Oy = Uab*Oz;
285  } else {
286  Oy = conj(Uab)*Oz;
287  }
288 
289  Link force = arg.outA(mu, mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity);
290  force += Sign(mu_positive ? 1-parity : parity)*mycoeff* (mu_positive ? Oy*Ox : conj(Ox)*conj(Oy));
291  arg.outA(mu, mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity) = force;
292 
293  Link shortP = arg.outB(0, point_d, 1-parity);
294  shortP += arg.accumu_coeff* (mu_positive ? Uad : conj(Uad)) *Oy;
295  arg.outB(0, point_d, 1-parity) = shortP;
296  }
297 
298 
299  /**************************middleLinkKernel*****************************
300  *
301  *
302  * Generally we need
303  * READ
304  * 3 LINKS: ab_link, bc_link, ad_link
305  * 3 COLOR MATRIX: newOprod_at_A, oprod_at_C, Qprod_at_D
306  * WRITE
307  * 4 COLOR MATRIX: newOprod_at_A, P3_at_A, Pmu_at_B, Qmu_at_A
308  *
309  * Three call variations:
310  * 1. when Qprev == NULL: Qprod_at_D does not exist and is not read in
311  * 2. full read/write
312  * 3. when Pmu/Qmu == NULL, Pmu_at_B and Qmu_at_A are not written out
313  *
314  * In all three above case, if the direction sig is negative, newOprod_at_A is
315  * not read in or written out.
316  *
317  * Therefore the data traffic, in two-number pair (num_of_link, num_of_color_matrix)
318  * Call 1: (called 48 times, half positive sig, half negative sig)
319  * if (sig is positive): (3, 6)
320  * else : (3, 4)
321  * Call 2: (called 192 time, half positive sig, half negative sig)
322  * if (sig is positive): (3, 7)
323  * else : (3, 5)
324  * Call 3: (called 48 times, half positive sig, half negative sig)
325  * if (sig is positive): (3, 5)
326  * else : (3, 2) no need to loadQprod_at_D in this case
327  *
328  * note: oprod_at_C could actually be read in from D when it is the fresh outer product
329  * and we call it oprod_at_C to simply naming. This does not affect our data traffic analysis
330  *
331  * Flop count, in two-number pair (matrix_multi, matrix_add)
332  * call 1: if (sig is positive) (3, 1)
333  * else (2, 0)
334  * call 2: if (sig is positive) (4, 1)
335  * else (3, 0)
336  * call 3: if (sig is positive) (4, 1)
337  * (Lepage) else (2, 0)
338  *
339  ****************************************************************************/
340  template <int sig_positive, int mu_positive, bool pMu, bool qMu, bool qPrev, typename Arg>
341  __global__ void middleLinkKernel(Arg arg)
342  {
343  typedef Matrix<complex<typename Arg::real>, Arg::nColor> Link;
344 
345  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
346  if (x_cb >= arg.threads) return;
347  int parity = blockIdx.y * blockDim.y + threadIdx.y;
348 
349  int x[4];
350  getCoords(x, x_cb, arg.D, parity);
351 
352  /* A________B
353  * mu | |
354  * D| |C
355  *
356  * A is the current point (sid)
357  *
358  */
359 
360  for (int d=0; d<4; d++) x[d] += arg.base_idx[d];
361  int e_cb = linkIndex(x,arg.E);
362  parity = parity ^ arg.oddness_change;
363  int y[4] = {x[0], x[1], x[2], x[3]};
364 
365  int mymu = posDir(arg.mu);
366  updateCoords(y, mymu, (mu_positive ? -1 : 1), arg);
367 
368  int point_d = linkIndex(y, arg.E);
369  int ad_link_nbr_idx = mu_positive ? point_d : e_cb;
370 
371  int mysig = posDir(arg.sig);
372  updateCoords(y, mysig, (sig_positive ? 1 : -1), arg);
373  int point_c = linkIndex(y, arg.E);
374 
375  for (int d=0; d<4; d++) y[d] = x[d];
376  updateCoords(y, mysig, (sig_positive ? 1 : -1), arg);
377  int point_b = linkIndex(y, arg.E);
378 
379  int bc_link_nbr_idx = mu_positive ? point_c : point_b;
380  int ab_link_nbr_idx = sig_positive ? e_cb : point_b;
381 
382  // load the link variable connecting a and b
383  Link Uab = arg.link(mysig, ab_link_nbr_idx, sig_positive^(1-parity));
384 
385  // load the link variable connecting b and c
386  Link Ubc = arg.link(mymu, bc_link_nbr_idx, mu_positive^(1-parity));
387 
388  Link Oy;
389  if (!qPrev) {
390  Oy = arg.oProd(posDir(arg.sig), sig_positive ? point_d : point_c, sig_positive^parity);
391  if (!sig_positive) Oy = conj(Oy);
392  } else { // QprevOdd != NULL
393  Oy = arg.oProd(0, point_c, parity);
394  }
395 
396  Link Ow = !mu_positive ? Ubc*Oy : conj(Ubc)*Oy;
397 
398  if (pMu) arg.pMu(0, point_b, 1-parity) = Ow;
399 
400  arg.p3(0, e_cb, parity) = sig_positive ? Uab*Ow : conj(Uab)*Ow;
401 
402  Link Uad = arg.link(mymu, ad_link_nbr_idx, mu_positive^parity);
403  if (!mu_positive) Uad = conj(Uad);
404 
405  if (!qPrev) {
406  if (sig_positive) Oy = Ow*Uad;
407  if ( qMu ) arg.qMu(0, e_cb, parity) = Uad;
408  } else {
409  Link Ox;
410  if ( qMu || sig_positive ) {
411  Oy = arg.qPrev(0, point_d, 1-parity);
412  Ox = Oy*Uad;
413  }
414  if ( qMu ) arg.qMu(0, e_cb, parity) = Ox;
415  if (sig_positive) Oy = Ow*Ox;
416  }
417 
418  if (sig_positive) {
419  Link oprod = arg.outA(arg.sig, e_cb, parity);
420  oprod += arg.coeff*Oy;
421  arg.outA(arg.sig, e_cb, parity) = oprod;
422  }
423 
424  }
425 
426  /***********************************sideLinkKernel***************************
427  *
428  * In general we need
429  * READ
430  * 1 LINK: ad_link
431  * 4 COLOR MATRIX: shortP_at_D, newOprod, P3_at_A, Qprod_at_D,
432  * WRITE
433  * 2 COLOR MATRIX: shortP_at_D, newOprod,
434  *
435  * Two call variations:
436  * 1. full read/write
437  * 2. when shortP == NULL && Qprod == NULL:
438  * no need to read ad_link/shortP_at_D or write shortP_at_D
439  * Qprod_at_D does not exit and is not read in
440  *
441  *
442  * Therefore the data traffic, in two-number pair (num_of_links, num_of_color_matrix)
443  * Call 1: (called 192 times)
444  * (1, 6)
445  *
446  * Call 2: (called 48 times)
447  * (0, 3)
448  *
449  * note: newOprod can be at point D or A, depending on if mu is postive or negative
450  *
451  * Flop count, in two-number pair (matrix_multi, matrix_add)
452  * call 1: (2, 2)
453  * call 2: (0, 1)
454  *
455  *********************************************************************************/
456  template <int mu_positive, typename Arg>
457  __global__ void sideLinkKernel(Arg arg)
458  {
459  typedef Matrix<complex<typename Arg::real>, Arg::nColor> Link;
460  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
461  if (x_cb >= arg.threads) return;
462  int parity = blockIdx.y * blockDim.y + threadIdx.y;
463 
464  int x[4];
465  getCoords(x, x_cb ,arg.D, parity);
466  for (int d=0; d<4; d++) x[d] = x[d] + arg.base_idx[d];
467  int e_cb = linkIndex(x,arg.E);
468  parity = parity ^ arg.oddness_change;
469 
470  /* compute the side link contribution to the momentum
471  *
472  * sig
473  * A________B
474  * | | mu
475  * D | |C
476  *
477  * A is the current point (x_cb)
478  *
479  */
480 
481  int mymu = posDir(arg.mu);
482  int y[4] = {x[0], x[1], x[2], x[3]};
483  updateCoords(y, mymu, (mu_positive ? -1 : 1), arg);
484  int point_d = linkIndex(y,arg.E);
485 
486  Link Oy = arg.p3(0, e_cb, parity);
487 
488  {
489  int ad_link_nbr_idx = mu_positive ? point_d : e_cb;
490 
491  Link Uad = arg.link(mymu, ad_link_nbr_idx, mu_positive^parity);
492  Link Ow = mu_positive ? Uad*Oy : conj(Uad)*Oy;
493 
494  Link shortP = arg.outB(0, point_d, 1-parity);
495  shortP += arg.accumu_coeff * Ow;
496  arg.outB(0, point_d, 1-parity) = shortP;
497  }
498 
499  {
500  Link Ox = arg.qProd(0, point_d, 1-parity);
501  Link Ow = mu_positive ? Oy*Ox : conj(Ox)*conj(Oy);
502 
503  auto mycoeff = CoeffSign(goes_forward(arg.sig), parity)*CoeffSign(goes_forward(arg.mu),parity)*arg.coeff;
504 
505  Link oprod = arg.outA(mu_positive ? arg.mu : opp_dir(arg.mu), mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity);
506  oprod += mycoeff * Ow;
507  arg.outA(mu_positive ? arg.mu : opp_dir(arg.mu), mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity) = oprod;
508  }
509  }
510 
511  // Flop count, in two-number pair (matrix_mult, matrix_add)
512  // (0,1)
513  template <int mu_positive, typename Arg>
514  __global__ void sideLinkShortKernel(Arg arg)
515  {
516  typedef Matrix<complex<typename Arg::real>, Arg::nColor> Link;
517  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
518  if (x_cb >= arg.threads) return;
519  int parity = blockIdx.y * blockDim.y + threadIdx.y;
520 
521  int x[4];
522  getCoords(x, x_cb, arg.D, parity);
523  for (int d=0; d<4; d++) x[d] = x[d] + arg.base_idx[d];
524  int e_cb = linkIndex(x,arg.E);
525  parity = parity ^ arg.oddness_change;
526 
527  /* compute the side link contribution to the momentum
528  *
529  * sig
530  * A________B
531  * | | mu
532  * D | |C
533  *
534  * A is the current point (x_cb)
535  *
536  */
537  int mymu = posDir(arg.mu);
538  int y[4] = {x[0], x[1], x[2], x[3]};
539  updateCoords(y, mymu, (mu_positive ? -1 : 1), arg);
540  int point_d = mu_positive ? linkIndex(y,arg.E) : e_cb;
541 
542  int parity_ = mu_positive ? 1-parity : parity;
543  auto mycoeff = CoeffSign(goes_forward(arg.sig),parity)*CoeffSign(goes_forward(arg.mu),parity)*arg.coeff;
544 
545  Link Oy = arg.p3(0, e_cb, parity);
546  Link oprod = arg.outA(posDir(arg.mu), point_d, parity_);
547  oprod += mu_positive ? mycoeff * Oy : mycoeff * conj(Oy);
548  arg.outA(posDir(arg.mu), point_d, parity_) = oprod;
549  }
550 
551  template <typename Arg>
552  class FatLinkForce : public TunableVectorYZ {
553 
554  Arg &arg;
555  const GaugeField &meta;
556  const HisqForceType type;
557 
558  unsigned int minThreads() const { return arg.threads; }
559  bool tuneGridDim() const { return false; }
560 
561  public:
562  FatLinkForce(Arg &arg, const GaugeField &meta, int sig, int mu, HisqForceType type)
563  : TunableVectorYZ(2,type == FORCE_ONE_LINK ? 4 : 1), arg(arg), meta(meta), type(type) {
564  arg.sig = sig;
565  arg.mu = mu;
566  }
567 
568  TuneKey tuneKey() const {
569  std::stringstream aux;
570  aux << meta.AuxString() << comm_dim_partitioned_string() << ",threads=" << arg.threads;
571  if (type == FORCE_MIDDLE_LINK || type == FORCE_LEPAGE_MIDDLE_LINK)
572  aux << ",sig=" << arg.sig << ",mu=" << arg.mu << ",pMu=" << arg.p_mu << ",q_muu=" << arg.q_mu << ",q_prev=" << arg.q_prev;
573  else if (type != FORCE_ONE_LINK)
574  aux << ",mu=" << arg.mu; // no sig dependence needed for side link
575 
576  switch (type) {
577  case FORCE_ONE_LINK: aux << ",ONE_LINK"; break;
578  case FORCE_ALL_LINK: aux << ",ALL_LINK"; break;
579  case FORCE_MIDDLE_LINK: aux << ",MIDDLE_LINK"; break;
580  case FORCE_LEPAGE_MIDDLE_LINK: aux << ",LEPAGE_MIDDLE_LINK"; break;
581  case FORCE_SIDE_LINK: aux << ",SIDE_LINK"; break;
582  case FORCE_SIDE_LINK_SHORT: aux << ",SIDE_LINK_SHORT"; break;
583  default: errorQuda("Undefined force type %d", type);
584  }
585  return TuneKey(meta.VolString(), typeid(*this).name(), aux.str().c_str());
586  }
587 
588  void apply(const qudaStream_t &stream)
589  {
590  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
591  switch (type) {
592  case FORCE_ONE_LINK:
593  qudaLaunchKernel(oneLinkTermKernel<Arg>, tp, stream, arg);
594  break;
595  case FORCE_ALL_LINK:
596  if (goes_forward(arg.sig) && goes_forward(arg.mu))
597  qudaLaunchKernel(allLinkKernel<1,1,Arg>, tp, stream, arg);
598  else if (goes_forward(arg.sig) && goes_backward(arg.mu))
599  qudaLaunchKernel(allLinkKernel<1,0,Arg>, tp, stream, arg);
600  else if (goes_backward(arg.sig) && goes_forward(arg.mu))
601  qudaLaunchKernel(allLinkKernel<0,1,Arg>, tp, stream, arg);
602  else
603  qudaLaunchKernel(allLinkKernel<0,0,Arg>, tp, stream, arg);
604  break;
605  case FORCE_MIDDLE_LINK:
606  if (!arg.p_mu || !arg.q_mu) errorQuda("Expect p_mu=%d and q_mu=%d to both be true", arg.p_mu, arg.q_mu);
607  if (arg.q_prev) {
608  if (goes_forward(arg.sig) && goes_forward(arg.mu))
609  qudaLaunchKernel(middleLinkKernel<1,1,true,true,true,Arg>, tp, stream, arg);
610  else if (goes_forward(arg.sig) && goes_backward(arg.mu))
611  qudaLaunchKernel(middleLinkKernel<1,0,true,true,true,Arg>, tp, stream, arg);
612  else if (goes_backward(arg.sig) && goes_forward(arg.mu))
613  qudaLaunchKernel(middleLinkKernel<0,1,true,true,true,Arg>, tp, stream, arg);
614  else
615  qudaLaunchKernel(middleLinkKernel<0,0,true,true,true,Arg>, tp, stream, arg);
616  } else {
617  if (goes_forward(arg.sig) && goes_forward(arg.mu))
618  qudaLaunchKernel(middleLinkKernel<1,1,true,true,false,Arg>, tp, stream, arg);
619  else if (goes_forward(arg.sig) && goes_backward(arg.mu))
620  qudaLaunchKernel(middleLinkKernel<1,0,true,true,false,Arg>, tp, stream, arg);
621  else if (goes_backward(arg.sig) && goes_forward(arg.mu))
622  qudaLaunchKernel(middleLinkKernel<0,1,true,true,false,Arg>, tp, stream, arg);
623  else
624  qudaLaunchKernel(middleLinkKernel<0,0,true,true,false,Arg>, tp, stream, arg);
625  }
626  break;
627  case FORCE_LEPAGE_MIDDLE_LINK:
628  if (arg.p_mu || arg.q_mu || !arg.q_prev)
629  errorQuda("Expect p_mu=%d and q_mu=%d to both be false and q_prev=%d true", arg.p_mu, arg.q_mu, arg.q_prev);
630  if (goes_forward(arg.sig) && goes_forward(arg.mu))
631  qudaLaunchKernel(middleLinkKernel<1,1,false,false,true,Arg>, tp, stream, arg);
632  else if (goes_forward(arg.sig) && goes_backward(arg.mu))
633  qudaLaunchKernel(middleLinkKernel<1,0,false,false,true,Arg>, tp, stream, arg);
634  else if (goes_backward(arg.sig) && goes_forward(arg.mu))
635  qudaLaunchKernel(middleLinkKernel<0,1,false,false,true,Arg>, tp, stream, arg);
636  else
637  qudaLaunchKernel(middleLinkKernel<0,0,false,false,true,Arg>, tp, stream, arg);
638  break;
639  case FORCE_SIDE_LINK:
640  if (goes_forward(arg.mu)) qudaLaunchKernel(sideLinkKernel<1,Arg>, tp, stream, arg);
641  else qudaLaunchKernel(sideLinkKernel<0,Arg>, tp, stream, arg);
642  break;
643  case FORCE_SIDE_LINK_SHORT:
644  if (goes_forward(arg.mu)) qudaLaunchKernel(sideLinkShortKernel<1,Arg>, tp, stream, arg);
645  else qudaLaunchKernel(sideLinkShortKernel<0,Arg>, tp, stream, arg);
646  break;
647  default:
648  errorQuda("Undefined force type %d", type);
649  }
650  }
651 
652  void preTune() {
653  switch (type) {
654  case FORCE_ONE_LINK:
655  arg.outA.save();
656  break;
657  case FORCE_ALL_LINK:
658  arg.outA.save();
659  arg.outB.save();
660  break;
661  case FORCE_MIDDLE_LINK:
662  arg.pMu.save();
663  arg.qMu.save();
664  case FORCE_LEPAGE_MIDDLE_LINK:
665  arg.outA.save();
666  arg.p3.save();
667  break;
668  case FORCE_SIDE_LINK:
669  arg.outB.save();
670  case FORCE_SIDE_LINK_SHORT:
671  arg.outA.save();
672  break;
673  default: errorQuda("Undefined force type %d", type);
674  }
675  }
676 
677  void postTune() {
678  switch (type) {
679  case FORCE_ONE_LINK:
680  arg.outA.load();
681  break;
682  case FORCE_ALL_LINK:
683  arg.outA.load();
684  arg.outB.load();
685  break;
686  case FORCE_MIDDLE_LINK:
687  arg.pMu.load();
688  arg.qMu.load();
689  case FORCE_LEPAGE_MIDDLE_LINK:
690  arg.outA.load();
691  arg.p3.load();
692  break;
693  case FORCE_SIDE_LINK:
694  arg.outB.load();
695  case FORCE_SIDE_LINK_SHORT:
696  arg.outA.load();
697  break;
698  default: errorQuda("Undefined force type %d", type);
699  }
700  }
701 
702  long long flops() const {
703  switch (type) {
704  case FORCE_ONE_LINK:
705  return 2*4*arg.threads*36ll;
706  case FORCE_ALL_LINK:
707  return 2*arg.threads*(goes_forward(arg.sig) ? 1242ll : 828ll);
708  case FORCE_MIDDLE_LINK:
709  case FORCE_LEPAGE_MIDDLE_LINK:
710  return 2*arg.threads*(2 * 198 +
711  (!arg.q_prev && goes_forward(arg.sig) ? 198 : 0) +
712  (arg.q_prev && (arg.q_mu || goes_forward(arg.sig) ) ? 198 : 0) +
713  ((arg.q_prev && goes_forward(arg.sig) ) ? 198 : 0) +
714  ( goes_forward(arg.sig) ? 216 : 0) );
715  case FORCE_SIDE_LINK: return 2*arg.threads*2*234;
716  case FORCE_SIDE_LINK_SHORT: return 2*arg.threads*36;
717  default: errorQuda("Undefined force type %d", type);
718  }
719  return 0;
720  }
721 
722  long long bytes() const {
723  switch (type) {
724  case FORCE_ONE_LINK:
725  return 2*4*arg.threads*( arg.oProd.Bytes() + 2*arg.outA.Bytes() );
726  case FORCE_ALL_LINK:
727  return 2*arg.threads*( (goes_forward(arg.sig) ? 4 : 2)*arg.outA.Bytes() + 3*arg.link.Bytes()
728  + arg.oProd.Bytes() + arg.qPrev.Bytes() + 2*arg.outB.Bytes());
729  case FORCE_MIDDLE_LINK:
730  case FORCE_LEPAGE_MIDDLE_LINK:
731  return 2*arg.threads*( ( goes_forward(arg.sig) ? 2*arg.outA.Bytes() : 0 ) +
732  (arg.p_mu ? arg.pMu.Bytes() : 0) +
733  (arg.q_mu ? arg.qMu.Bytes() : 0) +
734  ( ( goes_forward(arg.sig) || arg.q_mu ) ? arg.qPrev.Bytes() : 0) +
735  arg.p3.Bytes() + 3*arg.link.Bytes() + arg.oProd.Bytes() );
736  case FORCE_SIDE_LINK:
737  return 2*arg.threads*( 2*arg.outA.Bytes() + 2*arg.outB.Bytes() +
738  arg.p3.Bytes() + arg.link.Bytes() + arg.qProd.Bytes() );
739  case FORCE_SIDE_LINK_SHORT:
740  return 2*arg.threads*( 2*arg.outA.Bytes() + arg.p3.Bytes() );
741  default: errorQuda("Undefined force type %d", type);
742  }
743  return 0;
744  }
745  };
746 
747  template <typename real, int nColor, QudaReconstructType recon>
748  struct HisqStaplesForce {
749  HisqStaplesForce(GaugeField &Pmu, GaugeField &P3, GaugeField &P5, GaugeField &Pnumu,
750  GaugeField &Qmu, GaugeField &Qnumu, GaugeField &newOprod,
751  const GaugeField &oprod, const GaugeField &link,
752  const double *path_coeff_array)
753  {
754  PathCoefficients<real> act_path_coeff(path_coeff_array);
755  real OneLink = act_path_coeff.one;
756  real ThreeSt = act_path_coeff.three;
757  real mThreeSt = -ThreeSt;
758  real FiveSt = act_path_coeff.five;
759  real mFiveSt = -FiveSt;
760  real SevenSt = act_path_coeff.seven;
761  real Lepage = act_path_coeff.lepage;
762  real mLepage = -Lepage;
763 
764  FatLinkArg<real, nColor> arg(newOprod, oprod, link, OneLink, FORCE_ONE_LINK);
765  FatLinkForce<decltype(arg)> oneLink(arg, link, 0, 0, FORCE_ONE_LINK);
766  oneLink.apply(0);
767 
768  for (int sig=0; sig<8; sig++) {
769  for (int mu=0; mu<8; mu++) {
770  if ( (mu == sig) || (mu == opp_dir(sig))) continue;
771 
772  //3-link
773  //Kernel A: middle link
774  FatLinkArg<real, nColor> middleLinkArg( newOprod, Pmu, P3, Qmu, oprod, link, mThreeSt, 2, FORCE_MIDDLE_LINK);
775  FatLinkForce<decltype(arg)> middleLink(middleLinkArg, link, sig, mu, FORCE_MIDDLE_LINK);
776  middleLink.apply(0);
777 
778  for (int nu=0; nu < 8; nu++) {
779  if (nu == sig || nu == opp_dir(sig) || nu == mu || nu == opp_dir(mu)) continue;
780 
781  //5-link: middle link
782  //Kernel B
783  FatLinkArg<real, nColor> middleLinkArg( newOprod, Pnumu, P5, Qnumu, Pmu, Qmu, link, FiveSt, 1, FORCE_MIDDLE_LINK);
784  FatLinkForce<decltype(arg)> middleLink(middleLinkArg, link, sig, nu, FORCE_MIDDLE_LINK);
785  middleLink.apply(0);
786 
787  for (int rho = 0; rho < 8; rho++) {
788  if (rho == sig || rho == opp_dir(sig) || rho == mu || rho == opp_dir(mu) || rho == nu || rho == opp_dir(nu)) continue;
789 
790  //7-link: middle link and side link
791  FatLinkArg<real, nColor> arg(newOprod, P5, Pnumu, Qnumu, link, SevenSt, FiveSt != 0 ? SevenSt/FiveSt : 0, 1, FORCE_ALL_LINK, true);
792  FatLinkForce<decltype(arg)> all(arg, link, sig, rho, FORCE_ALL_LINK);
793  all.apply(0);
794 
795  }//rho
796 
797  //5-link: side link
798  FatLinkArg<real, nColor> arg(newOprod, P3, P5, Qmu, link, mFiveSt, (ThreeSt != 0 ? FiveSt/ThreeSt : 0), 1, FORCE_SIDE_LINK);
799  FatLinkForce<decltype(arg)> side(arg, link, sig, nu, FORCE_SIDE_LINK);
800  side.apply(0);
801 
802  } //nu
803 
804  //lepage
805  if (Lepage != 0.) {
806  FatLinkArg<real, nColor> middleLinkArg( newOprod, P5, Pmu, Qmu, link, Lepage, 2, FORCE_LEPAGE_MIDDLE_LINK);
807  FatLinkForce<decltype(arg)> middleLink(middleLinkArg, link, sig, mu, FORCE_LEPAGE_MIDDLE_LINK);
808  middleLink.apply(0);
809 
810  FatLinkArg<real, nColor> arg(newOprod, P3, P5, Qmu, link, mLepage, (ThreeSt != 0 ? Lepage/ThreeSt : 0), 2, FORCE_SIDE_LINK);
811  FatLinkForce<decltype(arg)> side(arg, link, sig, mu, FORCE_SIDE_LINK);
812  side.apply(0);
813  } // Lepage != 0.0
814 
815  // 3-link side link
816  FatLinkArg<real, nColor> arg(newOprod, P3, link, ThreeSt, 1, FORCE_SIDE_LINK_SHORT);
817  FatLinkForce<decltype(arg)> side(arg, P3, sig, mu, FORCE_SIDE_LINK_SHORT);
818  side.apply(0);
819  }//mu
820  }//sig
821  }
822  };
823 
824  void hisqStaplesForce(GaugeField &newOprod, const GaugeField &oprod, const GaugeField &link, const double path_coeff_array[6])
825  {
826  if (!link.isNative()) errorQuda("Unsupported gauge order %d", link.Order());
827  if (!oprod.isNative()) errorQuda("Unsupported gauge order %d", oprod.Order());
828  if (!newOprod.isNative()) errorQuda("Unsupported gauge order %d", newOprod.Order());
829  if (checkLocation(newOprod,oprod,link) == QUDA_CPU_FIELD_LOCATION) errorQuda("CPU not implemented");
830 
831  // create color matrix fields with zero padding
832  GaugeFieldParam gauge_param(link);
833  gauge_param.reconstruct = QUDA_RECONSTRUCT_NO;
834  gauge_param.order = QUDA_FLOAT2_GAUGE_ORDER;
835  gauge_param.geometry = QUDA_SCALAR_GEOMETRY;
836 
837  cudaGaugeField Pmu(gauge_param);
838  cudaGaugeField P3(gauge_param);
839  cudaGaugeField P5(gauge_param);
840  cudaGaugeField Pnumu(gauge_param);
841  cudaGaugeField Qmu(gauge_param);
842  cudaGaugeField Qnumu(gauge_param);
843 
844  QudaPrecision precision = checkPrecision(oprod, link, newOprod);
845  instantiate<HisqStaplesForce, ReconstructNone>(Pmu, P3, P5, Pnumu, Qmu, Qnumu, newOprod, oprod, link, path_coeff_array);
846 
847  qudaDeviceSynchronize();
848  }
849 
850  template <typename real, int nColor, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
851  struct CompleteForceArg : public BaseForceArg<real, nColor, reconstruct> {
852 
853  typedef typename gauge_mapper<real,QUDA_RECONSTRUCT_NO>::type F;
854  F outA; // force output accessor
855  const F oProd; // force input accessor
856  const real coeff;
857 
858  CompleteForceArg(GaugeField &force, const GaugeField &link)
859  : BaseForceArg<real, nColor, reconstruct>(link, 0), outA(force), oProd(force), coeff(0.0)
860  { }
861 
862  };
863 
864  // Flops count: 4 matrix multiplications per lattice site = 792 Flops per site
865  template <typename Arg>
866  __global__ void completeForceKernel(Arg arg)
867  {
868  typedef Matrix<complex<typename Arg::real>, Arg::nColor> Link;
869  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
870  if (x_cb >= arg.threads) return;
871  int parity = blockIdx.y * blockDim.y + threadIdx.y;
872 
873  int x[4];
874  getCoords(x, x_cb, arg.X, parity);
875 
876  for (int d=0; d<4; d++) x[d] += arg.border[d];
877  int e_cb = linkIndex(x,arg.E);
878 
879 #pragma unroll
880  for (int sig=0; sig<4; ++sig) {
881  Link Uw = arg.link(sig, e_cb, parity);
882  Link Ox = arg.oProd(sig, e_cb, parity);
883  Link Ow = Uw*Ox;
884 
885  makeAntiHerm(Ow);
886 
887  typename Arg::real coeff = (parity==1) ? -1.0 : 1.0;
888  arg.outA(sig, e_cb, parity) = coeff*Ow;
889  }
890  }
891 
892  template <typename real, int nColor, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
893  struct LongLinkArg : public BaseForceArg<real, nColor, reconstruct> {
894 
895  typedef typename gauge::FloatNOrder<real,18,2,11> M;
896  typedef typename gauge_mapper<real,QUDA_RECONSTRUCT_NO>::type F;
897  F outA;
898  const F oProd;
899  const real coeff;
900 
901  LongLinkArg(GaugeField &newOprod, const GaugeField &link, const GaugeField &oprod, real coeff)
902  : BaseForceArg<real, nColor, reconstruct>(link,0), outA(newOprod), oProd(oprod), coeff(coeff)
903  { }
904 
905  };
906 
907  // Flops count, in two-number pair (matrix_mult, matrix_add)
908  // (24, 12)
909  // 4968 Flops per site in total
910  template <typename Arg>
911  __global__ void longLinkKernel(Arg arg)
912  {
913  typedef Matrix<complex<typename Arg::real>, Arg::nColor> Link;
914  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
915  if (x_cb >= arg.threads) return;
916  int parity = blockIdx.y * blockDim.y + threadIdx.y;
917 
918  int x[4];
919  int dx[4] = {0,0,0,0};
920 
921  getCoords(x, x_cb, arg.X, parity);
922 
923  for (int i=0; i<4; i++) x[i] += arg.border[i];
924  int e_cb = linkIndex(x,arg.E);
925 
926  /*
927  *
928  * A B C D E
929  * ---- ---- ---- ----
930  *
931  * ---> sig direction
932  *
933  * C is the current point (sid)
934  *
935  */
936 
937  // compute the force for forward long links
938 #pragma unroll
939  for (int sig=0; sig<4; sig++) {
940  int point_c = e_cb;
941 
942  dx[sig]++;
943  int point_d = linkIndexShift(x,dx,arg.E);
944 
945  dx[sig]++;
946  int point_e = linkIndexShift(x,dx,arg.E);
947 
948  dx[sig] = -1;
949  int point_b = linkIndexShift(x,dx,arg.E);
950 
951  dx[sig]--;
952  int point_a = linkIndexShift(x,dx,arg.E);
953  dx[sig] = 0;
954 
955  Link Uab = arg.link(sig, point_a, parity);
956  Link Ubc = arg.link(sig, point_b, 1-parity);
957  Link Ude = arg.link(sig, point_d, 1-parity);
958  Link Uef = arg.link(sig, point_e, parity);
959 
960  Link Oz = arg.oProd(sig, point_c, parity);
961  Link Oy = arg.oProd(sig, point_b, 1-parity);
962  Link Ox = arg.oProd(sig, point_a, parity);
963 
964  Link temp = Ude*Uef*Oz - Ude*Oy*Ubc + Ox*Uab*Ubc;
965 
966  Link force = arg.outA(sig, e_cb, parity);
967  arg.outA(sig, e_cb, parity) = force + arg.coeff*temp;
968  } // loop over sig
969 
970  }
971 
972  template <typename Arg>
973  class HisqForce : public TunableVectorY {
974 
975  Arg &arg;
976  const GaugeField &meta;
977  const HisqForceType type;
978 
979  unsigned int minThreads() const { return arg.threads; }
980  bool tuneGridDim() const { return false; }
981 
982  public:
983  HisqForce(Arg &arg, const GaugeField &meta, int sig, int mu, HisqForceType type)
984  : TunableVectorY(2), arg(arg), meta(meta), type(type) {
985  arg.sig = sig;
986  arg.mu = mu;
987  }
988 
989  void apply(const qudaStream_t &stream) {
990  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
991  switch (type) {
992  case FORCE_LONG_LINK: qudaLaunchKernel(longLinkKernel<Arg>, tp, stream, arg); break;
993  case FORCE_COMPLETE: qudaLaunchKernel(completeForceKernel<Arg>, tp, stream, arg); break;
994  default:
995  errorQuda("Undefined force type %d", type);
996  }
997  }
998 
999  TuneKey tuneKey() const {
1000  std::stringstream aux;
1001  aux << meta.AuxString() << comm_dim_partitioned_string() << ",threads=" << arg.threads;
1002  switch (type) {
1003  case FORCE_LONG_LINK: aux << ",LONG_LINK"; break;
1004  case FORCE_COMPLETE: aux << ",COMPLETE"; break;
1005  default: errorQuda("Undefined force type %d", type);
1006  }
1007  return TuneKey(meta.VolString(), typeid(*this).name(), aux.str().c_str());
1008  }
1009 
1010  void preTune() {
1011  switch (type) {
1012  case FORCE_LONG_LINK:
1013  case FORCE_COMPLETE:
1014  arg.outA.save(); break;
1015  default: errorQuda("Undefined force type %d", type);
1016  }
1017  }
1018 
1019  void postTune() {
1020  switch (type) {
1021  case FORCE_LONG_LINK:
1022  case FORCE_COMPLETE:
1023  arg.outA.load(); break;
1024  default: errorQuda("Undefined force type %d", type);
1025  }
1026  }
1027 
1028  long long flops() const {
1029  switch (type) {
1030  case FORCE_LONG_LINK: return 2*arg.threads*4968ll;
1031  case FORCE_COMPLETE: return 2*arg.threads*792ll;
1032  default: errorQuda("Undefined force type %d", type);
1033  }
1034  return 0;
1035  }
1036 
1037  long long bytes() const {
1038  switch (type) {
1039  case FORCE_LONG_LINK: return 4*2*arg.threads*(2*arg.outA.Bytes() + 4*arg.link.Bytes() + 3*arg.oProd.Bytes());
1040  case FORCE_COMPLETE: return 4*2*arg.threads*(arg.outA.Bytes() + arg.link.Bytes() + arg.oProd.Bytes());
1041  default: errorQuda("Undefined force type %d", type);
1042  }
1043  return 0;
1044  }
1045  };
1046 
1047  template <typename real, int nColor, QudaReconstructType recon>
1048  struct HisqLongLinkForce {
1049  HisqLongLinkForce(GaugeField &newOprod, const GaugeField &oldOprod, const GaugeField &link, double coeff)
1050  {
1051  LongLinkArg<real, nColor, recon> arg(newOprod, link, oldOprod, coeff);
1052  HisqForce<decltype(arg)> longLink(arg, link, 0, 0, FORCE_LONG_LINK);
1053  longLink.apply(0);
1054  qudaDeviceSynchronize();
1055  }
1056  };
1057 
1058  void hisqLongLinkForce(GaugeField &newOprod, const GaugeField &oldOprod, const GaugeField &link, double coeff)
1059  {
1060  if (!link.isNative()) errorQuda("Unsupported gauge order %d", link.Order());
1061  if (!oldOprod.isNative()) errorQuda("Unsupported gauge order %d", oldOprod.Order());
1062  if (!newOprod.isNative()) errorQuda("Unsupported gauge order %d", newOprod.Order());
1063  if (checkLocation(newOprod,oldOprod,link) == QUDA_CPU_FIELD_LOCATION) errorQuda("CPU not implemented");
1064  checkPrecision(newOprod, link, oldOprod);
1065  instantiate<HisqLongLinkForce, ReconstructNone>(newOprod, oldOprod, link, coeff);
1066  }
1067 
1068  template <typename real, int nColor, QudaReconstructType recon>
1069  struct HisqCompleteForce {
1070  HisqCompleteForce(GaugeField &force, const GaugeField &link)
1071  {
1072  CompleteForceArg<real, nColor, recon> arg(force, link);
1073  HisqForce<decltype(arg)> completeForce(arg, link, 0, 0, FORCE_COMPLETE);
1074  completeForce.apply(0);
1075  qudaDeviceSynchronize();
1076  }
1077  };
1078 
1079  void hisqCompleteForce(GaugeField &force, const GaugeField &link)
1080  {
1081  if (!link.isNative()) errorQuda("Unsupported gauge order %d", link.Order());
1082  if (!force.isNative()) errorQuda("Unsupported gauge order %d", force.Order());
1083  if (checkLocation(force,link) == QUDA_CPU_FIELD_LOCATION) errorQuda("CPU not implemented");
1084  checkPrecision(link, force);
1085  instantiate<HisqCompleteForce, ReconstructNone>(force, link);
1086  }
1087 
1088  } // namespace fermion_force
1089 
1090 } // namespace quda
1091 
1092 #endif // GPU_HISQ_FORCE