3 #include <quda_internal.h>
4 #include <gauge_field.h>
5 #include <llfat_quda.h>
6 #include <index_helper.cuh>
7 #include <gauge_field_order.h>
8 #include <fast_intdiv.h>
10 #include <instantiate.h>
12 #define MIN_COEFF 1e-7
16 template <typename Float_, int nColor_, QudaReconstructType recon>
19 static constexpr int nColor = nColor_;
20 typedef typename gauge_mapper<Float, QUDA_RECONSTRUCT_NO>::type Link;
21 typedef typename gauge_mapper<Float, recon, 18, QUDA_STAGGERED_PHASE_MILC>::type Gauge;
33 /** This keeps track of any parity changes that result in using a
34 radius of 1 for the extended border (the staple computations use
35 such an extension, and if an odd number of dimensions are
36 partitioned then we have to correct for this when computing the local index */
39 LinkArg(GaugeField &link, const GaugeField &u, Float coeff) :
40 threads(link.VolumeCB()),
45 if (u.StaggeredPhase() != QUDA_STAGGERED_PHASE_MILC && u.Reconstruct() != QUDA_RECONSTRUCT_NO)
46 errorQuda("Staggered phase type %d not supported", u.StaggeredPhase());
47 for (int d=0; d<4; d++) {
50 border[d] = (E[d] - X[d]) / 2;
55 template <int dir, typename Arg>
56 __device__ void longLinkDir(Arg &arg, int idx, int parity) {
58 int dx[4] = {0, 0, 0, 0};
60 auto y = arg.u.coords;
61 getCoords(x, idx, arg.X, parity);
62 for (int d=0; d<4; d++) x[d] += arg.border[d];
64 using Link = Matrix<complex<typename Arg::Float>, Arg::nColor>;
66 Link a = arg.u(dir, linkIndex(y, x, arg.E), parity);
69 Link b = arg.u(dir, linkIndexShift(y, x, dx, arg.E), 1-parity);
72 Link c = arg.u(dir, linkIndexShift(y, x, dx, arg.E), parity);
75 arg.link(dir, idx, parity) = arg.coeff * a * b * c;
78 template <typename Arg>
79 __global__ void computeLongLink(Arg arg) {
81 int idx = blockIdx.x*blockDim.x + threadIdx.x;
82 int parity = blockIdx.y*blockDim.y + threadIdx.y;
83 int dir = blockIdx.z*blockDim.z + threadIdx.z;
84 if (idx >= arg.threads) return;
88 case 0: longLinkDir<0>(arg, idx, parity); break;
89 case 1: longLinkDir<1>(arg, idx, parity); break;
90 case 2: longLinkDir<2>(arg, idx, parity); break;
91 case 3: longLinkDir<3>(arg, idx, parity); break;
96 template <typename Float, int nColor, QudaReconstructType recon>
97 class LongLink : public TunableVectorYZ {
98 LinkArg<Float, nColor, recon> arg;
99 const GaugeField &meta;
100 unsigned int minThreads() const { return arg.threads; }
101 bool tuneGridDim() const { return false; }
104 LongLink(const GaugeField &u, GaugeField &lng, double coeff) :
105 TunableVectorYZ(2,4),
109 strcpy(aux, meta.AuxString());
110 strcat(aux, comm_dim_partitioned_string());
115 void apply(const qudaStream_t &stream) {
116 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
117 qudaLaunchKernel(computeLongLink<decltype(arg)>, tp, stream, arg);
120 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
121 long long flops() const { return 2*4*arg.threads*198; }
122 long long bytes() const { return 2*4*arg.threads*(3*arg.u.Bytes()+arg.link.Bytes()); }
125 void computeLongLink(GaugeField &lng, const GaugeField &u, double coeff)
127 instantiate<LongLink, ReconstructNo12>(u, lng, coeff); // u first arg so we pick its recon
130 template <typename Arg>
131 __global__ void computeOneLink(Arg arg)
133 int idx = blockIdx.x*blockDim.x + threadIdx.x;
134 int parity = blockIdx.y * blockDim.y + threadIdx.y;
135 int dir = blockIdx.z * blockDim.z + threadIdx.z;
136 if (idx >= arg.threads) return;
137 if (dir >= 4) return;
139 auto x = arg.u.coords;
140 getCoords(x, idx, arg.X, parity);
141 for (int d=0; d<4; d++) x[d] += arg.border[d];
143 using Link = Matrix<complex<typename Arg::Float>, Arg::nColor>;
145 Link a = arg.u(dir, linkIndex(x,arg.E), parity);
147 arg.link(dir, idx, parity) = arg.coeff*a;
152 template <typename Float, int nColor, QudaReconstructType recon>
153 class OneLink : public TunableVectorYZ {
154 LinkArg<Float, nColor, recon> arg;
155 const GaugeField &meta;
156 unsigned int minThreads() const { return arg.threads; }
157 bool tuneGridDim() const { return false; }
160 OneLink(const GaugeField &u, GaugeField &fat, double coeff) :
161 TunableVectorYZ(2,4),
165 strcpy(aux, meta.AuxString());
166 strcat(aux, comm_dim_partitioned_string());
171 void apply(const qudaStream_t &stream) {
172 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
173 qudaLaunchKernel(computeOneLink<decltype(arg)>, tp, stream, arg);
176 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
177 long long flops() const { return 2*4*arg.threads*18; }
178 long long bytes() const { return 2*4*arg.threads*(arg.u.Bytes()+arg.link.Bytes()); }
181 void computeOneLink(GaugeField &fat, const GaugeField &u, double coeff)
183 if (u.StaggeredPhase() != QUDA_STAGGERED_PHASE_MILC && u.Reconstruct() != QUDA_RECONSTRUCT_NO)
184 errorQuda("Staggered phase type %d not supported", u.StaggeredPhase());
185 instantiate<OneLink, ReconstructNo12>(u, fat, coeff);
188 template <typename Float_, int nColor_, typename Fat, typename Staple, typename Mulink, typename Gauge>
190 using Float = Float_;
191 static constexpr int nColor = nColor_;
192 unsigned int threads;
198 int_fastdiv inner_X[4];
201 /** This keeps track of any parity changes that result in using a
202 radius of 1 for the extended border (the staple computations use
203 such an extension, and if an odd number of dimensions are
204 partitioned then we have to correct for this when computing the local index */
216 StapleArg(Fat fat, Staple staple, Mulink mulink, Gauge u, Float coeff,
217 const GaugeField &fat_meta, const GaugeField &u_meta) :
218 threads(1), fat(fat), staple(staple), mulink(mulink), u(u), coeff(coeff),
219 odd_bit( (commDimPartitioned(0)+commDimPartitioned(1) +
220 commDimPartitioned(2)+commDimPartitioned(3))%2 )
222 for (int d=0; d<4; d++) {
223 X[d] = (fat_meta.X()[d] + u_meta.X()[d]) / 2;
224 E[d] = u_meta.X()[d];
225 border[d] = (E[d] - X[d]) / 2;
228 inner_X[d] = fat_meta.X()[d];
229 inner_border[d] = (E[d] - inner_X[d]) / 2;
231 threads /= 2; // account for parity in y dimension
235 template <int mu, int nu, typename Arg>
236 __device__ inline void computeStaple(Matrix<complex<typename Arg::Float>, Arg::nColor> &staple, Arg &arg, int x[], int parity)
238 using Link = Matrix<complex<typename Arg::Float>, Arg::nColor>;
239 int *y = arg.u.coords, *y_mu = arg.mulink.coords, dx[4] = {0, 0, 0, 0};
241 /* Computes the upper staple :
250 Link a = arg.u(nu, linkIndex(y, x, arg.E), parity);
254 Link b = arg.mulink(mu, linkIndexShift(y_mu, x, dx, arg.E), 1-parity);
259 Link c = arg.u(nu, linkIndexShift(y, x, dx, arg.E), 1-parity);
262 staple = a * b * conj(c);
265 /* Computes the lower staple :
275 Link a = arg.u(nu, linkIndexShift(y, x, dx, arg.E), 1-parity);
278 Link b = arg.mulink(mu, linkIndexShift(y_mu, x, dx, arg.E), 1-parity);
282 Link c = arg.u(nu, linkIndexShift(y, x, dx, arg.E), parity);
286 staple = staple + conj(a)*b*c;
290 template <bool save_staple, typename Arg>
291 __global__ void computeStaple(Arg arg, int nu)
293 int idx = blockIdx.x*blockDim.x + threadIdx.x;
294 int parity = blockIdx.y*blockDim.y + threadIdx.y;
295 if (idx >= arg.threads) return;
297 int mu_idx = blockIdx.z*blockDim.z + threadIdx.z;
298 if (mu_idx >= arg.n_mu) return;
301 case 0: mu = arg.mu_map[0]; break;
302 case 1: mu = arg.mu_map[1]; break;
303 case 2: mu = arg.mu_map[2]; break;
307 getCoords(x, idx, arg.X, (parity+arg.odd_bit)%2);
308 for (int d=0; d<4; d++) x[d] += arg.border[d];
310 using Link = Matrix<complex<typename Arg::Float>, Arg::nColor>;
315 case 1: computeStaple<0,1>(staple, arg, x, parity); break;
316 case 2: computeStaple<0,2>(staple, arg, x, parity); break;
317 case 3: computeStaple<0,3>(staple, arg, x, parity); break;
321 case 0: computeStaple<1,0>(staple, arg, x, parity); break;
322 case 2: computeStaple<1,2>(staple, arg, x, parity); break;
323 case 3: computeStaple<1,3>(staple, arg, x, parity); break;
327 case 0: computeStaple<2,0>(staple, arg, x, parity); break;
328 case 1: computeStaple<2,1>(staple, arg, x, parity); break;
329 case 3: computeStaple<2,3>(staple, arg, x, parity); break;
333 case 0: computeStaple<3,0>(staple, arg, x, parity); break;
334 case 1: computeStaple<3,1>(staple, arg, x, parity); break;
335 case 2: computeStaple<3,2>(staple, arg, x, parity); break;
339 // exclude inner halo
340 if ( !(x[0] < arg.inner_border[0] || x[0] >= arg.inner_X[0] + arg.inner_border[0] ||
341 x[1] < arg.inner_border[1] || x[1] >= arg.inner_X[1] + arg.inner_border[1] ||
342 x[2] < arg.inner_border[2] || x[2] >= arg.inner_X[2] + arg.inner_border[2] ||
343 x[3] < arg.inner_border[3] || x[3] >= arg.inner_X[3] + arg.inner_border[3]) ) {
344 // convert to inner coords
345 int inner_x[] = {x[0]-arg.inner_border[0], x[1]-arg.inner_border[1], x[2]-arg.inner_border[2], x[3]-arg.inner_border[3]};
346 Link fat = arg.fat(mu, linkIndex(inner_x, arg.inner_X), parity);
347 fat += arg.coeff * staple;
348 arg.fat(mu, linkIndex(inner_x, arg.inner_X), parity) = fat;
351 if (save_staple) arg.staple(mu, linkIndex(x, arg.E), parity) = staple;
355 template <typename Float, typename Arg>
356 class Staple : public TunableVectorYZ {
358 const GaugeField &meta;
359 unsigned int minThreads() const { return arg.threads; }
360 bool tuneGridDim() const { return false; }
367 Staple(Arg &arg, int nu, int dir1, int dir2, bool save_staple, const GaugeField &meta)
368 : TunableVectorYZ(2,(3 - ( (dir1 > -1) ? 1 : 0 ) - ( (dir2 > -1) ? 1 : 0 ))),
369 arg(arg), meta(meta), nu(nu), dir1(dir1), dir2(dir2), save_staple(save_staple)
371 // compute the map for z thread index to mu index in the kernel
372 // mu != nu 3 -> n_mu = 3
373 // mu != nu != rho 2 -> n_mu = 2
374 // mu != nu != rho != sig 1 -> n_mu = 1
375 arg.n_mu = 3 - ( (dir1 > -1) ? 1 : 0 ) - ( (dir2 > -1) ? 1 : 0 );
377 for (int i=0; i<4; i++) {
378 if (i==nu || i==dir1 || i==dir2) continue; // skip these dimensions
381 assert(j == arg.n_mu);
384 void apply(const qudaStream_t &stream) {
385 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
387 qudaLaunchKernel(computeStaple<true, Arg>, tp, stream, arg, nu);
389 qudaLaunchKernel(computeStaple<false, Arg>, tp, stream, arg, nu);
392 TuneKey tuneKey() const {
393 std::stringstream aux;
394 aux << meta.AuxString() << comm_dim_partitioned_string();
395 aux << ",nu=" << nu << ",dir1=" << dir1 << ",dir2=" << dir2 << ",save=" << save_staple;
396 return TuneKey(meta.VolString(), typeid(*this).name(), aux.str().c_str());
399 void preTune() { arg.fat.save(); arg.staple.save(); }
400 void postTune() { arg.fat.load(); arg.staple.load(); }
402 long long flops() const {
403 return 2*arg.n_mu*arg.threads*( 4*198 + 18 + 36 );
405 long long bytes() const {
406 return arg.n_mu*2*meta.VolumeCB()*arg.fat.Bytes()*2 // fat load/store is only done on interior
407 + arg.n_mu*2*arg.threads*(4*arg.u.Bytes() + 2*arg.mulink.Bytes() + (save_staple ? arg.staple.Bytes() : 0));
411 template <typename Float, int nColor, QudaReconstructType recon>
413 Staple_(const GaugeField &u, GaugeField &fat, GaugeField &staple, const GaugeField &mulink,
414 int nu, int dir1, int dir2, double coeff, bool save_staple)
415 { // FIXME - incorporate another level of reconstruct peel off in instantiate
416 typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO>::type L;
417 typedef typename gauge_mapper<Float,recon,18,QUDA_STAGGERED_PHASE_MILC>::type G;
418 if (mulink.Reconstruct() == QUDA_RECONSTRUCT_NO) {
419 StapleArg<Float, nColor, L, L, L, G> arg(L(fat), L(staple), L(mulink), G(u), coeff, fat, u);
420 Staple<Float,decltype(arg)> stapler(arg, nu, dir1, dir2, save_staple, fat);
422 } else if (mulink.Reconstruct() == recon) {
423 StapleArg<Float, nColor, L, L, G, G> arg(L(fat), L(staple), G(mulink), G(u), coeff, fat, u);
424 Staple<Float,decltype(arg)> stapler(arg, nu, dir1, dir2, save_staple, fat);
427 errorQuda("Reconstruct %d is not supported\n", u.Reconstruct());
432 // Compute the staple field for direction nu,excluding the directions dir1 and dir2.
433 void computeStaple(GaugeField &fat, GaugeField &staple, const GaugeField &mulink, const GaugeField &u,
434 int nu, int dir1, int dir2, double coeff, bool save_staple)
436 instantiate<Staple_, ReconstructNo12>(u, fat, staple, mulink, nu, dir1, dir2, coeff, save_staple);
439 void longKSLink(GaugeField *lng, const GaugeField &u, const double *coeff)
441 computeLongLink(*lng, u, coeff[1]);
444 void fatKSLink(GaugeField *fat, const GaugeField& u, const double *coeff)
447 GaugeFieldParam gParam(u);
448 gParam.reconstruct = QUDA_RECONSTRUCT_NO;
449 gParam.setPrecision(gParam.Precision());
450 gParam.create = QUDA_NULL_FIELD_CREATE;
451 auto staple = GaugeField::Create(gParam);
452 auto staple1 = GaugeField::Create(gParam);
454 if ( ((fat->X()[0] % 2 != 0) || (fat->X()[1] % 2 != 0) || (fat->X()[2] % 2 != 0) || (fat->X()[3] % 2 != 0))
455 && (u.Reconstruct() != QUDA_RECONSTRUCT_NO)){
456 errorQuda("Reconstruct %d and odd dimensionsize is not supported by link fattening code (yet)\n",
460 computeOneLink(*fat, u, coeff[0]-6.0*coeff[5]);
462 // Check the coefficients. If all of the following are zero, return.
463 if (fabs(coeff[2]) >= MIN_COEFF || fabs(coeff[3]) >= MIN_COEFF ||
464 fabs(coeff[4]) >= MIN_COEFF || fabs(coeff[5]) >= MIN_COEFF) {
466 for (int nu = 0; nu < 4; nu++) {
467 computeStaple(*fat, *staple, u, u, nu, -1, -1, coeff[2], 1);
469 if (coeff[5] != 0.0) computeStaple(*fat, *staple, *staple, u, nu, -1, -1, coeff[5], 0);
471 for (int rho = 0; rho < 4; rho++) {
474 computeStaple(*fat, *staple1, *staple, u, rho, nu, -1, coeff[3], 1);
476 if (fabs(coeff[4]) > MIN_COEFF) {
477 for (int sig = 0; sig < 4; sig++) {
478 if (sig != nu && sig != rho) {
479 computeStaple(*fat, *staple, *staple1, u, sig, nu, rho, coeff[4], 0);
488 qudaDeviceSynchronize();
493 errorQuda("Fat-link computation not enabled");