QUDA  v1.1.0
A library for QCD on GPUs
communicator_mpi.cpp
Go to the documentation of this file.
1 #include <communicator_quda.h>
2 
3 #define MPI_CHECK(mpi_call) \
4  do { \
5  int status = mpi_call; \
6  if (status != MPI_SUCCESS) { \
7  char err_string[128]; \
8  int err_len; \
9  MPI_Error_string(status, err_string, &err_len); \
10  err_string[127] = '\0'; \
11  errorQuda("(MPI) %s", err_string); \
12  } \
13  } while (0)
14 
15 struct MsgHandle_s {
20  MPI_Request request;
21 
26  MPI_Datatype datatype;
27 
32  bool custom;
33 };
34 
35 Communicator::Communicator(int nDim, const int *commDims, QudaCommsMap rank_from_coords, void *map_data,
36  bool user_set_comm_handle_, void *user_comm)
37 {
38  user_set_comm_handle = user_set_comm_handle_;
39 
40  int initialized;
41  MPI_CHECK(MPI_Initialized(&initialized));
42 
43  if (!initialized) { assert(false); }
44 
46  MPI_COMM_HANDLE = *((MPI_Comm *)user_comm);
47  } else {
48  MPI_Comm_dup(MPI_COMM_WORLD, &MPI_COMM_HANDLE);
49  }
50 
51  comm_init(nDim, commDims, rank_from_coords, map_data);
52 }
53 
54 Communicator::Communicator(Communicator &other, const int *comm_split)
55 {
56  user_set_comm_handle = false;
57 
58  constexpr int nDim = 4;
59 
60  quda::CommKey comm_dims_split;
61 
62  quda::CommKey comm_key_split;
63  quda::CommKey comm_color_split;
64 
65  for (int d = 0; d < nDim; d++) {
66  assert(other.comm_dim(d) % comm_split[d] == 0);
67  comm_dims_split[d] = other.comm_dim(d) / comm_split[d];
68  comm_key_split[d] = other.comm_coord(d) % comm_dims_split[d];
69  comm_color_split[d] = other.comm_coord(d) / comm_dims_split[d];
70  }
71 
72  int key = index(nDim, comm_dims_split.data(), comm_key_split.data());
73  int color = index(nDim, comm_split, comm_color_split.data());
74 
75  MPI_CHECK(MPI_Comm_split(other.MPI_COMM_HANDLE, color, key, &MPI_COMM_HANDLE));
76  int my_rank_;
77  MPI_CHECK(MPI_Comm_rank(MPI_COMM_HANDLE, &my_rank_));
78 
80  comm_init(nDim, comm_dims_split.data(), func, comm_dims_split.data());
81 }
82 
84 {
85  comm_finalize();
86  if (!user_set_comm_handle) { MPI_Comm_free(&MPI_COMM_HANDLE); }
87 }
88 
89 void Communicator::comm_gather_hostname(char *hostname_recv_buf)
90 {
91  // determine which GPU this rank will use
92  char *hostname = comm_hostname();
93  MPI_CHECK(MPI_Allgather(hostname, 128, MPI_CHAR, hostname_recv_buf, 128, MPI_CHAR, MPI_COMM_HANDLE));
94 }
95 
96 void Communicator::comm_gather_gpuid(int *gpuid_recv_buf)
97 {
98  int gpuid = comm_gpuid();
99  MPI_CHECK(MPI_Allgather(&gpuid, 1, MPI_INT, gpuid_recv_buf, 1, MPI_INT, MPI_COMM_HANDLE));
100 }
101 
102 void Communicator::comm_init(int ndim, const int *dims, QudaCommsMap rank_from_coords, void *map_data)
103 {
104  int initialized;
105  MPI_CHECK(MPI_Initialized(&initialized));
106 
107  if (!initialized) { errorQuda("MPI has not been initialized"); }
108 
109  MPI_CHECK(MPI_Comm_rank(MPI_COMM_HANDLE, &rank));
110  MPI_CHECK(MPI_Comm_size(MPI_COMM_HANDLE, &size));
111 
112  int grid_size = 1;
113  for (int i = 0; i < ndim; i++) { grid_size *= dims[i]; }
114  if (grid_size != size) {
115  errorQuda("Communication grid size declared via initCommsGridQuda() does not match"
116  " total number of MPI ranks (%d != %d)",
117  grid_size, size);
118  }
119 
120  comm_init_common(ndim, dims, rank_from_coords, map_data);
121 }
122 
123 int Communicator::comm_rank(void) { return rank; }
124 
125 int Communicator::comm_size(void) { return size; }
126 
130 MsgHandle *Communicator::comm_declare_send_rank(void *buffer, int rank, int tag, size_t nbytes)
131 {
132  MsgHandle *mh = (MsgHandle *)safe_malloc(sizeof(MsgHandle));
133  MPI_CHECK(MPI_Send_init(buffer, nbytes, MPI_BYTE, rank, tag, MPI_COMM_HANDLE, &(mh->request)));
134  mh->custom = false;
135 
136  return mh;
137 }
138 
142 MsgHandle *Communicator::comm_declare_recv_rank(void *buffer, int rank, int tag, size_t nbytes)
143 {
144  MsgHandle *mh = (MsgHandle *)safe_malloc(sizeof(MsgHandle));
145  MPI_CHECK(MPI_Recv_init(buffer, nbytes, MPI_BYTE, rank, tag, MPI_COMM_HANDLE, &(mh->request)));
146  mh->custom = false;
147 
148  return mh;
149 }
150 
154 MsgHandle *Communicator::comm_declare_send_displaced(void *buffer, const int displacement[], size_t nbytes)
155 {
157  int ndim = comm_ndim(topo);
158  check_displacement(displacement, ndim);
159 
160  int rank = comm_rank_displaced(topo, displacement);
161 
162  int tag = 0;
163  for (int i = ndim - 1; i >= 0; i--) tag = tag * 4 * max_displacement + displacement[i] + max_displacement;
164  tag = tag >= 0 ? tag : 2 * pow(4 * max_displacement, ndim) + tag;
165 
166  MsgHandle *mh = (MsgHandle *)safe_malloc(sizeof(MsgHandle));
167  MPI_CHECK(MPI_Send_init(buffer, nbytes, MPI_BYTE, rank, tag, MPI_COMM_HANDLE, &(mh->request)));
168  mh->custom = false;
169 
170  return mh;
171 }
172 
176 MsgHandle *Communicator::comm_declare_receive_displaced(void *buffer, const int displacement[], size_t nbytes)
177 {
179  int ndim = comm_ndim(topo);
180  check_displacement(displacement, ndim);
181 
182  int rank = comm_rank_displaced(topo, displacement);
183 
184  int tag = 0;
185  for (int i = ndim - 1; i >= 0; i--) tag = tag * 4 * max_displacement - displacement[i] + max_displacement;
186  tag = tag >= 0 ? tag : 2 * pow(4 * max_displacement, ndim) + tag;
187 
188  MsgHandle *mh = (MsgHandle *)safe_malloc(sizeof(MsgHandle));
189  MPI_CHECK(MPI_Recv_init(buffer, nbytes, MPI_BYTE, rank, tag, MPI_COMM_HANDLE, &(mh->request)));
190  mh->custom = false;
191 
192  return mh;
193 }
194 
198 MsgHandle *Communicator::comm_declare_strided_send_displaced(void *buffer, const int displacement[], size_t blksize,
199  int nblocks, size_t stride)
200 {
202  int ndim = comm_ndim(topo);
203  check_displacement(displacement, ndim);
204 
205  int rank = comm_rank_displaced(topo, displacement);
206 
207  int tag = 0;
208  for (int i = ndim - 1; i >= 0; i--) tag = tag * 4 * max_displacement + displacement[i] + max_displacement;
209  tag = tag >= 0 ? tag : 2 * pow(4 * max_displacement, ndim) + tag;
210 
211  MsgHandle *mh = (MsgHandle *)safe_malloc(sizeof(MsgHandle));
212 
213  // create a new strided MPI type
214  MPI_CHECK(MPI_Type_vector(nblocks, blksize, stride, MPI_BYTE, &(mh->datatype)));
215  MPI_CHECK(MPI_Type_commit(&(mh->datatype)));
216  mh->custom = true;
217 
218  MPI_CHECK(MPI_Send_init(buffer, 1, mh->datatype, rank, tag, MPI_COMM_HANDLE, &(mh->request)));
219 
220  return mh;
221 }
222 
226 MsgHandle *Communicator::comm_declare_strided_receive_displaced(void *buffer, const int displacement[], size_t blksize,
227  int nblocks, size_t stride)
228 {
230  int ndim = comm_ndim(topo);
231  check_displacement(displacement, ndim);
232 
233  int rank = comm_rank_displaced(topo, displacement);
234 
235  int tag = 0;
236  for (int i = ndim - 1; i >= 0; i--) tag = tag * 4 * max_displacement - displacement[i] + max_displacement;
237  tag = tag >= 0 ? tag : 2 * pow(4 * max_displacement, ndim) + tag;
238 
239  MsgHandle *mh = (MsgHandle *)safe_malloc(sizeof(MsgHandle));
240 
241  // create a new strided MPI type
242  MPI_CHECK(MPI_Type_vector(nblocks, blksize, stride, MPI_BYTE, &(mh->datatype)));
243  MPI_CHECK(MPI_Type_commit(&(mh->datatype)));
244  mh->custom = true;
245 
246  MPI_CHECK(MPI_Recv_init(buffer, 1, mh->datatype, rank, tag, MPI_COMM_HANDLE, &(mh->request)));
247 
248  return mh;
249 }
250 
252 {
253  MPI_CHECK(MPI_Request_free(&(mh->request)));
254  if (mh->custom) MPI_CHECK(MPI_Type_free(&(mh->datatype)));
255  host_free(mh);
256  mh = nullptr;
257 }
258 
260 
261 void Communicator::comm_wait(MsgHandle *mh) { MPI_CHECK(MPI_Wait(&(mh->request), MPI_STATUS_IGNORE)); }
262 
264 {
265  int query;
266  MPI_CHECK(MPI_Test(&(mh->request), &query, MPI_STATUS_IGNORE));
267 
268  return query;
269 }
270 
272 {
273  if (!comm_deterministic_reduce()) {
274  double recvbuf;
275  MPI_CHECK(MPI_Allreduce(data, &recvbuf, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_HANDLE));
276  *data = recvbuf;
277  } else {
278  const size_t n = comm_size();
279  double *recv_buf = (double *)safe_malloc(n * sizeof(double));
280  MPI_CHECK(MPI_Allgather(data, 1, MPI_DOUBLE, recv_buf, 1, MPI_DOUBLE, MPI_COMM_HANDLE));
281  *data = deterministic_reduce(recv_buf, n);
282  host_free(recv_buf);
283  }
284 }
285 
287 {
288  double recvbuf;
289  MPI_CHECK(MPI_Allreduce(data, &recvbuf, 1, MPI_DOUBLE, MPI_MAX, MPI_COMM_HANDLE));
290  *data = recvbuf;
291 }
292 
294 {
295  double recvbuf;
296  MPI_CHECK(MPI_Allreduce(data, &recvbuf, 1, MPI_DOUBLE, MPI_MIN, MPI_COMM_HANDLE));
297  *data = recvbuf;
298 }
299 
300 void Communicator::comm_allreduce_array(double *data, size_t size)
301 {
302  if (!comm_deterministic_reduce()) {
303  double *recvbuf = new double[size];
304  MPI_CHECK(MPI_Allreduce(data, recvbuf, size, MPI_DOUBLE, MPI_SUM, MPI_COMM_HANDLE));
305  memcpy(data, recvbuf, size * sizeof(double));
306  delete[] recvbuf;
307  } else {
308  size_t n = comm_size();
309  double *recv_buf = new double[size * n];
310  MPI_CHECK(MPI_Allgather(data, size, MPI_DOUBLE, recv_buf, size, MPI_DOUBLE, MPI_COMM_HANDLE));
311 
312  double *recv_trans = new double[size * n];
313  for (size_t i = 0; i < n; i++) {
314  for (size_t j = 0; j < size; j++) { recv_trans[j * n + i] = recv_buf[i * size + j]; }
315  }
316 
317  for (size_t i = 0; i < size; i++) { data[i] = deterministic_reduce(recv_trans + i * n, n); }
318 
319  delete[] recv_buf;
320  delete[] recv_trans;
321  }
322 }
323 
324 void Communicator::comm_allreduce_max_array(double *data, size_t size)
325 {
326  double *recvbuf = new double[size];
327  MPI_CHECK(MPI_Allreduce(data, recvbuf, size, MPI_DOUBLE, MPI_MAX, MPI_COMM_HANDLE));
328  memcpy(data, recvbuf, size * sizeof(double));
329  delete[] recvbuf;
330 }
331 
333 {
334  int recvbuf;
335  MPI_CHECK(MPI_Allreduce(data, &recvbuf, 1, MPI_INT, MPI_SUM, MPI_COMM_HANDLE));
336  *data = recvbuf;
337 }
338 
340 {
341  if (sizeof(uint64_t) != sizeof(unsigned long)) errorQuda("unsigned long is not 64-bit");
342  uint64_t recvbuf;
343  MPI_CHECK(MPI_Allreduce(data, &recvbuf, 1, MPI_UNSIGNED_LONG, MPI_BXOR, MPI_COMM_HANDLE));
344  *data = recvbuf;
345 }
346 
348 void Communicator::comm_broadcast(void *data, size_t nbytes)
349 {
350  MPI_CHECK(MPI_Bcast(data, (int)nbytes, MPI_BYTE, 0, MPI_COMM_HANDLE));
351 }
352 
353 void Communicator::comm_barrier(void) { MPI_CHECK(MPI_Barrier(MPI_COMM_HANDLE)); }
354 
355 void Communicator::comm_abort_(int status) { MPI_Abort(MPI_COMM_WORLD, status); }
356 
358 {
359  int rank;
360  MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &rank));
361  return rank;
362 }
int comm_rank_displaced(const Topology *topo, const int displacement[])
char * comm_hostname(void)
Definition: comm_common.cpp:10
int comm_ndim(const Topology *topo)
int(* QudaCommsMap)(const int *coords, void *fdata)
Definition: comm_quda.h:12
#define MPI_CHECK(mpi_call)
void check_displacement(const int displacement[], int ndim)
int lex_rank_from_coords_dim_t(const int *coords, void *fdata)
#define safe_malloc(size)
Definition: malloc_quda.h:106
#define host_free(ptr)
Definition: malloc_quda.h:115
__host__ __device__ ValueType pow(ValueType x, ExponentType e)
Definition: complex_quda.h:111
_EXTERN_C_ int MPI_Start(MPI_Request *request)
Definition: nvtx_pmpi.c:524
_EXTERN_C_ int MPI_Recv_init(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Request *request)
Definition: nvtx_pmpi.c:593
_EXTERN_C_ int MPI_Send_init(const void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, MPI_Request *request)
Definition: nvtx_pmpi.c:570
_EXTERN_C_ int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm)
Definition: nvtx_pmpi.c:340
_EXTERN_C_ int MPI_Bcast(void *buffer, int count, MPI_Datatype datatype, int root, MPI_Comm comm)
Definition: nvtx_pmpi.c:455
_EXTERN_C_ int MPI_Wait(MPI_Request *request, MPI_Status *status)
Definition: nvtx_pmpi.c:156
_EXTERN_C_ int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
Definition: nvtx_pmpi.c:110
_EXTERN_C_ int MPI_Barrier(MPI_Comm comm)
Definition: nvtx_pmpi.c:501
_EXTERN_C_ int MPI_Test(MPI_Request *request, int *flag, MPI_Status *status)
Definition: nvtx_pmpi.c:547
int comm_dim(int dim)
bool comm_deterministic_reduce()
void comm_allreduce_max(double *data)
T deterministic_reduce(T *array, int n)
void comm_wait(MsgHandle *mh)
void comm_allreduce(double *data)
void comm_broadcast(void *data, size_t nbytes)
int comm_query(MsgHandle *mh)
void comm_gather_gpuid(int *gpuid_recv_buf)
void comm_init(int ndim, const int *dims, QudaCommsMap rank_from_coords, void *map_data)
void comm_allreduce_min(double *data)
int comm_coord(int dim)
static int comm_rank_global()
MsgHandle * comm_declare_strided_send_displaced(void *buffer, const int displacement[], size_t blksize, int nblocks, size_t stride)
void comm_barrier(void)
static void comm_abort_(int status)
MsgHandle * comm_declare_send_rank(void *buffer, int rank, int tag, size_t nbytes)
MsgHandle * comm_declare_strided_receive_displaced(void *buffer, const int displacement[], size_t blksize, int nblocks, size_t stride)
void comm_allreduce_max_array(double *data, size_t size)
MsgHandle * comm_declare_send_displaced(void *buffer, const int displacement[], size_t nbytes)
void comm_gather_hostname(char *hostname_recv_buf)
void comm_allreduce_xor(uint64_t *data)
void comm_allreduce_int(int *data)
MsgHandle * comm_declare_receive_displaced(void *buffer, const int displacement[], size_t nbytes)
static int gpuid
void comm_free(MsgHandle *&mh)
Topology * comm_default_topology(void)
MsgHandle * comm_declare_recv_rank(void *buffer, int rank, int tag, size_t nbytes)
void comm_finalize(void)
static int comm_gpuid()
void comm_init_common(int ndim, const int *dims, QudaCommsMap rank_from_coords, void *map_data)
void comm_start(MsgHandle *mh)
void comm_allreduce_array(double *data, size_t size)
MPI_Request request
MPI_Datatype datatype
constexpr int * data()
Definition: comm_key.h:18
#define errorQuda(...)
Definition: util_quda.h:120