1 #include <quda_internal.h>
2 #include <quda_matrix.h>
4 #include <gauge_field.h>
5 #include <gauge_field_order.h>
7 #include <pgauge_monte.h>
8 #include <instantiate.h>
12 template <typename Float, QudaReconstructType recon>
13 struct GaugeFixUnPackArg {
14 int X[4]; // grid dimensions
15 using Gauge = typename gauge_mapper<Float, recon>::type;
18 complex<Float> *array;
23 GaugeFixUnPackArg(GaugeField & data)
26 for ( int dir = 0; dir < 4; ++dir ) X[dir] = data.X()[dir];
30 template <int NElems, typename Float, bool pack, typename Arg>
31 __global__ void Kernel_UnPack(Arg arg)
33 int idx = blockIdx.x * blockDim.x + threadIdx.x;
34 if ( idx >= arg.size ) return;
36 for ( int dr = 0; dr < 4; ++dr ) X[dr] = arg.X[dr];
41 za = idx / ( X[1] / 2);
43 x[2] = za - x[3] * X[2];
45 xodd = (arg.borderid + x[2] + x[3] + arg.parity) & 1;
46 x[1] = (2 * idx + xodd) - za * X[1];
49 za = idx / ( X[0] / 2);
51 x[2] = za - x[3] * X[2];
53 xodd = (arg.borderid + x[2] + x[3] + arg.parity) & 1;
54 x[0] = (2 * idx + xodd) - za * X[0];
57 za = idx / ( X[0] / 2);
59 x[1] = za - x[3] * X[1];
61 xodd = (arg.borderid + x[1] + x[3] + arg.parity) & 1;
62 x[0] = (2 * idx + xodd) - za * X[0];
65 za = idx / ( X[0] / 2);
67 x[1] = za - x[2] * X[1];
69 xodd = (arg.borderid + x[1] + x[2] + arg.parity) & 1;
70 x[0] = (2 * idx + xodd) - za * X[0];
74 int id = (((x[3] * X[2] + x[2]) * X[1] + x[1]) * X[0] + x[0]) >> 1;
75 typedef complex<Float> Complex;
76 typedef typename mapper<Float>::type RegType;
81 arg.dataOr.load(data, id, arg.dir, arg.parity);
82 arg.dataOr.reconstruct.Pack(tmp, data, id);
83 for ( int i = 0; i < NElems / 2; ++i ) arg.array[idx + arg.size * i] = Complex(tmp[2*i+0], tmp[2*i+1]);
85 for ( int i = 0; i < NElems / 2; ++i ) {
86 tmp[2*i+0] = arg.array[idx + arg.size * i].real();
87 tmp[2*i+1] = arg.array[idx + arg.size * i].imag();
89 arg.dataOr.reconstruct.Unpack(data, tmp, id, arg.dir, 0, arg.dataOr.X, arg.dataOr.R);
90 arg.dataOr.save(data, id, arg.dir, arg.parity);
94 static void *send_d[4];
95 static void *recv_d[4];
96 static void *sendg_d[4];
97 static void *recvg_d[4];
98 static void *hostbuffer_h[4];
99 static qudaStream_t GFStream[2];
100 static MsgHandle *mh_recv_back[4];
101 static MsgHandle *mh_recv_fwd[4];
102 static MsgHandle *mh_send_fwd[4];
103 static MsgHandle *mh_send_back[4];
105 static bool init = false;
108 * @brief Release all allocated memory used to exchange data between nodes
110 void PGaugeExchangeFree()
112 if ( comm_dim_partitioned(0) || comm_dim_partitioned(1) || comm_dim_partitioned(2) || comm_dim_partitioned(3) ) {
114 cudaStreamDestroy(GFStream[0]);
115 cudaStreamDestroy(GFStream[1]);
116 for (int d = 0; d < 4; d++ ) {
117 if (commDimPartitioned(d)) {
118 comm_free(mh_send_fwd[d]);
119 comm_free(mh_send_back[d]);
120 comm_free(mh_recv_back[d]);
121 comm_free(mh_recv_fwd[d]);
122 device_free(send_d[d]);
123 device_free(recv_d[d]);
124 device_free(sendg_d[d]);
125 device_free(recvg_d[d]);
126 host_free(hostbuffer_h[d]);
135 template<typename Float, int nColor, QudaReconstructType recon> struct PGaugeExchanger {
136 PGaugeExchanger(GaugeField& data, const int dir, const int parity)
139 for (int d = 0; d < 4; d++) {
140 if (X[d] != data.X()[d]) {
141 PGaugeExchangeFree();
142 printfQuda("PGaugeExchange needs to be reinitialized...\n");
153 for (int d = 0; d < 4; d++) {
154 if (!commDimPartitioned(d)) continue;
155 bytes[d] = sizeof(Float) * data.SurfaceCB(d) * recon;
159 X = (int*)safe_malloc(4 * sizeof(int));
160 for (int d = 0; d < 4; d++) X[d] = data.X()[d];
162 cudaStreamCreate(&GFStream[0]);
163 cudaStreamCreate(&GFStream[1]);
164 for (int d = 0; d < 4; d++ ) {
165 if (!commDimPartitioned(d)) continue;
166 // store both parities and directions in each
167 send_d[d] = device_malloc(bytes[d]);
168 recv_d[d] = device_malloc(bytes[d]);
169 sendg_d[d] = device_malloc(bytes[d]);
170 recvg_d[d] = device_malloc(bytes[d]);
171 hostbuffer_h[d] = (void*)pinned_malloc(4 * bytes[d]);
172 recv[d] = hostbuffer_h[d];
173 send[d] = static_cast<char*>(hostbuffer_h[d]) + bytes[d];
174 recvg[d] = static_cast<char*>(hostbuffer_h[d]) + 3 * bytes[d];
175 sendg[d] = static_cast<char*>(hostbuffer_h[d]) + 2 * bytes[d];
177 mh_recv_back[d] = comm_declare_receive_relative(recv[d], d, -1, bytes[d]);
178 mh_recv_fwd[d] = comm_declare_receive_relative(recvg[d], d, +1, bytes[d]);
179 mh_send_back[d] = comm_declare_send_relative(sendg[d], d, -1, bytes[d]);
180 mh_send_fwd[d] = comm_declare_send_relative(send[d], d, +1, bytes[d]);
184 for (int d = 0; d < 4; d++ ) {
185 if (!commDimPartitioned(d)) continue;
186 recv[d] = hostbuffer_h[d];
187 send[d] = static_cast<char*>(hostbuffer_h[d]) + bytes[d];
188 recvg[d] = static_cast<char*>(hostbuffer_h[d]) + 3 * bytes[d];
189 sendg[d] = static_cast<char*>(hostbuffer_h[d]) + 2 * bytes[d];
193 GaugeFixUnPackArg<Float, recon> arg(data);
195 qudaDeviceSynchronize();
196 for (int d = 0; d < 4; d++) {
197 if ( !commDimPartitioned(d)) continue;
198 comm_start(mh_recv_back[d]);
199 comm_start(mh_recv_fwd[d]);
202 tp.block = make_uint3(128, 1, 1);
203 tp.grid = make_uint3((data.SurfaceCB(d) + tp.block.x - 1) / tp.block.x, 1, 1);
205 arg.size = data.SurfaceCB(d);
211 arg.array = reinterpret_cast<complex<Float>*>(send_d[d]);
212 arg.borderid = X[d] - data.R()[d] - 1;
213 qudaLaunchKernel(Kernel_UnPack<recon, Float, true, decltype(arg)>, tp, GFStream[0], arg);
216 arg.array = reinterpret_cast<complex<Float>*>(sendg_d[d]);
217 arg.borderid = data.R()[d];
218 qudaLaunchKernel(Kernel_UnPack<recon, Float, true, decltype(arg)>, tp, GFStream[1], arg);
220 qudaMemcpyAsync(send[d], send_d[d], bytes[d], cudaMemcpyDeviceToHost, GFStream[0]);
221 qudaMemcpyAsync(sendg[d], sendg_d[d], bytes[d], cudaMemcpyDeviceToHost, GFStream[1]);
223 qudaStreamSynchronize(GFStream[0]);
224 comm_start(mh_send_fwd[d]);
226 qudaStreamSynchronize(GFStream[1]);
227 comm_start(mh_send_back[d]);
229 comm_wait(mh_recv_back[d]);
230 qudaMemcpyAsync(recv_d[d], recv[d], bytes[d], cudaMemcpyHostToDevice, GFStream[0]);
232 arg.array = reinterpret_cast<complex<Float>*>(recv_d[d]);
233 arg.borderid = data.R()[d] - 1;
234 qudaLaunchKernel(Kernel_UnPack<recon, Float, false, decltype(arg)>, tp, GFStream[0], arg);
236 comm_wait(mh_recv_fwd[d]);
237 qudaMemcpyAsync(recvg_d[d], recvg[d], bytes[d], cudaMemcpyHostToDevice, GFStream[1]);
239 arg.array = reinterpret_cast<complex<Float>*>(recvg_d[d]);
240 arg.borderid = X[d] - data.R()[d];
241 qudaLaunchKernel(Kernel_UnPack<recon, Float, false, decltype(arg)>, tp, GFStream[1], arg);
243 comm_wait(mh_send_back[d]);
244 comm_wait(mh_send_fwd[d]);
245 qudaStreamSynchronize(GFStream[0]);
246 qudaStreamSynchronize(GFStream[1]);
248 qudaDeviceSynchronize();
252 void PGaugeExchange(GaugeField& data, const int dir, const int parity)
255 if ( comm_dim_partitioned(0) || comm_dim_partitioned(1) || comm_dim_partitioned(2) || comm_dim_partitioned(3) ) {
256 instantiate<PGaugeExchanger>(data, dir, parity);
259 errorQuda("Pure gauge code has not been built");