QUDA  v1.1.0
A library for QCD on GPUs
llfat_quda.cu
Go to the documentation of this file.
1 #include <cstdio>
2 
3 #include <quda_internal.h>
4 #include <gauge_field.h>
5 #include <llfat_quda.h>
6 #include <index_helper.cuh>
7 #include <gauge_field_order.h>
8 #include <fast_intdiv.h>
9 #include <tune_quda.h>
10 #include <instantiate.h>
11 
12 #define MIN_COEFF 1e-7
13 
14 namespace quda {
15 
16  template <typename Float_, int nColor_, QudaReconstructType recon>
17  struct LinkArg {
18  using Float = Float_;
19  static constexpr int nColor = nColor_;
20  typedef typename gauge_mapper<Float, QUDA_RECONSTRUCT_NO>::type Link;
21  typedef typename gauge_mapper<Float, recon, 18, QUDA_STAGGERED_PHASE_MILC>::type Gauge;
22 
23  Link link;
24  Gauge u;
25  Float coeff;
26 
27  unsigned int threads;
28 
29  int_fastdiv X[4];
30  int_fastdiv E[4];
31  int border[4];
32 
33  /** This keeps track of any parity changes that result in using a
34  radius of 1 for the extended border (the staple computations use
35  such an extension, and if an odd number of dimensions are
36  partitioned then we have to correct for this when computing the local index */
37  int odd_bit;
38 
39  LinkArg(GaugeField &link, const GaugeField &u, Float coeff) :
40  threads(link.VolumeCB()),
41  link(link),
42  u(u),
43  coeff(coeff)
44  {
45  if (u.StaggeredPhase() != QUDA_STAGGERED_PHASE_MILC && u.Reconstruct() != QUDA_RECONSTRUCT_NO)
46  errorQuda("Staggered phase type %d not supported", u.StaggeredPhase());
47  for (int d=0; d<4; d++) {
48  X[d] = link.X()[d];
49  E[d] = u.X()[d];
50  border[d] = (E[d] - X[d]) / 2;
51  }
52  }
53  };
54 
55  template <int dir, typename Arg>
56  __device__ void longLinkDir(Arg &arg, int idx, int parity) {
57  int x[4];
58  int dx[4] = {0, 0, 0, 0};
59 
60  auto y = arg.u.coords;
61  getCoords(x, idx, arg.X, parity);
62  for (int d=0; d<4; d++) x[d] += arg.border[d];
63 
64  using Link = Matrix<complex<typename Arg::Float>, Arg::nColor>;
65 
66  Link a = arg.u(dir, linkIndex(y, x, arg.E), parity);
67 
68  dx[dir]++;
69  Link b = arg.u(dir, linkIndexShift(y, x, dx, arg.E), 1-parity);
70 
71  dx[dir]++;
72  Link c = arg.u(dir, linkIndexShift(y, x, dx, arg.E), parity);
73  dx[dir]-=2;
74 
75  arg.link(dir, idx, parity) = arg.coeff * a * b * c;
76  }
77 
78  template <typename Arg>
79  __global__ void computeLongLink(Arg arg) {
80 
81  int idx = blockIdx.x*blockDim.x + threadIdx.x;
82  int parity = blockIdx.y*blockDim.y + threadIdx.y;
83  int dir = blockIdx.z*blockDim.z + threadIdx.z;
84  if (idx >= arg.threads) return;
85  if (dir >= 4) return;
86 
87  switch(dir) {
88  case 0: longLinkDir<0>(arg, idx, parity); break;
89  case 1: longLinkDir<1>(arg, idx, parity); break;
90  case 2: longLinkDir<2>(arg, idx, parity); break;
91  case 3: longLinkDir<3>(arg, idx, parity); break;
92  }
93  return;
94  }
95 
96  template <typename Float, int nColor, QudaReconstructType recon>
97  class LongLink : public TunableVectorYZ {
98  LinkArg<Float, nColor, recon> arg;
99  const GaugeField &meta;
100  unsigned int minThreads() const { return arg.threads; }
101  bool tuneGridDim() const { return false; }
102 
103  public:
104  LongLink(const GaugeField &u, GaugeField &lng, double coeff) :
105  TunableVectorYZ(2,4),
106  arg(lng, u, coeff),
107  meta(lng)
108  {
109  strcpy(aux, meta.AuxString());
110  strcat(aux, comm_dim_partitioned_string());
111 
112  apply(0);
113  }
114 
115  void apply(const qudaStream_t &stream) {
116  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
117  qudaLaunchKernel(computeLongLink<decltype(arg)>, tp, stream, arg);
118  }
119 
120  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
121  long long flops() const { return 2*4*arg.threads*198; }
122  long long bytes() const { return 2*4*arg.threads*(3*arg.u.Bytes()+arg.link.Bytes()); }
123  };
124 
125  void computeLongLink(GaugeField &lng, const GaugeField &u, double coeff)
126  {
127  instantiate<LongLink, ReconstructNo12>(u, lng, coeff); // u first arg so we pick its recon
128  }
129 
130  template <typename Arg>
131  __global__ void computeOneLink(Arg arg)
132  {
133  int idx = blockIdx.x*blockDim.x + threadIdx.x;
134  int parity = blockIdx.y * blockDim.y + threadIdx.y;
135  int dir = blockIdx.z * blockDim.z + threadIdx.z;
136  if (idx >= arg.threads) return;
137  if (dir >= 4) return;
138 
139  auto x = arg.u.coords;
140  getCoords(x, idx, arg.X, parity);
141  for (int d=0; d<4; d++) x[d] += arg.border[d];
142 
143  using Link = Matrix<complex<typename Arg::Float>, Arg::nColor>;
144 
145  Link a = arg.u(dir, linkIndex(x,arg.E), parity);
146 
147  arg.link(dir, idx, parity) = arg.coeff*a;
148 
149  return;
150  }
151 
152  template <typename Float, int nColor, QudaReconstructType recon>
153  class OneLink : public TunableVectorYZ {
154  LinkArg<Float, nColor, recon> arg;
155  const GaugeField &meta;
156  unsigned int minThreads() const { return arg.threads; }
157  bool tuneGridDim() const { return false; }
158 
159  public:
160  OneLink(const GaugeField &u, GaugeField &fat, double coeff) :
161  TunableVectorYZ(2,4),
162  arg(fat, u, coeff),
163  meta(fat)
164  {
165  strcpy(aux, meta.AuxString());
166  strcat(aux, comm_dim_partitioned_string());
167 
168  apply(0);
169  }
170 
171  void apply(const qudaStream_t &stream) {
172  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
173  qudaLaunchKernel(computeOneLink<decltype(arg)>, tp, stream, arg);
174  }
175 
176  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
177  long long flops() const { return 2*4*arg.threads*18; }
178  long long bytes() const { return 2*4*arg.threads*(arg.u.Bytes()+arg.link.Bytes()); }
179  };
180 
181  void computeOneLink(GaugeField &fat, const GaugeField &u, double coeff)
182  {
183  if (u.StaggeredPhase() != QUDA_STAGGERED_PHASE_MILC && u.Reconstruct() != QUDA_RECONSTRUCT_NO)
184  errorQuda("Staggered phase type %d not supported", u.StaggeredPhase());
185  instantiate<OneLink, ReconstructNo12>(u, fat, coeff);
186  }
187 
188  template <typename Float_, int nColor_, typename Fat, typename Staple, typename Mulink, typename Gauge>
189  struct StapleArg {
190  using Float = Float_;
191  static constexpr int nColor = nColor_;
192  unsigned int threads;
193 
194  int_fastdiv X[4];
195  int_fastdiv E[4];
196  int border[4];
197 
198  int_fastdiv inner_X[4];
199  int inner_border[4];
200 
201  /** This keeps track of any parity changes that result in using a
202  radius of 1 for the extended border (the staple computations use
203  such an extension, and if an odd number of dimensions are
204  partitioned then we have to correct for this when computing the local index */
205  int odd_bit;
206 
207  Gauge u;
208  Fat fat;
209  Staple staple;
210  Mulink mulink;
211  Float coeff;
212 
213  int n_mu;
214  int mu_map[4];
215 
216  StapleArg(Fat fat, Staple staple, Mulink mulink, Gauge u, Float coeff,
217  const GaugeField &fat_meta, const GaugeField &u_meta) :
218  threads(1), fat(fat), staple(staple), mulink(mulink), u(u), coeff(coeff),
219  odd_bit( (commDimPartitioned(0)+commDimPartitioned(1) +
220  commDimPartitioned(2)+commDimPartitioned(3))%2 )
221  {
222  for (int d=0; d<4; d++) {
223  X[d] = (fat_meta.X()[d] + u_meta.X()[d]) / 2;
224  E[d] = u_meta.X()[d];
225  border[d] = (E[d] - X[d]) / 2;
226  threads *= X[d];
227 
228  inner_X[d] = fat_meta.X()[d];
229  inner_border[d] = (E[d] - inner_X[d]) / 2;
230  }
231  threads /= 2; // account for parity in y dimension
232  }
233  };
234 
235  template <int mu, int nu, typename Arg>
236  __device__ inline void computeStaple(Matrix<complex<typename Arg::Float>, Arg::nColor> &staple, Arg &arg, int x[], int parity)
237  {
238  using Link = Matrix<complex<typename Arg::Float>, Arg::nColor>;
239  int *y = arg.u.coords, *y_mu = arg.mulink.coords, dx[4] = {0, 0, 0, 0};
240 
241  /* Computes the upper staple :
242  * mu (B)
243  * +-------+
244  * nu | |
245  * (A) | |(C)
246  * X X
247  */
248  {
249  /* load matrix A*/
250  Link a = arg.u(nu, linkIndex(y, x, arg.E), parity);
251 
252  /* load matrix B*/
253  dx[nu]++;
254  Link b = arg.mulink(mu, linkIndexShift(y_mu, x, dx, arg.E), 1-parity);
255  dx[nu]--;
256 
257  /* load matrix C*/
258  dx[mu]++;
259  Link c = arg.u(nu, linkIndexShift(y, x, dx, arg.E), 1-parity);
260  dx[mu]--;
261 
262  staple = a * b * conj(c);
263  }
264 
265  /* Computes the lower staple :
266  * X X
267  * nu | |
268  * (A) | | (C)
269  * +-------+
270  * mu (B)
271  */
272  {
273  /* load matrix A*/
274  dx[nu]--;
275  Link a = arg.u(nu, linkIndexShift(y, x, dx, arg.E), 1-parity);
276 
277  /* load matrix B*/
278  Link b = arg.mulink(mu, linkIndexShift(y_mu, x, dx, arg.E), 1-parity);
279 
280  /* load matrix C*/
281  dx[mu]++;
282  Link c = arg.u(nu, linkIndexShift(y, x, dx, arg.E), parity);
283  dx[mu]--;
284  dx[nu]++;
285 
286  staple = staple + conj(a)*b*c;
287  }
288  }
289 
290  template <bool save_staple, typename Arg>
291  __global__ void computeStaple(Arg arg, int nu)
292  {
293  int idx = blockIdx.x*blockDim.x + threadIdx.x;
294  int parity = blockIdx.y*blockDim.y + threadIdx.y;
295  if (idx >= arg.threads) return;
296 
297  int mu_idx = blockIdx.z*blockDim.z + threadIdx.z;
298  if (mu_idx >= arg.n_mu) return;
299  int mu;
300  switch(mu_idx) {
301  case 0: mu = arg.mu_map[0]; break;
302  case 1: mu = arg.mu_map[1]; break;
303  case 2: mu = arg.mu_map[2]; break;
304  }
305 
306  int x[4];
307  getCoords(x, idx, arg.X, (parity+arg.odd_bit)%2);
308  for (int d=0; d<4; d++) x[d] += arg.border[d];
309 
310  using Link = Matrix<complex<typename Arg::Float>, Arg::nColor>;
311  Link staple;
312  switch(mu) {
313  case 0:
314  switch(nu) {
315  case 1: computeStaple<0,1>(staple, arg, x, parity); break;
316  case 2: computeStaple<0,2>(staple, arg, x, parity); break;
317  case 3: computeStaple<0,3>(staple, arg, x, parity); break;
318  } break;
319  case 1:
320  switch(nu) {
321  case 0: computeStaple<1,0>(staple, arg, x, parity); break;
322  case 2: computeStaple<1,2>(staple, arg, x, parity); break;
323  case 3: computeStaple<1,3>(staple, arg, x, parity); break;
324  } break;
325  case 2:
326  switch(nu) {
327  case 0: computeStaple<2,0>(staple, arg, x, parity); break;
328  case 1: computeStaple<2,1>(staple, arg, x, parity); break;
329  case 3: computeStaple<2,3>(staple, arg, x, parity); break;
330  } break;
331  case 3:
332  switch(nu) {
333  case 0: computeStaple<3,0>(staple, arg, x, parity); break;
334  case 1: computeStaple<3,1>(staple, arg, x, parity); break;
335  case 2: computeStaple<3,2>(staple, arg, x, parity); break;
336  } break;
337  }
338 
339  // exclude inner halo
340  if ( !(x[0] < arg.inner_border[0] || x[0] >= arg.inner_X[0] + arg.inner_border[0] ||
341  x[1] < arg.inner_border[1] || x[1] >= arg.inner_X[1] + arg.inner_border[1] ||
342  x[2] < arg.inner_border[2] || x[2] >= arg.inner_X[2] + arg.inner_border[2] ||
343  x[3] < arg.inner_border[3] || x[3] >= arg.inner_X[3] + arg.inner_border[3]) ) {
344  // convert to inner coords
345  int inner_x[] = {x[0]-arg.inner_border[0], x[1]-arg.inner_border[1], x[2]-arg.inner_border[2], x[3]-arg.inner_border[3]};
346  Link fat = arg.fat(mu, linkIndex(inner_x, arg.inner_X), parity);
347  fat += arg.coeff * staple;
348  arg.fat(mu, linkIndex(inner_x, arg.inner_X), parity) = fat;
349  }
350 
351  if (save_staple) arg.staple(mu, linkIndex(x, arg.E), parity) = staple;
352  return;
353  }
354 
355  template <typename Float, typename Arg>
356  class Staple : public TunableVectorYZ {
357  Arg &arg;
358  const GaugeField &meta;
359  unsigned int minThreads() const { return arg.threads; }
360  bool tuneGridDim() const { return false; }
361  int nu;
362  int dir1;
363  int dir2;
364  bool save_staple;
365 
366  public:
367  Staple(Arg &arg, int nu, int dir1, int dir2, bool save_staple, const GaugeField &meta)
368  : TunableVectorYZ(2,(3 - ( (dir1 > -1) ? 1 : 0 ) - ( (dir2 > -1) ? 1 : 0 ))),
369  arg(arg), meta(meta), nu(nu), dir1(dir1), dir2(dir2), save_staple(save_staple)
370  {
371  // compute the map for z thread index to mu index in the kernel
372  // mu != nu 3 -> n_mu = 3
373  // mu != nu != rho 2 -> n_mu = 2
374  // mu != nu != rho != sig 1 -> n_mu = 1
375  arg.n_mu = 3 - ( (dir1 > -1) ? 1 : 0 ) - ( (dir2 > -1) ? 1 : 0 );
376  int j=0;
377  for (int i=0; i<4; i++) {
378  if (i==nu || i==dir1 || i==dir2) continue; // skip these dimensions
379  arg.mu_map[j++] = i;
380  }
381  assert(j == arg.n_mu);
382  }
383 
384  void apply(const qudaStream_t &stream) {
385  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
386  if (save_staple)
387  qudaLaunchKernel(computeStaple<true, Arg>, tp, stream, arg, nu);
388  else
389  qudaLaunchKernel(computeStaple<false, Arg>, tp, stream, arg, nu);
390  }
391 
392  TuneKey tuneKey() const {
393  std::stringstream aux;
394  aux << meta.AuxString() << comm_dim_partitioned_string();
395  aux << ",nu=" << nu << ",dir1=" << dir1 << ",dir2=" << dir2 << ",save=" << save_staple;
396  return TuneKey(meta.VolString(), typeid(*this).name(), aux.str().c_str());
397  }
398 
399  void preTune() { arg.fat.save(); arg.staple.save(); }
400  void postTune() { arg.fat.load(); arg.staple.load(); }
401 
402  long long flops() const {
403  return 2*arg.n_mu*arg.threads*( 4*198 + 18 + 36 );
404  }
405  long long bytes() const {
406  return arg.n_mu*2*meta.VolumeCB()*arg.fat.Bytes()*2 // fat load/store is only done on interior
407  + arg.n_mu*2*arg.threads*(4*arg.u.Bytes() + 2*arg.mulink.Bytes() + (save_staple ? arg.staple.Bytes() : 0));
408  }
409  };
410 
411  template <typename Float, int nColor, QudaReconstructType recon>
412  struct Staple_ {
413  Staple_(const GaugeField &u, GaugeField &fat, GaugeField &staple, const GaugeField &mulink,
414  int nu, int dir1, int dir2, double coeff, bool save_staple)
415  { // FIXME - incorporate another level of reconstruct peel off in instantiate
416  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO>::type L;
417  typedef typename gauge_mapper<Float,recon,18,QUDA_STAGGERED_PHASE_MILC>::type G;
418  if (mulink.Reconstruct() == QUDA_RECONSTRUCT_NO) {
419  StapleArg<Float, nColor, L, L, L, G> arg(L(fat), L(staple), L(mulink), G(u), coeff, fat, u);
420  Staple<Float,decltype(arg)> stapler(arg, nu, dir1, dir2, save_staple, fat);
421  stapler.apply(0);
422  } else if (mulink.Reconstruct() == recon) {
423  StapleArg<Float, nColor, L, L, G, G> arg(L(fat), L(staple), G(mulink), G(u), coeff, fat, u);
424  Staple<Float,decltype(arg)> stapler(arg, nu, dir1, dir2, save_staple, fat);
425  stapler.apply(0);
426  } else {
427  errorQuda("Reconstruct %d is not supported\n", u.Reconstruct());
428  }
429  }
430  };
431 
432  // Compute the staple field for direction nu,excluding the directions dir1 and dir2.
433  void computeStaple(GaugeField &fat, GaugeField &staple, const GaugeField &mulink, const GaugeField &u,
434  int nu, int dir1, int dir2, double coeff, bool save_staple)
435  {
436  instantiate<Staple_, ReconstructNo12>(u, fat, staple, mulink, nu, dir1, dir2, coeff, save_staple);
437  }
438 
439  void longKSLink(GaugeField *lng, const GaugeField &u, const double *coeff)
440  {
441  computeLongLink(*lng, u, coeff[1]);
442  }
443 
444  void fatKSLink(GaugeField *fat, const GaugeField& u, const double *coeff)
445  {
446 #ifdef GPU_FATLINK
447  GaugeFieldParam gParam(u);
448  gParam.reconstruct = QUDA_RECONSTRUCT_NO;
449  gParam.setPrecision(gParam.Precision());
450  gParam.create = QUDA_NULL_FIELD_CREATE;
451  auto staple = GaugeField::Create(gParam);
452  auto staple1 = GaugeField::Create(gParam);
453 
454  if ( ((fat->X()[0] % 2 != 0) || (fat->X()[1] % 2 != 0) || (fat->X()[2] % 2 != 0) || (fat->X()[3] % 2 != 0))
455  && (u.Reconstruct() != QUDA_RECONSTRUCT_NO)){
456  errorQuda("Reconstruct %d and odd dimensionsize is not supported by link fattening code (yet)\n",
457  u.Reconstruct());
458  }
459 
460  computeOneLink(*fat, u, coeff[0]-6.0*coeff[5]);
461 
462  // Check the coefficients. If all of the following are zero, return.
463  if (fabs(coeff[2]) >= MIN_COEFF || fabs(coeff[3]) >= MIN_COEFF ||
464  fabs(coeff[4]) >= MIN_COEFF || fabs(coeff[5]) >= MIN_COEFF) {
465 
466  for (int nu = 0; nu < 4; nu++) {
467  computeStaple(*fat, *staple, u, u, nu, -1, -1, coeff[2], 1);
468 
469  if (coeff[5] != 0.0) computeStaple(*fat, *staple, *staple, u, nu, -1, -1, coeff[5], 0);
470 
471  for (int rho = 0; rho < 4; rho++) {
472  if (rho != nu) {
473 
474  computeStaple(*fat, *staple1, *staple, u, rho, nu, -1, coeff[3], 1);
475 
476  if (fabs(coeff[4]) > MIN_COEFF) {
477  for (int sig = 0; sig < 4; sig++) {
478  if (sig != nu && sig != rho) {
479  computeStaple(*fat, *staple, *staple1, u, sig, nu, rho, coeff[4], 0);
480  }
481  } //sig
482  } // MIN_COEFF
483  }
484  } //rho
485  } //nu
486  }
487 
488  qudaDeviceSynchronize();
489 
490  delete staple;
491  delete staple1;
492 #else
493  errorQuda("Fat-link computation not enabled");
494 #endif
495  }
496 
497 #undef MIN_COEFF
498 
499 } // namespace quda