QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
gauge_ape.cuh
Go to the documentation of this file.
1 #include <gauge_field_order.h>
2 #include <index_helper.cuh>
3 #include <quda_matrix.h>
4 #include <su3_project.cuh>
5 
6 namespace quda
7 {
8 
9  template <typename Float, typename GaugeOr, typename GaugeDs> struct GaugeAPEArg {
10  int threads; // number of active threads required
11  int X[4]; // grid dimensions
12  int border[4];
13  GaugeOr origin;
14  const Float alpha;
15  const Float tolerance;
16 
17  GaugeDs dest;
18 
19  GaugeAPEArg(GaugeOr &origin, GaugeDs &dest, const GaugeField &data, const Float alpha, const Float tolerance) :
20  threads(1),
21  origin(origin),
22  dest(dest),
23  alpha(alpha),
24  tolerance(tolerance)
25  {
26  for (int dir = 0; dir < 4; ++dir) {
27  border[dir] = data.R()[dir];
28  X[dir] = data.X()[dir] - border[dir] * 2;
29  threads *= X[dir];
30  }
31  threads /= 2;
32  }
33  };
34 
35  template <typename Float, typename Arg, typename Link>
36  __host__ __device__ void computeStaple(Arg &arg, int idx, int parity, int dir, Link &staple)
37  {
38 
39  // compute spacetime dimensions and parity
40  int X[4];
41  for (int dr = 0; dr < 4; ++dr) X[dr] = arg.X[dr];
42 
43  int x[4];
44  getCoords(x, idx, X, parity);
45  for (int dr = 0; dr < 4; ++dr) {
46  x[dr] += arg.border[dr];
47  X[dr] += 2 * arg.border[dr];
48  }
49 
50  setZero(&staple);
51 
52  // I believe most users won't want to include time staples in smearing
53  for (int mu = 0; mu < 3; mu++) {
54 
55  // identify directions orthogonal to the link.
56  if (mu != dir) {
57 
58  int nu = dir;
59  {
60  int dx[4] = {0, 0, 0, 0};
61  Link U1, U2, U3;
62 
63  // Get link U_{\mu}(x)
64  U1 = arg.origin(mu, linkIndexShift(x, dx, X), parity);
65 
66  dx[mu]++;
67  // Get link U_{\nu}(x+\mu)
68  U2 = arg.origin(nu, linkIndexShift(x, dx, X), 1 - parity);
69 
70  dx[mu]--;
71  dx[nu]++;
72  // Get link U_{\mu}(x+\nu)
73  U3 = arg.origin(mu, linkIndexShift(x, dx, X), 1 - parity);
74 
75  // staple += U_{\mu}(x) * U_{\nu}(x+\mu) * U^\dag_{\mu}(x+\nu)
76  staple = staple + U1 * U2 * conj(U3);
77 
78  dx[mu]--;
79  dx[nu]--;
80  // Get link U_{\mu}(x-\mu)
81  U1 = arg.origin(mu, linkIndexShift(x, dx, X), 1 - parity);
82  // Get link U_{\nu}(x-\mu)
83  U2 = arg.origin(nu, linkIndexShift(x, dx, X), 1 - parity);
84 
85  dx[nu]++;
86  // Get link U_{\mu}(x-\mu+\nu)
87  U3 = arg.origin(mu, linkIndexShift(x, dx, X), parity);
88 
89  // staple += U^\dag_{\mu}(x-\mu) * U_{\nu}(x-\mu) * U_{\mu}(x-\mu+\nu)
90  staple = staple + conj(U1) * U2 * U3;
91  }
92  }
93  }
94  }
95 
96  template <typename Float, typename Arg> __global__ void computeAPEStep(Arg arg)
97  {
98 
99  int idx = threadIdx.x + blockIdx.x * blockDim.x;
100  int parity = threadIdx.y + blockIdx.y * blockDim.y;
101  int dir = threadIdx.z + blockIdx.z * blockDim.z;
102  if (idx >= arg.threads) return;
103  if (dir >= 3) return;
104  typedef complex<Float> Complex;
105  typedef Matrix<complex<Float>, 3> Link;
106 
107  int X[4];
108  for (int dr = 0; dr < 4; ++dr) X[dr] = arg.X[dr];
109 
110  int x[4];
111  getCoords(x, idx, X, parity);
112  for (int dr = 0; dr < 4; ++dr) {
113  x[dr] += arg.border[dr];
114  X[dr] += 2 * arg.border[dr];
115  }
116 
117  int dx[4] = {0, 0, 0, 0};
118  // Only spatial dimensions are smeared
119  {
120  Link U, S, TestU, I;
121  // This function gets stap = S_{mu,nu} i.e., the staple of length 3,
122  computeStaple<Float>(arg, idx, parity, dir, S);
123  //
124  // |- > -| /- > -/ /- > -
125  // ^ v ^ v ^
126  // | | / / /- < -
127  // + | | + + / / + + - > -/
128  // v ^ v ^ v
129  // |- > -| /- > -/ - < -/
130 
131  // Get link U
132  U = arg.origin(dir, linkIndexShift(x, dx, X), parity);
133 
134  S = S * (arg.alpha / ((Float)(2. * (3. - 1.))));
135  setIdentity(&I);
136 
137  TestU = I * (1. - arg.alpha) + S * conj(U);
138  polarSu3<Float>(TestU, arg.tolerance);
139  U = TestU * U;
140 
141  arg.dest(dir, linkIndexShift(x, dx, X), parity) = U;
142  }
143  }
144 
145 } // namespace quda
__global__ void computeAPEStep(Arg arg)
Definition: gauge_ape.cuh:96
double mu
Definition: test_util.cpp:1648
__device__ __host__ void setZero(Matrix< T, N > *m)
Definition: quda_matrix.h:702
static __device__ __host__ int linkIndexShift(const I x[], const J dx[], const K X[4])
const Float tolerance
Definition: gauge_ape.cuh:15
GaugeAPEArg(GaugeOr &origin, GaugeDs &dest, const GaugeField &data, const Float alpha, const Float tolerance)
Definition: gauge_ape.cuh:19
This is just a dummy structure we use for trove to define the required structure size.
const int * R() const
__host__ __device__ void computeStaple(Arg &arg, int idx, int parity, int dir, Link &staple)
Definition: gauge_ape.cuh:36
Main header file for host and device accessors to GaugeFields.
std::complex< double > Complex
Definition: quda_internal.h:46
__device__ __host__ void setIdentity(Matrix< T, N > *m)
Definition: quda_matrix.h:653
const Float alpha
Definition: gauge_ape.cuh:14
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:130
QudaParity parity
Definition: covdev_test.cpp:54
__host__ __device__ int getCoords(int coord[], const Arg &arg, int &idx, int parity, int &dim)
Compute the space-time coordinates we are at.
const int * X() const