QUDA v0.4.0
A library for QCD on GPUs
quda/lib/comm_mpi.cpp
Go to the documentation of this file.
00001 #include <stdio.h>
00002 #include <stdlib.h>
00003 #include <unistd.h>
00004 #include <string.h>
00005 #include <mpi.h>
00006 #include <comm_quda.h>
00007 
00008 char hostname[128];
00009 static int fwd_nbr=-1;
00010 static int back_nbr=-1;
00011 static int rank = 0;
00012 static int size = -1;
00013 extern int verbose;
00014 static int num_nodes;
00015 extern int getGpuCount();
00016 static int which_gpu = -1;
00017 
00018 static int x_fwd_nbr=-1;
00019 static int y_fwd_nbr=-1;
00020 static int z_fwd_nbr=-1;
00021 static int t_fwd_nbr=-1;
00022 static int x_back_nbr=-1;
00023 static int y_back_nbr=-1;
00024 static int z_back_nbr=-1;
00025 static int t_back_nbr=-1;
00026 
00027 static int xgridsize=1;
00028 static int ygridsize=1;
00029 static int zgridsize=1;
00030 static int tgridsize=1;
00031 static int xgridid = -1;
00032 static int ygridid = -1;
00033 static int zgridid = -1;
00034 static int tgridid = -1;
00035 
00036 static int manual_set_partition[4] ={0, 0, 0, 0};
00037 
00038 void
00039 comm_set_gridsize(int x, int y, int z, int t)
00040 {
00041   xgridsize = x;
00042   ygridsize = y;
00043   zgridsize = z;
00044   tgridsize = t;
00045 
00046   return;
00047 }
00048 
00049 /* This function is for and testing debugging purpose only
00050  * The partitioning schedume should be generated automically 
00051  * in production runs. Don't use this function if you don't know
00052  * what you are doing
00053  */
00054 void
00055 comm_dim_partitioned_set(int dir)
00056 {
00057   manual_set_partition[dir] = 1;
00058   return;
00059 }
00060 
00061 
00062 int 
00063 comm_dim_partitioned(int dir)
00064 {
00065   int ret = 0;
00066   
00067   switch(dir){
00068   case 0: 
00069     ret = (xgridsize > 1);    
00070     break;
00071   case 1: 
00072     ret = (ygridsize > 1);
00073     break;
00074   case 2: 
00075     ret = (zgridsize > 1);
00076     break;
00077   case 3: 
00078     ret = (tgridsize > 1);
00079     break;    
00080   default:
00081     printf("ERROR: invalid direction\n");
00082     comm_exit(1);
00083   }
00084   
00085   if( manual_set_partition[dir]){
00086     ret = manual_set_partition[dir];
00087   }
00088   
00089   return ret;
00090 }
00091 
00092 static void
00093 comm_partition(void)
00094 {
00095   /*
00096   printf("xgridsize=%d\n", xgridsize);
00097   printf("ygridsize=%d\n", ygridsize);
00098   printf("zgridsize=%d\n", zgridsize);
00099   printf("tgridsize=%d\n", tgridsize);
00100   */
00101   if(xgridsize*ygridsize*zgridsize*tgridsize != size){
00102     if (rank ==0){
00103       printf("ERROR: Invalid configuration (t,z,y,x gridsize=%d %d %d %d) "
00104              "but # of MPI processes is %d\n", tgridsize, zgridsize, ygridsize, xgridsize, size);
00105     }
00106     comm_exit(1);
00107   }
00108 
00109   int leftover;
00110 
00111 #if 0
00112   tgridid  = rank/(zgridsize*ygridsize*xgridsize);
00113   leftover = rank%(zgridsize*ygridsize*xgridsize);
00114   zgridid  = leftover/(ygridsize*xgridsize);
00115   leftover = leftover%(ygridsize*xgridsize);
00116   ygridid  = leftover/xgridsize;
00117   xgridid  = leftover%xgridsize;
00118   #define GRID_ID(xid,yid,zid,tid) (tid*zgridsize*ygridsize*xgridsize+zid*ygridsize*xgridsize+yid*xgridsize+xid)
00119 #else
00120 
00121   xgridid  = rank/(ygridsize*zgridsize*tgridsize);
00122   leftover = rank%(ygridsize*zgridsize*tgridsize);
00123   ygridid  = leftover/(zgridsize*tgridsize);
00124   leftover = leftover%(zgridsize*tgridsize);
00125   zgridid  = leftover/tgridsize;
00126   tgridid  = leftover%tgridsize;  
00127 #define GRID_ID(xid,yid,zid,tid) (xid*ygridsize*zgridsize*tgridsize+yid*zgridsize*tgridsize+zid*tgridsize+tid)
00128 #endif
00129   printf("My rank: %d, gridid(t,z,y,x): %d %d %d %d\n", rank, tgridid, zgridid, ygridid, xgridid);
00130 
00131 
00132   int xid, yid, zid, tid;
00133   //X direction neighbors
00134   yid =ygridid;
00135   zid =zgridid;
00136   tid =tgridid;
00137   xid=(xgridid +1)%xgridsize;
00138   x_fwd_nbr = GRID_ID(xid,yid,zid,tid);
00139   xid=(xgridid -1+xgridsize)%xgridsize;
00140   x_back_nbr = GRID_ID(xid,yid,zid,tid);
00141 
00142   //Y direction neighbors
00143   xid =xgridid;
00144   zid =zgridid;
00145   tid =tgridid;
00146   yid =(ygridid+1)%ygridsize;
00147   y_fwd_nbr = GRID_ID(xid,yid,zid,tid);
00148   yid=(ygridid -1+ygridsize)%ygridsize;
00149   y_back_nbr = GRID_ID(xid,yid,zid,tid);
00150 
00151   //Z direction neighbors
00152   xid =xgridid;
00153   yid =ygridid;
00154   tid =tgridid;
00155   zid =(zgridid+1)%zgridsize;
00156   z_fwd_nbr = GRID_ID(xid,yid,zid,tid);
00157   zid=(zgridid -1+zgridsize)%zgridsize;
00158   z_back_nbr = GRID_ID(xid,yid,zid,tid);
00159 
00160   //T direction neighbors
00161   xid =xgridid;
00162   yid =ygridid;
00163   zid =zgridid;
00164   tid =(tgridid+1)%tgridsize;
00165   t_fwd_nbr = GRID_ID(xid,yid,zid,tid);
00166   tid=(tgridid -1+tgridsize)%tgridsize;
00167   t_back_nbr = GRID_ID(xid,yid,zid,tid);
00168 
00169   printf("MPI rank: rank=%d, hostname=%s, x_fwd_nbr=%d, x_back_nbr=%d\n", rank, hostname, x_fwd_nbr, x_back_nbr);
00170   printf("MPI rank: rank=%d, hostname=%s, y_fwd_nbr=%d, y_back_nbr=%d\n", rank, hostname, y_fwd_nbr, y_back_nbr);
00171   printf("MPI rank: rank=%d, hostname=%s, z_fwd_nbr=%d, z_back_nbr=%d\n", rank, hostname, z_fwd_nbr, z_back_nbr);
00172   printf("MPI rank: rank=%d, hostname=%s, t_fwd_nbr=%d, t_back_nbr=%d\n", rank, hostname, t_fwd_nbr, t_back_nbr);
00173 
00174   
00175 }
00176 
00177 int 
00178 comm_get_neighbor_rank(int dx, int dy, int dz, int dt)
00179 {
00180   int ret;
00181 #if 0
00182 #define GRID_ID(xid,yid,zid,tid) (tid*zgridsize*ygridsize*xgridsize+zid*ygridsize*xgridsize+yid*xgridsize+xid)
00183 #else
00184 #define GRID_ID(xid,yid,zid,tid) (xid*ygridsize*zgridsize*tgridsize+yid*zgridsize*tgridsize+zid*tgridsize+tid)
00185 #endif
00186 
00187   
00188   int xid, yid, zid, tid;
00189   xid=(xgridid + dx + xgridsize)%xgridsize;
00190   yid=(ygridid + dy + ygridsize)%ygridsize;
00191   zid=(zgridid + dz + zgridsize)%zgridsize;
00192   tid=(tgridid + dt + tgridsize)%tgridsize;
00193 
00194   ret = GRID_ID(xid,yid,zid,tid);
00195 
00196   return ret;
00197 }
00198 
00199 
00200 void 
00201 comm_init()
00202 {
00203   int i;
00204   
00205   static int firsttime=1;
00206   if (!firsttime){ 
00207     return;
00208   }
00209   firsttime = 0;
00210 
00211   gethostname(hostname, 128);  
00212   MPI_Comm_size(MPI_COMM_WORLD, &size);
00213   MPI_Comm_rank(MPI_COMM_WORLD, &rank);
00214 
00215   int gpus_per_node = getGpuCount();  
00216 
00217   comm_partition();
00218 
00219   back_nbr = (rank -1 + size)%size;
00220   fwd_nbr = (rank +1)%size;
00221   num_nodes=size / getGpuCount();
00222   if(num_nodes ==0) {
00223         num_nodes=1;
00224   }
00225 
00226   //determine which gpu this MPI process is going to use
00227   char* hostname_recv_buf = (char*)malloc(128*size);
00228   if(hostname_recv_buf == NULL){
00229     printf("ERROR: malloc failed for host_recv_buf\n");
00230     comm_exit(1);
00231   }
00232   
00233   gethostname(hostname, 128);
00234   int rc = MPI_Allgather(hostname, 128, MPI_CHAR, hostname_recv_buf, 128, MPI_CHAR, MPI_COMM_WORLD);
00235   if (rc != MPI_SUCCESS){
00236     printf("ERROR: MPI_Allgather failed for hostname\n");
00237     comm_exit(1);
00238   }
00239 
00240   which_gpu=0;
00241   for(i=0;i < size; i++){
00242     if (i == rank){
00243       break;
00244     }
00245     if (strncmp(hostname, hostname_recv_buf + 128*i, 128) == 0){
00246       which_gpu ++;
00247     }
00248   }
00249   
00250   if (which_gpu >= gpus_per_node){
00251     printf("ERROR: invalid gpu(%d) to use in rank=%d mpi process\n", which_gpu, rank);
00252     comm_exit(1);
00253   }
00254   
00255   srand(rank*999);
00256   
00257   free(hostname_recv_buf);
00258   return;
00259 }
00260 
00261 int comm_gpuid()
00262 {
00263   //int gpu = rank%getGpuCount();
00264 
00265   return which_gpu;
00266 }
00267 int
00268 comm_rank(void)
00269 {
00270   return rank;
00271 }
00272 
00273 int
00274 comm_size(void)
00275 {
00276   return size;
00277 }
00278 
00279 int
00280 comm_dim(int dir) {
00281 
00282   int i;
00283   switch(dir) {
00284   case 0:
00285     i = xgridsize;
00286     break;
00287   case 1:
00288     i = ygridsize;
00289     break;
00290   case 2:
00291     i = zgridsize;
00292     break;
00293   case 3:
00294     i = tgridsize;
00295     break;
00296   default:
00297     printf("Cannot get direction %d", dir);
00298     comm_exit(1);
00299   }
00300 
00301   return i;
00302 }
00303 
00304 int
00305 comm_coords(int dir) {
00306 
00307   int i;
00308   switch(dir) {
00309   case 0:
00310     i = xgridid;
00311     break;
00312   case 1:
00313     i = ygridid;
00314     break;
00315   case 2:
00316     i = zgridid;
00317     break;
00318   case 3:
00319     i = tgridid;
00320     break;
00321   default:
00322     printf("Cannot get direction %d", dir);
00323     comm_exit(1);
00324   }
00325 
00326   return i;
00327 }
00328 
00329 unsigned long
00330 comm_send(void* buf, int len, int dst, void* _request)
00331 {
00332   
00333   MPI_Request* request = (MPI_Request*)_request;
00334   if (request == NULL){
00335     printf("ERROR: malloc failed for mpi request\n");
00336     comm_exit(1);
00337   }
00338 
00339   int dstproc;
00340   int sendtag=99;
00341   if (dst == BACK_NBR){
00342     dstproc = back_nbr;
00343     sendtag = BACK_NBR;
00344   }else if (dst == FWD_NBR){
00345     dstproc = fwd_nbr;
00346     sendtag = FWD_NBR;
00347   }else{
00348     printf("ERROR: invalid dest\n");
00349     comm_exit(1);
00350   }
00351 
00352   MPI_Isend(buf, len, MPI_BYTE, dstproc, sendtag, MPI_COMM_WORLD, request);  
00353   return (unsigned long)request;  
00354 }
00355 
00356 unsigned long
00357 comm_send_to_rank(void* buf, int len, int dst_rank, void* _request)
00358 {
00359   
00360   MPI_Request* request = (MPI_Request*)_request;
00361   if (request == NULL){
00362     printf("ERROR: malloc failed for mpi request\n");
00363     comm_exit(1);
00364   }
00365   
00366   if(dst_rank < 0 || dst_rank >= comm_size()){
00367     printf("ERROR: Invalid dst rank(%d)\n", dst_rank);
00368     comm_exit(1);
00369   }
00370   int sendtag = 99;
00371   MPI_Isend(buf, len, MPI_BYTE, dst_rank, sendtag, MPI_COMM_WORLD, request);  
00372   return (unsigned long)request;  
00373 }
00374 
00375 unsigned long
00376 comm_send_with_tag(void* buf, int len, int dst, int tag, void*_request)
00377 {
00378 
00379   MPI_Request* request = (MPI_Request*)_request;
00380   if (request == NULL){
00381     printf("ERROR: malloc failed for mpi request\n");
00382     comm_exit(1);
00383   }
00384 
00385   int dstproc = -1;
00386   switch(dst){
00387   case X_BACK_NBR:
00388     dstproc = x_back_nbr;
00389     break;
00390   case X_FWD_NBR:
00391     dstproc = x_fwd_nbr;
00392     break;
00393   case Y_BACK_NBR:
00394     dstproc = y_back_nbr;
00395     break;
00396   case Y_FWD_NBR:
00397     dstproc = y_fwd_nbr;
00398     break;
00399   case Z_BACK_NBR:
00400     dstproc = z_back_nbr;
00401     break;
00402   case Z_FWD_NBR:
00403     dstproc = z_fwd_nbr;
00404     break;
00405   case T_BACK_NBR:
00406     dstproc = t_back_nbr;
00407     break;
00408   case T_FWD_NBR:
00409     dstproc = t_fwd_nbr;
00410     break;
00411   default:
00412     printf("ERROR: invalid dest, line %d, file %s\n", __LINE__, __FILE__);
00413     comm_exit(1);
00414   }
00415 
00416   MPI_Isend(buf, len, MPI_BYTE, dstproc, tag, MPI_COMM_WORLD, request);
00417   return (unsigned long)request;
00418 }
00419 
00420 
00421 
00422 unsigned long
00423 comm_recv(void* buf, int len, int src, void*_request)
00424 {
00425   MPI_Request* request = (MPI_Request*)_request;
00426   if (request == NULL){
00427     printf("ERROR: malloc failed for mpi request\n");
00428     comm_exit(1);
00429   }
00430   
00431   int srcproc=-1;
00432   int recvtag=99; //recvtag is opposite to the sendtag
00433   if (src == BACK_NBR){
00434     srcproc = back_nbr;
00435     recvtag = FWD_NBR;
00436   }else if (src == FWD_NBR){
00437     srcproc = fwd_nbr;
00438     recvtag = BACK_NBR;
00439   }else{
00440     printf("ERROR: invalid source\n");
00441     comm_exit(1);
00442   }
00443   
00444   MPI_Irecv(buf, len, MPI_BYTE, srcproc, recvtag, MPI_COMM_WORLD, request);
00445   
00446   return (unsigned long)request;
00447 }
00448 
00449 unsigned long
00450 comm_recv_from_rank(void* buf, int len, int src_rank, void* _request)
00451 {
00452   MPI_Request* request = (MPI_Request*)_request;
00453   if (request == NULL){
00454     printf("ERROR: malloc failed for mpi request\n");
00455     comm_exit(1);
00456   }
00457   
00458   if(src_rank < 0 || src_rank >= comm_size()){
00459     printf("ERROR: Invalid src rank(%d)\n", src_rank);
00460     comm_exit(1);
00461   }
00462   
00463   int recvtag = 99;
00464   MPI_Irecv(buf, len, MPI_BYTE, src_rank, recvtag, MPI_COMM_WORLD, request);
00465   
00466   return (unsigned long)request;
00467 }
00468 
00469 unsigned long
00470 comm_recv_with_tag(void* buf, int len, int src, int tag, void* _request)
00471 { 
00472   MPI_Request* request = (MPI_Request*)_request;
00473   if (request == NULL){
00474     printf("ERROR: malloc failed for mpi request\n");
00475     comm_exit(1);
00476   }
00477   
00478   int srcproc=-1;
00479   switch (src){
00480   case X_BACK_NBR:
00481     srcproc = x_back_nbr;
00482     break;
00483   case X_FWD_NBR:
00484     srcproc = x_fwd_nbr;
00485     break;
00486   case Y_BACK_NBR:
00487     srcproc = y_back_nbr;
00488     break;
00489   case Y_FWD_NBR:
00490     srcproc = y_fwd_nbr;
00491     break;
00492   case Z_BACK_NBR:
00493     srcproc = z_back_nbr;
00494     break;
00495   case Z_FWD_NBR:
00496     srcproc = z_fwd_nbr;
00497     break;
00498   case T_BACK_NBR:
00499     srcproc = t_back_nbr;
00500     break;
00501   case T_FWD_NBR:
00502     srcproc = t_fwd_nbr;
00503     break;
00504   default:
00505     printf("ERROR: invalid source, line %d, file %s\n", __LINE__, __FILE__);
00506     comm_exit(1);
00507   }
00508   MPI_Irecv(buf, len, MPI_BYTE, srcproc, tag, MPI_COMM_WORLD, request);
00509   
00510   return (unsigned long)request;
00511 }
00512 
00513 int comm_query(void* request) 
00514 {
00515   MPI_Status status;
00516   int query;
00517   int rc = MPI_Test( (MPI_Request*)request, &query, &status);
00518   if (rc != MPI_SUCCESS) {
00519     printf("ERROR: MPI_Test failed\n");
00520     comm_exit(1);
00521   }
00522 
00523   return query;
00524 }
00525 
00526 void comm_free(void* request) {
00527   free((void*)request);
00528   return;
00529 }
00530 
00531 //this request should be some return value from comm_recv
00532 void 
00533 comm_wait(void* request)
00534 {
00535   
00536   MPI_Status status;
00537   int rc = MPI_Wait( (MPI_Request*)request, &status);
00538   if (rc != MPI_SUCCESS){
00539     printf("ERROR: MPI_Wait failed\n");
00540     comm_exit(1);
00541   }
00542   
00543   return;
00544 }
00545 
00546 //we always reduce one double value
00547 void
00548 comm_allreduce(double* data)
00549 {
00550   double recvbuf;
00551   int rc = MPI_Allreduce ( data, &recvbuf,1,MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
00552   if (rc != MPI_SUCCESS){
00553     printf("ERROR: MPI_Allreduce failed\n");
00554     comm_exit(1);
00555   }
00556   
00557   *data = recvbuf;
00558   
00559   return;
00560 } 
00561 
00562 //reduce n double value
00563 void
00564 comm_allreduce_array(double* data, size_t size)
00565 {
00566   double recvbuf[size];
00567   int rc = MPI_Allreduce ( data, &recvbuf,size,MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
00568   if (rc != MPI_SUCCESS){
00569     printf("ERROR: MPI_Allreduce failed\n");
00570     comm_exit(1);
00571   }
00572   
00573   memcpy(data, recvbuf, sizeof(recvbuf));
00574   
00575   return;
00576 }
00577 
00578 //we always reduce one double value
00579 void
00580 comm_allreduce_max(double* data)
00581 {
00582   double recvbuf;
00583   int rc = MPI_Allreduce ( data, &recvbuf,1,MPI_DOUBLE, MPI_MAX, MPI_COMM_WORLD);
00584   if (rc != MPI_SUCCESS){
00585     printf("ERROR: MPI_Allreduce failed\n");
00586     comm_exit(1);
00587   }
00588   
00589   *data = recvbuf;
00590   
00591   return;
00592 } 
00593 
00594 // broadcast from rank 0
00595 void
00596 comm_broadcast(void *data, size_t nbytes)
00597 {
00598   MPI_Bcast(data, (int)nbytes, MPI_BYTE, 0, MPI_COMM_WORLD);
00599 }
00600 
00601 void
00602 comm_barrier(void)
00603 {
00604   MPI_Barrier(MPI_COMM_WORLD);  
00605 }
00606 void 
00607 comm_cleanup()
00608 {
00609   MPI_Finalize();
00610 }
00611 
00612 void
00613 comm_exit(int ret)
00614 {
00615   MPI_Finalize();
00616   exit(ret);
00617 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines