9 #define MPI_CHECK(mpi_call) do { \
10 int status = mpi_call; \
11 if (status != MPI_SUCCESS) { \
12 char err_string[128]; \
14 MPI_Error_string(status, err_string, &err_len); \
15 err_string[127] = '\0'; \
16 errorQuda("(MPI) %s", err_string); \
28 static int gpuid = -1;
34 MPI_CHECK( MPI_Initialized(&initialized) );
37 errorQuda(
"MPI has not been initialized");
40 MPI_CHECK( MPI_Comm_rank(MPI_COMM_WORLD, &rank) );
41 MPI_CHECK( MPI_Comm_size(MPI_COMM_WORLD, &size) );
44 for (
int i = 0; i < ndim; i++) {
47 if (grid_size != size) {
48 errorQuda(
"Communication grid size declared via initCommsGridQuda() does not match"
49 " total number of MPI ranks (%d != %d)", grid_size, size);
57 char *hostname_recv_buf = (
char *)
safe_malloc(128*size);
59 MPI_CHECK( MPI_Allgather(hostname, 128, MPI_CHAR, hostname_recv_buf, 128, MPI_CHAR, MPI_COMM_WORLD) );
62 for (
int i = 0; i < rank; i++) {
63 if (!strncmp(hostname, &hostname_recv_buf[128*i], 128)) {
70 cudaGetDeviceCount(&device_count);
71 if (device_count == 0) {
74 if (gpuid >= device_count) {
75 errorQuda(
"Too few GPUs available on %s", hostname);
108 MPI_CHECK( MPI_Send_init(buffer, nbytes, MPI_BYTE, rank, tag, MPI_COMM_WORLD, &(mh->
request)) );
124 MPI_CHECK( MPI_Recv_init(buffer, nbytes, MPI_BYTE, rank, tag, MPI_COMM_WORLD, &(mh->
request)) );
134 size_t blksize,
int nblocks,
size_t stride)
143 MPI_CHECK( MPI_Type_vector(nblocks, blksize, stride, MPI_BYTE, &(mh->
datatype)) );
156 size_t blksize,
int nblocks,
size_t stride)
165 MPI_CHECK( MPI_Type_vector(nblocks, blksize, stride, MPI_BYTE, &(mh->
datatype)) );
204 MPI_CHECK( MPI_Allreduce(data, &recvbuf, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD) );
212 MPI_CHECK( MPI_Allreduce(data, &recvbuf, 1, MPI_DOUBLE, MPI_MAX, MPI_COMM_WORLD) );
218 double *recvbuf =
new double[size];
219 MPI_CHECK( MPI_Allreduce(data, recvbuf, size, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD) );
220 memcpy(data, recvbuf, size*
sizeof(
double));
228 MPI_CHECK( MPI_Allreduce(data, &recvbuf, 1, MPI_INT, MPI_SUM, MPI_COMM_WORLD) );
236 MPI_CHECK( MPI_Bcast(data, (
int)nbytes, MPI_BYTE, 0, MPI_COMM_WORLD) );
242 MPI_CHECK( MPI_Barrier(MPI_COMM_WORLD) );
248 MPI_Abort(MPI_COMM_WORLD, status) ;
void comm_allreduce(double *data)
MsgHandle * comm_declare_send_displaced(void *buffer, const int displacement[], size_t nbytes)
int comm_query(MsgHandle *mh)
Topology * comm_create_topology(int ndim, const int *dims, QudaCommsMap rank_from_coords, void *map_data)
void comm_wait(MsgHandle *mh)
void comm_allreduce_array(double *data, size_t size)
void comm_allreduce_max(double *data)
void comm_allreduce_int(int *data)
#define MPI_CHECK(mpi_call)
Topology * comm_default_topology(void)
MsgHandle * comm_declare_strided_receive_displaced(void *buffer, const int displacement[], size_t blksize, int nblocks, size_t stride)
char * comm_hostname(void)
void comm_start(MsgHandle *mh)
void comm_init(int ndim, const int *dims, QudaCommsMap rank_from_coords, void *map_data)
void comm_broadcast(void *data, size_t nbytes)
int comm_rank_displaced(const Topology *topo, const int displacement[])
#define safe_malloc(size)
MsgHandle * comm_declare_receive_displaced(void *buffer, const int displacement[], size_t nbytes)
void comm_set_default_topology(Topology *topo)
void comm_free(MsgHandle *mh)
int(* QudaCommsMap)(const int *coords, void *fdata)
MsgHandle * comm_declare_strided_send_displaced(void *buffer, const int displacement[], size_t blksize, int nblocks, size_t stride)
void comm_abort(int status)