QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
field_strength_tensor.cuh
Go to the documentation of this file.
1 #include <gauge_field_order.h>
2 #include <index_helper.cuh>
3 #include <quda_matrix.h>
4 
5 namespace quda
6 {
7 
8  template <typename Float, typename Fmunu, typename Gauge> struct FmunuArg {
9  int threads; // number of active threads required
10  int X[4]; // grid dimensions
11  int border[4];
12  Fmunu f;
13  Gauge gauge;
14 
15  FmunuArg(Fmunu &f, Gauge &gauge, const GaugeField &meta, const GaugeField &meta_ex) :
16  threads(meta.VolumeCB()),
17  f(f),
18  gauge(gauge)
19  {
20  for (int dir = 0; dir < 4; ++dir) {
21  X[dir] = meta.X()[dir];
22  border[dir] = (meta_ex.X()[dir] - X[dir]) / 2;
23  }
24  }
25  };
26 
27  template <int mu, int nu, typename Float, typename Arg>
28  __device__ __host__ __forceinline__ void computeFmunuCore(Arg &arg, int idx, int parity)
29  {
30  typedef Matrix<complex<Float>, 3> Link;
31 
32  int x[4];
33  auto &X = arg.X;
34 
35  getCoords(x, idx, X, parity);
36  for (int dir = 0; dir < 4; ++dir) {
37  x[dir] += arg.border[dir];
38  X[dir] += 2 * arg.border[dir];
39  }
40 
41  Link F;
42  { // U(x,mu) U(x+mu,nu) U[dagger](x+nu,mu) U[dagger](x,nu)
43 
44  // load U(x)_(+mu)
45  int dx[4] = {0, 0, 0, 0};
46  Link U1 = arg.gauge(mu, linkIndexShift(x, dx, X), parity);
47 
48  // load U(x+mu)_(+nu)
49  dx[mu]++;
50  Link U2 = arg.gauge(nu, linkIndexShift(x, dx, X), 1 - parity);
51  dx[mu]--;
52 
53  // load U(x+nu)_(+mu)
54  dx[nu]++;
55  Link U3 = arg.gauge(mu, linkIndexShift(x, dx, X), 1 - parity);
56  dx[nu]--;
57 
58  // load U(x)_(+nu)
59  Link U4 = arg.gauge(nu, linkIndexShift(x, dx, X), parity);
60 
61  // compute plaquette
62  F = U1 * U2 * conj(U3) * conj(U4);
63  }
64 
65  { // U(x,nu) U[dagger](x+nu-mu,mu) U[dagger](x-mu,nu) U(x-mu, mu)
66 
67  // load U(x)_(+nu)
68  int dx[4] = {0, 0, 0, 0};
69  Link U1 = arg.gauge(nu, linkIndexShift(x, dx, X), parity);
70 
71  // load U(x+nu)_(-mu) = U(x+nu-mu)_(+mu)
72  dx[nu]++;
73  dx[mu]--;
74  Link U2 = arg.gauge(mu, linkIndexShift(x, dx, X), parity);
75  dx[mu]++;
76  dx[nu]--;
77 
78  // load U(x-mu)_nu
79  dx[mu]--;
80  Link U3 = arg.gauge(nu, linkIndexShift(x, dx, X), 1 - parity);
81  dx[mu]++;
82 
83  // load U(x)_(-mu) = U(x-mu)_(+mu)
84  dx[mu]--;
85  Link U4 = arg.gauge(mu, linkIndexShift(x, dx, X), 1 - parity);
86  dx[mu]++;
87 
88  // sum this contribution to Fmunu
89  F += U1 * conj(U2) * conj(U3) * U4;
90  }
91 
92  { // U[dagger](x-nu,nu) U(x-nu,mu) U(x+mu-nu,nu) U[dagger](x,mu)
93 
94  // load U(x)_(-nu)
95  int dx[4] = {0, 0, 0, 0};
96  dx[nu]--;
97  Link U1 = arg.gauge(nu, linkIndexShift(x, dx, X), 1 - parity);
98  dx[nu]++;
99 
100  // load U(x-nu)_(+mu)
101  dx[nu]--;
102  Link U2 = arg.gauge(mu, linkIndexShift(x, dx, X), 1 - parity);
103  dx[nu]++;
104 
105  // load U(x+mu-nu)_(+nu)
106  dx[mu]++;
107  dx[nu]--;
108  Link U3 = arg.gauge(nu, linkIndexShift(x, dx, X), parity);
109  dx[nu]++;
110  dx[mu]--;
111 
112  // load U(x)_(+mu)
113  Link U4 = arg.gauge(mu, linkIndexShift(x, dx, X), parity);
114 
115  // sum this contribution to Fmunu
116  F += conj(U1) * U2 * U3 * conj(U4);
117  }
118 
119  { // U[dagger](x-mu,mu) U[dagger](x-mu-nu,nu) U(x-mu-nu,mu) U(x-nu,nu)
120 
121  // load U(x)_(-mu)
122  int dx[4] = {0, 0, 0, 0};
123  dx[mu]--;
124  Link U1 = arg.gauge(mu, linkIndexShift(x, dx, X), 1 - parity);
125  dx[mu]++;
126 
127  // load U(x-mu)_(-nu) = U(x-mu-nu)_(+nu)
128  dx[mu]--;
129  dx[nu]--;
130  Link U2 = arg.gauge(nu, linkIndexShift(x, dx, X), parity);
131  dx[nu]++;
132  dx[mu]++;
133 
134  // load U(x-nu)_mu
135  dx[mu]--;
136  dx[nu]--;
137  Link U3 = arg.gauge(mu, linkIndexShift(x, dx, X), parity);
138  dx[nu]++;
139  dx[mu]++;
140 
141  // load U(x)_(-nu) = U(x-nu)_(+nu)
142  dx[nu]--;
143  Link U4 = arg.gauge(nu, linkIndexShift(x, dx, X), 1 - parity);
144  dx[nu]++;
145 
146  // sum this contribution to Fmunu
147  F += conj(U1) * conj(U2) * U3 * U4;
148  }
149  // 3 matrix additions, 12 matrix-matrix multiplications, 8 matrix conjugations
150  // Each matrix conjugation involves 9 unary minus operations but these ar not included in the operation count
151  // Each matrix addition involves 18 real additions
152  // Each matrix-matrix multiplication involves 9*3 complex multiplications and 9*2 complex additions
153  // = 9*3*6 + 9*2*2 = 198 floating-point ops
154  // => Total number of floating point ops per site above is
155  // 3*18 + 12*198 = 54 + 2376 = 2430
156  {
157  F -= conj(F); // 18 real subtractions + one matrix conjugation
158  F *= static_cast<Float>(0.125); // 18 real multiplications
159  // 36 floating point operations here
160  }
161 
162  constexpr int munu_idx = (mu * (mu - 1)) / 2 + nu; // lower-triangular indexing
163  arg.f(munu_idx, idx, parity) = F;
164  }
165 
166  template <typename Float, typename Arg> __global__ void computeFmunuKernel(Arg arg)
167  {
168  int x_cb = threadIdx.x + blockIdx.x * blockDim.x;
169  int parity = threadIdx.y + blockIdx.y * blockDim.y;
170  int mu_nu = threadIdx.z + blockIdx.z * blockDim.z;
171  if (x_cb >= arg.threads) return;
172  if (mu_nu >= 6) return;
173 
174  switch (mu_nu) { // F[1,0], F[2,0], F[2,1], F[3,0], F[3,1], F[3,2]
175  case 0: computeFmunuCore<1, 0, Float>(arg, x_cb, parity); break;
176  case 1: computeFmunuCore<2, 0, Float>(arg, x_cb, parity); break;
177  case 2: computeFmunuCore<2, 1, Float>(arg, x_cb, parity); break;
178  case 3: computeFmunuCore<3, 0, Float>(arg, x_cb, parity); break;
179  case 4: computeFmunuCore<3, 1, Float>(arg, x_cb, parity); break;
180  case 5: computeFmunuCore<3, 2, Float>(arg, x_cb, parity); break;
181  }
182  }
183 
184  template <typename Float, typename Arg> void computeFmunuCPU(Arg &arg)
185  {
186  for (int parity = 0; parity < 2; parity++) {
187  for (int x_cb = 0; x_cb < arg.threads; x_cb++) {
188  for (int mu = 0; mu < 4; mu++) {
189  for (int nu = 0; nu < mu; nu++) {
190  int mu_nu = (mu * (mu - 1)) / 2 + nu;
191  switch (mu_nu) { // F[1,0], F[2,0], F[2,1], F[3,0], F[3,1], F[3,2]
192  case 0: computeFmunuCore<1, 0, Float>(arg, x_cb, parity); break;
193  case 1: computeFmunuCore<2, 0, Float>(arg, x_cb, parity); break;
194  case 2: computeFmunuCore<2, 1, Float>(arg, x_cb, parity); break;
195  case 3: computeFmunuCore<3, 0, Float>(arg, x_cb, parity); break;
196  case 4: computeFmunuCore<3, 1, Float>(arg, x_cb, parity); break;
197  case 5: computeFmunuCore<3, 2, Float>(arg, x_cb, parity); break;
198  }
199  }
200  }
201  }
202  }
203  }
204 
205 } // namespace quda
double mu
Definition: test_util.cpp:1648
static __device__ __host__ int linkIndexShift(const I x[], const J dx[], const K X[4])
Main header file for host and device accessors to GaugeFields.
__device__ __host__ __forceinline__ void computeFmunuCore(Arg &arg, int idx, int parity)
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
__global__ void computeFmunuKernel(Arg arg)
void computeFmunuCPU(Arg &arg)
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:130
FmunuArg(Fmunu &f, Gauge &gauge, const GaugeField &meta, const GaugeField &meta_ex)
QudaParity parity
Definition: covdev_test.cpp:54
__host__ __device__ int getCoords(int coord[], const Arg &arg, int &idx, int parity, int &dim)
Compute the space-time coordinates we are at.
const int * X() const