QUDA  v0.5.0
A library for QCD on GPUs
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
comm_mpi.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <string.h>
4 #include <mpi.h>
5 
6 #include <quda_internal.h>
7 #include <comm_quda.h>
8 
9 
10 #define MPI_CHECK(mpi_call) do { \
11  int status = mpi_call; \
12  if (status != MPI_SUCCESS) { \
13  char err_string[128]; \
14  int err_len; \
15  MPI_Error_string(status, err_string, &err_len); \
16  err_string[127] = '\0'; \
17  errorQuda("(MPI) %s", err_string); \
18  } \
19 } while (0)
20 
21 
22 struct MsgHandle_s {
23  MPI_Request request;
24 };
25 
26 
27 static int rank = -1;
28 static int size = -1;
29 static int gpuid = -1;
30 
31 
32 void comm_init(int ndim, const int *dims, QudaCommsMap rank_from_coords, void *map_data)
33 {
34  int initialized;
35  MPI_CHECK( MPI_Initialized(&initialized) );
36 
37  if (!initialized) {
38  errorQuda("MPI has not been initialized");
39  }
40 
41  MPI_CHECK( MPI_Comm_rank(MPI_COMM_WORLD, &rank) );
42  MPI_CHECK( MPI_Comm_size(MPI_COMM_WORLD, &size) );
43 
44  int grid_size = 1;
45  for (int i = 0; i < ndim; i++) {
46  grid_size *= dims[i];
47  }
48  if (grid_size != size) {
49  errorQuda("Communication grid size declared via initCommsGridQuda() does not match"
50  " total number of MPI ranks (%d != %d)", grid_size, size);
51  }
52 
53  Topology *topo = comm_create_topology(ndim, dims, rank_from_coords, map_data);
55 
56  // determine which GPU this MPI rank will use
57 
58  char *hostname = comm_hostname();
59  char *hostname_recv_buf = (char *)safe_malloc(128*size);
60 
61  MPI_CHECK( MPI_Allgather(hostname, 128, MPI_CHAR, hostname_recv_buf, 128, MPI_CHAR, MPI_COMM_WORLD) );
62 
63  gpuid = 0;
64  for (int i = 0; i < rank; i++) {
65  if (!strncmp(hostname, &hostname_recv_buf[128*i], 128)) {
66  gpuid++;
67  }
68  }
69  host_free(hostname_recv_buf);
70 
71  int device_count;
72  cudaGetDeviceCount(&device_count);
73  if (device_count == 0) {
74  errorQuda("No CUDA devices found");
75  }
76  if (gpuid >= device_count) {
77  errorQuda("Too few GPUs available on %s", hostname);
78  }
79 }
80 
81 
82 int comm_rank(void)
83 {
84  return rank;
85 }
86 
87 
88 int comm_size(void)
89 {
90  return size;
91 }
92 
93 
94 int comm_gpuid(void)
95 {
96  return gpuid;
97 }
98 
99 
103 MsgHandle *comm_declare_send_displaced(void *buffer, const int displacement[], size_t nbytes)
104 {
106 
107  int rank = comm_rank_displaced(topo, displacement);
108  int tag = comm_rank();
109  MsgHandle *mh = (MsgHandle *)safe_malloc(sizeof(MsgHandle));
110  MPI_CHECK( MPI_Send_init(buffer, nbytes, MPI_BYTE, rank, tag, MPI_COMM_WORLD, &(mh->request)) );
111 
112  return mh;
113 }
114 
115 
119 MsgHandle *comm_declare_receive_displaced(void *buffer, const int displacement[], size_t nbytes)
120 {
122 
123  int rank = comm_rank_displaced(topo, displacement);
124  int tag = rank;
125  MsgHandle *mh = (MsgHandle *)safe_malloc(sizeof(MsgHandle));
126  MPI_CHECK( MPI_Recv_init(buffer, nbytes, MPI_BYTE, rank, tag, MPI_COMM_WORLD, &(mh->request)) );
127 
128  return mh;
129 }
130 
131 
133 {
134  host_free(mh);
135 }
136 
137 
139 {
140  MPI_CHECK( MPI_Start(&(mh->request)) );
141 }
142 
143 
145 {
146  MPI_CHECK( MPI_Wait(&(mh->request), MPI_STATUS_IGNORE) );
147 }
148 
149 
151 {
152  int query;
153  MPI_CHECK( MPI_Test(&(mh->request), &query, MPI_STATUS_IGNORE) );
154 
155  return query;
156 }
157 
158 
159 void comm_allreduce(double* data)
160 {
161  double recvbuf;
162  MPI_CHECK( MPI_Allreduce(data, &recvbuf, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD) );
163  *data = recvbuf;
164 }
165 
166 
167 void comm_allreduce_max(double* data)
168 {
169  double recvbuf;
170  MPI_CHECK( MPI_Allreduce(data, &recvbuf, 1, MPI_DOUBLE, MPI_MAX, MPI_COMM_WORLD) );
171  *data = recvbuf;
172 }
173 
174 
175 void comm_allreduce_array(double* data, size_t size)
176 {
177  double recvbuf[size];
178  MPI_CHECK( MPI_Allreduce(data, &recvbuf, size, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD) );
179  memcpy(data, recvbuf, sizeof(recvbuf));
180 }
181 
182 
183 void comm_allreduce_int(int* data)
184 {
185  int recvbuf;
186  MPI_CHECK( MPI_Allreduce(data, &recvbuf, 1, MPI_INT, MPI_SUM, MPI_COMM_WORLD) );
187  *data = recvbuf;
188 }
189 
190 
192 void comm_broadcast(void *data, size_t nbytes)
193 {
194  MPI_CHECK( MPI_Bcast(data, (int)nbytes, MPI_BYTE, 0, MPI_COMM_WORLD) );
195 }
196 
197 
198 void comm_barrier(void)
199 {
200  MPI_CHECK( MPI_Barrier(MPI_COMM_WORLD) );
201 }
202 
203 
204 void comm_abort(int status)
205 {
206  MPI_CHECK( MPI_Finalize() );
207  exit(status);
208 }