QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
contraction.cuh
Go to the documentation of this file.
1 #pragma once
2 
4 #include <index_helper.cuh>
5 #include <quda_matrix.h>
6 #include <matrix_field.h>
7 #include <su3_project.cuh>
8 
9 namespace quda
10 {
11 
12  template <typename real> struct ContractionArg {
13  int threads; // number of active threads required
14  int X[4]; // grid dimensions
15 
16  static constexpr int nSpin = 4;
17  static constexpr int nColor = 3;
18  static constexpr bool spin_project = true;
19  static constexpr bool spinor_direct_load = false; // false means texture load
20 
21  // Create a typename F for the ColorSpinorField (F for fermion)
23 
24  F x;
25  F y;
27 
28  ContractionArg(const ColorSpinorField &x, const ColorSpinorField &y, complex<real> *s) :
29  threads(x.VolumeCB()),
30  x(x),
31  y(y),
32  s(s, x.VolumeCB())
33  {
34  for (int dir = 0; dir < 4; dir++) X[dir] = x.X()[dir];
35  }
36  };
37 
38  template <typename real, typename Arg> __global__ void computeColorContraction(Arg arg)
39  {
40  int x_cb = threadIdx.x + blockIdx.x * blockDim.x;
41  int parity = threadIdx.y + blockIdx.y * blockDim.y;
42  if (x_cb >= arg.threads) return;
43 
44  constexpr int nSpin = Arg::nSpin;
45  constexpr int nColor = Arg::nColor;
47 
48  Vector x = arg.x(x_cb, parity);
49  Vector y = arg.y(x_cb, parity);
50 
51  Matrix<complex<real>, nSpin> A;
52 #pragma unroll
53  for (int mu = 0; mu < nSpin; mu++) {
54 #pragma unroll
55  for (int nu = 0; nu < nSpin; nu++) {
56  // Color inner product: <\phi(x)_{\mu} | \phi(y)_{\nu}>
57  // The Bra is conjugated
58  A(mu, nu) = innerProduct(x, y, mu, nu);
59  }
60  }
61 
62  arg.s.save(A, x_cb, parity);
63  }
64 
65  template <typename real, typename Arg> __global__ void computeDegrandRossiContraction(Arg arg)
66  {
67  int x_cb = threadIdx.x + blockIdx.x * blockDim.x;
68  int parity = threadIdx.y + blockIdx.y * blockDim.y;
69  const int nSpin = arg.nSpin;
70  const int nColor = arg.nColor;
71 
72  if (x_cb >= arg.threads) return;
73 
75 
76  Vector x = arg.x(x_cb, parity);
77  Vector y = arg.y(x_cb, parity);
78 
79  complex<real> I(0.0, 1.0);
80  complex<real> spin_elem[nSpin][nSpin];
81  complex<real> result_local(0.0, 0.0);
82 
83  // Color contract: <\phi(x)_{\mu} | \phi(y)_{\nu}>
84  // The Bra is conjugated
85  for (int mu = 0; mu < nSpin; mu++) {
86  for (int nu = 0; nu < nSpin; nu++) { spin_elem[mu][nu] = innerProduct(x, y, mu, nu); }
87  }
88 
89  complex<real> A[nSpin * nSpin];
90 
91  // Spin contract: <\phi(x)_{\mu} \Gamma_{mu,nu}^{rho,tau} \phi(y)_{\nu}>
92  // The rho index runs slowest.
93  // Layout is defined in enum_quda.h: G_idx = 4*rho + tau
94  // DMH: Hardcoded to Degrand-Rossi. Need a template on Gamma basis.
95 
96  int G_idx = 0;
97 
98  // SCALAR
99  // G_idx = 0: I
100  result_local = 0.0;
101  result_local += spin_elem[0][0];
102  result_local += spin_elem[1][1];
103  result_local += spin_elem[2][2];
104  result_local += spin_elem[3][3];
105  A[G_idx++] = result_local;
106 
107  // VECTORS
108  // G_idx = 1: \gamma_1
109  result_local = 0.0;
110  result_local += I * spin_elem[0][3];
111  result_local += I * spin_elem[1][2];
112  result_local -= I * spin_elem[2][1];
113  result_local -= I * spin_elem[3][0];
114  A[G_idx++] = result_local;
115 
116  // G_idx = 2: \gamma_2
117  result_local = 0.0;
118  result_local -= spin_elem[0][3];
119  result_local += spin_elem[1][2];
120  result_local += spin_elem[2][1];
121  result_local -= spin_elem[3][0];
122  A[G_idx++] = result_local;
123 
124  // G_idx = 3: \gamma_3
125  result_local = 0.0;
126  result_local += I * spin_elem[0][2];
127  result_local -= I * spin_elem[1][3];
128  result_local -= I * spin_elem[2][0];
129  result_local += I * spin_elem[3][1];
130  A[G_idx++] = result_local;
131 
132  // G_idx = 4: \gamma_4
133  result_local = 0.0;
134  result_local += spin_elem[0][2];
135  result_local += spin_elem[1][3];
136  result_local += spin_elem[2][0];
137  result_local += spin_elem[3][1];
138  A[G_idx++] = result_local;
139 
140  // PSEUDO-SCALAR
141  // G_idx = 5: \gamma_5
142  result_local = 0.0;
143  result_local += spin_elem[0][0];
144  result_local += spin_elem[1][1];
145  result_local -= spin_elem[2][2];
146  result_local -= spin_elem[3][3];
147  A[G_idx++] = result_local;
148 
149  // PSEUDO-VECTORS
150  // DMH: Careful here... we may wish to use \gamma_1,2,3,4\gamma_5 for pseudovectors
151  // G_idx = 6: \gamma_5\gamma_1
152  result_local = 0.0;
153  result_local += I * spin_elem[0][3];
154  result_local += I * spin_elem[1][2];
155  result_local += I * spin_elem[2][1];
156  result_local += I * spin_elem[3][0];
157  A[G_idx++] = result_local;
158 
159  // G_idx = 7: \gamma_5\gamma_2
160  result_local = 0.0;
161  result_local -= spin_elem[0][3];
162  result_local += spin_elem[1][2];
163  result_local -= spin_elem[2][1];
164  result_local += spin_elem[3][0];
165  A[G_idx++] = result_local;
166 
167  // G_idx = 8: \gamma_5\gamma_3
168  result_local = 0.0;
169  result_local += I * spin_elem[0][2];
170  result_local -= I * spin_elem[1][3];
171  result_local += I * spin_elem[2][0];
172  result_local -= I * spin_elem[3][1];
173  A[G_idx++] = result_local;
174 
175  // G_idx = 9: \gamma_5\gamma_4
176  result_local = 0.0;
177  result_local += spin_elem[0][2];
178  result_local += spin_elem[1][3];
179  result_local -= spin_elem[2][0];
180  result_local -= spin_elem[3][1];
181  A[G_idx++] = result_local;
182 
183  // TENSORS
184  // G_idx = 10: (i/2) * [\gamma_1, \gamma_2]
185  result_local = 0.0;
186  result_local += spin_elem[0][0];
187  result_local -= spin_elem[1][1];
188  result_local += spin_elem[2][2];
189  result_local -= spin_elem[3][3];
190  A[G_idx++] = result_local;
191 
192  // G_idx = 11: (i/2) * [\gamma_1, \gamma_3]
193  result_local = 0.0;
194  result_local -= I * spin_elem[0][2];
195  result_local -= I * spin_elem[1][3];
196  result_local += I * spin_elem[2][0];
197  result_local += I * spin_elem[3][1];
198  A[G_idx++] = result_local;
199 
200  // G_idx = 12: (i/2) * [\gamma_1, \gamma_4]
201  result_local = 0.0;
202  result_local -= spin_elem[0][1];
203  result_local -= spin_elem[1][0];
204  result_local += spin_elem[2][3];
205  result_local += spin_elem[3][2];
206  A[G_idx++] = result_local;
207 
208  // G_idx = 13: (i/2) * [\gamma_2, \gamma_3]
209  result_local = 0.0;
210  result_local += spin_elem[0][1];
211  result_local += spin_elem[1][0];
212  result_local += spin_elem[2][3];
213  result_local += spin_elem[3][2];
214  A[G_idx++] = result_local;
215 
216  // G_idx = 14: (i/2) * [\gamma_2, \gamma_4]
217  result_local = 0.0;
218  result_local -= I * spin_elem[0][1];
219  result_local += I * spin_elem[1][0];
220  result_local += I * spin_elem[2][3];
221  result_local -= I * spin_elem[3][2];
222  A[G_idx++] = result_local;
223 
224  // G_idx = 15: (i/2) * [\gamma_3, \gamma_4]
225  result_local = 0.0;
226  result_local -= spin_elem[0][0];
227  result_local -= spin_elem[1][1];
228  result_local += spin_elem[2][2];
229  result_local += spin_elem[3][3];
230  A[G_idx++] = result_local;
231 
232  arg.s.save(A, x_cb, parity);
233  }
234 } // namespace quda
double mu
Definition: test_util.cpp:1648
__global__ void computeColorContraction(Arg arg)
Definition: contraction.cuh:38
matrix_field< complex< real >, nSpin > s
Definition: contraction.cuh:26
static constexpr bool spinor_direct_load
Definition: contraction.cuh:19
const int nColor
Definition: covdev_test.cpp:75
colorspinor_mapper< real, nSpin, nColor, spin_project, spinor_direct_load >::type F
Definition: contraction.cuh:22
__global__ void computeDegrandRossiContraction(Arg arg)
Definition: contraction.cuh:65
static constexpr int nSpin
Definition: contraction.cuh:16
__device__ __host__ complex< Float > innerProduct(const ColorSpinor< Float, Nc, Ns > &a, const ColorSpinor< Float, Nc, Ns > &b)
Compute the inner product over color and spin dot = ,c conj(a(s,c)) * b(s,c)
Definition: color_spinor.h:914
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
VectorXcd Vector
static constexpr int nColor
Definition: contraction.cuh:17
const int * X() const
static constexpr bool spin_project
Definition: contraction.cuh:18
QudaParity parity
Definition: covdev_test.cpp:54
ContractionArg(const ColorSpinorField &x, const ColorSpinorField &y, complex< real > *s)
Definition: contraction.cuh:28