QUDA v0.4.0
A library for QCD on GPUs
|
00001 #include <stdio.h> 00002 #include <stdlib.h> 00003 #include <unistd.h> 00004 #include <string.h> 00005 #include <mpi.h> 00006 #include <comm_quda.h> 00007 00008 char hostname[128]; 00009 static int fwd_nbr=-1; 00010 static int back_nbr=-1; 00011 static int rank = 0; 00012 static int size = -1; 00013 extern int verbose; 00014 static int num_nodes; 00015 extern int getGpuCount(); 00016 static int which_gpu = -1; 00017 00018 static int x_fwd_nbr=-1; 00019 static int y_fwd_nbr=-1; 00020 static int z_fwd_nbr=-1; 00021 static int t_fwd_nbr=-1; 00022 static int x_back_nbr=-1; 00023 static int y_back_nbr=-1; 00024 static int z_back_nbr=-1; 00025 static int t_back_nbr=-1; 00026 00027 static int xgridsize=1; 00028 static int ygridsize=1; 00029 static int zgridsize=1; 00030 static int tgridsize=1; 00031 static int xgridid = -1; 00032 static int ygridid = -1; 00033 static int zgridid = -1; 00034 static int tgridid = -1; 00035 00036 static int manual_set_partition[4] ={0, 0, 0, 0}; 00037 00038 void 00039 comm_set_gridsize(int x, int y, int z, int t) 00040 { 00041 xgridsize = x; 00042 ygridsize = y; 00043 zgridsize = z; 00044 tgridsize = t; 00045 00046 return; 00047 } 00048 00049 /* This function is for and testing debugging purpose only 00050 * The partitioning schedume should be generated automically 00051 * in production runs. Don't use this function if you don't know 00052 * what you are doing 00053 */ 00054 void 00055 comm_dim_partitioned_set(int dir) 00056 { 00057 manual_set_partition[dir] = 1; 00058 return; 00059 } 00060 00061 00062 int 00063 comm_dim_partitioned(int dir) 00064 { 00065 int ret = 0; 00066 00067 switch(dir){ 00068 case 0: 00069 ret = (xgridsize > 1); 00070 break; 00071 case 1: 00072 ret = (ygridsize > 1); 00073 break; 00074 case 2: 00075 ret = (zgridsize > 1); 00076 break; 00077 case 3: 00078 ret = (tgridsize > 1); 00079 break; 00080 default: 00081 printf("ERROR: invalid direction\n"); 00082 comm_exit(1); 00083 } 00084 00085 if( manual_set_partition[dir]){ 00086 ret = manual_set_partition[dir]; 00087 } 00088 00089 return ret; 00090 } 00091 00092 static void 00093 comm_partition(void) 00094 { 00095 /* 00096 printf("xgridsize=%d\n", xgridsize); 00097 printf("ygridsize=%d\n", ygridsize); 00098 printf("zgridsize=%d\n", zgridsize); 00099 printf("tgridsize=%d\n", tgridsize); 00100 */ 00101 if(xgridsize*ygridsize*zgridsize*tgridsize != size){ 00102 if (rank ==0){ 00103 printf("ERROR: Invalid configuration (t,z,y,x gridsize=%d %d %d %d) " 00104 "but # of MPI processes is %d\n", tgridsize, zgridsize, ygridsize, xgridsize, size); 00105 } 00106 comm_exit(1); 00107 } 00108 00109 int leftover; 00110 00111 #if 0 00112 tgridid = rank/(zgridsize*ygridsize*xgridsize); 00113 leftover = rank%(zgridsize*ygridsize*xgridsize); 00114 zgridid = leftover/(ygridsize*xgridsize); 00115 leftover = leftover%(ygridsize*xgridsize); 00116 ygridid = leftover/xgridsize; 00117 xgridid = leftover%xgridsize; 00118 #define GRID_ID(xid,yid,zid,tid) (tid*zgridsize*ygridsize*xgridsize+zid*ygridsize*xgridsize+yid*xgridsize+xid) 00119 #else 00120 00121 xgridid = rank/(ygridsize*zgridsize*tgridsize); 00122 leftover = rank%(ygridsize*zgridsize*tgridsize); 00123 ygridid = leftover/(zgridsize*tgridsize); 00124 leftover = leftover%(zgridsize*tgridsize); 00125 zgridid = leftover/tgridsize; 00126 tgridid = leftover%tgridsize; 00127 #define GRID_ID(xid,yid,zid,tid) (xid*ygridsize*zgridsize*tgridsize+yid*zgridsize*tgridsize+zid*tgridsize+tid) 00128 #endif 00129 printf("My rank: %d, gridid(t,z,y,x): %d %d %d %d\n", rank, tgridid, zgridid, ygridid, xgridid); 00130 00131 00132 int xid, yid, zid, tid; 00133 //X direction neighbors 00134 yid =ygridid; 00135 zid =zgridid; 00136 tid =tgridid; 00137 xid=(xgridid +1)%xgridsize; 00138 x_fwd_nbr = GRID_ID(xid,yid,zid,tid); 00139 xid=(xgridid -1+xgridsize)%xgridsize; 00140 x_back_nbr = GRID_ID(xid,yid,zid,tid); 00141 00142 //Y direction neighbors 00143 xid =xgridid; 00144 zid =zgridid; 00145 tid =tgridid; 00146 yid =(ygridid+1)%ygridsize; 00147 y_fwd_nbr = GRID_ID(xid,yid,zid,tid); 00148 yid=(ygridid -1+ygridsize)%ygridsize; 00149 y_back_nbr = GRID_ID(xid,yid,zid,tid); 00150 00151 //Z direction neighbors 00152 xid =xgridid; 00153 yid =ygridid; 00154 tid =tgridid; 00155 zid =(zgridid+1)%zgridsize; 00156 z_fwd_nbr = GRID_ID(xid,yid,zid,tid); 00157 zid=(zgridid -1+zgridsize)%zgridsize; 00158 z_back_nbr = GRID_ID(xid,yid,zid,tid); 00159 00160 //T direction neighbors 00161 xid =xgridid; 00162 yid =ygridid; 00163 zid =zgridid; 00164 tid =(tgridid+1)%tgridsize; 00165 t_fwd_nbr = GRID_ID(xid,yid,zid,tid); 00166 tid=(tgridid -1+tgridsize)%tgridsize; 00167 t_back_nbr = GRID_ID(xid,yid,zid,tid); 00168 00169 printf("MPI rank: rank=%d, hostname=%s, x_fwd_nbr=%d, x_back_nbr=%d\n", rank, hostname, x_fwd_nbr, x_back_nbr); 00170 printf("MPI rank: rank=%d, hostname=%s, y_fwd_nbr=%d, y_back_nbr=%d\n", rank, hostname, y_fwd_nbr, y_back_nbr); 00171 printf("MPI rank: rank=%d, hostname=%s, z_fwd_nbr=%d, z_back_nbr=%d\n", rank, hostname, z_fwd_nbr, z_back_nbr); 00172 printf("MPI rank: rank=%d, hostname=%s, t_fwd_nbr=%d, t_back_nbr=%d\n", rank, hostname, t_fwd_nbr, t_back_nbr); 00173 00174 00175 } 00176 00177 int 00178 comm_get_neighbor_rank(int dx, int dy, int dz, int dt) 00179 { 00180 int ret; 00181 #if 0 00182 #define GRID_ID(xid,yid,zid,tid) (tid*zgridsize*ygridsize*xgridsize+zid*ygridsize*xgridsize+yid*xgridsize+xid) 00183 #else 00184 #define GRID_ID(xid,yid,zid,tid) (xid*ygridsize*zgridsize*tgridsize+yid*zgridsize*tgridsize+zid*tgridsize+tid) 00185 #endif 00186 00187 00188 int xid, yid, zid, tid; 00189 xid=(xgridid + dx + xgridsize)%xgridsize; 00190 yid=(ygridid + dy + ygridsize)%ygridsize; 00191 zid=(zgridid + dz + zgridsize)%zgridsize; 00192 tid=(tgridid + dt + tgridsize)%tgridsize; 00193 00194 ret = GRID_ID(xid,yid,zid,tid); 00195 00196 return ret; 00197 } 00198 00199 00200 void 00201 comm_init() 00202 { 00203 int i; 00204 00205 static int firsttime=1; 00206 if (!firsttime){ 00207 return; 00208 } 00209 firsttime = 0; 00210 00211 gethostname(hostname, 128); 00212 MPI_Comm_size(MPI_COMM_WORLD, &size); 00213 MPI_Comm_rank(MPI_COMM_WORLD, &rank); 00214 00215 int gpus_per_node = getGpuCount(); 00216 00217 comm_partition(); 00218 00219 back_nbr = (rank -1 + size)%size; 00220 fwd_nbr = (rank +1)%size; 00221 num_nodes=size / getGpuCount(); 00222 if(num_nodes ==0) { 00223 num_nodes=1; 00224 } 00225 00226 //determine which gpu this MPI process is going to use 00227 char* hostname_recv_buf = (char*)malloc(128*size); 00228 if(hostname_recv_buf == NULL){ 00229 printf("ERROR: malloc failed for host_recv_buf\n"); 00230 comm_exit(1); 00231 } 00232 00233 gethostname(hostname, 128); 00234 int rc = MPI_Allgather(hostname, 128, MPI_CHAR, hostname_recv_buf, 128, MPI_CHAR, MPI_COMM_WORLD); 00235 if (rc != MPI_SUCCESS){ 00236 printf("ERROR: MPI_Allgather failed for hostname\n"); 00237 comm_exit(1); 00238 } 00239 00240 which_gpu=0; 00241 for(i=0;i < size; i++){ 00242 if (i == rank){ 00243 break; 00244 } 00245 if (strncmp(hostname, hostname_recv_buf + 128*i, 128) == 0){ 00246 which_gpu ++; 00247 } 00248 } 00249 00250 if (which_gpu >= gpus_per_node){ 00251 printf("ERROR: invalid gpu(%d) to use in rank=%d mpi process\n", which_gpu, rank); 00252 comm_exit(1); 00253 } 00254 00255 srand(rank*999); 00256 00257 free(hostname_recv_buf); 00258 return; 00259 } 00260 00261 int comm_gpuid() 00262 { 00263 //int gpu = rank%getGpuCount(); 00264 00265 return which_gpu; 00266 } 00267 int 00268 comm_rank(void) 00269 { 00270 return rank; 00271 } 00272 00273 int 00274 comm_size(void) 00275 { 00276 return size; 00277 } 00278 00279 int 00280 comm_dim(int dir) { 00281 00282 int i; 00283 switch(dir) { 00284 case 0: 00285 i = xgridsize; 00286 break; 00287 case 1: 00288 i = ygridsize; 00289 break; 00290 case 2: 00291 i = zgridsize; 00292 break; 00293 case 3: 00294 i = tgridsize; 00295 break; 00296 default: 00297 printf("Cannot get direction %d", dir); 00298 comm_exit(1); 00299 } 00300 00301 return i; 00302 } 00303 00304 int 00305 comm_coords(int dir) { 00306 00307 int i; 00308 switch(dir) { 00309 case 0: 00310 i = xgridid; 00311 break; 00312 case 1: 00313 i = ygridid; 00314 break; 00315 case 2: 00316 i = zgridid; 00317 break; 00318 case 3: 00319 i = tgridid; 00320 break; 00321 default: 00322 printf("Cannot get direction %d", dir); 00323 comm_exit(1); 00324 } 00325 00326 return i; 00327 } 00328 00329 unsigned long 00330 comm_send(void* buf, int len, int dst, void* _request) 00331 { 00332 00333 MPI_Request* request = (MPI_Request*)_request; 00334 if (request == NULL){ 00335 printf("ERROR: malloc failed for mpi request\n"); 00336 comm_exit(1); 00337 } 00338 00339 int dstproc; 00340 int sendtag=99; 00341 if (dst == BACK_NBR){ 00342 dstproc = back_nbr; 00343 sendtag = BACK_NBR; 00344 }else if (dst == FWD_NBR){ 00345 dstproc = fwd_nbr; 00346 sendtag = FWD_NBR; 00347 }else{ 00348 printf("ERROR: invalid dest\n"); 00349 comm_exit(1); 00350 } 00351 00352 MPI_Isend(buf, len, MPI_BYTE, dstproc, sendtag, MPI_COMM_WORLD, request); 00353 return (unsigned long)request; 00354 } 00355 00356 unsigned long 00357 comm_send_to_rank(void* buf, int len, int dst_rank, void* _request) 00358 { 00359 00360 MPI_Request* request = (MPI_Request*)_request; 00361 if (request == NULL){ 00362 printf("ERROR: malloc failed for mpi request\n"); 00363 comm_exit(1); 00364 } 00365 00366 if(dst_rank < 0 || dst_rank >= comm_size()){ 00367 printf("ERROR: Invalid dst rank(%d)\n", dst_rank); 00368 comm_exit(1); 00369 } 00370 int sendtag = 99; 00371 MPI_Isend(buf, len, MPI_BYTE, dst_rank, sendtag, MPI_COMM_WORLD, request); 00372 return (unsigned long)request; 00373 } 00374 00375 unsigned long 00376 comm_send_with_tag(void* buf, int len, int dst, int tag, void*_request) 00377 { 00378 00379 MPI_Request* request = (MPI_Request*)_request; 00380 if (request == NULL){ 00381 printf("ERROR: malloc failed for mpi request\n"); 00382 comm_exit(1); 00383 } 00384 00385 int dstproc = -1; 00386 switch(dst){ 00387 case X_BACK_NBR: 00388 dstproc = x_back_nbr; 00389 break; 00390 case X_FWD_NBR: 00391 dstproc = x_fwd_nbr; 00392 break; 00393 case Y_BACK_NBR: 00394 dstproc = y_back_nbr; 00395 break; 00396 case Y_FWD_NBR: 00397 dstproc = y_fwd_nbr; 00398 break; 00399 case Z_BACK_NBR: 00400 dstproc = z_back_nbr; 00401 break; 00402 case Z_FWD_NBR: 00403 dstproc = z_fwd_nbr; 00404 break; 00405 case T_BACK_NBR: 00406 dstproc = t_back_nbr; 00407 break; 00408 case T_FWD_NBR: 00409 dstproc = t_fwd_nbr; 00410 break; 00411 default: 00412 printf("ERROR: invalid dest, line %d, file %s\n", __LINE__, __FILE__); 00413 comm_exit(1); 00414 } 00415 00416 MPI_Isend(buf, len, MPI_BYTE, dstproc, tag, MPI_COMM_WORLD, request); 00417 return (unsigned long)request; 00418 } 00419 00420 00421 00422 unsigned long 00423 comm_recv(void* buf, int len, int src, void*_request) 00424 { 00425 MPI_Request* request = (MPI_Request*)_request; 00426 if (request == NULL){ 00427 printf("ERROR: malloc failed for mpi request\n"); 00428 comm_exit(1); 00429 } 00430 00431 int srcproc=-1; 00432 int recvtag=99; //recvtag is opposite to the sendtag 00433 if (src == BACK_NBR){ 00434 srcproc = back_nbr; 00435 recvtag = FWD_NBR; 00436 }else if (src == FWD_NBR){ 00437 srcproc = fwd_nbr; 00438 recvtag = BACK_NBR; 00439 }else{ 00440 printf("ERROR: invalid source\n"); 00441 comm_exit(1); 00442 } 00443 00444 MPI_Irecv(buf, len, MPI_BYTE, srcproc, recvtag, MPI_COMM_WORLD, request); 00445 00446 return (unsigned long)request; 00447 } 00448 00449 unsigned long 00450 comm_recv_from_rank(void* buf, int len, int src_rank, void* _request) 00451 { 00452 MPI_Request* request = (MPI_Request*)_request; 00453 if (request == NULL){ 00454 printf("ERROR: malloc failed for mpi request\n"); 00455 comm_exit(1); 00456 } 00457 00458 if(src_rank < 0 || src_rank >= comm_size()){ 00459 printf("ERROR: Invalid src rank(%d)\n", src_rank); 00460 comm_exit(1); 00461 } 00462 00463 int recvtag = 99; 00464 MPI_Irecv(buf, len, MPI_BYTE, src_rank, recvtag, MPI_COMM_WORLD, request); 00465 00466 return (unsigned long)request; 00467 } 00468 00469 unsigned long 00470 comm_recv_with_tag(void* buf, int len, int src, int tag, void* _request) 00471 { 00472 MPI_Request* request = (MPI_Request*)_request; 00473 if (request == NULL){ 00474 printf("ERROR: malloc failed for mpi request\n"); 00475 comm_exit(1); 00476 } 00477 00478 int srcproc=-1; 00479 switch (src){ 00480 case X_BACK_NBR: 00481 srcproc = x_back_nbr; 00482 break; 00483 case X_FWD_NBR: 00484 srcproc = x_fwd_nbr; 00485 break; 00486 case Y_BACK_NBR: 00487 srcproc = y_back_nbr; 00488 break; 00489 case Y_FWD_NBR: 00490 srcproc = y_fwd_nbr; 00491 break; 00492 case Z_BACK_NBR: 00493 srcproc = z_back_nbr; 00494 break; 00495 case Z_FWD_NBR: 00496 srcproc = z_fwd_nbr; 00497 break; 00498 case T_BACK_NBR: 00499 srcproc = t_back_nbr; 00500 break; 00501 case T_FWD_NBR: 00502 srcproc = t_fwd_nbr; 00503 break; 00504 default: 00505 printf("ERROR: invalid source, line %d, file %s\n", __LINE__, __FILE__); 00506 comm_exit(1); 00507 } 00508 MPI_Irecv(buf, len, MPI_BYTE, srcproc, tag, MPI_COMM_WORLD, request); 00509 00510 return (unsigned long)request; 00511 } 00512 00513 int comm_query(void* request) 00514 { 00515 MPI_Status status; 00516 int query; 00517 int rc = MPI_Test( (MPI_Request*)request, &query, &status); 00518 if (rc != MPI_SUCCESS) { 00519 printf("ERROR: MPI_Test failed\n"); 00520 comm_exit(1); 00521 } 00522 00523 return query; 00524 } 00525 00526 void comm_free(void* request) { 00527 free((void*)request); 00528 return; 00529 } 00530 00531 //this request should be some return value from comm_recv 00532 void 00533 comm_wait(void* request) 00534 { 00535 00536 MPI_Status status; 00537 int rc = MPI_Wait( (MPI_Request*)request, &status); 00538 if (rc != MPI_SUCCESS){ 00539 printf("ERROR: MPI_Wait failed\n"); 00540 comm_exit(1); 00541 } 00542 00543 return; 00544 } 00545 00546 //we always reduce one double value 00547 void 00548 comm_allreduce(double* data) 00549 { 00550 double recvbuf; 00551 int rc = MPI_Allreduce ( data, &recvbuf,1,MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); 00552 if (rc != MPI_SUCCESS){ 00553 printf("ERROR: MPI_Allreduce failed\n"); 00554 comm_exit(1); 00555 } 00556 00557 *data = recvbuf; 00558 00559 return; 00560 } 00561 00562 //reduce n double value 00563 void 00564 comm_allreduce_array(double* data, size_t size) 00565 { 00566 double recvbuf[size]; 00567 int rc = MPI_Allreduce ( data, &recvbuf,size,MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); 00568 if (rc != MPI_SUCCESS){ 00569 printf("ERROR: MPI_Allreduce failed\n"); 00570 comm_exit(1); 00571 } 00572 00573 memcpy(data, recvbuf, sizeof(recvbuf)); 00574 00575 return; 00576 } 00577 00578 //we always reduce one double value 00579 void 00580 comm_allreduce_max(double* data) 00581 { 00582 double recvbuf; 00583 int rc = MPI_Allreduce ( data, &recvbuf,1,MPI_DOUBLE, MPI_MAX, MPI_COMM_WORLD); 00584 if (rc != MPI_SUCCESS){ 00585 printf("ERROR: MPI_Allreduce failed\n"); 00586 comm_exit(1); 00587 } 00588 00589 *data = recvbuf; 00590 00591 return; 00592 } 00593 00594 // broadcast from rank 0 00595 void 00596 comm_broadcast(void *data, size_t nbytes) 00597 { 00598 MPI_Bcast(data, (int)nbytes, MPI_BYTE, 0, MPI_COMM_WORLD); 00599 } 00600 00601 void 00602 comm_barrier(void) 00603 { 00604 MPI_Barrier(MPI_COMM_WORLD); 00605 } 00606 void 00607 comm_cleanup() 00608 { 00609 MPI_Finalize(); 00610 } 00611 00612 void 00613 comm_exit(int ret) 00614 { 00615 MPI_Finalize(); 00616 exit(ret); 00617 }