QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
gauge_stout.cuh
Go to the documentation of this file.
1 #include <gauge_field_order.h>
2 #include <index_helper.cuh>
3 #include <quda_matrix.h>
4 #include <su3_project.cuh>
5 
6 namespace quda
7 {
8 
9  template <typename Float, typename GaugeOr, typename GaugeDs> struct GaugeSTOUTArg {
10  int threads; // number of active threads required
11  int X[4]; // grid dimensions
12  int border[4];
13  GaugeOr origin;
14  const Float rho;
15  const Float tolerance;
16 
17  GaugeDs dest;
18 
19  GaugeSTOUTArg(GaugeOr &origin, GaugeDs &dest, const GaugeField &data, const Float rho, const Float tolerance) :
20  threads(1),
21  origin(origin),
22  dest(dest),
23  rho(rho),
24  tolerance(tolerance)
25  {
26  for (int dir = 0; dir < 4; ++dir) {
27  border[dir] = data.R()[dir];
28  X[dir] = data.X()[dir] - border[dir] * 2;
29  threads *= X[dir];
30  }
31  threads /= 2;
32  }
33  };
34 
35  template <typename Float, typename Arg, typename Link>
36  __host__ __device__ void computeStaple(Arg &arg, int idx, int parity, int dir, Link &staple)
37  {
38 
39  // compute spacetime dimensions and parity
40  int X[4];
41  for (int dr = 0; dr < 4; ++dr) X[dr] = arg.X[dr];
42 
43  int x[4];
44  getCoords(x, idx, X, parity);
45  for (int dr = 0; dr < 4; ++dr) {
46  x[dr] += arg.border[dr];
47  X[dr] += 2 * arg.border[dr];
48  }
49 
50  setZero(&staple);
51 
52  // I believe most users won't want to include time staples in smearing
53  for (int mu = 0; mu < 3; mu++) {
54 
55  // identify directions orthogonal to the link.
56  if (mu != dir) {
57 
58  int nu = dir;
59  {
60  int dx[4] = {0, 0, 0, 0};
61  Link U1, U2, U3;
62 
63  // Get link U_{\mu}(x)
64  U1 = arg.origin(mu, linkIndexShift(x, dx, X), parity);
65 
66  dx[mu]++;
67  // Get link U_{\nu}(x+\mu)
68  U2 = arg.origin(nu, linkIndexShift(x, dx, X), 1 - parity);
69 
70  dx[mu]--;
71  dx[nu]++;
72  // Get link U_{\mu}(x+\nu)
73  U3 = arg.origin(mu, linkIndexShift(x, dx, X), 1 - parity);
74 
75  // staple += U_{\mu}(x) * U_{\nu}(x+\mu) * U^\dag_{\mu}(x+\nu)
76  staple = staple + U1 * U2 * conj(U3);
77 
78  dx[mu]--;
79  dx[nu]--;
80  // Get link U_{\mu}(x-\mu)
81  U1 = arg.origin(mu, linkIndexShift(x, dx, X), 1 - parity);
82  // Get link U_{\nu}(x-\mu)
83  U2 = arg.origin(nu, linkIndexShift(x, dx, X), 1 - parity);
84 
85  dx[nu]++;
86  // Get link U_{\mu}(x-\mu+\nu)
87  U3 = arg.origin(mu, linkIndexShift(x, dx, X), parity);
88 
89  // staple += U^\dag_{\mu}(x-\mu) * U_{\nu}(x-\mu) * U_{\mu}(x-\mu+\nu)
90  staple = staple + conj(U1) * U2 * U3;
91  }
92  }
93  }
94  }
95 
96  template <typename Float, typename Arg> __global__ void computeSTOUTStep(Arg arg)
97  {
98 
99  int idx = threadIdx.x + blockIdx.x * blockDim.x;
100  int parity = threadIdx.y + blockIdx.y * blockDim.y;
101  int dir = threadIdx.z + blockIdx.z * blockDim.z;
102  if (idx >= arg.threads) return;
103  if (dir >= 3) return;
104  typedef complex<Float> Complex;
105  typedef Matrix<complex<Float>, 3> Link;
106 
107  int X[4];
108  for (int dr = 0; dr < 4; ++dr) X[dr] = arg.X[dr];
109 
110  int x[4];
111  getCoords(x, idx, X, parity);
112  for (int dr = 0; dr < 4; ++dr) {
113  x[dr] += arg.border[dr];
114  X[dr] += 2 * arg.border[dr];
115  }
116 
117  int dx[4] = {0, 0, 0, 0};
118  // Only spatial dimensions are smeared
119  {
120  Link U, UDag, Stap, Omega, OmegaDiff, ODT, Q, exp_iQ;
121  Complex OmegaDiffTr;
122  Complex i_2(0, 0.5);
123 
124  // This function gets stap = S_{mu,nu} i.e., the staple of length 3,
125  computeStaple<Float>(arg, idx, parity, dir, Stap);
126  //
127  // |- > -| /- > -/ /- > -
128  // ^ v ^ v ^
129  // | | / / /- < -
130  // + | | + + / / + + - > -/
131  // v ^ v ^ v
132  // |- > -| /- > -/ - < -/
133 
134  // Get link U
135  U = arg.origin(dir, linkIndexShift(x, dx, X), parity);
136 
137  // Compute Omega_{mu}=[Sum_{mu neq nu}rho_{mu,nu}C_{mu,nu}]*U_{mu}^dag
138 
139  // Get U^{\dagger}
140  UDag = inverse(U);
141 
142  // Compute \Omega = \rho * S * U^{\dagger}
143  Omega = (arg.rho * Stap) * UDag;
144 
145  // Compute \Q_{mu} = i/2[Omega_{mu}^dag - Omega_{mu}
146  // - 1/3 Tr(Omega_{mu}^dag - Omega_{mu})]
147 
148  OmegaDiff = conj(Omega) - Omega;
149 
150  Q = OmegaDiff;
151  OmegaDiffTr = getTrace(OmegaDiff);
152  OmegaDiffTr = (1.0 / 3.0) * OmegaDiffTr;
153 
154  // Matrix proportional to OmegaDiffTr
155  setIdentity(&ODT);
156 
157  Q = Q - OmegaDiffTr * ODT;
158  Q = i_2 * Q;
159  // Q is now defined.
160 
161 #if 0
162  //Test for Tracless:
163  //reuse OmegaDiffTr
164  OmegaDiffTr = getTrace(Q);
165  double error;
166  error = OmegaDiffTr.real();
167  printf("Trace test %d %d %.15e\n", idx, dir, error);
168 
169  //Test for hemiticity:
170  Link Q_diff = conj(Q);
171  Q_diff -= Q; //This should be the zero matrix. Test by ReTr(Q_diff^2);
172  Q_diff *= Q_diff;
173  //reuse OmegaDiffTr
174  OmegaDiffTr = getTrace(Q_diff);
175  error = OmegaDiffTr.real();
176  printf("Herm test %d %d %.15e\n", idx, dir, error);
177 #endif
178 
179  exponentiate_iQ(Q, &exp_iQ);
180 
181 #if 0
182  //Test for expiQ unitarity:
183  error = ErrorSU3(exp_iQ);
184  printf("expiQ test %d %d %.15e\n", idx, dir, error);
185 #endif
186 
187  U = exp_iQ * U;
188 #if 0
189  //Test for expiQ*U unitarity:
190  error = ErrorSU3(U);
191  printf("expiQ*u test %d %d %.15e\n", idx, dir, error);
192 #endif
193 
194  arg.dest(dir, linkIndexShift(x, dx, X), parity) = U;
195  }
196  }
197 
198  //------------------------//
199  // Over-Improved routines //
200  //------------------------//
201 
202  template <typename Float, typename GaugeOr, typename GaugeDs> struct GaugeOvrImpSTOUTArg {
203  int threads; // number of active threads required
204  int X[4]; // grid dimensions
205  int border[4];
206  GaugeOr origin;
207  const Float rho;
208  const Float epsilon;
209  const Float tolerance;
210 
211  GaugeDs dest;
212 
213  GaugeOvrImpSTOUTArg(GaugeOr &origin, GaugeDs &dest, const GaugeField &data, const Float rho, const Float epsilon,
214  const Float tolerance) :
215  threads(1),
216  origin(origin),
217  dest(dest),
218  rho(rho),
219  epsilon(epsilon),
220  tolerance(tolerance)
221  {
222  for (int dir = 0; dir < 4; ++dir) {
223  border[dir] = data.R()[dir];
224  X[dir] = data.X()[dir] - border[dir] * 2;
225  threads *= X[dir];
226  }
227  threads /= 2;
228  }
229  };
230 
231  template <typename Float, typename Arg, typename Link>
232  __host__ __device__ void computeStapleRectangle(Arg &arg, int idx, int parity, int dir, Link &staple, Link &rectangle)
233  {
234 
235  // compute spacetime dimensions and parity
236  int X[4];
237  for (int dr = 0; dr < 4; ++dr) X[dr] = arg.X[dr];
238 
239  int x[4];
240  getCoords(x, idx, X, parity);
241  for (int dr = 0; dr < 4; ++dr) {
242  x[dr] += arg.border[dr];
243  X[dr] += 2 * arg.border[dr];
244  }
245 
246  setZero(&staple);
247  setZero(&rectangle);
248 
249  // Over-Improved stout is usually done for topological
250  // measuremnts, so we include the temporal direction.
251  for (int mu = 0; mu < 4; mu++) {
252 
253  // identify directions orthogonal to the link.
254  if (mu != dir) {
255 
256  int nu = dir;
257 
258  // RECTANGLE calculation
259  // This is done in three parts. For some link U_nu(x) there are
260  // 1x2 rectangles (R12) and two sets of 2x1 rectangles, defined as
261  // 'forward' (R21f) and 'backward' (R21b).
262 
263  // STAPLE calculation
264  // This is done part way through the computation of (R21f) as the
265  // First two links of the staple are already in memory.
266 
267  // Memory usage and communications.
268  // There are 10 unique links to be fetched per direction. 3 of these
269  // links (the ones that form the simple staple) can be recycled on
270  // the fly. The two links immediately succeeding and preceding
271  // U_nu(x) in the nu directon are also reused when changing from
272  // +ve to -ve mu.
273 
274  {
275  int dx[4] = {0, 0, 0, 0};
276  Link U1, U2, U3, U4, U5, U6, U7;
277 
278  //--------//
279  // +ve mu //
280  //--------//
281 
282  //----------------------------------------------------------------
283  // R12 = U_mu(x)*U_mu(x+mu)*U_nu(x+2mu)*U^d_mu(x+nu+mu)*U^d_mu(x+nu)
284  // Get link U_mu(x)
285  U1 = arg.origin(mu, linkIndexShift(x, dx, X), parity);
286 
287  dx[mu]++;
288  // Get link U_mu(x+mu)
289  U2 = arg.origin(mu, linkIndexShift(x, dx, X), 1 - parity);
290 
291  dx[mu]++;
292  // Get link U_nu(x+2mu)
293  U3 = arg.origin(nu, linkIndexShift(x, dx, X), parity);
294 
295  dx[mu]--;
296  dx[nu]++;
297  // Get link U_mu(x+nu+mu)
298  U4 = arg.origin(mu, linkIndexShift(x, dx, X), parity);
299 
300  dx[mu]--;
301  // Get link U_mu(x+nu)
302  U5 = arg.origin(mu, linkIndexShift(x, dx, X), 1 - parity);
303 
304  rectangle = rectangle + U1 * U2 * U3 * conj(U4) * conj(U5);
305  //---------------------------------------------------------------
306 
307  // reset dx
308  dx[nu]--;
309  //---------------------------------------------------------------
310  // R21f=U_mu(x)*U_nu(x+mu)*U_nu(x+nu+mu)*U^d_mu(x+2nu)*U^d_nu(x+nu)
311  // Get link U_mu(x)
312  // Same as U1 from R12
313 
314  dx[mu]++;
315  // Get link U_nu(x+mu)
316  U2 = arg.origin(nu, linkIndexShift(x, dx, X), 1 - parity);
317 
319  // Here we get the third link in the staple and compute.
320  // Get U_mu(x+nu)
321  // Same as U5 from R12
322  staple = staple + U1 * U2 * conj(U5);
324 
325  dx[nu]++;
326  // Get link U_nu(x+nu+mu)
327  U3 = arg.origin(nu, linkIndexShift(x, dx, X), parity);
328 
329  dx[mu]--;
330  dx[nu]++;
331  // Get link U_mu(x+2nu)
332  U4 = arg.origin(mu, linkIndexShift(x, dx, X), parity);
333 
334  dx[nu]--;
335  // Get link U_nu(x+nu)
336  U6 = arg.origin(nu, linkIndexShift(x, dx, X), 1 - parity);
337 
338  rectangle = rectangle + U1 * U2 * U3 * conj(U4) * conj(U6);
339  //---------------------------------------------------------------
340 
341  // reset dx
342  dx[nu]--;
343  //---------------------------------------------------------------
344  // R21b=U^d_nu(x-nu)*U_mu(x-nu)*U_nu(x+nu+mu)*U^d_mu(x+2nu)*U^dag_nu(x+nu)
345 
346  // Get link U_nu(x-nu)
347  dx[nu]--;
348  U7 = arg.origin(nu, linkIndexShift(x, dx, X), 1 - parity);
349 
350  // Get link U_mu(x-nu)
351  U4 = arg.origin(mu, linkIndexShift(x, dx, X), 1 - parity);
352 
353  // Get link U_nu(x-nu+mu)
354  dx[mu]++;
355  U3 = arg.origin(nu, linkIndexShift(x, dx, X), parity);
356 
357  // Get link U_nu(x+mu)
358  // Same as U2 from R21f
359 
360  // Get link U_mu(x+nu)
361  // Same as U5 from R12
362 
363  rectangle = rectangle + conj(U7) * U4 * U3 * U2 * conj(U5);
364  //---------------------------------------------------------------
365 
366  //--------//
367  // -ve mu //
368  //--------//
369 
370  // reset dx
371  dx[mu]--;
372  dx[nu]++;
373  //---------------------------------------------------------------
374  // R12 = U^dag_mu(x-mu) * U^dag_mu(x-2mu) * U_nu(x-2mu) * U_mu(x-2mu+nu) * U_mu(x-mu+nu)
375 
376  dx[mu]--;
377  // Get link U_mu(x-mu)
378  U1 = arg.origin(mu, linkIndexShift(x, dx, X), 1 - parity);
379 
380  dx[mu]--;
381  // Get link U_mu(x-2mu)
382  U2 = arg.origin(mu, linkIndexShift(x, dx, X), parity);
383 
384  // Get link U_nu(x-2mu)
385  U3 = arg.origin(nu, linkIndexShift(x, dx, X), parity);
386 
387  dx[nu]++;
388  // Get link U_mu(x-2mu+nu)
389  U4 = arg.origin(mu, linkIndexShift(x, dx, X), 1 - parity);
390 
391  dx[mu]++;
392  // Get link U_mu(x-mu+nu)
393  U5 = arg.origin(mu, linkIndexShift(x, dx, X), parity);
394 
395  rectangle = rectangle + conj(U1) * conj(U2) * U3 * U4 * U5;
396  //---------------------------------------------------------------
397 
398  // reset dx
399  dx[mu]++;
400  dx[nu]--;
401  //---------------------------------------------------------------
402  // R21f = U^dag_mu(x-mu) * U_nu(x-mu) * U_nu(x-mu+nu) * U_mu(x-mu+2nu) * U^dag_nu(x+nu)
403 
404  // Get link U_mu(x-mu)
405  // Same as U1 from R12
406 
407  dx[mu]--;
408  // Get link U_nu(x-mu)
409  U2 = arg.origin(nu, linkIndexShift(x, dx, X), 1 - parity);
410 
412  // Here we get the third link in the staple and compute.
413  // Get U_mu(x-mu+nu)
414  // Same as U5 from R12
415  staple = staple + conj(U1) * U2 * U5;
417 
418  dx[nu]++;
419  // Get link U_nu(x-mu+nu)
420  U3 = arg.origin(nu, linkIndexShift(x, dx, X), parity);
421 
422  dx[nu]++;
423  // Get link U_mu(x-mu+2nu)
424  U4 = arg.origin(mu, linkIndexShift(x, dx, X), 1 - parity);
425 
426  // Get link U_nu(x+nu)
427  // Same as U6 from +ve R21f
428 
429  rectangle = rectangle + conj(U1) * U2 * U3 * U4 * conj(U6);
430  //---------------------------------------------------------------
431 
432  // reset dx
433  dx[nu]--;
434  dx[nu]--;
435  dx[mu]++;
436  //---------------------------------------------------------------
437  // R21b= U^dag_nu(x-nu) * U^dag_mu(x-mu-nu) * U_nu(x-mu-nu) * U_nu(x-mu) * U_mu(x-mu+nu)
438 
439  // Get link U_nu(x-nu)
440  // Same as U7 from +ve R21b
441 
442  // Get link U_mu(x-mu-nu)
443  dx[nu]--;
444  dx[mu]--;
445  U4 = arg.origin(mu, linkIndexShift(x, dx, X), 1 - parity);
446 
447  // Get link U_nu(x-nu-mu)
448  U3 = arg.origin(nu, linkIndexShift(x, dx, X), parity);
449 
450  // Get link U_nu(x-mu)
451  // Same as U2 from R21f
452 
453  // Get link U_mu(x-mu+nu)
454  // Same as U5 from R12
455 
456  rectangle = rectangle + conj(U7) * conj(U4) * U3 * U2 * U5;
457  //---------------------------------------------------------------
458  }
459  }
460  }
461  }
462 
463  template <typename Float, typename Arg> __global__ void computeOvrImpSTOUTStep(Arg arg)
464  {
465 
466  int idx = threadIdx.x + blockIdx.x * blockDim.x;
467  int parity = threadIdx.y + blockIdx.y * blockDim.y;
468  int dir = threadIdx.z + blockIdx.z * blockDim.z;
469  if (idx >= arg.threads) return;
470  // if (dir >= 3) return;
471  typedef complex<Float> Complex;
472  typedef Matrix<complex<Float>, 3> Link;
473 
474  int X[4];
475  for (int dr = 0; dr < 4; ++dr) X[dr] = arg.X[dr];
476 
477  int x[4];
478  getCoords(x, idx, X, parity);
479  for (int dr = 0; dr < 4; ++dr) {
480  x[dr] += arg.border[dr];
481  X[dr] += 2 * arg.border[dr];
482  }
483 
484  double staple_coeff = (5.0 - 2.0 * arg.epsilon) / 3.0;
485  double rectangle_coeff = (1.0 - arg.epsilon) / 12.0;
486 
487  int dx[4] = {0, 0, 0, 0};
488  // All dimensions are smeared
489  {
490  Link U, UDag, Stap, Rect, Omega, OmegaDiff, ODT, Q, exp_iQ;
491  Complex OmegaDiffTr;
492  Complex i_2(0, 0.5);
493 
494  // This function gets stap = S_{mu,nu} i.e., the staple of length 3,
495  // and the 1x2 and 2x1 rectangles of length 5. From the following paper:
496  // https://arxiv.org/abs/0801.1165
497  computeStapleRectangle<Float>(arg, idx, parity, dir, Stap, Rect);
498 
499  // Get link U
500  U = arg.origin(dir, linkIndexShift(x, dx, X), parity);
501 
502  // Compute Omega_{mu}=[Sum_{mu neq nu}rho_{mu,nu}C_{mu,nu}]*U_{mu}^dag
503  //-------------------------------------------------------------------
504 
505  // Get U^{\dagger}
506  UDag = inverse(U);
507 
508  // Compute \rho * staple_coeff * S
509  Omega = (arg.rho * staple_coeff) * (Stap);
510 
511  // Compute \rho * rectangle_coeff * R
512  Omega = Omega - (arg.rho * rectangle_coeff) * (Rect);
513  Omega = Omega * UDag;
514 
515  // Compute \Q_{mu} = i/2[Omega_{mu}^dag - Omega_{mu}
516  // - 1/3 Tr(Omega_{mu}^dag - Omega_{mu})]
517 
518  OmegaDiff = conj(Omega) - Omega;
519 
520  Q = OmegaDiff;
521  OmegaDiffTr = getTrace(OmegaDiff);
522  OmegaDiffTr = (1.0 / 3.0) * OmegaDiffTr;
523 
524  // Matrix proportional to OmegaDiffTr
525  setIdentity(&ODT);
526 
527  Q = Q - OmegaDiffTr * ODT;
528  Q = i_2 * Q;
529  // Q is now defined.
530 
531 #if 0
532  //Test for Tracless:
533  //reuse OmegaDiffTr
534  OmegaDiffTr = getTrace(Q);
535  double error;
536  error = OmegaDiffTr.real();
537  printf("Trace test %d %d %.15e\n", idx, dir, error);
538 
539  //Test for hemiticity:
540  Link Q_diff = conj(Q);
541  Q_diff -= Q; //This should be the zero matrix. Test by ReTr(Q_diff^2);
542  Q_diff *= Q_diff;
543  //reuse OmegaDiffTr
544  OmegaDiffTr = getTrace(Q_diff);
545  error = OmegaDiffTr.real();
546  printf("Herm test %d %d %.15e\n", idx, dir, error);
547 #endif
548 
549  exponentiate_iQ(Q, &exp_iQ);
550 
551 #if 0
552  //Test for expiQ unitarity:
553  error = ErrorSU3(exp_iQ);
554  printf("expiQ test %d %d %.15e\n", idx, dir, error);
555 #endif
556 
557  U = exp_iQ * U;
558 #if 0
559  //Test for expiQ*U unitarity:
560  error = ErrorSU3(U);
561  printf("expiQ*u test %d %d %.15e\n", idx, dir, error);
562 #endif
563 
564  arg.dest(dir, linkIndexShift(x, dx, X), parity) = U;
565  }
566  }
567 
568 } // namespace quda
__device__ __host__ double ErrorSU3(const Matrix< Cmplx, 3 > &matrix)
Definition: quda_matrix.h:1164
double mu
Definition: test_util.cpp:1648
__device__ __host__ void setZero(Matrix< T, N > *m)
Definition: quda_matrix.h:702
static __device__ __host__ int linkIndexShift(const I x[], const J dx[], const K X[4])
__global__ void computeSTOUTStep(Arg arg)
Definition: gauge_stout.cuh:96
const Float tolerance
Definition: gauge_stout.cuh:15
__global__ void computeOvrImpSTOUTStep(Arg arg)
__host__ __device__ void computeStapleRectangle(Arg &arg, int idx, int parity, int dir, Link &staple, Link &rectangle)
const int * R() const
__host__ __device__ void computeStaple(Arg &arg, int idx, int parity, int dir, Link &staple)
Definition: gauge_ape.cuh:36
Main header file for host and device accessors to GaugeFields.
std::complex< double > Complex
Definition: quda_internal.h:46
__device__ __host__ void setIdentity(Matrix< T, N > *m)
Definition: quda_matrix.h:653
GaugeSTOUTArg(GaugeOr &origin, GaugeDs &dest, const GaugeField &data, const Float rho, const Float tolerance)
Definition: gauge_stout.cuh:19
__device__ __host__ T getTrace(const Matrix< T, 3 > &a)
Definition: quda_matrix.h:415
__device__ __host__ Matrix< T, 3 > inverse(const Matrix< T, 3 > &u)
Definition: quda_matrix.h:611
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
GaugeOvrImpSTOUTArg(GaugeOr &origin, GaugeDs &dest, const GaugeField &data, const Float rho, const Float epsilon, const Float tolerance)
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:130
QudaParity parity
Definition: covdev_test.cpp:54
__device__ __host__ void exponentiate_iQ(const Matrix< T, 3 > &Q, Matrix< T, 3 > *exp_iQ)
Definition: quda_matrix.h:1191
__host__ __device__ int getCoords(int coord[], const Arg &arg, int &idx, int parity, int &dim)
Compute the space-time coordinates we are at.
const int * X() const