1 #include <color_spinor_field.h>
3 // STRIPED - spread the blocks throughout the workload to ensure we
4 // work on all directions/dimensions simultanesouly to maximize NVLink saturation
5 // if not STRIPED then this means we assign one thread block per direction / dimension
6 // currently does not work with NVSHMEM
11 #include <dslash_quda.h>
12 #include <kernels/dslash_pack.cuh>
13 #include <instantiate.h>
18 int* getPackComms() { return commDim; }
20 void setPackComms(const int *comm_dim)
22 for (int i = 0; i < 4; i++) commDim[i] = comm_dim[i];
23 for (int i = 4; i < QUDA_MAX_DIM; i++) commDim[i] = 0;
26 template <typename Float, int nSpin, int nColor, bool spin_project>
27 std::ostream &operator<<(std::ostream &out, const PackArg<Float, nSpin, nColor, spin_project> &arg)
29 out << "parity = " << arg.parity << std::endl;
30 out << "nParity = " << arg.nParity << std::endl;
31 out << "pc_type = " << arg.pc_type << std::endl;
32 out << "nFace = " << arg.nFace << std::endl;
33 out << "dagger = " << arg.dagger << std::endl;
34 out << "a = " << arg.a << std::endl;
35 out << "b = " << arg.b << std::endl;
36 out << "c = " << arg.c << std::endl;
37 out << "twist = " << arg.twist << std::endl;
38 out << "threads = " << arg.threads << std::endl;
39 out << "threadDimMapLower = { ";
40 for (int i = 0; i < 4; i++) out << arg.threadDimMapLower[i] << (i < 3 ? ", " : " }");
42 out << "threadDimMapUpper = { ";
43 for (int i = 0; i < 4; i++) out << arg.threadDimMapUpper[i] << (i < 3 ? ", " : " }");
45 out << "sites_per_block = " << arg.sites_per_block << std::endl;
49 // FIXME - add CPU variant
51 template <typename Float, int nColor, bool spin_project> class Pack : TunableVectorYZ
56 const ColorSpinorField ∈
57 MemoryLocation location;
59 const bool dagger; // only has meaning for nSpin=4
66 int twist; // only has meaning for nSpin=4
70 static constexpr int shmem = 0;
73 bool tuneGridDim() const { return true; } // If striping, always tune grid dimension
75 unsigned int maxGridSize() const
77 if (location & Host) {
79 // if zero-copy policy then set a maximum number of blocks to be
80 // the 3 * number of dimensions we are communicating
83 // if zero-copy policy then assign exactly up to four thread blocks
84 // per direction per dimension (effectively no grid-size tuning)
88 for (int d = 0; d < in.Ndim(); d++) nDimComms += commDim[d];
89 return max * nDimComms;
91 return TunableVectorYZ::maxGridSize();
93 } // use no more than a quarter of the GPU
95 unsigned int minGridSize() const
97 if (location & Host || location & Shmem) {
99 // if zero-copy policy then set a minimum number of blocks to be
100 // the 1 * number of dimensions we are communicating
103 // if zero-copy policy then assign exactly one thread block
104 // per direction per dimension (effectively no grid-size tuning)
108 for (int d = 0; d < in.Ndim(); d++) nDimComms += commDim[d];
109 return min * nDimComms;
111 return TunableVectorYZ::minGridSize();
118 return TunableVectorYZ::gridStep();
120 if (location & Host || location & Shmem) {
121 // the shmem kernel must ensure the grid size autotuner
122 // increments in steps of 2 * number partitioned dimensions
123 // for equal division of blocks to each direction/dimension
125 for (int d = 0; d < in.Ndim(); d++) nDimComms += commDim[d];
126 return 2 * nDimComms;
128 return TunableVectorYZ::gridStep();
133 bool tuneAuxDim() const { return true; } // Do tune the aux dimensions.
134 unsigned int minThreads() const { return threads; }
138 strcpy(aux, "policy_kernel,");
139 strcat(aux, in.AuxString());
141 for (int i = 0; i < 4; i++) comm[i] = (commDim[i] ? '1' : '0');
143 strcat(aux, ",comm=");
145 strcat(aux, comm_dim_topology_string());
146 if (in.PCType() == QUDA_5D_PC) { strcat(aux, ",5D_pc"); }
147 if (dagger && in.Nspin() == 4) { strcat(aux, ",dagger"); }
148 if (getKernelPackT()) { strcat(aux, ",kernelPackT"); }
150 case 1: strcat(aux, ",nFace=1"); break;
151 case 3: strcat(aux, ",nFace=3"); break;
152 default: errorQuda("Number of faces not supported");
155 twist = ((b != 0.0) ? (c != 0.0 ? 2 : 1) : 0);
156 if (twist && a == 0.0) errorQuda("Twisted packing requires non-zero scale factor a");
157 if (twist) strcat(aux, twist == 2 ? ",twist-doublet" : ",twist-singlet");
159 // label the locations we are packing to
160 // location label is nonp2p-p2p
161 switch ((int)location) {
162 case Device | Remote: strcat(aux, ",device-remote"); break;
163 case Host | Remote: strcat(aux, ",host-remote"); break;
164 case Device: strcat(aux, ",device-device"); break;
165 case Host: strcat(aux, comm_peer2peer_enabled_global() ? ",host-device" : ",host-host"); break;
166 case Shmem: strcat(aux, ",shmem"); break;
167 default: errorQuda("Unknown pack target location %d\n", location);
172 Pack(void *ghost[], const ColorSpinorField &in, MemoryLocation location, int nFace, bool dagger, int parity, double a,
173 double b, double c, int shmem) :
174 TunableVectorYZ((in.Ndim() == 5 ? in.X(4) : 1), in.SiteSubset()),
181 nParity(in.SiteSubset()),
193 // compute number of threads - really number of active work items we have to do
194 for (int i = 0; i < 4; i++) {
195 if (!commDim[i]) continue;
196 if (i == 3 && !getKernelPackT()) continue;
197 threads += 2 * nFace * in.getDslashConstant().ghostFaceCB[i]; // 2 for forwards and backwards faces
203 template <typename T, typename Arg>
204 inline void launch(T *f, const TuneParam &tp, Arg &arg, const qudaStream_t &stream)
206 qudaLaunchKernel(f, tp, stream, arg);
209 void apply(const qudaStream_t &stream)
211 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
212 // enable max shared memory mode on GPUs that support it
213 if (deviceProp.major >= 7) tp.set_max_shared_bytes = true;
215 if (in.Nspin() == 4) {
216 using Arg = PackArg<Float, nColor, 4, spin_project>;
217 Arg arg(ghost, in, nFace, dagger, parity, threads, a, b, c, shmem);
218 arg.counter = dslash::get_shmem_sync_counter();
219 arg.swizzle = tp.aux.x;
220 arg.sites_per_block = (arg.threads + tp.grid.x - 1) / tp.grid.x;
221 arg.blocks_per_dir = tp.grid.x / (2 * arg.active_dims); // set number of blocks per direction
224 if (in.PCType() == QUDA_4D_PC) {
227 case 0: launch(packKernel<true, 0, QUDA_4D_PC, Arg>, tp, arg, stream); break;
228 case 1: launch(packKernel<true, 1, QUDA_4D_PC, Arg>, tp, arg, stream); break;
229 case 2: launch(packKernel<true, 2, QUDA_4D_PC, Arg>, tp, arg, stream); break;
233 case 0: launch(packKernel<false, 0, QUDA_4D_PC, Arg>, tp, arg, stream); break;
234 default: errorQuda("Twisted packing only for dagger");
237 } else if (arg.pc_type == QUDA_5D_PC) {
238 if (arg.twist) errorQuda("Twist packing not defined");
240 launch(packKernel<true, 0, QUDA_5D_PC, Arg>, tp, arg, stream);
242 launch(packKernel<false, 0, QUDA_5D_PC, Arg>, tp, arg, stream);
245 errorQuda("Unexpected preconditioning type %d", in.PCType());
248 if (in.PCType() == QUDA_4D_PC) {
252 launch((location & Host || location & Shmem) ? packShmemKernel<true, 0, QUDA_4D_PC, Arg> :
253 packKernel<true, 0, QUDA_4D_PC, Arg>,
257 launch((location & Host || location & Shmem) ? packShmemKernel<true, 1, QUDA_4D_PC, Arg> :
258 packKernel<true, 1, QUDA_4D_PC, Arg>,
262 launch((location & Host || location & Shmem) ? packShmemKernel<true, 2, QUDA_4D_PC, Arg> :
263 packKernel<true, 2, QUDA_4D_PC, Arg>,
270 launch((location & Host || location & Shmem) ? packShmemKernel<false, 0, QUDA_4D_PC, Arg> :
271 packKernel<false, 0, QUDA_4D_PC, Arg>,
274 default: errorQuda("Twisted packing only for dagger");
277 } else if (arg.pc_type == QUDA_5D_PC) {
278 if (arg.twist) errorQuda("Twist packing not defined");
280 launch(packKernel<true, 0, QUDA_5D_PC, Arg>, tp, arg, stream);
282 launch(packKernel<false, 0, QUDA_5D_PC, Arg>, tp, arg, stream);
286 } else if (in.Nspin() == 1) {
287 using Arg = PackArg<Float, nColor, 1, false>;
288 Arg arg(ghost, in, nFace, dagger, parity, threads, a, b, c, shmem);
289 arg.counter = dslash::get_shmem_sync_counter();
290 arg.swizzle = tp.aux.x;
291 arg.sites_per_block = (arg.threads + tp.grid.x - 1) / tp.grid.x;
292 arg.blocks_per_dir = tp.grid.x / (2 * arg.active_dims); // set number of blocks per direction
295 launch(packStaggeredKernel<Arg>, tp, arg, stream);
297 launch((location & Host || location & Shmem) ? packStaggeredShmemKernel<Arg> : packStaggeredKernel<Arg>, tp,
301 errorQuda("Unsupported nSpin = %d\n", in.Nspin());
305 bool tuneSharedBytes() const { return false; }
308 // not used at present, but if tuneSharedBytes is enabled then
309 // this allows tuning up the full dynamic shared memory if needed
310 unsigned int maxSharedBytesPerBlock() const { return maxDynamicSharedBytesPerBlock(); }
313 void initTuneParam(TuneParam ¶m) const
315 TunableVectorYZ::initTuneParam(param);
316 // if doing a zero-copy policy then ensure that each thread block
317 // runs exclusively on a given SM - this is to ensure quality of
318 // service for the packing kernel when running concurrently.
319 if (location & Host) param.shared_bytes = maxDynamicSharedBytesPerBlock() / 2 + 1;
321 if (location & Host) param.grid.x = minGridSize();
325 void defaultTuneParam(TuneParam ¶m) const
327 TunableVectorYZ::defaultTuneParam(param);
328 // if doing a zero-copy policy then ensure that each thread block
329 // runs exclusively on a given SM - this is to ensure quality of
330 // service for the packing kernel when running concurrently.
331 if (location & Host) param.shared_bytes = maxDynamicSharedBytesPerBlock() / 2 + 1;
333 if (location & Host) param.grid.x = minGridSize();
337 TuneKey tuneKey() const { return TuneKey(in.VolString(), typeid(*this).name(), aux); }
339 int tuningIter() const { return 3; }
341 long long flops() const
343 // unless we are spin projecting (nSpin = 4), there are no flops to do
344 return in.Nspin() == 4 ? 2 * in.Nspin() / 2 * nColor * nParity * in.getDslashConstant().Ls * threads : 0;
347 long long bytes() const
349 size_t precision = sizeof(Float);
350 size_t faceBytes = 2 * ((in.Nspin() == 4 ? in.Nspin() / 2 : in.Nspin()) + in.Nspin()) * nColor * precision;
351 if (precision == QUDA_HALF_PRECISION || precision == QUDA_QUARTER_PRECISION)
352 faceBytes += 2 * sizeof(float); // 2 is from input and output
353 return faceBytes * nParity * in.getDslashConstant().Ls * threads;
357 template <typename Float, int nColor> struct GhostPack {
358 GhostPack(const ColorSpinorField &in, void *ghost[], MemoryLocation location, int nFace, bool dagger, int parity,
359 bool spin_project, double a, double b, double c, int shmem, const qudaStream_t &stream)
362 Pack<Float, nColor, true> pack(ghost, in, location, nFace, dagger, parity, a, b, c, shmem);
365 Pack<Float, nColor, false> pack(ghost, in, location, nFace, dagger, parity, a, b, c, shmem);
371 // Pack the ghost for the Dslash operator
372 void PackGhost(void *ghost[2 * QUDA_MAX_DIM], const ColorSpinorField &in, MemoryLocation location, int nFace,
373 bool dagger, int parity, bool spin_project, double a, double b, double c, int shmem,
374 const qudaStream_t &stream)
377 for (int d = 0; d < 4; d++) {
378 if (!commDim[d]) continue;
379 if (d != 3 || getKernelPackT()) nDimPack++;
381 if (!nDimPack) return; // if zero then we have nothing to pack
383 instantiate<GhostPack>(in, ghost, location, nFace, dagger, parity, spin_project, a, b, c, shmem, stream);