QUDA  v1.1.0
A library for QCD on GPUs
Go to the documentation of this file.
1 #include <cstdio>
3 #include <quda_internal.h>
4 #include <gauge_field.h>
5 #include <llfat_quda.h>
6 #include <index_helper.cuh>
7 #include <gauge_field_order.h>
8 #include <fast_intdiv.h>
9 #include <tune_quda.h>
10 #include <instantiate.h>
12 #define MIN_COEFF 1e-7
14 namespace quda {
16  template <typename Float_, int nColor_, QudaReconstructType recon>
17  struct LinkArg {
18  using Float = Float_;
19  static constexpr int nColor = nColor_;
20  typedef typename gauge_mapper<Float, QUDA_RECONSTRUCT_NO>::type Link;
21  typedef typename gauge_mapper<Float, recon, 18, QUDA_STAGGERED_PHASE_MILC>::type Gauge;
23  Link link;
24  Gauge u;
25  Float coeff;
27  unsigned int threads;
29  int_fastdiv X[4];
30  int_fastdiv E[4];
31  int border[4];
33  /** This keeps track of any parity changes that result in using a
34  radius of 1 for the extended border (the staple computations use
35  such an extension, and if an odd number of dimensions are
36  partitioned then we have to correct for this when computing the local index */
37  int odd_bit;
39  LinkArg(GaugeField &link, const GaugeField &u, Float coeff) :
40  threads(link.VolumeCB()),
41  link(link),
42  u(u),
43  coeff(coeff)
44  {
45  if (u.StaggeredPhase() != QUDA_STAGGERED_PHASE_MILC && u.Reconstruct() != QUDA_RECONSTRUCT_NO)
46  errorQuda("Staggered phase type %d not supported", u.StaggeredPhase());
47  for (int d=0; d<4; d++) {
48  X[d] = link.X()[d];
49  E[d] = u.X()[d];
50  border[d] = (E[d] - X[d]) / 2;
51  }
52  }
53  };
55  template <int dir, typename Arg>
56  __device__ void longLinkDir(Arg &arg, int idx, int parity) {
57  int x[4];
58  int dx[4] = {0, 0, 0, 0};
60  auto y = arg.u.coords;
61  getCoords(x, idx, arg.X, parity);
62  for (int d=0; d<4; d++) x[d] += arg.border[d];
64  using Link = Matrix<complex<typename Arg::Float>, Arg::nColor>;
66  Link a = arg.u(dir, linkIndex(y, x, arg.E), parity);
68  dx[dir]++;
69  Link b = arg.u(dir, linkIndexShift(y, x, dx, arg.E), 1-parity);
71  dx[dir]++;
72  Link c = arg.u(dir, linkIndexShift(y, x, dx, arg.E), parity);
73  dx[dir]-=2;
75  arg.link(dir, idx, parity) = arg.coeff * a * b * c;
76  }
78  template <typename Arg>
79  __global__ void computeLongLink(Arg arg) {
81  int idx = blockIdx.x*blockDim.x + threadIdx.x;
82  int parity = blockIdx.y*blockDim.y + threadIdx.y;
83  int dir = blockIdx.z*blockDim.z + threadIdx.z;
84  if (idx >= arg.threads) return;
85  if (dir >= 4) return;
87  switch(dir) {
88  case 0: longLinkDir<0>(arg, idx, parity); break;
89  case 1: longLinkDir<1>(arg, idx, parity); break;
90  case 2: longLinkDir<2>(arg, idx, parity); break;
91  case 3: longLinkDir<3>(arg, idx, parity); break;
92  }
93  return;
94  }
96  template <typename Float, int nColor, QudaReconstructType recon>
97  class LongLink : public TunableVectorYZ {
98  LinkArg<Float, nColor, recon> arg;
99  const GaugeField &meta;
100  unsigned int minThreads() const { return arg.threads; }
101  bool tuneGridDim() const { return false; }
103  public:
104  LongLink(const GaugeField &u, GaugeField &lng, double coeff) :
105  TunableVectorYZ(2,4),
106  arg(lng, u, coeff),
107  meta(lng)
108  {
109  strcpy(aux, meta.AuxString());
110  strcat(aux, comm_dim_partitioned_string());
112  apply(0);
113  }
115  void apply(const qudaStream_t &stream) {
116  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
117  qudaLaunchKernel(computeLongLink<decltype(arg)>, tp, stream, arg);
118  }
120  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
121  long long flops() const { return 2*4*arg.threads*198; }
122  long long bytes() const { return 2*4*arg.threads*(3*arg.u.Bytes()+arg.link.Bytes()); }
123  };
125  void computeLongLink(GaugeField &lng, const GaugeField &u, double coeff)
126  {
127  instantiate<LongLink, ReconstructNo12>(u, lng, coeff); // u first arg so we pick its recon
128  }
130  template <typename Arg>
131  __global__ void computeOneLink(Arg arg)
132  {
133  int idx = blockIdx.x*blockDim.x + threadIdx.x;
134  int parity = blockIdx.y * blockDim.y + threadIdx.y;
135  int dir = blockIdx.z * blockDim.z + threadIdx.z;
136  if (idx >= arg.threads) return;
137  if (dir >= 4) return;
139  auto x = arg.u.coords;
140  getCoords(x, idx, arg.X, parity);
141  for (int d=0; d<4; d++) x[d] += arg.border[d];
143  using Link = Matrix<complex<typename Arg::Float>, Arg::nColor>;
145  Link a = arg.u(dir, linkIndex(x,arg.E), parity);
147  arg.link(dir, idx, parity) = arg.coeff*a;
149  return;
150  }
152  template <typename Float, int nColor, QudaReconstructType recon>
153  class OneLink : public TunableVectorYZ {
154  LinkArg<Float, nColor, recon> arg;
155  const GaugeField &meta;
156  unsigned int minThreads() const { return arg.threads; }
157  bool tuneGridDim() const { return false; }
159  public:
160  OneLink(const GaugeField &u, GaugeField &fat, double coeff) :
161  TunableVectorYZ(2,4),
162  arg(fat, u, coeff),
163  meta(fat)
164  {
165  strcpy(aux, meta.AuxString());
166  strcat(aux, comm_dim_partitioned_string());
168  apply(0);
169  }
171  void apply(const qudaStream_t &stream) {
172  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
173  qudaLaunchKernel(computeOneLink<decltype(arg)>, tp, stream, arg);
174  }
176  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
177  long long flops() const { return 2*4*arg.threads*18; }
178  long long bytes() const { return 2*4*arg.threads*(arg.u.Bytes()+arg.link.Bytes()); }
179  };
181  void computeOneLink(GaugeField &fat, const GaugeField &u, double coeff)
182  {
183  if (u.StaggeredPhase() != QUDA_STAGGERED_PHASE_MILC && u.Reconstruct() != QUDA_RECONSTRUCT_NO)
184  errorQuda("Staggered phase type %d not supported", u.StaggeredPhase());
185  instantiate<OneLink, ReconstructNo12>(u, fat, coeff);
186  }
188  template <typename Float_, int nColor_, typename Fat, typename Staple, typename Mulink, typename Gauge>
189  struct StapleArg {
190  using Float = Float_;
191  static constexpr int nColor = nColor_;
192  unsigned int threads;
194  int_fastdiv X[4];
195  int_fastdiv E[4];
196  int border[4];
198  int_fastdiv inner_X[4];
199  int inner_border[4];
201  /** This keeps track of any parity changes that result in using a
202  radius of 1 for the extended border (the staple computations use
203  such an extension, and if an odd number of dimensions are
204  partitioned then we have to correct for this when computing the local index */
205  int odd_bit;
207  Gauge u;
208  Fat fat;
209  Staple staple;
210  Mulink mulink;
211  Float coeff;
213  int n_mu;
214  int mu_map[4];
216  StapleArg(Fat fat, Staple staple, Mulink mulink, Gauge u, Float coeff,
217  const GaugeField &fat_meta, const GaugeField &u_meta) :
218  threads(1), fat(fat), staple(staple), mulink(mulink), u(u), coeff(coeff),
219  odd_bit( (commDimPartitioned(0)+commDimPartitioned(1) +
220  commDimPartitioned(2)+commDimPartitioned(3))%2 )
221  {
222  for (int d=0; d<4; d++) {
223  X[d] = (fat_meta.X()[d] + u_meta.X()[d]) / 2;
224  E[d] = u_meta.X()[d];
225  border[d] = (E[d] - X[d]) / 2;
226  threads *= X[d];
228  inner_X[d] = fat_meta.X()[d];
229  inner_border[d] = (E[d] - inner_X[d]) / 2;
230  }
231  threads /= 2; // account for parity in y dimension
232  }
233  };
235  template <int mu, int nu, typename Arg>
236  __device__ inline void computeStaple(Matrix<complex<typename Arg::Float>, Arg::nColor> &staple, Arg &arg, int x[], int parity)
237  {
238  using Link = Matrix<complex<typename Arg::Float>, Arg::nColor>;
239  int *y = arg.u.coords, *y_mu = arg.mulink.coords, dx[4] = {0, 0, 0, 0};
241  /* Computes the upper staple :
242  * mu (B)
243  * +-------+
244  * nu | |
245  * (A) | |(C)
246  * X X
247  */
248  {
249  /* load matrix A*/
250  Link a = arg.u(nu, linkIndex(y, x, arg.E), parity);
252  /* load matrix B*/
253  dx[nu]++;
254  Link b = arg.mulink(mu, linkIndexShift(y_mu, x, dx, arg.E), 1-parity);
255  dx[nu]--;
257  /* load matrix C*/
258  dx[mu]++;
259  Link c = arg.u(nu, linkIndexShift(y, x, dx, arg.E), 1-parity);
260  dx[mu]--;
262  staple = a * b * conj(c);
263  }
265  /* Computes the lower staple :
266  * X X
267  * nu | |
268  * (A) | | (C)
269  * +-------+
270  * mu (B)
271  */
272  {
273  /* load matrix A*/
274  dx[nu]--;
275  Link a = arg.u(nu, linkIndexShift(y, x, dx, arg.E), 1-parity);
277  /* load matrix B*/
278  Link b = arg.mulink(mu, linkIndexShift(y_mu, x, dx, arg.E), 1-parity);
280  /* load matrix C*/
281  dx[mu]++;
282  Link c = arg.u(nu, linkIndexShift(y, x, dx, arg.E), parity);
283  dx[mu]--;
284  dx[nu]++;
286  staple = staple + conj(a)*b*c;
287  }
288  }
290  template <bool save_staple, typename Arg>
291  __global__ void computeStaple(Arg arg, int nu)
292  {
293  int idx = blockIdx.x*blockDim.x + threadIdx.x;
294  int parity = blockIdx.y*blockDim.y + threadIdx.y;
295  if (idx >= arg.threads) return;
297  int mu_idx = blockIdx.z*blockDim.z + threadIdx.z;
298  if (mu_idx >= arg.n_mu) return;
299  int mu;
300  switch(mu_idx) {
301  case 0: mu = arg.mu_map[0]; break;
302  case 1: mu = arg.mu_map[1]; break;
303  case 2: mu = arg.mu_map[2]; break;
304  }
306  int x[4];
307  getCoords(x, idx, arg.X, (parity+arg.odd_bit)%2);
308  for (int d=0; d<4; d++) x[d] += arg.border[d];
310  using Link = Matrix<complex<typename Arg::Float>, Arg::nColor>;
311  Link staple;
312  switch(mu) {
313  case 0:
314  switch(nu) {
315  case 1: computeStaple<0,1>(staple, arg, x, parity); break;
316  case 2: computeStaple<0,2>(staple, arg, x, parity); break;
317  case 3: computeStaple<0,3>(staple, arg, x, parity); break;
318  } break;
319  case 1:
320  switch(nu) {
321  case 0: computeStaple<1,0>(staple, arg, x, parity); break;
322  case 2: computeStaple<1,2>(staple, arg, x, parity); break;
323  case 3: computeStaple<1,3>(staple, arg, x, parity); break;
324  } break;
325  case 2:
326  switch(nu) {
327  case 0: computeStaple<2,0>(staple, arg, x, parity); break;
328  case 1: computeStaple<2,1>(staple, arg, x, parity); break;
329  case 3: computeStaple<2,3>(staple, arg, x, parity); break;
330  } break;
331  case 3:
332  switch(nu) {
333  case 0: computeStaple<3,0>(staple, arg, x, parity); break;
334  case 1: computeStaple<3,1>(staple, arg, x, parity); break;
335  case 2: computeStaple<3,2>(staple, arg, x, parity); break;
336  } break;
337  }
339  // exclude inner halo
340  if ( !(x[0] < arg.inner_border[0] || x[0] >= arg.inner_X[0] + arg.inner_border[0] ||
341  x[1] < arg.inner_border[1] || x[1] >= arg.inner_X[1] + arg.inner_border[1] ||
342  x[2] < arg.inner_border[2] || x[2] >= arg.inner_X[2] + arg.inner_border[2] ||
343  x[3] < arg.inner_border[3] || x[3] >= arg.inner_X[3] + arg.inner_border[3]) ) {
344  // convert to inner coords
345  int inner_x[] = {x[0]-arg.inner_border[0], x[1]-arg.inner_border[1], x[2]-arg.inner_border[2], x[3]-arg.inner_border[3]};
346  Link fat = arg.fat(mu, linkIndex(inner_x, arg.inner_X), parity);
347  fat += arg.coeff * staple;
348  arg.fat(mu, linkIndex(inner_x, arg.inner_X), parity) = fat;
349  }
351  if (save_staple) arg.staple(mu, linkIndex(x, arg.E), parity) = staple;
352  return;
353  }
355  template <typename Float, typename Arg>
356  class Staple : public TunableVectorYZ {
357  Arg &arg;
358  const GaugeField &meta;
359  unsigned int minThreads() const { return arg.threads; }
360  bool tuneGridDim() const { return false; }
361  int nu;
362  int dir1;
363  int dir2;
364  bool save_staple;
366  public:
367  Staple(Arg &arg, int nu, int dir1, int dir2, bool save_staple, const GaugeField &meta)
368  : TunableVectorYZ(2,(3 - ( (dir1 > -1) ? 1 : 0 ) - ( (dir2 > -1) ? 1 : 0 ))),
369  arg(arg), meta(meta), nu(nu), dir1(dir1), dir2(dir2), save_staple(save_staple)
370  {
371  // compute the map for z thread index to mu index in the kernel
372  // mu != nu 3 -> n_mu = 3
373  // mu != nu != rho 2 -> n_mu = 2
374  // mu != nu != rho != sig 1 -> n_mu = 1
375  arg.n_mu = 3 - ( (dir1 > -1) ? 1 : 0 ) - ( (dir2 > -1) ? 1 : 0 );
376  int j=0;
377  for (int i=0; i<4; i++) {
378  if (i==nu || i==dir1 || i==dir2) continue; // skip these dimensions
379  arg.mu_map[j++] = i;
380  }
381  assert(j == arg.n_mu);
382  }
384  void apply(const qudaStream_t &stream) {
385  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
386  if (save_staple)
387  qudaLaunchKernel(computeStaple<true, Arg>, tp, stream, arg, nu);
388  else
389  qudaLaunchKernel(computeStaple<false, Arg>, tp, stream, arg, nu);
390  }
392  TuneKey tuneKey() const {
393  std::stringstream aux;
394  aux << meta.AuxString() << comm_dim_partitioned_string();
395  aux << ",nu=" << nu << ",dir1=" << dir1 << ",dir2=" << dir2 << ",save=" << save_staple;
396  return TuneKey(meta.VolString(), typeid(*this).name(), aux.str().c_str());
397  }
399  void preTune() { arg.fat.save(); arg.staple.save(); }
400  void postTune() { arg.fat.load(); arg.staple.load(); }
402  long long flops() const {
403  return 2*arg.n_mu*arg.threads*( 4*198 + 18 + 36 );
404  }
405  long long bytes() const {
406  return arg.n_mu*2*meta.VolumeCB()*arg.fat.Bytes()*2 // fat load/store is only done on interior
407  + arg.n_mu*2*arg.threads*(4*arg.u.Bytes() + 2*arg.mulink.Bytes() + (save_staple ? arg.staple.Bytes() : 0));
408  }
409  };
411  template <typename Float, int nColor, QudaReconstructType recon>
412  struct Staple_ {
413  Staple_(const GaugeField &u, GaugeField &fat, GaugeField &staple, const GaugeField &mulink,
414  int nu, int dir1, int dir2, double coeff, bool save_staple)
415  { // FIXME - incorporate another level of reconstruct peel off in instantiate
416  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO>::type L;
417  typedef typename gauge_mapper<Float,recon,18,QUDA_STAGGERED_PHASE_MILC>::type G;
418  if (mulink.Reconstruct() == QUDA_RECONSTRUCT_NO) {
419  StapleArg<Float, nColor, L, L, L, G> arg(L(fat), L(staple), L(mulink), G(u), coeff, fat, u);
420  Staple<Float,decltype(arg)> stapler(arg, nu, dir1, dir2, save_staple, fat);
421  stapler.apply(0);
422  } else if (mulink.Reconstruct() == recon) {
423  StapleArg<Float, nColor, L, L, G, G> arg(L(fat), L(staple), G(mulink), G(u), coeff, fat, u);
424  Staple<Float,decltype(arg)> stapler(arg, nu, dir1, dir2, save_staple, fat);
425  stapler.apply(0);
426  } else {
427  errorQuda("Reconstruct %d is not supported\n", u.Reconstruct());
428  }
429  }
430  };
432  // Compute the staple field for direction nu,excluding the directions dir1 and dir2.
433  void computeStaple(GaugeField &fat, GaugeField &staple, const GaugeField &mulink, const GaugeField &u,
434  int nu, int dir1, int dir2, double coeff, bool save_staple)
435  {
436  instantiate<Staple_, ReconstructNo12>(u, fat, staple, mulink, nu, dir1, dir2, coeff, save_staple);
437  }
439  void longKSLink(GaugeField *lng, const GaugeField &u, const double *coeff)
440  {
441  computeLongLink(*lng, u, coeff[1]);
442  }
444  void fatKSLink(GaugeField *fat, const GaugeField& u, const double *coeff)
445  {
446 #ifdef GPU_FATLINK
447  GaugeFieldParam gParam(u);
448  gParam.reconstruct = QUDA_RECONSTRUCT_NO;
449  gParam.setPrecision(gParam.Precision());
450  gParam.create = QUDA_NULL_FIELD_CREATE;
451  auto staple = GaugeField::Create(gParam);
452  auto staple1 = GaugeField::Create(gParam);
454  if ( ((fat->X()[0] % 2 != 0) || (fat->X()[1] % 2 != 0) || (fat->X()[2] % 2 != 0) || (fat->X()[3] % 2 != 0))
455  && (u.Reconstruct() != QUDA_RECONSTRUCT_NO)){
456  errorQuda("Reconstruct %d and odd dimensionsize is not supported by link fattening code (yet)\n",
457  u.Reconstruct());
458  }
460  computeOneLink(*fat, u, coeff[0]-6.0*coeff[5]);
462  // Check the coefficients. If all of the following are zero, return.
463  if (fabs(coeff[2]) >= MIN_COEFF || fabs(coeff[3]) >= MIN_COEFF ||
464  fabs(coeff[4]) >= MIN_COEFF || fabs(coeff[5]) >= MIN_COEFF) {
466  for (int nu = 0; nu < 4; nu++) {
467  computeStaple(*fat, *staple, u, u, nu, -1, -1, coeff[2], 1);
469  if (coeff[5] != 0.0) computeStaple(*fat, *staple, *staple, u, nu, -1, -1, coeff[5], 0);
471  for (int rho = 0; rho < 4; rho++) {
472  if (rho != nu) {
474  computeStaple(*fat, *staple1, *staple, u, rho, nu, -1, coeff[3], 1);
476  if (fabs(coeff[4]) > MIN_COEFF) {
477  for (int sig = 0; sig < 4; sig++) {
478  if (sig != nu && sig != rho) {
479  computeStaple(*fat, *staple, *staple1, u, sig, nu, rho, coeff[4], 0);
480  }
481  } //sig
482  } // MIN_COEFF
483  }
484  } //rho
485  } //nu
486  }
488  qudaDeviceSynchronize();
490  delete staple;
491  delete staple1;
492 #else
493  errorQuda("Fat-link computation not enabled");
494 #endif
495  }
497 #undef MIN_COEFF
499 } // namespace quda