QUDA  0.9.0
gauge_stout.cu
Go to the documentation of this file.
1 #include <quda_internal.h>
2 #include <quda_matrix.h>
3 #include <su3_project.cuh>
4 #include <tune_quda.h>
5 #include <gauge_field.h>
6 #include <gauge_field_order.h>
7 #include <index_helper.cuh>
8 
9 #define DOUBLE_TOL 1e-15
10 #define SINGLE_TOL 2e-6
11 
12 namespace quda {
13 
14 #ifdef GPU_GAUGE_TOOLS
15 
16  template <typename Float, typename GaugeOr, typename GaugeDs>
17  struct GaugeSTOUTArg {
18  int threads; // number of active threads required
19  int X[4]; // grid dimensions
20  int border[4];
21  GaugeOr origin;
22  const Float rho;
23  const Float tolerance;
24 
25  GaugeDs dest;
26 
27  GaugeSTOUTArg(GaugeOr &origin, GaugeDs &dest, const GaugeField &data, const Float rho, const Float tolerance)
28  : threads(1), origin(origin), dest(dest), rho(rho), tolerance(tolerance) {
29  for ( int dir = 0; dir < 4; ++dir ) {
30  border[dir] = data.R()[dir];
31  X[dir] = data.X()[dir] - border[dir] * 2;
32  threads *= X[dir];
33  }
34  threads /= 2;
35  }
36  };
37 
38 
39  template <typename Float, typename GaugeOr, typename GaugeDs, typename Float2>
40  __host__ __device__ void computeStaple(GaugeSTOUTArg<Float,GaugeOr,GaugeDs>& arg, int idx, int parity, int dir, Matrix<Float2,3> &staple) {
41 
42  typedef Matrix<complex<Float>,3> Link;
43  // compute spacetime dimensions and parity
44 
45  int X[4];
46  for(int dr=0; dr<4; ++dr) X[dr] = arg.X[dr];
47 
48  int x[4];
49  getCoords(x, idx, X, parity);
50  for(int dr=0; dr<4; ++dr) {
51  x[dr] += arg.border[dr];
52  X[dr] += 2*arg.border[dr];
53  }
54 
55  setZero(&staple);
56 
57  // I believe most users won't want to include time staples in smearing
58  for (int mu=0; mu<3; mu++) {
59 
60  //identify directions orthogonal to the link.
61  if (mu != dir) {
62 
63  int nu = dir;
64  {
65  int dx[4] = {0, 0, 0, 0};
66  Link U1, U2, U3;
67 
68  //Get link U_{\mu}(x)
69  U1 = arg.origin(mu, linkIndexShift(x,dx,X), parity);
70 
71  dx[mu]++;
72  //Get link U_{\nu}(x+\mu)
73  U2 = arg.origin(nu, linkIndexShift(x,dx,X), 1-parity);
74 
75  dx[mu]--;
76  dx[nu]++;
77  //Get link U_{\mu}(x+\nu)
78  U3 = arg.origin(mu, linkIndexShift(x,dx,X), 1-parity);
79 
80  // staple += U_{\mu}(x) * U_{\nu}(x+\mu) * U^\dag_{\mu}(x+\nu)
81  staple = staple + U1 * U2 * conj(U3);
82 
83  dx[mu]--;
84  dx[nu]--;
85  //Get link U_{\mu}(x-\mu)
86  U1 = arg.origin(mu, linkIndexShift(x,dx,X), 1-parity);
87  //Get link U_{\nu}(x-\mu)
88  U2 = arg.origin(nu, linkIndexShift(x,dx,X), 1-parity);
89 
90  dx[nu]++;
91  //Get link U_{\mu}(x-\mu+\nu)
92  U3 = arg.origin(mu, linkIndexShift(x,dx,X), parity);
93 
94  // staple += U^\dag_{\mu}(x-\mu) * U_{\nu}(x-\mu) * U_{\mu}(x-\mu+\nu)
95  staple = staple + conj(U1) * U2 * U3;
96  }
97  }
98  }
99  }
100 
101  template<typename Float, typename GaugeOr, typename GaugeDs>
102  __global__ void computeSTOUTStep(GaugeSTOUTArg<Float,GaugeOr,GaugeDs> arg){
103 
104  int idx = threadIdx.x + blockIdx.x*blockDim.x;
105  int parity = threadIdx.y + blockIdx.y*blockDim.y;
106  int dir = threadIdx.z + blockIdx.z*blockDim.z;
107  if (idx >= arg.threads) return;
108  if (dir >= 3) return;
109  typedef complex<Float> Complex;
110  typedef Matrix<complex<Float>,3> Link;
111 
112  int X[4];
113  for(int dr=0; dr<4; ++dr) X[dr] = arg.X[dr];
114 
115  int x[4];
116  getCoords(x, idx, X, parity);
117  for(int dr=0; dr<4; ++dr) {
118  x[dr] += arg.border[dr];
119  X[dr] += 2*arg.border[dr];
120  }
121 
122  int dx[4] = {0, 0, 0, 0};
123  //Only spatial dimensions are smeared
124  {
125  Link U, UDag, Stap, Omega, OmegaDiff, ODT, Q, exp_iQ;
126  Complex OmegaDiffTr;
127  Complex i_2(0,0.5);
128 
129  //This function gets stap = S_{mu,nu} i.e., the staple of length 3,
130  computeStaple<Float,GaugeOr,GaugeDs,Complex>(arg,idx,parity,dir,Stap);
131  //
132  // |- > -| /- > -/ /- > -
133  // ^ v ^ v ^
134  // | | / / /- < -
135  // + | | + + / / + + - > -/
136  // v ^ v ^ v
137  // |- > -| /- > -/ - < -/
138 
139  // Get link U
140  U = arg.origin(dir, linkIndexShift(x,dx,X), parity);
141 
142  //Compute Omega_{mu}=[Sum_{mu neq nu}rho_{mu,nu}C_{mu,nu}]*U_{mu}^dag
143 
144  //Get U^{\dagger}
145  computeMatrixInverse(U,&UDag);
146 
147  //Compute \Omega = \rho * S * U^{\dagger}
148  Omega = (arg.rho * Stap) * UDag;
149 
150  //Compute \Q_{mu} = i/2[Omega_{mu}^dag - Omega_{mu}
151  // - 1/3 Tr(Omega_{mu}^dag - Omega_{mu})]
152 
153  OmegaDiff = conj(Omega) - Omega;
154 
155  Q = OmegaDiff;
156  OmegaDiffTr = getTrace(OmegaDiff);
157  OmegaDiffTr = (1.0/3.0) * OmegaDiffTr;
158 
159  //Matrix proportional to OmegaDiffTr
160  setIdentity(&ODT);
161 
162  Q = Q - OmegaDiffTr * ODT;
163  Q = i_2 * Q;
164  //Q is now defined.
165 
166 #ifdef HOST_DEBUG
167  //Test for Tracless:
168  //reuse OmegaDiffTr
169  OmegaDiffTr = getTrace(Q);
170  double error;
171  error = OmegaDiffTr.real();
172  printf("Trace test %d %d %.15e\n", idx, dir, error);
173 
174  //Test for hemiticity:
175  Link Q_diff = conj(Q);
176  Q_diff -= Q; //This should be the zero matrix. Test by ReTr(Q_diff^2);
177  Q_diff *= Q_diff;
178  //reuse OmegaDiffTr
179  OmegaDiffTr = getTrace(Q_diff);
180  error = OmegaDiffTr.real();
181  printf("Herm test %d %d %.15e\n", idx, dir, error);
182 #endif
183 
184  exponentiate_iQ(Q,&exp_iQ);
185 
186 #ifdef HOST_DEBUG
187  //Test for expiQ unitarity:
188  error = ErrorSU3(exp_iQ);
189  printf("expiQ test %d %d %.15e\n", idx, dir, error);
190 #endif
191 
192  U = exp_iQ * U;
193 #ifdef HOST_DEBUG
194  //Test for expiQ*U unitarity:
195  error = ErrorSU3(U);
196  printf("expiQ*u test %d %d %.15e\n", idx, dir, error);
197 #endif
198 
199  arg.dest(dir, linkIndexShift(x,dx,X), parity) = U;
200  }
201  }
202 
203  template<typename Float, typename GaugeOr, typename GaugeDs>
204  class GaugeSTOUT : TunableVectorYZ {
205  GaugeSTOUTArg<Float,GaugeOr,GaugeDs> arg;
206  const GaugeField &meta;
207 
208  private:
209  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
210  unsigned int minThreads() const { return arg.threads; }
211 
212  public:
213  // (2,3) --- 2 for parity in the y thread dim, 3 corresponds to mapping direction to the z thread dim
214  GaugeSTOUT(GaugeSTOUTArg<Float,GaugeOr,GaugeDs> &arg, const GaugeField &meta)
215  : TunableVectorYZ(2,3), arg(arg), meta(meta) {}
216  virtual ~GaugeSTOUT () {}
217 
218  void apply(const cudaStream_t &stream){
219  if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
220  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
221  computeSTOUTStep<<<tp.grid,tp.block,tp.shared_bytes>>>(arg);
222  } else {
223  errorQuda("CPU not supported yet\n");
224  //computeSTOUTStepCPU(arg);
225  }
226  }
227 
228  TuneKey tuneKey() const {
229  std::stringstream aux;
230  aux << "threads=" << arg.threads << ",prec=" << sizeof(Float);
231  return TuneKey(meta.VolString(), typeid(*this).name(), aux.str().c_str());
232  }
233 
234  long long flops() const { return 3*(2+2*4)*198ll*arg.threads; } // just counts matrix multiplication
235  long long bytes() const { return 3*((1+2*6)*arg.origin.Bytes()+arg.dest.Bytes())*arg.threads; }
236  }; // GaugeSTOUT
237 
238  template<typename Float,typename GaugeOr, typename GaugeDs>
239  void STOUTStep(GaugeOr origin, GaugeDs dest, const GaugeField& dataOr, Float rho) {
240  GaugeSTOUTArg<Float,GaugeOr,GaugeDs> arg(origin, dest, dataOr, rho, dataOr.Precision() == QUDA_DOUBLE_PRECISION ? DOUBLE_TOL : SINGLE_TOL);
241  GaugeSTOUT<Float,GaugeOr,GaugeDs> gaugeSTOUT(arg,dataOr);
242  gaugeSTOUT.apply(0);
244  }
245 
246  template<typename Float>
247  void STOUTStep(GaugeField &dataDs, const GaugeField& dataOr, Float rho) {
248 
249  if(dataDs.Reconstruct() == QUDA_RECONSTRUCT_NO) {
250  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO>::type GDs;
251 
252  if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_NO) {
253  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO>::type GOr;
254  STOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho);
255  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_12){
256  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_12>::type GOr;
257  STOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho);
258  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_8){
259  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_8>::type GOr;
260  STOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho);
261  }else{
262  errorQuda("Reconstruction type %d of origin gauge field not supported", dataOr.Reconstruct());
263  }
264  } else if(dataDs.Reconstruct() == QUDA_RECONSTRUCT_12){
265  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_12>::type GDs;
266  if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_NO){
267  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO>::type GOr;
268  STOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho);
269  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_12){
270  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_12>::type GOr;
271  STOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho);
272  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_8){
273  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_8>::type GOr;
274  STOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho);
275  }else{
276  errorQuda("Reconstruction type %d of origin gauge field not supported", dataOr.Reconstruct());
277  }
278  } else if(dataDs.Reconstruct() == QUDA_RECONSTRUCT_8){
279  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_8>::type GDs;
280  if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_NO){
281  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO>::type GOr;
282  STOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho);
283  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_12){
284  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_12>::type GOr;
285  STOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho);
286  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_8){
287  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_8>::type GOr;
288  STOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho);
289  }else{
290  errorQuda("Reconstruction type %d of origin gauge field not supported", dataOr.Reconstruct());
291  }
292  } else {
293  errorQuda("Reconstruction type %d of destination gauge field not supported", dataDs.Reconstruct());
294  }
295 
296  }
297 
298 #endif
299 
300  void STOUTStep(GaugeField &dataDs, const GaugeField& dataOr, double rho) {
301 
302 #ifdef GPU_GAUGE_TOOLS
303 
304  if(dataOr.Precision() != dataDs.Precision()) {
305  errorQuda("Origin and destination fields must have the same precision\n");
306  }
307 
308  if(dataDs.Precision() == QUDA_HALF_PRECISION){
309  errorQuda("Half precision not supported\n");
310  }
311 
312  if (!dataOr.isNative())
313  errorQuda("Order %d with %d reconstruct not supported", dataOr.Order(), dataOr.Reconstruct());
314 
315  if (!dataDs.isNative())
316  errorQuda("Order %d with %d reconstruct not supported", dataDs.Order(), dataDs.Reconstruct());
317 
318  if (dataDs.Precision() == QUDA_SINGLE_PRECISION){
319  STOUTStep<float>(dataDs, dataOr, (float) rho);
320  } else if(dataDs.Precision() == QUDA_DOUBLE_PRECISION) {
321  STOUTStep<double>(dataDs, dataOr, rho);
322  } else {
323  errorQuda("Precision %d not supported", dataDs.Precision());
324  }
325  return;
326 #else
327  errorQuda("Gauge tools are not build");
328 #endif
329  }
330 
331 
332  //------------------------//
333  // Over-Improved routines //
334  //------------------------//
335 
336 
337  template <typename Float, typename GaugeOr, typename GaugeDs>
339  int threads; // number of active threads required
340  int X[4]; // grid dimensions
341  int border[4];
342  GaugeOr origin;
343  const Float rho;
344  const Float epsilon;
345  const Float tolerance;
346 
347  GaugeDs dest;
348 
349  GaugeOvrImpSTOUTArg(GaugeOr &origin, GaugeDs &dest, const GaugeField &data, const Float rho, const Float epsilon, const Float tolerance)
351  for ( int dir = 0; dir < 4; ++dir ) {
352  border[dir] = data.R()[dir];
353  X[dir] = data.X()[dir] - border[dir] * 2;
354  threads *= X[dir];
355  }
356  threads /= 2;
357  }
358  };
359 
360 
361  template <typename Float, typename GaugeOr, typename GaugeDs, typename Float2>
362  __host__ __device__ void computeStapleRectangle(GaugeOvrImpSTOUTArg<Float,GaugeOr,GaugeDs>& arg, int idx, int parity, int dir,
363  Matrix<Float2,3> &staple, Matrix<Float2,3> &rectangle) {
364 
365  typedef Matrix<complex<Float>,3> Link;
366  // compute spacetime dimensions and parity
367 
368  int X[4];
369  for(int dr=0; dr<4; ++dr) X[dr] = arg.X[dr];
370 
371  int x[4];
372  getCoords(x, idx, X, parity);
373  for(int dr=0; dr<4; ++dr) {
374  x[dr] += arg.border[dr];
375  X[dr] += 2*arg.border[dr];
376  }
377 
378  setZero(&staple);
379  setZero(&rectangle);
380 
381  // Over-Improved stout is usually done for topological
382  // measuremnts, so we include the temporal direction.
383  for (int mu=0; mu<4; mu++) {
384 
385  //identify directions orthogonal to the link.
386  if (mu != dir) {
387 
388  int nu = dir;
389 
390  //RECTANGLE calculation
391  // This is done in three parts. For some link U_nu(x) there are
392  // 1x2 rectangles (R12) and two sets of 2x1 rectangles, defined as
393  // 'forward' (R21f) and 'backward' (R21b).
394 
395  //STAPLE calculation
396  // This is done part way through the computation of (R21f) as the
397  // First two links of the staple are already in memory.
398 
399  //Memory usage and communications.
400  // There are 10 unique links to be fetched per direction. 3 of these
401  // links (the ones that form the simple staple) can be recycled on
402  // the fly. The two links immediately succeeding and preceding
403  // U_nu(x) in the nu directon are also reused when changing from
404  // +ve to -ve mu.
405 
406  {
407  int dx[4] = {0, 0, 0, 0};
408  Link U1, U2, U3, U4, U5, U6, U7;
409 
410  //--------//
411  // +ve mu //
412  //--------//
413 
414  //----------------------------------------------------------------
415  //R12 = U_mu(x)*U_mu(x+mu)*U_nu(x+2mu)*U^d_mu(x+nu+mu)*U^d_mu(x+nu)
416  //Get link U_mu(x)
417  U1 = arg.origin(mu, linkIndexShift(x,dx,X), parity);
418 
419  dx[mu]++;
420  //Get link U_mu(x+mu)
421  U2 = arg.origin(mu, linkIndexShift(x,dx,X), 1-parity);
422 
423  dx[mu]++;
424  //Get link U_nu(x+2mu)
425  U3 = arg.origin(nu, linkIndexShift(x,dx,X), parity);
426 
427  dx[mu]--;
428  dx[nu]++;
429  //Get link U_mu(x+nu+mu)
430  U4 = arg.origin(mu, linkIndexShift(x,dx,X), parity);
431 
432  dx[mu]--;
433  //Get link U_mu(x+nu)
434  U5 = arg.origin(mu, linkIndexShift(x,dx,X), 1-parity);
435 
436  rectangle = rectangle + U1*U2*U3*conj(U4)*conj(U5);
437  //---------------------------------------------------------------
438 
439  //reset dx
440  dx[nu]--;
441  //---------------------------------------------------------------
442  //R21f=U_mu(x)*U_nu(x+mu)*U_nu(x+nu+mu)*U^d_mu(x+2nu)*U^d_nu(x+nu)
443  //Get link U_mu(x)
444  //Same as U1 from R12
445 
446  dx[mu]++;
447  //Get link U_nu(x+mu)
448  U2 = arg.origin(nu, linkIndexShift(x,dx,X), 1-parity);
449 
451  //Here we get the third link in the staple and compute.
452  //Get U_mu(x+nu)
453  //Same as U5 from R12
454  staple = staple + U1*U2*conj(U5);
456 
457  dx[nu]++;
458  //Get link U_nu(x+nu+mu)
459  U3 = arg.origin(nu, linkIndexShift(x,dx,X), parity);
460 
461  dx[mu]--;
462  dx[nu]++;
463  //Get link U_mu(x+2nu)
464  U4 = arg.origin(mu, linkIndexShift(x,dx,X), parity);
465 
466  dx[nu]--;
467  //Get link U_nu(x+nu)
468  U6 = arg.origin(nu, linkIndexShift(x,dx,X), 1-parity);
469 
470  rectangle = rectangle + U1 * U2 * U3 * conj(U4) * conj(U6);
471  //---------------------------------------------------------------
472 
473 
474  //reset dx
475  dx[nu]--;
476  //---------------------------------------------------------------
477  //R21b=U^d_nu(x-nu)*U_mu(x-nu)*U_nu(x+nu+mu)*U^d_mu(x+2nu)*U^dag_nu(x+nu)
478 
479  //Get link U_nu(x-nu)
480  dx[nu]--;
481  U7 = arg.origin(nu, linkIndexShift(x,dx,X), 1-parity);
482 
483  //Get link U_mu(x-nu)
484  U4 = arg.origin(mu, linkIndexShift(x,dx,X), 1-parity);
485 
486  //Get link U_nu(x-nu+mu)
487  dx[mu]++;
488  U3 = arg.origin(nu, linkIndexShift(x,dx,X), parity);
489 
490  //Get link U_nu(x+mu)
491  //Same as U2 from R21f
492 
493  //Get link U_mu(x+nu)
494  //Same as U5 from R12
495 
496  rectangle = rectangle + conj(U7) * U4 * U3 * U2 * conj(U5);
497  //---------------------------------------------------------------
498 
499 
500  //--------//
501  // -ve mu //
502  //--------//
503 
504  //reset dx
505  dx[mu]--;
506  dx[nu]++;
507  //---------------------------------------------------------------
508  // R12 = U^dag_mu(x-mu) * U^dag_mu(x-2mu) * U_nu(x-2mu) * U_mu(x-2mu+nu) * U_mu(x-mu+nu)
509 
510  dx[mu]--;
511  //Get link U_mu(x-mu)
512  U1 = arg.origin(mu, linkIndexShift(x,dx,X), 1-parity);
513 
514  dx[mu]--;
515  //Get link U_mu(x-2mu)
516  U2 = arg.origin(mu, linkIndexShift(x,dx,X), parity);
517 
518  //Get link U_nu(x-2mu)
519  U3 = arg.origin(nu, linkIndexShift(x,dx,X), parity);
520 
521  dx[nu]++;
522  //Get link U_mu(x-2mu+nu)
523  U4 = arg.origin(mu, linkIndexShift(x,dx,X), 1-parity);
524 
525  dx[mu]++;
526  //Get link U_mu(x-mu+nu)
527  U5 = arg.origin(mu, linkIndexShift(x,dx,X), parity);
528 
529  rectangle = rectangle + conj(U1) * conj(U2) * U3 * U4 * U5;
530  //---------------------------------------------------------------
531 
532  //reset dx
533  dx[mu]++;
534  dx[nu]--;
535  //---------------------------------------------------------------
536  // R21f = U^dag_mu(x-mu) * U_nu(x-mu) * U_nu(x-mu+nu) * U_mu(x-mu+2nu) * U^dag_nu(x+nu)
537 
538  //Get link U_mu(x-mu)
539  //Same as U1 from R12
540 
541  dx[mu]--;
542  //Get link U_nu(x-mu)
543  U2 = arg.origin(nu, linkIndexShift(x,dx,X), 1-parity);
544 
546  //Here we get the third link in the staple and compute.
547  //Get U_mu(x-mu+nu)
548  //Same as U5 from R12
549  staple = staple + conj(U1) * U2 * U5;
551 
552  dx[nu]++;
553  //Get link U_nu(x-mu+nu)
554  U3 = arg.origin(nu, linkIndexShift(x,dx,X), parity);
555 
556  dx[nu]++;
557  //Get link U_mu(x-mu+2nu)
558  U4 = arg.origin(mu, linkIndexShift(x,dx,X), 1-parity);
559 
560  //Get link U_nu(x+nu)
561  //Same as U6 from +ve R21f
562 
563  rectangle = rectangle + conj(U1) * U2 * U3 * U4 * conj(U6);
564  //---------------------------------------------------------------
565 
566  //reset dx
567  dx[nu]--;
568  dx[nu]--;
569  dx[mu]++;
570  //---------------------------------------------------------------
571  // R21b= U^dag_nu(x-nu) * U^dag_mu(x-mu-nu) * U_nu(x-mu-nu) * U_nu(x-mu) * U_mu(x-mu+nu)
572 
573  //Get link U_nu(x-nu)
574  //Same as U7 from +ve R21b
575 
576  //Get link U_mu(x-mu-nu)
577  dx[nu]--;
578  dx[mu]--;
579  U4 = arg.origin(mu, linkIndexShift(x,dx,X), 1-parity);
580 
581  //Get link U_nu(x-nu-mu)
582  U3 = arg.origin(nu, linkIndexShift(x,dx,X), parity);
583 
584  //Get link U_nu(x-mu)
585  //Same as U2 from R21f
586 
587  //Get link U_mu(x-mu+nu)
588  //Same as U5 from R12
589 
590  rectangle = rectangle + conj(U7) * conj(U4) * U3 * U2 * U5;
591  //---------------------------------------------------------------
592  }
593  }
594  }
595  }
596 
597  template<typename Float, typename GaugeOr, typename GaugeDs>
599 
600  int idx = threadIdx.x + blockIdx.x*blockDim.x;
601  int parity = threadIdx.y + blockIdx.y*blockDim.y;
602  int dir = threadIdx.z + blockIdx.z*blockDim.z;
603  if (idx >= arg.threads) return;
604  //if (dir >= 3) return;
605  typedef complex<Float> Complex;
606  typedef Matrix<complex<Float>,3> Link;
607 
608  int X[4];
609  for(int dr=0; dr<4; ++dr) X[dr] = arg.X[dr];
610 
611  int x[4];
612  getCoords(x, idx, X, parity);
613  for(int dr=0; dr<4; ++dr) {
614  x[dr] += arg.border[dr];
615  X[dr] += 2*arg.border[dr];
616  }
617 
618  double staple_coeff = (5.0 - 2.0*arg.epsilon)/3.0;
619  double rectangle_coeff = (1.0 - arg.epsilon)/12.0;
620 
621  int dx[4] = {0, 0, 0, 0};
622  //All dimensions are smeared
623  {
624  Link U, UDag, Stap, Rect, Omega, OmegaDiff, ODT, Q, exp_iQ;
625  Complex OmegaDiffTr;
626  Complex i_2(0,0.5);
627 
628  //This function gets stap = S_{mu,nu} i.e., the staple of length 3,
629  //and the 1x2 and 2x1 rectangles of length 5. From the following paper:
630  //https://arxiv.org/abs/0801.1165
631  computeStapleRectangle<Float,GaugeOr,GaugeDs,Complex>(arg,idx,parity,dir,Stap,Rect);
632 
633  // Get link U
634  U = arg.origin(dir, linkIndexShift(x,dx,X), parity);
635 
636  //Compute Omega_{mu}=[Sum_{mu neq nu}rho_{mu,nu}C_{mu,nu}]*U_{mu}^dag
637  //-------------------------------------------------------------------
638 
639  //Get U^{\dagger}
640  computeMatrixInverse(U,&UDag);
641 
642  //Compute \rho * staple_coeff * S
643  Omega = (arg.rho*staple_coeff)*(Stap);
644 
645  //Compute \rho * rectangle_coeff * R
646  Omega = Omega - (arg.rho*rectangle_coeff)*(Rect);
647  Omega = Omega * UDag;
648 
649  //Compute \Q_{mu} = i/2[Omega_{mu}^dag - Omega_{mu}
650  // - 1/3 Tr(Omega_{mu}^dag - Omega_{mu})]
651 
652  OmegaDiff = conj(Omega) - Omega;
653 
654  Q = OmegaDiff;
655  OmegaDiffTr = getTrace(OmegaDiff);
656  OmegaDiffTr = (1.0/3.0) * OmegaDiffTr;
657 
658  //Matrix proportional to OmegaDiffTr
659  setIdentity(&ODT);
660 
661  Q = Q - OmegaDiffTr * ODT;
662  Q = i_2 * Q;
663  //Q is now defined.
664 
665 #ifdef HOST_DEBUG
666  //Test for Tracless:
667  //reuse OmegaDiffTr
668  OmegaDiffTr = getTrace(Q);
669  double error;
670  error = OmegaDiffTr.real();
671  printf("Trace test %d %d %.15e\n", idx, dir, error);
672 
673  //Test for hemiticity:
674  Link Q_diff = conj(Q);
675  Q_diff -= Q; //This should be the zero matrix. Test by ReTr(Q_diff^2);
676  Q_diff *= Q_diff;
677  //reuse OmegaDiffTr
678  OmegaDiffTr = getTrace(Q_diff);
679  error = OmegaDiffTr.real();
680  printf("Herm test %d %d %.15e\n", idx, dir, error);
681 #endif
682 
683  exponentiate_iQ(Q,&exp_iQ);
684 
685 #ifdef HOST_DEBUG
686  //Test for expiQ unitarity:
687  error = ErrorSU3(exp_iQ);
688  printf("expiQ test %d %d %.15e\n", idx, dir, error);
689 #endif
690 
691  U = exp_iQ * U;
692 #ifdef HOST_DEBUG
693  //Test for expiQ*U unitarity:
694  error = ErrorSU3(U);
695  printf("expiQ*u test %d %d %.15e\n", idx, dir, error);
696 #endif
697 
698  arg.dest(dir, linkIndexShift(x,dx,X), parity) = U;
699  }
700  }
701 
702 
703  template<typename Float, typename GaugeOr, typename GaugeDs>
706  const GaugeField &meta;
707 
708  private:
709  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
710  unsigned int minThreads() const { return arg.threads; }
711 
712  public:
713  // (2,3) --- 2 for parity in the y thread dim, 3 corresponds to mapping direction to the z thread dim
715  : TunableVectorYZ(2,3), arg(arg), meta(meta) {}
716  virtual ~GaugeOvrImpSTOUT () {}
717 
718  void apply(const cudaStream_t &stream){
720  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
722  } else {
723  errorQuda("CPU not supported yet\n");
724  //computeOvrImpSTOUTStepCPU(arg);
725  }
726  }
727 
728  TuneKey tuneKey() const {
729  std::stringstream aux;
730  aux << "threads=" << arg.threads << ",prec=" << sizeof(Float);
731  return TuneKey(meta.VolString(), typeid(*this).name(), aux.str().c_str());
732  }
733 
734  long long flops() const { return 4*(18+2+2*4)*198ll*arg.threads; } // just counts matrix multiplication
735  long long bytes() const { return 4*((1+2*12)*arg.origin.Bytes()+arg.dest.Bytes())*arg.threads; }
736  }; // GaugeSTOUT
737 
738 
739  template<typename Float,typename GaugeOr, typename GaugeDs>
740  void OvrImpSTOUTStep(GaugeOr origin, GaugeDs dest, const GaugeField& dataOr, Float rho, Float epsilon) {
741  GaugeOvrImpSTOUTArg<Float,GaugeOr,GaugeDs> arg(origin, dest, dataOr, rho, epsilon,
743  GaugeOvrImpSTOUT<Float,GaugeOr,GaugeDs> gaugeOvrImpSTOUT(arg,dataOr);
744  gaugeOvrImpSTOUT.apply(0);
746  }
747 
748  template<typename Float>
749  void OvrImpSTOUTStep(GaugeField &dataDs, const GaugeField& dataOr, Float rho, Float epsilon) {
750 
751  if(dataDs.Reconstruct() == QUDA_RECONSTRUCT_NO) {
753 
754  if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_NO) {
756  OvrImpSTOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho, epsilon);
757  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_12){
759  OvrImpSTOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho, epsilon);
760  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_8){
762  OvrImpSTOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho, epsilon);
763  }else{
764  errorQuda("Reconstruction type %d of origin gauge field not supported", dataOr.Reconstruct());
765  }
766  } else if(dataDs.Reconstruct() == QUDA_RECONSTRUCT_12){
768  if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_NO){
770  OvrImpSTOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho, epsilon);
771  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_12){
773  OvrImpSTOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho, epsilon);
774  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_8){
776  OvrImpSTOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho, epsilon);
777  }else{
778  errorQuda("Reconstruction type %d of origin gauge field not supported", dataOr.Reconstruct());
779  }
780  } else if(dataDs.Reconstruct() == QUDA_RECONSTRUCT_8){
782  if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_NO){
784  OvrImpSTOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho, epsilon);
785  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_12){
787  OvrImpSTOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho, epsilon);
788  }else if(dataOr.Reconstruct() == QUDA_RECONSTRUCT_8){
790  OvrImpSTOUTStep(GOr(dataOr), GDs(dataDs), dataOr, rho, epsilon);
791  }else{
792  errorQuda("Reconstruction type %d of origin gauge field not supported", dataOr.Reconstruct());
793  }
794  } else {
795  errorQuda("Reconstruction type %d of destination gauge field not supported", dataDs.Reconstruct());
796  }
797 
798  }
799 
800 
801  void OvrImpSTOUTStep(GaugeField &dataDs, const GaugeField& dataOr, double rho, double epsilon) {
802 
803 #ifdef GPU_GAUGE_TOOLS
804 
805  if(dataOr.Precision() != dataDs.Precision()) {
806  errorQuda("Origin and destination fields must have the same precision\n");
807  }
808 
809  if(dataDs.Precision() == QUDA_HALF_PRECISION){
810  errorQuda("Half precision not supported\n");
811  }
812 
813  if (!dataOr.isNative())
814  errorQuda("Order %d with %d reconstruct not supported", dataOr.Order(), dataOr.Reconstruct());
815 
816  if (!dataDs.isNative())
817  errorQuda("Order %d with %d reconstruct not supported", dataDs.Order(), dataDs.Reconstruct());
818 
819  if (dataDs.Precision() == QUDA_SINGLE_PRECISION){
820  OvrImpSTOUTStep<float>(dataDs, dataOr, (float) rho, epsilon);
821  } else if(dataDs.Precision() == QUDA_DOUBLE_PRECISION) {
822  OvrImpSTOUTStep<double>(dataDs, dataOr, rho, epsilon);
823  } else {
824  errorQuda("Precision %d not supported", dataDs.Precision());
825  }
826  return;
827 #else
828  errorQuda("Gauge tools are not build");
829 #endif
830  }
831 }
dim3 dim3 blockDim
__device__ __host__ double ErrorSU3(const Matrix< Cmplx, 3 > &matrix)
Definition: quda_matrix.h:1083
double mu
Definition: test_util.cpp:1643
__device__ __host__ void setZero(Matrix< T, N > *m)
Definition: quda_matrix.h:592
static __device__ __host__ int linkIndexShift(const I x[], const J dx[], const K X[4])
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:20
#define errorQuda(...)
Definition: util_quda.h:90
void STOUTStep(GaugeField &dataDs, const GaugeField &dataOr, double rho)
Definition: gauge_stout.cu:300
std::complex< double > Complex
Definition: eig_variables.h:13
void apply(const cudaStream_t &stream)
Definition: gauge_stout.cu:718
cudaStream_t * stream
const char * VolString() const
TuneKey tuneKey() const
Definition: gauge_stout.cu:728
const int * R() const
#define SINGLE_TOL
Definition: gauge_stout.cu:10
int printf(const char *,...) __attribute__((__format__(__printf__
GaugeOvrImpSTOUTArg< Float, GaugeOr, GaugeDs > arg
Definition: gauge_stout.cu:705
GaugeOvrImpSTOUT(GaugeOvrImpSTOUTArg< Float, GaugeOr, GaugeDs > &arg, const GaugeField &meta)
Definition: gauge_stout.cu:714
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:603
const GaugeField & meta
Definition: gauge_stout.cu:706
long long bytes() const
Definition: gauge_stout.cu:735
long long flops() const
Definition: gauge_stout.cu:734
Main header file for host and device accessors to GaugeFields.
void OvrImpSTOUTStep(GaugeField &dataDs, const GaugeField &dataOr, double rho, double epsilon)
Definition: gauge_stout.cu:801
cudaError_t qudaDeviceSynchronize()
Wrapper around cudaDeviceSynchronize or cuDeviceSynchronize.
__host__ __device__ void computeStapleRectangle(GaugeOvrImpSTOUTArg< Float, GaugeOr, GaugeDs > &arg, int idx, int parity, int dir, Matrix< Float2, 3 > &staple, Matrix< Float2, 3 > &rectangle)
Definition: gauge_stout.cu:362
bool tuneGridDim() const
Definition: gauge_stout.cu:709
__device__ __host__ void setIdentity(Matrix< T, N > *m)
Definition: quda_matrix.h:543
QudaFieldLocation Location() const
__device__ __host__ T getTrace(const Matrix< T, 3 > &a)
Definition: quda_matrix.h:305
#define DOUBLE_TOL
Definition: gauge_stout.cu:9
unsigned long long flops
Definition: blas_quda.cu:42
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
Definition: complex_quda.h:880
__device__ __host__ void computeMatrixInverse(const Matrix< T, 3 > &u, Matrix< T, 3 > *uinv)
Definition: quda_matrix.h:501
__global__ void computeOvrImpSTOUTStep(GaugeOvrImpSTOUTArg< Float, GaugeOr, GaugeDs > arg)
Definition: gauge_stout.cu:598
GaugeOvrImpSTOUTArg(GaugeOr &origin, GaugeDs &dest, const GaugeField &data, const Float rho, const Float epsilon, const Float tolerance)
Definition: gauge_stout.cu:349
QudaReconstructType Reconstruct() const
Definition: gauge_field.h:203
QudaGaugeFieldOrder Order() const
Definition: gauge_field.h:204
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:115
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:51
QudaPrecision Precision() const
bool isNative() const
QudaParity parity
Definition: covdev_test.cpp:53
__device__ __host__ void exponentiate_iQ(const Matrix< T, 3 > &Q, Matrix< T, 3 > *exp_iQ)
Definition: quda_matrix.h:1110
char aux[TuneKey::aux_n]
Definition: tune_quda.h:189
unsigned long long bytes
Definition: blas_quda.cu:43
const int * X() const
static __device__ __host__ void getCoords(int x[], int cb_index, const I X[], int parity)
unsigned int minThreads() const
Definition: gauge_stout.cu:710