QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
comm_mpi.cpp
Go to the documentation of this file.
1 #include <cstdio>
2 #include <cstdlib>
3 #include <cstring>
4 #include <algorithm>
5 #include <numeric>
6 #include <mpi.h>
7 #include <csignal>
8 #include <quda_internal.h>
9 #include <comm_quda.h>
10 #include <mpi_comm_handle.h>
11 
12 #define MPI_CHECK(mpi_call) do { \
13  int status = mpi_call; \
14  if (status != MPI_SUCCESS) { \
15  char err_string[128]; \
16  int err_len; \
17  MPI_Error_string(status, err_string, &err_len); \
18  err_string[127] = '\0'; \
19  errorQuda("(MPI) %s", err_string); \
20  } \
21 } while (0)
22 
23 
24 struct MsgHandle_s {
29  MPI_Request request;
30 
35  MPI_Datatype datatype;
36 
41  bool custom;
42 };
43 
44 static int rank = -1;
45 static int size = -1;
46 
47 void comm_gather_hostname(char *hostname_recv_buf) {
48  // determine which GPU this rank will use
49  char *hostname = comm_hostname();
50  MPI_CHECK(MPI_Allgather(hostname, 128, MPI_CHAR, hostname_recv_buf, 128, MPI_CHAR, MPI_COMM_HANDLE));
51 }
52 
53 void comm_gather_gpuid(int *gpuid_recv_buf) {
54  int gpuid = comm_gpuid();
55  MPI_CHECK(MPI_Allgather(&gpuid, 1, MPI_INT, gpuid_recv_buf, 1, MPI_INT, MPI_COMM_HANDLE));
56 }
57 
58 void comm_init(int ndim, const int *dims, QudaCommsMap rank_from_coords, void *map_data)
59 {
60  int initialized;
61  MPI_CHECK( MPI_Initialized(&initialized) );
62 
63  if (!initialized) {
64  errorQuda("MPI has not been initialized");
65  }
66 
67  MPI_CHECK(MPI_Comm_rank(MPI_COMM_HANDLE, &rank));
68  MPI_CHECK(MPI_Comm_size(MPI_COMM_HANDLE, &size));
69 
70  int grid_size = 1;
71  for (int i = 0; i < ndim; i++) {
72  grid_size *= dims[i];
73  }
74  if (grid_size != size) {
75  errorQuda("Communication grid size declared via initCommsGridQuda() does not match"
76  " total number of MPI ranks (%d != %d)", grid_size, size);
77  }
78 
79  comm_init_common(ndim, dims, rank_from_coords, map_data);
80 }
81 
82 int comm_rank(void)
83 {
84  return rank;
85 }
86 
87 
88 int comm_size(void)
89 {
90  return size;
91 }
92 
93 
94 static const int max_displacement = 4;
95 
96 static void check_displacement(const int displacement[], int ndim) {
97  for (int i=0; i<ndim; i++) {
98  if (abs(displacement[i]) > max_displacement){
99  errorQuda("Requested displacement[%d] = %d is greater than maximum allowed", i, displacement[i]);
100  }
101  }
102 }
103 
107 MsgHandle *comm_declare_send_displaced(void *buffer, const int displacement[], size_t nbytes)
108 {
110  int ndim = comm_ndim(topo);
111  check_displacement(displacement, ndim);
112 
113  int rank = comm_rank_displaced(topo, displacement);
114 
115  int tag = 0;
116  for (int i=ndim-1; i>=0; i--) tag = tag * 4 * max_displacement + displacement[i] + max_displacement;
117  tag = tag >= 0 ? tag : 2*pow(4*max_displacement,ndim) + tag;
118 
119  MsgHandle *mh = (MsgHandle *)safe_malloc(sizeof(MsgHandle));
120  MPI_CHECK(MPI_Send_init(buffer, nbytes, MPI_BYTE, rank, tag, MPI_COMM_HANDLE, &(mh->request)));
121  mh->custom = false;
122 
123  return mh;
124 }
125 
126 
130 MsgHandle *comm_declare_receive_displaced(void *buffer, const int displacement[], size_t nbytes)
131 {
133  int ndim = comm_ndim(topo);
134  check_displacement(displacement,ndim);
135 
136  int rank = comm_rank_displaced(topo, displacement);
137 
138  int tag = 0;
139  for (int i=ndim-1; i>=0; i--) tag = tag * 4 * max_displacement - displacement[i] + max_displacement;
140  tag = tag >= 0 ? tag : 2*pow(4*max_displacement,ndim) + tag;
141 
142  MsgHandle *mh = (MsgHandle *)safe_malloc(sizeof(MsgHandle));
143  MPI_CHECK(MPI_Recv_init(buffer, nbytes, MPI_BYTE, rank, tag, MPI_COMM_HANDLE, &(mh->request)));
144  mh->custom = false;
145 
146  return mh;
147 }
148 
149 
153 MsgHandle *comm_declare_strided_send_displaced(void *buffer, const int displacement[],
154  size_t blksize, int nblocks, size_t stride)
155 {
157  int ndim = comm_ndim(topo);
158  check_displacement(displacement, ndim);
159 
160  int rank = comm_rank_displaced(topo, displacement);
161 
162  int tag = 0;
163  for (int i=ndim-1; i>=0; i--) tag = tag * 4 * max_displacement + displacement[i] + max_displacement;
164  tag = tag >= 0 ? tag : 2*pow(4*max_displacement,ndim) + tag;
165 
166  MsgHandle *mh = (MsgHandle *)safe_malloc(sizeof(MsgHandle));
167 
168  // create a new strided MPI type
169  MPI_CHECK( MPI_Type_vector(nblocks, blksize, stride, MPI_BYTE, &(mh->datatype)) );
170  MPI_CHECK( MPI_Type_commit(&(mh->datatype)) );
171  mh->custom = true;
172 
173  MPI_CHECK(MPI_Send_init(buffer, 1, mh->datatype, rank, tag, MPI_COMM_HANDLE, &(mh->request)));
174 
175  return mh;
176 }
177 
178 
182 MsgHandle *comm_declare_strided_receive_displaced(void *buffer, const int displacement[],
183  size_t blksize, int nblocks, size_t stride)
184 {
186  int ndim = comm_ndim(topo);
187  check_displacement(displacement,ndim);
188 
189  int rank = comm_rank_displaced(topo, displacement);
190 
191  int tag = 0;
192  for (int i=ndim-1; i>=0; i--) tag = tag * 4 * max_displacement - displacement[i] + max_displacement;
193  tag = tag >= 0 ? tag : 2*pow(4*max_displacement,ndim) + tag;
194 
195  MsgHandle *mh = (MsgHandle *)safe_malloc(sizeof(MsgHandle));
196 
197  // create a new strided MPI type
198  MPI_CHECK( MPI_Type_vector(nblocks, blksize, stride, MPI_BYTE, &(mh->datatype)) );
199  MPI_CHECK( MPI_Type_commit(&(mh->datatype)) );
200  mh->custom = true;
201 
202  MPI_CHECK(MPI_Recv_init(buffer, 1, mh->datatype, rank, tag, MPI_COMM_HANDLE, &(mh->request)));
203 
204  return mh;
205 }
206 
208 {
209  MPI_CHECK(MPI_Request_free(&(mh->request)));
210  if (mh->custom) MPI_CHECK(MPI_Type_free(&(mh->datatype)));
211  host_free(mh);
212  mh = nullptr;
213 }
214 
215 
217 {
218  MPI_CHECK( MPI_Start(&(mh->request)) );
219 }
220 
221 
223 {
224  MPI_CHECK( MPI_Wait(&(mh->request), MPI_STATUS_IGNORE) );
225 }
226 
227 
229 {
230  int query;
231  MPI_CHECK( MPI_Test(&(mh->request), &query, MPI_STATUS_IGNORE) );
232 
233  return query;
234 }
235 
236 template <typename T> T deterministic_reduce(T *array, int n)
237 {
238  std::sort(array, array + n); // sort reduction into ascending order for deterministic reduction
239  return std::accumulate(array, array + n, 0.0);
240 }
241 
242 void comm_allreduce(double* data)
243 {
244  if (!comm_deterministic_reduce()) {
245  double recvbuf;
246  MPI_CHECK(MPI_Allreduce(data, &recvbuf, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_HANDLE));
247  *data = recvbuf;
248  } else {
249  const size_t n = comm_size();
250  double *recv_buf = (double *)safe_malloc(n * sizeof(double));
251  MPI_CHECK(MPI_Allgather(data, 1, MPI_DOUBLE, recv_buf, 1, MPI_DOUBLE, MPI_COMM_HANDLE));
252  *data = deterministic_reduce(recv_buf, n);
253  host_free(recv_buf);
254  }
255 }
256 
257 
258 void comm_allreduce_max(double* data)
259 {
260  double recvbuf;
261  MPI_CHECK(MPI_Allreduce(data, &recvbuf, 1, MPI_DOUBLE, MPI_MAX, MPI_COMM_HANDLE));
262  *data = recvbuf;
263 }
264 
265 void comm_allreduce_min(double* data)
266 {
267  double recvbuf;
268  MPI_CHECK(MPI_Allreduce(data, &recvbuf, 1, MPI_DOUBLE, MPI_MIN, MPI_COMM_HANDLE));
269  *data = recvbuf;
270 }
271 
272 void comm_allreduce_array(double* data, size_t size)
273 {
274  if (!comm_deterministic_reduce()) {
275  double *recvbuf = new double[size];
276  MPI_CHECK(MPI_Allreduce(data, recvbuf, size, MPI_DOUBLE, MPI_SUM, MPI_COMM_HANDLE));
277  memcpy(data, recvbuf, size * sizeof(double));
278  delete[] recvbuf;
279  } else {
280  size_t n = comm_size();
281  double *recv_buf = new double[size * n];
282  MPI_CHECK(MPI_Allgather(data, size, MPI_DOUBLE, recv_buf, size, MPI_DOUBLE, MPI_COMM_HANDLE));
283 
284  double *recv_trans = new double[size * n];
285  for (size_t i = 0; i < n; i++) {
286  for (size_t j = 0; j < size; j++) { recv_trans[j * n + i] = recv_buf[i * size + j]; }
287  }
288 
289  for (size_t i = 0; i < size; i++) { data[i] = deterministic_reduce(recv_trans + i * n, n); }
290 
291  delete[] recv_buf;
292  delete[] recv_trans;
293  }
294 }
295 
296 void comm_allreduce_max_array(double* data, size_t size)
297 {
298  double *recvbuf = new double[size];
299  MPI_CHECK(MPI_Allreduce(data, recvbuf, size, MPI_DOUBLE, MPI_MAX, MPI_COMM_HANDLE));
300  memcpy(data, recvbuf, size*sizeof(double));
301  delete []recvbuf;
302 }
303 
304 void comm_allreduce_int(int* data)
305 {
306  int recvbuf;
307  MPI_CHECK(MPI_Allreduce(data, &recvbuf, 1, MPI_INT, MPI_SUM, MPI_COMM_HANDLE));
308  *data = recvbuf;
309 }
310 
311 void comm_allreduce_xor(uint64_t *data)
312 {
313  if (sizeof(uint64_t) != sizeof(unsigned long)) errorQuda("unsigned long is not 64-bit");
314  uint64_t recvbuf;
315  MPI_CHECK(MPI_Allreduce(data, &recvbuf, 1, MPI_UNSIGNED_LONG, MPI_BXOR, MPI_COMM_HANDLE));
316  *data = recvbuf;
317 }
318 
319 
321 void comm_broadcast(void *data, size_t nbytes)
322 {
323  MPI_CHECK(MPI_Bcast(data, (int)nbytes, MPI_BYTE, 0, MPI_COMM_HANDLE));
324 }
325 
326 void comm_barrier(void) { MPI_CHECK(MPI_Barrier(MPI_COMM_HANDLE)); }
327 
328 void comm_abort(int status)
329 {
330 #ifdef HOST_DEBUG
331  raise(SIGINT);
332 #endif
333  MPI_Abort(MPI_COMM_HANDLE, status);
334 }
void comm_free(MsgHandle *&mh)
Definition: comm_mpi.cpp:207
void comm_allreduce(double *data)
Definition: comm_mpi.cpp:242
MsgHandle * comm_declare_send_displaced(void *buffer, const int displacement[], size_t nbytes)
Definition: comm_mpi.cpp:107
void comm_gather_hostname(char *hostname_recv_buf)
Gather all hostnames.
Definition: comm_mpi.cpp:47
_EXTERN_C_ int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm)
Definition: nvtx_pmpi.c:340
MPI_Request request
Definition: comm_mpi.cpp:29
_EXTERN_C_ int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
Definition: nvtx_pmpi.c:110
int comm_query(MsgHandle *mh)
Definition: comm_mpi.cpp:228
void comm_wait(MsgHandle *mh)
Definition: comm_mpi.cpp:222
void comm_allreduce_array(double *data, size_t size)
Definition: comm_mpi.cpp:272
void comm_allreduce_max(double *data)
Definition: comm_mpi.cpp:258
_EXTERN_C_ int MPI_Wait(MPI_Request *request, MPI_Status *status)
Definition: nvtx_pmpi.c:156
#define errorQuda(...)
Definition: util_quda.h:121
void comm_gather_gpuid(int *gpuid_recv_buf)
Gather all GPU ids.
Definition: comm_mpi.cpp:53
void comm_allreduce_int(int *data)
Definition: comm_mpi.cpp:304
#define host_free(ptr)
Definition: malloc_quda.h:71
static const int max_displacement
Definition: comm_mpi.cpp:94
#define MPI_CHECK(mpi_call)
Definition: comm_mpi.cpp:12
static int rank
Definition: comm_mpi.cpp:44
Topology * comm_default_topology(void)
MsgHandle * comm_declare_strided_receive_displaced(void *buffer, const int displacement[], size_t blksize, int nblocks, size_t stride)
Definition: comm_mpi.cpp:182
void comm_allreduce_max_array(double *data, size_t size)
Definition: comm_mpi.cpp:296
int comm_gpuid(void)
static int size
Definition: comm_mpi.cpp:45
static int ndim
Definition: layout_hyper.c:53
_EXTERN_C_ int MPI_Barrier(MPI_Comm comm)
Definition: nvtx_pmpi.c:501
char * comm_hostname(void)
Definition: comm_common.cpp:58
int comm_rank(void)
Definition: comm_mpi.cpp:82
void comm_start(MsgHandle *mh)
Definition: comm_mpi.cpp:216
void comm_barrier(void)
Definition: comm_mpi.cpp:326
void comm_allreduce_min(double *data)
Definition: comm_mpi.cpp:265
_EXTERN_C_ int MPI_Test(MPI_Request *request, int *flag, MPI_Status *status)
Definition: nvtx_pmpi.c:547
static bool initialized
Profiler for initQuda.
void comm_init(int ndim, const int *dims, QudaCommsMap rank_from_coords, void *map_data)
Initialize the communications, implemented in comm_single.cpp, comm_qmp.cpp, and comm_mpi.cpp.
Definition: comm_mpi.cpp:58
void comm_broadcast(void *data, size_t nbytes)
Definition: comm_mpi.cpp:321
__host__ __device__ ValueType pow(ValueType x, ExponentType e)
Definition: complex_quda.h:111
int comm_rank_displaced(const Topology *topo, const int displacement[])
_EXTERN_C_ int MPI_Send_init(const void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, MPI_Request *request)
Definition: nvtx_pmpi.c:570
int(* QudaCommsMap)(const int *coords, void *fdata)
Definition: comm_quda.h:12
MPI_Datatype datatype
Definition: comm_mpi.cpp:35
void comm_init_common(int ndim, const int *dims, QudaCommsMap rank_from_coords, void *map_data)
Initialize the communications common to all communications abstractions.
_EXTERN_C_ int MPI_Recv_init(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Request *request)
Definition: nvtx_pmpi.c:593
#define safe_malloc(size)
Definition: malloc_quda.h:66
bool comm_deterministic_reduce()
static int dims[4]
Definition: face_gauge.cpp:41
_EXTERN_C_ int MPI_Start(MPI_Request *request)
Definition: nvtx_pmpi.c:524
bool custom
Definition: comm_mpi.cpp:41
MsgHandle * comm_declare_receive_displaced(void *buffer, const int displacement[], size_t nbytes)
Definition: comm_mpi.cpp:130
static int gpuid
int comm_ndim(const Topology *topo)
void comm_allreduce_xor(uint64_t *data)
Definition: comm_mpi.cpp:311
__host__ __device__ ValueType abs(ValueType x)
Definition: complex_quda.h:125
static void check_displacement(const int displacement[], int ndim)
Definition: comm_mpi.cpp:96
T deterministic_reduce(T *array, int n)
Definition: comm_mpi.cpp:236
MsgHandle * comm_declare_strided_send_displaced(void *buffer, const int displacement[], size_t blksize, int nblocks, size_t stride)
Definition: comm_mpi.cpp:153
void comm_abort(int status)
Definition: comm_mpi.cpp:328
_EXTERN_C_ int MPI_Bcast(void *buffer, int count, MPI_Datatype datatype, int root, MPI_Comm comm)
Definition: nvtx_pmpi.c:455
int comm_size(void)
Definition: comm_mpi.cpp:88