QUDA  v1.1.0
A library for QCD on GPUs
color_spinor_pack.cu
Go to the documentation of this file.
1 #include <color_spinor_field.h>
2 #include <tune_quda.h>
3 
4 #include <jitify_helper.cuh>
5 #include <kernels/color_spinor_pack.cuh>
6 
7 /**
8  @file color_spinor_pack.cu
9 
10  @brief This is the implementation of the color-spinor halo packer
11  for an arbitrary field. This implementation uses the fine-grained
12  accessors and should support all field types reqgardless of
13  precision, number of color or spins etc.
14 
15  Using a different precision of the field and of the halo is
16  supported, though only QUDA_SINGLE_PRECISION fields with
17  QUDA_HALF_PRECISION or QUDA_QUARTER_PRECISION halos are
18  instantiated. When an integer format is requested for the halos
19  then block-float format is used.
20 
21  As well as tuning basic block sizes, the autotuner also tunes for
22  the dimensions to assign to each thread. E.g., dim_thread=1 means
23  we have one thread for all dimensions, dim_thread=4 means we have
24  four threads (e.g., one per dimension). We always uses seperate
25  threads for forwards and backwards directions. Dimension,
26  direction and parity are assigned to the z thread dimension.
27 
28  If doing block-float format, since all spin and color components of
29  a given site have to reside in the same thread block (to allow us
30  to compute the max element) we override the autotuner to keep the z
31  thread dimensions in the grid and not the block, and allow for
32  smaller tuning increments of the thread block dimension in x to
33  ensure that we can always fit within a single thread block. It is
34  this constraint that gives rise for the need to cap the limit for
35  block-float support, e.g., max_block_float_nc.
36 
37  At present we launch a volume of threads (actually multiples
38  thereof for direction / dimension) and thus we have coalesced reads
39  but not coalesced writes. A more optimal implementation will
40  launch a surface of threads for each halo giving coalesced writes.
41  */
42 
43 namespace quda {
44 
45  template <typename Float, bool block_float, int Ns, int Ms, int Nc, int Mc, typename Arg>
46  class GenericPackGhostLauncher : public TunableVectorYZ {
47  Arg &arg;
48  const ColorSpinorField &meta;
49  unsigned int minThreads() const { return arg.volumeCB; }
50  bool tuneGridDim() const { return false; }
51  bool tuneAuxDim() const { return true; }
52 
53  public:
54  inline GenericPackGhostLauncher(Arg &arg, const ColorSpinorField &meta, MemoryLocation *destination)
55  : TunableVectorYZ((Ns/Ms)*(Nc/Mc), 2*arg.nParity), arg(arg), meta(meta) {
56 
57  if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
58 #ifdef JITIFY
59  create_jitify_program("kernels/color_spinor_pack.cuh");
60 #endif
61  }
62 
63  strcpy(aux,compile_type_str(meta));
64  strcat(aux,meta.AuxString());
65  switch(meta.GhostPrecision()) {
66  case QUDA_DOUBLE_PRECISION: strcat(aux,",halo_prec=8"); break;
67  case QUDA_SINGLE_PRECISION: strcat(aux,",halo_prec=4"); break;
68  case QUDA_HALF_PRECISION: strcat(aux,",halo_prec=2"); break;
69  case QUDA_QUARTER_PRECISION: strcat(aux,",halo_prec=1"); break;
70  default: errorQuda("Unexpected precision = %d", meta.GhostPrecision());
71  }
72  strcat(aux,comm_dim_partitioned_string());
73  strcat(aux,comm_dim_topology_string());
74 
75  // record the location of where each pack buffer is in [2*dim+dir] ordering
76  // 0 - no packing
77  // 1 - pack to local GPU memory
78  // 2 - pack to local mapped CPU memory
79  // 3 - pack to remote mapped GPU memory
80  char label[15] = ",dest=";
81  for (int dim=0; dim<4; dim++) {
82  for (int dir=0; dir<2; dir++) {
83  label[2*dim+dir+6] = !comm_dim_partitioned(dim) ? '0' : destination[2*dim+dir] == Device ? '1' : destination[2*dim+dir] == Host ? '2' : '3';
84  }
85  }
86  label[14] = '\0';
87  strcat(aux,label);
88  }
89 
90  inline void apply(const qudaStream_t &stream) {
91  if (meta.Location() == QUDA_CPU_FIELD_LOCATION) {
92  if (arg.nDim == 5) GenericPackGhost<Float,block_float,Ns,Ms,Nc,Mc,5,Arg>(arg);
93  else GenericPackGhost<Float,block_float,Ns,Ms,Nc,Mc,4,Arg>(arg);
94  } else {
95  const TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
96  arg.nParity2dim_threads = arg.nParity*2*tp.aux.x;
97 #ifdef JITIFY
98  using namespace jitify::reflection;
99  jitify_error = program->kernel("quda::GenericPackGhostKernel")
100  .instantiate(Type<Float>(),block_float,Ns,Ms,Nc,Mc,arg.nDim,(int)tp.aux.x,Type<Arg>())
101  .configure(tp.grid,tp.block,tp.shared_bytes,stream).launch(arg);
102 #else
103  switch(tp.aux.x) {
104  case 1:
105  if (arg.nDim == 5) qudaLaunchKernel(GenericPackGhostKernel<Float,block_float,Ns,Ms,Nc,Mc,5,1,Arg>, tp, stream, arg);
106  else qudaLaunchKernel(GenericPackGhostKernel<Float,block_float,Ns,Ms,Nc,Mc,4,1,Arg>, tp, stream, arg);
107  break;
108  case 2:
109  if (arg.nDim == 5) qudaLaunchKernel(GenericPackGhostKernel<Float,block_float,Ns,Ms,Nc,Mc,5,2,Arg>, tp, stream, arg);
110  else qudaLaunchKernel(GenericPackGhostKernel<Float,block_float,Ns,Ms,Nc,Mc,4,2,Arg>, tp, stream, arg);
111  break;
112  case 4:
113  if (arg.nDim == 5) qudaLaunchKernel(GenericPackGhostKernel<Float,block_float,Ns,Ms,Nc,Mc,5,4,Arg>, tp, stream, arg);
114  else qudaLaunchKernel(GenericPackGhostKernel<Float,block_float,Ns,Ms,Nc,Mc,4,4,Arg>, tp, stream, arg);
115  break;
116  }
117 #endif
118  }
119  }
120 
121  // if doing block float then all spin-color components must be within the same block
122  void setColorSpinBlock(TuneParam &param) const {
123  param.block.y = (Ns/Ms)*(Nc/Mc);
124  param.grid.y = 1;
125  param.block.z = 1;
126  param.grid.z = arg.nParity*2*param.aux.x;
127  }
128 
129  bool advanceBlockDim(TuneParam &param) const {
130  if (!block_float) {
131  return TunableVectorYZ::advanceBlockDim(param);
132  } else {
133  bool advance = Tunable::advanceBlockDim(param);
134  setColorSpinBlock(param); // if doing block float then all spin-color components must be within the same block
135  return advance;
136  }
137  }
138 
139  int blockStep() const { return block_float ? 2 : TunableVectorYZ::blockStep(); }
140  int blockMin() const { return block_float ? 2 : TunableVectorYZ::blockMin(); }
141 
142  bool advanceAux(TuneParam &param) const {
143  if (param.aux.x < 4) {
144  param.aux.x *= 2;
145  const_cast<GenericPackGhostLauncher*>(this)->resizeVector((Ns/Ms)*(Nc/Mc), arg.nParity*2*param.aux.x);
146  TunableVectorYZ::initTuneParam(param);
147  if (block_float) setColorSpinBlock(param);
148  return true;
149  }
150  param.aux.x = 1;
151  const_cast<GenericPackGhostLauncher*>(this)->resizeVector((Ns/Ms)*(Nc/Mc), arg.nParity*2*param.aux.x);
152  TunableVectorYZ::initTuneParam(param);
153  if (block_float) setColorSpinBlock(param);
154  return false;
155  }
156 
157  TuneKey tuneKey() const {
158  return TuneKey(meta.VolString(), typeid(*this).name(), aux);
159  }
160 
161  virtual void initTuneParam(TuneParam &param) const {
162  TunableVectorYZ::initTuneParam(param);
163  param.aux = make_int4(1,1,1,1);
164  if (block_float) setColorSpinBlock(param);
165  }
166 
167  virtual void defaultTuneParam(TuneParam &param) const {
168  TunableVectorYZ::defaultTuneParam(param);
169  param.aux = make_int4(1,1,1,1);
170  if (block_float) setColorSpinBlock(param);
171  }
172 
173  long long flops() const { return 0; }
174  long long bytes() const {
175  size_t totalBytes = 0;
176  for (int d=0; d<4; d++) {
177  if (!comm_dim_partitioned(d)) continue;
178  totalBytes += arg.nFace*2*Ns*Nc*meta.SurfaceCB(d)*(meta.Precision() + meta.GhostPrecision());
179  }
180  return totalBytes;
181  }
182  };
183 
184  template <typename Float, typename ghostFloat, QudaFieldOrder order, int Ns, int Nc>
185  inline void genericPackGhost(void **ghost, const ColorSpinorField &a, QudaParity parity,
186  int nFace, int dagger, MemoryLocation *destination)
187  {
188  typedef typename mapper<Float>::type RegFloat;
189  typedef typename colorspinor::FieldOrderCB<RegFloat,Ns,Nc,1,order,Float,ghostFloat> Q;
190  Q field(a, nFace, 0, ghost);
191 
192  constexpr int spins_per_thread = Ns == 1 ? 1 : 2; // make this autotunable?
193  constexpr int colors_per_thread = Nc%2 == 0 ? 2 : 1;
194  PackGhostArg<Q> arg(field, a, parity, nFace, dagger);
195 
196  constexpr bool block_float_requested = sizeof(Float) == QUDA_SINGLE_PRECISION &&
197  (sizeof(ghostFloat) == QUDA_HALF_PRECISION || sizeof(ghostFloat) == QUDA_QUARTER_PRECISION);
198 
199  // if we only have short precision for the ghost then this means we have block-float
200  constexpr bool block_float = block_float_requested && Nc <= max_block_float_nc;
201 
202  // ensure we only compile supported block-float kernels
203  constexpr int Nc_ = (block_float_requested && Nc > max_block_float_nc) ? max_block_float_nc : Nc;
204 
205  if (block_float_requested && Nc > max_block_float_nc)
206  errorQuda("Block-float format not supported for Nc = %d", Nc);
207 
208  GenericPackGhostLauncher<RegFloat,block_float,Ns,spins_per_thread,Nc_,colors_per_thread,PackGhostArg<Q> >
209  launch(arg, a, destination);
210 
211  launch.apply(0);
212  }
213 
214  // traits used to ensure we only instantiate arbitrary colors for nSpin=2,4 fields, and only 3 colors otherwise
215  template<typename T, typename G, int nSpin, int nColor_> struct precision_spin_color_mapper { static constexpr int nColor = nColor_; };
216 #ifndef NSPIN1
217  template<typename T, typename G, int nColor_> struct precision_spin_color_mapper<T,G,1,nColor_> { static constexpr int nColor = 3; };
218 #endif
219 
220 #ifdef NSPIN4
221  // never need block-float format with nSpin=4 fields for arbitrary colors
222  template<int nColor_> struct precision_spin_color_mapper<float,short,4,nColor_> { static constexpr int nColor = 3; };
223  template<int nColor_> struct precision_spin_color_mapper<float,int8_t,4,nColor_> { static constexpr int nColor = 3; };
224 #endif
225 
226 #ifdef NSPIN1
227  // never need block-float format with nSpin=4 fields for arbitrary colors
228  template<int nColor_> struct precision_spin_color_mapper<float,short,1,nColor_> { static constexpr int nColor = 3; };
229  template<int nColor_> struct precision_spin_color_mapper<float,int8_t,1,nColor_> { static constexpr int nColor = 3; };
230 #endif
231 
232 #ifndef GPU_MULTIGRID_DOUBLE
233 #ifdef NSPIN1
234  template<int nColor_> struct precision_spin_color_mapper<double,double,1,nColor_> { static constexpr int nColor = 3; };
235 #endif
236 #ifdef NSPIN2
237  template<int nColor_> struct precision_spin_color_mapper<double,double,2,nColor_> { static constexpr int nColor = 3; };
238 #endif
239 #ifdef NSPIN4
240  template<int nColor_> struct precision_spin_color_mapper<double,double,4,nColor_> { static constexpr int nColor = 3; };
241 #endif
242 #endif
243 
244  template <typename Float, typename ghostFloat, QudaFieldOrder order, int Ns>
245  inline void genericPackGhost(void **ghost, const ColorSpinorField &a, QudaParity parity,
246  int nFace, int dagger, MemoryLocation *destination) {
247 
248 #ifndef NSPIN1
249  if (a.Ncolor() != 3 && a.Nspin() == 1)
250  errorQuda("Ncolor = %d not supported for Nspin = %d fields", a.Ncolor(), a.Nspin());
251 #endif
252  if (a.Ncolor() != 3 && a.Nspin() == 4 && a.Precision() == QUDA_SINGLE_PRECISION &&
253  (a.GhostPrecision() == QUDA_HALF_PRECISION || a.GhostPrecision() == QUDA_QUARTER_PRECISION) )
254  errorQuda("Ncolor = %d not supported for Nspin = %d fields with precision = %d and ghost_precision = %d",
255  a.Ncolor(), a.Nspin(), a.Precision(), a.GhostPrecision());
256 #ifndef GPU_MULTIGRID_DOUBLE
257  if ( a.Ncolor() != 3 && a.Precision() == QUDA_DOUBLE_PRECISION)
258  errorQuda("Ncolor = %d not supported for double precision fields", a.Ncolor());
259 #endif
260 
261  if (a.Ncolor() == 3) {
262  genericPackGhost<Float,ghostFloat,order,Ns,precision_spin_color_mapper<Float,ghostFloat,Ns,3>::nColor>(ghost, a, parity, nFace, dagger, destination);
263 #ifdef GPU_MULTIGRID
264  } else if (a.Ncolor() == 6) {
265  genericPackGhost<Float,ghostFloat,order,Ns,precision_spin_color_mapper<Float,ghostFloat,Ns,6>::nColor>(ghost, a, parity, nFace, dagger, destination);
266  } else if (a.Ncolor() == 18) { // Needed for two level free field Wilson
267  genericPackGhost<Float,ghostFloat,order,Ns,precision_spin_color_mapper<Float,ghostFloat,Ns,18>::nColor>(ghost, a, parity, nFace, dagger, destination);
268  } else if (a.Ncolor() == 24) { // Needed for K-D staggered Wilson
269  genericPackGhost<Float,ghostFloat,order,Ns,precision_spin_color_mapper<Float,ghostFloat,Ns,24>::nColor>(ghost, a, parity, nFace, dagger, destination);
270 #ifdef NSPIN4
271  } else if (a.Ncolor() == 32) { // Needed for Wilson
272  genericPackGhost<Float,ghostFloat,order,Ns,precision_spin_color_mapper<Float,ghostFloat,Ns,32>::nColor>(ghost, a, parity, nFace, dagger, destination);
273  } else if (a.Ncolor() == 36) { // Needed for three level free field Wilson
274  genericPackGhost<Float,ghostFloat,order,Ns,precision_spin_color_mapper<Float,ghostFloat,Ns,36>::nColor>(ghost, a, parity, nFace, dagger, destination);
275 #endif // NSPIN4
276 #ifdef NSPIN1
277  } else if (a.Ncolor() == 64) { // Needed for staggered Nc = 64
278  genericPackGhost<Float,ghostFloat,order,Ns,precision_spin_color_mapper<Float,ghostFloat,Ns,64>::nColor>(ghost, a, parity, nFace, dagger, destination);
279 #endif // NSPIN1
280  } else if (a.Ncolor() == 72) { // wilson 3 -> 24 nvec, or staggered 3 -> 24 nvec, which could end up getting used for Laplace...
281  genericPackGhost<Float,ghostFloat,order,Ns,precision_spin_color_mapper<Float,ghostFloat,Ns,72>::nColor>(ghost, a, parity, nFace, dagger, destination);
282  } else if (a.Ncolor() == 96) { // wilson 3 -> 32 nvec, or staggered Nc = 96
283  genericPackGhost<Float,ghostFloat,order,Ns,precision_spin_color_mapper<Float,ghostFloat,Ns,96>::nColor>(ghost, a, parity, nFace, dagger, destination);
284 #ifdef NSPIN1
285  } else if (a.Ncolor() == 192) { // staggered 3 -> 64 Nvec
286  genericPackGhost<Float,ghostFloat,order,Ns,precision_spin_color_mapper<Float,ghostFloat,Ns,192>::nColor>(ghost, a, parity, nFace, dagger, destination);
287  } else if (a.Ncolor() == 288) { // staggered 3 -> 96 Nvec
288  genericPackGhost<Float,ghostFloat,order,Ns,precision_spin_color_mapper<Float,ghostFloat,Ns,288>::nColor>(ghost, a, parity, nFace, dagger, destination);
289 #endif // NSPIN1
290  } else if (a.Ncolor() == 576) { // staggered KD free-field or wilson 24 -> 24 nvec
291  genericPackGhost<Float,ghostFloat,order,Ns,precision_spin_color_mapper<Float,ghostFloat,Ns,576>::nColor>(ghost, a, parity, nFace, dagger, destination);
292 #ifdef NSPIN4
293  } else if (a.Ncolor() == 768) { // wilson 24 -> 32 nvec
294  genericPackGhost<Float,ghostFloat,order,Ns,precision_spin_color_mapper<Float,ghostFloat,Ns,768>::nColor>(ghost, a, parity, nFace, dagger, destination);
295  } else if (a.Ncolor() == 1024) { // wilson 32 -> 32 nvec
296  genericPackGhost<Float,ghostFloat,order,Ns,precision_spin_color_mapper<Float,ghostFloat,Ns,1024>::nColor>(ghost, a, parity, nFace, dagger, destination);
297 #endif // NSPIN4
298 #ifdef NSPIN1
299  } else if (a.Ncolor() == 1536) { // staggered KD 24 -> 64 nvec
300  genericPackGhost<Float,ghostFloat,order,Ns,precision_spin_color_mapper<Float,ghostFloat,Ns,1536>::nColor>(ghost, a, parity, nFace, dagger, destination);
301  } else if (a.Ncolor() == 2304) { // staggered KD 24 -> 96 nvec
302  genericPackGhost<Float,ghostFloat,order,Ns,precision_spin_color_mapper<Float,ghostFloat,Ns,2304>::nColor>(ghost, a, parity, nFace, dagger, destination);
303  } else if (a.Ncolor() == 4096) { // staggered 64 -> 64
304  genericPackGhost<Float,ghostFloat,order,Ns,precision_spin_color_mapper<Float,ghostFloat,Ns,4096>::nColor>(ghost, a, parity, nFace, dagger, destination);
305  } else if (a.Ncolor() == 6144) { // staggered 64 -> 96 nvec
306  genericPackGhost<Float,ghostFloat,order,Ns,precision_spin_color_mapper<Float,ghostFloat,Ns,6144>::nColor>(ghost, a, parity, nFace, dagger, destination);
307  } else if (a.Ncolor() == 9216) { // staggered 96 -> 96 nvec
308  genericPackGhost<Float,ghostFloat,order,Ns,precision_spin_color_mapper<Float,ghostFloat,Ns,9216>::nColor>(ghost, a, parity, nFace, dagger, destination);
309 #endif // NSPIN1
310 #endif // GPU_MULTIGRID
311  } else {
312  errorQuda("Unsupported nColor = %d", a.Ncolor());
313  }
314 
315  }
316 
317  // traits used to ensure we only instantiate float4 for spin=4 fields
318  template<int nSpin,QudaFieldOrder order_> struct spin_order_mapper { static constexpr QudaFieldOrder order = order_; };
319  template<> struct spin_order_mapper<2,QUDA_FLOAT4_FIELD_ORDER> { static constexpr QudaFieldOrder order = QUDA_FLOAT2_FIELD_ORDER; };
320  template<> struct spin_order_mapper<1,QUDA_FLOAT4_FIELD_ORDER> { static constexpr QudaFieldOrder order = QUDA_FLOAT2_FIELD_ORDER; };
321 
322  template <typename Float, typename ghostFloat, QudaFieldOrder order>
323  inline void genericPackGhost(void **ghost, const ColorSpinorField &a, QudaParity parity,
324  int nFace, int dagger, MemoryLocation *destination) {
325 
326  if (a.Nspin() == 4) {
327 #ifdef NSPIN4
328  genericPackGhost<Float,ghostFloat,order,4>(ghost, a, parity, nFace, dagger, destination);
329 #else
330  errorQuda("nSpin=4 not enabled for this build");
331 #endif
332  } else if (a.Nspin() == 2) {
333 #ifdef NSPIN2
334  if (order == QUDA_FLOAT4_FIELD_ORDER) errorQuda("Field order %d with nSpin = %d not supported", order, a.Nspin());
335  genericPackGhost<Float,ghostFloat,spin_order_mapper<2,order>::order,2>(ghost, a, parity, nFace, dagger, destination);
336 #else
337  errorQuda("nSpin=2 not enabled for this build");
338 #endif
339  } else if (a.Nspin() == 1) {
340 #ifdef NSPIN1
341  if (order == QUDA_FLOAT4_FIELD_ORDER) errorQuda("Field order %d with nSpin = %d not supported", order, a.Nspin());
342  genericPackGhost<Float,ghostFloat,spin_order_mapper<1,order>::order,1>(ghost, a, parity, nFace, dagger, destination);
343 #else
344  errorQuda("nSpin=1 not enabled for this build");
345 #endif
346  } else {
347  errorQuda("Unsupported nSpin = %d", a.Nspin());
348  }
349 
350  }
351 
352  // traits used to ensure we only instantiate double and float templates for non-native fields
353  template<typename> struct non_native_precision_mapper { };
354  template<> struct non_native_precision_mapper<double> { typedef double type; };
355  template<> struct non_native_precision_mapper<float> { typedef float type; };
356  template<> struct non_native_precision_mapper<short> { typedef float type; };
357  template<> struct non_native_precision_mapper<int8_t> { typedef float type; };
358 
359  // traits used to ensure we only instantiate float and lower precision for float4 fields
360  template<typename T> struct float4_precision_mapper { typedef T type; };
361  template<> struct float4_precision_mapper<double> { typedef float type; };
362  template<> struct float4_precision_mapper<short> { typedef float type; };
363  template<> struct float4_precision_mapper<int8_t> { typedef float type; };
364 
365  template <typename Float, typename ghostFloat>
366  inline void genericPackGhost(void **ghost, const ColorSpinorField &a, QudaParity parity,
367  int nFace, int dagger, MemoryLocation *destination) {
368 
369  if (a.FieldOrder() == QUDA_FLOAT2_FIELD_ORDER) {
370 
371  // all precisions, color and spin can use this order
372  genericPackGhost<Float,ghostFloat,QUDA_FLOAT2_FIELD_ORDER>(ghost, a, parity, nFace, dagger, destination);
373 
374  } else if (a.FieldOrder() == QUDA_FLOAT4_FIELD_ORDER) {
375 
376  // never have double fields here
377  if (typeid(Float) != typeid(typename float4_precision_mapper<Float>::type))
378  errorQuda("Precision %d not supported for field type %d", a.Precision(), a.FieldOrder());
379  if (typeid(ghostFloat) != typeid(typename float4_precision_mapper<ghostFloat>::type))
380  errorQuda("Ghost precision %d not supported for field type %d", a.GhostPrecision(), a.FieldOrder());
381  genericPackGhost<typename float4_precision_mapper<Float>::type,
382  typename float4_precision_mapper<ghostFloat>::type,
383  QUDA_FLOAT4_FIELD_ORDER>(ghost, a, parity, nFace, dagger, destination);
384 
385  } else if (a.FieldOrder() == QUDA_SPACE_SPIN_COLOR_FIELD_ORDER) {
386 #ifndef GPU_MULTIGRID // with MG mma we need half-precision AoS exchange support
387  if (typeid(Float) != typeid(typename non_native_precision_mapper<Float>::type))
388  errorQuda("Precision %d not supported for field type %d", a.Precision(), a.FieldOrder());
389  if (typeid(ghostFloat) != typeid(typename non_native_precision_mapper<ghostFloat>::type))
390  errorQuda("Ghost precision %d not supported for field type %d", a.GhostPrecision(), a.FieldOrder());
391  genericPackGhost<typename non_native_precision_mapper<Float>::type,
392  typename non_native_precision_mapper<ghostFloat>::type,
393  QUDA_SPACE_SPIN_COLOR_FIELD_ORDER>(ghost, a, parity, nFace, dagger, destination);
394 #else
395  genericPackGhost<Float, ghostFloat, QUDA_SPACE_SPIN_COLOR_FIELD_ORDER>(ghost, a, parity, nFace, dagger,
396  destination);
397 #endif
398  } else {
399  errorQuda("Unsupported field order = %d", a.FieldOrder());
400  }
401 
402  }
403 
404  void genericPackGhost(void **ghost, const ColorSpinorField &a, QudaParity parity,
405  int nFace, int dagger, MemoryLocation *destination_) {
406 
407  if (a.FieldOrder() == QUDA_QOP_DOMAIN_WALL_FIELD_ORDER) {
408  errorQuda("Field order %d not supported", a.FieldOrder());
409  }
410 
411  // set default location to match field type
412  MemoryLocation destination[2*QUDA_MAX_DIM];
413  for (int i=0; i<4*2; i++) {
414  destination[i] = destination_ ? destination_[i] : a.Location() == QUDA_CUDA_FIELD_LOCATION ? Device : Host;
415  }
416 
417  // only do packing if one of the dimensions is partitioned
418  bool partitioned = false;
419  for (int d=0; d<4; d++)
420  if (comm_dim_partitioned(d)) partitioned = true;
421  if (!partitioned) return;
422 
423  if (a.Precision() == QUDA_DOUBLE_PRECISION) {
424  if (a.GhostPrecision() == QUDA_DOUBLE_PRECISION) {
425  genericPackGhost<double,double>(ghost, a, parity, nFace, dagger, destination);
426  } else {
427  errorQuda("precision = %d and ghost precision = %d not supported", a.Precision(), a.GhostPrecision());
428  }
429  } else if (a.Precision() == QUDA_SINGLE_PRECISION) {
430  if (a.GhostPrecision() == QUDA_SINGLE_PRECISION) {
431  genericPackGhost<float,float>(ghost, a, parity, nFace, dagger, destination);
432  } else if (a.GhostPrecision() == QUDA_HALF_PRECISION) {
433 #if QUDA_PRECISION & 2
434  genericPackGhost<float,short>(ghost, a, parity, nFace, dagger, destination);
435 #else
436  errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION);
437 #endif
438  } else if (a.GhostPrecision() == QUDA_QUARTER_PRECISION) {
439 #if QUDA_PRECISION & 1
440  genericPackGhost<float,int8_t>(ghost, a, parity, nFace, dagger, destination);
441 #else
442  errorQuda("QUDA_PRECISION=%d does not enable quarter precision", QUDA_PRECISION);
443 #endif
444  } else {
445  errorQuda("precision = %d and ghost precision = %d not supported", a.Precision(), a.GhostPrecision());
446  }
447  } else if (a.Precision() == QUDA_HALF_PRECISION) {
448  if (a.GhostPrecision() == QUDA_HALF_PRECISION) {
449 #if QUDA_PRECISION & 2
450  genericPackGhost<short,short>(ghost, a, parity, nFace, dagger, destination);
451 #else
452  errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION);
453 #endif
454  } else {
455  errorQuda("precision = %d and ghost precision = %d not supported", a.Precision(), a.GhostPrecision());
456  }
457  } else {
458  errorQuda("Unsupported precision %d", a.Precision());
459  }
460 
461  }
462 
463 } // namespace quda