QUDA  v1.1.0
A library for QCD on GPUs
pgauge_exchange.cu
Go to the documentation of this file.
1 #include <quda_internal.h>
2 #include <quda_matrix.h>
3 #include <tune_quda.h>
4 #include <gauge_field.h>
5 #include <gauge_field_order.h>
6 #include <comm_quda.h>
7 #include <pgauge_monte.h>
8 #include <instantiate.h>
9 
10 namespace quda {
11 
12  template <typename Float, QudaReconstructType recon>
13  struct GaugeFixUnPackArg {
14  int X[4]; // grid dimensions
15  using Gauge = typename gauge_mapper<Float, recon>::type;
16  Gauge dataOr;
17  int size;
18  complex<Float> *array;
19  int parity;
20  int face;
21  int dir;
22  int borderid;
23  GaugeFixUnPackArg(GaugeField & data)
24  : dataOr(data)
25  {
26  for ( int dir = 0; dir < 4; ++dir ) X[dir] = data.X()[dir];
27  }
28  };
29 
30  template <int NElems, typename Float, bool pack, typename Arg>
31  __global__ void Kernel_UnPack(Arg arg)
32  {
33  int idx = blockIdx.x * blockDim.x + threadIdx.x;
34  if ( idx >= arg.size ) return;
35  int X[4];
36  for ( int dr = 0; dr < 4; ++dr ) X[dr] = arg.X[dr];
37  int x[4];
38  int za, xodd;
39  switch ( arg.face ) {
40  case 0: //X FACE
41  za = idx / ( X[1] / 2);
42  x[3] = za / X[2];
43  x[2] = za - x[3] * X[2];
44  x[0] = arg.borderid;
45  xodd = (arg.borderid + x[2] + x[3] + arg.parity) & 1;
46  x[1] = (2 * idx + xodd) - za * X[1];
47  break;
48  case 1: //Y FACE
49  za = idx / ( X[0] / 2);
50  x[3] = za / X[2];
51  x[2] = za - x[3] * X[2];
52  x[1] = arg.borderid;
53  xodd = (arg.borderid + x[2] + x[3] + arg.parity) & 1;
54  x[0] = (2 * idx + xodd) - za * X[0];
55  break;
56  case 2: //Z FACE
57  za = idx / ( X[0] / 2);
58  x[3] = za / X[1];
59  x[1] = za - x[3] * X[1];
60  x[2] = arg.borderid;
61  xodd = (arg.borderid + x[1] + x[3] + arg.parity) & 1;
62  x[0] = (2 * idx + xodd) - za * X[0];
63  break;
64  case 3: //T FACE
65  za = idx / ( X[0] / 2);
66  x[2] = za / X[1];
67  x[1] = za - x[2] * X[1];
68  x[3] = arg.borderid;
69  xodd = (arg.borderid + x[1] + x[2] + arg.parity) & 1;
70  x[0] = (2 * idx + xodd) - za * X[0];
71  break;
72  }
73 
74  int id = (((x[3] * X[2] + x[2]) * X[1] + x[1]) * X[0] + x[0]) >> 1;
75  typedef complex<Float> Complex;
76  typedef typename mapper<Float>::type RegType;
77  RegType tmp[NElems];
78  Complex data[9];
79 
80  if (pack) {
81  arg.dataOr.load(data, id, arg.dir, arg.parity);
82  arg.dataOr.reconstruct.Pack(tmp, data, id);
83  for ( int i = 0; i < NElems / 2; ++i ) arg.array[idx + arg.size * i] = Complex(tmp[2*i+0], tmp[2*i+1]);
84  } else {
85  for ( int i = 0; i < NElems / 2; ++i ) {
86  tmp[2*i+0] = arg.array[idx + arg.size * i].real();
87  tmp[2*i+1] = arg.array[idx + arg.size * i].imag();
88  }
89  arg.dataOr.reconstruct.Unpack(data, tmp, id, arg.dir, 0, arg.dataOr.X, arg.dataOr.R);
90  arg.dataOr.save(data, id, arg.dir, arg.parity);
91  }
92  }
93 
94  static void *send_d[4];
95  static void *recv_d[4];
96  static void *sendg_d[4];
97  static void *recvg_d[4];
98  static void *hostbuffer_h[4];
99  static qudaStream_t GFStream[2];
100  static MsgHandle *mh_recv_back[4];
101  static MsgHandle *mh_recv_fwd[4];
102  static MsgHandle *mh_send_fwd[4];
103  static MsgHandle *mh_send_back[4];
104  static int *X;
105  static bool init = false;
106 
107  /**
108  * @brief Release all allocated memory used to exchange data between nodes
109  */
110  void PGaugeExchangeFree()
111  {
112  if ( comm_dim_partitioned(0) || comm_dim_partitioned(1) || comm_dim_partitioned(2) || comm_dim_partitioned(3) ) {
113  if (init) {
114  cudaStreamDestroy(GFStream[0]);
115  cudaStreamDestroy(GFStream[1]);
116  for (int d = 0; d < 4; d++ ) {
117  if (commDimPartitioned(d)) {
118  comm_free(mh_send_fwd[d]);
119  comm_free(mh_send_back[d]);
120  comm_free(mh_recv_back[d]);
121  comm_free(mh_recv_fwd[d]);
122  device_free(send_d[d]);
123  device_free(recv_d[d]);
124  device_free(sendg_d[d]);
125  device_free(recvg_d[d]);
126  host_free(hostbuffer_h[d]);
127  }
128  }
129  host_free(X);
130  init = false;
131  }
132  }
133  }
134 
135  template<typename Float, int nColor, QudaReconstructType recon> struct PGaugeExchanger {
136  PGaugeExchanger(GaugeField& data, const int dir, const int parity)
137  {
138  if (init) {
139  for (int d = 0; d < 4; d++) {
140  if (X[d] != data.X()[d]) {
141  PGaugeExchangeFree();
142  printfQuda("PGaugeExchange needs to be reinitialized...\n");
143  break;
144  }
145  }
146  }
147 
148  size_t bytes[4];
149  void *send[4];
150  void *recv[4];
151  void *sendg[4];
152  void *recvg[4];
153  for (int d = 0; d < 4; d++) {
154  if (!commDimPartitioned(d)) continue;
155  bytes[d] = sizeof(Float) * data.SurfaceCB(d) * recon;
156  }
157 
158  if (!init) {
159  X = (int*)safe_malloc(4 * sizeof(int));
160  for (int d = 0; d < 4; d++) X[d] = data.X()[d];
161 
162  cudaStreamCreate(&GFStream[0]);
163  cudaStreamCreate(&GFStream[1]);
164  for (int d = 0; d < 4; d++ ) {
165  if (!commDimPartitioned(d)) continue;
166  // store both parities and directions in each
167  send_d[d] = device_malloc(bytes[d]);
168  recv_d[d] = device_malloc(bytes[d]);
169  sendg_d[d] = device_malloc(bytes[d]);
170  recvg_d[d] = device_malloc(bytes[d]);
171  hostbuffer_h[d] = (void*)pinned_malloc(4 * bytes[d]);
172  recv[d] = hostbuffer_h[d];
173  send[d] = static_cast<char*>(hostbuffer_h[d]) + bytes[d];
174  recvg[d] = static_cast<char*>(hostbuffer_h[d]) + 3 * bytes[d];
175  sendg[d] = static_cast<char*>(hostbuffer_h[d]) + 2 * bytes[d];
176 
177  mh_recv_back[d] = comm_declare_receive_relative(recv[d], d, -1, bytes[d]);
178  mh_recv_fwd[d] = comm_declare_receive_relative(recvg[d], d, +1, bytes[d]);
179  mh_send_back[d] = comm_declare_send_relative(sendg[d], d, -1, bytes[d]);
180  mh_send_fwd[d] = comm_declare_send_relative(send[d], d, +1, bytes[d]);
181  }
182  init = true;
183  } else {
184  for (int d = 0; d < 4; d++ ) {
185  if (!commDimPartitioned(d)) continue;
186  recv[d] = hostbuffer_h[d];
187  send[d] = static_cast<char*>(hostbuffer_h[d]) + bytes[d];
188  recvg[d] = static_cast<char*>(hostbuffer_h[d]) + 3 * bytes[d];
189  sendg[d] = static_cast<char*>(hostbuffer_h[d]) + 2 * bytes[d];
190  }
191  }
192 
193  GaugeFixUnPackArg<Float, recon> arg(data);
194 
195  qudaDeviceSynchronize();
196  for (int d = 0; d < 4; d++) {
197  if ( !commDimPartitioned(d)) continue;
198  comm_start(mh_recv_back[d]);
199  comm_start(mh_recv_fwd[d]);
200 
201  TuneParam tp;
202  tp.block = make_uint3(128, 1, 1);
203  tp.grid = make_uint3((data.SurfaceCB(d) + tp.block.x - 1) / tp.block.x, 1, 1);
204 
205  arg.size = data.SurfaceCB(d);
206  arg.parity = parity;
207  arg.face = d;
208  arg.dir = dir;
209 
210  //extract top face
211  arg.array = reinterpret_cast<complex<Float>*>(send_d[d]);
212  arg.borderid = X[d] - data.R()[d] - 1;
213  qudaLaunchKernel(Kernel_UnPack<recon, Float, true, decltype(arg)>, tp, GFStream[0], arg);
214 
215  //extract bottom
216  arg.array = reinterpret_cast<complex<Float>*>(sendg_d[d]);
217  arg.borderid = data.R()[d];
218  qudaLaunchKernel(Kernel_UnPack<recon, Float, true, decltype(arg)>, tp, GFStream[1], arg);
219 
220  qudaMemcpyAsync(send[d], send_d[d], bytes[d], cudaMemcpyDeviceToHost, GFStream[0]);
221  qudaMemcpyAsync(sendg[d], sendg_d[d], bytes[d], cudaMemcpyDeviceToHost, GFStream[1]);
222 
223  qudaStreamSynchronize(GFStream[0]);
224  comm_start(mh_send_fwd[d]);
225 
226  qudaStreamSynchronize(GFStream[1]);
227  comm_start(mh_send_back[d]);
228 
229  comm_wait(mh_recv_back[d]);
230  qudaMemcpyAsync(recv_d[d], recv[d], bytes[d], cudaMemcpyHostToDevice, GFStream[0]);
231 
232  arg.array = reinterpret_cast<complex<Float>*>(recv_d[d]);
233  arg.borderid = data.R()[d] - 1;
234  qudaLaunchKernel(Kernel_UnPack<recon, Float, false, decltype(arg)>, tp, GFStream[0], arg);
235 
236  comm_wait(mh_recv_fwd[d]);
237  qudaMemcpyAsync(recvg_d[d], recvg[d], bytes[d], cudaMemcpyHostToDevice, GFStream[1]);
238 
239  arg.array = reinterpret_cast<complex<Float>*>(recvg_d[d]);
240  arg.borderid = X[d] - data.R()[d];
241  qudaLaunchKernel(Kernel_UnPack<recon, Float, false, decltype(arg)>, tp, GFStream[1], arg);
242 
243  comm_wait(mh_send_back[d]);
244  comm_wait(mh_send_fwd[d]);
245  qudaStreamSynchronize(GFStream[0]);
246  qudaStreamSynchronize(GFStream[1]);
247  }
248  qudaDeviceSynchronize();
249  }
250  };
251 
252  void PGaugeExchange(GaugeField& data, const int dir, const int parity)
253  {
254 #ifdef GPU_GAUGE_ALG
255  if ( comm_dim_partitioned(0) || comm_dim_partitioned(1) || comm_dim_partitioned(2) || comm_dim_partitioned(3) ) {
256  instantiate<PGaugeExchanger>(data, dir, parity);
257  }
258 #else
259  errorQuda("Pure gauge code has not been built");
260 #endif
261  }
262 }