1 #include <quda_internal.h>
2 #include <quda_matrix.h>
4 #include <gauge_field.h>
5 #include <gauge_field_order.h>
6 #include <launch_kernel.cuh>
7 #include <unitarization_links.h>
9 #include <gauge_fix_ovr_extra.h>
10 #include <gauge_fix_ovr_hit_devf.cuh>
11 #include <reduce_helper.h>
12 #include <index_helper.cuh>
13 #include <instantiate.h>
17 #define LAUNCH_KERNEL_GAUGEFIX(kernel, tp, stream, arg, parity, ...) \
18 if (tp.aux.x == 0) { \
19 switch (tp.block.x) { \
20 case 256: qudaLaunchKernel(kernel<0, 32, __VA_ARGS__>, tp, stream, arg, parity); break; \
21 case 512: qudaLaunchKernel(kernel<0, 64, __VA_ARGS__>, tp, stream, arg, parity); break; \
22 case 768: qudaLaunchKernel(kernel<0, 96, __VA_ARGS__>, tp, stream, arg, parity); break; \
23 case 1024: qudaLaunchKernel(kernel<0, 128, __VA_ARGS__>, tp, stream, arg, parity); break; \
24 default: errorQuda("%s not implemented for %d threads", #kernel, tp.block.x); \
26 } else if (tp.aux.x == 1) { \
27 switch (tp.block.x) { \
28 case 256: qudaLaunchKernel(kernel<1, 32, __VA_ARGS__>, tp, stream, arg, parity); break; \
29 case 512: qudaLaunchKernel(kernel<1, 64, __VA_ARGS__>, tp, stream, arg, parity); break; \
30 case 768: qudaLaunchKernel(kernel<1, 96, __VA_ARGS__>, tp, stream, arg, parity); break; \
31 case 1024: qudaLaunchKernel(kernel<1, 128, __VA_ARGS__>, tp, stream, arg, parity); break; \
32 default: errorQuda("%s not implemented for %d threads", #kernel, tp.block.x); \
34 } else if (tp.aux.x == 2) { \
35 switch (tp.block.x) { \
36 case 256: qudaLaunchKernel(kernel<2, 32, __VA_ARGS__>, tp, stream, arg, parity); break; \
37 case 512: qudaLaunchKernel(kernel<2, 64, __VA_ARGS__>, tp, stream, arg, parity); break; \
38 case 768: qudaLaunchKernel(kernel<2, 96, __VA_ARGS__>, tp, stream, arg, parity); break; \
39 case 1024: qudaLaunchKernel(kernel<2, 128, __VA_ARGS__>, tp, stream, arg, parity); break; \
40 default: errorQuda("%s not implemented for %d threads", #kernel, tp.block.x); \
42 } else if (tp.aux.x == 3) { \
43 switch (tp.block.x) { \
44 case 128: qudaLaunchKernel(kernel<3, 32, __VA_ARGS__>, tp, stream, arg, parity); break; \
45 case 256: qudaLaunchKernel(kernel<3, 64, __VA_ARGS__>, tp, stream, arg, parity); break; \
46 case 384: qudaLaunchKernel(kernel<3, 96, __VA_ARGS__>, tp, stream, arg, parity); break; \
47 case 512: qudaLaunchKernel(kernel<3, 128, __VA_ARGS__>, tp, stream, arg, parity); break; \
48 case 640: qudaLaunchKernel(kernel<3, 160, __VA_ARGS__>, tp, stream, arg, parity); break; \
49 case 768: qudaLaunchKernel(kernel<3, 192, __VA_ARGS__>, tp, stream, arg, parity); break; \
50 case 896: qudaLaunchKernel(kernel<3, 224, __VA_ARGS__>, tp, stream, arg, parity); break; \
51 case 1024: qudaLaunchKernel(kernel<3, 256, __VA_ARGS__>, tp, stream, arg, parity); break; \
52 default: errorQuda("%s not implemented for %d threads", #kernel, tp.block.x); \
54 } else if (tp.aux.x == 4) { \
55 switch (tp.block.x) { \
56 case 128: qudaLaunchKernel(kernel<4, 32, __VA_ARGS__>, tp, stream, arg, parity); break; \
57 case 256: qudaLaunchKernel(kernel<4, 64, __VA_ARGS__>, tp, stream, arg, parity); break; \
58 case 384: qudaLaunchKernel(kernel<4, 96, __VA_ARGS__>, tp, stream, arg, parity); break; \
59 case 512: qudaLaunchKernel(kernel<4, 128, __VA_ARGS__>, tp, stream, arg, parity); break; \
60 case 640: qudaLaunchKernel(kernel<4, 160, __VA_ARGS__>, tp, stream, arg, parity); break; \
61 case 768: qudaLaunchKernel(kernel<4, 192, __VA_ARGS__>, tp, stream, arg, parity); break; \
62 case 896: qudaLaunchKernel(kernel<4, 224, __VA_ARGS__>, tp, stream, arg, parity); break; \
63 case 1024: qudaLaunchKernel(kernel<4, 256, __VA_ARGS__>, tp, stream, arg, parity); break; \
64 default: errorQuda("%s not implemented for %d threads", #kernel, tp.block.x); \
66 } else if (tp.aux.x == 5) { \
67 switch (tp.block.x) { \
68 case 128: qudaLaunchKernel(kernel<5, 32, __VA_ARGS__>, tp, stream, arg, parity); break; \
69 case 256: qudaLaunchKernel(kernel<5, 64, __VA_ARGS__>, tp, stream, arg, parity); break; \
70 case 384: qudaLaunchKernel(kernel<5, 96, __VA_ARGS__>, tp, stream, arg, parity); break; \
71 case 512: qudaLaunchKernel(kernel<5, 128, __VA_ARGS__>, tp, stream, arg, parity); break; \
72 case 640: qudaLaunchKernel(kernel<5, 160, __VA_ARGS__>, tp, stream, arg, parity); break; \
73 case 768: qudaLaunchKernel(kernel<5, 192, __VA_ARGS__>, tp, stream, arg, parity); break; \
74 case 896: qudaLaunchKernel(kernel<5, 224, __VA_ARGS__>, tp, stream, arg, parity); break; \
75 case 1024: qudaLaunchKernel(kernel<5, 256, __VA_ARGS__>, tp, stream, arg, parity); break; \
76 default: errorQuda("%s not implemented for %d threads", #kernel, tp.block.x); \
79 errorQuda("Not implemented for %d", tp.aux.x); \
83 * @brief container to pass parameters for the gauge fixing quality kernel
85 template <typename Gauge>
86 struct GaugeFixQualityArg : public ReduceArg<double2> {
87 int threads; // number of active threads required
88 int X[4]; // grid dimensions
94 GaugeFixQualityArg(const Gauge &dataOr, const GaugeField &data)
95 : ReduceArg<double2>(), dataOr(dataOr) {
97 for ( int dir = 0; dir < 4; ++dir ) {
98 X[dir] = data.X()[dir] - data.R()[dir] * 2;
100 border[dir] = data.R()[dir];
103 threads = X[0]*X[1]*X[2]*X[3]/2;
105 double getAction(){ return result.x; }
106 double getTheta(){ return result.y; }
110 * @brief Measure gauge fixing quality
112 template<int blockSize, typename Float, typename Gauge, int gauge_dir>
113 __global__ void computeFix_quality(GaugeFixQualityArg<Gauge> argQ){
114 typedef complex<Float> Cmplx;
116 int idx_cb = threadIdx.x + blockIdx.x * blockDim.x;
117 int parity = threadIdx.y;
119 double2 data = make_double2(0.0,0.0);
120 while (idx_cb < argQ.threads) {
123 for ( int dr = 0; dr < 4; ++dr ) X[dr] = argQ.X[dr];
126 getCoords(x, idx_cb, X, parity);
129 for ( int dr = 0; dr < 4; ++dr ) {
130 x[dr] += argQ.border[dr];
131 X[dr] += 2 * argQ.border[dr];
134 Matrix<Cmplx,3> delta;
137 for ( int mu = 0; mu < gauge_dir; mu++ ) {
138 Matrix<Cmplx,3> U = argQ.dataOr(mu, linkIndex(x, X), parity);
142 data.x += -delta(0, 0).x - delta(1, 1).x - delta(2, 2).x;
144 //load downward links
145 for ( int mu = 0; mu < gauge_dir; mu++ ) {
146 Matrix<Cmplx,3> U = argQ.dataOr(mu, linkIndexM1(x,X,mu), 1 - parity);
150 delta -= conj(delta);
154 data.y += getRealTraceUVdagger(delta, delta);
158 idx_cb += blockDim.x * gridDim.x;
160 argQ.template reduce2d<blockSize,2>(data);
164 * @brief Tunable object for the gauge fixing quality kernel
166 template<typename Float, typename Gauge, int gauge_dir>
167 class GaugeFixQuality : TunableLocalParityReduction {
168 GaugeFixQualityArg<Gauge> &arg;
169 const GaugeField &meta;
172 GaugeFixQuality(GaugeFixQualityArg<Gauge> &arg, const GaugeField &meta) :
177 void apply(const qudaStream_t &stream)
179 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
180 LAUNCH_KERNEL_LOCAL_PARITY(computeFix_quality, (*this), tp, stream, arg, Float, Gauge, gauge_dir);
181 auto reset = true; // apply is called multiple times with the same arg instance so we need to reset
182 arg.complete(arg.result, stream, reset);
183 if (!activeTuning()) {
184 comm_allreduce_array((double*)&arg.result, 2);
185 arg.result.x /= (double)(3 * gauge_dir * 2 * arg.threads * comm_size());
186 arg.result.y /= (double)(3 * 2 * arg.threads * comm_size());
190 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
191 long long flops() const { return (36LL * gauge_dir + 65LL) * meta.Volume(); }
192 //long long bytes() const { return (1)*2*gauge_dir*arg.Bytes(); }
193 long long bytes() const { return 2LL * gauge_dir * meta.Volume() * meta.Reconstruct() * sizeof(Float); }
197 * @brief container to pass parameters for the gauge fixing kernel
199 template <typename Float, typename Gauge>
201 int threads; // number of active threads required
202 int X[4]; // grid dimensions
208 const Float relax_boost;
210 GaugeFixArg(Gauge & dataOr, GaugeField & data, const Float relax_boost)
211 : dataOr(dataOr), data(data), relax_boost(relax_boost) {
213 for ( int dir = 0; dir < 4; ++dir ) {
214 X[dir] = data.X()[dir] - data.R()[dir] * 2;
216 border[dir] = data.R()[dir];
219 threads = X[0] * X[1] * X[2] * X[3] >> 1;
225 * @brief Kernel to perform gauge fixing with overrelaxation for single-GPU
227 template<int ImplementationType, int blockSize, typename Float, typename Gauge, int gauge_dir>
228 __global__ void computeFix(GaugeFixArg<Float, Gauge> arg, int parity)
230 typedef complex<Float> Cmplx;
231 int tid = (threadIdx.x + blockSize) % blockSize;
232 int idx = blockIdx.x * blockSize + tid;
234 if ( idx >= arg.threads ) return;
236 // 8 threads per lattice site
237 if ( ImplementationType < 3 ) {
240 for ( int dr = 0; dr < 4; ++dr ) X[dr] = arg.X[dr];
243 getCoords(x, idx, X, parity);
246 for ( int dr = 0; dr < 4; ++dr ) {
247 x[dr] += arg.border[dr];
248 X[dr] += 2 * arg.border[dr];
251 int mu = (threadIdx.x / blockSize);
253 if ( threadIdx.x >= blockSize * 4 ) {
255 x[mu] = (x[mu] - 1 + X[mu]) % X[mu];
258 idx = (((x[3] * X[2] + x[2]) * X[1] + x[1]) * X[0] + x[0]) >> 1;
259 Matrix<Cmplx,3> link = arg.dataOr(mu, idx, oddbit);
260 // 8 treads per lattice site, the reduction is performed by shared memory without using atomicadd.
261 // this implementation needs 8x more shared memory than the implementation using atomicadd
262 if ( ImplementationType == 0 ) GaugeFixHit_NoAtomicAdd<blockSize, Float, gauge_dir, 3>(link, arg.relax_boost, tid);
263 // 8 treads per lattice site, the reduction is performed by shared memory using atomicadd
264 if ( ImplementationType == 1 ) GaugeFixHit_AtomicAdd<blockSize, Float, gauge_dir, 3>(link, arg.relax_boost, tid);
265 // 8 treads per lattice site, the reduction is performed by shared memory without using atomicadd.
266 // uses the same amount of shared memory as the atomicadd implementation with more thread block synchronization
267 if ( ImplementationType == 2 ) GaugeFixHit_NoAtomicAdd_LessSM<blockSize, Float, gauge_dir, 3>(link, arg.relax_boost, tid);
268 arg.dataOr(mu, idx, oddbit) = link;
270 // 4 threads per lattice site
274 for ( int dr = 0; dr < 4; ++dr ) X[dr] = arg.X[dr];
277 getCoords(x, idx, X, parity);
280 for ( int dr = 0; dr < 4; ++dr ) {
281 x[dr] += arg.border[dr];
282 X[dr] += 2 * arg.border[dr];
285 int mu = (threadIdx.x / blockSize);
286 idx = (((x[3] * X[2] + x[2]) * X[1] + x[1]) * X[0] + x[0]) >> 1;
288 Matrix<Cmplx,3> link = arg.dataOr(mu, idx, parity);
290 x[mu] = (x[mu] - 1 + X[mu]) % X[mu];
291 int idx1 = (((x[3] * X[2] + x[2]) * X[1] + x[1]) * X[0] + x[0]) >> 1;
293 Matrix<Cmplx,3> link1 = arg.dataOr(mu, idx1, 1 - parity);
295 // 4 treads per lattice site, the reduction is performed by shared memory without using atomicadd.
296 // this implementation needs 4x more shared memory than the implementation using atomicadd
297 if ( ImplementationType == 3 ) GaugeFixHit_NoAtomicAdd<blockSize, Float, gauge_dir, 3>(link, link1, arg.relax_boost, tid);
298 // 4 treads per lattice site, the reduction is performed by shared memory using atomicadd
299 if ( ImplementationType == 4 ) GaugeFixHit_AtomicAdd<blockSize, Float, gauge_dir, 3>(link, link1, arg.relax_boost, tid);
300 // 4 treads per lattice site, the reduction is performed by shared memory without using atomicadd.
301 // uses the same amount of shared memory as the atomicadd implementation with more thread block synchronization
302 if ( ImplementationType == 5 ) GaugeFixHit_NoAtomicAdd_LessSM<blockSize, Float, gauge_dir, 3>(link, link1, arg.relax_boost, tid);
304 arg.dataOr(mu, idx, parity) = link;
305 arg.dataOr(mu, idx1, 1 - parity) = link1;
310 * @brief Tunable object for the gauge fixing kernel
312 template<typename Float, typename Gauge, int gauge_dir>
313 class GaugeFix : Tunable {
314 GaugeFixArg<Float, Gauge> &arg;
315 const GaugeField &meta;
318 dim3 createGrid(const TuneParam ¶m) const
320 unsigned int blockx = param.block.x / 8;
321 if (param.aux.x > 2) blockx = param.block.x / 4;
322 unsigned int gx = std::max((arg.threads + blockx - 1) / blockx, 1u);
323 return dim3(gx, 1, 1);
326 bool advanceBlockDim (TuneParam ¶m) const
328 // Use param.aux.x to tune and save state for best kernel option
329 // to make use or not of atomicAdd operations and 4 or 8 threads per lattice site!!!
330 const unsigned int min_threads0 = 32 * 8;
331 const unsigned int min_threads1 = 32 * 4;
332 const unsigned int max_threads = 1024; // FIXME: use deviceProp.maxThreadsDim[0];
333 const unsigned int atmadd = 0;
334 unsigned int min_threads = min_threads0;
335 param.aux.x += atmadd; // USE TO SELECT BEST KERNEL OPTION WITH/WITHOUT USING ATOMICADD
336 if (param.aux.x > 2) min_threads = 32 * 4;
337 param.block.x += min_threads;
339 param.grid = createGrid(param);
341 if ((param.block.x >= min_threads) && (param.block.x <= max_threads)) {
342 param.shared_bytes = sharedBytesPerBlock(param);
344 } else if (param.aux.x == 0) {
345 param.block.x = min_threads0;
347 param.aux.x = 1; // USE FOR ATOMIC ADD
348 param.grid = createGrid(param);
349 param.shared_bytes = param.block.x * 4 * sizeof(Float) / 8;
351 } else if (param.aux.x == 1) {
352 param.block.x = min_threads0;
354 param.aux.x = 2; // USE FOR NO ATOMIC ADD and LESS SHARED MEM
355 param.grid = createGrid(param);
356 param.shared_bytes = param.block.x * 4 * sizeof(Float) / 8;
358 } else if (param.aux.x == 2) {
359 param.block.x = min_threads1;
361 param.aux.x = 3; // USE FOR NO ATOMIC ADD
362 param.grid = createGrid(param);
363 param.shared_bytes = param.block.x * 4 * sizeof(Float);
365 } else if (param.aux.x == 3) {
366 param.block.x = min_threads1;
369 param.grid = createGrid(param);
370 param.shared_bytes = param.block.x * sizeof(Float);
372 } else if (param.aux.x == 4) {
373 param.block.x = min_threads1;
376 param.grid = createGrid(param);
377 param.shared_bytes = param.block.x * sizeof(Float);
384 unsigned int sharedBytesPerThread() const { return 0; }
385 unsigned int sharedBytesPerBlock(const TuneParam ¶m) const {
386 switch (param.aux.x) {
387 case 0: return param.block.x * 4 * sizeof(Float);
388 case 1: return param.block.x * 4 * sizeof(Float) / 8;
389 case 2: return param.block.x * 4 * sizeof(Float) / 8;
390 case 3: return param.block.x * 4 * sizeof(Float);
391 default: return param.block.x * sizeof(Float);
395 bool tuneSharedBytes() const { return false; }
396 bool tuneGridDim() const { return false; }
397 unsigned int minThreads() const { return arg.threads; }
400 GaugeFix(GaugeFixArg<Float, Gauge> &arg, const GaugeField &meta) :
405 void setParity(const int par) { parity = par; }
407 void apply(const qudaStream_t &stream){
408 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
409 LAUNCH_KERNEL_GAUGEFIX(computeFix, tp, stream, arg, parity, Float, Gauge, gauge_dir);
412 virtual void initTuneParam(TuneParam ¶m) const
414 param.block = dim3(256, 1, 1);
416 param.grid = createGrid(param);
417 param.shared_bytes = sharedBytesPerBlock(param);
420 virtual void defaultTuneParam(TuneParam ¶m) const { initTuneParam(param); }
422 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
424 std::string paramString(const TuneParam ¶m) const {
425 std::stringstream ps(Tunable::paramString(param));
426 ps << ", atomicadd=" << param.aux.x;
430 void preTune() { arg.data.backup(); }
431 void postTune() { arg.data.restore(); }
432 long long flops() const { return 3LL * (22 + 28 * gauge_dir + 224 * 3) * arg.threads; }
433 long long bytes() const { return 8LL * 2 * arg.threads * meta.Reconstruct() * sizeof(Float); }
437 template <typename Float, typename Gauge>
438 struct GaugeFixInteriorPointsArg {
439 int threads; // number of active threads required
440 int X[4]; // grid dimensions
446 const Float relax_boost;
447 GaugeFixInteriorPointsArg(Gauge & dataOr, GaugeField & data, const Float relax_boost)
448 : dataOr(dataOr), data(data), relax_boost(relax_boost) {
451 for ( int dir = 0; dir < 4; ++dir ) {
452 if ( comm_dim_partitioned(dir)) border[dir] = data.R()[dir] + 1; //skip BORDER_RADIUS + face border point
453 else border[dir] = 0;
455 for ( int dir = 0; dir < 4; ++dir ) X[dir] = data.X()[dir] - border[dir] * 2;
457 for ( int dir = 0; dir < 4; ++dir ) X[dir] = data.X()[dir];
459 threads = X[0] * X[1] * X[2] * X[3] >> 1;
460 if (this->threads == 0) errorQuda("Local volume is too small");
466 * @brief Kernel to perform gauge fixing with overrelaxation in the interior points for multi-GPU implementation
468 template<int ImplementationType, int blockSize, typename Float, typename Gauge, int gauge_dir>
469 __global__ void computeFixInteriorPoints(GaugeFixInteriorPointsArg<Float, Gauge> arg, int parity){
470 int tid = (threadIdx.x + blockSize) % blockSize;
471 int idx = blockIdx.x * blockSize + tid;
472 if ( idx >= arg.threads ) return;
473 typedef complex<Float> Complex;
476 for ( int dr = 0; dr < 4; ++dr ) X[dr] = arg.X[dr];
479 int za = (idx / (X[0] / 2));
480 int zb = (za / X[1]);
481 x[1] = za - zb * X[1];
483 x[2] = zb - x[3] * X[2];
484 int p = 0; for ( int dr = 0; dr < 4; ++dr ) p += arg.border[dr];
486 int x1odd = (x[1] + x[2] + x[3] + parity + p) & 1;
487 //int x1odd = (x[1] + x[2] + x[3] + parity) & 1;
488 x[0] = (2 * idx + x1odd) - za * X[0];
489 for ( int dr = 0; dr < 4; ++dr ) {
490 x[dr] += arg.border[dr];
491 X[dr] += 2 * arg.border[dr];
494 getCoords(x, idx, X, parity);
496 int mu = (threadIdx.x / blockSize);
498 // 8 threads per lattice site
499 if ( ImplementationType < 3 ) {
500 if ( threadIdx.x >= blockSize * 4 ) {
502 x[mu] = (x[mu] - 1 + X[mu]) % X[mu];
505 idx = (((x[3] * X[2] + x[2]) * X[1] + x[1]) * X[0] + x[0]) >> 1;
506 Matrix<Complex,3> link = arg.dataOr(mu, idx, parity);
507 // 8 treads per lattice site, the reduction is performed by shared memory without using atomicadd.
508 // this implementation needs 8x more shared memory than the implementation using atomicadd
509 if ( ImplementationType == 0 ) GaugeFixHit_NoAtomicAdd<blockSize, Float, gauge_dir, 3>(link, arg.relax_boost, tid);
510 // 8 treads per lattice site, the reduction is performed by shared memory using atomicadd
511 if ( ImplementationType == 1 ) GaugeFixHit_AtomicAdd<blockSize, Float, gauge_dir, 3>(link, arg.relax_boost, tid);
512 // 8 treads per lattice site, the reduction is performed by shared memory without using atomicadd.
513 // uses the same amount of shared memory as the atomicadd implementation with more thread block synchronization
514 if ( ImplementationType == 2 ) GaugeFixHit_NoAtomicAdd_LessSM<blockSize, Float, gauge_dir, 3>(link, arg.relax_boost, tid);
515 arg.dataOr(mu, idx, parity) = link;
517 // 4 threads per lattice site
519 idx = (((x[3] * X[2] + x[2]) * X[1] + x[1]) * X[0] + x[0]) >> 1;
520 Matrix<Complex,3> link = arg.dataOr(mu, idx, parity);
523 x[mu] = (x[mu] - 1 + X[mu]) % X[mu];
524 int idx1 = (((x[3] * X[2] + x[2]) * X[1] + x[1]) * X[0] + x[0]) >> 1;
525 Matrix<Complex,3> link1 = arg.dataOr(mu, idx1, 1 - parity);
527 // 4 treads per lattice site, the reduction is performed by shared memory without using atomicadd.
528 // this implementation needs 4x more shared memory than the implementation using atomicadd
529 if ( ImplementationType == 3 ) GaugeFixHit_NoAtomicAdd<blockSize, Float, gauge_dir, 3>(link, link1, arg.relax_boost, tid);
530 // 4 treads per lattice site, the reduction is performed by shared memory using atomicadd
531 if ( ImplementationType == 4 ) GaugeFixHit_AtomicAdd<blockSize, Float, gauge_dir, 3>(link, link1, arg.relax_boost, tid);
532 // 4 treads per lattice site, the reduction is performed by shared memory without using atomicadd.
533 // uses the same amount of shared memory as the atomicadd implementation with more thread block synchronization
534 if ( ImplementationType == 5 ) GaugeFixHit_NoAtomicAdd_LessSM<blockSize, Float, gauge_dir, 3>(link, link1, arg.relax_boost, tid);
536 arg.dataOr(mu, idx, parity) = link;
537 arg.dataOr(mu, idx1, 1 - parity) = link1;
542 * @brief Tunable object for the interior points of the gauge fixing
543 * kernel in multi-GPU implementation
545 template<typename Float, typename Gauge, int gauge_dir>
546 class GaugeFixInteriorPoints : Tunable {
547 GaugeFixInteriorPointsArg<Float, Gauge> &arg;
548 const GaugeField &meta;
551 dim3 createGrid(const TuneParam ¶m) const
553 unsigned int blockx = param.block.x / 8;
554 if (param.aux.x > 2) blockx = param.block.x / 4;
555 unsigned int gx = (arg.threads + blockx - 1) / blockx;
556 return dim3(gx, 1, 1);
559 bool advanceBlockDim(TuneParam ¶m) const
561 // Use param.aux.x to tune and save state for best kernel option
562 // to make use or not of atomicAdd operations and 4 or 8 threads per lattice site!!!
563 const unsigned int min_threads0 = 32 * 8;
564 const unsigned int min_threads1 = 32 * 4;
565 const unsigned int max_threads = 1024; // FIXME: use deviceProp.maxThreadsDim[0];
566 const unsigned int atmadd = 0;
567 unsigned int min_threads = min_threads0;
568 param.aux.x += atmadd; // USE TO SELECT BEST KERNEL OPTION WITH/WITHOUT USING ATOMICADD
569 if (param.aux.x > 2) min_threads = 32 * 4;
570 param.block.x += min_threads;
572 param.grid = createGrid(param);
574 if ((param.block.x >= min_threads) && (param.block.x <= max_threads)) {
575 param.shared_bytes = sharedBytesPerBlock(param);
577 } else if (param.aux.x == 0) {
578 param.block.x = min_threads0;
580 param.aux.x = 1; // USE FOR ATOMIC ADD
581 param.grid = createGrid(param);
582 param.shared_bytes = param.block.x * 4 * sizeof(Float) / 8;
584 } else if (param.aux.x == 1) {
585 param.block.x = min_threads0;
587 param.aux.x = 2; // USE FOR NO ATOMIC ADD and LESS SHARED MEM
588 param.grid = createGrid(param);
589 param.shared_bytes = param.block.x * 4 * sizeof(Float) / 8;
591 } else if (param.aux.x == 2) {
592 param.block.x = min_threads1;
594 param.aux.x = 3; // USE FOR NO ATOMIC ADD
595 param.grid = createGrid(param);
596 param.shared_bytes = param.block.x * 4 * sizeof(Float);
598 } else if (param.aux.x == 3) {
599 param.block.x = min_threads1;
602 param.grid = createGrid(param);
603 param.shared_bytes = param.block.x * sizeof(Float);
605 } else if (param.aux.x == 4) {
606 param.block.x = min_threads1;
609 param.grid = createGrid(param);
610 param.shared_bytes = param.block.x * sizeof(Float);
617 unsigned int sharedBytesPerThread() const { return 0; }
618 unsigned int sharedBytesPerBlock(const TuneParam ¶m) const {
619 switch (param.aux.x) {
620 case 0: return param.block.x * 4 * sizeof(Float);
621 case 1: return param.block.x * 4 * sizeof(Float) / 8;
622 case 2: return param.block.x * 4 * sizeof(Float) / 8;
623 case 3: return param.block.x * 4 * sizeof(Float);
624 default: return param.block.x * sizeof(Float);
628 bool tuneSharedBytes() const { return false; }
629 bool tuneGridDim() const { return false; }
630 unsigned int minThreads() const { return arg.threads; }
633 GaugeFixInteriorPoints(GaugeFixInteriorPointsArg<Float, Gauge> &arg, const GaugeField &meta) :
638 void setParity(const int par) { parity = par; }
640 void apply(const qudaStream_t &stream)
642 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
643 LAUNCH_KERNEL_GAUGEFIX(computeFixInteriorPoints, tp, stream, arg, parity, Float, Gauge, gauge_dir);
646 virtual void initTuneParam(TuneParam ¶m) const
648 param.block = dim3(256, 1, 1);
650 param.grid = createGrid(param);
651 param.shared_bytes = sharedBytesPerBlock(param);
654 virtual void defaultTuneParam(TuneParam ¶m) const { initTuneParam(param); }
656 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
658 std::string paramString(const TuneParam ¶m) const {
659 std::stringstream ps(Tunable::paramString(param));
660 ps << ", atomicadd=" << param.aux.x;
664 void preTune() { arg.data.backup(); }
665 void postTune() { arg.data.restore(); }
666 long long flops() const { return 3LL * (22 + 28 * gauge_dir + 224 * 3) * arg.threads; }
667 long long bytes() const { return 8LL * 2 * arg.threads * meta.Reconstruct() * sizeof(Float); }
670 template <typename Float, typename Gauge>
671 struct GaugeFixBorderPointsArg {
672 int threads; // number of active threads required
673 int X[4]; // grid dimensions
675 int *borderpoints[2];
676 int *faceindicessize[2];
677 size_t faceVolume[4];
678 size_t faceVolumeCB[4];
681 const Float relax_boost;
683 GaugeFixBorderPointsArg(Gauge & dataOr, GaugeField & data, const Float relax_boost, size_t faceVolume_[4], size_t faceVolumeCB_[4])
684 : dataOr(dataOr), data(data), relax_boost(relax_boost)
686 for ( int dir = 0; dir < 4; ++dir ) {
687 X[dir] = data.X()[dir] - data.R()[dir] * 2;
688 border[dir] = data.R()[dir];
691 /*for(int dir=0; dir<4; ++dir){
692 if(comm_dim_partitioned(dir)) border[dir] = BORDER_RADIUS;
693 else border[dir] = 0;
695 for(int dir=0; dir<4; ++dir) X[dir] = data.X()[dir] - border[dir]*2;*/
696 for ( int dir = 0; dir < 4; ++dir ) {
697 faceVolume[dir] = faceVolume_[dir];
698 faceVolumeCB[dir] = faceVolumeCB_[dir];
700 if ( comm_partitioned() ) PreCalculateLatticeIndices(faceVolume, faceVolumeCB, X, border, threads, borderpoints);
705 * @brief Kernel to perform gauge fixing with overrelaxation in the border points for multi-GPU implementation
707 template<int ImplementationType, int blockSize, typename Float, typename Gauge, int gauge_dir>
708 __global__ void computeFixBorderPoints(GaugeFixBorderPointsArg<Float, Gauge> arg, int parity){
709 typedef complex<Float> Cmplx;
711 int tid = (threadIdx.x + blockSize) % blockSize;
712 int idx = blockIdx.x * blockSize + tid;
713 if ( idx >= arg.threads ) return;
714 int mu = (threadIdx.x / blockSize);
715 idx = arg.borderpoints[parity][idx];
717 x[3] = idx / (arg.X[0] * arg.X[1] * arg.X[2]);
718 x[2] = (idx / (arg.X[0] * arg.X[1])) % arg.X[2];
719 x[1] = (idx / arg.X[0]) % arg.X[1];
720 x[0] = idx % arg.X[0];
722 for ( int dr = 0; dr < 4; ++dr ) x[dr] += arg.border[dr];
724 for ( int dr = 0; dr < 4; ++dr ) X[dr] = arg.X[dr] + 2 * arg.border[dr];
726 // 8 threads per lattice site
727 if ( ImplementationType < 3 ) {
728 if ( threadIdx.x >= blockSize * 4 ) {
730 x[mu] = (x[mu] - 1 + X[mu]) % X[mu];
733 idx = (((x[3] * X[2] + x[2]) * X[1] + x[1]) * X[0] + x[0]) >> 1;
734 Matrix<Cmplx,3> link = arg.dataOr(mu, idx, parity);
735 // 8 treads per lattice site, the reduction is performed by shared memory without using atomicadd.
736 // this implementation needs 8x more shared memory than the implementation using atomicadd
737 if ( ImplementationType == 0 ) GaugeFixHit_NoAtomicAdd<blockSize, Float, gauge_dir, 3>(link, arg.relax_boost, tid);
738 // 8 treads per lattice site, the reduction is performed by shared memory using atomicadd
739 if ( ImplementationType == 1 ) GaugeFixHit_AtomicAdd<blockSize, Float, gauge_dir, 3>(link, arg.relax_boost, tid);
740 // 8 treads per lattice site, the reduction is performed by shared memory without using atomicadd.
741 // uses the same amount of shared memory as the atomicadd implementation with more thread block synchronization
742 if ( ImplementationType == 2 ) GaugeFixHit_NoAtomicAdd_LessSM<blockSize, Float, gauge_dir, 3>(link, arg.relax_boost, tid);
743 arg.dataOr(mu, idx, parity) = link;
745 // 4 threads per lattice site
747 idx = (((x[3] * X[2] + x[2]) * X[1] + x[1]) * X[0] + x[0]) >> 1;
748 Matrix<Cmplx,3> link = arg.dataOr(mu, idx, parity);
751 x[mu] = (x[mu] - 1 + X[mu]) % X[mu];
752 int idx1 = (((x[3] * X[2] + x[2]) * X[1] + x[1]) * X[0] + x[0]) >> 1;
753 Matrix<Cmplx,3> link1 = arg.dataOr(mu, idx1, 1 - parity);
755 // 4 treads per lattice site, the reduction is performed by shared memory without using atomicadd.
756 // this implementation needs 4x more shared memory than the implementation using atomicadd
757 if ( ImplementationType == 3 ) GaugeFixHit_NoAtomicAdd<blockSize, Float, gauge_dir, 3>(link, link1, arg.relax_boost, tid);
758 // 4 treads per lattice site, the reduction is performed by shared memory using atomicadd
759 if ( ImplementationType == 4 ) GaugeFixHit_AtomicAdd<blockSize, Float, gauge_dir, 3>(link, link1, arg.relax_boost, tid);
760 // 4 treads per lattice site, the reduction is performed by shared memory without using atomicadd.
761 // uses the same amount of shared memory as the atomicadd implementation with more thread block synchronization
762 if ( ImplementationType == 5 ) GaugeFixHit_NoAtomicAdd_LessSM<blockSize, Float, gauge_dir, 3>(link, link1, arg.relax_boost, tid);
764 arg.dataOr(mu, idx, parity) = link;
765 arg.dataOr(mu, idx1, 1 - parity) = link1;
770 * @brief Tunable object for the border points of the gauge fixing kernel in multi-GPU implementation
772 template<typename Float, typename Gauge, int gauge_dir>
773 class GaugeFixBorderPoints : Tunable {
774 GaugeFixBorderPointsArg<Float, Gauge> &arg;
775 const GaugeField &meta;
778 dim3 createGrid(const TuneParam ¶m) const
780 unsigned int blockx = param.block.x / 8;
781 if (param.aux.x > 2) blockx = param.block.x / 4;
782 unsigned int gx = (arg.threads + blockx - 1) / blockx;
783 return dim3(gx, 1, 1);
786 bool advanceBlockDim(TuneParam ¶m) const
788 // Use param.aux.x to tune and save state for best kernel option
789 // to make use or not of atomicAdd operations and 4 or 8 threads per lattice site!!!
790 const unsigned int min_threads0 = 32 * 8;
791 const unsigned int min_threads1 = 32 * 4;
792 const unsigned int max_threads = 1024; // FIXME: use deviceProp.maxThreadsDim[0];
793 const unsigned int atmadd = 0;
794 unsigned int min_threads = min_threads0;
795 param.aux.x += atmadd; // USE TO SELECT BEST KERNEL OPTION WITH/WITHOUT USING ATOMICADD
796 if (param.aux.x > 2) min_threads = 32 * 4;
797 param.block.x += min_threads;
799 param.grid = createGrid(param);
801 if ((param.block.x >= min_threads) && (param.block.x <= max_threads)) {
802 param.shared_bytes = sharedBytesPerBlock(param);
804 } else if (param.aux.x == 0) {
805 param.block.x = min_threads0;
807 param.aux.x = 1; // USE FOR ATOMIC ADD
808 param.grid = createGrid(param);
809 param.shared_bytes = param.block.x * 4 * sizeof(Float) / 8;
811 } else if (param.aux.x == 1) {
812 param.block.x = min_threads0;
814 param.aux.x = 2; // USE FOR NO ATOMIC ADD and LESS SHARED MEM
815 param.grid = createGrid(param);
816 param.shared_bytes = param.block.x * 4 * sizeof(Float) / 8;
818 } else if (param.aux.x == 2) {
819 param.block.x = min_threads1;
821 param.aux.x = 3; // USE FOR NO ATOMIC ADD
822 param.grid = createGrid(param);
823 param.shared_bytes = param.block.x * 4 * sizeof(Float);
825 } else if (param.aux.x == 3) {
826 param.block.x = min_threads1;
829 param.grid = createGrid(param);
830 param.shared_bytes = param.block.x * sizeof(Float);
832 } else if (param.aux.x == 4) {
833 param.block.x = min_threads1;
836 param.grid = createGrid(param);
837 param.shared_bytes = param.block.x * sizeof(Float);
844 unsigned int sharedBytesPerThread() const { return 0; }
845 unsigned int sharedBytesPerBlock(const TuneParam ¶m) const {
846 switch (param.aux.x) {
847 case 0: return param.block.x * 4 * sizeof(Float);
848 case 1: return param.block.x * 4 * sizeof(Float) / 8;
849 case 2: return param.block.x * 4 * sizeof(Float) / 8;
850 case 3: return param.block.x * 4 * sizeof(Float);
851 default: return param.block.x * sizeof(Float);
855 bool tuneSharedBytes() const { return false; }
856 bool tuneGridDim() const { return false; }
857 unsigned int minThreads() const { return arg.threads; }
860 GaugeFixBorderPoints(GaugeFixBorderPointsArg<Float, Gauge> &arg, const GaugeField &meta) :
865 ~GaugeFixBorderPoints () {
866 if ( comm_partitioned() ) for ( int i = 0; i < 2; i++ ) pool_device_free(arg.borderpoints[i]);
869 void setParity(const int par) { parity = par; }
871 void apply(const qudaStream_t &stream){
872 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
873 LAUNCH_KERNEL_GAUGEFIX(computeFixBorderPoints, tp, stream, arg, parity, Float, Gauge, gauge_dir);
876 virtual void initTuneParam(TuneParam ¶m) const
878 param.block = dim3(256, 1, 1);
880 param.grid = createGrid(param);
881 param.shared_bytes = sharedBytesPerBlock(param);
884 virtual void defaultTuneParam(TuneParam ¶m) const { initTuneParam(param); }
886 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
888 std::string paramString(const TuneParam ¶m) const {
889 std::stringstream ps(Tunable::paramString(param));
890 ps << ", atomicadd=" << param.aux.x;
894 void preTune() { arg.data.backup(); }
895 void postTune() { arg.data.restore(); }
896 long long flops() const { return 3LL * (22 + 28 * gauge_dir + 224 * 3) * arg.threads; }
897 //long long bytes() const { return (1)*8*2*arg.dataOr.Bytes(); } // Only correct if there is no link reconstruction load+save
898 long long bytes() const { return 8LL * 2 * arg.threads * meta.Reconstruct() * sizeof(Float); }
901 template <int NElems_, typename Gauge>
902 struct GaugeFixUnPackArg {
903 static constexpr int NElems = NElems_;
904 int X[4]; // grid dimensions
909 GaugeFixUnPackArg(Gauge & dataOr, GaugeField & data)
911 for ( int dir = 0; dir < 4; ++dir ) {
912 X[dir] = data.X()[dir] - data.R()[dir] * 2;
914 border[dir] = data.R()[dir];
920 template <typename Float, bool pack, typename Arg>
921 __global__ void Kernel_UnPackGhost(int size, Arg arg, complex<Float> *array, int parity, int face, int dir)
923 int idx = blockIdx.x * blockDim.x + threadIdx.x;
924 if ( idx >= size ) return;
926 for ( int dr = 0; dr < 4; ++dr ) X[dr] = arg.X[dr];
933 za = idx / ( X[1] / 2);
935 x[2] = za - x[3] * X[2];
937 xodd = (borderid + x[2] + x[3] + parity) & 1;
938 x[1] = (2 * idx + xodd) - za * X[1];
941 za = idx / ( X[0] / 2);
943 x[2] = za - x[3] * X[2];
945 xodd = (borderid + x[2] + x[3] + parity) & 1;
946 x[0] = (2 * idx + xodd) - za * X[0];
949 za = idx / ( X[0] / 2);
951 x[1] = za - x[3] * X[1];
953 xodd = (borderid + x[1] + x[3] + parity) & 1;
954 x[0] = (2 * idx + xodd) - za * X[0];
957 za = idx / ( X[0] / 2);
959 x[1] = za - x[2] * X[1];
961 xodd = (borderid + x[1] + x[2] + parity) & 1;
962 x[0] = (2 * idx + xodd) - za * X[0];
965 for ( int dr = 0; dr < 4; ++dr ) {
966 x[dr] += arg.border[dr];
967 X[dr] += 2 * arg.border[dr];
971 int id = (((x[3] * X[2] + x[2]) * X[1] + x[1]) * X[0] + x[0]) >> 1;
972 typedef complex<Float> Cmplx;
973 typedef typename mapper<Float>::type RegType;
974 RegType tmp[Arg::NElems];
977 arg.dataOr.load(data, id, dir, parity);
978 arg.dataOr.reconstruct.Pack(tmp, data, id);
979 for ( int i = 0; i < Arg::NElems / 2; ++i ) {
980 array[idx + size * i] = Cmplx(tmp[2*i+0], tmp[2*i+1]);
983 for ( int i = 0; i < Arg::NElems / 2; ++i ) {
984 tmp[2*i+0] = array[idx + size * i].real();
985 tmp[2*i+1] = array[idx + size * i].imag();
987 arg.dataOr.reconstruct.Unpack(data, tmp, id, dir, 0, arg.dataOr.X, arg.dataOr.R);
988 arg.dataOr.save(data, id, dir, parity);
992 template <typename Float, bool pack, typename Arg>
993 __global__ void Kernel_UnPackTop(int size, Arg arg, complex<Float> *array, int parity, int face, int dir)
995 int idx = blockIdx.x * blockDim.x + threadIdx.x;
996 if ( idx >= size ) return;
998 for ( int dr = 0; dr < 4; ++dr ) X[dr] = arg.X[dr];
1001 int borderid = arg.X[face] - 1;
1004 za = idx / ( X[1] / 2);
1006 x[2] = za - x[3] * X[2];
1008 xodd = (borderid + x[2] + x[3] + parity) & 1;
1009 x[1] = (2 * idx + xodd) - za * X[1];
1012 za = idx / ( X[0] / 2);
1014 x[2] = za - x[3] * X[2];
1016 xodd = (borderid + x[2] + x[3] + parity) & 1;
1017 x[0] = (2 * idx + xodd) - za * X[0];
1020 za = idx / ( X[0] / 2);
1022 x[1] = za - x[3] * X[1];
1024 xodd = (borderid + x[1] + x[3] + parity) & 1;
1025 x[0] = (2 * idx + xodd) - za * X[0];
1028 za = idx / ( X[0] / 2);
1030 x[1] = za - x[2] * X[1];
1032 xodd = (borderid + x[1] + x[2] + parity) & 1;
1033 x[0] = (2 * idx + xodd) - za * X[0];
1036 for ( int dr = 0; dr < 4; ++dr ) {
1037 x[dr] += arg.border[dr];
1038 X[dr] += 2 * arg.border[dr];
1040 int id = (((x[3] * X[2] + x[2]) * X[1] + x[1]) * X[0] + x[0]) >> 1;
1041 typedef complex<Float> Cmplx;
1042 typedef typename mapper<Float>::type RegType;
1043 RegType tmp[Arg::NElems];
1046 arg.dataOr.load(data, id, dir, parity);
1047 arg.dataOr.reconstruct.Pack(tmp, data, id);
1048 for ( int i = 0; i < Arg::NElems / 2; ++i ) array[idx + size * i] = Cmplx(tmp[2*i+0], tmp[2*i+1]);
1051 for ( int i = 0; i < Arg::NElems / 2; ++i ) {
1052 tmp[2*i+0] = array[idx + size * i].real();
1053 tmp[2*i+1] = array[idx + size * i].imag();
1055 arg.dataOr.reconstruct.Unpack(data, tmp, id, dir, 0, arg.dataOr.X, arg.dataOr.R);
1056 arg.dataOr.save(data, id, dir, parity);
1062 template<typename Float, typename Gauge, int NElems, int gauge_dir>
1063 void gaugefixingOVR( Gauge dataOr, GaugeField& data,
1064 const int Nsteps, const int verbose_interval,
1065 const Float relax_boost, const double tolerance,
1066 const int reunit_interval, const int stopWtheta)
1068 TimeProfile profileInternalGaugeFixOVR("InternalGaugeFixQudaOVR", false);
1070 profileInternalGaugeFixOVR.TPSTART(QUDA_PROFILE_COMPUTE);
1074 printfQuda("\tOverrelaxation boost parameter: %lf\n", (double)relax_boost);
1075 printfQuda("\tStop criterium: %lf\n", tolerance);
1076 if ( stopWtheta ) printfQuda("\tStop criterium method: theta\n");
1077 else printfQuda("\tStop criterium method: Delta\n");
1078 printfQuda("\tMaximum number of iterations: %d\n", Nsteps);
1079 printfQuda("\tReunitarize at every %d steps\n", reunit_interval);
1080 printfQuda("\tPrint convergence results at every %d steps\n", verbose_interval);
1082 const double unitarize_eps = 1e-14;
1083 const double max_error = 1e-10;
1084 const int reunit_allow_svd = 1;
1085 const int reunit_svd_only = 0;
1086 const double svd_rel_error = 1e-6;
1087 const double svd_abs_error = 1e-6;
1088 setUnitarizeLinksConstants(unitarize_eps, max_error,
1089 reunit_allow_svd, reunit_svd_only,
1090 svd_rel_error, svd_abs_error);
1091 int num_failures = 0;
1092 int* num_failures_dev = static_cast<int*>(pool_device_malloc(sizeof(int)));
1093 qudaMemset(num_failures_dev, 0, sizeof(int));
1095 GaugeFixQualityArg<Gauge> argQ(dataOr, data);
1096 GaugeFixQuality<Float,Gauge, gauge_dir> GaugeFixQuality(argQ, data);
1098 GaugeFixArg<Float, Gauge> arg(dataOr, data, relax_boost);
1099 GaugeFix<Float,Gauge, gauge_dir> gaugeFix(arg, data);
1110 void *hostbuffer_h[4];
1111 qudaStream_t GFStream[9];
1114 size_t faceVolume[4];
1115 size_t faceVolumeCB[4];
1117 MsgHandle *mh_recv_back[4];
1118 MsgHandle *mh_recv_fwd[4];
1119 MsgHandle *mh_send_fwd[4];
1120 MsgHandle *mh_send_back[4];
1124 if ( comm_partitioned() ) {
1126 for ( int dir = 0; dir < 4; ++dir ) {
1127 X[dir] = data.X()[dir] - data.R()[dir] * 2;
1128 if ( !commDimPartitioned(dir) && data.R()[dir] != 0 ) errorQuda("Not supported!");
1130 for ( int i = 0; i < 4; i++ ) {
1132 for ( int j = 0; j < 4; j++ ) {
1133 if ( i == j ) continue;
1134 faceVolume[i] *= X[j];
1136 faceVolumeCB[i] = faceVolume[i] / 2;
1139 for ( int d = 0; d < 4; d++ ) {
1140 if ( !commDimPartitioned(d)) continue;
1141 offset[d] = faceVolumeCB[d] * NElems;
1142 bytes[d] = sizeof(Float) * offset[d];
1143 send_d[d] = device_malloc(bytes[d]);
1144 recv_d[d] = device_malloc(bytes[d]);
1145 sendg_d[d] = device_malloc(bytes[d]);
1146 recvg_d[d] = device_malloc(bytes[d]);
1147 cudaStreamCreate(&GFStream[d]);
1148 cudaStreamCreate(&GFStream[4 + d]);
1150 hostbuffer_h[d] = (void*)pinned_malloc(4 * bytes[d]);
1152 tp[d].block = make_uint3(128, 1, 1);
1153 tp[d].grid = make_uint3((faceVolumeCB[d] + tp[d].block.x - 1) / tp[d].block.x, 1, 1);
1155 cudaStreamCreate(&GFStream[8]);
1156 for ( int d = 0; d < 4; d++ ) {
1157 if ( !commDimPartitioned(d)) continue;
1159 recv[d] = recv_d[d];
1160 send[d] = send_d[d];
1161 recvg[d] = recvg_d[d];
1162 sendg[d] = sendg_d[d];
1164 recv[d] = hostbuffer_h[d];
1165 send[d] = static_cast<char*>(hostbuffer_h[d]) + bytes[d];
1166 recvg[d] = static_cast<char*>(hostbuffer_h[d]) + 3 * bytes[d];
1167 sendg[d] = static_cast<char*>(hostbuffer_h[d]) + 2 * bytes[d];
1169 mh_recv_back[d] = comm_declare_receive_relative(recv[d], d, -1, bytes[d]);
1170 mh_recv_fwd[d] = comm_declare_receive_relative(recvg[d], d, +1, bytes[d]);
1171 mh_send_back[d] = comm_declare_send_relative(sendg[d], d, -1, bytes[d]);
1172 mh_send_fwd[d] = comm_declare_send_relative(send[d], d, +1, bytes[d]);
1175 GaugeFixUnPackArg<NElems,Gauge> dataexarg(dataOr, data);
1176 GaugeFixBorderPointsArg<Float, Gauge> argBorder(dataOr, data, relax_boost, faceVolume, faceVolumeCB);
1177 GaugeFixBorderPoints<Float,Gauge, gauge_dir> gfixBorderPoints(argBorder, data);
1178 GaugeFixInteriorPointsArg<Float, Gauge> argInt(dataOr, data, relax_boost);
1179 GaugeFixInteriorPoints<Float,Gauge, gauge_dir> gfixIntPoints(argInt, data);
1182 GaugeFixQuality.apply(0);
1183 flop += (double)GaugeFixQuality.flops();
1184 byte += (double)GaugeFixQuality.bytes();
1185 double action0 = argQ.getAction();
1186 printfQuda("Step: %d\tAction: %.16e\ttheta: %.16e\n", 0, argQ.getAction(), argQ.getTheta());
1188 unitarizeLinks(data, data, num_failures_dev);
1189 qudaMemcpy(&num_failures, num_failures_dev, sizeof(int), cudaMemcpyDeviceToHost);
1190 if ( num_failures > 0 ) {
1191 pool_device_free(num_failures_dev);
1192 errorQuda("Error in the unitarization\n");
1195 qudaMemset(num_failures_dev, 0, sizeof(int));
1198 for ( iter = 0; iter < Nsteps; iter++ ) {
1199 for ( int p = 0; p < 2; p++ ) {
1201 gaugeFix.setParity(p);
1203 flop += (double)gaugeFix.flops();
1204 byte += (double)gaugeFix.bytes();
1206 if ( !comm_partitioned() ) {
1207 gaugeFix.setParity(p);
1209 flop += (double)gaugeFix.flops();
1210 byte += (double)gaugeFix.bytes();
1213 gfixIntPoints.setParity(p);
1214 gfixBorderPoints.setParity(p); //compute border points
1215 gfixBorderPoints.apply(0);
1216 flop += (double)gfixBorderPoints.flops();
1217 byte += (double)gfixBorderPoints.bytes();
1218 flop += (double)gfixIntPoints.flops();
1219 byte += (double)gfixIntPoints.bytes();
1220 for ( int d = 0; d < 4; d++ ) {
1221 if ( !commDimPartitioned(d)) continue;
1222 comm_start(mh_recv_back[d]);
1223 comm_start(mh_recv_fwd[d]);
1225 //wait for the update to the halo points before start packing...
1226 qudaDeviceSynchronize();
1227 for ( int d = 0; d < 4; d++ ) {
1228 if ( !commDimPartitioned(d)) continue;
1230 qudaLaunchKernel(Kernel_UnPackTop<Float, true, decltype(dataexarg)>, tp[d], GFStream[d],
1231 faceVolumeCB[d], dataexarg, reinterpret_cast<complex<Float>*>(send_d[d]), p, d, d);
1232 //extract bottom ghost
1233 qudaLaunchKernel(Kernel_UnPackGhost<Float, true, decltype(dataexarg)>, tp[d], GFStream[4 + d],
1234 faceVolumeCB[d], dataexarg, reinterpret_cast<complex<Float>*>(sendg_d[d]), 1 - p, d, d);
1237 for ( int d = 0; d < 4; d++ ) {
1238 if ( !commDimPartitioned(d)) continue;
1239 qudaStreamSynchronize(GFStream[d]);
1240 comm_start(mh_send_fwd[d]);
1241 qudaStreamSynchronize(GFStream[4 + d]);
1242 comm_start(mh_send_back[d]);
1245 for ( int d = 0; d < 4; d++ ) {
1246 if ( !commDimPartitioned(d)) continue;
1247 qudaMemcpyAsync(send[d], send_d[d], bytes[d], cudaMemcpyDeviceToHost, GFStream[d]);
1249 for ( int d = 0; d < 4; d++ ) {
1250 if ( !commDimPartitioned(d)) continue;
1251 qudaMemcpyAsync(sendg[d], sendg_d[d], bytes[d], cudaMemcpyDeviceToHost, GFStream[4 + d]);
1254 //compute interior points
1255 gfixIntPoints.apply(GFStream[8]);
1258 for ( int d = 0; d < 4; d++ ) {
1259 if ( !commDimPartitioned(d)) continue;
1260 qudaStreamSynchronize(GFStream[d]);
1261 comm_start(mh_send_fwd[d]);
1262 qudaStreamSynchronize(GFStream[4 + d]);
1263 comm_start(mh_send_back[d]);
1265 for ( int d = 0; d < 4; d++ ) {
1266 if ( !commDimPartitioned(d)) continue;
1267 comm_wait(mh_recv_back[d]);
1268 qudaMemcpyAsync(recv_d[d], recv[d], bytes[d], cudaMemcpyHostToDevice, GFStream[d]);
1270 for ( int d = 0; d < 4; d++ ) {
1271 if ( !commDimPartitioned(d)) continue;
1272 comm_wait(mh_recv_fwd[d]);
1273 qudaMemcpyAsync(recvg_d[d], recvg[d], bytes[d], cudaMemcpyHostToDevice, GFStream[4 + d]);
1276 for ( int d = 0; d < 4; d++ ) {
1277 if ( !commDimPartitioned(d)) continue;
1279 comm_wait(mh_recv_back[d]);
1281 qudaLaunchKernel(Kernel_UnPackGhost<Float, false, decltype(dataexarg)>, tp[d], GFStream[d],
1282 faceVolumeCB[d], dataexarg, reinterpret_cast<complex<Float>*>(recv_d[d]), p, d, d);
1284 for ( int d = 0; d < 4; d++ ) {
1285 if ( !commDimPartitioned(d)) continue;
1287 comm_wait(mh_recv_fwd[d]);
1289 qudaLaunchKernel(Kernel_UnPackTop<Float, false, decltype(dataexarg)>, tp[d], GFStream[4 + d],
1290 faceVolumeCB[d], dataexarg, reinterpret_cast<complex<Float>*>(recvg_d[d]), 1 - p, d, d);
1292 for ( int d = 0; d < 4; d++ ) {
1293 if ( !commDimPartitioned(d)) continue;
1294 comm_wait(mh_send_back[d]);
1295 comm_wait(mh_send_fwd[d]);
1296 qudaStreamSynchronize(GFStream[d]);
1297 qudaStreamSynchronize(GFStream[4 + d]);
1299 qudaStreamSynchronize(GFStream[8]);
1303 if ((iter % reunit_interval) == (reunit_interval - 1)) {
1304 unitarizeLinks(data, data, num_failures_dev);
1305 qudaMemcpy(&num_failures, num_failures_dev, sizeof(int), cudaMemcpyDeviceToHost);
1306 if ( num_failures > 0 ) errorQuda("Error in the unitarization\n");
1307 qudaMemset(num_failures_dev, 0, sizeof(int));
1308 flop += 4588.0 * data.X()[0]*data.X()[1]*data.X()[2]*data.X()[3];
1309 byte += 8.0 * data.X()[0]*data.X()[1]*data.X()[2]*data.X()[3] * dataOr.Bytes();
1311 GaugeFixQuality.apply(0);
1312 flop += (double)GaugeFixQuality.flops();
1313 byte += (double)GaugeFixQuality.bytes();
1314 double action = argQ.getAction();
1315 double diff = abs(action0 - action);
1316 if ((iter % verbose_interval) == (verbose_interval - 1))
1317 printfQuda("Step: %d\tAction: %.16e\ttheta: %.16e\tDelta: %.16e\n", iter + 1, argQ.getAction(), argQ.getTheta(), diff);
1319 if ( argQ.getTheta() < tolerance ) break;
1322 if ( diff < tolerance ) break;
1326 if ((iter % reunit_interval) != 0 ) {
1327 unitarizeLinks(data, data, num_failures_dev);
1328 qudaMemcpy(&num_failures, num_failures_dev, sizeof(int), cudaMemcpyDeviceToHost);
1329 if ( num_failures > 0 ) errorQuda("Error in the unitarization\n");
1330 qudaMemset(num_failures_dev, 0, sizeof(int));
1331 flop += 4588.0 * data.X()[0]*data.X()[1]*data.X()[2]*data.X()[3];
1332 byte += 8.0 * data.X()[0]*data.X()[1]*data.X()[2]*data.X()[3] * dataOr.Bytes();
1334 if ((iter % verbose_interval) != 0 ) {
1335 GaugeFixQuality.apply(0);
1336 flop += (double)GaugeFixQuality.flops();
1337 byte += (double)GaugeFixQuality.bytes();
1338 double action = argQ.getAction();
1339 double diff = abs(action0 - action);
1340 printfQuda("Step: %d\tAction: %.16e\ttheta: %.16e\tDelta: %.16e\n", iter + 1, argQ.getAction(), argQ.getTheta(), diff);
1342 pool_device_free(num_failures_dev);
1344 if ( comm_partitioned() ) {
1345 data.exchangeExtendedGhost(data.R(),false);
1346 for ( int d = 0; d < 4; d++ ) {
1347 if ( commDimPartitioned(d)) {
1348 comm_free(mh_send_fwd[d]);
1349 comm_free(mh_send_back[d]);
1350 comm_free(mh_recv_back[d]);
1351 comm_free(mh_recv_fwd[d]);
1352 device_free(send_d[d]);
1353 device_free(recv_d[d]);
1354 device_free(sendg_d[d]);
1355 device_free(recvg_d[d]);
1356 cudaStreamDestroy(GFStream[d]);
1357 cudaStreamDestroy(GFStream[4 + d]);
1359 host_free(hostbuffer_h[d]);
1363 cudaStreamDestroy(GFStream[8]);
1366 qudaDeviceSynchronize();
1367 profileInternalGaugeFixOVR.TPSTOP(QUDA_PROFILE_COMPUTE);
1368 if (getVerbosity() > QUDA_SUMMARIZE){
1369 double secs = profileInternalGaugeFixOVR.Last(QUDA_PROFILE_COMPUTE);
1370 double gflops = (flop * 1e-9) / (secs);
1371 double gbytes = byte / (secs * 1e9);
1372 printfQuda("Time: %6.6f s, Gflop/s = %6.1f, GB/s = %6.1f\n", secs, gflops * comm_size(), gbytes * comm_size());
1376 template <typename Float, int nColor, QudaReconstructType recon> struct GaugeFixingOVR {
1377 GaugeFixingOVR(GaugeField& data, const int gauge_dir, const int Nsteps, const int verbose_interval,
1378 const Float relax_boost, const double tolerance, const int reunit_interval, const int stopWtheta)
1380 using Gauge = typename gauge_mapper<Float, recon>::type;
1381 if (gauge_dir == 4) {
1382 printfQuda("Starting Landau gauge fixing...\n");
1383 gaugefixingOVR<Float, Gauge, recon, 4>(Gauge(data), data, Nsteps, verbose_interval, relax_boost, tolerance, reunit_interval, stopWtheta);
1384 } else if (gauge_dir == 3) {
1385 printfQuda("Starting Coulomb gauge fixing...\n");
1386 gaugefixingOVR<Float, Gauge, recon, 3>(Gauge(data), data, Nsteps, verbose_interval, relax_boost, tolerance, reunit_interval, stopWtheta);
1388 errorQuda("Unexpected gauge_dir = %d", gauge_dir);
1394 * @brief Gauge fixing with overrelaxation with support for single and multi GPU.
1395 * @param[in,out] data, quda gauge field
1396 * @param[in] gauge_dir, 3 for Coulomb gauge fixing, other for Landau gauge fixing
1397 * @param[in] Nsteps, maximum number of steps to perform gauge fixing
1398 * @param[in] verbose_interval, print gauge fixing info when iteration count is a multiple of this
1399 * @param[in] relax_boost, gauge fixing parameter of the overrelaxation method, most common value is 1.5 or 1.7.
1400 * @param[in] tolerance, torelance value to stop the method, if this value is zero then the method stops when iteration reachs the maximum number of steps defined by Nsteps
1401 * @param[in] reunit_interval, reunitarize gauge field when iteration count is a multiple of this
1402 * @param[in] stopWtheta, 0 for MILC criterium and 1 to use the theta value
1404 void gaugeFixingOVR(GaugeField& data, const int gauge_dir, const int Nsteps, const int verbose_interval, const double relax_boost,
1405 const double tolerance, const int reunit_interval, const int stopWtheta) {
1406 #ifdef GPU_GAUGE_ALG
1407 instantiate<GaugeFixingOVR>(data, gauge_dir, Nsteps, verbose_interval, relax_boost, tolerance, reunit_interval, stopWtheta);
1409 errorQuda("Gauge fixing has not been built");
1410 #endif // GPU_GAUGE_ALG