QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
dslash_helper.cuh
Go to the documentation of this file.
1 #pragma once
2 
3 #include <color_spinor_field.h>
4 #include <gauge_field.h>
5 #include <register_traits.h>
6 #include <index_helper.cuh>
7 
8 namespace quda
9 {
10 
17  template <KernelType type> __host__ __device__ inline bool doHalo(int dim = -1)
18  {
19  switch (type) {
20  case EXTERIOR_KERNEL_ALL: return true;
21  case EXTERIOR_KERNEL_X: return dim == 0 || dim == -1 ? true : false;
22  case EXTERIOR_KERNEL_Y: return dim == 1 || dim == -1 ? true : false;
23  case EXTERIOR_KERNEL_Z: return dim == 2 || dim == -1 ? true : false;
24  case EXTERIOR_KERNEL_T: return dim == 3 || dim == -1 ? true : false;
25  case INTERIOR_KERNEL: return false;
26  }
27  return false;
28  }
29 
35  template <KernelType type> __host__ __device__ inline bool doBulk()
36  {
37  switch (type) {
39  case EXTERIOR_KERNEL_X:
40  case EXTERIOR_KERNEL_Y:
41  case EXTERIOR_KERNEL_Z:
42  case EXTERIOR_KERNEL_T: return false;
43  case INTERIOR_KERNEL: return true;
44  }
45  return false;
46  }
47 
55  template <KernelType type, typename Arg> __host__ __device__ inline bool isComplete(const Arg &arg, int coord[])
56  {
57 
58  int incomplete = 0; // Have all 8 contributions been computed for this site?
59 
60  switch (type) { // intentional fall-through
61  case EXTERIOR_KERNEL_ALL: incomplete = false; break; // all active threads are complete
62  case INTERIOR_KERNEL:
63  incomplete = incomplete || (arg.commDim[3] && (coord[3] == 0 || coord[3] == (arg.dc.X[3] - 1)));
64  case EXTERIOR_KERNEL_T:
65  incomplete = incomplete || (arg.commDim[2] && (coord[2] == 0 || coord[2] == (arg.dc.X[2] - 1)));
66  case EXTERIOR_KERNEL_Z:
67  incomplete = incomplete || (arg.commDim[1] && (coord[1] == 0 || coord[1] == (arg.dc.X[1] - 1)));
68  case EXTERIOR_KERNEL_Y:
69  incomplete = incomplete || (arg.commDim[0] && (coord[0] == 0 || coord[0] == (arg.dc.X[0] - 1)));
70  case EXTERIOR_KERNEL_X: break;
71  }
72 
73  return !incomplete;
74  }
75 
87  template <int nDim, QudaPCType pc_type, KernelType kernel_type, typename Arg, int nface_ = 1>
88  __host__ __device__ inline int getCoords(int coord[], const Arg &arg, int &idx, int parity, int &dim)
89  {
90 
91  int x_cb, X;
92  dim = kernel_type; // keep compiler happy
93 
94  // only for 5-d checkerboarding where we need to include the fifth dimension
95  const int Ls = (nDim == 5 && pc_type == QUDA_5D_PC ? (int)arg.dim[4] : 1);
96 
97  if (kernel_type == INTERIOR_KERNEL) {
98  x_cb = idx;
99  if (nDim == 5)
100  getCoords5CB(coord, idx, arg.dim, arg.X0h, parity, pc_type);
101  else
102  getCoordsCB(coord, idx, arg.dim, arg.X0h, parity);
103  } else if (kernel_type != EXTERIOR_KERNEL_ALL) {
104 
105  // compute face index and then compute coords
106  const int face_size = nface_ * arg.dc.ghostFaceCB[kernel_type] * Ls;
107  const int face_num = idx >= face_size;
108  idx -= face_num * face_size;
109  coordsFromFaceIndex<nDim, pc_type, kernel_type, nface_>(X, x_cb, coord, idx, face_num, parity, arg);
110 
111  } else { // fused kernel
112 
113  // work out which dimension this thread corresponds to, then compute coords
114  if (idx < arg.threadDimMapUpper[0] * Ls) { // x face
115  dim = 0;
116  const int face_size = nface_ * arg.dc.ghostFaceCB[dim] * Ls;
117  const int face_num = idx >= face_size;
118  idx -= face_num * face_size;
119  coordsFromFaceIndex<nDim, pc_type, 0, nface_>(X, x_cb, coord, idx, face_num, parity, arg);
120  } else if (idx < arg.threadDimMapUpper[1] * Ls) { // y face
121  dim = 1;
122  idx -= arg.threadDimMapLower[1] * Ls;
123  const int face_size = nface_ * arg.dc.ghostFaceCB[dim] * Ls;
124  const int face_num = idx >= face_size;
125  idx -= face_num * face_size;
126  coordsFromFaceIndex<nDim, pc_type, 1, nface_>(X, x_cb, coord, idx, face_num, parity, arg);
127  } else if (idx < arg.threadDimMapUpper[2] * Ls) { // z face
128  dim = 2;
129  idx -= arg.threadDimMapLower[2] * Ls;
130  const int face_size = nface_ * arg.dc.ghostFaceCB[dim] * Ls;
131  const int face_num = idx >= face_size;
132  idx -= face_num * face_size;
133  coordsFromFaceIndex<nDim, pc_type, 2, nface_>(X, x_cb, coord, idx, face_num, parity, arg);
134  } else { // t face
135  dim = 3;
136  idx -= arg.threadDimMapLower[3] * Ls;
137  const int face_size = nface_ * arg.dc.ghostFaceCB[dim] * Ls;
138  const int face_num = idx >= face_size;
139  idx -= face_num * face_size;
140  coordsFromFaceIndex<nDim, pc_type, 3, nface_>(X, x_cb, coord, idx, face_num, parity, arg);
141  }
142  }
143 
144  return x_cb;
145  }
146 
155  template <int dim, typename Arg> inline __host__ __device__ bool inBoundary(const int coord[], const Arg &arg)
156  {
157  return ((coord[dim] >= arg.dim[dim] - arg.nFace) || (coord[dim] < arg.nFace));
158  }
159 
187  template <KernelType kernel_type, typename Arg>
188  inline __device__ bool isActive(bool &active, int threadDim, int offsetDim, const int coord[], const Arg &arg)
189  {
190  // Threads with threadDim = t can handle t,z,y,x offsets
191  // Threads with threadDim = z can handle z,y,x offsets
192  // Threads with threadDim = y can handle y,x offsets
193  // Threads with threadDim = x can handle x offsets
194  if (!arg.ghostDim[offsetDim]) return false;
195 
196  if (kernel_type == EXTERIOR_KERNEL_ALL) {
197  if (threadDim < offsetDim) return false;
198 
199  switch (threadDim) {
200  case 3: // threadDim = T
201  break;
202 
203  case 2: // threadDim = Z
204  if (!arg.ghostDim[3]) break;
205  if (arg.ghostDim[3] && inBoundary<3>(coord, arg)) return false;
206  break;
207 
208  case 1: // threadDim = Y
209  if ((!arg.ghostDim[3]) && (!arg.ghostDim[2])) break;
210  if (arg.ghostDim[3] && inBoundary<3>(coord, arg)) return false;
211  if (arg.ghostDim[2] && inBoundary<2>(coord, arg)) return false;
212  break;
213 
214  case 0: // threadDim = X
215  if ((!arg.ghostDim[3]) && (!arg.ghostDim[2]) && (!arg.ghostDim[1])) break;
216  if (arg.ghostDim[3] && inBoundary<3>(coord, arg)) return false;
217  if (arg.ghostDim[2] && inBoundary<2>(coord, arg)) return false;
218  if (arg.ghostDim[1] && inBoundary<1>(coord, arg)) return false;
219  break;
220 
221  default: break;
222  }
223  }
224 
225  active = true;
226  return true;
227  }
228 
229  template <typename Float> struct DslashArg {
230 
231  typedef typename mapper<Float>::type real;
232 
233  const int parity; // only use this for single parity fields
234  const int nParity; // number of parities we're working on
235  const int nFace; // hard code to 1 for now
237 
239  const int_fastdiv dim[5]; // full lattice dimensions
240  const int volumeCB; // checkerboarded volume
241  int commDim[4]; // whether a given dimension is partitioned or not (potentially overridden for Schwarz)
242  int ghostDim[4]; // always equal to actual dimension partitioning (used inside kernel to ensure correct indexing)
243 
244  const bool dagger; // dagger
245  const bool xpay; // whether we are doing xpay or not
246 
247  real t_proj_scale; // factor to correct for T-dimensional spin projection
248 
249  DslashConstant dc; // pre-computed dslash constants for optimized indexing
250  KernelType kernel_type; // interior, exterior_t, etc.
251  bool remote_write; // used by the autotuner to switch on/off remote writing vs using copy engines
252 
253  int_fastdiv threads; // number of threads in x-thread dimension
256 
257  const bool spin_project; // whether to spin project nSpin=4 fields (generally true, except for, e.g., covariant derivative)
258 
259  // these are set with symmetric preconditioned twisted-mass dagger
260  // operator for the packing (which needs to a do a twist)
261  real twist_a; // scale factor
262  real twist_b; // chiral twist
263  real twist_c; // flavor twist
264 
265  // constructor needed for staggered to set xpay from derived class
266  DslashArg(const ColorSpinorField &in, const GaugeField &U, int parity, bool dagger, bool xpay, int nFace,
267  int spin_project, const int *comm_override) :
268  parity(parity),
269  nParity(in.SiteSubset()),
270  nFace(nFace),
271  reconstruct(U.Reconstruct()),
272  X0h(nParity == 2 ? in.X(0) / 2 : in.X(0)),
273  dim {(3 - nParity) * in.X(0), in.X(1), in.X(2), in.X(3), in.Ndim() == 5 ? in.X(4) : 1},
274  volumeCB(in.VolumeCB()),
275  dagger(dagger),
276  xpay(xpay),
278  threads(in.VolumeCB()),
279  threadDimMapLower {},
280  threadDimMapUpper {},
281  spin_project(spin_project),
282  twist_a(0.0),
283  twist_b(0.0),
284  twist_c(0.0)
285  {
286  for (int d = 0; d < 4; d++) {
287  ghostDim[d] = comm_dim_partitioned(d);
288  commDim[d] = (comm_override[d] == 0) ? 0 : comm_dim_partitioned(d);
289  }
290 
291  if (in.Location() == QUDA_CUDA_FIELD_LOCATION) {
292  // create comms buffers - need to do this before we grab the dslash constants
293  ColorSpinorField *in_ = const_cast<ColorSpinorField *>(&in);
294  static_cast<cudaColorSpinorField *>(in_)->createComms(nFace, spin_project);
295  }
296  dc = in.getDslashConstant();
297  }
298  };
299 
300  template <typename Float> std::ostream &operator<<(std::ostream &out, const DslashArg<Float> &arg)
301  {
302  out << "parity = " << arg.parity << std::endl;
303  out << "nParity = " << arg.nParity << std::endl;
304  out << "nFace = " << arg.nFace << std::endl;
305  out << "reconstruct = " << arg.reconstruct << std::endl;
306  out << "X0h = " << arg.X0h << std::endl;
307  out << "dim = { ";
308  for (int i = 0; i < 5; i++) out << arg.dim[i] << (i < 4 ? ", " : " }");
309  out << std::endl;
310  out << "commDim = { ";
311  for (int i = 0; i < 4; i++) out << arg.commDim[i] << (i < 3 ? ", " : " }");
312  out << std::endl;
313  out << "ghostDim = { ";
314  for (int i = 0; i < 4; i++) out << arg.ghostDim[i] << (i < 3 ? ", " : " }");
315  out << std::endl;
316  out << "volumeCB = " << arg.volumeCB << std::endl;
317  out << "dagger = " << arg.dagger << std::endl;
318  out << "xpay = " << arg.xpay << std::endl;
319  out << "kernel_type = " << arg.kernel_type << std::endl;
320  out << "remote_write = " << arg.remote_write << std::endl;
321  out << "threads = " << arg.threads << std::endl;
322  out << "threadDimMapLower = { ";
323  for (int i = 0; i < 4; i++) out << arg.threadDimMapLower[i] << (i < 3 ? ", " : " }");
324  out << std::endl;
325  out << "threadDimMapUpper = { ";
326  for (int i = 0; i < 4; i++) out << arg.threadDimMapUpper[i] << (i < 3 ? ", " : " }");
327  out << std::endl;
328  out << "twist_a = " << arg.twist_a;
329  out << "twist_b = " << arg.twist_b;
330  out << "twist_c = " << arg.twist_c;
331  return out;
332  }
333 
334 } // namespace quda
KernelType kernel_type
DslashConstant dc
Constants used by dslash and packing kernels.
static __device__ __host__ void getCoords5CB(int x[5], int cb_index, const I X[5], J X0h, int parity, QudaPCType pc_type)
const bool spin_project
static __device__ __host__ void getCoordsCB(int x[], int cb_index, const I X[], J X0h, int parity)
__host__ __device__ bool doBulk()
Helper function to determine if we should do interior computation.
const QudaReconstructType reconstruct
int Ls
Definition: test_util.cpp:38
int_fastdiv threads
const int_fastdiv dim[5]
mapper< Float >::type real
__device__ bool isActive(bool &active, int threadDim, int offsetDim, const int coord[], const Arg &arg)
Compute whether this thread should be active for updating the a given offsetDim halo. For non-fused halo update kernels this is a trivial kernel that just checks if the given dimension is partitioned and if so, return true.
const int_fastdiv X0h
cpuColorSpinorField * in
Generic reconstruction helper with no reconstruction.
__host__ __device__ bool doHalo(int dim=-1)
Helper function to determine if we should do halo computation.
Provides precision abstractions and defines the register precision given the storage precision using ...
int X[4]
Definition: covdev_test.cpp:70
__host__ __device__ bool inBoundary(const int coord[], const Arg &arg)
Compute whether the provided coordinate is within the halo region boundary of a given dimension...
cpuColorSpinorField * out
enum QudaReconstructType_s QudaReconstructType
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
const int * X() const
DslashArg(const ColorSpinorField &in, const GaugeField &U, int parity, bool dagger, bool xpay, int nFace, int spin_project, const int *comm_override)
QudaParity parity
Definition: covdev_test.cpp:54
int comm_dim_partitioned(int dim)
__host__ __device__ int getCoords(int coord[], const Arg &arg, int &idx, int parity, int &dim)
Compute the space-time coordinates we are at.
__host__ __device__ bool isComplete(const Arg &arg, int coord[])
Helper functon to determine if the application of the derivative in the dslash is complete...