1 #include <quda_internal.h>
2 #include <quda_matrix.h>
4 #include <gauge_field.h>
5 #include <gauge_field_order.h>
6 #include <launch_kernel.cuh>
7 #include <unitarization_links.h>
9 #include <reduce_helper.h>
10 #include <index_helper.cuh>
13 #include <CUFFT_Plans.h>
14 #include <instantiate.h>
18 //UNCOMMENT THIS IF YOU WAN'T TO USE LESS MEMORY
19 #define GAUGEFIXING_DONT_USE_GX
20 //Without using the precalculation of g(x),
21 //we loose some performance, because Delta(x) is written in normal lattice coordinates need for the FFTs
22 //and the gauge array in even/odd format
25 #ifdef GAUGEFIXING_DONT_USE_GX
26 #warning Not using precalculated g(x)
28 #warning Using precalculated g(x)
32 #ifndef FL_UNITARIZE_PI
33 #define FL_UNITARIZE_PI 3.14159265358979323846
36 template <typename Float>
37 struct GaugeFixFFTRotateArg {
38 int threads; // number of active threads required
39 int X[4]; // grid dimensions
42 GaugeFixFFTRotateArg(const GaugeField &data){
43 for ( int dir = 0; dir < 4; ++dir ) X[dir] = data.X()[dir];
44 threads = X[0] * X[1] * X[2] * X[3];
50 template <int direction, typename Arg>
51 __global__ void fft_rotate_kernel_2D2D(Arg arg){ //Cmplx *data_in, Cmplx *data_out){
52 int id = blockIdx.x * blockDim.x + threadIdx.x;
53 if ( id >= arg.threads ) return;
54 if ( direction == 0 ) {
55 int x3 = id / (arg.X[0] * arg.X[1] * arg.X[2]);
56 int x2 = (id / (arg.X[0] * arg.X[1])) % arg.X[2];
57 int x1 = (id / arg.X[0]) % arg.X[1];
58 int x0 = id % arg.X[0];
60 int id = x0 + (x1 + (x2 + x3 * arg.X[2]) * arg.X[1]) * arg.X[0];
61 int id_out = x2 + (x3 + (x0 + x1 * arg.X[0]) * arg.X[3]) * arg.X[2];
62 arg.tmp1[id_out] = arg.tmp0[id];
63 //data_out[id_out] = data_in[id];
65 if ( direction == 1 ) {
67 int x1 = id / (arg.X[2] * arg.X[3] * arg.X[0]);
68 int x0 = (id / (arg.X[2] * arg.X[3])) % arg.X[0];
69 int x3 = (id / arg.X[2]) % arg.X[3];
70 int x2 = id % arg.X[2];
72 int id = x2 + (x3 + (x0 + x1 * arg.X[0]) * arg.X[3]) * arg.X[2];
73 int id_out = x0 + (x1 + (x2 + x3 * arg.X[2]) * arg.X[1]) * arg.X[0];
74 arg.tmp1[id_out] = arg.tmp0[id];
75 //data_out[id_out] = data_in[id];
79 template <typename Float, typename Arg>
80 class GaugeFixFFTRotate : Tunable {
82 const GaugeField &meta;
84 unsigned int sharedBytesPerThread() const { return 0; }
85 unsigned int sharedBytesPerBlock(const TuneParam ¶m) const { return 0; }
86 bool tuneGridDim() const { return false; }
87 unsigned int minThreads() const { return arg.threads; }
90 GaugeFixFFTRotate(Arg &arg, const GaugeField &meta) :
97 void setDirection(int dir, complex<Float> *data_in, complex<Float> *data_out){
103 void apply(const qudaStream_t &stream){
104 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
105 if ( direction == 0 ) qudaLaunchKernel(fft_rotate_kernel_2D2D<0, Arg>, tp, stream, arg);
106 else if ( direction == 1 ) qudaLaunchKernel(fft_rotate_kernel_2D2D<1, Arg>, tp, stream, arg);
107 else errorQuda("Error in GaugeFixFFTRotate option.\n");
110 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
111 long long flops() const { return 0; }
112 long long bytes() const { return 4LL * sizeof(Float) * arg.threads; }
115 template <typename Float, typename Gauge>
116 struct GaugeFixQualityArg : public ReduceArg<double2> {
117 int threads; // number of active threads required
118 int X[4]; // grid dimensions
120 complex<Float> *delta;
123 GaugeFixQualityArg(const Gauge &dataOr, const GaugeField &data, complex<Float> * delta)
124 : ReduceArg<double2>(), dataOr(dataOr), delta(delta)
126 for ( int dir = 0; dir < 4; ++dir ) X[dir] = data.X()[dir];
127 threads = data.VolumeCB();
129 double getAction() { return result.x; }
130 double getTheta() { return result.y; }
133 template <int blockSize, int Elems, typename Float, typename Gauge, int gauge_dir>
134 __global__ void computeFix_quality(GaugeFixQualityArg<Float, Gauge> argQ)
136 int idx_cb = threadIdx.x + blockIdx.x * blockDim.x;
137 int parity = threadIdx.y;
139 double2 data = make_double2(0.0,0.0);
140 while (idx_cb < argQ.threads) {
141 typedef complex<Float> Cmplx;
144 getCoords(x, idx_cb, argQ.X, parity);
145 Matrix<Cmplx,3> delta;
147 //idx = linkIndex(x,X);
148 for ( int mu = 0; mu < gauge_dir; mu++ ) {
149 Matrix<Cmplx,3> U = argQ.dataOr(mu, idx_cb, parity);
153 data.x += -delta(0, 0).x - delta(1, 1).x - delta(2, 2).x;
155 for ( int mu = 0; mu < gauge_dir; mu++ ) {
156 Matrix<Cmplx,3> U = argQ.dataOr(mu, linkIndexM1(x,argQ.X,mu), 1 - parity);
160 delta -= conj(delta);
164 int idx = getIndexFull(idx_cb, argQ.X, parity);
166 argQ.delta[idx] = delta(0,0);
167 argQ.delta[idx + 2 * argQ.threads] = delta(0,1);
168 argQ.delta[idx + 4 * argQ.threads] = delta(0,2);
169 argQ.delta[idx + 6 * argQ.threads] = delta(1,1);
170 argQ.delta[idx + 8 * argQ.threads] = delta(1,2);
171 argQ.delta[idx + 10 * argQ.threads] = delta(2,2);
173 data.y += getRealTraceUVdagger(delta, delta);
177 idx_cb += blockDim.x * gridDim.x;
180 argQ.template reduce2d<blockSize,2>(data);
183 template<int Elems, typename Float, typename Gauge, int gauge_dir>
184 class GaugeFixQuality : TunableLocalParityReduction {
185 GaugeFixQualityArg<Float, Gauge> &arg;
186 const GaugeField &meta;
189 GaugeFixQuality(GaugeFixQualityArg<Float, Gauge> &arg, const GaugeField &meta) :
193 void apply(const qudaStream_t &stream)
195 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
196 LAUNCH_KERNEL_LOCAL_PARITY(computeFix_quality, (*this), tp, stream, arg, Elems, Float, Gauge, gauge_dir);
197 auto reset = true; // apply is called multiple times with the same arg instance so we need to reset
198 arg.complete(arg.result, stream, reset);
199 if (!activeTuning()) {
200 arg.result.x /= (double)(3 * gauge_dir * 2 * arg.threads);
201 arg.result.y /= (double)(3 * 2 * arg.threads);
205 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
206 long long flops() const { return (36LL * gauge_dir + 65LL) * 2 * arg.threads; }
207 long long bytes() const { return (2LL * gauge_dir + 2LL) * Elems * 2 * arg.threads * sizeof(Float); }
210 template <typename Float>
212 int threads; // number of active threads required
213 int X[4]; // grid dimensions
216 complex<Float> *delta;
219 GaugeFixArg(GaugeField & data, const int Elems) : data(data){
220 for ( int dir = 0; dir < 4; ++dir ) X[dir] = data.X()[dir];
221 threads = X[0] * X[1] * X[2] * X[3];
222 invpsq = (Float*)device_malloc(sizeof(Float) * threads);
223 delta = (complex<Float>*)device_malloc(sizeof(complex<Float>) * threads * 6);
224 #ifdef GAUGEFIXING_DONT_USE_GX
225 gx = (complex<Float>*)device_malloc(sizeof(complex<Float>) * threads);
227 gx = (complex<Float>*)device_malloc(sizeof(complex<Float>) * threads * Elems);
237 template <typename Float>
238 __global__ void kernel_gauge_set_invpsq(GaugeFixArg<Float> arg){
239 int id = blockIdx.x * blockDim.x + threadIdx.x;
240 if ( id >= arg.threads ) return;
241 int x1 = id / (arg.X[2] * arg.X[3] * arg.X[0]);
242 int x0 = (id / (arg.X[2] * arg.X[3])) % arg.X[0];
243 int x3 = (id / arg.X[2]) % arg.X[3];
244 int x2 = id % arg.X[2];
245 //id = x2 + (x3 + (x0 + x1 * arg.X[0]) * arg.X[3]) * arg.X[2];
246 Float sx = sin( (Float)x0 * FL_UNITARIZE_PI / (Float)arg.X[0]);
247 Float sy = sin( (Float)x1 * FL_UNITARIZE_PI / (Float)arg.X[1]);
248 Float sz = sin( (Float)x2 * FL_UNITARIZE_PI / (Float)arg.X[2]);
249 Float st = sin( (Float)x3 * FL_UNITARIZE_PI / (Float)arg.X[3]);
250 Float sinsq = sx * sx + sy * sy + sz * sz + st * st;
252 //The FFT normalization is done here
253 if ( sinsq > 0.00001 ) prcfact = 4.0 / (sinsq * (Float)arg.threads);
254 arg.invpsq[id] = prcfact;
257 template<typename Float>
258 class GaugeFixSETINVPSP : Tunable {
259 GaugeFixArg<Float> arg;
260 const GaugeField &meta;
261 unsigned int sharedBytesPerThread() const { return 0; }
262 unsigned int sharedBytesPerBlock(const TuneParam ¶m) const { return 0; }
263 bool tuneSharedBytes() const { return false; }
264 bool tuneGridDim() const { return false; }
265 unsigned int minThreads() const { return arg.threads; }
268 GaugeFixSETINVPSP(GaugeFixArg<Float> &arg, const GaugeField &meta) :
272 void apply(const qudaStream_t &stream){
273 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
274 qudaLaunchKernel(kernel_gauge_set_invpsq<Float>, tp, stream, arg);
277 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
278 long long flops() const { return 21 * arg.threads; }
279 long long bytes() const { return sizeof(Float) * arg.threads; }
282 template<typename Float>
283 __global__ void kernel_gauge_mult_norm_2D(GaugeFixArg<Float> arg) {
284 int id = blockIdx.x * blockDim.x + threadIdx.x;
285 if ( id < arg.threads ) arg.gx[id] = arg.gx[id] * arg.invpsq[id];
288 template<typename Float>
289 class GaugeFixINVPSP : Tunable {
290 GaugeFixArg<Float> arg;
291 const GaugeField &meta;
292 unsigned int sharedBytesPerThread() const { return 0; }
293 unsigned int sharedBytesPerBlock(const TuneParam ¶m) const { return 0; }
294 bool tuneGridDim() const { return false; }
295 unsigned int minThreads() const { return arg.threads; }
298 GaugeFixINVPSP(GaugeFixArg<Float> &arg, const GaugeField &meta) :
303 void apply(const qudaStream_t &stream) {
304 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
305 qudaLaunchKernel(kernel_gauge_mult_norm_2D<Float>, tp, stream, arg);
308 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
311 //since delta contents are irrelevant at this point, we can swap gx with delta
312 complex<Float> *tmp = arg.gx;
319 long long flops() const { return 2LL * arg.threads; }
320 long long bytes() const { return 5LL * sizeof(Float) * arg.threads; }
323 template <typename Float>
324 __host__ __device__ inline void reunit_link( Matrix<complex<Float>,3> &U ){
326 complex<Float> t2((Float)0.0, (Float)0.0);
328 //first normalize first row
329 //sum of squares of row
331 for ( int c = 0; c < 3; c++ ) t1 += norm(U(0,c));
332 t1 = (Float)1.0 / sqrt(t1);
334 //used to normalize row
336 for ( int c = 0; c < 3; c++ ) U(0,c) *= t1;
339 for ( int c = 0; c < 3; c++ ) t2 += conj(U(0,c)) * U(1,c);
342 for ( int c = 0; c < 3; c++ ) U(1,c) -= t2 * U(0,c);
344 //normalize second row
345 //sum of squares of row
348 for ( int c = 0; c < 3; c++ ) t1 += norm(U(1,c));
349 t1 = (Float)1.0 / sqrt(t1);
351 //used to normalize row
353 for ( int c = 0; c < 3; c++ ) U(1, c) *= t1;
355 //Reconstruct lat row
356 U(2,0) = conj(U(0,1) * U(1,2) - U(0,2) * U(1,1));
357 U(2,1) = conj(U(0,2) * U(1,0) - U(0,0) * U(1,2));
358 U(2,2) = conj(U(0,0) * U(1,1) - U(0,1) * U(1,0));
363 #ifdef GAUGEFIXING_DONT_USE_GX
365 template <typename Float, typename Gauge>
366 __global__ void kernel_gauge_fix_U_EO_NEW(GaugeFixArg<Float> arg, Gauge dataOr, Float half_alpha)
368 int id = threadIdx.x + blockIdx.x * blockDim.x;
369 int parity = threadIdx.y + blockIdx.y * blockDim.y;
370 if (id >= arg.threads/2) return;
372 using complex = complex<Float>;
373 using matrix = Matrix<complex, 3>;
376 getCoords(x, id, arg.X, parity);
377 int idx = ((x[3] * arg.X[2] + x[2]) * arg.X[1] + x[1]) * arg.X[0] + x[0];
380 de(0,0) = arg.delta[idx + 0 * arg.threads];
381 de(0,1) = arg.delta[idx + 1 * arg.threads];
382 de(0,2) = arg.delta[idx + 2 * arg.threads];
383 de(1,1) = arg.delta[idx + 3 * arg.threads];
384 de(1,2) = arg.delta[idx + 4 * arg.threads];
385 de(2,2) = arg.delta[idx + 5 * arg.threads];
387 de(1,0) = complex(-de(0,1).real(), de(0,1).imag());
388 de(2,0) = complex(-de(0,2).real(), de(0,2).imag());
389 de(2,1) = complex(-de(1,2).real(), de(1,2).imag());
392 g += de * half_alpha;
394 reunit_link<Float>( g );
397 for ( int mu = 0; mu < 4; mu++ ) {
398 matrix U = dataOr(mu, id, parity);
402 idx = linkNormalIndexP1(x,arg.X,mu);
404 de(0,0) = arg.delta[idx + 0 * arg.threads];
405 de(0,1) = arg.delta[idx + 1 * arg.threads];
406 de(0,2) = arg.delta[idx + 2 * arg.threads];
407 de(1,1) = arg.delta[idx + 3 * arg.threads];
408 de(1,2) = arg.delta[idx + 4 * arg.threads];
409 de(2,2) = arg.delta[idx + 5 * arg.threads];
411 de(1,0) = complex(-de(0,1).real(), de(0,1).imag());
412 de(2,0) = complex(-de(0,2).real(), de(0,2).imag());
413 de(2,1) = complex(-de(1,2).real(), de(1,2).imag());
416 g0 += de * half_alpha;
418 reunit_link<Float>( g0 );
423 dataOr(mu, id, parity) = U;
427 template<typename Float, typename Gauge>
428 class GaugeFixNEW : TunableVectorY {
429 GaugeFixArg<Float> arg;
430 const GaugeField &meta;
434 bool tuneGridDim() const { return false; }
435 // since GaugeFixArg is used by other kernels that don't keep
436 // parity separate, arg.threads stores Volume and not VolumeCB so
437 // we need to divide by two
438 unsigned int minThreads() const { return arg.threads/2; }
441 GaugeFixNEW(Gauge & dataOr, GaugeFixArg<Float> &arg, Float alpha, const GaugeField &meta) :
447 half_alpha = alpha * 0.5;
450 void setAlpha(Float alpha){ half_alpha = alpha * 0.5; }
452 void apply(const qudaStream_t &stream){
453 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
454 qudaLaunchKernel(kernel_gauge_fix_U_EO_NEW<Float, Gauge>, tp, stream, arg, dataOr, half_alpha);
457 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
458 void preTune() { arg.data.backup(); }
459 void postTune() { arg.data.restore(); }
460 long long flops() const { return 2414LL * arg.threads; }
461 long long bytes() const { return ( dataOr.Bytes() * 4LL + 5 * 12LL * sizeof(Float)) * arg.threads; }
466 template <int Elems, typename Float>
467 __global__ void kernel_gauge_GX(GaugeFixArg<Float> arg, Float half_alpha)
469 int id = blockIdx.x * blockDim.x + threadIdx.x;
470 if (id >= arg.threads) return;
472 using complex = complex<Float>;
474 Matrix<complex,3> de;
476 de(0,0) = arg.delta[id];
477 de(0,1) = arg.delta[id + arg.threads];
478 de(0,2) = arg.delta[id + 2 * arg.threads];
479 de(1,1) = arg.delta[id + 3 * arg.threads];
480 de(1,2) = arg.delta[id + 4 * arg.threads];
481 de(2,2) = arg.delta[id + 5 * arg.threads];
483 de(1,0) = complex(-de(0,1).x, de(0,1).y);
484 de(2,0) = complex(-de(0,2).x, de(0,2).y);
485 de(2,1) = complex(-de(1,2).x, de(1,2).y);
487 Matrix<complex, 3> g;
489 g += de * half_alpha;
491 reunit_link<Float>( g );
493 //gx is represented in even/odd order
494 //normal lattice index to even/odd index
495 int x3 = id / (arg.X[0] * arg.X[1] * arg.X[2]);
496 int x2 = (id / (arg.X[0] * arg.X[1])) % arg.X[2];
497 int x1 = (id / arg.X[0]) % arg.X[1];
498 int x0 = id % arg.X[0];
499 id = (x0 + (x1 + (x2 + x3 * arg.X[2]) * arg.X[1]) * arg.X[0]) >> 1;
500 id += ((x0 + x1 + x2 + x3) & 1 ) * arg.threads / 2;
502 for ( int i = 0; i < Elems; i++ ) arg.gx[id + i * arg.threads] = g.data[i];
507 template<int Elems, typename Float>
508 class GaugeFix_GX : Tunable {
509 GaugeFixArg<Float> arg;
510 const GaugeField &meta;
512 unsigned int sharedBytesPerThread() const { return 0; }
513 unsigned int sharedBytesPerBlock(const TuneParam ¶m) const { return 0; }
514 bool tuneGridDim() const { return false; }
515 unsigned int minThreads() const { return arg.threads; }
518 GaugeFix_GX(GaugeFixArg<Float> &arg, Float alpha, const GaugeField &meta) :
522 half_alpha = alpha * 0.5;
525 void setAlpha(Float alpha) { half_alpha = alpha * 0.5; }
527 void apply(const qudaStream_t &stream){
528 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
529 qudaLaunchKernel(kernel_gauge_GX<Elems, Float>, tp, stream, arg, half_alpha);
532 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
534 long long flops() const {
535 if ( Elems == 6 ) return 208LL * arg.threads;
536 else return 166LL * arg.threads;
538 long long bytes() const { return 4LL * Elems * sizeof(Float) * arg.threads; }
541 template <int Elems, typename Float, typename Gauge>
542 __global__ void kernel_gauge_fix_U_EO( GaugeFixArg<Float> arg, Gauge dataOr)
544 int idd = threadIdx.x + blockIdx.x * blockDim.x;
545 if ( idd >= arg.threads ) return;
549 if ( idd >= arg.threads / 2 ) {
551 id -= arg.threads / 2;
553 typedef complex<Float> Cmplx;
556 //for(int i = 0; i < Elems; i++) g.data[i] = arg.gx[idd + i * arg.threads];
557 for ( int i = 0; i < Elems; i++ ) {
558 g.data[i] = arg.gx[idd + i * arg.threads];
561 g(2,0) = conj(g(0,1) * g(1,2) - g(0,2) * g(1,1));
562 g(2,1) = conj(g(0,2) * g(1,0) - g(0,0) * g(1,2));
563 g(2,2) = conj(g(0,0) * g(1,1) - g(0,1) * g(1,0));
567 getCoords(x, id, arg.X, parity);
568 for ( int mu = 0; mu < 4; mu++ ) {
569 Matrix<Cmplx,3> U = dataOr(mu, id, parity);
573 int idm1 = linkIndexP1(x,arg.X,mu);
574 idm1 += (1 - parity) * arg.threads / 2;
575 //for(int i = 0; i < Elems; i++) g0.data[i] = arg.gx[idm1 + i * arg.threads];
576 for ( int i = 0; i < Elems; i++ ) {
577 g0.data[i] = arg.gx[idm1 + i * arg.threads];
580 g0(2,0) = conj(g0(0,1) * g0(1,2) - g0(0,2) * g0(1,1));
581 g0(2,1) = conj(g0(0,2) * g0(1,0) - g0(0,0) * g0(1,2));
582 g0(2,2) = conj(g0(0,0) * g0(1,1) - g0(0,1) * g0(1,0));
587 dataOr(mu, id, parity) = U;
589 //T=42+4*(198*2+42) Elems=6
590 //T=4*(198*2) Elems=9
591 //Not accounting here the reconstruction of the gauge if 12 or 8!!!!!!
594 template<int Elems, typename Float, typename Gauge>
595 class GaugeFix : Tunable {
596 GaugeFixArg<Float> arg;
597 const GaugeField &meta;
599 unsigned int sharedBytesPerThread() const { return 0; }
600 unsigned int sharedBytesPerBlock(const TuneParam ¶m) const { return 0; }
601 bool tuneGridDim() const { return false; }
602 unsigned int minThreads() const { return arg.threads; }
605 GaugeFix(Gauge & dataOr, GaugeFixArg<Float> &arg, const GaugeField &meta) :
611 void apply(const qudaStream_t &stream) {
612 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
613 qudaLaunchKernel(kernel_gauge_fix_U_EO<Elems, Float, Gauge>, tp, stream, arg, dataOr);
616 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
618 void preTune() { arg.data.backup(); }
619 void postTune() { arg.data.restore(); }
620 long long flops() const {
621 if ( Elems == 6 ) return 1794LL * arg.threads;
622 else return 1536LL * arg.threads;
624 long long bytes() const { return 26LL * Elems * sizeof(Float) * arg.threads; }
627 //GAUGEFIXING_DONT_USE_GX
629 template<int Elems, typename Float, typename Gauge, int gauge_dir>
630 void gaugefixingFFT(Gauge dataOr, GaugeField& data, const int Nsteps, const int verbose_interval,
631 const Float alpha0, const int autotune, const double tolerance, const int stopWtheta)
633 TimeProfile profileInternalGaugeFixFFT("InternalGaugeFixQudaFFT", false);
635 profileInternalGaugeFixFFT.TPSTART(QUDA_PROFILE_COMPUTE);
637 Float alpha = alpha0;
638 std::cout << "\tAlpha parameter of the Steepest Descent Method: " << alpha << std::endl;
639 if ( autotune ) std::cout << "\tAuto tune active: yes" << std::endl;
640 else std::cout << "\tAuto tune active: no" << std::endl;
641 std::cout << "\tStop criterium: " << tolerance << std::endl;
642 if ( stopWtheta ) std::cout << "\tStop criterium method: theta" << std::endl;
643 else std::cout << "\tStop criterium method: Delta" << std::endl;
644 std::cout << "\tMaximum number of iterations: " << Nsteps << std::endl;
645 std::cout << "\tPrint convergence results at every " << verbose_interval << " steps" << std::endl;
648 unsigned int delta_pad = data.X()[0] * data.X()[1] * data.X()[2] * data.X()[3];
649 int4 size = make_int4( data.X()[0], data.X()[1], data.X()[2], data.X()[3] );
653 GaugeFixArg<Float> arg(data, Elems);
654 SetPlanFFT2DMany( plan_zt, size, 0, arg.delta); //for space and time ZT
655 SetPlanFFT2DMany( plan_xy, size, 1, arg.delta); //with space only XY
657 GaugeFixFFTRotateArg<Float> arg_rotate(data);
658 GaugeFixFFTRotate<Float, decltype(arg_rotate)> GFRotate(arg_rotate, data);
660 GaugeFixSETINVPSP<Float> setinvpsp(arg, data);
662 GaugeFixINVPSP<Float> invpsp(arg, data);
664 #ifdef GAUGEFIXING_DONT_USE_GX
665 //without using GX, gx will be created only for plane rotation but with less size
666 GaugeFixNEW<Float, Gauge> gfixNew(dataOr, arg, alpha, data);
669 GaugeFix_GX<Elems, Float> calcGX(arg, alpha, data);
670 GaugeFix<Elems, Float, Gauge> gfix(dataOr, arg, data);
673 GaugeFixQualityArg<Float, Gauge> argQ(dataOr, data, arg.delta);
674 GaugeFixQuality<Elems, Float, Gauge, gauge_dir> gfixquality(argQ, data);
676 gfixquality.apply(0);
677 double action0 = argQ.getAction();
678 printf("Step: %d\tAction: %.16e\ttheta: %.16e\n", 0, argQ.getAction(), argQ.getTheta());
682 for ( iter = 0; iter < Nsteps; iter++ ) {
683 for ( int k = 0; k < 6; k++ ) {
684 //------------------------------------------------------------------------
685 // Set a pointer do the element k in lattice volume
686 // each element is stored with stride lattice volume
687 // it uses gx as temporary array!!!!!!
688 //------------------------------------------------------------------------
689 complex<Float> *_array = arg.delta + k * delta_pad;
690 ////// 2D FFT + 2D FFT
691 //------------------------------------------------------------------------
692 // Perform FFT on xy plane
693 //------------------------------------------------------------------------
694 ApplyFFT(plan_xy, _array, arg.gx, CUFFT_FORWARD);
695 //------------------------------------------------------------------------
696 // Rotate hypercube, xyzt -> ztxy
697 //------------------------------------------------------------------------
698 GFRotate.setDirection(0, arg.gx, _array);
700 //------------------------------------------------------------------------
701 // Perform FFT on zt plane
702 //------------------------------------------------------------------------
703 ApplyFFT(plan_zt, _array, arg.gx, CUFFT_FORWARD);
704 //------------------------------------------------------------------------
705 // Normalize FFT and apply pmax^2/p^2
706 //------------------------------------------------------------------------
708 //------------------------------------------------------------------------
709 // Perform IFFT on zt plane
710 //------------------------------------------------------------------------
711 ApplyFFT(plan_zt, arg.gx, _array, CUFFT_INVERSE);
712 //------------------------------------------------------------------------
713 // Rotate hypercube, ztxy -> xyzt
714 //------------------------------------------------------------------------
715 GFRotate.setDirection(1, _array, arg.gx);
717 //------------------------------------------------------------------------
718 // Perform IFFT on xy plane
719 //------------------------------------------------------------------------
720 ApplyFFT(plan_xy, arg.gx, _array, CUFFT_INVERSE);
722 #ifdef GAUGEFIXING_DONT_USE_GX
723 //------------------------------------------------------------------------
724 // Apply gauge fix to current gauge field
725 //------------------------------------------------------------------------
728 //------------------------------------------------------------------------
730 //------------------------------------------------------------------------
732 //------------------------------------------------------------------------
733 // Apply gauge fix to current gauge field
734 //------------------------------------------------------------------------
737 //------------------------------------------------------------------------
738 // Measure gauge quality and recalculate new Delta(x)
739 //------------------------------------------------------------------------
740 gfixquality.apply(0);
741 double action = argQ.getAction();
742 diff = abs(action0 - action);
743 if ((iter % verbose_interval) == (verbose_interval - 1))
744 printf("Step: %d\tAction: %.16e\ttheta: %.16e\tDelta: %.16e\n", iter + 1, argQ.getAction(), argQ.getTheta(), diff);
745 if ( autotune && ((action - action0) < -1e-14) ) {
746 if ( alpha > 0.01 ) {
747 alpha = 0.95 * alpha;
748 #ifdef GAUGEFIXING_DONT_USE_GX
749 gfixNew.setAlpha(alpha);
751 calcGX.setAlpha(alpha);
753 printf(">>>>>>>>>>>>>> Warning: changing alpha down -> %.4e\n", alpha );
756 //------------------------------------------------------------------------
757 // Check gauge fix quality criterium
758 //------------------------------------------------------------------------
759 if ( stopWtheta ) { if ( argQ.getTheta() < tolerance ) break; }
760 else { if ( diff < tolerance ) break; }
764 if ((iter % verbose_interval) != 0 )
765 printf("Step: %d\tAction: %.16e\ttheta: %.16e\tDelta: %.16e\n", iter, argQ.getAction(), argQ.getTheta(), diff);
767 // Reunitarize at end
768 const double unitarize_eps = 1e-14;
769 const double max_error = 1e-10;
770 const int reunit_allow_svd = 1;
771 const int reunit_svd_only = 0;
772 const double svd_rel_error = 1e-6;
773 const double svd_abs_error = 1e-6;
774 setUnitarizeLinksConstants(unitarize_eps, max_error,
775 reunit_allow_svd, reunit_svd_only,
776 svd_rel_error, svd_abs_error);
777 int num_failures = 0;
778 int* num_failures_dev = static_cast<int*>(pool_device_malloc(sizeof(int)));
779 qudaMemset(num_failures_dev, 0, sizeof(int));
780 unitarizeLinks(data, data, num_failures_dev);
781 qudaMemcpy(&num_failures, num_failures_dev, sizeof(int), cudaMemcpyDeviceToHost);
783 pool_device_free(num_failures_dev);
784 if ( num_failures > 0 ) {
785 errorQuda("Error in the unitarization\n");
791 CUFFT_SAFE_CALL(cufftDestroy(plan_zt));
792 CUFFT_SAFE_CALL(cufftDestroy(plan_xy));
793 qudaDeviceSynchronize();
794 profileInternalGaugeFixFFT.TPSTOP(QUDA_PROFILE_COMPUTE);
796 if (getVerbosity() > QUDA_SUMMARIZE){
797 double secs = profileInternalGaugeFixFFT.Last(QUDA_PROFILE_COMPUTE);
798 double fftflop = 5.0 * (log2((double)( data.X()[0] * data.X()[1]) ) + log2( (double)(data.X()[2] * data.X()[3] )));
799 fftflop *= (double)( data.X()[0] * data.X()[1] * data.X()[2] * data.X()[3] );
800 double gflops = setinvpsp.flops() + gfixquality.flops();
801 double gbytes = setinvpsp.bytes() + gfixquality.bytes();
802 double flop = invpsp.flops() * Elems;
803 double byte = invpsp.bytes() * Elems;
804 flop += (GFRotate.flops() + fftflop) * Elems * 2;
805 byte += GFRotate.bytes() * Elems * 4; //includes FFT reads, assuming 1 read and 1 write per site
806 #ifdef GAUGEFIXING_DONT_USE_GX
807 flop += gfixNew.flops();
808 byte += gfixNew.bytes();
810 flop += calcGX.flops();
811 byte += calcGX.bytes();
812 flop += gfix.flops();
813 byte += gfix.bytes();
815 flop += gfixquality.flops();
816 byte += gfixquality.bytes();
817 gflops += flop * iter;
818 gbytes += byte * iter;
819 gflops += 4588.0 * data.X()[0]*data.X()[1]*data.X()[2]*data.X()[3]; //Reunitarize at end
820 gbytes += 8.0 * data.X()[0]*data.X()[1]*data.X()[2]*data.X()[3] * dataOr.Bytes() ; //Reunitarize at end
822 gflops = (gflops * 1e-9) / (secs);
823 gbytes = gbytes / (secs * 1e9);
824 printfQuda("Time: %6.6f s, Gflop/s = %6.1f, GB/s = %6.1f\n", secs, gflops, gbytes);
828 template<typename Float, int nColors, QudaReconstructType recon> struct GaugeFixingFFT {
829 GaugeFixingFFT(GaugeField& data, const int gauge_dir, const int Nsteps, const int verbose_interval, const Float alpha,
830 const int autotune, const double tolerance, const int stopWtheta)
832 using Gauge = typename gauge_mapper<Float, recon>::type;
833 constexpr int n_element = recon / 2; // number of complex elements used to store g(x) and Delta(x)
834 if ( gauge_dir != 3 ) {
835 printfQuda("Starting Landau gauge fixing with FFTs...\n");
836 gaugefixingFFT<n_element, Float, Gauge, 4>(Gauge(data), data, Nsteps, verbose_interval, alpha, autotune, tolerance, stopWtheta);
838 printfQuda("Starting Coulomb gauge fixing with FFTs...\n");
839 gaugefixingFFT<n_element, Float, Gauge, 3>(Gauge(data), data, Nsteps, verbose_interval, alpha, autotune, tolerance, stopWtheta);
845 * @brief Gauge fixing with Steepest descent method with FFTs with support for single GPU only.
846 * @param[in,out] data, quda gauge field
847 * @param[in] gauge_dir, 3 for Coulomb gauge fixing, other for Landau gauge fixing
848 * @param[in] Nsteps, maximum number of steps to perform gauge fixing
849 * @param[in] verbose_interval, print gauge fixing info when iteration count is a multiple of this
850 * @param[in] alpha, gauge fixing parameter of the method, most common value is 0.08
851 * @param[in] autotune, 1 to autotune the method, i.e., if the Fg inverts its tendency we decrease the alpha value
852 * @param[in] tolerance, torelance value to stop the method, if this value is zero then the method stops when iteration reachs the maximum number of steps defined by Nsteps
853 * @param[in] stopWtheta, 0 for MILC criterium and 1 to use the theta value
855 void gaugeFixingFFT(GaugeField& data, const int gauge_dir, const int Nsteps, const int verbose_interval, const double alpha,
856 const int autotune, const double tolerance, const int stopWtheta)
860 if (comm_dim_partitioned(0) || comm_dim_partitioned(1) || comm_dim_partitioned(2) || comm_dim_partitioned(3))
861 errorQuda("Gauge Fixing with FFTs in multi-GPU support NOT implemented yet!\n");
863 instantiate<GaugeFixingFFT>(data, gauge_dir, Nsteps, verbose_interval, (float)alpha, autotune, tolerance, stopWtheta);
865 errorQuda("Gauge fixing has bot been built");