QUDA  0.9.0
hisq_paths_force_quda.cu
Go to the documentation of this file.
1 #include <quda_internal.h>
2 #include <gauge_field.h>
3 #include <ks_improved_force.h>
4 #include <quda_matrix.h>
5 #include <tune_quda.h>
6 #include <index_helper.cuh>
7 #include <gauge_field_order.h>
8 
9 #ifdef GPU_HISQ_FORCE
10 
11 namespace quda {
12 
13  namespace fermion_force {
14 
15  enum {
16  XUP = 0,
17  YUP = 1,
18  ZUP = 2,
19  TUP = 3,
20  TDOWN = 4,
21  ZDOWN = 5,
22  YDOWN = 6,
23  XDOWN = 7
24  };
25 
26  enum HisqForceType {
27  FORCE_ALL_LINK,
28  FORCE_MIDDLE_LINK,
29  FORCE_LEPAGE_MIDDLE_LINK,
30  FORCE_SIDE_LINK,
31  FORCE_SIDE_LINK_SHORT,
32  FORCE_LONG_LINK,
33  FORCE_COMPLETE,
34  FORCE_ONE_LINK,
35  FORCE_INVALID
36  };
37 
38  __device__ __host__ constexpr inline int opp_dir(int dir) { return 7-dir; }
39  __device__ __host__ constexpr inline int goes_forward(int dir) { return dir<=3; }
40  __device__ __host__ constexpr inline int goes_backward(int dir) { return dir>3; }
41  __device__ __host__ constexpr inline int CoeffSign(int pos_dir, int odd_lattice) { return 2*((pos_dir + odd_lattice + 1) & 1) - 1; }
42  __device__ __host__ constexpr inline int Sign(int parity) { return parity ? -1 : 1; }
43  __device__ __host__ constexpr inline int posDir(int dir) { return (dir >= 4) ? 7-dir : dir; }
44 
45  template <int dir, typename Arg>
46  inline __device__ __host__ void updateCoords(int x[], int shift, const Arg &arg) {
47  x[dir] = (x[dir] + shift + arg.E[dir]) % arg.E[dir];
48  }
49 
50  template <typename Arg>
51  inline __device__ __host__ void updateCoords(int x[], int dir, int shift, const Arg &arg) {
52  switch (dir) {
53  case 0: updateCoords<0>(x, shift, arg); break;
54  case 1: updateCoords<1>(x, shift, arg); break;
55  case 2: updateCoords<2>(x, shift, arg); break;
56  case 3: updateCoords<3>(x, shift, arg); break;
57  }
58  }
59 
60  //struct for holding the fattening path coefficients
61  template <typename real>
62  struct PathCoefficients {
63  const real one;
64  const real three;
65  const real five;
66  const real seven;
67  const real naik;
68  const real lepage;
69  PathCoefficients(const double *path_coeff_array)
70  : one(path_coeff_array[0]), naik(path_coeff_array[1]),
71  three(path_coeff_array[2]), five(path_coeff_array[3]),
72  seven(path_coeff_array[4]), lepage(path_coeff_array[5]) { }
73  };
74 
75  template <typename real, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
76  struct BaseForceArg {
77  typedef typename gauge_mapper<real,reconstruct>::type G;
78  const G link;
79  int threads;
80  int X[4]; // regular grid dims
81  int D[4]; // working set grid dims
82  int E[4]; // extended grid dims
83 
84  int commDim[4];
85  int border[4];
86  int base_idx[4]; // the offset into the extended field
87  int oddness_change;
88  int mu;
89  int sig;
90 
95  BaseForceArg(const GaugeField &link, int overlap) : link(link), threads(1),
97  {
98  for (int d=0; d<4; d++) {
99  E[d] = link.X()[d];
100  border[d] = link.R()[d];
101  X[d] = E[d] - 2*border[d];
102  D[d] = comm_dim_partitioned(d) ? X[d]+overlap*2 : X[d];
103  base_idx[d] = comm_dim_partitioned(d) ? border[d]-overlap : 0;
104  threads *= D[d];
105  }
106  threads /= 2;
107  oddness_change = (base_idx[0] + base_idx[1] + base_idx[2] + base_idx[3])&1;
108  }
109  };
110 
111  template <typename real, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
112  struct FatLinkArg : public BaseForceArg<real,reconstruct> {
113 
114  typedef typename gauge_mapper<real,QUDA_RECONSTRUCT_NO>::type F;
115  F outA;
116  F outB;
117  F pMu;
118  F p3;
119  F qMu;
120 
121  const F oProd;
122  const F qProd;
123  const F qPrev;
124  const real coeff;
125  const real accumu_coeff;
126 
127  const bool p_mu;
128  const bool q_mu;
129  const bool q_prev;
130 
131  FatLinkArg(GaugeField &force, const GaugeField &oProd, const GaugeField &link, real coeff, HisqForceType type)
132  : BaseForceArg<real,reconstruct>(link, 0), outA(force), outB(force), pMu(oProd), p3(oProd), qMu(oProd),
133  oProd(oProd), qProd(oProd), qPrev(oProd), coeff(coeff), accumu_coeff(0),
134  p_mu(false), q_mu(false), q_prev(false)
135  { if (type != FORCE_ONE_LINK) errorQuda("This constructor is for FORCE_ONE_LINK"); }
136 
137  FatLinkArg(GaugeField &newOprod, GaugeField &pMu, GaugeField &P3, GaugeField &qMu,
138  const GaugeField &oProd, const GaugeField &qPrev, const GaugeField &link,
139  real coeff, int overlap, HisqForceType type)
140  : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(newOprod), pMu(pMu), p3(P3), qMu(qMu),
141  oProd(oProd), qProd(oProd), qPrev(qPrev), coeff(coeff), accumu_coeff(0), p_mu(true), q_mu(true), q_prev(true)
142  { if (type != FORCE_MIDDLE_LINK) errorQuda("This constructor is for FORCE_MIDDLE_LINK"); }
143 
144  FatLinkArg(GaugeField &newOprod, GaugeField &pMu, GaugeField &P3, GaugeField &qMu,
145  const GaugeField &oProd, const GaugeField &link,
146  real coeff, int overlap, HisqForceType type)
147  : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(newOprod), pMu(pMu), p3(P3), qMu(qMu),
148  oProd(oProd), qProd(oProd), qPrev(qMu), coeff(coeff), accumu_coeff(0), p_mu(true), q_mu(true), q_prev(false)
149  { if (type != FORCE_MIDDLE_LINK) errorQuda("This constructor is for FORCE_MIDDLE_LINK"); }
150 
151  FatLinkArg(GaugeField &newOprod, GaugeField &P3, const GaugeField &oProd,
152  const GaugeField &qPrev, const GaugeField &link,
153  real coeff, int overlap, HisqForceType type)
154  : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(newOprod), pMu(P3), p3(P3), qMu(qPrev),
155  oProd(oProd), qProd(oProd), qPrev(qPrev), coeff(coeff), accumu_coeff(0), p_mu(false), q_mu(false), q_prev(true)
156  { if (type != FORCE_LEPAGE_MIDDLE_LINK) errorQuda("This constructor is for FORCE_MIDDLE_LINK"); }
157 
158  FatLinkArg(GaugeField &newOprod, GaugeField &shortP, const GaugeField &P3,
159  const GaugeField &qProd, const GaugeField &link, real coeff, real accumu_coeff, int overlap, HisqForceType type)
160  : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(shortP), pMu(P3), p3(P3), qMu(qProd), oProd(qProd), qProd(qProd),
161  qPrev(qProd), coeff(coeff), accumu_coeff(accumu_coeff),
162  p_mu(false), q_mu(false), q_prev(false)
163  { if (type != FORCE_SIDE_LINK) errorQuda("This constructor is for FORCE_SIDE_LINK or FORCE_ALL_LINK"); }
164 
165  FatLinkArg(GaugeField &newOprod, GaugeField &P3, const GaugeField &link,
166  real coeff, int overlap, HisqForceType type)
167  : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(newOprod),
168  pMu(P3), p3(P3), qMu(P3), oProd(P3), qProd(P3), qPrev(P3), coeff(coeff), accumu_coeff(0.0),
169  p_mu(false), q_mu(false), q_prev(false)
170  { if (type != FORCE_SIDE_LINK_SHORT) errorQuda("This constructor is for FORCE_SIDE_LINK_SHORT"); }
171 
172  FatLinkArg(GaugeField &newOprod, GaugeField &shortP, const GaugeField &oProd, const GaugeField &qPrev,
173  const GaugeField &link, real coeff, real accumu_coeff, int overlap, HisqForceType type, bool dummy)
174  : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(shortP), oProd(oProd), qPrev(qPrev),
175  pMu(shortP), p3(shortP), qMu(qPrev), qProd(qPrev), // dummy
176  coeff(coeff), accumu_coeff(accumu_coeff), p_mu(false), q_mu(false), q_prev(false)
177  { if (type != FORCE_ALL_LINK) errorQuda("This constructor is for FORCE_ALL_LINK"); }
178 
179  };
180 
181  template <typename real, typename Arg>
182  __global__ void oneLinkTermKernel(Arg arg)
183  {
184  typedef Matrix<complex<real>,3> Link;
185  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
186  if (x_cb >= arg.threads) return;
187  int parity = blockIdx.y * blockDim.y + threadIdx.y;
188  int sig = blockIdx.z * blockDim.z + threadIdx.z;
189  if (sig >= 4) return;
190 
191  int x[4];
192  getCoords(x, x_cb, arg.X, parity);
193 #pragma unroll
194  for (int d=0; d<4; d++) x[d] += arg.border[d];
195  int e_cb = linkIndex(x,arg.E);
196 
197  Link w = arg.oProd(sig, e_cb, parity);
198  Link force = arg.outA(sig, e_cb, parity);
199  force += arg.coeff * w;
200  arg.outA(sig, e_cb, parity) = force;
201  }
202 
203 
204  /********************************allLinkKernel*********************************************
205  *
206  * In this function we need
207  * READ
208  * 3 LINKS: ad_link, ab_link, bc_link
209  * 5 COLOR MATRIX: Qprev_at_D, oprod_at_C, newOprod_at_A(sig), newOprod_at_D/newOprod_at_A(mu), shortP_at_D
210  * WRITE:
211  * 3 COLOR MATRIX: newOprod_at_A(sig), newOprod_at_D/newOprod_at_A(mu), shortP_at_D,
212  *
213  * If sig is negative, then we don't need to read/write the color matrix newOprod_at_A(sig)
214  *
215  * Therefore the data traffic, in two-number pair (num_of_link, num_of_color_matrix)
216  *
217  * if (sig is positive): (3, 8)
218  * else : (3, 6)
219  *
220  * This function is called 384 times, half positive sig, half negative sig
221  *
222  * Flop count, in two-number pair (matrix_multi, matrix_add)
223  * if(sig is positive) (6,3)
224  * else (4,2)
225  *
226  ************************************************************************************************/
227  template<typename real, int sig_positive, int mu_positive, typename Arg>
228  __global__ void allLinkKernel(Arg arg)
229  {
230  typedef Matrix<complex<real>,3> Link;
231 
232  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
233  if (x_cb >= arg.threads) return;
234  int parity = blockIdx.y * blockDim.y + threadIdx.y;
235 
236  int x[4];
237  getCoords(x, x_cb, arg.D, parity);
238  for (int d=0; d<4; d++) x[d] += arg.base_idx[d];
239  int e_cb = linkIndex(x,arg.E);
240  parity = parity^arg.oddness_change;
241 
242  real mycoeff = CoeffSign(sig_positive,parity)*arg.coeff;
243 
244  int y[4] = {x[0], x[1], x[2], x[3]};
245  int mysig = posDir(arg.sig);
246  updateCoords(y, mysig, (sig_positive ? 1 : -1), arg);
247  int point_b = linkIndex(y,arg.E);
248  int ab_link_nbr_idx = (sig_positive) ? e_cb : point_b;
249 
250  for (int d=0; d<4; d++) y[d] = x[d];
251 
252  /* sig
253  * A________B
254  * mu | |
255  * D | |C
256  *
257  * A is the current point (sid)
258  *
259  */
260 
261  int mu = mu_positive ? arg.mu : opp_dir(arg.mu);
262  int dir = mu_positive ? -1 : 1;
263 
264  updateCoords(y, mu, dir, arg);
265  int point_d = linkIndex(y,arg.E);
266  updateCoords(y, mysig, (sig_positive ? 1 : -1), arg);
267  int point_c = linkIndex(y,arg.E);
268 
269  Link Uab = arg.link(posDir(arg.sig), ab_link_nbr_idx, sig_positive^(1-parity));
270  Link Uad = arg.link(mu, mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity);
271  Link Ubc = arg.link(mu, mu_positive ? point_c : point_b, mu_positive ? parity : 1-parity);
272  Link Ox = arg.qPrev(0, point_d, 1-parity);
273  Link Oy = arg.oProd(0, point_c, parity);
274  Link Oz = mu_positive ? conj(Ubc)*Oy : Ubc*Oy;
275 
276  if (sig_positive) {
277  Link force = arg.outA(arg.sig, e_cb, parity);
278  force += Sign(parity)*mycoeff*Oz*Ox* (mu_positive ? Uad : conj(Uad));
279  arg.outA(arg.sig, e_cb, parity) = force;
280  Oy = Uab*Oz;
281  } else {
282  Oy = conj(Uab)*Oz;
283  }
284 
285  Link force = arg.outA(mu, mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity);
286  force += Sign(mu_positive ? 1-parity : parity)*mycoeff* (mu_positive ? Oy*Ox : conj(Ox)*conj(Oy));
287  arg.outA(mu, mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity) = force;
288 
289  Link shortP = arg.outB(0, point_d, 1-parity);
290  shortP += arg.accumu_coeff* (mu_positive ? Uad : conj(Uad)) *Oy;
291  arg.outB(0, point_d, 1-parity) = shortP;
292  }
293 
294 
295  /**************************middleLinkKernel*****************************
296  *
297  *
298  * Generally we need
299  * READ
300  * 3 LINKS: ab_link, bc_link, ad_link
301  * 3 COLOR MATRIX: newOprod_at_A, oprod_at_C, Qprod_at_D
302  * WRITE
303  * 4 COLOR MATRIX: newOprod_at_A, P3_at_A, Pmu_at_B, Qmu_at_A
304  *
305  * Three call variations:
306  * 1. when Qprev == NULL: Qprod_at_D does not exist and is not read in
307  * 2. full read/write
308  * 3. when Pmu/Qmu == NULL, Pmu_at_B and Qmu_at_A are not written out
309  *
310  * In all three above case, if the direction sig is negative, newOprod_at_A is
311  * not read in or written out.
312  *
313  * Therefore the data traffic, in two-number pair (num_of_link, num_of_color_matrix)
314  * Call 1: (called 48 times, half positive sig, half negative sig)
315  * if (sig is positive): (3, 6)
316  * else : (3, 4)
317  * Call 2: (called 192 time, half positive sig, half negative sig)
318  * if (sig is positive): (3, 7)
319  * else : (3, 5)
320  * Call 3: (called 48 times, half positive sig, half negative sig)
321  * if (sig is positive): (3, 5)
322  * else : (3, 2) no need to loadQprod_at_D in this case
323  *
324  * note: oprod_at_C could actually be read in from D when it is the fresh outer product
325  * and we call it oprod_at_C to simply naming. This does not affect our data traffic analysis
326  *
327  * Flop count, in two-number pair (matrix_multi, matrix_add)
328  * call 1: if (sig is positive) (3, 1)
329  * else (2, 0)
330  * call 2: if (sig is positive) (4, 1)
331  * else (3, 0)
332  * call 3: if (sig is positive) (4, 1)
333  * (Lepage) else (2, 0)
334  *
335  ****************************************************************************/
336  template <typename real, int sig_positive, int mu_positive, bool pMu, bool qMu, bool qPrev, typename Arg>
337  __global__ void middleLinkKernel(Arg arg)
338  {
339  typedef Matrix<complex<real>,3> Link;
340 
341  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
342  if (x_cb >= arg.threads) return;
343  int parity = blockIdx.y * blockDim.y + threadIdx.y;
344 
345  int x[4];
346  getCoords(x, x_cb, arg.D, parity);
347 
348  /* A________B
349  * mu | |
350  * D| |C
351  *
352  * A is the current point (sid)
353  *
354  */
355 
356  for (int d=0; d<4; d++) x[d] += arg.base_idx[d];
357  int e_cb = linkIndex(x,arg.E);
358  parity = parity ^ arg.oddness_change;
359  int y[4] = {x[0], x[1], x[2], x[3]};
360 
361  int mymu = posDir(arg.mu);
362  updateCoords(y, mymu, (mu_positive ? -1 : 1), arg);
363 
364  int point_d = linkIndex(y, arg.E);
365  int ad_link_nbr_idx = mu_positive ? point_d : e_cb;
366 
367  int mysig = posDir(arg.sig);
368  updateCoords(y, mysig, (sig_positive ? 1 : -1), arg);
369  int point_c = linkIndex(y, arg.E);
370 
371  for (int d=0; d<4; d++) y[d] = x[d];
372  updateCoords(y, mysig, (sig_positive ? 1 : -1), arg);
373  int point_b = linkIndex(y, arg.E);
374 
375  int bc_link_nbr_idx = mu_positive ? point_c : point_b;
376  int ab_link_nbr_idx = sig_positive ? e_cb : point_b;
377 
378  // load the link variable connecting a and b
379  Link Uab = arg.link(mysig, ab_link_nbr_idx, sig_positive^(1-parity));
380 
381  // load the link variable connecting b and c
382  Link Ubc = arg.link(mymu, bc_link_nbr_idx, mu_positive^(1-parity));
383 
384  Link Oy;
385  if (!qPrev) {
386  Oy = arg.oProd(posDir(arg.sig), sig_positive ? point_d : point_c, sig_positive^parity);
387  if (!sig_positive) Oy = conj(Oy);
388  } else { // QprevOdd != NULL
389  Oy = arg.oProd(0, point_c, parity);
390  }
391 
392  Link Ow = !mu_positive ? Ubc*Oy : conj(Ubc)*Oy;
393 
394  if (pMu) arg.pMu(0, point_b, 1-parity) = Ow;
395 
396  arg.p3(0, e_cb, parity) = sig_positive ? Uab*Ow : conj(Uab)*Ow;
397 
398  Link Uad = arg.link(mymu, ad_link_nbr_idx, mu_positive^parity);
399  if (!mu_positive) Uad = conj(Uad);
400 
401  if (!qPrev) {
402  if (sig_positive) Oy = Ow*Uad;
403  if ( qMu ) arg.qMu(0, e_cb, parity) = Uad;
404  } else {
405  Link Ox;
406  if ( qMu || sig_positive ) {
407  Oy = arg.qPrev(0, point_d, 1-parity);
408  Ox = Oy*Uad;
409  }
410  if ( qMu ) arg.qMu(0, e_cb, parity) = Ox;
411  if (sig_positive) Oy = Ow*Ox;
412  }
413 
414  if (sig_positive) {
415  Link oprod = arg.outA(arg.sig, e_cb, parity);
416  oprod += arg.coeff*Oy;
417  arg.outA(arg.sig, e_cb, parity) = oprod;
418  }
419 
420  }
421 
422  /***********************************sideLinkKernel***************************
423  *
424  * In general we need
425  * READ
426  * 1 LINK: ad_link
427  * 4 COLOR MATRIX: shortP_at_D, newOprod, P3_at_A, Qprod_at_D,
428  * WRITE
429  * 2 COLOR MATRIX: shortP_at_D, newOprod,
430  *
431  * Two call variations:
432  * 1. full read/write
433  * 2. when shortP == NULL && Qprod == NULL:
434  * no need to read ad_link/shortP_at_D or write shortP_at_D
435  * Qprod_at_D does not exit and is not read in
436  *
437  *
438  * Therefore the data traffic, in two-number pair (num_of_links, num_of_color_matrix)
439  * Call 1: (called 192 times)
440  * (1, 6)
441  *
442  * Call 2: (called 48 times)
443  * (0, 3)
444  *
445  * note: newOprod can be at point D or A, depending on if mu is postive or negative
446  *
447  * Flop count, in two-number pair (matrix_multi, matrix_add)
448  * call 1: (2, 2)
449  * call 2: (0, 1)
450  *
451  *********************************************************************************/
452  template <typename real, int mu_positive, typename Arg>
453  __global__ void sideLinkKernel(Arg arg)
454  {
455  typedef Matrix<complex<real>, 3> Link;
456  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
457  if (x_cb >= arg.threads) return;
458  int parity = blockIdx.y * blockDim.y + threadIdx.y;
459 
460  int x[4];
461  getCoords(x, x_cb ,arg.D, parity);
462  for (int d=0; d<4; d++) x[d] = x[d] + arg.base_idx[d];
463  int e_cb = linkIndex(x,arg.E);
464  parity = parity ^ arg.oddness_change;
465 
466  /* compute the side link contribution to the momentum
467  *
468  * sig
469  * A________B
470  * | | mu
471  * D | |C
472  *
473  * A is the current point (x_cb)
474  *
475  */
476 
477  int mymu = posDir(arg.mu);
478  int y[4] = {x[0], x[1], x[2], x[3]};
479  updateCoords(y, mymu, (mu_positive ? -1 : 1), arg);
480  int point_d = linkIndex(y,arg.E);
481 
482  Link Oy = arg.p3(0, e_cb, parity);
483 
484  {
485  int ad_link_nbr_idx = mu_positive ? point_d : e_cb;
486 
487  Link Uad = arg.link(mymu, ad_link_nbr_idx, mu_positive^parity);
488  Link Ow = mu_positive ? Uad*Oy : conj(Uad)*Oy;
489 
490  Link shortP = arg.outB(0, point_d, 1-parity);
491  shortP += arg.accumu_coeff * Ow;
492  arg.outB(0, point_d, 1-parity) = shortP;
493  }
494 
495  {
496  Link Ox = arg.qProd(0, point_d, 1-parity);
497  Link Ow = mu_positive ? Oy*Ox : conj(Ox)*conj(Oy);
498 
499  real mycoeff = CoeffSign(goes_forward(arg.sig), parity)*CoeffSign(goes_forward(arg.mu),parity)*arg.coeff;
500 
501  Link oprod = arg.outA(mu_positive ? arg.mu : opp_dir(arg.mu), mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity);
502  oprod += mycoeff * Ow;
503  arg.outA(mu_positive ? arg.mu : opp_dir(arg.mu), mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity) = oprod;
504  }
505  }
506 
507  // Flop count, in two-number pair (matrix_mult, matrix_add)
508  // (0,1)
509  template<typename real, int mu_positive, typename Arg>
510  __global__ void sideLinkShortKernel(Arg arg)
511  {
512  typedef Matrix<complex<real>,3> Link;
513  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
514  if (x_cb >= arg.threads) return;
515  int parity = blockIdx.y * blockDim.y + threadIdx.y;
516 
517  int x[4];
518  getCoords(x, x_cb, arg.D, parity);
519  for (int d=0; d<4; d++) x[d] = x[d] + arg.base_idx[d];
520  int e_cb = linkIndex(x,arg.E);
521  parity = parity ^ arg.oddness_change;
522 
523  /* compute the side link contribution to the momentum
524  *
525  * sig
526  * A________B
527  * | | mu
528  * D | |C
529  *
530  * A is the current point (x_cb)
531  *
532  */
533  int mymu = posDir(arg.mu);
534  int y[4] = {x[0], x[1], x[2], x[3]};
535  updateCoords(y, mymu, (mu_positive ? -1 : 1), arg);
536  int point_d = mu_positive ? linkIndex(y,arg.E) : e_cb;
537 
538  int parity_ = mu_positive ? 1-parity : parity;
539  real mycoeff = CoeffSign(goes_forward(arg.sig),parity)*CoeffSign(goes_forward(arg.mu),parity)*arg.coeff;
540 
541  Link Oy = arg.p3(0, e_cb, parity);
542  Link oprod = arg.outA(posDir(arg.mu), point_d, parity_);
543  oprod += mu_positive ? mycoeff * Oy : mycoeff * conj(Oy);
544  arg.outA(posDir(arg.mu), point_d, parity_) = oprod;
545  }
546 
547  template <typename real, typename Arg>
548  class FatLinkForce : public TunableVectorYZ {
549 
550  private:
551  Arg &arg;
552  const GaugeField &meta;
553  const HisqForceType type;
554 
555  unsigned int minThreads() const { return arg.threads; }
556  bool tuneGridDim() const { return false; }
557 
558  public:
559  FatLinkForce(Arg &arg, const GaugeField &meta, int sig, int mu, HisqForceType type)
560  : TunableVectorYZ(2,type == FORCE_ONE_LINK ? 4 : 1), arg(arg), meta(meta), type(type) {
561  arg.sig = sig;
562  arg.mu = mu;
563  }
564  virtual ~FatLinkForce() { }
565 
566  TuneKey tuneKey() const {
567  std::stringstream aux;
568  if (type == FORCE_ONE_LINK) aux << "threads=" << arg.threads;
569  else if (type == FORCE_MIDDLE_LINK || type == FORCE_LEPAGE_MIDDLE_LINK)
570  aux << "threads=" << arg.threads << ",sig=" << arg.sig << ",mu=" << arg.mu <<
571  ",pMu=" << arg.p_mu << ",q_muu=" << arg.q_mu << ",q_prev=" << arg.q_prev;
572  else
573  aux << "threads=" << arg.threads << ",mu=" << arg.mu; // no sig dependence needed for side link
574 
575  switch (type) {
576  case FORCE_ONE_LINK: aux << ",ONE_LINK"; break;
577  case FORCE_ALL_LINK: aux << ",ALL_LINK"; break;
578  case FORCE_MIDDLE_LINK: aux << ",MIDDLE_LINK"; break;
579  case FORCE_LEPAGE_MIDDLE_LINK: aux << ",LEPAGE_MIDDLE_LINK"; break;
580  case FORCE_SIDE_LINK: aux << ",SIDE_LINK"; break;
581  case FORCE_SIDE_LINK_SHORT: aux << ",SIDE_LINK_SHORT"; break;
582  default: errorQuda("Undefined force type %d", type);
583  }
584  return TuneKey(meta.VolString(), typeid(*this).name(), aux.str().c_str());
585  }
586 
587  void apply(const cudaStream_t &stream) {
588  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
589  switch (type) {
590  case FORCE_ONE_LINK:
591  oneLinkTermKernel<real,Arg> <<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
592  break;
593  case FORCE_ALL_LINK:
594  if (goes_forward(arg.sig) && goes_forward(arg.mu))
595  allLinkKernel<real,1,1,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
596  else if (goes_forward(arg.sig) && goes_backward(arg.mu))
597  allLinkKernel<real,1,0,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
598  else if (goes_backward(arg.sig) && goes_forward(arg.mu))
599  allLinkKernel<real,0,1,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
600  else
601  allLinkKernel<real,0,0,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
602  break;
603  case FORCE_MIDDLE_LINK:
604  if (!arg.p_mu || !arg.q_mu) errorQuda("Expect p_mu=%d and q_mu=%d to both be true", arg.p_mu, arg.q_mu);
605  if (arg.q_prev) {
606  if (goes_forward(arg.sig) && goes_forward(arg.mu))
607  middleLinkKernel<real,1,1,true,true,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
608  else if (goes_forward(arg.sig) && goes_backward(arg.mu))
609  middleLinkKernel<real,1,0,true,true,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
610  else if (goes_backward(arg.sig) && goes_forward(arg.mu))
611  middleLinkKernel<real,0,1,true,true,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
612  else
613  middleLinkKernel<real,0,0,true,true,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
614  } else {
615  if (goes_forward(arg.sig) && goes_forward(arg.mu))
616  middleLinkKernel<real,1,1,true,true,false,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
617  else if (goes_forward(arg.sig) && goes_backward(arg.mu))
618  middleLinkKernel<real,1,0,true,true,false,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
619  else if (goes_backward(arg.sig) && goes_forward(arg.mu))
620  middleLinkKernel<real,0,1,true,true,false,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
621  else
622  middleLinkKernel<real,0,0,true,true,false,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
623  }
624  break;
625  case FORCE_LEPAGE_MIDDLE_LINK:
626  if (arg.p_mu || arg.q_mu || !arg.q_prev)
627  errorQuda("Expect p_mu=%d and q_mu=%d to both be false and q_prev=%d true", arg.p_mu, arg.q_mu, arg.q_prev);
628  if (goes_forward(arg.sig) && goes_forward(arg.mu))
629  middleLinkKernel<real,1,1,false,false,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
630  else if (goes_forward(arg.sig) && goes_backward(arg.mu))
631  middleLinkKernel<real,1,0,false,false,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
632  else if (goes_backward(arg.sig) && goes_forward(arg.mu))
633  middleLinkKernel<real,0,1,false,false,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
634  else
635  middleLinkKernel<real,0,0,false,false,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
636  break;
637  case FORCE_SIDE_LINK:
638  if (goes_forward(arg.mu)) sideLinkKernel<real,1,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
639  else sideLinkKernel<real,0,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
640  break;
641  case FORCE_SIDE_LINK_SHORT:
642  if (goes_forward(arg.mu)) sideLinkShortKernel<real,1,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
643  else sideLinkShortKernel<real,0,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
644  break;
645  default:
646  errorQuda("Undefined force type %d", type);
647  }
648  }
649 
650  void preTune() {
651  switch (type) {
652  case FORCE_ONE_LINK:
653  arg.outA.save();
654  break;
655  case FORCE_ALL_LINK:
656  arg.outA.save();
657  arg.outB.save();
658  break;
659  case FORCE_MIDDLE_LINK:
660  arg.pMu.save();
661  arg.qMu.save();
662  case FORCE_LEPAGE_MIDDLE_LINK:
663  arg.outA.save();
664  arg.p3.save();
665  break;
666  case FORCE_SIDE_LINK:
667  arg.outB.save();
668  case FORCE_SIDE_LINK_SHORT:
669  arg.outA.save();
670  break;
671  default: errorQuda("Undefined force type %d", type);
672  }
673  }
674 
675  void postTune() {
676  switch (type) {
677  case FORCE_ONE_LINK:
678  arg.outA.load();
679  break;
680  case FORCE_ALL_LINK:
681  arg.outA.load();
682  arg.outB.load();
683  break;
684  case FORCE_MIDDLE_LINK:
685  arg.pMu.load();
686  arg.qMu.load();
687  case FORCE_LEPAGE_MIDDLE_LINK:
688  arg.outA.load();
689  arg.p3.load();
690  break;
691  case FORCE_SIDE_LINK:
692  arg.outB.load();
693  case FORCE_SIDE_LINK_SHORT:
694  arg.outA.load();
695  break;
696  default: errorQuda("Undefined force type %d", type);
697  }
698  }
699 
700  long long flops() const {
701  switch (type) {
702  case FORCE_ONE_LINK:
703  return 2*4*arg.threads*36ll;
704  case FORCE_ALL_LINK:
705  return 2*arg.threads*(goes_forward(arg.sig) ? 1242ll : 828ll);
706  case FORCE_MIDDLE_LINK:
707  case FORCE_LEPAGE_MIDDLE_LINK:
708  return 2*arg.threads*(2 * 198 +
709  (!arg.q_prev && goes_forward(arg.sig) ? 198 : 0) +
710  (arg.q_prev && (arg.q_mu || goes_forward(arg.sig) ) ? 198 : 0) +
711  ((arg.q_prev && goes_forward(arg.sig) ) ? 198 : 0) +
712  ( goes_forward(arg.sig) ? 216 : 0) );
713  case FORCE_SIDE_LINK: return 2*arg.threads*2*234;
714  case FORCE_SIDE_LINK_SHORT: return 2*arg.threads*36;
715  default: errorQuda("Undefined force type %d", type);
716  }
717  return 0;
718  }
719 
720  long long bytes() const {
721  switch (type) {
722  case FORCE_ONE_LINK:
723  return 2*4*arg.threads*( arg.oProd.Bytes() + 2*arg.outA.Bytes() );
724  case FORCE_ALL_LINK:
725  return 2*arg.threads*( (goes_forward(arg.sig) ? 4 : 2)*arg.outA.Bytes() + 3*arg.link.Bytes()
726  + arg.oProd.Bytes() + arg.qPrev.Bytes() + 2*arg.outB.Bytes());
727  case FORCE_MIDDLE_LINK:
728  case FORCE_LEPAGE_MIDDLE_LINK:
729  return 2*arg.threads*( ( goes_forward(arg.sig) ? 2*arg.outA.Bytes() : 0 ) +
730  (arg.p_mu ? arg.pMu.Bytes() : 0) +
731  (arg.q_mu ? arg.qMu.Bytes() : 0) +
732  ( ( goes_forward(arg.sig) || arg.q_mu ) ? arg.qPrev.Bytes() : 0) +
733  arg.p3.Bytes() + 3*arg.link.Bytes() + arg.oProd.Bytes() );
734  case FORCE_SIDE_LINK:
735  return 2*arg.threads*( 2*arg.outA.Bytes() + 2*arg.outB.Bytes() +
736  arg.p3.Bytes() + arg.link.Bytes() + arg.qProd.Bytes() );
737  case FORCE_SIDE_LINK_SHORT:
738  return 2*arg.threads*( 2*arg.outA.Bytes() + arg.p3.Bytes() );
739  default: errorQuda("Undefined force type %d", type);
740  }
741  return 0;
742  }
743  };
744 
745  template<typename real>
746  static void hisqStaplesForce(GaugeField &Pmu, GaugeField &P3, GaugeField &P5, GaugeField &Pnumu,
747  GaugeField &Qmu, GaugeField &Qnumu, GaugeField &newOprod,
748  const GaugeField &oprod, const GaugeField &link,
749  const PathCoefficients<real> &act_path_coeff)
750  {
751  real OneLink = act_path_coeff.one;
752  real ThreeSt = act_path_coeff.three;
753  real mThreeSt = -ThreeSt;
754  real FiveSt = act_path_coeff.five;
755  real mFiveSt = -FiveSt;
756  real SevenSt = act_path_coeff.seven;
757  real Lepage = act_path_coeff.lepage;
758  real mLepage = -Lepage;
759 
760  FatLinkArg<real> arg(newOprod, oprod, link, OneLink, FORCE_ONE_LINK);
761  FatLinkForce<real, FatLinkArg<real> > oneLink(arg, link, 0, 0, FORCE_ONE_LINK);
762  oneLink.apply(0);
763 
764  for (int sig=0; sig<8; sig++) {
765  for (int mu=0; mu<8; mu++) {
766  if ( (mu == sig) || (mu == opp_dir(sig))) continue;
767 
768  //3-link
769  //Kernel A: middle link
770  FatLinkArg<real> middleLinkArg( newOprod, Pmu, P3, Qmu, oprod, link, mThreeSt, 2, FORCE_MIDDLE_LINK);
771  FatLinkForce<real, FatLinkArg<real> > middleLink(middleLinkArg, link, sig, mu, FORCE_MIDDLE_LINK);
772  middleLink.apply(0);
773 
774  for (int nu=0; nu < 8; nu++) {
775  if (nu == sig || nu == opp_dir(sig) || nu == mu || nu == opp_dir(mu)) continue;
776 
777  //5-link: middle link
778  //Kernel B
779  FatLinkArg<real> middleLinkArg( newOprod, Pnumu, P5, Qnumu, Pmu, Qmu, link, FiveSt, 1, FORCE_MIDDLE_LINK);
780  FatLinkForce<real, FatLinkArg<real> > middleLink(middleLinkArg, link, sig, nu, FORCE_MIDDLE_LINK);
781  middleLink.apply(0);
782 
783  for (int rho = 0; rho < 8; rho++) {
784  if (rho == sig || rho == opp_dir(sig) || rho == mu || rho == opp_dir(mu) || rho == nu || rho == opp_dir(nu)) continue;
785 
786  //7-link: middle link and side link
787  FatLinkArg<real> arg(newOprod, P5, Pnumu, Qnumu, link, SevenSt, FiveSt != 0 ? SevenSt/FiveSt : 0, 1, FORCE_ALL_LINK, true);
788  FatLinkForce<real, FatLinkArg<real> > all(arg, link, sig, rho, FORCE_ALL_LINK);
789  all.apply(0);
790 
791  }//rho
792 
793  //5-link: side link
794  FatLinkArg<real> arg(newOprod, P3, P5, Qmu, link, mFiveSt, (ThreeSt != 0 ? FiveSt/ThreeSt : 0), 1, FORCE_SIDE_LINK);
795  FatLinkForce<real, FatLinkArg<real> > side(arg, link, sig, nu, FORCE_SIDE_LINK);
796  side.apply(0);
797 
798  } //nu
799 
800  //lepage
801  if (Lepage != 0.) {
802  FatLinkArg<real> middleLinkArg( newOprod, P5, Pmu, Qmu, link, Lepage, 2, FORCE_LEPAGE_MIDDLE_LINK);
803  FatLinkForce<real, FatLinkArg<real> > middleLink(middleLinkArg, link, sig, mu, FORCE_LEPAGE_MIDDLE_LINK);
804  middleLink.apply(0);
805 
806  FatLinkArg<real> arg(newOprod, P3, P5, Qmu, link, mLepage, (ThreeSt != 0 ? Lepage/ThreeSt : 0), 2, FORCE_SIDE_LINK);
807  FatLinkForce<real, FatLinkArg<real> > side(arg, link, sig, mu, FORCE_SIDE_LINK);
808  side.apply(0);
809  } // Lepage != 0.0
810 
811  // 3-link side link
812  FatLinkArg<real> arg(newOprod, P3, link, ThreeSt, 1, FORCE_SIDE_LINK_SHORT);
813  FatLinkForce<real, FatLinkArg<real> > side(arg, P3, sig, mu, FORCE_SIDE_LINK_SHORT);
814  side.apply(0);
815  }//mu
816  }//sig
817 
818  } // hisqStaplesForce
819 
820  void hisqStaplesForce(GaugeField &newOprod, const GaugeField &oprod, const GaugeField &link, const double path_coeff_array[6], long long* flops)
821  {
822  if (!link.isNative()) errorQuda("Unsupported gauge order %d", link.Order());
823  if (!oprod.isNative()) errorQuda("Unsupported gauge order %d", oprod.Order());
824  if (!newOprod.isNative()) errorQuda("Unsupported gauge order %d", newOprod.Order());
825  if (checkLocation(newOprod,oprod,link) == QUDA_CPU_FIELD_LOCATION) errorQuda("CPU not implemented");
826 
827  // create color matrix fields with zero padding
828  GaugeFieldParam gauge_param(link);
832 
833  cudaGaugeField Pmu(gauge_param);
834  cudaGaugeField P3(gauge_param);
835  cudaGaugeField P5(gauge_param);
836  cudaGaugeField Pnumu(gauge_param);
837  cudaGaugeField Qmu(gauge_param);
838  cudaGaugeField Qnumu(gauge_param);
839 
840  QudaPrecision precision = checkPrecision(oprod, link, newOprod);
841  if (precision == QUDA_DOUBLE_PRECISION) {
842  PathCoefficients<double> act_path_coeff(path_coeff_array);
843  hisqStaplesForce<double>(Pmu, P3, P5, Pnumu, Qmu, Qnumu, newOprod, oprod, link, act_path_coeff);
844  } else if (precision == QUDA_SINGLE_PRECISION) {
845  PathCoefficients<float> act_path_coeff(path_coeff_array);
846  hisqStaplesForce<float>(Pmu, P3, P5, Pnumu, Qmu, Qnumu, newOprod, oprod, link, act_path_coeff);
847  } else {
848  errorQuda("Unsupported precision");
849  }
850 
851  cudaDeviceSynchronize();
852  checkCudaError();
853 
854  if (flops) {
855  int volume = 1;
856  for (int d=0; d<4; d++) volume += link.X()[d]-2*link.R()[d]; // compute physical volume for useful flops
857  // Middle Link, side link, short side link, AllLink, OneLink
858  *flops += (long long)volume*(134784 + 24192 + 103680 + 864 + 397440 + 72 + (path_coeff_array[5] != 0 ? 28944 : 0));
859  }
860 
861  }
862 
863  template <typename real, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
864  struct CompleteForceArg : public BaseForceArg<real,reconstruct> {
865 
866  typedef typename gauge::FloatNOrder<real,18,2,11> M;
867  typedef typename gauge_mapper<real,QUDA_RECONSTRUCT_NO>::type F;
868  M outA;
869  const F oProd;
870  const real coeff;
871 
872  CompleteForceArg(GaugeField &force, const GaugeField &link, const GaugeField &oprod)
873  : BaseForceArg<real,reconstruct>(link, 0), outA(force), oProd(oprod), coeff(0.0)
874  { }
875 
876  };
877 
878  // Flops count: 4 matrix multiplications per lattice site = 792 Flops per site
879  template <typename real, typename Arg>
880  __global__ void completeForceKernel(Arg arg)
881  {
882  typedef Matrix<complex<real>,3> Link;
883  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
884  if (x_cb >= arg.threads) return;
885  int parity = blockIdx.y * blockDim.y + threadIdx.y;
886 
887  int x[4];
888  getCoords(x, x_cb, arg.X, parity);
889 
890  for (int d=0; d<4; d++) x[d] += arg.border[d];
891  int e_cb = linkIndex(x,arg.E);
892 
893 #pragma unroll
894  for (int sig=0; sig<4; ++sig) {
895  Link Uw = arg.link(sig, e_cb, parity);
896  Link Ox = arg.oProd(sig, e_cb, parity);
897  Link Ow = Uw*Ox;
898 
899  makeAntiHerm(Ow);
900 
901  real coeff = (parity==1) ? -1.0 : 1.0;
902  arg.outA(sig, x_cb, parity) = coeff*Ow;
903  }
904  }
905 
906  template <typename real, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
907  struct LongLinkArg : public BaseForceArg<real,reconstruct> {
908 
909  typedef typename gauge::FloatNOrder<real,18,2,11> M;
910  typedef typename gauge_mapper<real,QUDA_RECONSTRUCT_NO>::type F;
911  F outA;
912  const F oProd;
913  const real coeff;
914 
915  LongLinkArg(GaugeField &newOprod, const GaugeField &link, const GaugeField &oprod, real coeff)
916  : BaseForceArg<real,reconstruct>(link,0), outA(newOprod), oProd(oprod), coeff(coeff)
917  { }
918 
919  };
920 
921  // Flops count, in two-number pair (matrix_mult, matrix_add)
922  // (24, 12)
923  // 4968 Flops per site in total
924  template <typename real, typename Arg>
925  __global__ void longLinkKernel(Arg arg)
926  {
927  typedef Matrix<complex<real>,3> Link;
928  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
929  if (x_cb >= arg.threads) return;
930  int parity = blockIdx.y * blockDim.y + threadIdx.y;
931 
932  int x[4];
933  int dx[4] = {0,0,0,0};
934 
935  getCoords(x, x_cb, arg.X, parity);
936 
937  for (int i=0; i<4; i++) x[i] += arg.border[i];
938  int e_cb = linkIndex(x,arg.E);
939 
940  /*
941  *
942  * A B C D E
943  * ---- ---- ---- ----
944  *
945  * ---> sig direction
946  *
947  * C is the current point (sid)
948  *
949  */
950 
951  // compute the force for forward long links
952 #pragma unroll
953  for (int sig=0; sig<4; sig++) {
954  int point_c = e_cb;
955 
956  dx[sig]++;
957  int point_d = linkIndexShift(x,dx,arg.E);
958 
959  dx[sig]++;
960  int point_e = linkIndexShift(x,dx,arg.E);
961 
962  dx[sig] = -1;
963  int point_b = linkIndexShift(x,dx,arg.E);
964 
965  dx[sig]--;
966  int point_a = linkIndexShift(x,dx,arg.E);
967  dx[sig] = 0;
968 
969  Link Uab = arg.link(sig, point_a, parity);
970  Link Ubc = arg.link(sig, point_b, 1-parity);
971  Link Ude = arg.link(sig, point_d, 1-parity);
972  Link Uef = arg.link(sig, point_e, parity);
973 
974  Link Oz = arg.oProd(sig, point_c, parity);
975  Link Oy = arg.oProd(sig, point_b, 1-parity);
976  Link Ox = arg.oProd(sig, point_a, parity);
977 
978  Link temp = Ude*Uef*Oz - Ude*Oy*Ubc + Ox*Uab*Ubc;
979 
980  Link force = arg.outA(sig, e_cb, parity);
981  arg.outA(sig, e_cb, parity) = force + arg.coeff*temp;
982  } // loop over sig
983 
984  }
985 
986  template <typename real, typename Arg>
987  class HisqForce : public TunableVectorY {
988 
989  Arg &arg;
990  const GaugeField &meta;
991  const HisqForceType type;
992 
993  unsigned int minThreads() const { return arg.threads; }
994  bool tuneGridDim() const { return false; }
995 
996  public:
997  HisqForce(Arg &arg, const GaugeField &meta, int sig, int mu, HisqForceType type)
998  : TunableVectorY(2), arg(arg), meta(meta), type(type) {
999  arg.sig = sig;
1000  arg.mu = mu;
1001  }
1002  virtual ~HisqForce() { }
1003 
1004  void apply(const cudaStream_t &stream) {
1005  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
1006  switch (type) {
1007  case FORCE_LONG_LINK:
1008  longLinkKernel<real,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg); break;
1009  case FORCE_COMPLETE:
1010  completeForceKernel<real,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg); break;
1011  default:
1012  errorQuda("Undefined force type %d", type);
1013  }
1014  }
1015 
1016  TuneKey tuneKey() const {
1017  std::stringstream aux;
1018  aux << "threads=" << arg.threads << ",prec=" << sizeof(real);
1019  switch (type) {
1020  case FORCE_LONG_LINK: aux << ",LONG_LINK"; break;
1021  case FORCE_COMPLETE: aux << ",COMPLETE"; break;
1022  default: errorQuda("Undefined force type %d", type);
1023  }
1024  return TuneKey(meta.VolString(), typeid(*this).name(), aux.str().c_str());
1025  }
1026 
1027  void preTune() {
1028  switch (type) {
1029  case FORCE_LONG_LINK: arg.outA.save(); break;
1030  case FORCE_COMPLETE: break;
1031  default: errorQuda("Undefined force type %d", type);
1032  }
1033  }
1034 
1035  void postTune() {
1036  switch (type) {
1037  case FORCE_LONG_LINK: arg.outA.load(); break;
1038  case FORCE_COMPLETE: break;
1039  default: errorQuda("Undefined force type %d", type);
1040  }
1041  }
1042 
1043  long long flops() const {
1044  switch (type) {
1045  case FORCE_LONG_LINK: return 2*arg.threads*4968ll;
1046  case FORCE_COMPLETE: return 2*arg.threads*792ll;
1047  default: errorQuda("Undefined force type %d", type);
1048  }
1049  return 0;
1050  }
1051 
1052  long long bytes() const {
1053  switch (type) {
1054  case FORCE_LONG_LINK: return 4*2*arg.threads*(2*arg.outA.Bytes() + 4*arg.link.Bytes() + 3*arg.oProd.Bytes());
1055  case FORCE_COMPLETE: return 4*2*arg.threads*(arg.outA.Bytes() + arg.link.Bytes() + arg.oProd.Bytes());
1056  default: errorQuda("Undefined force type %d", type);
1057  }
1058  return 0;
1059  }
1060  };
1061 
1062  void hisqLongLinkForce(GaugeField &newOprod, const GaugeField &oldOprod, const GaugeField &link, double coeff, long long* flops)
1063  {
1064  if (!link.isNative()) errorQuda("Unsupported gauge order %d", link.Order());
1065  if (!oldOprod.isNative()) errorQuda("Unsupported gauge order %d", oldOprod.Order());
1066  if (!newOprod.isNative()) errorQuda("Unsupported gauge order %d", newOprod.Order());
1067  if (checkLocation(newOprod,oldOprod,link) == QUDA_CPU_FIELD_LOCATION) errorQuda("CPU not implemented");
1068 
1069  QudaPrecision precision = checkPrecision(newOprod, link, oldOprod);
1070  if (precision == QUDA_DOUBLE_PRECISION) {
1071  if (link.Reconstruct() == QUDA_RECONSTRUCT_NO) {
1072  typedef LongLinkArg<double,QUDA_RECONSTRUCT_NO> Arg;
1073  Arg arg(newOprod, link, oldOprod, coeff);
1074  HisqForce<double,Arg> longLink(arg, link, 0, 0, FORCE_LONG_LINK);
1075  longLink.apply(0);
1076  if (flops) (*flops) += longLink.flops();
1077  } else {
1078  errorQuda("Reconstruct %d not supported", link.Reconstruct());
1079  }
1080  } else if (precision == QUDA_SINGLE_PRECISION) {
1081  if (link.Reconstruct() == QUDA_RECONSTRUCT_NO) {
1082  typedef LongLinkArg<float,QUDA_RECONSTRUCT_NO> Arg;
1083  Arg arg(newOprod, link, oldOprod, coeff);
1084  HisqForce<float, Arg> longLink(arg, link, 0, 0, FORCE_LONG_LINK);
1085  longLink.apply(0);
1086  if (flops) (*flops) += longLink.flops();
1087  } else {
1088  errorQuda("Reconstruct %d not supported", link.Reconstruct());
1089  }
1090  } else {
1091  errorQuda("Unsupported precision %d", precision);
1092  }
1093  checkCudaError();
1094  cudaDeviceSynchronize();
1095  }
1096 
1097  void hisqCompleteForce(GaugeField &force, const GaugeField &oprod, const GaugeField &link, long long* flops)
1098  {
1099  if (!link.isNative()) errorQuda("Unsupported gauge order %d", link.Order());
1100  if (!oprod.isNative()) errorQuda("Unsupported gauge order %d", oprod.Order());
1101  if (!force.isNative()) errorQuda("Unsupported gauge order %d", force.Order());
1102  if (checkLocation(force,oprod,link) == QUDA_CPU_FIELD_LOCATION) errorQuda("CPU not implemented");
1103 
1104  QudaPrecision precision = checkPrecision(oprod, link, force);
1105  if (precision == QUDA_DOUBLE_PRECISION) {
1106  if (link.Reconstruct() == QUDA_RECONSTRUCT_NO) {
1107  typedef CompleteForceArg<double,QUDA_RECONSTRUCT_NO> Arg;
1108  Arg arg(force, link, oprod);
1109  HisqForce<double,Arg> completeForce(arg, link, 0, 0, FORCE_COMPLETE);
1110  completeForce.apply(0);
1111  if (flops) *flops += completeForce.flops();
1112  } else {
1113  errorQuda("Reconstruct %d not supported", link.Reconstruct());
1114  }
1115  } else if (precision == QUDA_SINGLE_PRECISION) {
1116  if (link.Reconstruct() == QUDA_RECONSTRUCT_NO) {
1117  typedef CompleteForceArg<float,QUDA_RECONSTRUCT_NO> Arg;
1118  Arg arg(force, link, oprod);
1119  HisqForce<float, Arg> completeForce(arg, link, 0, 0, FORCE_COMPLETE);
1120  completeForce.apply(0);
1121  if (flops) *flops += completeForce.flops();
1122  } else {
1123  errorQuda("Reconstruct %d not supported", link.Reconstruct());
1124  }
1125  } else {
1126  errorQuda("Unsupported precision %d", precision);
1127  }
1128  checkCudaError();
1129  cudaDeviceSynchronize();
1130  }
1131  } // namespace fermion_force
1132 } // namespace quda
1133 
1134 #endif // GPU_HISQ_FORCE
dim3 dim3 blockDim
double mu
Definition: test_util.cpp:1643
enum QudaPrecision_s QudaPrecision
static __device__ __host__ int linkIndexShift(const I x[], const J dx[], const K X[4])
static __device__ __host__ int linkIndex(const int x[], const I X[4])
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:20
#define checkPrecision(...)
#define errorQuda(...)
Definition: util_quda.h:90
#define TDOWN
Definition: misc.h:64
cudaStream_t * stream
#define XUP
QudaGaugeParam gauge_param
int E[4]
Definition: test_util.cpp:36
#define YUP
#define Qmu
void hisqLongLinkForce(GaugeField &newOprod, const GaugeField &oprod, const GaugeField &link, double coeff, long long *flops=nullptr)
Compute the long-link contribution to the fermion force.
void hisqStaplesForce(GaugeField &newOprod, const GaugeField &oprod, const GaugeField &link, const double path_coeff[6], long long *flops=nullptr)
Compute the fat-link contribution to the fermion force.
int int int w
#define ZUP
int commDim(int)
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:603
#define checkLocation(...)
static unsigned int unsigned int shift
Main header file for host and device accessors to GaugeFields.
QudaReconstructType reconstruct
Definition: quda.h:43
#define TUP
void hisqCompleteForce(GaugeField &momentum, const GaugeField &oprod, const GaugeField &link, long long *flops=nullptr)
Multiply the computed the force matrix by the gauge field and perform traceless anti-hermitian projec...
#define Pmu
#define P5
unsigned long long flops
Definition: blas_quda.cu:42
#define Qnumu
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
Definition: complex_quda.h:880
#define XDOWN
Definition: misc.h:67
#define ZDOWN
Definition: misc.h:65
__device__ __host__ void makeAntiHerm(Matrix< Complex, N > &m)
Definition: quda_matrix.h:636
#define checkCudaError()
Definition: util_quda.h:129
#define Pnumu
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:115
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:51
static __inline__ size_t size_t d
QudaParity parity
Definition: covdev_test.cpp:53
#define YDOWN
Definition: misc.h:66
unsigned long long bytes
Definition: blas_quda.cu:43
#define P3
int comm_dim_partitioned(int dim)
static __device__ __host__ void getCoords(int x[], int cb_index, const I X[], int parity)