QUDA  v0.7.0
A library for QCD on GPUs
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
comm_mpi.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <string.h>
4 #include <mpi.h>
5 #include <quda_internal.h>
6 #include <comm_quda.h>
7 
8 
9 #define MPI_CHECK(mpi_call) do { \
10  int status = mpi_call; \
11  if (status != MPI_SUCCESS) { \
12  char err_string[128]; \
13  int err_len; \
14  MPI_Error_string(status, err_string, &err_len); \
15  err_string[127] = '\0'; \
16  errorQuda("(MPI) %s", err_string); \
17  } \
18 } while (0)
19 
20 
21 struct MsgHandle_s {
22  MPI_Request request;
23  MPI_Datatype datatype;
24 };
25 
26 static int rank = -1;
27 static int size = -1;
28 static int gpuid = -1;
29 
30 
31 void comm_init(int ndim, const int *dims, QudaCommsMap rank_from_coords, void *map_data)
32 {
33  int initialized;
34  MPI_CHECK( MPI_Initialized(&initialized) );
35 
36  if (!initialized) {
37  errorQuda("MPI has not been initialized");
38  }
39 
40  MPI_CHECK( MPI_Comm_rank(MPI_COMM_WORLD, &rank) );
41  MPI_CHECK( MPI_Comm_size(MPI_COMM_WORLD, &size) );
42 
43  int grid_size = 1;
44  for (int i = 0; i < ndim; i++) {
45  grid_size *= dims[i];
46  }
47  if (grid_size != size) {
48  errorQuda("Communication grid size declared via initCommsGridQuda() does not match"
49  " total number of MPI ranks (%d != %d)", grid_size, size);
50  }
51 
52  Topology *topo = comm_create_topology(ndim, dims, rank_from_coords, map_data);
54 
55  // determine which GPU this MPI rank will use
56  char *hostname = comm_hostname();
57  char *hostname_recv_buf = (char *)safe_malloc(128*size);
58 
59  MPI_CHECK( MPI_Allgather(hostname, 128, MPI_CHAR, hostname_recv_buf, 128, MPI_CHAR, MPI_COMM_WORLD) );
60 
61  gpuid = 0;
62  for (int i = 0; i < rank; i++) {
63  if (!strncmp(hostname, &hostname_recv_buf[128*i], 128)) {
64  gpuid++;
65  }
66  }
67  host_free(hostname_recv_buf);
68 
69  int device_count;
70  cudaGetDeviceCount(&device_count);
71  if (device_count == 0) {
72  errorQuda("No CUDA devices found");
73  }
74  if (gpuid >= device_count) {
75  errorQuda("Too few GPUs available on %s", hostname);
76  }
77 }
78 
79 
80 int comm_rank(void)
81 {
82  return rank;
83 }
84 
85 
86 int comm_size(void)
87 {
88  return size;
89 }
90 
91 
92 int comm_gpuid(void)
93 {
94  return gpuid;
95 }
96 
97 
101 MsgHandle *comm_declare_send_displaced(void *buffer, const int displacement[], size_t nbytes)
102 {
104 
105  int rank = comm_rank_displaced(topo, displacement);
106  int tag = comm_rank();
107  MsgHandle *mh = (MsgHandle *)safe_malloc(sizeof(MsgHandle));
108  MPI_CHECK( MPI_Send_init(buffer, nbytes, MPI_BYTE, rank, tag, MPI_COMM_WORLD, &(mh->request)) );
109 
110  return mh;
111 }
112 
113 
117 MsgHandle *comm_declare_receive_displaced(void *buffer, const int displacement[], size_t nbytes)
118 {
120 
121  int rank = comm_rank_displaced(topo, displacement);
122  int tag = rank;
123  MsgHandle *mh = (MsgHandle *)safe_malloc(sizeof(MsgHandle));
124  MPI_CHECK( MPI_Recv_init(buffer, nbytes, MPI_BYTE, rank, tag, MPI_COMM_WORLD, &(mh->request)) );
125 
126  return mh;
127 }
128 
129 
133 MsgHandle *comm_declare_strided_send_displaced(void *buffer, const int displacement[],
134  size_t blksize, int nblocks, size_t stride)
135 {
137 
138  int rank = comm_rank_displaced(topo, displacement);
139  int tag = comm_rank();
140  MsgHandle *mh = (MsgHandle *)safe_malloc(sizeof(MsgHandle));
141 
142  // create a new strided MPI type
143  MPI_CHECK( MPI_Type_vector(nblocks, blksize, stride, MPI_BYTE, &(mh->datatype)) );
144  MPI_CHECK( MPI_Type_commit(&(mh->datatype)) );
145 
146  MPI_CHECK( MPI_Send_init(buffer, 1, mh->datatype, rank, tag, MPI_COMM_WORLD, &(mh->request)) );
147 
148  return mh;
149 }
150 
151 
155 MsgHandle *comm_declare_strided_receive_displaced(void *buffer, const int displacement[],
156  size_t blksize, int nblocks, size_t stride)
157 {
159 
160  int rank = comm_rank_displaced(topo, displacement);
161  int tag = rank;
162  MsgHandle *mh = (MsgHandle *)safe_malloc(sizeof(MsgHandle));
163 
164  // create a new strided MPI type
165  MPI_CHECK( MPI_Type_vector(nblocks, blksize, stride, MPI_BYTE, &(mh->datatype)) );
166  MPI_CHECK( MPI_Type_commit(&(mh->datatype)) );
167 
168  MPI_CHECK( MPI_Recv_init(buffer, 1, mh->datatype, rank, tag, MPI_COMM_WORLD, &(mh->request)) );
169 
170  return mh;
171 }
172 
173 
175 {
176  host_free(mh);
177 }
178 
179 
181 {
182  MPI_CHECK( MPI_Start(&(mh->request)) );
183 }
184 
185 
187 {
188  MPI_CHECK( MPI_Wait(&(mh->request), MPI_STATUS_IGNORE) );
189 }
190 
191 
193 {
194  int query;
195  MPI_CHECK( MPI_Test(&(mh->request), &query, MPI_STATUS_IGNORE) );
196 
197  return query;
198 }
199 
200 
201 void comm_allreduce(double* data)
202 {
203  double recvbuf;
204  MPI_CHECK( MPI_Allreduce(data, &recvbuf, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD) );
205  *data = recvbuf;
206 }
207 
208 
209 void comm_allreduce_max(double* data)
210 {
211  double recvbuf;
212  MPI_CHECK( MPI_Allreduce(data, &recvbuf, 1, MPI_DOUBLE, MPI_MAX, MPI_COMM_WORLD) );
213  *data = recvbuf;
214 }
215 
216 void comm_allreduce_array(double* data, size_t size)
217 {
218  double *recvbuf = new double[size];
219  MPI_CHECK( MPI_Allreduce(data, recvbuf, size, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD) );
220  memcpy(data, recvbuf, size*sizeof(double));
221  delete []recvbuf;
222 }
223 
224 
225 void comm_allreduce_int(int* data)
226 {
227  int recvbuf;
228  MPI_CHECK( MPI_Allreduce(data, &recvbuf, 1, MPI_INT, MPI_SUM, MPI_COMM_WORLD) );
229  *data = recvbuf;
230 }
231 
232 
234 void comm_broadcast(void *data, size_t nbytes)
235 {
236  MPI_CHECK( MPI_Bcast(data, (int)nbytes, MPI_BYTE, 0, MPI_COMM_WORLD) );
237 }
238 
239 
240 void comm_barrier(void)
241 {
242  MPI_CHECK( MPI_Barrier(MPI_COMM_WORLD) );
243 }
244 
245 
246 void comm_abort(int status)
247 {
248  MPI_Abort(MPI_COMM_WORLD, status) ;
249 }
void comm_allreduce(double *data)
Definition: comm_mpi.cpp:201
MsgHandle * comm_declare_send_displaced(void *buffer, const int displacement[], size_t nbytes)
Definition: comm_mpi.cpp:101
MPI_Request request
Definition: comm_mpi.cpp:22
int comm_query(MsgHandle *mh)
Definition: comm_mpi.cpp:192
Topology * comm_create_topology(int ndim, const int *dims, QudaCommsMap rank_from_coords, void *map_data)
Definition: comm_common.cpp:93
void comm_wait(MsgHandle *mh)
Definition: comm_mpi.cpp:186
void comm_allreduce_array(double *data, size_t size)
Definition: comm_mpi.cpp:216
void comm_allreduce_max(double *data)
Definition: comm_mpi.cpp:209
#define errorQuda(...)
Definition: util_quda.h:73
void comm_allreduce_int(int *data)
Definition: comm_mpi.cpp:225
#define host_free(ptr)
Definition: malloc_quda.h:29
#define MPI_CHECK(mpi_call)
Definition: comm_mpi.cpp:9
Topology * comm_default_topology(void)
MsgHandle * comm_declare_strided_receive_displaced(void *buffer, const int displacement[], size_t blksize, int nblocks, size_t stride)
Definition: comm_mpi.cpp:155
int comm_gpuid(void)
Definition: comm_mpi.cpp:92
char * comm_hostname(void)
Definition: comm_common.cpp:57
int comm_rank(void)
Definition: comm_mpi.cpp:80
void comm_start(MsgHandle *mh)
Definition: comm_mpi.cpp:180
void comm_barrier(void)
Definition: comm_mpi.cpp:240
void comm_init(int ndim, const int *dims, QudaCommsMap rank_from_coords, void *map_data)
Definition: comm_mpi.cpp:31
void comm_broadcast(void *data, size_t nbytes)
Definition: comm_mpi.cpp:234
int comm_rank_displaced(const Topology *topo, const int displacement[])
MPI_Datatype datatype
Definition: comm_mpi.cpp:23
#define safe_malloc(size)
Definition: malloc_quda.h:25
MsgHandle * comm_declare_receive_displaced(void *buffer, const int displacement[], size_t nbytes)
Definition: comm_mpi.cpp:117
void comm_set_default_topology(Topology *topo)
void comm_free(MsgHandle *mh)
Definition: comm_mpi.cpp:174
int(* QudaCommsMap)(const int *coords, void *fdata)
Definition: comm_quda.h:12
MsgHandle * comm_declare_strided_send_displaced(void *buffer, const int displacement[], size_t blksize, int nblocks, size_t stride)
Definition: comm_mpi.cpp:133
void comm_abort(int status)
Definition: comm_mpi.cpp:246
int comm_size(void)
Definition: comm_mpi.cpp:86