QUDA  v1.1.0
A library for QCD on GPUs
gauge_fix_fft.cu
Go to the documentation of this file.
1 #include <quda_internal.h>
2 #include <quda_matrix.h>
3 #include <tune_quda.h>
4 #include <gauge_field.h>
5 #include <gauge_field_order.h>
6 #include <launch_kernel.cuh>
7 #include <unitarization_links.h>
8 #include <atomic.cuh>
9 #include <reduce_helper.h>
10 #include <index_helper.cuh>
11 
12 #include <cufft.h>
13 #include <CUFFT_Plans.h>
14 #include <instantiate.h>
15 
16 namespace quda {
17 
18 //UNCOMMENT THIS IF YOU WAN'T TO USE LESS MEMORY
19 #define GAUGEFIXING_DONT_USE_GX
20 //Without using the precalculation of g(x),
21 //we loose some performance, because Delta(x) is written in normal lattice coordinates need for the FFTs
22 //and the gauge array in even/odd format
23 
24 #ifdef HOST_DEBUG
25 #ifdef GAUGEFIXING_DONT_USE_GX
26 #warning Not using precalculated g(x)
27 #else
28 #warning Using precalculated g(x)
29 #endif
30 #endif
31 
32 #ifndef FL_UNITARIZE_PI
33 #define FL_UNITARIZE_PI 3.14159265358979323846
34 #endif
35 
36  template <typename Float>
37  struct GaugeFixFFTRotateArg {
38  int threads; // number of active threads required
39  int X[4]; // grid dimensions
40  complex<Float> *tmp0;
41  complex<Float> *tmp1;
42  GaugeFixFFTRotateArg(const GaugeField &data){
43  for ( int dir = 0; dir < 4; ++dir ) X[dir] = data.X()[dir];
44  threads = X[0] * X[1] * X[2] * X[3];
45  tmp0 = 0;
46  tmp1 = 0;
47  }
48  };
49 
50  template <int direction, typename Arg>
51  __global__ void fft_rotate_kernel_2D2D(Arg arg){ //Cmplx *data_in, Cmplx *data_out){
52  int id = blockIdx.x * blockDim.x + threadIdx.x;
53  if ( id >= arg.threads ) return;
54  if ( direction == 0 ) {
55  int x3 = id / (arg.X[0] * arg.X[1] * arg.X[2]);
56  int x2 = (id / (arg.X[0] * arg.X[1])) % arg.X[2];
57  int x1 = (id / arg.X[0]) % arg.X[1];
58  int x0 = id % arg.X[0];
59 
60  int id = x0 + (x1 + (x2 + x3 * arg.X[2]) * arg.X[1]) * arg.X[0];
61  int id_out = x2 + (x3 + (x0 + x1 * arg.X[0]) * arg.X[3]) * arg.X[2];
62  arg.tmp1[id_out] = arg.tmp0[id];
63  //data_out[id_out] = data_in[id];
64  }
65  if ( direction == 1 ) {
66 
67  int x1 = id / (arg.X[2] * arg.X[3] * arg.X[0]);
68  int x0 = (id / (arg.X[2] * arg.X[3])) % arg.X[0];
69  int x3 = (id / arg.X[2]) % arg.X[3];
70  int x2 = id % arg.X[2];
71 
72  int id = x2 + (x3 + (x0 + x1 * arg.X[0]) * arg.X[3]) * arg.X[2];
73  int id_out = x0 + (x1 + (x2 + x3 * arg.X[2]) * arg.X[1]) * arg.X[0];
74  arg.tmp1[id_out] = arg.tmp0[id];
75  //data_out[id_out] = data_in[id];
76  }
77  }
78 
79  template <typename Float, typename Arg>
80  class GaugeFixFFTRotate : Tunable {
81  Arg &arg;
82  const GaugeField &meta;
83  int direction;
84  unsigned int sharedBytesPerThread() const { return 0; }
85  unsigned int sharedBytesPerBlock(const TuneParam &param) const { return 0; }
86  bool tuneGridDim() const { return false; }
87  unsigned int minThreads() const { return arg.threads; }
88 
89  public:
90  GaugeFixFFTRotate(Arg &arg, const GaugeField &meta) :
91  arg(arg),
92  meta(meta)
93  {
94  direction = 0;
95  }
96 
97  void setDirection(int dir, complex<Float> *data_in, complex<Float> *data_out){
98  direction = dir;
99  arg.tmp0 = data_in;
100  arg.tmp1 = data_out;
101  }
102 
103  void apply(const qudaStream_t &stream){
104  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
105  if ( direction == 0 ) qudaLaunchKernel(fft_rotate_kernel_2D2D<0, Arg>, tp, stream, arg);
106  else if ( direction == 1 ) qudaLaunchKernel(fft_rotate_kernel_2D2D<1, Arg>, tp, stream, arg);
107  else errorQuda("Error in GaugeFixFFTRotate option.\n");
108  }
109 
110  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
111  long long flops() const { return 0; }
112  long long bytes() const { return 4LL * sizeof(Float) * arg.threads; }
113  };
114 
115  template <typename Float, typename Gauge>
116  struct GaugeFixQualityArg : public ReduceArg<double2> {
117  int threads; // number of active threads required
118  int X[4]; // grid dimensions
119  Gauge dataOr;
120  complex<Float> *delta;
121  double2 result;
122 
123  GaugeFixQualityArg(const Gauge &dataOr, const GaugeField &data, complex<Float> * delta)
124  : ReduceArg<double2>(), dataOr(dataOr), delta(delta)
125  {
126  for ( int dir = 0; dir < 4; ++dir ) X[dir] = data.X()[dir];
127  threads = data.VolumeCB();
128  }
129  double getAction() { return result.x; }
130  double getTheta() { return result.y; }
131  };
132 
133  template <int blockSize, int Elems, typename Float, typename Gauge, int gauge_dir>
134  __global__ void computeFix_quality(GaugeFixQualityArg<Float, Gauge> argQ)
135  {
136  int idx_cb = threadIdx.x + blockIdx.x * blockDim.x;
137  int parity = threadIdx.y;
138 
139  double2 data = make_double2(0.0,0.0);
140  while (idx_cb < argQ.threads) {
141  typedef complex<Float> Cmplx;
142 
143  int x[4];
144  getCoords(x, idx_cb, argQ.X, parity);
145  Matrix<Cmplx,3> delta;
146  setZero(&delta);
147  //idx = linkIndex(x,X);
148  for ( int mu = 0; mu < gauge_dir; mu++ ) {
149  Matrix<Cmplx,3> U = argQ.dataOr(mu, idx_cb, parity);
150  delta -= U;
151  }
152  //18*gauge_dir
153  data.x += -delta(0, 0).x - delta(1, 1).x - delta(2, 2).x;
154  //2
155  for ( int mu = 0; mu < gauge_dir; mu++ ) {
156  Matrix<Cmplx,3> U = argQ.dataOr(mu, linkIndexM1(x,argQ.X,mu), 1 - parity);
157  delta += U;
158  }
159  //18*gauge_dir
160  delta -= conj(delta);
161  //18
162  //SAVE DELTA!!!!!
163  SubTraceUnit(delta);
164  int idx = getIndexFull(idx_cb, argQ.X, parity);
165  //Saving Delta
166  argQ.delta[idx] = delta(0,0);
167  argQ.delta[idx + 2 * argQ.threads] = delta(0,1);
168  argQ.delta[idx + 4 * argQ.threads] = delta(0,2);
169  argQ.delta[idx + 6 * argQ.threads] = delta(1,1);
170  argQ.delta[idx + 8 * argQ.threads] = delta(1,2);
171  argQ.delta[idx + 10 * argQ.threads] = delta(2,2);
172  //12
173  data.y += getRealTraceUVdagger(delta, delta);
174  //35
175  //T=36*gauge_dir+65
176 
177  idx_cb += blockDim.x * gridDim.x;
178  }
179 
180  argQ.template reduce2d<blockSize,2>(data);
181  }
182 
183  template<int Elems, typename Float, typename Gauge, int gauge_dir>
184  class GaugeFixQuality : TunableLocalParityReduction {
185  GaugeFixQualityArg<Float, Gauge> &arg;
186  const GaugeField &meta;
187 
188  public:
189  GaugeFixQuality(GaugeFixQualityArg<Float, Gauge> &arg, const GaugeField &meta) :
190  arg(arg),
191  meta(meta) { }
192 
193  void apply(const qudaStream_t &stream)
194  {
195  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
196  LAUNCH_KERNEL_LOCAL_PARITY(computeFix_quality, (*this), tp, stream, arg, Elems, Float, Gauge, gauge_dir);
197  auto reset = true; // apply is called multiple times with the same arg instance so we need to reset
198  arg.complete(arg.result, stream, reset);
199  if (!activeTuning()) {
200  arg.result.x /= (double)(3 * gauge_dir * 2 * arg.threads);
201  arg.result.y /= (double)(3 * 2 * arg.threads);
202  }
203  }
204 
205  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
206  long long flops() const { return (36LL * gauge_dir + 65LL) * 2 * arg.threads; }
207  long long bytes() const { return (2LL * gauge_dir + 2LL) * Elems * 2 * arg.threads * sizeof(Float); }
208  };
209 
210  template <typename Float>
211  struct GaugeFixArg {
212  int threads; // number of active threads required
213  int X[4]; // grid dimensions
214  GaugeField &data;
215  Float *invpsq;
216  complex<Float> *delta;
217  complex<Float> *gx;
218 
219  GaugeFixArg(GaugeField & data, const int Elems) : data(data){
220  for ( int dir = 0; dir < 4; ++dir ) X[dir] = data.X()[dir];
221  threads = X[0] * X[1] * X[2] * X[3];
222  invpsq = (Float*)device_malloc(sizeof(Float) * threads);
223  delta = (complex<Float>*)device_malloc(sizeof(complex<Float>) * threads * 6);
224 #ifdef GAUGEFIXING_DONT_USE_GX
225  gx = (complex<Float>*)device_malloc(sizeof(complex<Float>) * threads);
226 #else
227  gx = (complex<Float>*)device_malloc(sizeof(complex<Float>) * threads * Elems);
228 #endif
229  }
230  void free(){
231  device_free(invpsq);
232  device_free(delta);
233  device_free(gx);
234  }
235  };
236 
237  template <typename Float>
238  __global__ void kernel_gauge_set_invpsq(GaugeFixArg<Float> arg){
239  int id = blockIdx.x * blockDim.x + threadIdx.x;
240  if ( id >= arg.threads ) return;
241  int x1 = id / (arg.X[2] * arg.X[3] * arg.X[0]);
242  int x0 = (id / (arg.X[2] * arg.X[3])) % arg.X[0];
243  int x3 = (id / arg.X[2]) % arg.X[3];
244  int x2 = id % arg.X[2];
245  //id = x2 + (x3 + (x0 + x1 * arg.X[0]) * arg.X[3]) * arg.X[2];
246  Float sx = sin( (Float)x0 * FL_UNITARIZE_PI / (Float)arg.X[0]);
247  Float sy = sin( (Float)x1 * FL_UNITARIZE_PI / (Float)arg.X[1]);
248  Float sz = sin( (Float)x2 * FL_UNITARIZE_PI / (Float)arg.X[2]);
249  Float st = sin( (Float)x3 * FL_UNITARIZE_PI / (Float)arg.X[3]);
250  Float sinsq = sx * sx + sy * sy + sz * sz + st * st;
251  Float prcfact = 0.0;
252  //The FFT normalization is done here
253  if ( sinsq > 0.00001 ) prcfact = 4.0 / (sinsq * (Float)arg.threads);
254  arg.invpsq[id] = prcfact;
255  }
256 
257  template<typename Float>
258  class GaugeFixSETINVPSP : Tunable {
259  GaugeFixArg<Float> arg;
260  const GaugeField &meta;
261  unsigned int sharedBytesPerThread() const { return 0; }
262  unsigned int sharedBytesPerBlock(const TuneParam &param) const { return 0; }
263  bool tuneSharedBytes() const { return false; }
264  bool tuneGridDim() const { return false; }
265  unsigned int minThreads() const { return arg.threads; }
266 
267  public:
268  GaugeFixSETINVPSP(GaugeFixArg<Float> &arg, const GaugeField &meta) :
269  arg(arg),
270  meta(meta) { }
271 
272  void apply(const qudaStream_t &stream){
273  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
274  qudaLaunchKernel(kernel_gauge_set_invpsq<Float>, tp, stream, arg);
275  }
276 
277  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
278  long long flops() const { return 21 * arg.threads; }
279  long long bytes() const { return sizeof(Float) * arg.threads; }
280  };
281 
282  template<typename Float>
283  __global__ void kernel_gauge_mult_norm_2D(GaugeFixArg<Float> arg) {
284  int id = blockIdx.x * blockDim.x + threadIdx.x;
285  if ( id < arg.threads ) arg.gx[id] = arg.gx[id] * arg.invpsq[id];
286  }
287 
288  template<typename Float>
289  class GaugeFixINVPSP : Tunable {
290  GaugeFixArg<Float> arg;
291  const GaugeField &meta;
292  unsigned int sharedBytesPerThread() const { return 0; }
293  unsigned int sharedBytesPerBlock(const TuneParam &param) const { return 0; }
294  bool tuneGridDim() const { return false; }
295  unsigned int minThreads() const { return arg.threads; }
296 
297  public:
298  GaugeFixINVPSP(GaugeFixArg<Float> &arg, const GaugeField &meta) :
299  arg(arg),
300  meta(meta)
301  { }
302 
303  void apply(const qudaStream_t &stream) {
304  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
305  qudaLaunchKernel(kernel_gauge_mult_norm_2D<Float>, tp, stream, arg);
306  }
307 
308  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
309 
310  void preTune() {
311  //since delta contents are irrelevant at this point, we can swap gx with delta
312  complex<Float> *tmp = arg.gx;
313  arg.gx = arg.delta;
314  arg.delta = tmp;
315  }
316  void postTune() {
317  arg.gx = arg.delta;
318  }
319  long long flops() const { return 2LL * arg.threads; }
320  long long bytes() const { return 5LL * sizeof(Float) * arg.threads; }
321  };
322 
323  template <typename Float>
324  __host__ __device__ inline void reunit_link( Matrix<complex<Float>,3> &U ){
325 
326  complex<Float> t2((Float)0.0, (Float)0.0);
327  Float t1 = 0.0;
328  //first normalize first row
329  //sum of squares of row
330 #pragma unroll
331  for ( int c = 0; c < 3; c++ ) t1 += norm(U(0,c));
332  t1 = (Float)1.0 / sqrt(t1);
333  //14
334  //used to normalize row
335 #pragma unroll
336  for ( int c = 0; c < 3; c++ ) U(0,c) *= t1;
337  //6
338 #pragma unroll
339  for ( int c = 0; c < 3; c++ ) t2 += conj(U(0,c)) * U(1,c);
340  //24
341 #pragma unroll
342  for ( int c = 0; c < 3; c++ ) U(1,c) -= t2 * U(0,c);
343  //24
344  //normalize second row
345  //sum of squares of row
346  t1 = 0.0;
347 #pragma unroll
348  for ( int c = 0; c < 3; c++ ) t1 += norm(U(1,c));
349  t1 = (Float)1.0 / sqrt(t1);
350  //14
351  //used to normalize row
352 #pragma unroll
353  for ( int c = 0; c < 3; c++ ) U(1, c) *= t1;
354  //6
355  //Reconstruct lat row
356  U(2,0) = conj(U(0,1) * U(1,2) - U(0,2) * U(1,1));
357  U(2,1) = conj(U(0,2) * U(1,0) - U(0,0) * U(1,2));
358  U(2,2) = conj(U(0,0) * U(1,1) - U(0,1) * U(1,0));
359  //42
360  //T=130
361  }
362 
363 #ifdef GAUGEFIXING_DONT_USE_GX
364 
365  template <typename Float, typename Gauge>
366  __global__ void kernel_gauge_fix_U_EO_NEW(GaugeFixArg<Float> arg, Gauge dataOr, Float half_alpha)
367  {
368  int id = threadIdx.x + blockIdx.x * blockDim.x;
369  int parity = threadIdx.y + blockIdx.y * blockDim.y;
370  if (id >= arg.threads/2) return;
371 
372  using complex = complex<Float>;
373  using matrix = Matrix<complex, 3>;
374 
375  int x[4];
376  getCoords(x, id, arg.X, parity);
377  int idx = ((x[3] * arg.X[2] + x[2]) * arg.X[1] + x[1]) * arg.X[0] + x[0];
378  matrix de;
379  //Read Delta
380  de(0,0) = arg.delta[idx + 0 * arg.threads];
381  de(0,1) = arg.delta[idx + 1 * arg.threads];
382  de(0,2) = arg.delta[idx + 2 * arg.threads];
383  de(1,1) = arg.delta[idx + 3 * arg.threads];
384  de(1,2) = arg.delta[idx + 4 * arg.threads];
385  de(2,2) = arg.delta[idx + 5 * arg.threads];
386 
387  de(1,0) = complex(-de(0,1).real(), de(0,1).imag());
388  de(2,0) = complex(-de(0,2).real(), de(0,2).imag());
389  de(2,1) = complex(-de(1,2).real(), de(1,2).imag());
390  matrix g;
391  setIdentity(&g);
392  g += de * half_alpha;
393  //36
394  reunit_link<Float>( g );
395  //130
396 
397  for ( int mu = 0; mu < 4; mu++ ) {
398  matrix U = dataOr(mu, id, parity);
399  matrix g0;
400  U = g * U;
401  //198
402  idx = linkNormalIndexP1(x,arg.X,mu);
403  //Read Delta
404  de(0,0) = arg.delta[idx + 0 * arg.threads];
405  de(0,1) = arg.delta[idx + 1 * arg.threads];
406  de(0,2) = arg.delta[idx + 2 * arg.threads];
407  de(1,1) = arg.delta[idx + 3 * arg.threads];
408  de(1,2) = arg.delta[idx + 4 * arg.threads];
409  de(2,2) = arg.delta[idx + 5 * arg.threads];
410 
411  de(1,0) = complex(-de(0,1).real(), de(0,1).imag());
412  de(2,0) = complex(-de(0,2).real(), de(0,2).imag());
413  de(2,1) = complex(-de(1,2).real(), de(1,2).imag());
414 
415  setIdentity(&g0);
416  g0 += de * half_alpha;
417  //36
418  reunit_link<Float>( g0 );
419  //130
420 
421  U = U * conj(g0);
422  //198
423  dataOr(mu, id, parity) = U;
424  }
425  }
426 
427  template<typename Float, typename Gauge>
428  class GaugeFixNEW : TunableVectorY {
429  GaugeFixArg<Float> arg;
430  const GaugeField &meta;
431  Float half_alpha;
432  Gauge dataOr;
433 
434  bool tuneGridDim() const { return false; }
435  // since GaugeFixArg is used by other kernels that don't keep
436  // parity separate, arg.threads stores Volume and not VolumeCB so
437  // we need to divide by two
438  unsigned int minThreads() const { return arg.threads/2; }
439 
440  public:
441  GaugeFixNEW(Gauge & dataOr, GaugeFixArg<Float> &arg, Float alpha, const GaugeField &meta) :
442  TunableVectorY(2),
443  dataOr(dataOr),
444  arg(arg),
445  meta(meta)
446  {
447  half_alpha = alpha * 0.5;
448  }
449 
450  void setAlpha(Float alpha){ half_alpha = alpha * 0.5; }
451 
452  void apply(const qudaStream_t &stream){
453  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
454  qudaLaunchKernel(kernel_gauge_fix_U_EO_NEW<Float, Gauge>, tp, stream, arg, dataOr, half_alpha);
455  }
456 
457  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
458  void preTune() { arg.data.backup(); }
459  void postTune() { arg.data.restore(); }
460  long long flops() const { return 2414LL * arg.threads; }
461  long long bytes() const { return ( dataOr.Bytes() * 4LL + 5 * 12LL * sizeof(Float)) * arg.threads; }
462  };
463 
464 #else
465 
466  template <int Elems, typename Float>
467  __global__ void kernel_gauge_GX(GaugeFixArg<Float> arg, Float half_alpha)
468  {
469  int id = blockIdx.x * blockDim.x + threadIdx.x;
470  if (id >= arg.threads) return;
471 
472  using complex = complex<Float>;
473 
474  Matrix<complex,3> de;
475  //Read Delta
476  de(0,0) = arg.delta[id];
477  de(0,1) = arg.delta[id + arg.threads];
478  de(0,2) = arg.delta[id + 2 * arg.threads];
479  de(1,1) = arg.delta[id + 3 * arg.threads];
480  de(1,2) = arg.delta[id + 4 * arg.threads];
481  de(2,2) = arg.delta[id + 5 * arg.threads];
482 
483  de(1,0) = complex(-de(0,1).x, de(0,1).y);
484  de(2,0) = complex(-de(0,2).x, de(0,2).y);
485  de(2,1) = complex(-de(1,2).x, de(1,2).y);
486 
487  Matrix<complex, 3> g;
488  setIdentity(&g);
489  g += de * half_alpha;
490  //36
491  reunit_link<Float>( g );
492  //130
493  //gx is represented in even/odd order
494  //normal lattice index to even/odd index
495  int x3 = id / (arg.X[0] * arg.X[1] * arg.X[2]);
496  int x2 = (id / (arg.X[0] * arg.X[1])) % arg.X[2];
497  int x1 = (id / arg.X[0]) % arg.X[1];
498  int x0 = id % arg.X[0];
499  id = (x0 + (x1 + (x2 + x3 * arg.X[2]) * arg.X[1]) * arg.X[0]) >> 1;
500  id += ((x0 + x1 + x2 + x3) & 1 ) * arg.threads / 2;
501 
502  for ( int i = 0; i < Elems; i++ ) arg.gx[id + i * arg.threads] = g.data[i];
503  //T=166 for Elems 9
504  //T=208 for Elems 6
505  }
506 
507  template<int Elems, typename Float>
508  class GaugeFix_GX : Tunable {
509  GaugeFixArg<Float> arg;
510  const GaugeField &meta;
511  Float half_alpha;
512  unsigned int sharedBytesPerThread() const { return 0; }
513  unsigned int sharedBytesPerBlock(const TuneParam &param) const { return 0; }
514  bool tuneGridDim() const { return false; }
515  unsigned int minThreads() const { return arg.threads; }
516 
517  public:
518  GaugeFix_GX(GaugeFixArg<Float> &arg, Float alpha, const GaugeField &meta) :
519  arg(arg),
520  meta(meta)
521  {
522  half_alpha = alpha * 0.5;
523  }
524 
525  void setAlpha(Float alpha) { half_alpha = alpha * 0.5; }
526 
527  void apply(const qudaStream_t &stream){
528  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
529  qudaLaunchKernel(kernel_gauge_GX<Elems, Float>, tp, stream, arg, half_alpha);
530  }
531 
532  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
533 
534  long long flops() const {
535  if ( Elems == 6 ) return 208LL * arg.threads;
536  else return 166LL * arg.threads;
537  }
538  long long bytes() const { return 4LL * Elems * sizeof(Float) * arg.threads; }
539  };
540 
541  template <int Elems, typename Float, typename Gauge>
542  __global__ void kernel_gauge_fix_U_EO( GaugeFixArg<Float> arg, Gauge dataOr)
543  {
544  int idd = threadIdx.x + blockIdx.x * blockDim.x;
545  if ( idd >= arg.threads ) return;
546 
547  int parity = 0;
548  int id = idd;
549  if ( idd >= arg.threads / 2 ) {
550  parity = 1;
551  id -= arg.threads / 2;
552  }
553  typedef complex<Float> Cmplx;
554 
555  Matrix<Cmplx,3> g;
556  //for(int i = 0; i < Elems; i++) g.data[i] = arg.gx[idd + i * arg.threads];
557  for ( int i = 0; i < Elems; i++ ) {
558  g.data[i] = arg.gx[idd + i * arg.threads];
559  }
560  if ( Elems == 6 ) {
561  g(2,0) = conj(g(0,1) * g(1,2) - g(0,2) * g(1,1));
562  g(2,1) = conj(g(0,2) * g(1,0) - g(0,0) * g(1,2));
563  g(2,2) = conj(g(0,0) * g(1,1) - g(0,1) * g(1,0));
564  //42
565  }
566  int x[4];
567  getCoords(x, id, arg.X, parity);
568  for ( int mu = 0; mu < 4; mu++ ) {
569  Matrix<Cmplx,3> U = dataOr(mu, id, parity);
570  Matrix<Cmplx,3> g0;
571  U = g * U;
572  //198
573  int idm1 = linkIndexP1(x,arg.X,mu);
574  idm1 += (1 - parity) * arg.threads / 2;
575  //for(int i = 0; i < Elems; i++) g0.data[i] = arg.gx[idm1 + i * arg.threads];
576  for ( int i = 0; i < Elems; i++ ) {
577  g0.data[i] = arg.gx[idm1 + i * arg.threads];
578  }
579  if ( Elems == 6 ) {
580  g0(2,0) = conj(g0(0,1) * g0(1,2) - g0(0,2) * g0(1,1));
581  g0(2,1) = conj(g0(0,2) * g0(1,0) - g0(0,0) * g0(1,2));
582  g0(2,2) = conj(g0(0,0) * g0(1,1) - g0(0,1) * g0(1,0));
583  //42
584  }
585  U = U * conj(g0);
586  //198
587  dataOr(mu, id, parity) = U;
588  }
589  //T=42+4*(198*2+42) Elems=6
590  //T=4*(198*2) Elems=9
591  //Not accounting here the reconstruction of the gauge if 12 or 8!!!!!!
592  }
593 
594  template<int Elems, typename Float, typename Gauge>
595  class GaugeFix : Tunable {
596  GaugeFixArg<Float> arg;
597  const GaugeField &meta;
598  Gauge dataOr;
599  unsigned int sharedBytesPerThread() const { return 0; }
600  unsigned int sharedBytesPerBlock(const TuneParam &param) const { return 0; }
601  bool tuneGridDim() const { return false; }
602  unsigned int minThreads() const { return arg.threads; }
603 
604  public:
605  GaugeFix(Gauge & dataOr, GaugeFixArg<Float> &arg, const GaugeField &meta) :
606  dataOr(dataOr),
607  arg(arg),
608  meta(meta)
609  { }
610 
611  void apply(const qudaStream_t &stream) {
612  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
613  qudaLaunchKernel(kernel_gauge_fix_U_EO<Elems, Float, Gauge>, tp, stream, arg, dataOr);
614  }
615 
616  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
617 
618  void preTune() { arg.data.backup(); }
619  void postTune() { arg.data.restore(); }
620  long long flops() const {
621  if ( Elems == 6 ) return 1794LL * arg.threads;
622  else return 1536LL * arg.threads;
623  }
624  long long bytes() const { return 26LL * Elems * sizeof(Float) * arg.threads; }
625  };
626 #endif
627 //GAUGEFIXING_DONT_USE_GX
628 
629  template<int Elems, typename Float, typename Gauge, int gauge_dir>
630  void gaugefixingFFT(Gauge dataOr, GaugeField& data, const int Nsteps, const int verbose_interval,
631  const Float alpha0, const int autotune, const double tolerance, const int stopWtheta)
632  {
633  TimeProfile profileInternalGaugeFixFFT("InternalGaugeFixQudaFFT", false);
634 
635  profileInternalGaugeFixFFT.TPSTART(QUDA_PROFILE_COMPUTE);
636 
637  Float alpha = alpha0;
638  std::cout << "\tAlpha parameter of the Steepest Descent Method: " << alpha << std::endl;
639  if ( autotune ) std::cout << "\tAuto tune active: yes" << std::endl;
640  else std::cout << "\tAuto tune active: no" << std::endl;
641  std::cout << "\tStop criterium: " << tolerance << std::endl;
642  if ( stopWtheta ) std::cout << "\tStop criterium method: theta" << std::endl;
643  else std::cout << "\tStop criterium method: Delta" << std::endl;
644  std::cout << "\tMaximum number of iterations: " << Nsteps << std::endl;
645  std::cout << "\tPrint convergence results at every " << verbose_interval << " steps" << std::endl;
646 
647 
648  unsigned int delta_pad = data.X()[0] * data.X()[1] * data.X()[2] * data.X()[3];
649  int4 size = make_int4( data.X()[0], data.X()[1], data.X()[2], data.X()[3] );
650  cufftHandle plan_xy;
651  cufftHandle plan_zt;
652 
653  GaugeFixArg<Float> arg(data, Elems);
654  SetPlanFFT2DMany( plan_zt, size, 0, arg.delta); //for space and time ZT
655  SetPlanFFT2DMany( plan_xy, size, 1, arg.delta); //with space only XY
656 
657  GaugeFixFFTRotateArg<Float> arg_rotate(data);
658  GaugeFixFFTRotate<Float, decltype(arg_rotate)> GFRotate(arg_rotate, data);
659 
660  GaugeFixSETINVPSP<Float> setinvpsp(arg, data);
661  setinvpsp.apply(0);
662  GaugeFixINVPSP<Float> invpsp(arg, data);
663 
664 #ifdef GAUGEFIXING_DONT_USE_GX
665  //without using GX, gx will be created only for plane rotation but with less size
666  GaugeFixNEW<Float, Gauge> gfixNew(dataOr, arg, alpha, data);
667 #else
668  //using GX
669  GaugeFix_GX<Elems, Float> calcGX(arg, alpha, data);
670  GaugeFix<Elems, Float, Gauge> gfix(dataOr, arg, data);
671 #endif
672 
673  GaugeFixQualityArg<Float, Gauge> argQ(dataOr, data, arg.delta);
674  GaugeFixQuality<Elems, Float, Gauge, gauge_dir> gfixquality(argQ, data);
675 
676  gfixquality.apply(0);
677  double action0 = argQ.getAction();
678  printf("Step: %d\tAction: %.16e\ttheta: %.16e\n", 0, argQ.getAction(), argQ.getTheta());
679 
680  double diff = 0.0;
681  int iter = 0;
682  for ( iter = 0; iter < Nsteps; iter++ ) {
683  for ( int k = 0; k < 6; k++ ) {
684  //------------------------------------------------------------------------
685  // Set a pointer do the element k in lattice volume
686  // each element is stored with stride lattice volume
687  // it uses gx as temporary array!!!!!!
688  //------------------------------------------------------------------------
689  complex<Float> *_array = arg.delta + k * delta_pad;
690  ////// 2D FFT + 2D FFT
691  //------------------------------------------------------------------------
692  // Perform FFT on xy plane
693  //------------------------------------------------------------------------
694  ApplyFFT(plan_xy, _array, arg.gx, CUFFT_FORWARD);
695  //------------------------------------------------------------------------
696  // Rotate hypercube, xyzt -> ztxy
697  //------------------------------------------------------------------------
698  GFRotate.setDirection(0, arg.gx, _array);
699  GFRotate.apply(0);
700  //------------------------------------------------------------------------
701  // Perform FFT on zt plane
702  //------------------------------------------------------------------------
703  ApplyFFT(plan_zt, _array, arg.gx, CUFFT_FORWARD);
704  //------------------------------------------------------------------------
705  // Normalize FFT and apply pmax^2/p^2
706  //------------------------------------------------------------------------
707  invpsp.apply(0);
708  //------------------------------------------------------------------------
709  // Perform IFFT on zt plane
710  //------------------------------------------------------------------------
711  ApplyFFT(plan_zt, arg.gx, _array, CUFFT_INVERSE);
712  //------------------------------------------------------------------------
713  // Rotate hypercube, ztxy -> xyzt
714  //------------------------------------------------------------------------
715  GFRotate.setDirection(1, _array, arg.gx);
716  GFRotate.apply(0);
717  //------------------------------------------------------------------------
718  // Perform IFFT on xy plane
719  //------------------------------------------------------------------------
720  ApplyFFT(plan_xy, arg.gx, _array, CUFFT_INVERSE);
721  }
722  #ifdef GAUGEFIXING_DONT_USE_GX
723  //------------------------------------------------------------------------
724  // Apply gauge fix to current gauge field
725  //------------------------------------------------------------------------
726  gfixNew.apply(0);
727  #else
728  //------------------------------------------------------------------------
729  // Calculate g(x)
730  //------------------------------------------------------------------------
731  calcGX.apply(0);
732  //------------------------------------------------------------------------
733  // Apply gauge fix to current gauge field
734  //------------------------------------------------------------------------
735  gfix.apply(0);
736  #endif
737  //------------------------------------------------------------------------
738  // Measure gauge quality and recalculate new Delta(x)
739  //------------------------------------------------------------------------
740  gfixquality.apply(0);
741  double action = argQ.getAction();
742  diff = abs(action0 - action);
743  if ((iter % verbose_interval) == (verbose_interval - 1))
744  printf("Step: %d\tAction: %.16e\ttheta: %.16e\tDelta: %.16e\n", iter + 1, argQ.getAction(), argQ.getTheta(), diff);
745  if ( autotune && ((action - action0) < -1e-14) ) {
746  if ( alpha > 0.01 ) {
747  alpha = 0.95 * alpha;
748  #ifdef GAUGEFIXING_DONT_USE_GX
749  gfixNew.setAlpha(alpha);
750  #else
751  calcGX.setAlpha(alpha);
752  #endif
753  printf(">>>>>>>>>>>>>> Warning: changing alpha down -> %.4e\n", alpha );
754  }
755  }
756  //------------------------------------------------------------------------
757  // Check gauge fix quality criterium
758  //------------------------------------------------------------------------
759  if ( stopWtheta ) { if ( argQ.getTheta() < tolerance ) break; }
760  else { if ( diff < tolerance ) break; }
761 
762  action0 = action;
763  }
764  if ((iter % verbose_interval) != 0 )
765  printf("Step: %d\tAction: %.16e\ttheta: %.16e\tDelta: %.16e\n", iter, argQ.getAction(), argQ.getTheta(), diff);
766 
767  // Reunitarize at end
768  const double unitarize_eps = 1e-14;
769  const double max_error = 1e-10;
770  const int reunit_allow_svd = 1;
771  const int reunit_svd_only = 0;
772  const double svd_rel_error = 1e-6;
773  const double svd_abs_error = 1e-6;
774  setUnitarizeLinksConstants(unitarize_eps, max_error,
775  reunit_allow_svd, reunit_svd_only,
776  svd_rel_error, svd_abs_error);
777  int num_failures = 0;
778  int* num_failures_dev = static_cast<int*>(pool_device_malloc(sizeof(int)));
779  qudaMemset(num_failures_dev, 0, sizeof(int));
780  unitarizeLinks(data, data, num_failures_dev);
781  qudaMemcpy(&num_failures, num_failures_dev, sizeof(int), cudaMemcpyDeviceToHost);
782 
783  pool_device_free(num_failures_dev);
784  if ( num_failures > 0 ) {
785  errorQuda("Error in the unitarization\n");
786  exit(1);
787  }
788  // end reunitarize
789 
790  arg.free();
791  CUFFT_SAFE_CALL(cufftDestroy(plan_zt));
792  CUFFT_SAFE_CALL(cufftDestroy(plan_xy));
793  qudaDeviceSynchronize();
794  profileInternalGaugeFixFFT.TPSTOP(QUDA_PROFILE_COMPUTE);
795 
796  if (getVerbosity() > QUDA_SUMMARIZE){
797  double secs = profileInternalGaugeFixFFT.Last(QUDA_PROFILE_COMPUTE);
798  double fftflop = 5.0 * (log2((double)( data.X()[0] * data.X()[1]) ) + log2( (double)(data.X()[2] * data.X()[3] )));
799  fftflop *= (double)( data.X()[0] * data.X()[1] * data.X()[2] * data.X()[3] );
800  double gflops = setinvpsp.flops() + gfixquality.flops();
801  double gbytes = setinvpsp.bytes() + gfixquality.bytes();
802  double flop = invpsp.flops() * Elems;
803  double byte = invpsp.bytes() * Elems;
804  flop += (GFRotate.flops() + fftflop) * Elems * 2;
805  byte += GFRotate.bytes() * Elems * 4; //includes FFT reads, assuming 1 read and 1 write per site
806  #ifdef GAUGEFIXING_DONT_USE_GX
807  flop += gfixNew.flops();
808  byte += gfixNew.bytes();
809  #else
810  flop += calcGX.flops();
811  byte += calcGX.bytes();
812  flop += gfix.flops();
813  byte += gfix.bytes();
814  #endif
815  flop += gfixquality.flops();
816  byte += gfixquality.bytes();
817  gflops += flop * iter;
818  gbytes += byte * iter;
819  gflops += 4588.0 * data.X()[0]*data.X()[1]*data.X()[2]*data.X()[3]; //Reunitarize at end
820  gbytes += 8.0 * data.X()[0]*data.X()[1]*data.X()[2]*data.X()[3] * dataOr.Bytes() ; //Reunitarize at end
821 
822  gflops = (gflops * 1e-9) / (secs);
823  gbytes = gbytes / (secs * 1e9);
824  printfQuda("Time: %6.6f s, Gflop/s = %6.1f, GB/s = %6.1f\n", secs, gflops, gbytes);
825  }
826  }
827 
828  template<typename Float, int nColors, QudaReconstructType recon> struct GaugeFixingFFT {
829  GaugeFixingFFT(GaugeField& data, const int gauge_dir, const int Nsteps, const int verbose_interval, const Float alpha,
830  const int autotune, const double tolerance, const int stopWtheta)
831  {
832  using Gauge = typename gauge_mapper<Float, recon>::type;
833  constexpr int n_element = recon / 2; // number of complex elements used to store g(x) and Delta(x)
834  if ( gauge_dir != 3 ) {
835  printfQuda("Starting Landau gauge fixing with FFTs...\n");
836  gaugefixingFFT<n_element, Float, Gauge, 4>(Gauge(data), data, Nsteps, verbose_interval, alpha, autotune, tolerance, stopWtheta);
837  } else {
838  printfQuda("Starting Coulomb gauge fixing with FFTs...\n");
839  gaugefixingFFT<n_element, Float, Gauge, 3>(Gauge(data), data, Nsteps, verbose_interval, alpha, autotune, tolerance, stopWtheta);
840  }
841  }
842  };
843 
844  /**
845  * @brief Gauge fixing with Steepest descent method with FFTs with support for single GPU only.
846  * @param[in,out] data, quda gauge field
847  * @param[in] gauge_dir, 3 for Coulomb gauge fixing, other for Landau gauge fixing
848  * @param[in] Nsteps, maximum number of steps to perform gauge fixing
849  * @param[in] verbose_interval, print gauge fixing info when iteration count is a multiple of this
850  * @param[in] alpha, gauge fixing parameter of the method, most common value is 0.08
851  * @param[in] autotune, 1 to autotune the method, i.e., if the Fg inverts its tendency we decrease the alpha value
852  * @param[in] tolerance, torelance value to stop the method, if this value is zero then the method stops when iteration reachs the maximum number of steps defined by Nsteps
853  * @param[in] stopWtheta, 0 for MILC criterium and 1 to use the theta value
854  */
855  void gaugeFixingFFT(GaugeField& data, const int gauge_dir, const int Nsteps, const int verbose_interval, const double alpha,
856  const int autotune, const double tolerance, const int stopWtheta)
857  {
858 #ifdef GPU_GAUGE_ALG
859 #ifdef MULTI_GPU
860  if (comm_dim_partitioned(0) || comm_dim_partitioned(1) || comm_dim_partitioned(2) || comm_dim_partitioned(3))
861  errorQuda("Gauge Fixing with FFTs in multi-GPU support NOT implemented yet!\n");
862 #endif
863  instantiate<GaugeFixingFFT>(data, gauge_dir, Nsteps, verbose_interval, (float)alpha, autotune, tolerance, stopWtheta);
864 #else
865  errorQuda("Gauge fixing has bot been built");
866 #endif
867  }
868 
869 }