2 #include <cuda_runtime.h> 13 #define MIN_COEFF 1e-7 19 template <
typename Float,
typename Link,
typename Gauge>
37 LinkArg(Link link, Gauge u, Float coeff,
const GaugeField &link_meta,
const GaugeField &u_meta)
38 : threads(link_meta.VolumeCB()), link(link), u(u), coeff(coeff)
40 for (
int d=0; d<4; d++) {
41 X[d] = link_meta.X()[d];
43 border[d] = (E[d] - X[d]) / 2;
48 template <
typename Float,
int dir,
typename Arg>
49 __device__
void longLinkDir(Arg &
arg,
int idx,
int parity) {
51 int dx[4] = {0, 0, 0, 0};
53 int *y = arg.u.coords;
55 for (
int d=0; d<4; d++) x[d] += arg.border[d];
57 typedef Matrix<complex<Float>,3> Link;
59 Link a = arg.u(dir,
linkIndex(y, x, arg.E), parity);
68 arg.link(dir, idx, parity) = arg.coeff * a * b * c;
71 template <
typename Float,
typename Arg>
72 __global__
void computeLongLink(Arg arg) {
74 int idx = blockIdx.x*blockDim.x + threadIdx.x;
75 int parity = blockIdx.y*blockDim.y + threadIdx.y;
76 int dir = blockIdx.z*blockDim.z + threadIdx.z;
77 if (idx >= arg.threads)
return;
81 case 0: longLinkDir<Float, 0>(
arg, idx,
parity);
break;
82 case 1: longLinkDir<Float, 1>(
arg, idx,
parity);
break;
83 case 2: longLinkDir<Float, 2>(
arg, idx,
parity);
break;
84 case 3: longLinkDir<Float, 3>(
arg, idx,
parity);
break;
89 template <
typename Float,
typename Arg>
90 class LongLink :
public TunableVectorYZ {
92 const GaugeField &meta;
93 unsigned int minThreads()
const {
return arg.threads; }
94 bool tuneGridDim()
const {
return false; }
97 LongLink(Arg &arg,
const GaugeField &meta) : TunableVectorYZ(2,4),
arg(arg), meta(meta) {}
98 virtual ~LongLink() {}
100 void apply(
const cudaStream_t &
stream) {
102 computeLongLink<Float><<<tp.grid,tp.block,tp.shared_bytes>>>(
arg);
105 TuneKey tuneKey()
const {
106 std::stringstream aux;
107 aux <<
"threads=" << arg.threads <<
",prec=" <<
sizeof(Float);
108 return TuneKey(meta.VolString(),
typeid(*this).name(), aux.str().c_str());
111 long long flops()
const {
return 2*4*arg.threads*198; }
112 long long bytes()
const {
return 2*4*arg.threads*(3*arg.u.Bytes()+arg.link.Bytes()); }
115 void computeLongLink(GaugeField &lng,
const GaugeField &u,
double coeff)
118 typedef typename gauge_mapper<double,QUDA_RECONSTRUCT_NO>::type L;
120 typedef LinkArg<double,L,L> Arg;
121 Arg
arg(L(lng), L(u), coeff, lng, u);
122 LongLink<double,Arg> longLink(arg,lng);
126 typedef typename gauge_mapper<double, QUDA_RECONSTRUCT_12, 18, QUDA_STAGGERED_PHASE_MILC>::type G;
127 typedef LinkArg<double, L, G> Arg;
128 Arg
arg(L(lng), G(u), coeff, lng, u);
129 LongLink<double, Arg> longLink(arg, lng);
132 errorQuda(
"Staggered phase type %d not supported", u.StaggeredPhase());
135 errorQuda(
"Reconstruct %d is not supported\n", u.Reconstruct());
138 typedef typename gauge_mapper<float,QUDA_RECONSTRUCT_NO>::type L;
140 typedef LinkArg<float,L,L> Arg;
141 Arg
arg(L(lng), L(u), coeff, lng, u) ;
142 LongLink<float,Arg> longLink(arg,lng);
146 typedef typename gauge_mapper<float, QUDA_RECONSTRUCT_12, 18, QUDA_STAGGERED_PHASE_MILC>::type G;
147 typedef LinkArg<float, L, G> Arg;
148 Arg
arg(L(lng), G(u), coeff, lng, u);
149 LongLink<float, Arg> longLink(arg, lng);
152 errorQuda(
"Staggered phase type %d not supported", u.StaggeredPhase());
155 errorQuda(
"Reconstruct %d is not supported\n", u.Reconstruct());
158 errorQuda(
"Unsupported precision %d\n", u.Precision());
163 template <
typename Float,
typename Arg>
164 __global__
void computeOneLink(Arg arg) {
166 int idx = blockIdx.x*blockDim.x + threadIdx.x;
167 int parity = blockIdx.y * blockDim.y + threadIdx.y;
168 int dir = blockIdx.z * blockDim.z + threadIdx.z;
169 if (idx >= arg.threads)
return;
170 if (dir >= 4)
return;
172 int *x = arg.u.coords;
174 for (
int d=0; d<4; d++) x[d] += arg.border[d];
176 typedef Matrix<complex<Float>,3> Link;
178 Link a = arg.u(dir,
linkIndex(x,x,arg.E), parity);
180 arg.link(dir, idx, parity) = arg.coeff*a;
185 template <
typename Float,
typename Arg>
186 class OneLink :
public TunableVectorYZ {
188 const GaugeField &meta;
189 unsigned int minThreads()
const {
return arg.threads; }
190 bool tuneGridDim()
const {
return false; }
193 OneLink(Arg &arg,
const GaugeField &meta) : TunableVectorYZ(2,4),
arg(arg), meta(meta) {}
194 virtual ~OneLink() {}
196 void apply(
const cudaStream_t &
stream) {
198 computeOneLink<Float><<<tp.grid,tp.block>>>(
arg);
201 TuneKey tuneKey()
const {
202 std::stringstream aux;
203 aux <<
"threads=" << arg.threads <<
",prec=" <<
sizeof(Float);
204 return TuneKey(meta.VolString(),
typeid(*this).name(), aux.str().c_str());
207 long long flops()
const {
return 2*4*arg.threads*18; }
208 long long bytes()
const {
return 2*4*arg.threads*(arg.u.Bytes()+arg.link.Bytes()); }
211 void computeOneLink(GaugeField &fat,
const GaugeField &u,
double coeff)
214 typedef typename gauge_mapper<double,QUDA_RECONSTRUCT_NO>::type L;
216 typedef LinkArg<double,L,L> Arg;
217 Arg
arg(L(fat), L(u), coeff, fat, u);
218 OneLink<double,Arg> oneLink(arg,fat);
221 typedef typename gauge_mapper<double,QUDA_RECONSTRUCT_12,18,QUDA_STAGGERED_PHASE_MILC>::type G;
222 typedef LinkArg<double,L,G> Arg;
223 Arg
arg(L(fat), G(u), coeff, fat, u);
224 OneLink<double,Arg> oneLink(arg,fat);
227 errorQuda(
"Reconstruct %d is not supported\n", u.Reconstruct());
230 typedef typename gauge_mapper<float,QUDA_RECONSTRUCT_NO>::type L;
232 typedef LinkArg<float,L,L> Arg;
233 Arg
arg(L(fat), L(u), coeff, fat, u);
234 OneLink<float,Arg> oneLink(arg,fat);
237 typedef typename gauge_mapper<float,QUDA_RECONSTRUCT_12,18,QUDA_STAGGERED_PHASE_MILC>::type G;
238 typedef LinkArg<float,L,G> Arg;
239 Arg
arg(L(fat), G(u), coeff, fat, u);
240 OneLink<float,Arg> oneLink(arg,fat);
243 errorQuda(
"Reconstruct %d is not supported\n", u.Reconstruct());
246 errorQuda(
"Unsupported precision %d\n", u.Precision());
251 template <
typename Float,
typename Fat,
typename Staple,
typename Mulink,
typename Gauge>
253 unsigned int threads;
277 StapleArg(Fat fat, Staple staple, Mulink mulink, Gauge u, Float coeff,
278 const GaugeField &fat_meta,
const GaugeField &u_meta)
279 : threads(1), fat(fat), staple(staple), mulink(mulink), u(u), coeff(coeff),
282 for (
int d=0; d<4; d++) {
283 X[d] = (fat_meta.X()[d] + u_meta.X()[d]) / 2;
284 E[d] = u_meta.X()[d];
285 border[d] = (E[d] - X[d]) / 2;
288 inner_X[d] = fat_meta.X()[d];
289 inner_border[d] = (E[d] - inner_X[d]) / 2;
295 template<
typename Float,
int mu,
int nu,
typename Arg>
296 __device__
inline void computeStaple(
Matrix<complex<Float>,3> &staple, Arg &arg,
int x[],
int parity) {
298 int *y = arg.u.coords, *y_mu = arg.mulink.coords, dx[4] = {0, 0, 0, 0};
309 Link a = arg.u(nu,
linkIndex(y, x, arg.E), parity);
321 staple = a * b *
conj(c);
345 staple = staple +
conj(a)*b*c;
349 template<
typename Float,
bool save_staple,
typename Arg>
352 int idx = blockIdx.x*blockDim.x + threadIdx.x;
353 int parity = blockIdx.y*blockDim.y + threadIdx.y;
354 if (idx >= arg.threads)
return;
356 int mu_idx = blockIdx.z*blockDim.z + threadIdx.z;
357 if (mu_idx >= arg.n_mu)
return;
360 case 0: mu = arg.mu_map[0];
break;
361 case 1: mu = arg.mu_map[1];
break;
362 case 2: mu = arg.mu_map[2];
break;
366 getCoords(x, idx, arg.X, (parity+arg.odd_bit)%2);
367 for (
int d=0; d<4; d++) x[d] += arg.border[d];
369 typedef Matrix<complex<Float>,3> Link;
374 case 1: computeStaple<Float,0,1>(staple,
arg, x,
parity);
break;
375 case 2: computeStaple<Float,0,2>(staple,
arg, x,
parity);
break;
376 case 3: computeStaple<Float,0,3>(staple,
arg, x,
parity);
break;
380 case 0: computeStaple<Float,1,0>(staple,
arg, x,
parity);
break;
381 case 2: computeStaple<Float,1,2>(staple,
arg, x,
parity);
break;
382 case 3: computeStaple<Float,1,3>(staple,
arg, x,
parity);
break;
386 case 0: computeStaple<Float,2,0>(staple,
arg, x,
parity);
break;
387 case 1: computeStaple<Float,2,1>(staple,
arg, x,
parity);
break;
388 case 3: computeStaple<Float,2,3>(staple,
arg, x,
parity);
break;
392 case 0: computeStaple<Float,3,0>(staple,
arg, x,
parity);
break;
393 case 1: computeStaple<Float,3,1>(staple,
arg, x,
parity);
break;
394 case 2: computeStaple<Float,3,2>(staple,
arg, x,
parity);
break;
399 if ( !(x[0] < arg.inner_border[0] || x[0] >= arg.inner_X[0] + arg.inner_border[0] ||
400 x[1] < arg.inner_border[1] || x[1] >= arg.inner_X[1] + arg.inner_border[1] ||
401 x[2] < arg.inner_border[2] || x[2] >= arg.inner_X[2] + arg.inner_border[2] ||
402 x[3] < arg.inner_border[3] || x[3] >= arg.inner_X[3] + arg.inner_border[3]) ) {
404 int inner_x[] = {x[0]-arg.inner_border[0], x[1]-arg.inner_border[1], x[2]-arg.inner_border[2], x[3]-arg.inner_border[3]};
405 Link fat = arg.fat(mu,
linkIndex(inner_x, arg.inner_X), parity);
406 fat += arg.coeff * staple;
407 arg.fat(mu,
linkIndex(inner_x, arg.inner_X), parity) = fat;
410 if (save_staple) arg.staple(mu,
linkIndex(x, arg.E), parity) = staple;
414 template <
typename Float,
typename Arg>
415 class Staple :
public TunableVectorYZ {
417 const GaugeField &meta;
418 unsigned int minThreads()
const {
return arg.threads; }
419 bool tuneGridDim()
const {
return false; }
426 Staple(Arg &arg,
int nu,
int dir1,
int dir2,
bool save_staple,
const GaugeField &meta)
427 : TunableVectorYZ(2,(3 - ( (dir1 > -1) ? 1 : 0 ) - ( (dir2 > -1) ? 1 : 0 ))),
428 arg(arg), meta(meta), nu(nu), dir1(dir1), dir2(dir2), save_staple(save_staple)
434 arg.n_mu = 3 - ( (dir1 > -1) ? 1 : 0 ) - ( (dir2 > -1) ? 1 : 0 );
436 for (
int i=0; i<4; i++) {
437 if (i==nu || i==dir1 || i==dir2)
continue;
440 assert(j == arg.n_mu);
444 void apply(
const cudaStream_t &
stream) {
447 computeStaple<Float,true><<<tp.grid,tp.block>>>(
arg, nu);
449 computeStaple<Float,false><<<tp.grid,tp.block>>>(
arg, nu);
452 TuneKey tuneKey()
const {
453 std::stringstream aux;
454 aux <<
"threads=" << arg.threads <<
",prec=" <<
sizeof(Float);
455 aux <<
",nu=" << nu <<
",dir1=" << dir1 <<
",dir2=" << dir2 <<
",save=" << save_staple;
456 return TuneKey(meta.VolString(),
typeid(*this).name(), aux.str().c_str());
459 void preTune() { arg.fat.save(); arg.staple.save(); }
460 void postTune() { arg.fat.load(); arg.staple.load(); }
462 long long flops()
const {
463 return 2*arg.n_mu*arg.threads*( 4*198 + 18 + 36 );
465 long long bytes()
const {
466 return arg.n_mu*2*meta.VolumeCB()*arg.fat.Bytes()*2
467 + arg.n_mu*2*arg.threads*(4*arg.u.Bytes() + 2*arg.mulink.Bytes() + (save_staple ? arg.staple.Bytes() : 0));
472 void computeStaple(GaugeField &fat, GaugeField &staple,
const GaugeField &mulink,
const GaugeField &u,
473 int nu,
int dir1,
int dir2,
double coeff,
bool save_staple) {
476 typedef typename gauge_mapper<double,QUDA_RECONSTRUCT_NO>::type L;
478 typedef StapleArg<double,L,L,L,L> Arg;
479 Arg
arg(L(fat), L(staple), L(mulink), L(u), coeff, fat, u);
480 Staple<double,Arg> stapler(arg, nu, dir1, dir2, save_staple, fat);
483 typedef typename gauge_mapper<double,QUDA_RECONSTRUCT_12,18,QUDA_STAGGERED_PHASE_MILC>::type G;
485 typedef StapleArg<double,L,L,L,G> Arg;
486 Arg
arg(L(fat), L(staple), L(mulink), G(u), coeff, fat, u);
487 Staple<double,Arg> stapler(arg, nu, dir1, dir2, save_staple, fat);
490 typedef StapleArg<double,L,L,G,G> Arg;
491 Arg
arg(L(fat), L(staple), G(mulink), G(u), coeff, fat, u);
492 Staple<double,Arg> stapler(arg, nu, dir1, dir2, save_staple, fat);
495 errorQuda(
"Reconstruct %d is not supported\n", u.Reconstruct());
498 errorQuda(
"Reconstruct %d is not supported\n", u.Reconstruct());
501 typedef typename gauge_mapper<float,QUDA_RECONSTRUCT_NO>::type L;
503 typedef StapleArg<float,L,L,L,L> Arg;
504 Arg
arg(L(fat), L(staple), L(mulink), L(u), coeff, fat, u);
505 Staple<float,Arg> stapler(arg, nu, dir1, dir2, save_staple, fat);
508 typedef typename gauge_mapper<float,QUDA_RECONSTRUCT_12,18,QUDA_STAGGERED_PHASE_MILC>::type G;
510 typedef StapleArg<double,L,L,L,G> Arg;
511 Arg
arg(L(fat), L(staple), L(mulink), G(u), coeff, fat, u);
512 Staple<float,Arg> stapler(arg, nu, dir1, dir2, save_staple, fat);
515 typedef StapleArg<double,L,L,G,G> Arg;
516 Arg
arg(L(fat), L(staple), G(mulink), G(u), coeff, fat, u);
517 Staple<float,Arg> stapler(arg, nu, dir1, dir2, save_staple, fat);
520 errorQuda(
"Reconstruct %d is not supported\n", u.Reconstruct());
523 errorQuda(
"Reconstruct %d is not supported\n", u.Reconstruct());
526 errorQuda(
"Unsupported precision %d\n", u.Precision());
543 if( ((fat->
X()[0] % 2 != 0) || (fat->
X()[1] % 2 != 0) || (fat->
X()[2] % 2 != 0) || (fat->
X()[3] % 2 != 0))
545 errorQuda(
"Reconstruct %d and odd dimensionsize is not supported by link fattening code (yet)\n",
549 computeOneLink(*fat, u, coeff[0]-6.0*coeff[5]);
552 if (lng) computeLongLink(*lng, u, coeff[1]);
558 for (
int nu = 0; nu < 4; nu++) {
561 if (coeff[5] != 0.0)
computeStaple(*fat, staple, staple, u, nu, -1, -1, coeff[5], 0);
563 for (
int rho = 0; rho < 4; rho++) {
566 computeStaple(*fat, staple1, staple, u, rho, nu, -1, coeff[3], 1);
569 for (
int sig = 0; sig < 4; sig++) {
570 if (sig != nu && sig != rho) {
571 computeStaple(*fat, staple, staple1, u, sig, nu, rho, coeff[4], 0);
582 errorQuda(
"Fat-link computation not enabled");
void fatLongKSLink(cudaGaugeField *fat, cudaGaugeField *lng, const cudaGaugeField &gauge, const double *coeff)
Compute the fat and long links for an improved staggered (Kogut-Susskind) fermions.
int commDimPartitioned(int dir)
static __device__ __host__ int linkIndexShift(const I x[], const J dx[], const K X[4])
static __device__ __host__ int linkIndex(const int x[], const I X[4])
QudaVerbosity getVerbosity()
#define qudaDeviceSynchronize()
__host__ __device__ void computeStaple(Arg &arg, int idx, int parity, int dir, Link &staple)
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Main header file for host and device accessors to GaugeFields.
void setPrecision(QudaPrecision precision, bool force_native=false)
Helper function for setting the precision and corresponding field order for QUDA internal fields...
QudaPrecision Precision() const
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
QudaReconstructType reconstruct
QudaReconstructType Reconstruct() const
__host__ __device__ ValueType conj(ValueType x)
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
__host__ __device__ int getCoords(int coord[], const Arg &arg, int &idx, int parity, int &dim)
Compute the space-time coordinates we are at.