QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
hisq_paths_force_quda.cu
Go to the documentation of this file.
1 #include <utility>
2 #include <typeinfo>
3 #include <quda_internal.h>
4 #include <gauge_field.h>
5 #include <ks_improved_force.h>
6 #include <quda_matrix.h>
7 #include <tune_quda.h>
8 #include <index_helper.cuh>
9 #include <gauge_field_order.h>
10 
11 #ifdef GPU_HISQ_FORCE
12 
13 namespace quda {
14 
15  namespace fermion_force {
16 
17  enum {
18  XUP = 0,
19  YUP = 1,
20  ZUP = 2,
21  TUP = 3,
22  TDOWN = 4,
23  ZDOWN = 5,
24  YDOWN = 6,
25  XDOWN = 7
26  };
27 
28  enum HisqForceType {
29  FORCE_ALL_LINK,
30  FORCE_MIDDLE_LINK,
31  FORCE_LEPAGE_MIDDLE_LINK,
32  FORCE_SIDE_LINK,
33  FORCE_SIDE_LINK_SHORT,
34  FORCE_LONG_LINK,
35  FORCE_COMPLETE,
36  FORCE_ONE_LINK,
37  FORCE_INVALID
38  };
39 
40  __device__ __host__ constexpr inline int opp_dir(int dir) { return 7-dir; }
41  __device__ __host__ constexpr inline int goes_forward(int dir) { return dir<=3; }
42  __device__ __host__ constexpr inline int goes_backward(int dir) { return dir>3; }
43  __device__ __host__ constexpr inline int CoeffSign(int pos_dir, int odd_lattice) { return 2*((pos_dir + odd_lattice + 1) & 1) - 1; }
44  __device__ __host__ constexpr inline int Sign(int parity) { return parity ? -1 : 1; }
45  __device__ __host__ constexpr inline int posDir(int dir) { return (dir >= 4) ? 7-dir : dir; }
46 
47  template <int dir, typename Arg>
48  inline __device__ __host__ void updateCoords(int x[], int shift, const Arg &arg) {
49  x[dir] = (x[dir] + shift + arg.E[dir]) % arg.E[dir];
50  }
51 
52  template <typename Arg>
53  inline __device__ __host__ void updateCoords(int x[], int dir, int shift, const Arg &arg) {
54  switch (dir) {
55  case 0: updateCoords<0>(x, shift, arg); break;
56  case 1: updateCoords<1>(x, shift, arg); break;
57  case 2: updateCoords<2>(x, shift, arg); break;
58  case 3: updateCoords<3>(x, shift, arg); break;
59  }
60  }
61 
62  //struct for holding the fattening path coefficients
63  template <typename real>
64  struct PathCoefficients {
65  const real one;
66  const real three;
67  const real five;
68  const real seven;
69  const real naik;
70  const real lepage;
71  PathCoefficients(const double *path_coeff_array)
72  : one(path_coeff_array[0]), naik(path_coeff_array[1]),
73  three(path_coeff_array[2]), five(path_coeff_array[3]),
74  seven(path_coeff_array[4]), lepage(path_coeff_array[5]) { }
75  };
76 
77  template <typename real, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
78  struct BaseForceArg {
79  typedef typename gauge_mapper<real,reconstruct>::type G;
80  const G link;
81  int threads;
82  int X[4]; // regular grid dims
83  int D[4]; // working set grid dims
84  int E[4]; // extended grid dims
85 
86  int commDim[4];
87  int border[4];
88  int base_idx[4]; // the offset into the extended field
89  int oddness_change;
90  int mu;
91  int sig;
92 
97  BaseForceArg(const GaugeField &link, int overlap) : link(link), threads(1),
99  {
100  for (int d=0; d<4; d++) {
101  E[d] = link.X()[d];
102  border[d] = link.R()[d];
103  X[d] = E[d] - 2*border[d];
104  D[d] = comm_dim_partitioned(d) ? X[d]+overlap*2 : X[d];
105  base_idx[d] = comm_dim_partitioned(d) ? border[d]-overlap : 0;
106  threads *= D[d];
107  }
108  threads /= 2;
109  oddness_change = (base_idx[0] + base_idx[1] + base_idx[2] + base_idx[3])&1;
110  }
111  };
112 
113  template <typename real, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
114  struct FatLinkArg : public BaseForceArg<real,reconstruct> {
115 
116  typedef typename gauge_mapper<real,QUDA_RECONSTRUCT_NO>::type F;
117  F outA;
118  F outB;
119  F pMu;
120  F p3;
121  F qMu;
122 
123  const F oProd;
124  const F qProd;
125  const F qPrev;
126  const real coeff;
127  const real accumu_coeff;
128 
129  const bool p_mu;
130  const bool q_mu;
131  const bool q_prev;
132 
133  FatLinkArg(GaugeField &force, const GaugeField &oProd, const GaugeField &link, real coeff, HisqForceType type)
134  : BaseForceArg<real,reconstruct>(link, 0), outA(force), outB(force), pMu(oProd), p3(oProd), qMu(oProd),
135  oProd(oProd), qProd(oProd), qPrev(oProd), coeff(coeff), accumu_coeff(0),
136  p_mu(false), q_mu(false), q_prev(false)
137  { if (type != FORCE_ONE_LINK) errorQuda("This constructor is for FORCE_ONE_LINK"); }
138 
139  FatLinkArg(GaugeField &newOprod, GaugeField &pMu, GaugeField &P3, GaugeField &qMu,
140  const GaugeField &oProd, const GaugeField &qPrev, const GaugeField &link,
141  real coeff, int overlap, HisqForceType type)
142  : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(newOprod), pMu(pMu), p3(P3), qMu(qMu),
143  oProd(oProd), qProd(oProd), qPrev(qPrev), coeff(coeff), accumu_coeff(0), p_mu(true), q_mu(true), q_prev(true)
144  { if (type != FORCE_MIDDLE_LINK) errorQuda("This constructor is for FORCE_MIDDLE_LINK"); }
145 
146  FatLinkArg(GaugeField &newOprod, GaugeField &pMu, GaugeField &P3, GaugeField &qMu,
147  const GaugeField &oProd, const GaugeField &link,
148  real coeff, int overlap, HisqForceType type)
149  : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(newOprod), pMu(pMu), p3(P3), qMu(qMu),
150  oProd(oProd), qProd(oProd), qPrev(qMu), coeff(coeff), accumu_coeff(0), p_mu(true), q_mu(true), q_prev(false)
151  { if (type != FORCE_MIDDLE_LINK) errorQuda("This constructor is for FORCE_MIDDLE_LINK"); }
152 
153  FatLinkArg(GaugeField &newOprod, GaugeField &P3, const GaugeField &oProd,
154  const GaugeField &qPrev, const GaugeField &link,
155  real coeff, int overlap, HisqForceType type)
156  : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(newOprod), pMu(P3), p3(P3), qMu(qPrev),
157  oProd(oProd), qProd(oProd), qPrev(qPrev), coeff(coeff), accumu_coeff(0), p_mu(false), q_mu(false), q_prev(true)
158  { if (type != FORCE_LEPAGE_MIDDLE_LINK) errorQuda("This constructor is for FORCE_MIDDLE_LINK"); }
159 
160  FatLinkArg(GaugeField &newOprod, GaugeField &shortP, const GaugeField &P3,
161  const GaugeField &qProd, const GaugeField &link, real coeff, real accumu_coeff, int overlap, HisqForceType type)
162  : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(shortP), pMu(P3), p3(P3), qMu(qProd), oProd(qProd), qProd(qProd),
163  qPrev(qProd), coeff(coeff), accumu_coeff(accumu_coeff),
164  p_mu(false), q_mu(false), q_prev(false)
165  { if (type != FORCE_SIDE_LINK) errorQuda("This constructor is for FORCE_SIDE_LINK or FORCE_ALL_LINK"); }
166 
167  FatLinkArg(GaugeField &newOprod, GaugeField &P3, const GaugeField &link,
168  real coeff, int overlap, HisqForceType type)
169  : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(newOprod),
170  pMu(P3), p3(P3), qMu(P3), oProd(P3), qProd(P3), qPrev(P3), coeff(coeff), accumu_coeff(0.0),
171  p_mu(false), q_mu(false), q_prev(false)
172  { if (type != FORCE_SIDE_LINK_SHORT) errorQuda("This constructor is for FORCE_SIDE_LINK_SHORT"); }
173 
174  FatLinkArg(GaugeField &newOprod, GaugeField &shortP, const GaugeField &oProd, const GaugeField &qPrev,
175  const GaugeField &link, real coeff, real accumu_coeff, int overlap, HisqForceType type, bool dummy)
176  : BaseForceArg<real,reconstruct>(link, overlap), outA(newOprod), outB(shortP), oProd(oProd), qPrev(qPrev),
177  pMu(shortP), p3(shortP), qMu(qPrev), qProd(qPrev), // dummy
178  coeff(coeff), accumu_coeff(accumu_coeff), p_mu(false), q_mu(false), q_prev(false)
179  { if (type != FORCE_ALL_LINK) errorQuda("This constructor is for FORCE_ALL_LINK"); }
180 
181  };
182 
183  template <typename real, typename Arg>
184  __global__ void oneLinkTermKernel(Arg arg)
185  {
186  typedef Matrix<complex<real>,3> Link;
187  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
188  if (x_cb >= arg.threads) return;
189  int parity = blockIdx.y * blockDim.y + threadIdx.y;
190  int sig = blockIdx.z * blockDim.z + threadIdx.z;
191  if (sig >= 4) return;
192 
193  int x[4];
194  getCoords(x, x_cb, arg.X, parity);
195 #pragma unroll
196  for (int d=0; d<4; d++) x[d] += arg.border[d];
197  int e_cb = linkIndex(x,arg.E);
198 
199  Link w = arg.oProd(sig, e_cb, parity);
200  Link force = arg.outA(sig, e_cb, parity);
201  force += arg.coeff * w;
202  arg.outA(sig, e_cb, parity) = force;
203  }
204 
205 
206  /********************************allLinkKernel*********************************************
207  *
208  * In this function we need
209  * READ
210  * 3 LINKS: ad_link, ab_link, bc_link
211  * 5 COLOR MATRIX: Qprev_at_D, oprod_at_C, newOprod_at_A(sig), newOprod_at_D/newOprod_at_A(mu), shortP_at_D
212  * WRITE:
213  * 3 COLOR MATRIX: newOprod_at_A(sig), newOprod_at_D/newOprod_at_A(mu), shortP_at_D,
214  *
215  * If sig is negative, then we don't need to read/write the color matrix newOprod_at_A(sig)
216  *
217  * Therefore the data traffic, in two-number pair (num_of_link, num_of_color_matrix)
218  *
219  * if (sig is positive): (3, 8)
220  * else : (3, 6)
221  *
222  * This function is called 384 times, half positive sig, half negative sig
223  *
224  * Flop count, in two-number pair (matrix_multi, matrix_add)
225  * if(sig is positive) (6,3)
226  * else (4,2)
227  *
228  ************************************************************************************************/
229  template<typename real, int sig_positive, int mu_positive, typename Arg>
230  __global__ void allLinkKernel(Arg arg)
231  {
232  typedef Matrix<complex<real>,3> Link;
233 
234  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
235  if (x_cb >= arg.threads) return;
236  int parity = blockIdx.y * blockDim.y + threadIdx.y;
237 
238  int x[4];
239  getCoords(x, x_cb, arg.D, parity);
240  for (int d=0; d<4; d++) x[d] += arg.base_idx[d];
241  int e_cb = linkIndex(x,arg.E);
242  parity = parity^arg.oddness_change;
243 
244  real mycoeff = CoeffSign(sig_positive,parity)*arg.coeff;
245 
246  int y[4] = {x[0], x[1], x[2], x[3]};
247  int mysig = posDir(arg.sig);
248  updateCoords(y, mysig, (sig_positive ? 1 : -1), arg);
249  int point_b = linkIndex(y,arg.E);
250  int ab_link_nbr_idx = (sig_positive) ? e_cb : point_b;
251 
252  for (int d=0; d<4; d++) y[d] = x[d];
253 
254  /* sig
255  * A________B
256  * mu | |
257  * D | |C
258  *
259  * A is the current point (sid)
260  *
261  */
262 
263  int mu = mu_positive ? arg.mu : opp_dir(arg.mu);
264  int dir = mu_positive ? -1 : 1;
265 
266  updateCoords(y, mu, dir, arg);
267  int point_d = linkIndex(y,arg.E);
268  updateCoords(y, mysig, (sig_positive ? 1 : -1), arg);
269  int point_c = linkIndex(y,arg.E);
270 
271  Link Uab = arg.link(posDir(arg.sig), ab_link_nbr_idx, sig_positive^(1-parity));
272  Link Uad = arg.link(mu, mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity);
273  Link Ubc = arg.link(mu, mu_positive ? point_c : point_b, mu_positive ? parity : 1-parity);
274  Link Ox = arg.qPrev(0, point_d, 1-parity);
275  Link Oy = arg.oProd(0, point_c, parity);
276  Link Oz = mu_positive ? conj(Ubc)*Oy : Ubc*Oy;
277 
278  if (sig_positive) {
279  Link force = arg.outA(arg.sig, e_cb, parity);
280  force += Sign(parity)*mycoeff*Oz*Ox* (mu_positive ? Uad : conj(Uad));
281  arg.outA(arg.sig, e_cb, parity) = force;
282  Oy = Uab*Oz;
283  } else {
284  Oy = conj(Uab)*Oz;
285  }
286 
287  Link force = arg.outA(mu, mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity);
288  force += Sign(mu_positive ? 1-parity : parity)*mycoeff* (mu_positive ? Oy*Ox : conj(Ox)*conj(Oy));
289  arg.outA(mu, mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity) = force;
290 
291  Link shortP = arg.outB(0, point_d, 1-parity);
292  shortP += arg.accumu_coeff* (mu_positive ? Uad : conj(Uad)) *Oy;
293  arg.outB(0, point_d, 1-parity) = shortP;
294  }
295 
296 
297  /**************************middleLinkKernel*****************************
298  *
299  *
300  * Generally we need
301  * READ
302  * 3 LINKS: ab_link, bc_link, ad_link
303  * 3 COLOR MATRIX: newOprod_at_A, oprod_at_C, Qprod_at_D
304  * WRITE
305  * 4 COLOR MATRIX: newOprod_at_A, P3_at_A, Pmu_at_B, Qmu_at_A
306  *
307  * Three call variations:
308  * 1. when Qprev == NULL: Qprod_at_D does not exist and is not read in
309  * 2. full read/write
310  * 3. when Pmu/Qmu == NULL, Pmu_at_B and Qmu_at_A are not written out
311  *
312  * In all three above case, if the direction sig is negative, newOprod_at_A is
313  * not read in or written out.
314  *
315  * Therefore the data traffic, in two-number pair (num_of_link, num_of_color_matrix)
316  * Call 1: (called 48 times, half positive sig, half negative sig)
317  * if (sig is positive): (3, 6)
318  * else : (3, 4)
319  * Call 2: (called 192 time, half positive sig, half negative sig)
320  * if (sig is positive): (3, 7)
321  * else : (3, 5)
322  * Call 3: (called 48 times, half positive sig, half negative sig)
323  * if (sig is positive): (3, 5)
324  * else : (3, 2) no need to loadQprod_at_D in this case
325  *
326  * note: oprod_at_C could actually be read in from D when it is the fresh outer product
327  * and we call it oprod_at_C to simply naming. This does not affect our data traffic analysis
328  *
329  * Flop count, in two-number pair (matrix_multi, matrix_add)
330  * call 1: if (sig is positive) (3, 1)
331  * else (2, 0)
332  * call 2: if (sig is positive) (4, 1)
333  * else (3, 0)
334  * call 3: if (sig is positive) (4, 1)
335  * (Lepage) else (2, 0)
336  *
337  ****************************************************************************/
338  template <typename real, int sig_positive, int mu_positive, bool pMu, bool qMu, bool qPrev, typename Arg>
339  __global__ void middleLinkKernel(Arg arg)
340  {
341  typedef Matrix<complex<real>,3> Link;
342 
343  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
344  if (x_cb >= arg.threads) return;
345  int parity = blockIdx.y * blockDim.y + threadIdx.y;
346 
347  int x[4];
348  getCoords(x, x_cb, arg.D, parity);
349 
350  /* A________B
351  * mu | |
352  * D| |C
353  *
354  * A is the current point (sid)
355  *
356  */
357 
358  for (int d=0; d<4; d++) x[d] += arg.base_idx[d];
359  int e_cb = linkIndex(x,arg.E);
360  parity = parity ^ arg.oddness_change;
361  int y[4] = {x[0], x[1], x[2], x[3]};
362 
363  int mymu = posDir(arg.mu);
364  updateCoords(y, mymu, (mu_positive ? -1 : 1), arg);
365 
366  int point_d = linkIndex(y, arg.E);
367  int ad_link_nbr_idx = mu_positive ? point_d : e_cb;
368 
369  int mysig = posDir(arg.sig);
370  updateCoords(y, mysig, (sig_positive ? 1 : -1), arg);
371  int point_c = linkIndex(y, arg.E);
372 
373  for (int d=0; d<4; d++) y[d] = x[d];
374  updateCoords(y, mysig, (sig_positive ? 1 : -1), arg);
375  int point_b = linkIndex(y, arg.E);
376 
377  int bc_link_nbr_idx = mu_positive ? point_c : point_b;
378  int ab_link_nbr_idx = sig_positive ? e_cb : point_b;
379 
380  // load the link variable connecting a and b
381  Link Uab = arg.link(mysig, ab_link_nbr_idx, sig_positive^(1-parity));
382 
383  // load the link variable connecting b and c
384  Link Ubc = arg.link(mymu, bc_link_nbr_idx, mu_positive^(1-parity));
385 
386  Link Oy;
387  if (!qPrev) {
388  Oy = arg.oProd(posDir(arg.sig), sig_positive ? point_d : point_c, sig_positive^parity);
389  if (!sig_positive) Oy = conj(Oy);
390  } else { // QprevOdd != NULL
391  Oy = arg.oProd(0, point_c, parity);
392  }
393 
394  Link Ow = !mu_positive ? Ubc*Oy : conj(Ubc)*Oy;
395 
396  if (pMu) arg.pMu(0, point_b, 1-parity) = Ow;
397 
398  arg.p3(0, e_cb, parity) = sig_positive ? Uab*Ow : conj(Uab)*Ow;
399 
400  Link Uad = arg.link(mymu, ad_link_nbr_idx, mu_positive^parity);
401  if (!mu_positive) Uad = conj(Uad);
402 
403  if (!qPrev) {
404  if (sig_positive) Oy = Ow*Uad;
405  if ( qMu ) arg.qMu(0, e_cb, parity) = Uad;
406  } else {
407  Link Ox;
408  if ( qMu || sig_positive ) {
409  Oy = arg.qPrev(0, point_d, 1-parity);
410  Ox = Oy*Uad;
411  }
412  if ( qMu ) arg.qMu(0, e_cb, parity) = Ox;
413  if (sig_positive) Oy = Ow*Ox;
414  }
415 
416  if (sig_positive) {
417  Link oprod = arg.outA(arg.sig, e_cb, parity);
418  oprod += arg.coeff*Oy;
419  arg.outA(arg.sig, e_cb, parity) = oprod;
420  }
421 
422  }
423 
424  /***********************************sideLinkKernel***************************
425  *
426  * In general we need
427  * READ
428  * 1 LINK: ad_link
429  * 4 COLOR MATRIX: shortP_at_D, newOprod, P3_at_A, Qprod_at_D,
430  * WRITE
431  * 2 COLOR MATRIX: shortP_at_D, newOprod,
432  *
433  * Two call variations:
434  * 1. full read/write
435  * 2. when shortP == NULL && Qprod == NULL:
436  * no need to read ad_link/shortP_at_D or write shortP_at_D
437  * Qprod_at_D does not exit and is not read in
438  *
439  *
440  * Therefore the data traffic, in two-number pair (num_of_links, num_of_color_matrix)
441  * Call 1: (called 192 times)
442  * (1, 6)
443  *
444  * Call 2: (called 48 times)
445  * (0, 3)
446  *
447  * note: newOprod can be at point D or A, depending on if mu is postive or negative
448  *
449  * Flop count, in two-number pair (matrix_multi, matrix_add)
450  * call 1: (2, 2)
451  * call 2: (0, 1)
452  *
453  *********************************************************************************/
454  template <typename real, int mu_positive, typename Arg>
455  __global__ void sideLinkKernel(Arg arg)
456  {
457  typedef Matrix<complex<real>, 3> Link;
458  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
459  if (x_cb >= arg.threads) return;
460  int parity = blockIdx.y * blockDim.y + threadIdx.y;
461 
462  int x[4];
463  getCoords(x, x_cb ,arg.D, parity);
464  for (int d=0; d<4; d++) x[d] = x[d] + arg.base_idx[d];
465  int e_cb = linkIndex(x,arg.E);
466  parity = parity ^ arg.oddness_change;
467 
468  /* compute the side link contribution to the momentum
469  *
470  * sig
471  * A________B
472  * | | mu
473  * D | |C
474  *
475  * A is the current point (x_cb)
476  *
477  */
478 
479  int mymu = posDir(arg.mu);
480  int y[4] = {x[0], x[1], x[2], x[3]};
481  updateCoords(y, mymu, (mu_positive ? -1 : 1), arg);
482  int point_d = linkIndex(y,arg.E);
483 
484  Link Oy = arg.p3(0, e_cb, parity);
485 
486  {
487  int ad_link_nbr_idx = mu_positive ? point_d : e_cb;
488 
489  Link Uad = arg.link(mymu, ad_link_nbr_idx, mu_positive^parity);
490  Link Ow = mu_positive ? Uad*Oy : conj(Uad)*Oy;
491 
492  Link shortP = arg.outB(0, point_d, 1-parity);
493  shortP += arg.accumu_coeff * Ow;
494  arg.outB(0, point_d, 1-parity) = shortP;
495  }
496 
497  {
498  Link Ox = arg.qProd(0, point_d, 1-parity);
499  Link Ow = mu_positive ? Oy*Ox : conj(Ox)*conj(Oy);
500 
501  real mycoeff = CoeffSign(goes_forward(arg.sig), parity)*CoeffSign(goes_forward(arg.mu),parity)*arg.coeff;
502 
503  Link oprod = arg.outA(mu_positive ? arg.mu : opp_dir(arg.mu), mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity);
504  oprod += mycoeff * Ow;
505  arg.outA(mu_positive ? arg.mu : opp_dir(arg.mu), mu_positive ? point_d : e_cb, mu_positive ? 1-parity : parity) = oprod;
506  }
507  }
508 
509  // Flop count, in two-number pair (matrix_mult, matrix_add)
510  // (0,1)
511  template<typename real, int mu_positive, typename Arg>
512  __global__ void sideLinkShortKernel(Arg arg)
513  {
514  typedef Matrix<complex<real>,3> Link;
515  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
516  if (x_cb >= arg.threads) return;
517  int parity = blockIdx.y * blockDim.y + threadIdx.y;
518 
519  int x[4];
520  getCoords(x, x_cb, arg.D, parity);
521  for (int d=0; d<4; d++) x[d] = x[d] + arg.base_idx[d];
522  int e_cb = linkIndex(x,arg.E);
523  parity = parity ^ arg.oddness_change;
524 
525  /* compute the side link contribution to the momentum
526  *
527  * sig
528  * A________B
529  * | | mu
530  * D | |C
531  *
532  * A is the current point (x_cb)
533  *
534  */
535  int mymu = posDir(arg.mu);
536  int y[4] = {x[0], x[1], x[2], x[3]};
537  updateCoords(y, mymu, (mu_positive ? -1 : 1), arg);
538  int point_d = mu_positive ? linkIndex(y,arg.E) : e_cb;
539 
540  int parity_ = mu_positive ? 1-parity : parity;
541  real mycoeff = CoeffSign(goes_forward(arg.sig),parity)*CoeffSign(goes_forward(arg.mu),parity)*arg.coeff;
542 
543  Link Oy = arg.p3(0, e_cb, parity);
544  Link oprod = arg.outA(posDir(arg.mu), point_d, parity_);
545  oprod += mu_positive ? mycoeff * Oy : mycoeff * conj(Oy);
546  arg.outA(posDir(arg.mu), point_d, parity_) = oprod;
547  }
548 
549  template <typename real, typename Arg>
550  class FatLinkForce : public TunableVectorYZ {
551 
552  private:
553  Arg &arg;
554  const GaugeField &meta;
555  const HisqForceType type;
556 
557  unsigned int minThreads() const { return arg.threads; }
558  bool tuneGridDim() const { return false; }
559 
560  public:
561  FatLinkForce(Arg &arg, const GaugeField &meta, int sig, int mu, HisqForceType type)
562  : TunableVectorYZ(2,type == FORCE_ONE_LINK ? 4 : 1), arg(arg), meta(meta), type(type) {
563  arg.sig = sig;
564  arg.mu = mu;
565  }
566  virtual ~FatLinkForce() { }
567 
568  TuneKey tuneKey() const {
569  std::stringstream aux;
570  if (type == FORCE_ONE_LINK) aux << "threads=" << arg.threads;
571  else if (type == FORCE_MIDDLE_LINK || type == FORCE_LEPAGE_MIDDLE_LINK)
572  aux << "threads=" << arg.threads << ",sig=" << arg.sig << ",mu=" << arg.mu <<
573  ",pMu=" << arg.p_mu << ",q_muu=" << arg.q_mu << ",q_prev=" << arg.q_prev;
574  else
575  aux << "threads=" << arg.threads << ",mu=" << arg.mu; // no sig dependence needed for side link
576 
577  switch (type) {
578  case FORCE_ONE_LINK: aux << ",ONE_LINK"; break;
579  case FORCE_ALL_LINK: aux << ",ALL_LINK"; break;
580  case FORCE_MIDDLE_LINK: aux << ",MIDDLE_LINK"; break;
581  case FORCE_LEPAGE_MIDDLE_LINK: aux << ",LEPAGE_MIDDLE_LINK"; break;
582  case FORCE_SIDE_LINK: aux << ",SIDE_LINK"; break;
583  case FORCE_SIDE_LINK_SHORT: aux << ",SIDE_LINK_SHORT"; break;
584  default: errorQuda("Undefined force type %d", type);
585  }
586  return TuneKey(meta.VolString(), typeid(*this).name(), aux.str().c_str());
587  }
588 
589  void apply(const cudaStream_t &stream) {
590  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
591  switch (type) {
592  case FORCE_ONE_LINK:
593  oneLinkTermKernel<real,Arg> <<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
594  break;
595  case FORCE_ALL_LINK:
596  if (goes_forward(arg.sig) && goes_forward(arg.mu))
597  allLinkKernel<real,1,1,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
598  else if (goes_forward(arg.sig) && goes_backward(arg.mu))
599  allLinkKernel<real,1,0,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
600  else if (goes_backward(arg.sig) && goes_forward(arg.mu))
601  allLinkKernel<real,0,1,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
602  else
603  allLinkKernel<real,0,0,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
604  break;
605  case FORCE_MIDDLE_LINK:
606  if (!arg.p_mu || !arg.q_mu) errorQuda("Expect p_mu=%d and q_mu=%d to both be true", arg.p_mu, arg.q_mu);
607  if (arg.q_prev) {
608  if (goes_forward(arg.sig) && goes_forward(arg.mu))
609  middleLinkKernel<real,1,1,true,true,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
610  else if (goes_forward(arg.sig) && goes_backward(arg.mu))
611  middleLinkKernel<real,1,0,true,true,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
612  else if (goes_backward(arg.sig) && goes_forward(arg.mu))
613  middleLinkKernel<real,0,1,true,true,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
614  else
615  middleLinkKernel<real,0,0,true,true,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
616  } else {
617  if (goes_forward(arg.sig) && goes_forward(arg.mu))
618  middleLinkKernel<real,1,1,true,true,false,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
619  else if (goes_forward(arg.sig) && goes_backward(arg.mu))
620  middleLinkKernel<real,1,0,true,true,false,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
621  else if (goes_backward(arg.sig) && goes_forward(arg.mu))
622  middleLinkKernel<real,0,1,true,true,false,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
623  else
624  middleLinkKernel<real,0,0,true,true,false,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
625  }
626  break;
627  case FORCE_LEPAGE_MIDDLE_LINK:
628  if (arg.p_mu || arg.q_mu || !arg.q_prev)
629  errorQuda("Expect p_mu=%d and q_mu=%d to both be false and q_prev=%d true", arg.p_mu, arg.q_mu, arg.q_prev);
630  if (goes_forward(arg.sig) && goes_forward(arg.mu))
631  middleLinkKernel<real,1,1,false,false,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
632  else if (goes_forward(arg.sig) && goes_backward(arg.mu))
633  middleLinkKernel<real,1,0,false,false,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
634  else if (goes_backward(arg.sig) && goes_forward(arg.mu))
635  middleLinkKernel<real,0,1,false,false,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
636  else
637  middleLinkKernel<real,0,0,false,false,true,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
638  break;
639  case FORCE_SIDE_LINK:
640  if (goes_forward(arg.mu)) sideLinkKernel<real,1,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
641  else sideLinkKernel<real,0,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
642  break;
643  case FORCE_SIDE_LINK_SHORT:
644  if (goes_forward(arg.mu)) sideLinkShortKernel<real,1,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
645  else sideLinkShortKernel<real,0,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
646  break;
647  default:
648  errorQuda("Undefined force type %d", type);
649  }
650  }
651 
652  void preTune() {
653  switch (type) {
654  case FORCE_ONE_LINK:
655  arg.outA.save();
656  break;
657  case FORCE_ALL_LINK:
658  arg.outA.save();
659  arg.outB.save();
660  break;
661  case FORCE_MIDDLE_LINK:
662  arg.pMu.save();
663  arg.qMu.save();
664  case FORCE_LEPAGE_MIDDLE_LINK:
665  arg.outA.save();
666  arg.p3.save();
667  break;
668  case FORCE_SIDE_LINK:
669  arg.outB.save();
670  case FORCE_SIDE_LINK_SHORT:
671  arg.outA.save();
672  break;
673  default: errorQuda("Undefined force type %d", type);
674  }
675  }
676 
677  void postTune() {
678  switch (type) {
679  case FORCE_ONE_LINK:
680  arg.outA.load();
681  break;
682  case FORCE_ALL_LINK:
683  arg.outA.load();
684  arg.outB.load();
685  break;
686  case FORCE_MIDDLE_LINK:
687  arg.pMu.load();
688  arg.qMu.load();
689  case FORCE_LEPAGE_MIDDLE_LINK:
690  arg.outA.load();
691  arg.p3.load();
692  break;
693  case FORCE_SIDE_LINK:
694  arg.outB.load();
695  case FORCE_SIDE_LINK_SHORT:
696  arg.outA.load();
697  break;
698  default: errorQuda("Undefined force type %d", type);
699  }
700  }
701 
702  long long flops() const {
703  switch (type) {
704  case FORCE_ONE_LINK:
705  return 2*4*arg.threads*36ll;
706  case FORCE_ALL_LINK:
707  return 2*arg.threads*(goes_forward(arg.sig) ? 1242ll : 828ll);
708  case FORCE_MIDDLE_LINK:
709  case FORCE_LEPAGE_MIDDLE_LINK:
710  return 2*arg.threads*(2 * 198 +
711  (!arg.q_prev && goes_forward(arg.sig) ? 198 : 0) +
712  (arg.q_prev && (arg.q_mu || goes_forward(arg.sig) ) ? 198 : 0) +
713  ((arg.q_prev && goes_forward(arg.sig) ) ? 198 : 0) +
714  ( goes_forward(arg.sig) ? 216 : 0) );
715  case FORCE_SIDE_LINK: return 2*arg.threads*2*234;
716  case FORCE_SIDE_LINK_SHORT: return 2*arg.threads*36;
717  default: errorQuda("Undefined force type %d", type);
718  }
719  return 0;
720  }
721 
722  long long bytes() const {
723  switch (type) {
724  case FORCE_ONE_LINK:
725  return 2*4*arg.threads*( arg.oProd.Bytes() + 2*arg.outA.Bytes() );
726  case FORCE_ALL_LINK:
727  return 2*arg.threads*( (goes_forward(arg.sig) ? 4 : 2)*arg.outA.Bytes() + 3*arg.link.Bytes()
728  + arg.oProd.Bytes() + arg.qPrev.Bytes() + 2*arg.outB.Bytes());
729  case FORCE_MIDDLE_LINK:
730  case FORCE_LEPAGE_MIDDLE_LINK:
731  return 2*arg.threads*( ( goes_forward(arg.sig) ? 2*arg.outA.Bytes() : 0 ) +
732  (arg.p_mu ? arg.pMu.Bytes() : 0) +
733  (arg.q_mu ? arg.qMu.Bytes() : 0) +
734  ( ( goes_forward(arg.sig) || arg.q_mu ) ? arg.qPrev.Bytes() : 0) +
735  arg.p3.Bytes() + 3*arg.link.Bytes() + arg.oProd.Bytes() );
736  case FORCE_SIDE_LINK:
737  return 2*arg.threads*( 2*arg.outA.Bytes() + 2*arg.outB.Bytes() +
738  arg.p3.Bytes() + arg.link.Bytes() + arg.qProd.Bytes() );
739  case FORCE_SIDE_LINK_SHORT:
740  return 2*arg.threads*( 2*arg.outA.Bytes() + arg.p3.Bytes() );
741  default: errorQuda("Undefined force type %d", type);
742  }
743  return 0;
744  }
745  };
746 
747  template<typename real>
748  static void hisqStaplesForce(GaugeField &Pmu, GaugeField &P3, GaugeField &P5, GaugeField &Pnumu,
749  GaugeField &Qmu, GaugeField &Qnumu, GaugeField &newOprod,
750  const GaugeField &oprod, const GaugeField &link,
751  const PathCoefficients<real> &act_path_coeff)
752  {
753  real OneLink = act_path_coeff.one;
754  real ThreeSt = act_path_coeff.three;
755  real mThreeSt = -ThreeSt;
756  real FiveSt = act_path_coeff.five;
757  real mFiveSt = -FiveSt;
758  real SevenSt = act_path_coeff.seven;
759  real Lepage = act_path_coeff.lepage;
760  real mLepage = -Lepage;
761 
762  FatLinkArg<real> arg(newOprod, oprod, link, OneLink, FORCE_ONE_LINK);
763  FatLinkForce<real, FatLinkArg<real> > oneLink(arg, link, 0, 0, FORCE_ONE_LINK);
764  oneLink.apply(0);
765 
766  for (int sig=0; sig<8; sig++) {
767  for (int mu=0; mu<8; mu++) {
768  if ( (mu == sig) || (mu == opp_dir(sig))) continue;
769 
770  //3-link
771  //Kernel A: middle link
772  FatLinkArg<real> middleLinkArg( newOprod, Pmu, P3, Qmu, oprod, link, mThreeSt, 2, FORCE_MIDDLE_LINK);
773  FatLinkForce<real, FatLinkArg<real> > middleLink(middleLinkArg, link, sig, mu, FORCE_MIDDLE_LINK);
774  middleLink.apply(0);
775 
776  for (int nu=0; nu < 8; nu++) {
777  if (nu == sig || nu == opp_dir(sig) || nu == mu || nu == opp_dir(mu)) continue;
778 
779  //5-link: middle link
780  //Kernel B
781  FatLinkArg<real> middleLinkArg( newOprod, Pnumu, P5, Qnumu, Pmu, Qmu, link, FiveSt, 1, FORCE_MIDDLE_LINK);
782  FatLinkForce<real, FatLinkArg<real> > middleLink(middleLinkArg, link, sig, nu, FORCE_MIDDLE_LINK);
783  middleLink.apply(0);
784 
785  for (int rho = 0; rho < 8; rho++) {
786  if (rho == sig || rho == opp_dir(sig) || rho == mu || rho == opp_dir(mu) || rho == nu || rho == opp_dir(nu)) continue;
787 
788  //7-link: middle link and side link
789  FatLinkArg<real> arg(newOprod, P5, Pnumu, Qnumu, link, SevenSt, FiveSt != 0 ? SevenSt/FiveSt : 0, 1, FORCE_ALL_LINK, true);
790  FatLinkForce<real, FatLinkArg<real> > all(arg, link, sig, rho, FORCE_ALL_LINK);
791  all.apply(0);
792 
793  }//rho
794 
795  //5-link: side link
796  FatLinkArg<real> arg(newOprod, P3, P5, Qmu, link, mFiveSt, (ThreeSt != 0 ? FiveSt/ThreeSt : 0), 1, FORCE_SIDE_LINK);
797  FatLinkForce<real, FatLinkArg<real> > side(arg, link, sig, nu, FORCE_SIDE_LINK);
798  side.apply(0);
799 
800  } //nu
801 
802  //lepage
803  if (Lepage != 0.) {
804  FatLinkArg<real> middleLinkArg( newOprod, P5, Pmu, Qmu, link, Lepage, 2, FORCE_LEPAGE_MIDDLE_LINK);
805  FatLinkForce<real, FatLinkArg<real> > middleLink(middleLinkArg, link, sig, mu, FORCE_LEPAGE_MIDDLE_LINK);
806  middleLink.apply(0);
807 
808  FatLinkArg<real> arg(newOprod, P3, P5, Qmu, link, mLepage, (ThreeSt != 0 ? Lepage/ThreeSt : 0), 2, FORCE_SIDE_LINK);
809  FatLinkForce<real, FatLinkArg<real> > side(arg, link, sig, mu, FORCE_SIDE_LINK);
810  side.apply(0);
811  } // Lepage != 0.0
812 
813  // 3-link side link
814  FatLinkArg<real> arg(newOprod, P3, link, ThreeSt, 1, FORCE_SIDE_LINK_SHORT);
815  FatLinkForce<real, FatLinkArg<real> > side(arg, P3, sig, mu, FORCE_SIDE_LINK_SHORT);
816  side.apply(0);
817  }//mu
818  }//sig
819 
820  } // hisqStaplesForce
821 
822  void hisqStaplesForce(GaugeField &newOprod, const GaugeField &oprod, const GaugeField &link, const double path_coeff_array[6])
823  {
824  if (!link.isNative()) errorQuda("Unsupported gauge order %d", link.Order());
825  if (!oprod.isNative()) errorQuda("Unsupported gauge order %d", oprod.Order());
826  if (!newOprod.isNative()) errorQuda("Unsupported gauge order %d", newOprod.Order());
827  if (checkLocation(newOprod,oprod,link) == QUDA_CPU_FIELD_LOCATION) errorQuda("CPU not implemented");
828 
829  // create color matrix fields with zero padding
830  GaugeFieldParam gauge_param(link);
834 
835  cudaGaugeField Pmu(gauge_param);
836  cudaGaugeField P3(gauge_param);
837  cudaGaugeField P5(gauge_param);
838  cudaGaugeField Pnumu(gauge_param);
839  cudaGaugeField Qmu(gauge_param);
840  cudaGaugeField Qnumu(gauge_param);
841 
842  QudaPrecision precision = checkPrecision(oprod, link, newOprod);
843  if (precision == QUDA_DOUBLE_PRECISION) {
844  PathCoefficients<double> act_path_coeff(path_coeff_array);
845  hisqStaplesForce<double>(Pmu, P3, P5, Pnumu, Qmu, Qnumu, newOprod, oprod, link, act_path_coeff);
846  } else if (precision == QUDA_SINGLE_PRECISION) {
847  PathCoefficients<float> act_path_coeff(path_coeff_array);
848  hisqStaplesForce<float>(Pmu, P3, P5, Pnumu, Qmu, Qnumu, newOprod, oprod, link, act_path_coeff);
849  } else {
850  errorQuda("Unsupported precision");
851  }
852 
853  cudaDeviceSynchronize();
854  checkCudaError();
855  }
856 
857  template <typename real, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
858  struct CompleteForceArg : public BaseForceArg<real,reconstruct> {
859 
860  typedef typename gauge_mapper<real,QUDA_RECONSTRUCT_NO>::type F;
861  F outA; // force output accessor
862  const F oProd; // force input accessor
863  const real coeff;
864 
865  CompleteForceArg(GaugeField &force, const GaugeField &link)
866  : BaseForceArg<real,reconstruct>(link, 0), outA(force), oProd(force), coeff(0.0)
867  { }
868 
869  };
870 
871  // Flops count: 4 matrix multiplications per lattice site = 792 Flops per site
872  template <typename real, typename Arg>
873  __global__ void completeForceKernel(Arg arg)
874  {
875  typedef Matrix<complex<real>,3> Link;
876  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
877  if (x_cb >= arg.threads) return;
878  int parity = blockIdx.y * blockDim.y + threadIdx.y;
879 
880  int x[4];
881  getCoords(x, x_cb, arg.X, parity);
882 
883  for (int d=0; d<4; d++) x[d] += arg.border[d];
884  int e_cb = linkIndex(x,arg.E);
885 
886 #pragma unroll
887  for (int sig=0; sig<4; ++sig) {
888  Link Uw = arg.link(sig, e_cb, parity);
889  Link Ox = arg.oProd(sig, e_cb, parity);
890  Link Ow = Uw*Ox;
891 
892  makeAntiHerm(Ow);
893 
894  real coeff = (parity==1) ? -1.0 : 1.0;
895  arg.outA(sig, e_cb, parity) = coeff*Ow;
896  }
897  }
898 
899  template <typename real, QudaReconstructType reconstruct=QUDA_RECONSTRUCT_NO>
900  struct LongLinkArg : public BaseForceArg<real,reconstruct> {
901 
902  typedef typename gauge::FloatNOrder<real,18,2,11> M;
903  typedef typename gauge_mapper<real,QUDA_RECONSTRUCT_NO>::type F;
904  F outA;
905  const F oProd;
906  const real coeff;
907 
908  LongLinkArg(GaugeField &newOprod, const GaugeField &link, const GaugeField &oprod, real coeff)
909  : BaseForceArg<real,reconstruct>(link,0), outA(newOprod), oProd(oprod), coeff(coeff)
910  { }
911 
912  };
913 
914  // Flops count, in two-number pair (matrix_mult, matrix_add)
915  // (24, 12)
916  // 4968 Flops per site in total
917  template <typename real, typename Arg>
918  __global__ void longLinkKernel(Arg arg)
919  {
920  typedef Matrix<complex<real>,3> Link;
921  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
922  if (x_cb >= arg.threads) return;
923  int parity = blockIdx.y * blockDim.y + threadIdx.y;
924 
925  int x[4];
926  int dx[4] = {0,0,0,0};
927 
928  getCoords(x, x_cb, arg.X, parity);
929 
930  for (int i=0; i<4; i++) x[i] += arg.border[i];
931  int e_cb = linkIndex(x,arg.E);
932 
933  /*
934  *
935  * A B C D E
936  * ---- ---- ---- ----
937  *
938  * ---> sig direction
939  *
940  * C is the current point (sid)
941  *
942  */
943 
944  // compute the force for forward long links
945 #pragma unroll
946  for (int sig=0; sig<4; sig++) {
947  int point_c = e_cb;
948 
949  dx[sig]++;
950  int point_d = linkIndexShift(x,dx,arg.E);
951 
952  dx[sig]++;
953  int point_e = linkIndexShift(x,dx,arg.E);
954 
955  dx[sig] = -1;
956  int point_b = linkIndexShift(x,dx,arg.E);
957 
958  dx[sig]--;
959  int point_a = linkIndexShift(x,dx,arg.E);
960  dx[sig] = 0;
961 
962  Link Uab = arg.link(sig, point_a, parity);
963  Link Ubc = arg.link(sig, point_b, 1-parity);
964  Link Ude = arg.link(sig, point_d, 1-parity);
965  Link Uef = arg.link(sig, point_e, parity);
966 
967  Link Oz = arg.oProd(sig, point_c, parity);
968  Link Oy = arg.oProd(sig, point_b, 1-parity);
969  Link Ox = arg.oProd(sig, point_a, parity);
970 
971  Link temp = Ude*Uef*Oz - Ude*Oy*Ubc + Ox*Uab*Ubc;
972 
973  Link force = arg.outA(sig, e_cb, parity);
974  arg.outA(sig, e_cb, parity) = force + arg.coeff*temp;
975  } // loop over sig
976 
977  }
978 
979  template <typename real, typename Arg>
980  class HisqForce : public TunableVectorY {
981 
982  Arg &arg;
983  const GaugeField &meta;
984  const HisqForceType type;
985 
986  unsigned int minThreads() const { return arg.threads; }
987  bool tuneGridDim() const { return false; }
988 
989  public:
990  HisqForce(Arg &arg, const GaugeField &meta, int sig, int mu, HisqForceType type)
991  : TunableVectorY(2), arg(arg), meta(meta), type(type) {
992  arg.sig = sig;
993  arg.mu = mu;
994  }
995  virtual ~HisqForce() { }
996 
997  void apply(const cudaStream_t &stream) {
998  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
999  switch (type) {
1000  case FORCE_LONG_LINK:
1001  longLinkKernel<real,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg); break;
1002  case FORCE_COMPLETE:
1003  completeForceKernel<real,Arg><<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg); break;
1004  default:
1005  errorQuda("Undefined force type %d", type);
1006  }
1007  }
1008 
1009  TuneKey tuneKey() const {
1010  std::stringstream aux;
1011  aux << "threads=" << arg.threads << ",prec=" << sizeof(real);
1012  switch (type) {
1013  case FORCE_LONG_LINK: aux << ",LONG_LINK"; break;
1014  case FORCE_COMPLETE: aux << ",COMPLETE"; break;
1015  default: errorQuda("Undefined force type %d", type);
1016  }
1017  return TuneKey(meta.VolString(), typeid(*this).name(), aux.str().c_str());
1018  }
1019 
1020  void preTune() {
1021  switch (type) {
1022  case FORCE_LONG_LINK:
1023  case FORCE_COMPLETE:
1024  arg.outA.save(); break;
1025  default: errorQuda("Undefined force type %d", type);
1026  }
1027  }
1028 
1029  void postTune() {
1030  switch (type) {
1031  case FORCE_LONG_LINK:
1032  case FORCE_COMPLETE:
1033  arg.outA.load(); break;
1034  default: errorQuda("Undefined force type %d", type);
1035  }
1036  }
1037 
1038  long long flops() const {
1039  switch (type) {
1040  case FORCE_LONG_LINK: return 2*arg.threads*4968ll;
1041  case FORCE_COMPLETE: return 2*arg.threads*792ll;
1042  default: errorQuda("Undefined force type %d", type);
1043  }
1044  return 0;
1045  }
1046 
1047  long long bytes() const {
1048  switch (type) {
1049  case FORCE_LONG_LINK: return 4*2*arg.threads*(2*arg.outA.Bytes() + 4*arg.link.Bytes() + 3*arg.oProd.Bytes());
1050  case FORCE_COMPLETE: return 4*2*arg.threads*(arg.outA.Bytes() + arg.link.Bytes() + arg.oProd.Bytes());
1051  default: errorQuda("Undefined force type %d", type);
1052  }
1053  return 0;
1054  }
1055  };
1056 
1057  void hisqLongLinkForce(GaugeField &newOprod, const GaugeField &oldOprod, const GaugeField &link, double coeff)
1058  {
1059  if (!link.isNative()) errorQuda("Unsupported gauge order %d", link.Order());
1060  if (!oldOprod.isNative()) errorQuda("Unsupported gauge order %d", oldOprod.Order());
1061  if (!newOprod.isNative()) errorQuda("Unsupported gauge order %d", newOprod.Order());
1062  if (checkLocation(newOprod,oldOprod,link) == QUDA_CPU_FIELD_LOCATION) errorQuda("CPU not implemented");
1063 
1064  QudaPrecision precision = checkPrecision(newOprod, link, oldOprod);
1065  if (precision == QUDA_DOUBLE_PRECISION) {
1066  if (link.Reconstruct() == QUDA_RECONSTRUCT_NO) {
1067  typedef LongLinkArg<double,QUDA_RECONSTRUCT_NO> Arg;
1068  Arg arg(newOprod, link, oldOprod, coeff);
1069  HisqForce<double,Arg> longLink(arg, link, 0, 0, FORCE_LONG_LINK);
1070  longLink.apply(0);
1071  } else {
1072  errorQuda("Reconstruct %d not supported", link.Reconstruct());
1073  }
1074  } else if (precision == QUDA_SINGLE_PRECISION) {
1075  if (link.Reconstruct() == QUDA_RECONSTRUCT_NO) {
1076  typedef LongLinkArg<float,QUDA_RECONSTRUCT_NO> Arg;
1077  Arg arg(newOprod, link, oldOprod, coeff);
1078  HisqForce<float, Arg> longLink(arg, link, 0, 0, FORCE_LONG_LINK);
1079  longLink.apply(0);
1080  } else {
1081  errorQuda("Reconstruct %d not supported", link.Reconstruct());
1082  }
1083  } else {
1084  errorQuda("Unsupported precision %d", precision);
1085  }
1086  checkCudaError();
1087  cudaDeviceSynchronize();
1088  }
1089 
1090  void hisqCompleteForce(GaugeField &force, const GaugeField &link)
1091  {
1092  if (!link.isNative()) errorQuda("Unsupported gauge order %d", link.Order());
1093  if (!force.isNative()) errorQuda("Unsupported gauge order %d", force.Order());
1094  if (checkLocation(force,link) == QUDA_CPU_FIELD_LOCATION) errorQuda("CPU not implemented");
1095 
1096  QudaPrecision precision = checkPrecision(link, force);
1097  if (precision == QUDA_DOUBLE_PRECISION) {
1098  if (link.Reconstruct() == QUDA_RECONSTRUCT_NO) {
1099  typedef CompleteForceArg<double,QUDA_RECONSTRUCT_NO> Arg;
1100  Arg arg(force, link);
1101  HisqForce<double,Arg> completeForce(arg, link, 0, 0, FORCE_COMPLETE);
1102  completeForce.apply(0);
1103  } else {
1104  errorQuda("Reconstruct %d not supported", link.Reconstruct());
1105  }
1106  } else if (precision == QUDA_SINGLE_PRECISION) {
1107  if (link.Reconstruct() == QUDA_RECONSTRUCT_NO) {
1108  typedef CompleteForceArg<float,QUDA_RECONSTRUCT_NO> Arg;
1109  Arg arg(force, link);
1110  HisqForce<float, Arg> completeForce(arg, link, 0, 0, FORCE_COMPLETE);
1111  completeForce.apply(0);
1112  } else {
1113  errorQuda("Reconstruct %d not supported", link.Reconstruct());
1114  }
1115  } else {
1116  errorQuda("Unsupported precision %d", precision);
1117  }
1118  checkCudaError();
1119  cudaDeviceSynchronize();
1120  }
1121  } // namespace fermion_force
1122 } // namespace quda
1123 
1124 #endif // GPU_HISQ_FORCE
double mu
Definition: test_util.cpp:1648
enum QudaPrecision_s QudaPrecision
static __device__ __host__ int linkIndexShift(const I x[], const J dx[], const K X[4])
static __device__ __host__ int linkIndex(const int x[], const I X[4])
void hisqLongLinkForce(GaugeField &newOprod, const GaugeField &oprod, const GaugeField &link, double coeff)
Compute the long-link contribution to the fermion force.
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define checkPrecision(...)
#define errorQuda(...)
Definition: util_quda.h:121
void hisqCompleteForce(GaugeField &oprod, const GaugeField &link)
Multiply the computed the force matrix by the gauge field and perform traceless anti-hermitian projec...
cudaStream_t * stream
QudaGaugeParam gauge_param
int E[4]
Definition: test_util.cpp:35
#define Qmu
void hisqStaplesForce(GaugeField &newOprod, const GaugeField &oprod, const GaugeField &link, const double path_coeff[6])
Compute the fat-link contribution to the fermion force.
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
#define checkLocation(...)
Main header file for host and device accessors to GaugeFields.
int X[4]
Definition: covdev_test.cpp:70
QudaReconstructType reconstruct
Definition: quda.h:50
static int commDim[QUDA_MAX_DIM]
Definition: dslash_pack.cuh:9
#define Pmu
#define P5
unsigned long long flops
Definition: blas_quda.cu:22
#define Qnumu
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
__device__ __host__ void makeAntiHerm(Matrix< Complex, N > &m)
Definition: quda_matrix.h:746
#define checkCudaError()
Definition: util_quda.h:161
#define Pnumu
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:130
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
QudaParity parity
Definition: covdev_test.cpp:54
unsigned long long bytes
Definition: blas_quda.cu:23
#define P3
int comm_dim_partitioned(int dim)
__host__ __device__ int getCoords(int coord[], const Arg &arg, int &idx, int parity, int &dim)
Compute the space-time coordinates we are at.