QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
dslash_pack.cuh
Go to the documentation of this file.
2 #include <color_spinor.h>
3 #include <index_helper.cuh>
4 #include <dslash_helper.cuh>
5 
6 namespace quda
7 {
8 
9  static int commDim[QUDA_MAX_DIM];
10 
11  template <typename Float_, int nColor_, int nSpin_, bool spin_project_ = true> struct PackArg {
12 
13  typedef Float_ Float;
14  typedef typename mapper<Float>::type real;
15 
16  static constexpr int nColor = nColor_;
17  static constexpr int nSpin = nSpin_;
18 
19  static constexpr bool spin_project = (nSpin == 4 && spin_project_ ? true : false);
20  static constexpr bool spinor_direct_load = false; // false means texture load
22 
23  const F in; // field we are packing
24 
25  const int nFace;
26  const bool dagger;
27  const int parity; // only use this for single parity fields
28  const int nParity; // number of parities we are working on
29  const QudaPCType pc_type; // preconditioning type (4-d or 5-d)
30 
31  const DslashConstant dc; // pre-computed dslash constants for optimized indexing
32 
33  real a; // preconditioned twisted-mass scaling parameter
34  real b; // preconditioned twisted-mass chiral twist factor
35  real c; // preconditioned twisted-mass flavor twist factor
36  int twist; // whether we are doing preconditioned twisted-mass or not (1 - singlet, 2 - doublet)
37 
41 
43  int dim_map[4];
45 
48 
49  PackArg(void **ghost, const ColorSpinorField &in, int nFace, bool dagger, int parity, int threads, double a,
50  double b, double c) :
51  in(in, nFace, nullptr, nullptr, reinterpret_cast<Float **>(ghost)),
52  nFace(nFace),
53  dagger(dagger),
54  parity(parity),
55  nParity(in.SiteSubset()),
56  threads(threads),
57  pc_type(in.PCType()),
58  dc(in.getDslashConstant()),
59  a(a),
60  b(b),
61  c(c),
62  twist((a != 0.0 && b != 0.0) ? (c != 0.0 ? 2 : 1) : 0)
63  {
64  if (!in.isNative()) errorQuda("Unsupported field order colorspinor=%d\n", in.FieldOrder());
65 
66  int d = 0;
67  int prev = -1; // previous dimension that was partitioned
68  for (int i = 0; i < 4; i++) {
69  threadDimMapLower[i] = 0;
70  threadDimMapUpper[i] = 0;
71  if (!commDim[i]) continue;
72  threadDimMapLower[i] = (prev >= 0 ? threadDimMapUpper[prev] : 0);
73  threadDimMapUpper[i] = threadDimMapLower[i] + 2 * nFace * dc.ghostFaceCB[i];
74  prev = i;
75 
76  dim_map[d++] = i;
77  }
78  active_dims = d;
79  }
80  };
81 
82  template <bool dagger, int twist, int dim, QudaPCType pc, typename Arg>
83  __device__ __host__ inline void pack(Arg &arg, int ghost_idx, int s, int parity)
84  {
85 
86  typedef typename mapper<typename Arg::Float>::type real;
88  constexpr int nFace = 1;
89 
90  // this means we treat 4-d preconditioned fields as 4-d fields,
91  // and don't fold in any fifth dimension until after we have
92  // computed the 4-d indices (saves division)
93  constexpr int nDim = pc;
94 
95  // for 5-d preconditioning the face_size includes the Ls dimension
96  const int face_size = nFace * arg.dc.ghostFaceCB[dim] * (pc == QUDA_5D_PC ? arg.dc.Ls : 1);
97 
98  int spinor_parity = (arg.nParity == 2) ? parity : 0;
99 
100  // compute where the output is located
101  // compute an index into the local volume from the index into the face
102  // read spinor, spin-project, and write half spinor to face
103 
104  // face_num determines which end of the lattice we are packing: 0 = start, 1 = end
105  const int face_num = (ghost_idx >= face_size) ? 1 : 0;
106  ghost_idx -= face_num * face_size;
107 
108  // remove const to ensure we have non-const Ghost member
109  typedef typename std::remove_const<decltype(arg.in)>::type T;
110  T &in = const_cast<T &>(arg.in);
111 
112  if (face_num == 0) { // backwards
113 
114  int idx = indexFromFaceIndex<nDim, pc, dim, nFace, 0>(ghost_idx, parity, arg);
115  constexpr int proj_dir = dagger ? +1 : -1;
116  Vector f = arg.in(idx + s * arg.dc.volume_4d_cb, spinor_parity);
117  if (twist == 1) {
118  f = arg.a * (f + arg.b * f.igamma(4));
119  } else if (twist == 2) {
120  Vector f1 = arg.in(idx + (1 - s) * arg.dc.volume_4d_cb, spinor_parity); // load other flavor
121  if (s == 0)
122  f = arg.a * (f + arg.b * f.igamma(4) + arg.c * f1);
123  else
124  f = arg.a * (f - arg.b * f.igamma(4) + arg.c * f1);
125  }
126  if (arg.spin_project) {
127  in.Ghost(dim, 0, ghost_idx + s * arg.dc.ghostFaceCB[dim], spinor_parity) = f.project(dim, proj_dir);
128  } else {
129  in.Ghost(dim, 0, ghost_idx + s * arg.dc.ghostFaceCB[dim], spinor_parity) = f;
130  }
131  } else { // forwards
132 
133  int idx = indexFromFaceIndex<nDim, pc, dim, nFace, 1>(ghost_idx, parity, arg);
134  constexpr int proj_dir = dagger ? -1 : +1;
135  Vector f = arg.in(idx + s * arg.dc.volume_4d_cb, spinor_parity);
136  if (twist == 1) {
137  f = arg.a * (f + arg.b * f.igamma(4));
138  } else if (twist == 2) {
139  Vector f1 = arg.in(idx + (1 - s) * arg.dc.volume_4d_cb, spinor_parity); // load other flavor
140  if (s == 0)
141  f = arg.a * (f + arg.b * f.igamma(4) + arg.c * f1);
142  else
143  f = arg.a * (f - arg.b * f.igamma(4) + arg.c * f1);
144  }
145  if (arg.spin_project) {
146  in.Ghost(dim, 1, ghost_idx + s * arg.dc.ghostFaceCB[dim], spinor_parity) = f.project(dim, proj_dir);
147  } else {
148  in.Ghost(dim, 1, ghost_idx + s * arg.dc.ghostFaceCB[dim], spinor_parity) = f;
149  }
150  }
151  }
152 
153  template <int dim, int nFace = 1, typename Arg>
154  __device__ __host__ inline void packStaggered(Arg &arg, int ghost_idx, int s, int parity)
155  {
156  typedef typename mapper<typename Arg::Float>::type real;
158 
159  int spinor_parity = (arg.nParity == 2) ? parity : 0;
160 
161  // compute where the output is located
162  // compute an index into the local volume from the index into the face
163  // read spinor and write spinor to face buffer
164 
165  // face_num determines which end of the lattice we are packing: 0 = start, 1 = end
166  const int face_num = (ghost_idx >= nFace * arg.dc.ghostFaceCB[dim]) ? 1 : 0;
167  ghost_idx -= face_num * nFace * arg.dc.ghostFaceCB[dim];
168 
169  // remove const to ensure we have non-const Ghost member
170  typedef typename std::remove_const<decltype(arg.in)>::type T;
171  T &in = const_cast<T &>(arg.in);
172 
173  if (face_num == 0) { // backwards
174  int idx = indexFromFaceIndexStaggered<4, QUDA_4D_PC, dim, nFace, 0>(ghost_idx, parity, arg);
175  Vector f = arg.in(idx + s * arg.dc.volume_4d_cb, spinor_parity);
176  in.Ghost(dim, 0, ghost_idx + s * arg.dc.ghostFaceCB[dim], spinor_parity) = f;
177  } else { // forwards
178  int idx = indexFromFaceIndexStaggered<4, QUDA_4D_PC, dim, nFace, 1>(ghost_idx, parity, arg);
179  Vector f = arg.in(idx + s * arg.dc.volume_4d_cb, spinor_parity);
180  in.Ghost(dim, 1, ghost_idx + s * arg.dc.ghostFaceCB[dim], spinor_parity) = f;
181  }
182  }
183 
184  template <bool dagger, int twist, QudaPCType pc, typename Arg> __global__ void packKernel(Arg arg)
185  {
186  const int sites_per_block = arg.sites_per_block;
187  int local_tid = threadIdx.x;
188  int tid = sites_per_block * blockIdx.x + local_tid;
189  int s = blockDim.y * blockIdx.y + threadIdx.y;
190  if (s >= arg.dc.Ls) return;
191 
192  // this is the parity used for load/store, but we use arg.parity for index mapping
193  int parity = (arg.nParity == 2) ? blockDim.z * blockIdx.z + threadIdx.z : arg.parity;
194 
195  while (local_tid < sites_per_block && tid < arg.threads) {
196 
197  // determine which dimension we are packing
198  int ghost_idx;
199  const int dim = dimFromFaceIndex(ghost_idx, tid, arg);
200 
201  if (pc == QUDA_5D_PC) { // 5-d checkerboarded, include s (not ghostFaceCB since both faces)
202  switch (dim) {
203  case 0: pack<dagger, twist, 0, pc>(arg, ghost_idx + s * arg.dc.ghostFace[0], 0, parity); break;
204  case 1: pack<dagger, twist, 1, pc>(arg, ghost_idx + s * arg.dc.ghostFace[1], 0, parity); break;
205  case 2: pack<dagger, twist, 2, pc>(arg, ghost_idx + s * arg.dc.ghostFace[2], 0, parity); break;
206  case 3: pack<dagger, twist, 3, pc>(arg, ghost_idx + s * arg.dc.ghostFace[3], 0, parity); break;
207  }
208  } else { // 4-d checkerboarding, keeping s separate (if it exists)
209  switch (dim) {
210  case 0: pack<dagger, twist, 0, pc>(arg, ghost_idx, s, parity); break;
211  case 1: pack<dagger, twist, 1, pc>(arg, ghost_idx, s, parity); break;
212  case 2: pack<dagger, twist, 2, pc>(arg, ghost_idx, s, parity); break;
213  case 3: pack<dagger, twist, 3, pc>(arg, ghost_idx, s, parity); break;
214  }
215  }
216 
217  local_tid += blockDim.x;
218  tid += blockDim.x;
219  } // while tid
220  }
221 
222  template <bool dagger, int twist, QudaPCType pc, typename Arg> __global__ void packShmemKernel(Arg arg)
223  {
224  // (active_dims * 2 + dir) * blocks_per_dir + local_block_idx
225  int local_block_idx = blockIdx.x % arg.blocks_per_dir;
226  int dim_dir = blockIdx.x / arg.blocks_per_dir;
227  int dir = dim_dir % 2;
228  int dim;
229  switch (dim_dir / 2) {
230  case 0: dim = arg.dim_map[0]; break;
231  case 1: dim = arg.dim_map[1]; break;
232  case 2: dim = arg.dim_map[2]; break;
233  case 3: dim = arg.dim_map[3]; break;
234  }
235 
236  int local_tid = local_block_idx * blockDim.x + threadIdx.x;
237 
238  int s = blockDim.y * blockIdx.y + threadIdx.y;
239  if (s >= arg.dc.Ls) return;
240 
241  // this is the parity used for load/store, but we use arg.parity for index mapping
242  int parity = (arg.nParity == 2) ? blockDim.z * blockIdx.z + threadIdx.z : arg.parity;
243 
244  switch (dim) {
245  case 0:
246  while (local_tid < arg.dc.ghostFaceCB[0]) {
247  int ghost_idx = dir * arg.dc.ghostFaceCB[0] + local_tid;
248  if (pc == QUDA_5D_PC)
249  pack<dagger, twist, 0, pc>(arg, ghost_idx + s * arg.dc.ghostFace[0], 0, parity);
250  else
251  pack<dagger, twist, 0, pc>(arg, ghost_idx, s, parity);
252  local_tid += arg.blocks_per_dir * blockDim.x;
253  }
254  break;
255  case 1:
256  while (local_tid < arg.dc.ghostFaceCB[1]) {
257  int ghost_idx = dir * arg.dc.ghostFaceCB[1] + local_tid;
258  if (pc == QUDA_5D_PC)
259  pack<dagger, twist, 1, pc>(arg, ghost_idx + s * arg.dc.ghostFace[1], 0, parity);
260  else
261  pack<dagger, twist, 1, pc>(arg, ghost_idx, s, parity);
262  local_tid += arg.blocks_per_dir * blockDim.x;
263  }
264  break;
265  case 2:
266  while (local_tid < arg.dc.ghostFaceCB[2]) {
267  int ghost_idx = dir * arg.dc.ghostFaceCB[2] + local_tid;
268  if (pc == QUDA_5D_PC)
269  pack<dagger, twist, 2, pc>(arg, ghost_idx + s * arg.dc.ghostFace[2], 0, parity);
270  else
271  pack<dagger, twist, 2, pc>(arg, ghost_idx, s, parity);
272  local_tid += arg.blocks_per_dir * blockDim.x;
273  }
274  break;
275  case 3:
276  while (local_tid < arg.dc.ghostFaceCB[3]) {
277  int ghost_idx = dir * arg.dc.ghostFaceCB[3] + local_tid;
278  if (pc == QUDA_5D_PC)
279  pack<dagger, twist, 3, pc>(arg, ghost_idx + s * arg.dc.ghostFace[3], 0, parity);
280  else
281  pack<dagger, twist, 3, pc>(arg, ghost_idx, s, parity);
282  local_tid += arg.blocks_per_dir * blockDim.x;
283  }
284  break;
285  }
286  }
287 
288  template <typename Arg> __global__ void packStaggeredKernel(Arg arg)
289  {
290  const int sites_per_block = arg.sites_per_block;
291  int local_tid = threadIdx.x;
292  int tid = sites_per_block * blockIdx.x + local_tid;
293  int s = blockDim.y * blockIdx.y + threadIdx.y;
294  if (s >= arg.dc.Ls) return;
295 
296  // this is the parity used for load/store, but we use arg.parity for index mapping
297  int parity = (arg.nParity == 2) ? blockDim.z * blockIdx.z + threadIdx.z : arg.parity;
298 
299  while (local_tid < sites_per_block && tid < arg.threads) {
300  // determine which dimension we are packing
301  int ghost_idx;
302  const int dim = dimFromFaceIndex(ghost_idx, tid, arg);
303 
304  if (arg.nFace == 1) {
305  switch (dim) {
306  case 0: packStaggered<0, 1>(arg, ghost_idx, s, parity); break;
307  case 1: packStaggered<1, 1>(arg, ghost_idx, s, parity); break;
308  case 2: packStaggered<2, 1>(arg, ghost_idx, s, parity); break;
309  case 3: packStaggered<3, 1>(arg, ghost_idx, s, parity); break;
310  }
311  } else if (arg.nFace == 3) {
312  switch (dim) {
313  case 0: packStaggered<0, 3>(arg, ghost_idx, s, parity); break;
314  case 1: packStaggered<1, 3>(arg, ghost_idx, s, parity); break;
315  case 2: packStaggered<2, 3>(arg, ghost_idx, s, parity); break;
316  case 3: packStaggered<3, 3>(arg, ghost_idx, s, parity); break;
317  }
318  }
319 
320  local_tid += blockDim.x;
321  tid += blockDim.x;
322  } // while tid
323  }
324 
325  template <typename Arg> __global__ void packStaggeredShmemKernel(Arg arg)
326  {
327  // (active_dims * 2 + dir) * blocks_per_dir + local_block_idx
328  int local_block_idx = blockIdx.x % arg.blocks_per_dir;
329  int dim_dir = blockIdx.x / arg.blocks_per_dir;
330  int dir = dim_dir % 2;
331  int dim;
332  switch (dim_dir / 2) {
333  case 0: dim = arg.dim_map[0]; break;
334  case 1: dim = arg.dim_map[1]; break;
335  case 2: dim = arg.dim_map[2]; break;
336  case 3: dim = arg.dim_map[3]; break;
337  }
338 
339  int local_tid = local_block_idx * blockDim.x + threadIdx.x;
340 
341  int s = blockDim.y * blockIdx.y + threadIdx.y;
342  if (s >= arg.dc.Ls) return;
343 
344  // this is the parity used for load/store, but we use arg.parity for index mapping
345  int parity = (arg.nParity == 2) ? blockDim.z * blockIdx.z + threadIdx.z : arg.parity;
346 
347  switch (dim) {
348  case 0:
349  while (local_tid < arg.nFace * arg.dc.ghostFaceCB[0]) {
350  int ghost_idx = dir * arg.nFace * arg.dc.ghostFaceCB[0] + local_tid;
351  if (arg.nFace == 1)
352  packStaggered<0, 1>(arg, ghost_idx, s, parity);
353  else
354  packStaggered<0, 3>(arg, ghost_idx, s, parity);
355  local_tid += arg.blocks_per_dir * blockDim.x;
356  }
357  break;
358  case 1:
359  while (local_tid < arg.nFace * arg.dc.ghostFaceCB[1]) {
360  int ghost_idx = dir * arg.nFace * arg.dc.ghostFaceCB[1] + local_tid;
361  if (arg.nFace == 1)
362  packStaggered<1, 1>(arg, ghost_idx, s, parity);
363  else
364  packStaggered<1, 3>(arg, ghost_idx, s, parity);
365  local_tid += arg.blocks_per_dir * blockDim.x;
366  }
367  break;
368  case 2:
369  while (local_tid < arg.nFace * arg.dc.ghostFaceCB[2]) {
370  int ghost_idx = dir * arg.nFace * arg.dc.ghostFaceCB[2] + local_tid;
371  if (arg.nFace == 1)
372  packStaggered<2, 1>(arg, ghost_idx, s, parity);
373  else
374  packStaggered<2, 3>(arg, ghost_idx, s, parity);
375  local_tid += arg.blocks_per_dir * blockDim.x;
376  }
377  break;
378  case 3:
379  while (local_tid < arg.nFace * arg.dc.ghostFaceCB[3]) {
380  int ghost_idx = dir * arg.nFace * arg.dc.ghostFaceCB[3] + local_tid;
381  if (arg.nFace == 1)
382  packStaggered<3, 1>(arg, ghost_idx, s, parity);
383  else
384  packStaggered<3, 3>(arg, ghost_idx, s, parity);
385  local_tid += arg.blocks_per_dir * blockDim.x;
386  }
387  break;
388  }
389  }
390 
391 } // namespace quda
const int nParity
Definition: dslash_pack.cuh:28
PackArg(void **ghost, const ColorSpinorField &in, int nFace, bool dagger, int parity, int threads, double a, double b, double c)
Definition: dslash_pack.cuh:49
__device__ __host__ void packStaggered(Arg &arg, int ghost_idx, int s, int parity)
Constants used by dslash and packing kernels.
__host__ __device__ int dimFromFaceIndex(int &face_idx, int tid, const Arg &arg)
Determines which face a given thread is computing. Also rescale face_idx so that is relative to a giv...
#define errorQuda(...)
Definition: util_quda.h:121
static constexpr int nSpin
Definition: dslash_pack.cuh:17
enum QudaPCType_s QudaPCType
static constexpr int nColor
Definition: dslash_pack.cuh:16
const int nFace
Definition: dslash_pack.cuh:25
const DslashConstant dc
Definition: dslash_pack.cuh:31
int threadDimMapUpper[4]
Definition: dslash_pack.cuh:40
const bool dagger
Definition: dslash_pack.cuh:26
int_fastdiv swizzle
Definition: dslash_pack.cuh:46
colorspinor_mapper< Float, nSpin, nColor, spin_project, spinor_direct_load >::type F
Definition: dslash_pack.cuh:21
int threadDimMapLower[4]
Definition: dslash_pack.cuh:39
__global__ void packShmemKernel(Arg arg)
int ghostFaceCB[QUDA_MAX_DIM+1]
int_fastdiv blocks_per_dir
Definition: dslash_pack.cuh:42
const QudaPCType pc_type
Definition: dslash_pack.cuh:29
mapper< Float >::type real
Definition: dslash_pack.cuh:14
static int commDim[QUDA_MAX_DIM]
Definition: dslash_pack.cuh:9
const int nParity
Definition: spinor_noise.cu:25
__device__ __host__ void pack(Arg &arg, int ghost_idx, int s, int parity)
Definition: dslash_pack.cuh:83
__shared__ float s[]
static constexpr bool spinor_direct_load
Definition: dslash_pack.cuh:20
const int parity
Definition: dslash_pack.cuh:27
__global__ void packStaggeredShmemKernel(Arg arg)
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
VectorXcd Vector
__global__ void packStaggeredKernel(Arg arg)
#define QUDA_MAX_DIM
Maximum number of dimensions supported by QUDA. In practice, no routines make use of more than 5...
int_fastdiv threads
Definition: dslash_pack.cuh:38
__global__ void packKernel(Arg arg)
QudaFieldOrder FieldOrder() const
static constexpr bool spin_project
Definition: dslash_pack.cuh:19