11 template <
typename Float_,
int nColor_,
int nSpin_,
bool spin_project_ = true>
struct PackArg {
16 static constexpr
int nColor = nColor_;
17 static constexpr
int nSpin = nSpin_;
19 static constexpr
bool spin_project = (nSpin == 4 && spin_project_ ? true :
false);
51 in(in, nFace, nullptr, nullptr, reinterpret_cast<Float **>(ghost)),
55 nParity(in.SiteSubset()),
58 dc(in.getDslashConstant()),
62 twist((a != 0.0 && b != 0.0) ? (c != 0.0 ? 2 : 1) : 0)
68 for (
int i = 0; i < 4; i++) {
69 threadDimMapLower[i] = 0;
70 threadDimMapUpper[i] = 0;
71 if (!commDim[i])
continue;
72 threadDimMapLower[i] = (prev >= 0 ? threadDimMapUpper[prev] : 0);
73 threadDimMapUpper[i] = threadDimMapLower[i] + 2 * nFace * dc.
ghostFaceCB[i];
82 template <
bool dagger,
int twist,
int dim, QudaPCType pc,
typename Arg>
88 constexpr
int nFace = 1;
93 constexpr
int nDim = pc;
96 const int face_size = nFace * arg.dc.ghostFaceCB[dim] * (pc ==
QUDA_5D_PC ? arg.dc.Ls : 1);
98 int spinor_parity = (arg.
nParity == 2) ? parity : 0;
105 const int face_num = (ghost_idx >= face_size) ? 1 : 0;
106 ghost_idx -= face_num * face_size;
109 typedef typename std::remove_const<decltype(arg.in)>::type T;
110 T &
in =
const_cast<T &
>(arg.in);
114 int idx = indexFromFaceIndex<nDim, pc, dim, nFace, 0>(ghost_idx,
parity,
arg);
115 constexpr
int proj_dir =
dagger ? +1 : -1;
116 Vector f = arg.in(idx + s * arg.dc.volume_4d_cb, spinor_parity);
118 f = arg.a * (f + arg.b * f.igamma(4));
119 }
else if (
twist == 2) {
120 Vector f1 = arg.in(idx + (1 - s) * arg.dc.volume_4d_cb, spinor_parity);
122 f = arg.a * (f + arg.b * f.igamma(4) + arg.c * f1);
124 f = arg.a * (f - arg.b * f.igamma(4) + arg.c * f1);
126 if (arg.spin_project) {
127 in.Ghost(dim, 0, ghost_idx + s * arg.dc.ghostFaceCB[dim], spinor_parity) = f.project(dim, proj_dir);
129 in.Ghost(dim, 0, ghost_idx + s * arg.dc.ghostFaceCB[dim], spinor_parity) = f;
133 int idx = indexFromFaceIndex<nDim, pc, dim, nFace, 1>(ghost_idx,
parity,
arg);
134 constexpr
int proj_dir =
dagger ? -1 : +1;
135 Vector f = arg.in(idx + s * arg.dc.volume_4d_cb, spinor_parity);
137 f = arg.a * (f + arg.b * f.igamma(4));
138 }
else if (
twist == 2) {
139 Vector f1 = arg.in(idx + (1 - s) * arg.dc.volume_4d_cb, spinor_parity);
141 f = arg.a * (f + arg.b * f.igamma(4) + arg.c * f1);
143 f = arg.a * (f - arg.b * f.igamma(4) + arg.c * f1);
145 if (arg.spin_project) {
146 in.Ghost(dim, 1, ghost_idx + s * arg.dc.ghostFaceCB[dim], spinor_parity) = f.project(dim, proj_dir);
148 in.Ghost(dim, 1, ghost_idx + s * arg.dc.ghostFaceCB[dim], spinor_parity) = f;
153 template <
int dim,
int nFace = 1,
typename Arg>
159 int spinor_parity = (arg.
nParity == 2) ? parity : 0;
166 const int face_num = (ghost_idx >=
nFace * arg.dc.ghostFaceCB[dim]) ? 1 : 0;
167 ghost_idx -= face_num *
nFace * arg.dc.ghostFaceCB[dim];
170 typedef typename std::remove_const<decltype(arg.in)>::type T;
171 T &
in =
const_cast<T &
>(arg.in);
174 int idx = indexFromFaceIndexStaggered<4, QUDA_4D_PC, dim, nFace, 0>(ghost_idx,
parity,
arg);
175 Vector f = arg.in(idx + s * arg.dc.volume_4d_cb, spinor_parity);
176 in.Ghost(dim, 0, ghost_idx + s * arg.dc.ghostFaceCB[dim], spinor_parity) = f;
178 int idx = indexFromFaceIndexStaggered<4, QUDA_4D_PC, dim, nFace, 1>(ghost_idx,
parity,
arg);
179 Vector f = arg.in(idx + s * arg.dc.volume_4d_cb, spinor_parity);
180 in.Ghost(dim, 1, ghost_idx + s * arg.dc.ghostFaceCB[dim], spinor_parity) = f;
184 template <
bool dagger,
int twist, QudaPCType pc,
typename Arg> __global__
void packKernel(
Arg arg)
187 int local_tid = threadIdx.x;
188 int tid = sites_per_block * blockIdx.x + local_tid;
189 int s = blockDim.y * blockIdx.y + threadIdx.y;
190 if (s >= arg.dc.Ls)
return;
193 int parity = (arg.
nParity == 2) ? blockDim.z * blockIdx.z + threadIdx.z : arg.parity;
195 while (local_tid < sites_per_block && tid < arg.threads) {
203 case 0: pack<dagger, twist, 0, pc>(
arg, ghost_idx + s * arg.dc.ghostFace[0], 0,
parity);
break;
204 case 1: pack<dagger, twist, 1, pc>(
arg, ghost_idx + s * arg.dc.ghostFace[1], 0,
parity);
break;
205 case 2: pack<dagger, twist, 2, pc>(
arg, ghost_idx + s * arg.dc.ghostFace[2], 0,
parity);
break;
206 case 3: pack<dagger, twist, 3, pc>(
arg, ghost_idx + s * arg.dc.ghostFace[3], 0,
parity);
break;
210 case 0: pack<dagger, twist, 0, pc>(
arg, ghost_idx,
s,
parity);
break;
211 case 1: pack<dagger, twist, 1, pc>(
arg, ghost_idx,
s,
parity);
break;
212 case 2: pack<dagger, twist, 2, pc>(
arg, ghost_idx,
s,
parity);
break;
213 case 3: pack<dagger, twist, 3, pc>(
arg, ghost_idx,
s,
parity);
break;
217 local_tid += blockDim.x;
225 int local_block_idx = blockIdx.x % arg.blocks_per_dir;
226 int dim_dir = blockIdx.x / arg.blocks_per_dir;
227 int dir = dim_dir % 2;
229 switch (dim_dir / 2) {
230 case 0: dim = arg.dim_map[0];
break;
231 case 1: dim = arg.dim_map[1];
break;
232 case 2: dim = arg.dim_map[2];
break;
233 case 3: dim = arg.dim_map[3];
break;
236 int local_tid = local_block_idx * blockDim.x + threadIdx.x;
238 int s = blockDim.y * blockIdx.y + threadIdx.y;
239 if (s >= arg.dc.Ls)
return;
242 int parity = (arg.
nParity == 2) ? blockDim.z * blockIdx.z + threadIdx.z : arg.parity;
246 while (local_tid < arg.dc.ghostFaceCB[0]) {
247 int ghost_idx = dir * arg.dc.ghostFaceCB[0] + local_tid;
249 pack<dagger, twist, 0, pc>(
arg, ghost_idx + s * arg.dc.ghostFace[0], 0,
parity);
251 pack<dagger, twist, 0, pc>(
arg, ghost_idx,
s,
parity);
252 local_tid += arg.blocks_per_dir * blockDim.x;
256 while (local_tid < arg.dc.ghostFaceCB[1]) {
257 int ghost_idx = dir * arg.dc.ghostFaceCB[1] + local_tid;
259 pack<dagger, twist, 1, pc>(
arg, ghost_idx + s * arg.dc.ghostFace[1], 0,
parity);
261 pack<dagger, twist, 1, pc>(
arg, ghost_idx,
s,
parity);
262 local_tid += arg.blocks_per_dir * blockDim.x;
266 while (local_tid < arg.dc.ghostFaceCB[2]) {
267 int ghost_idx = dir * arg.dc.ghostFaceCB[2] + local_tid;
269 pack<dagger, twist, 2, pc>(
arg, ghost_idx + s * arg.dc.ghostFace[2], 0,
parity);
271 pack<dagger, twist, 2, pc>(
arg, ghost_idx,
s,
parity);
272 local_tid += arg.blocks_per_dir * blockDim.x;
276 while (local_tid < arg.dc.ghostFaceCB[3]) {
277 int ghost_idx = dir * arg.dc.ghostFaceCB[3] + local_tid;
279 pack<dagger, twist, 3, pc>(
arg, ghost_idx + s * arg.dc.ghostFace[3], 0,
parity);
281 pack<dagger, twist, 3, pc>(
arg, ghost_idx,
s,
parity);
282 local_tid += arg.blocks_per_dir * blockDim.x;
291 int local_tid = threadIdx.x;
292 int tid = sites_per_block * blockIdx.x + local_tid;
293 int s = blockDim.y * blockIdx.y + threadIdx.y;
294 if (s >= arg.dc.Ls)
return;
297 int parity = (arg.
nParity == 2) ? blockDim.z * blockIdx.z + threadIdx.z : arg.parity;
299 while (local_tid < sites_per_block && tid < arg.threads) {
304 if (arg.nFace == 1) {
306 case 0: packStaggered<0, 1>(
arg, ghost_idx,
s,
parity);
break;
307 case 1: packStaggered<1, 1>(
arg, ghost_idx,
s,
parity);
break;
308 case 2: packStaggered<2, 1>(
arg, ghost_idx,
s,
parity);
break;
309 case 3: packStaggered<3, 1>(
arg, ghost_idx,
s,
parity);
break;
311 }
else if (arg.nFace == 3) {
313 case 0: packStaggered<0, 3>(
arg, ghost_idx,
s,
parity);
break;
314 case 1: packStaggered<1, 3>(
arg, ghost_idx,
s,
parity);
break;
315 case 2: packStaggered<2, 3>(
arg, ghost_idx,
s,
parity);
break;
316 case 3: packStaggered<3, 3>(
arg, ghost_idx,
s,
parity);
break;
320 local_tid += blockDim.x;
328 int local_block_idx = blockIdx.x % arg.blocks_per_dir;
329 int dim_dir = blockIdx.x / arg.blocks_per_dir;
330 int dir = dim_dir % 2;
332 switch (dim_dir / 2) {
333 case 0: dim = arg.dim_map[0];
break;
334 case 1: dim = arg.dim_map[1];
break;
335 case 2: dim = arg.dim_map[2];
break;
336 case 3: dim = arg.dim_map[3];
break;
339 int local_tid = local_block_idx * blockDim.x + threadIdx.x;
341 int s = blockDim.y * blockIdx.y + threadIdx.y;
342 if (s >= arg.dc.Ls)
return;
345 int parity = (arg.
nParity == 2) ? blockDim.z * blockIdx.z + threadIdx.z : arg.parity;
349 while (local_tid < arg.nFace * arg.dc.ghostFaceCB[0]) {
350 int ghost_idx = dir * arg.nFace * arg.dc.ghostFaceCB[0] + local_tid;
352 packStaggered<0, 1>(
arg, ghost_idx,
s,
parity);
354 packStaggered<0, 3>(
arg, ghost_idx,
s,
parity);
355 local_tid += arg.blocks_per_dir * blockDim.x;
359 while (local_tid < arg.nFace * arg.dc.ghostFaceCB[1]) {
360 int ghost_idx = dir * arg.nFace * arg.dc.ghostFaceCB[1] + local_tid;
362 packStaggered<1, 1>(
arg, ghost_idx,
s,
parity);
364 packStaggered<1, 3>(
arg, ghost_idx,
s,
parity);
365 local_tid += arg.blocks_per_dir * blockDim.x;
369 while (local_tid < arg.nFace * arg.dc.ghostFaceCB[2]) {
370 int ghost_idx = dir * arg.nFace * arg.dc.ghostFaceCB[2] + local_tid;
372 packStaggered<2, 1>(
arg, ghost_idx,
s,
parity);
374 packStaggered<2, 3>(
arg, ghost_idx,
s,
parity);
375 local_tid += arg.blocks_per_dir * blockDim.x;
379 while (local_tid < arg.nFace * arg.dc.ghostFaceCB[3]) {
380 int ghost_idx = dir * arg.nFace * arg.dc.ghostFaceCB[3] + local_tid;
382 packStaggered<3, 1>(
arg, ghost_idx,
s,
parity);
384 packStaggered<3, 3>(
arg, ghost_idx,
s,
parity);
385 local_tid += arg.blocks_per_dir * blockDim.x;
PackArg(void **ghost, const ColorSpinorField &in, int nFace, bool dagger, int parity, int threads, double a, double b, double c)
__device__ __host__ void packStaggered(Arg &arg, int ghost_idx, int s, int parity)
Constants used by dslash and packing kernels.
__host__ __device__ int dimFromFaceIndex(int &face_idx, int tid, const Arg &arg)
Determines which face a given thread is computing. Also rescale face_idx so that is relative to a giv...
static constexpr int nSpin
enum QudaPCType_s QudaPCType
static constexpr int nColor
colorspinor_mapper< Float, nSpin, nColor, spin_project, spinor_direct_load >::type F
__global__ void packShmemKernel(Arg arg)
int ghostFaceCB[QUDA_MAX_DIM+1]
int_fastdiv blocks_per_dir
mapper< Float >::type real
static int commDim[QUDA_MAX_DIM]
__device__ __host__ void pack(Arg &arg, int ghost_idx, int s, int parity)
static constexpr bool spinor_direct_load
__global__ void packStaggeredShmemKernel(Arg arg)
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
__global__ void packStaggeredKernel(Arg arg)
#define QUDA_MAX_DIM
Maximum number of dimensions supported by QUDA. In practice, no routines make use of more than 5...
__global__ void packKernel(Arg arg)
QudaFieldOrder FieldOrder() const
static constexpr bool spin_project