QUDA  0.9.0
gauge_fix_fft.cu
Go to the documentation of this file.
1 #include <quda_internal.h>
2 #include <quda_matrix.h>
3 #include <tune_quda.h>
4 #include <gauge_field.h>
5 #include <gauge_field_order.h>
6 #include <launch_kernel.cuh>
7 #include <unitarization_links.h>
8 #include <atomic.cuh>
9 #include <cub_helper.cuh>
10 #include <index_helper.cuh>
11 
12 #include <cufft.h>
13 
14 #ifdef GPU_GAUGE_ALG
15 #include <CUFFT_Plans.h>
16 #endif
17 
18 namespace quda {
19 
20 #ifdef GPU_GAUGE_ALG
21 
22 //Comment if you don't want to use textures for Delta(x) and g(x)
23 #define GAUGEFIXING_SITE_MATRIX_LOAD_TEX
24 
25 //UNCOMMENT THIS IF YOU WAN'T TO USE LESS MEMORY
26 #define GAUGEFIXING_DONT_USE_GX
27 //Without using the precalculation of g(x),
28 //we loose some performance, because Delta(x) is written in normal lattice coordinates need for the FFTs
29 //and the gauge array in even/odd format
30 
31 
32 
33 #ifdef HOST_DEBUG
34 #ifdef GAUGEFIXING_DONT_USE_GX
35 #warning Not using precalculated g(x)
36 #else
37 #warning Using precalculated g(x)
38 #endif
39 #endif
40 
41 
42 #ifndef FL_UNITARIZE_PI
43 #define FL_UNITARIZE_PI 3.14159265358979323846
44 #endif
45 
46 
47  texture<float2, 1, cudaReadModeElementType> GXTexSingle;
48  texture<int4, 1, cudaReadModeElementType> GXTexDouble;
49 //Delta is only stored using 12 real number parameters,
50 // (0,0), (0,1), (0,2), (1,1), (1,2) and (2,2)
51 // (0,0), (1,1) and (0,1) don't have real part, however we need a complex for the FFTs
52  texture<float2, 1, cudaReadModeElementType> DELTATexSingle;
53  texture<int4, 1, cudaReadModeElementType> DELTATexDouble;
54 
55 
56  template <class T>
57  inline __device__ T TEXTURE_GX(int id){
58  return 0.0;
59  }
60  template <>
61  inline __device__ complex<float> TEXTURE_GX<complex<float> >(int id){
62  return tex1Dfetch(GXTexSingle, id);
63  }
64  template <>
65  inline __device__ complex<double> TEXTURE_GX<complex<double> >(int id){
66  int4 u = tex1Dfetch(GXTexDouble, id);
67  return complex<double>(__hiloint2double(u.y, u.x), __hiloint2double(u.w, u.z));
68  }
69  template <class T>
70  inline __device__ T TEXTURE_DELTA(int id){
71  return 0.0;
72  }
73  template <>
74  inline __device__ complex<float> TEXTURE_DELTA<complex<float> >(int id){
75  return tex1Dfetch(DELTATexSingle, id);
76  }
77  template <>
78  inline __device__ complex<double> TEXTURE_DELTA<complex<double> >(int id){
79  int4 u = tex1Dfetch(DELTATexDouble, id);
80  return complex<double>(__hiloint2double(u.y, u.x), __hiloint2double(u.w, u.z));
81  }
82 
83  static void BindTex(complex<float> *delta, complex<float> *gx, size_t bytes){
84 #ifdef GAUGEFIXING_SITE_MATRIX_LOAD_TEX
85 #ifndef GAUGEFIXING_DONT_USE_GX
86  cudaBindTexture(0, GXTexSingle, gx, bytes);
87 #endif
88  cudaBindTexture(0, DELTATexSingle, delta, bytes);
89 #endif
90  }
91 
92  static void BindTex(complex<double> *delta, complex<double> *gx, size_t bytes){
93 #ifdef GAUGEFIXING_SITE_MATRIX_LOAD_TEX
94 #ifndef GAUGEFIXING_DONT_USE_GX
95  cudaBindTexture(0, GXTexDouble, gx, bytes);
96 #endif
97  cudaBindTexture(0, DELTATexDouble, delta, bytes);
98 #endif
99  }
100 
101  static void UnBindTex(complex<float> *delta, complex<float> *gx){
102 #ifdef GAUGEFIXING_SITE_MATRIX_LOAD_TEX
103 #ifndef GAUGEFIXING_DONT_USE_GX
104  cudaUnbindTexture(GXTexSingle);
105 #endif
106  cudaUnbindTexture(DELTATexSingle);
107 #endif
108  }
109 
110  static void UnBindTex(complex<double> *delta, complex<double> *gx){
111 #ifdef GAUGEFIXING_SITE_MATRIX_LOAD_TEX
112 #ifndef GAUGEFIXING_DONT_USE_GX
113  cudaUnbindTexture(GXTexDouble);
114 #endif
115  cudaUnbindTexture(DELTATexDouble);
116 #endif
117  }
118 
119 
120  template <typename Float>
121  struct GaugeFixFFTRotateArg {
122  int threads; // number of active threads required
123  int X[4]; // grid dimensions
124  complex<Float> *tmp0;
125  complex<Float> *tmp1;
126  GaugeFixFFTRotateArg(const cudaGaugeField &data){
127  for ( int dir = 0; dir < 4; ++dir ) X[dir] = data.X()[dir];
128  threads = X[0] * X[1] * X[2] * X[3];
129  tmp0 = 0;
130  tmp1 = 0;
131  }
132  };
133 
134 
135 
136  template <int direction, typename Float>
137  __global__ void fft_rotate_kernel_2D2D(GaugeFixFFTRotateArg<Float> arg){ //Cmplx *data_in, Cmplx *data_out){
138  int id = blockIdx.x * blockDim.x + threadIdx.x;
139  if ( id >= arg.threads ) return;
140  if ( direction == 0 ) {
141  int x3 = id / (arg.X[0] * arg.X[1] * arg.X[2]);
142  int x2 = (id / (arg.X[0] * arg.X[1])) % arg.X[2];
143  int x1 = (id / arg.X[0]) % arg.X[1];
144  int x0 = id % arg.X[0];
145 
146  int id = x0 + (x1 + (x2 + x3 * arg.X[2]) * arg.X[1]) * arg.X[0];
147  int id_out = x2 + (x3 + (x0 + x1 * arg.X[0]) * arg.X[3]) * arg.X[2];
148  arg.tmp1[id_out] = arg.tmp0[id];
149  //data_out[id_out] = data_in[id];
150  }
151  if ( direction == 1 ) {
152 
153  int x1 = id / (arg.X[2] * arg.X[3] * arg.X[0]);
154  int x0 = (id / (arg.X[2] * arg.X[3])) % arg.X[0];
155  int x3 = (id / arg.X[2]) % arg.X[3];
156  int x2 = id % arg.X[2];
157 
158  int id = x2 + (x3 + (x0 + x1 * arg.X[0]) * arg.X[3]) * arg.X[2];
159  int id_out = x0 + (x1 + (x2 + x3 * arg.X[2]) * arg.X[1]) * arg.X[0];
160  arg.tmp1[id_out] = arg.tmp0[id];
161  //data_out[id_out] = data_in[id];
162  }
163  }
164 
165 
166 
167 
168 
169 
170  template<typename Float>
171  class GaugeFixFFTRotate : Tunable {
172  GaugeFixFFTRotateArg<Float> arg;
173  int direction;
174  mutable char aux_string[128]; // used as a label in the autotuner
175  private:
176  unsigned int sharedBytesPerThread() const {
177  return 0;
178  }
179  unsigned int sharedBytesPerBlock(const TuneParam &param) const {
180  return 0;
181  }
182  //bool tuneSharedBytes() const { return false; } // Don't tune shared memory
183  bool tuneGridDim() const {
184  return false;
185  } // Don't tune the grid dimensions.
186  unsigned int minThreads() const {
187  return arg.threads;
188  }
189 
190  public:
191  GaugeFixFFTRotate(GaugeFixFFTRotateArg<Float> &arg) : arg(arg) {
192  direction = 0;
193  }
194  ~GaugeFixFFTRotate () {
195  }
196  void setDirection(int dir, complex<Float> *data_in, complex<Float> *data_out){
197  direction = dir;
198  arg.tmp0 = data_in;
199  arg.tmp1 = data_out;
200  }
201 
202  void apply(const cudaStream_t &stream){
203  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
204  if ( direction == 0 )
205  fft_rotate_kernel_2D2D<0, Float ><< < tp.grid, tp.block, 0, stream >> > (arg);
206  else if ( direction == 1 )
207  fft_rotate_kernel_2D2D<1, Float ><< < tp.grid, tp.block, 0, stream >> > (arg);
208  else
209  errorQuda("Error in GaugeFixFFTRotate option.\n");
210  }
211 
212  TuneKey tuneKey() const {
213  std::stringstream vol;
214  vol << arg.X[0] << "x";
215  vol << arg.X[1] << "x";
216  vol << arg.X[2] << "x";
217  vol << arg.X[3];
218  sprintf(aux_string,"threads=%d,prec=%lu", arg.threads, sizeof(Float));
219  return TuneKey(vol.str().c_str(), typeid(*this).name(), aux_string);
220 
221  }
222 
223  long long flops() const {
224  return 0;
225  }
226  long long bytes() const {
227  return 4LL * sizeof(Float) * arg.threads;
228  }
229 
230  };
231 
232 
233  template <typename Float, typename Gauge>
234  struct GaugeFixQualityArg : public ReduceArg<double2> {
235  int threads; // number of active threads required
236  int X[4]; // grid dimensions
237  Gauge dataOr;
238  complex<Float> *delta;
239 
240  GaugeFixQualityArg(const Gauge &dataOr, const cudaGaugeField &data, complex<Float> * delta)
241  : ReduceArg<double2>(), dataOr(dataOr), delta(delta) {
242  for ( int dir = 0; dir < 4; ++dir ) X[dir] = data.X()[dir];
243  threads = data.VolumeCB();
244  }
245  double getAction(){ return result_h[0].x; }
246  double getTheta(){ return result_h[0].y; }
247  };
248 
249 
250 
251  template<int blockSize, int Elems, typename Float, typename Gauge, int gauge_dir>
252  __global__ void computeFix_quality(GaugeFixQualityArg<Float, Gauge> argQ){
253  int idx = threadIdx.x + blockIdx.x * blockDim.x;
254  int parity = threadIdx.y;
255 
256  double2 data = make_double2(0.0,0.0);
257  if ( idx < argQ.threads ) {
258  typedef complex<Float> Cmplx;
259 
260  int x[4];
261  getCoords(x, idx, argQ.X, parity);
263  setZero(&delta);
264  //idx = linkIndex(x,X);
265  for ( int mu = 0; mu < gauge_dir; mu++ ) {
266  Matrix<Cmplx,3> U;
267  argQ.dataOr.load((Float*)(U.data),idx, mu, parity);
268  delta -= U;
269  }
270  //18*gauge_dir
271  data.x = -delta(0,0).x - delta(1,1).x - delta(2,2).x;
272  //2
273  for ( int mu = 0; mu < gauge_dir; mu++ ) {
274  Matrix<Cmplx,3> U;
275  argQ.dataOr.load((Float*)(U.data),linkIndexM1(x,argQ.X,mu), mu, 1 - parity);
276  delta += U;
277  }
278  //18*gauge_dir
279  delta -= conj(delta);
280  //18
281  //SAVE DELTA!!!!!
283  idx = getIndexFull(idx, argQ.X, parity);
284  //Saving Delta
285  argQ.delta[idx] = delta(0,0);
286  argQ.delta[idx + 2 * argQ.threads] = delta(0,1);
287  argQ.delta[idx + 4 * argQ.threads] = delta(0,2);
288  argQ.delta[idx + 6 * argQ.threads] = delta(1,1);
289  argQ.delta[idx + 8 * argQ.threads] = delta(1,2);
290  argQ.delta[idx + 10 * argQ.threads] = delta(2,2);
291  //12
292  data.y = getRealTraceUVdagger(delta, delta);
293  //35
294  //T=36*gauge_dir+65
295  }
296 
297  reduce2d<blockSize,2>(argQ, data);
298  }
299 
300 
301 
302  template<int Elems, typename Float, typename Gauge, int gauge_dir>
303  class GaugeFixQuality : TunableLocalParity {
304  GaugeFixQualityArg<Float, Gauge> argQ;
305  mutable char aux_string[128]; // used as a label in the autotuner
306  private:
307 
308  unsigned int minThreads() const { return argQ.threads; }
309 
310  public:
311  GaugeFixQuality(GaugeFixQualityArg<Float, Gauge> &argQ)
312  : argQ(argQ) {
313  }
314  ~GaugeFixQuality () { }
315 
316  void apply(const cudaStream_t &stream){
317  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
318  argQ.result_h[0] = make_double2(0.0,0.0);
319  LAUNCH_KERNEL_LOCAL_PARITY(computeFix_quality, tp, stream, argQ, Elems, Float, Gauge, gauge_dir);
321  argQ.result_h[0].x /= (double)(3 * gauge_dir * 2 * argQ.threads);
322  argQ.result_h[0].y /= (double)(3 * 2 * argQ.threads);
323  }
324 
325  TuneKey tuneKey() const {
326  std::stringstream vol;
327  vol << argQ.X[0] << "x" << argQ.X[1] << "x" << argQ.X[2] << "x" << argQ.X[3];
328  sprintf(aux_string,"threads=%d,prec=%lu,gaugedir=%d", argQ.threads, sizeof(Float), gauge_dir);
329  return TuneKey(vol.str().c_str(), typeid(*this).name(), aux_string);
330  }
331 
332  long long flops() const {
333  return (36LL * gauge_dir + 65LL) * 2 * argQ.threads;
334  } // Only correct if there is no link reconstruction, no cub reduction accounted also
335  long long bytes() const {
336  return (2LL * gauge_dir + 2LL) * Elems * 2 * argQ.threads * sizeof(Float);
337  } //Not accounting the reduction!!!
338 
339  };
340 
341 
342 
343  template <typename Float>
344  struct GaugeFixArg {
345  int threads; // number of active threads required
346  int X[4]; // grid dimensions
347  cudaGaugeField &data;
348  Float *invpsq;
349  complex<Float> *delta;
350  complex<Float> *gx;
351 
352  GaugeFixArg( cudaGaugeField & data, const int Elems) : data(data){
353  for ( int dir = 0; dir < 4; ++dir ) X[dir] = data.X()[dir];
354  threads = X[0] * X[1] * X[2] * X[3];
355  invpsq = (Float*)device_malloc(sizeof(Float) * threads);
356  delta = (complex<Float>*)device_malloc(sizeof(complex<Float>) * threads * 6);
357 #ifdef GAUGEFIXING_DONT_USE_GX
358  gx = (complex<Float>*)device_malloc(sizeof(complex<Float>) * threads);
359 #else
360  gx = (complex<Float>*)device_malloc(sizeof(complex<Float>) * threads * Elems);
361 #endif
362  BindTex(delta, gx, sizeof(complex<Float>) * threads * Elems);
363  }
364  void free(){
365  UnBindTex(delta, gx);
366  device_free(invpsq);
368  device_free(gx);
369  }
370  };
371 
372 
373 
374 
375  template <typename Float>
376  __global__ void kernel_gauge_set_invpsq(GaugeFixArg<Float> arg){
377  int id = blockIdx.x * blockDim.x + threadIdx.x;
378  if ( id >= arg.threads ) return;
379  int x1 = id / (arg.X[2] * arg.X[3] * arg.X[0]);
380  int x0 = (id / (arg.X[2] * arg.X[3])) % arg.X[0];
381  int x3 = (id / arg.X[2]) % arg.X[3];
382  int x2 = id % arg.X[2];
383  //id = x2 + (x3 + (x0 + x1 * arg.X[0]) * arg.X[3]) * arg.X[2];
384  Float sx = sin( (Float)x0 * FL_UNITARIZE_PI / (Float)arg.X[0]);
385  Float sy = sin( (Float)x1 * FL_UNITARIZE_PI / (Float)arg.X[1]);
386  Float sz = sin( (Float)x2 * FL_UNITARIZE_PI / (Float)arg.X[2]);
387  Float st = sin( (Float)x3 * FL_UNITARIZE_PI / (Float)arg.X[3]);
388  Float sinsq = sx * sx + sy * sy + sz * sz + st * st;
389  Float prcfact = 0.0;
390  //The FFT normalization is done here
391  if ( sinsq > 0.00001 ) prcfact = 4.0 / (sinsq * (Float)arg.threads);
392  arg.invpsq[id] = prcfact;
393  }
394 
395 
396  template<typename Float>
397  class GaugeFixSETINVPSP : Tunable {
398  GaugeFixArg<Float> arg;
399  mutable char aux_string[128]; // used as a label in the autotuner
400  private:
401  unsigned int sharedBytesPerThread() const {
402  return 0;
403  }
404  unsigned int sharedBytesPerBlock(const TuneParam &param) const {
405  return 0;
406  }
407  bool tuneSharedBytes() const {
408  return false;
409  } // Don't tune shared memory
410  bool tuneGridDim() const {
411  return false;
412  } // Don't tune the grid dimensions.
413  unsigned int minThreads() const {
414  return arg.threads;
415  }
416 
417  public:
418  GaugeFixSETINVPSP(GaugeFixArg<Float> &arg) : arg(arg) { }
419  ~GaugeFixSETINVPSP () { }
420 
421  void apply(const cudaStream_t &stream){
422  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
423  kernel_gauge_set_invpsq<Float><< < tp.grid, tp.block, 0, stream >> > (arg);
424  }
425 
426  TuneKey tuneKey() const {
427  std::stringstream vol;
428  vol << arg.X[0] << "x";
429  vol << arg.X[1] << "x";
430  vol << arg.X[2] << "x";
431  vol << arg.X[3];
432  sprintf(aux_string,"threads=%d,prec=%lu", arg.threads, sizeof(Float));
433  return TuneKey(vol.str().c_str(), typeid(*this).name(), aux_string);
434 
435  }
436 
437  long long flops() const {
438  return 21 * arg.threads;
439  }
440  long long bytes() const {
441  return sizeof(Float) * arg.threads;
442  }
443 
444  };
445 
446  template<typename Float>
447  __global__ void kernel_gauge_mult_norm_2D(GaugeFixArg<Float> arg){
448  int id = blockIdx.x * blockDim.x + threadIdx.x;
449  if ( id < arg.threads ) arg.gx[id] = arg.gx[id] * arg.invpsq[id];
450  }
451 
452 
453  template<typename Float>
454  class GaugeFixINVPSP : Tunable {
455  GaugeFixArg<Float> arg;
456  mutable char aux_string[128]; // used as a label in the autotuner
457  private:
458  unsigned int sharedBytesPerThread() const {
459  return 0;
460  }
461  unsigned int sharedBytesPerBlock(const TuneParam &param) const {
462  return 0;
463  }
464  //bool tuneSharedBytes() const { return false; } // Don't tune shared memory
465  bool tuneGridDim() const {
466  return false;
467  } // Don't tune the grid dimensions.
468  unsigned int minThreads() const {
469  return arg.threads;
470  }
471 
472  public:
473  GaugeFixINVPSP(GaugeFixArg<Float> &arg)
474  : arg(arg){
475  cudaFuncSetCacheConfig( kernel_gauge_mult_norm_2D<Float>, cudaFuncCachePreferL1);
476  }
477  ~GaugeFixINVPSP () {
478  }
479 
480  void apply(const cudaStream_t &stream){
481  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
482  kernel_gauge_mult_norm_2D<Float><< < tp.grid, tp.block, 0, stream >> > (arg);
483  }
484 
485  TuneKey tuneKey() const {
486  std::stringstream vol;
487  vol << arg.X[0] << "x";
488  vol << arg.X[1] << "x";
489  vol << arg.X[2] << "x";
490  vol << arg.X[3];
491  sprintf(aux_string,"threads=%d,prec=%lu", arg.threads, sizeof(Float));
492  return TuneKey(vol.str().c_str(), typeid(*this).name(), aux_string);
493 
494  }
495 
496  void preTune(){
497  //since delta contents are irrelevant at this point, we can swap gx with delta
498  complex<Float> *tmp = arg.gx;
499  arg.gx = arg.delta;
500  arg.delta = tmp;
501  }
502  void postTune(){
503  arg.gx = arg.delta;
504  }
505  long long flops() const {
506  return 2LL * arg.threads;
507  }
508  long long bytes() const {
509  return 5LL * sizeof(Float) * arg.threads;
510  }
511 
512  };
513 
514 
515 
516  template <typename Float>
517  __host__ __device__ inline void reunit_link( Matrix<complex<Float>,3> &U ){
518 
519  complex<Float> t2((Float)0.0, (Float)0.0);
520  Float t1 = 0.0;
521  //first normalize first row
522  //sum of squares of row
523 #pragma unroll
524  for ( int c = 0; c < 3; c++ ) t1 += norm(U(0,c));
525  t1 = (Float)1.0 / sqrt(t1);
526  //14
527  //used to normalize row
528 #pragma unroll
529  for ( int c = 0; c < 3; c++ ) U(0,c) *= t1;
530  //6
531 #pragma unroll
532  for ( int c = 0; c < 3; c++ ) t2 += conj(U(0,c)) * U(1,c);
533  //24
534 #pragma unroll
535  for ( int c = 0; c < 3; c++ ) U(1,c) -= t2 * U(0,c);
536  //24
537  //normalize second row
538  //sum of squares of row
539  t1 = 0.0;
540 #pragma unroll
541  for ( int c = 0; c < 3; c++ ) t1 += norm(U(1,c));
542  t1 = (Float)1.0 / sqrt(t1);
543  //14
544  //used to normalize row
545 #pragma unroll
546  for ( int c = 0; c < 3; c++ ) U(1, c) *= t1;
547  //6
548  //Reconstruct lat row
549  U(2,0) = conj(U(0,1) * U(1,2) - U(0,2) * U(1,1));
550  U(2,1) = conj(U(0,2) * U(1,0) - U(0,0) * U(1,2));
551  U(2,2) = conj(U(0,0) * U(1,1) - U(0,1) * U(1,0));
552  //42
553  //T=130
554  }
555 
556 #ifdef GAUGEFIXING_DONT_USE_GX
557 
558  template <typename Float, typename Gauge>
559  __global__ void kernel_gauge_fix_U_EO_NEW( GaugeFixArg<Float> arg, Gauge dataOr, Float half_alpha){
560  int id = threadIdx.x + blockIdx.x * blockDim.x;
561  int parity = threadIdx.y;
562 
563  if ( id >= arg.threads/2 ) return;
564 
565  typedef complex<Float> Cmplx;
566 
567  int x[4];
568  getCoords(x, id, arg.X, parity);
569  int idx = ((x[3] * arg.X[2] + x[2]) * arg.X[1] + x[1]) * arg.X[0] + x[0];
570  Matrix<Cmplx,3> de;
571  //Read Delta
572 #ifdef GAUGEFIXING_SITE_MATRIX_LOAD_TEX
573  de(0,0) = TEXTURE_DELTA<Cmplx>(idx + 0 * arg.threads);
574  de(0,1) = TEXTURE_DELTA<Cmplx>(idx + 1 * arg.threads);
575  de(0,2) = TEXTURE_DELTA<Cmplx>(idx + 2 * arg.threads);
576  de(1,1) = TEXTURE_DELTA<Cmplx>(idx + 3 * arg.threads);
577  de(1,2) = TEXTURE_DELTA<Cmplx>(idx + 4 * arg.threads);
578  de(2,2) = TEXTURE_DELTA<Cmplx>(idx + 5 * arg.threads);
579 #else
580  de(0,0) = arg.delta[idx + 0 * arg.threads];
581  de(0,1) = arg.delta[idx + 1 * arg.threads];
582  de(0,2) = arg.delta[idx + 2 * arg.threads];
583  de(1,1) = arg.delta[idx + 3 * arg.threads];
584  de(1,2) = arg.delta[idx + 4 * arg.threads];
585  de(2,2) = arg.delta[idx + 5 * arg.threads];
586 #endif
587  de(1,0) = Cmplx(-de(0,1).x, de(0,1).y);
588  de(2,0) = Cmplx(-de(0,2).x, de(0,2).y);
589  de(2,1) = Cmplx(-de(1,2).x, de(1,2).y);
590  Matrix<Cmplx,3> g;
591  setIdentity(&g);
592  g += de * half_alpha;
593  //36
594  reunit_link<Float>( g );
595  //130
596 
597 
598  for ( int mu = 0; mu < 4; mu++ ) {
599  Matrix<Cmplx,3> U;
600  Matrix<Cmplx,3> g0;
601  dataOr.load((Float*)(U.data),id, mu, parity);
602  U = g * U;
603  //198
605  //Read Delta
606 #ifdef GAUGEFIXING_SITE_MATRIX_LOAD_TEX
607  de(0,0) = TEXTURE_DELTA<Cmplx>(idx + 0 * arg.threads);
608  de(0,1) = TEXTURE_DELTA<Cmplx>(idx + 1 * arg.threads);
609  de(0,2) = TEXTURE_DELTA<Cmplx>(idx + 2 * arg.threads);
610  de(1,1) = TEXTURE_DELTA<Cmplx>(idx + 3 * arg.threads);
611  de(1,2) = TEXTURE_DELTA<Cmplx>(idx + 4 * arg.threads);
612  de(2,2) = TEXTURE_DELTA<Cmplx>(idx + 5 * arg.threads);
613 #else
614  de(0,0) = arg.delta[idx + 0 * arg.threads];
615  de(0,1) = arg.delta[idx + 1 * arg.threads];
616  de(0,2) = arg.delta[idx + 2 * arg.threads];
617  de(1,1) = arg.delta[idx + 3 * arg.threads];
618  de(1,2) = arg.delta[idx + 4 * arg.threads];
619  de(2,2) = arg.delta[idx + 5 * arg.threads];
620 #endif
621  de(1,0) = Cmplx(-de(0,1).x, de(0,1).y);
622  de(2,0) = Cmplx(-de(0,2).x, de(0,2).y);
623  de(2,1) = Cmplx(-de(1,2).x, de(1,2).y);
624 
625  setIdentity(&g0);
626  g0 += de * half_alpha;
627  //36
628  reunit_link<Float>( g0 );
629  //130
630 
631  U = U * conj(g0);
632  //198
633  dataOr.save((Float*)(U.data),id, mu, parity);
634  }
635  }
636 
637 
638  template<typename Float, typename Gauge>
639  class GaugeFixNEW : TunableLocalParity {
640  GaugeFixArg<Float> arg;
641  Float half_alpha;
642  Gauge dataOr;
643  mutable char aux_string[128]; // used as a label in the autotuner
644  private:
645 
646  // since GaugeFixArg is used by other kernels that don't use
647  // tunableLocalParity, arg.threads stores Volume and not VolumeCB
648  // so we need to divide by two
649  unsigned int minThreads() const { return arg.threads/2; }
650 
651  public:
652  GaugeFixNEW(Gauge & dataOr, GaugeFixArg<Float> &arg, Float alpha)
653  : dataOr(dataOr), arg(arg) {
654  half_alpha = alpha * 0.5;
655  cudaFuncSetCacheConfig( kernel_gauge_fix_U_EO_NEW<Float, Gauge>, cudaFuncCachePreferL1);
656  }
657  ~GaugeFixNEW () { }
658 
659  void setAlpha(Float alpha){ half_alpha = alpha * 0.5; }
660 
661  void apply(const cudaStream_t &stream){
662  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
663  kernel_gauge_fix_U_EO_NEW<Float, Gauge><< < tp.grid, tp.block, 0, stream >> > (arg, dataOr, half_alpha);
664  }
665 
666  TuneKey tuneKey() const {
667  std::stringstream vol;
668  vol << arg.X[0] << "x" << arg.X[1] << "x" << arg.X[2] << "x" << arg.X[3];
669  sprintf(aux_string,"threads=%d,prec=%lu", arg.threads, sizeof(Float));
670  return TuneKey(vol.str().c_str(), typeid(*this).name(), aux_string);
671 
672  }
673 
674  //need this
675  void preTune() {
676  arg.data.backup();
677  }
678  void postTune() {
679  arg.data.restore();
680  }
681  long long flops() const {
682  return 2414LL * arg.threads;
683  //Not accounting here the reconstruction of the gauge if 12 or 8!!!!!!
684  }
685  long long bytes() const {
686  return ( dataOr.Bytes() * 4LL + 5 * 12LL * sizeof(Float)) * arg.threads;
687  }
688 
689  };
690 
691 
692 
693 #else
694  template <int Elems, typename Float>
695  __global__ void kernel_gauge_GX(GaugeFixArg<Float> arg, Float half_alpha){
696 
697  int id = blockIdx.x * blockDim.x + threadIdx.x;
698 
699  if ( id >= arg.threads ) return;
700 
701  typedef complex<Float> Cmplx;
702 
703  Matrix<Cmplx,3> de;
704  //Read Delta
705  #ifdef GAUGEFIXING_SITE_MATRIX_LOAD_TEX
706  de(0,0) = TEXTURE_DELTA<Cmplx>(id);
707  de(0,1) = TEXTURE_DELTA<Cmplx>(id + arg.threads);
708  de(0,2) = TEXTURE_DELTA<Cmplx>(id + 2 * arg.threads);
709  de(1,1) = TEXTURE_DELTA<Cmplx>(id + 3 * arg.threads);
710  de(1,2) = TEXTURE_DELTA<Cmplx>(id + 4 * arg.threads);
711  de(2,2) = TEXTURE_DELTA<Cmplx>(id + 5 * arg.threads);
712  #else
713  de(0,0) = arg.delta[id];
714  de(0,1) = arg.delta[id + arg.threads];
715  de(0,2) = arg.delta[id + 2 * arg.threads];
716  de(1,1) = arg.delta[id + 3 * arg.threads];
717  de(1,2) = arg.delta[id + 4 * arg.threads];
718  de(2,2) = arg.delta[id + 5 * arg.threads];
719  #endif
720  de(1,0) = makeComplex(-de(0,1).x, de(0,1).y);
721  de(2,0) = makeComplex(-de(0,2).x, de(0,2).y);
722  de(2,1) = makeComplex(-de(1,2).x, de(1,2).y);
723 
724 
725  Matrix<Cmplx,3> g;
726  setIdentity(&g);
727  g += de * half_alpha;
728  //36
729  reunit_link<Float>( g );
730  //130
731  //gx is represented in even/odd order
732  //normal lattice index to even/odd index
733  int x3 = id / (arg.X[0] * arg.X[1] * arg.X[2]);
734  int x2 = (id / (arg.X[0] * arg.X[1])) % arg.X[2];
735  int x1 = (id / arg.X[0]) % arg.X[1];
736  int x0 = id % arg.X[0];
737  id = (x0 + (x1 + (x2 + x3 * arg.X[2]) * arg.X[1]) * arg.X[0]) >> 1;
738  id += ((x0 + x1 + x2 + x3) & 1 ) * arg.threads / 2;
739 
740  for ( int i = 0; i < Elems; i++ ) arg.gx[id + i * arg.threads] = g.data[i];
741  //T=166 for Elems 9
742  //T=208 for Elems 6
743  }
744 
745 
746 
747 
748  template<int Elems, typename Float>
749  class GaugeFix_GX : Tunable {
750  GaugeFixArg<Float> arg;
751  Float half_alpha;
752  mutable char aux_string[128]; // used as a label in the autotuner
753  private:
754  unsigned int sharedBytesPerThread() const {
755  return 0;
756  }
757  unsigned int sharedBytesPerBlock(const TuneParam &param) const {
758  return 0;
759  }
760  //bool tuneSharedBytes() const { return false; } // Don't tune shared memory
761  bool tuneGridDim() const {
762  return false;
763  } // Don't tune the grid dimensions.
764  unsigned int minThreads() const {
765  return arg.threads;
766  }
767 
768  public:
769  GaugeFix_GX(GaugeFixArg<Float> &arg, Float alpha)
770  : arg(arg) {
771  half_alpha = alpha * 0.5;
772  cudaFuncSetCacheConfig( kernel_gauge_GX<Elems, Float>, cudaFuncCachePreferL1);
773  }
774  ~GaugeFix_GX () {
775  }
776 
777  void setAlpha(Float alpha){
778  half_alpha = alpha * 0.5;
779  }
780 
781 
782  void apply(const cudaStream_t &stream){
783  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
784  kernel_gauge_GX<Elems, Float><< < tp.grid, tp.block, 0, stream >> > (arg, half_alpha);
785  }
786 
787  TuneKey tuneKey() const {
788  std::stringstream vol;
789  vol << arg.X[0] << "x";
790  vol << arg.X[1] << "x";
791  vol << arg.X[2] << "x";
792  vol << arg.X[3];
793  sprintf(aux_string,"threads=%d,prec=%lu", arg.threads, sizeof(Float));
794  return TuneKey(vol.str().c_str(), typeid(*this).name(), aux_string);
795 
796  }
797 
798  long long flops() const {
799  if ( Elems == 6 ) return 208LL * arg.threads;
800  else return 166LL * arg.threads;
801  }
802  long long bytes() const {
803  return 4LL * Elems * sizeof(Float) * arg.threads;
804  }
805 
806  };
807 
808 
809  template <int Elems, typename Float, typename Gauge>
810  __global__ void kernel_gauge_fix_U_EO( GaugeFixArg<Float> arg, Gauge dataOr){
811  int idd = threadIdx.x + blockIdx.x * blockDim.x;
812 
813  if ( idd >= arg.threads ) return;
814 
815  int parity = 0;
816  int id = idd;
817  if ( idd >= arg.threads / 2 ) {
818  parity = 1;
819  id -= arg.threads / 2;
820  }
821  typedef complex<Float> Cmplx;
822 
823  Matrix<Cmplx,3> g;
824  //for(int i = 0; i < Elems; i++) g.data[i] = arg.gx[idd + i * arg.threads];
825  for ( int i = 0; i < Elems; i++ ) {
826  #ifdef GAUGEFIXING_SITE_MATRIX_LOAD_TEX
827  g.data[i] = TEXTURE_GX<Cmplx>(idd + i * arg.threads);
828  #else
829  g.data[i] = arg.gx[idd + i * arg.threads];
830  #endif
831  }
832  if ( Elems == 6 ) {
833  g(2,0) = conj(g(0,1) * g(1,2) - g(0,2) * g(1,1));
834  g(2,1) = conj(g(0,2) * g(1,0) - g(0,0) * g(1,2));
835  g(2,2) = conj(g(0,0) * g(1,1) - g(0,1) * g(1,0));
836  //42
837  }
838  int x[4];
839  getCoords(x, id, arg.X, parity);
840  for ( int mu = 0; mu < 4; mu++ ) {
841  Matrix<Cmplx,3> U;
842  Matrix<Cmplx,3> g0;
843  dataOr.load((Float*)(U.data),id, mu, parity);
844  U = g * U;
845  //198
846  int idm1 = linkIndexP1(x,arg.X,mu);
847  idm1 += (1 - parity) * arg.threads / 2;
848  //for(int i = 0; i < Elems; i++) g0.data[i] = arg.gx[idm1 + i * arg.threads];
849  for ( int i = 0; i < Elems; i++ ) {
850  #ifdef GAUGEFIXING_SITE_MATRIX_LOAD_TEX
851  g0.data[i] = TEXTURE_GX<Cmplx>(idm1 + i * arg.threads);
852  #else
853  g0.data[i] = arg.gx[idm1 + i * arg.threads];
854  #endif
855  }
856  if ( Elems == 6 ) {
857  g0(2,0) = conj(g0(0,1) * g0(1,2) - g0(0,2) * g0(1,1));
858  g0(2,1) = conj(g0(0,2) * g0(1,0) - g0(0,0) * g0(1,2));
859  g0(2,2) = conj(g0(0,0) * g0(1,1) - g0(0,1) * g0(1,0));
860  //42
861  }
862  U = U * conj(g0);
863  //198
864  dataOr.save((Float*)(U.data),id, mu, parity);
865  }
866  //T=42+4*(198*2+42) Elems=6
867  //T=4*(198*2) Elems=9
868  //Not accounting here the reconstruction of the gauge if 12 or 8!!!!!!
869  }
870 
871 
872  template<int Elems, typename Float, typename Gauge>
873  class GaugeFix : Tunable {
874  GaugeFixArg<Float> arg;
875  Gauge dataOr;
876  mutable char aux_string[128]; // used as a label in the autotuner
877  private:
878  unsigned int sharedBytesPerThread() const {
879  return 0;
880  }
881  unsigned int sharedBytesPerBlock(const TuneParam &param) const {
882  return 0;
883  }
884  //bool tuneSharedBytes() const { return false; } // Don't tune shared memory
885  bool tuneGridDim() const {
886  return false;
887  } // Don't tune the grid dimensions.
888  unsigned int minThreads() const {
889  return arg.threads;
890  }
891 
892  public:
893  GaugeFix(Gauge & dataOr, GaugeFixArg<Float> &arg)
894  : dataOr(dataOr), arg(arg) {
895  cudaFuncSetCacheConfig( kernel_gauge_fix_U_EO<Elems, Float, Gauge>, cudaFuncCachePreferL1);
896  }
897  ~GaugeFix () { }
898 
899 
900  void apply(const cudaStream_t &stream){
901  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
902  kernel_gauge_fix_U_EO<Elems, Float, Gauge><< < tp.grid, tp.block, 0, stream >> > (arg, dataOr);
903  }
904 
905  TuneKey tuneKey() const {
906  std::stringstream vol;
907  vol << arg.X[0] << "x";
908  vol << arg.X[1] << "x";
909  vol << arg.X[2] << "x";
910  vol << arg.X[3];
911  sprintf(aux_string,"threads=%d,prec=%lu", arg.threads, sizeof(Float));
912  return TuneKey(vol.str().c_str(), typeid(*this).name(), aux_string);
913 
914  }
915 
916  //need this
917  void preTune() {
918  arg.data.backup();
919  }
920  void postTune() {
921  arg.data.restore();
922  }
923  long long flops() const {
924  if ( Elems == 6 ) return 1794LL * arg.threads;
925  else return 1536LL * arg.threads;
926  //Not accounting here the reconstruction of the gauge if 12 or 8!!!!!!
927  }
928  long long bytes() const {
929  return 26LL * Elems * sizeof(Float) * arg.threads;
930  }
931 
932  };
933 #endif
934 //GAUGEFIXING_DONT_USE_GX
935 
936 
937  template<int Elems, typename Float, typename Gauge, int gauge_dir>
938  void gaugefixingFFT( Gauge dataOr, cudaGaugeField& data, \
939  const int Nsteps, const int verbose_interval, \
940  const Float alpha0, const int autotune, const double tolerance, \
941  const int stopWtheta) {
942 
943  TimeProfile profileInternalGaugeFixFFT("InternalGaugeFixQudaFFT", false);
944 
945  profileInternalGaugeFixFFT.TPSTART(QUDA_PROFILE_COMPUTE);
946 
947  Float alpha = alpha0;
948  std::cout << "\tAlpha parameter of the Steepest Descent Method: " << alpha << std::endl;
949  if ( autotune ) std::cout << "\tAuto tune active: yes" << std::endl;
950  else std::cout << "\tAuto tune active: no" << std::endl;
951  std::cout << "\tStop criterium: " << tolerance << std::endl;
952  if ( stopWtheta ) std::cout << "\tStop criterium method: theta" << std::endl;
953  else std::cout << "\tStop criterium method: Delta" << std::endl;
954  std::cout << "\tMaximum number of iterations: " << Nsteps << std::endl;
955  std::cout << "\tPrint convergence results at every " << verbose_interval << " steps" << std::endl;
956 
957 
958  unsigned int delta_pad = data.X()[0] * data.X()[1] * data.X()[2] * data.X()[3];
959  int4 size = make_int4( data.X()[0], data.X()[1], data.X()[2], data.X()[3] );
960  cufftHandle plan_xy;
961  cufftHandle plan_zt;
962 
963  GaugeFixArg<Float> arg(data, Elems);
964  SetPlanFFT2DMany( plan_zt, size, 0, arg.delta); //for space and time ZT
965  SetPlanFFT2DMany( plan_xy, size, 1, arg.delta); //with space only XY
966 
967 
968  GaugeFixFFTRotateArg<Float> arg_rotate(data);
969  GaugeFixFFTRotate<Float> GFRotate(arg_rotate);
970 
971  GaugeFixSETINVPSP<Float> setinvpsp(arg);
972  setinvpsp.apply(0);
973  GaugeFixINVPSP<Float> invpsp(arg);
974 
975 
976 #ifdef GAUGEFIXING_DONT_USE_GX
977  //without using GX, gx will be created only for plane rotation but with less size
978  GaugeFixNEW<Float, Gauge> gfixNew(dataOr, arg, alpha);
979 #else
980  //using GX
981  GaugeFix_GX<Elems, Float> calcGX(arg, alpha);
982  GaugeFix<Elems, Float, Gauge> gfix(dataOr, arg);
983 #endif
984 
985  GaugeFixQualityArg<Float, Gauge> argQ(dataOr, data, arg.delta);
986  GaugeFixQuality<Elems, Float, Gauge, gauge_dir> gfixquality(argQ);
987 
988  gfixquality.apply(0);
989  double action0 = argQ.getAction();
990  printf("Step: %d\tAction: %.16e\ttheta: %.16e\n", 0, argQ.getAction(), argQ.getTheta());
991 
992  double diff = 0.0;
993  int iter = 0;
994  for ( iter = 0; iter < Nsteps; iter++ ) {
995  for ( int k = 0; k < 6; k++ ) {
996  //------------------------------------------------------------------------
997  // Set a pointer do the element k in lattice volume
998  // each element is stored with stride lattice volume
999  // it uses gx as temporary array!!!!!!
1000  //------------------------------------------------------------------------
1001  complex<Float> *_array = arg.delta + k * delta_pad;
1003  //------------------------------------------------------------------------
1004  // Perform FFT on xy plane
1005  //------------------------------------------------------------------------
1006  ApplyFFT(plan_xy, _array, arg.gx, CUFFT_FORWARD);
1007  //------------------------------------------------------------------------
1008  // Rotate hypercube, xyzt -> ztxy
1009  //------------------------------------------------------------------------
1010  GFRotate.setDirection(0, arg.gx, _array);
1011  GFRotate.apply(0);
1012  //------------------------------------------------------------------------
1013  // Perform FFT on zt plane
1014  //------------------------------------------------------------------------
1015  ApplyFFT(plan_zt, _array, arg.gx, CUFFT_FORWARD);
1016  //------------------------------------------------------------------------
1017  // Normalize FFT and apply pmax^2/p^2
1018  //------------------------------------------------------------------------
1019  invpsp.apply(0);
1020  //------------------------------------------------------------------------
1021  // Perform IFFT on zt plane
1022  //------------------------------------------------------------------------
1023  ApplyFFT(plan_zt, arg.gx, _array, CUFFT_INVERSE);
1024  //------------------------------------------------------------------------
1025  // Rotate hypercube, ztxy -> xyzt
1026  //------------------------------------------------------------------------
1027  GFRotate.setDirection(1, _array, arg.gx);
1028  GFRotate.apply(0);
1029  //------------------------------------------------------------------------
1030  // Perform IFFT on xy plane
1031  //------------------------------------------------------------------------
1032  ApplyFFT(plan_xy, arg.gx, _array, CUFFT_INVERSE);
1033  }
1034  #ifdef GAUGEFIXING_DONT_USE_GX
1035  //------------------------------------------------------------------------
1036  // Apply gauge fix to current gauge field
1037  //------------------------------------------------------------------------
1038  gfixNew.apply(0);
1039  #else
1040  //------------------------------------------------------------------------
1041  // Calculate g(x)
1042  //------------------------------------------------------------------------
1043  calcGX.apply(0);
1044  //------------------------------------------------------------------------
1045  // Apply gauge fix to current gauge field
1046  //------------------------------------------------------------------------
1047  gfix.apply(0);
1048  #endif
1049  //------------------------------------------------------------------------
1050  // Measure gauge quality and recalculate new Delta(x)
1051  //------------------------------------------------------------------------
1052  gfixquality.apply(0);
1053  double action = argQ.getAction();
1054  diff = abs(action0 - action);
1055  if ((iter % verbose_interval) == (verbose_interval - 1))
1056  printf("Step: %d\tAction: %.16e\ttheta: %.16e\tDelta: %.16e\n", iter + 1, argQ.getAction(), argQ.getTheta(), diff);
1057  if ( autotune && ((action - action0) < -1e-14) ) {
1058  if ( alpha > 0.01 ) {
1059  alpha = 0.95 * alpha;
1060  #ifdef GAUGEFIXING_DONT_USE_GX
1061  gfixNew.setAlpha(alpha);
1062  #else
1063  calcGX.setAlpha(alpha);
1064  #endif
1065  printf(">>>>>>>>>>>>>> Warning: changing alpha down -> %.4e\n", alpha );
1066  }
1067  }
1068  //------------------------------------------------------------------------
1069  // Check gauge fix quality criterium
1070  //------------------------------------------------------------------------
1071  if ( stopWtheta ) { if ( argQ.getTheta() < tolerance ) break; }
1072  else { if ( diff < tolerance ) break; }
1073 
1074  action0 = action;
1075  }
1076  if ((iter % verbose_interval) != 0 )
1077  printf("Step: %d\tAction: %.16e\ttheta: %.16e\tDelta: %.16e\n", iter, argQ.getAction(), argQ.getTheta(), diff);
1078 
1079  // Reunitarize at end
1080  const double unitarize_eps = 1e-14;
1081  const double max_error = 1e-10;
1082  const int reunit_allow_svd = 1;
1083  const int reunit_svd_only = 0;
1084  const double svd_rel_error = 1e-6;
1085  const double svd_abs_error = 1e-6;
1089  int num_failures = 0;
1090  int* num_failures_dev = static_cast<int*>(pool_device_malloc(sizeof(int)));
1091  cudaMemset(num_failures_dev, 0, sizeof(int));
1092  unitarizeLinks(data, data, num_failures_dev);
1093  qudaMemcpy(&num_failures, num_failures_dev, sizeof(int), cudaMemcpyDeviceToHost);
1094 
1096  if ( num_failures > 0 ) {
1097  errorQuda("Error in the unitarization\n");
1098  exit(1);
1099  }
1100  // end reunitarize
1101 
1102 
1103  arg.free();
1104  CUFFT_SAFE_CALL(cufftDestroy(plan_zt));
1105  CUFFT_SAFE_CALL(cufftDestroy(plan_xy));
1106  checkCudaError();
1108  profileInternalGaugeFixFFT.TPSTOP(QUDA_PROFILE_COMPUTE);
1109 
1110  if (getVerbosity() > QUDA_SUMMARIZE){
1111  double secs = profileInternalGaugeFixFFT.Last(QUDA_PROFILE_COMPUTE);
1112  double fftflop = 5.0 * (log2((double)( data.X()[0] * data.X()[1]) ) + log2( (double)(data.X()[2] * data.X()[3] )));
1113  fftflop *= (double)( data.X()[0] * data.X()[1] * data.X()[2] * data.X()[3] );
1114  double gflops = setinvpsp.flops() + gfixquality.flops();
1115  double gbytes = setinvpsp.bytes() + gfixquality.bytes();
1116  double flop = invpsp.flops() * Elems;
1117  double byte = invpsp.bytes() * Elems;
1118  flop += (GFRotate.flops() + fftflop) * Elems * 2;
1119  byte += GFRotate.bytes() * Elems * 4; //includes FFT reads, assuming 1 read and 1 write per site
1120  #ifdef GAUGEFIXING_DONT_USE_GX
1121  flop += gfixNew.flops();
1122  byte += gfixNew.bytes();
1123  #else
1124  flop += calcGX.flops();
1125  byte += calcGX.bytes();
1126  flop += gfix.flops();
1127  byte += gfix.bytes();
1128  #endif
1129  flop += gfixquality.flops();
1130  byte += gfixquality.bytes();
1131  gflops += flop * iter;
1132  gbytes += byte * iter;
1133  gflops += 4588.0 * data.X()[0]*data.X()[1]*data.X()[2]*data.X()[3]; //Reunitarize at end
1134  gbytes += 8.0 * data.X()[0]*data.X()[1]*data.X()[2]*data.X()[3] * dataOr.Bytes() ; //Reunitarize at end
1135 
1136  gflops = (gflops * 1e-9) / (secs);
1137  gbytes = gbytes / (secs * 1e9);
1138  printfQuda("Time: %6.6f s, Gflop/s = %6.1f, GB/s = %6.1f\n", secs, gflops, gbytes);
1139  }
1140  }
1141 
1142  template<int Elems, typename Float, typename Gauge>
1143  void gaugefixingFFT( Gauge dataOr, cudaGaugeField& data, const int gauge_dir, \
1144  const int Nsteps, const int verbose_interval, const Float alpha, const int autotune, \
1145  const double tolerance, const int stopWtheta) {
1146  if ( gauge_dir != 3 ) {
1147  printf("Starting Landau gauge fixing with FFTs...\n");
1148  gaugefixingFFT<Elems, Float, Gauge, 4>(dataOr, data, Nsteps, verbose_interval, alpha, autotune, tolerance, stopWtheta);
1149  }
1150  else {
1151  printf("Starting Coulomb gauge fixing with FFTs...\n");
1152  gaugefixingFFT<Elems, Float, Gauge, 3>(dataOr, data, Nsteps, verbose_interval, alpha, autotune, tolerance, stopWtheta);
1153  }
1154  }
1155 
1156 
1157 
1158  template<typename Float>
1159  void gaugefixingFFT( cudaGaugeField& data, const int gauge_dir, \
1160  const int Nsteps, const int verbose_interval, const Float alpha, const int autotune, \
1161  const double tolerance, const int stopWtheta) {
1162 
1163  // Switching to FloatNOrder for the gauge field in order to support RECONSTRUCT_12
1164  // Need to fix this!!
1165  //9 and 6 means the number of complex elements used to store g(x) and Delta(x)
1166  if ( data.isNative() ) {
1167  if ( data.Reconstruct() == QUDA_RECONSTRUCT_NO ) {
1168  //printfQuda("QUDA_RECONSTRUCT_NO\n");
1169  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO>::type Gauge;
1170  gaugefixingFFT<9, Float>(Gauge(data), data, gauge_dir, Nsteps, verbose_interval, alpha, autotune, tolerance, stopWtheta);
1171  } else if ( data.Reconstruct() == QUDA_RECONSTRUCT_12 ) {
1172  //printfQuda("QUDA_RECONSTRUCT_12\n");
1173  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_12>::type Gauge;
1174  gaugefixingFFT<6, Float>(Gauge(data), data, gauge_dir, Nsteps, verbose_interval, alpha, autotune, tolerance, stopWtheta);
1175  } else if ( data.Reconstruct() == QUDA_RECONSTRUCT_8 ) {
1176  //printfQuda("QUDA_RECONSTRUCT_8\n");
1177  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_8>::type Gauge;
1178  gaugefixingFFT<6, Float>(Gauge(data), data, gauge_dir, Nsteps, verbose_interval, alpha, autotune, tolerance, stopWtheta);
1179 
1180  } else {
1181  errorQuda("Reconstruction type %d of gauge field not supported", data.Reconstruct());
1182  }
1183  } else {
1184  errorQuda("Invalid Gauge Order\n");
1185  }
1186  }
1187 
1188 #endif // GPU_GAUGE_ALG
1189 
1190 
1202  void gaugefixingFFT( cudaGaugeField& data, const int gauge_dir, \
1203  const int Nsteps, const int verbose_interval, const double alpha, const int autotune, \
1204  const double tolerance, const int stopWtheta) {
1205 
1206 #ifdef GPU_GAUGE_ALG
1207 #ifdef MULTI_GPU
1209  errorQuda("Gauge Fixing with FFTs in multi-GPU support NOT implemented yet!\n");
1210 #endif
1211  if ( data.Precision() == QUDA_HALF_PRECISION ) {
1212  errorQuda("Half precision not supported\n");
1213  }
1214  if ( data.Precision() == QUDA_SINGLE_PRECISION ) {
1215  gaugefixingFFT<float> (data, gauge_dir, Nsteps, verbose_interval, (float)alpha, autotune, tolerance, stopWtheta);
1216  } else if ( data.Precision() == QUDA_DOUBLE_PRECISION ) {
1217  gaugefixingFFT<double>(data, gauge_dir, Nsteps, verbose_interval, alpha, autotune, tolerance, stopWtheta);
1218  } else {
1219  errorQuda("Precision %d not supported", data.Precision());
1220  }
1221 #else
1222  errorQuda("Gauge fixing has bot been built");
1223 #endif
1224  }
1225 
1226 
1227 
1228 }
static __device__ __host__ int getIndexFull(int cb_index, const I X[4], int parity)
#define qudaMemcpy(dst, src, count, kind)
Definition: quda_cuda_api.h:32
dim3 dim3 blockDim
void free(void *)
double mu
Definition: test_util.cpp:1643
__device__ __host__ void setZero(Matrix< T, N > *m)
Definition: quda_matrix.h:592
#define LAUNCH_KERNEL_LOCAL_PARITY(kernel, tp, stream, arg,...)
__host__ __device__ ValueType norm(const complex< ValueType > &z)
Returns the magnitude of z squared.
Definition: complex_quda.h:896
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:20
#define errorQuda(...)
Definition: util_quda.h:90
void setUnitarizeLinksConstants(double unitarize_eps, double max_error, bool allow_svd, bool svd_only, double svd_rel_error, double svd_abs_error)
int * num_failures_dev
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:105
void SetPlanFFT2DMany(cufftHandle &plan, int4 size, int dim, float2 *data)
Creates a CUFFT plan supporting 4D (2D+2D) data layouts for single-precision complex-to-complex.
Definition: CUFFT_Plans.h:96
cudaStream_t * stream
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:44
__device__ __host__ double getRealTraceUVdagger(const Matrix< T, 3 > &a, const Matrix< T, 3 > &b)
Definition: quda_matrix.h:1021
int num_failures
void exit(int) __attribute__((noreturn))
double log2(double)
QudaGaugeParam param
Definition: pack_test.cpp:17
static unsigned int delta
void unitarizeLinks(cudaGaugeField &outfield, const cudaGaugeField &infield, int *fails)
int printf(const char *,...) __attribute__((__format__(__printf__
static __device__ __host__ int linkIndexM1(const int x[], const I X[4], const int mu)
__host__ __device__ ValueType sin(ValueType x)
Definition: complex_quda.h:40
def id
projector matrices ######################################################################## ...
for(int s=0;s< param.dc.Ls;s++)
T data[N *N]
Definition: quda_matrix.h:74
#define pool_device_malloc(size)
Definition: malloc_quda.h:113
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:603
Main header file for host and device accessors to GaugeFields.
#define tmp1
Definition: tmc_core.h:15
__device__ __host__ void SubTraceUnit(Matrix< T, 3 > &a)
Definition: quda_matrix.h:1015
cudaError_t qudaDeviceSynchronize()
Wrapper around cudaDeviceSynchronize or cuDeviceSynchronize.
__device__ __host__ void setIdentity(Matrix< T, N > *m)
Definition: quda_matrix.h:543
void ApplyFFT(cufftHandle &plan, float2 *data_in, float2 *data_out, int direction)
Call CUFFT to perform a single-precision complex-to-complex transform plan in the transform direction...
Definition: CUFFT_Plans.h:29
void gaugefixingFFT(cudaGaugeField &data, const int gauge_dir, const int Nsteps, const int verbose_interval, const double alpha, const int autotune, const double tolerance, const int stopWtheta)
Gauge fixing with Steepest descent method with FFTs with support for single GPU only.
int sprintf(char *, const char *,...) __attribute__((__format__(__printf__
#define printfQuda(...)
Definition: util_quda.h:84
int VolumeCB() const
unsigned long long flops
Definition: blas_quda.cu:42
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
Definition: complex_quda.h:880
#define device_malloc(size)
Definition: malloc_quda.h:52
const void * c
QudaReconstructType Reconstruct() const
Definition: gauge_field.h:203
__host__ __device__ ValueType abs(ValueType x)
Definition: complex_quda.h:110
#define pool_device_free(ptr)
Definition: malloc_quda.h:114
#define checkCudaError()
Definition: util_quda.h:129
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:115
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:51
QudaPrecision Precision() const
static __device__ __host__ int linkIndexP1(const int x[], const I X[4], const int mu)
bool isNative() const
QudaParity parity
Definition: covdev_test.cpp:53
static __device__ __host__ int linkNormalIndexP1(const int x[], const I X[4], const int mu)
#define tmp0
Definition: tmc_core.h:14
#define CUFFT_SAFE_CALL(call)
Definition: CUFFT_Plans.h:10
unsigned long long bytes
Definition: blas_quda.cu:43
int comm_dim_partitioned(int dim)
const int * X() const
#define device_free(ptr)
Definition: malloc_quda.h:57
static __device__ __host__ void getCoords(int x[], int cb_index, const I X[], int parity)