QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
covDev.cuh
Go to the documentation of this file.
1 #pragma once
2 
3 #include <dslash_helper.cuh>
5 #include <gauge_field_order.h>
6 #include <color_spinor.h>
7 #include <dslash_helper.cuh>
8 #include <index_helper.cuh>
9 
10 namespace quda
11 {
12 
16  template <typename Float, int nColor, QudaReconstructType reconstruct_> struct CovDevArg : DslashArg<Float> {
17  static constexpr int nSpin = 4;
18  static constexpr bool spin_project = false;
19  static constexpr bool spinor_direct_load = false; // false means texture load
21 
22  static constexpr QudaReconstructType reconstruct = reconstruct_;
23  static constexpr bool gauge_direct_load = false; // false means texture load
26 
27  typedef typename mapper<Float>::type real;
28 
29  F out;
30  const F in;
31  const G U;
32  int mu;
34  CovDevArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, int mu, int parity, bool dagger,
35  const int *comm_override) :
36 
37  DslashArg<Float>(in, U, parity, dagger, false, 1, spin_project, comm_override),
38  out(out),
39  in(in),
40  U(U),
41  mu(mu)
42  {
43  if (!out.isNative() || !in.isNative() || !U.isNative())
44  errorQuda("Unsupported field order colorspinor(in)=%d gauge=%d combination\n", in.FieldOrder(), U.FieldOrder());
45  }
46  };
47 
61  template <typename Float, int nDim, int nColor, int nParity, bool dagger, KernelType kernel_type, int mu,
62  typename Arg, typename Vector>
63  __device__ __host__ inline void applyCovDev(Vector &out, Arg &arg, int coord[nDim], int x_cb, int parity, int idx,
64  int thread_dim, bool &active)
65  {
66 
67  typedef typename mapper<Float>::type real;
68  typedef Matrix<complex<real>, nColor> Link;
69  const int their_spinor_parity = (arg.nParity == 2) ? 1 - parity : 0;
70 
71  const int d = mu % 4;
72 
73  if (mu < 4) { // Forward gather - compute fwd offset for vector fetch
74 
75  const int fwd_idx = getNeighborIndexCB<nDim>(coord, d, +1, arg.dc);
76  const bool ghost = (coord[d] + 1 >= arg.dim[d]) && isActive<kernel_type>(active, thread_dim, d, coord, arg);
77 
78  const Link U = arg.U(d, x_cb, parity);
79 
80  if (doHalo<kernel_type>(d) && ghost) {
81 
82  const int ghost_idx = ghostFaceIndex<1>(coord, arg.dim, d, arg.nFace);
83  const Vector in = arg.in.Ghost(d, 1, ghost_idx, their_spinor_parity);
84 
85  out += U * in;
86 
87  } else if (doBulk<kernel_type>() && !ghost) {
88 
89  const Vector in = arg.in(fwd_idx, their_spinor_parity);
90  out += U * in;
91  }
92 
93  } else { // Backward gather - compute back offset for spinor and gauge fetch
94 
95  const int back_idx = getNeighborIndexCB<nDim>(coord, d, -1, arg.dc);
96  const int gauge_idx = back_idx;
97 
98  const bool ghost = (coord[d] - 1 < 0) && isActive<kernel_type>(active, thread_dim, d, coord, arg);
99 
100  if (doHalo<kernel_type>(d) && ghost) {
101 
102  const int ghost_idx = ghostFaceIndex<0>(coord, arg.dim, d, arg.nFace);
103  const Link U = arg.U.Ghost(d, ghost_idx, 1 - parity);
104  const Vector in = arg.in.Ghost(d, 0, ghost_idx, their_spinor_parity);
105 
106  out += conj(U) * in;
107  } else if (doBulk<kernel_type>() && !ghost) {
108 
109  const Link U = arg.U(d, gauge_idx, 1 - parity);
110  const Vector in = arg.in(back_idx, their_spinor_parity);
111 
112  out += conj(U) * in;
113  }
114  } // Forward/backward derivative
115  }
116 
117  // out(x) = M*in
118  template <typename Float, int nDim, int nColor, int nParity, bool dagger, KernelType kernel_type, typename Arg>
119  __device__ __host__ inline void covDev(Arg &arg, int idx, int parity)
120  {
121 
122  using real = typename mapper<Float>::type;
124 
125  // is thread active (non-trival for fused kernel only)
126  bool active = kernel_type == EXTERIOR_KERNEL_ALL ? false : true;
127 
128  // which dimension is thread working on (fused kernel only)
129  int thread_dim;
130 
131  int coord[nDim];
132  int x_cb = getCoords<nDim, QUDA_4D_PC, kernel_type, Arg>(coord, arg, idx, parity, thread_dim);
133 
134  const int my_spinor_parity = nParity == 2 ? parity : 0;
135  Vector out;
136 
137  switch (arg.mu) { // ensure that mu is known to compiler for indexing in applyCovDev (avoid register spillage)
138  case 0:
139  applyCovDev<Float, nDim, nColor, nParity, dagger, kernel_type, 0>(out, arg, coord, x_cb, parity, idx, thread_dim,
140  active);
141  break;
142  case 1:
143  applyCovDev<Float, nDim, nColor, nParity, dagger, kernel_type, 1>(out, arg, coord, x_cb, parity, idx, thread_dim,
144  active);
145  break;
146  case 2:
147  applyCovDev<Float, nDim, nColor, nParity, dagger, kernel_type, 2>(out, arg, coord, x_cb, parity, idx, thread_dim,
148  active);
149  break;
150  case 3:
151  applyCovDev<Float, nDim, nColor, nParity, dagger, kernel_type, 3>(out, arg, coord, x_cb, parity, idx, thread_dim,
152  active);
153  break;
154  case 4:
155  applyCovDev<Float, nDim, nColor, nParity, dagger, kernel_type, 4>(out, arg, coord, x_cb, parity, idx, thread_dim,
156  active);
157  break;
158  case 5:
159  applyCovDev<Float, nDim, nColor, nParity, dagger, kernel_type, 5>(out, arg, coord, x_cb, parity, idx, thread_dim,
160  active);
161  break;
162  case 6:
163  applyCovDev<Float, nDim, nColor, nParity, dagger, kernel_type, 6>(out, arg, coord, x_cb, parity, idx, thread_dim,
164  active);
165  break;
166  case 7:
167  applyCovDev<Float, nDim, nColor, nParity, dagger, kernel_type, 7>(out, arg, coord, x_cb, parity, idx, thread_dim,
168  active);
169  break;
170  }
171 
172  if (kernel_type != INTERIOR_KERNEL) {
173  Vector x = arg.out(x_cb, my_spinor_parity);
174  out += x;
175  }
176 
177  if (kernel_type != EXTERIOR_KERNEL_ALL || active) arg.out(x_cb, my_spinor_parity) = out;
178  }
179 
180  // GPU Kernel for applying the covariant derivative operator to a vector
181  template <typename Float, int nDim, int nColor, int nParity, bool dagger, bool xpay, KernelType kernel_type, typename Arg>
182  __global__ void covDevGPU(Arg arg)
183  {
184  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
185  if (x_cb >= arg.threads) return;
186 
187  // for full fields set parity from z thread index else use arg setting
188  int parity = nParity == 2 ? blockDim.z * blockIdx.z + threadIdx.z : arg.parity;
189 
190  switch (parity) {
191  case 0: covDev<Float, nDim, nColor, nParity, dagger, kernel_type>(arg, x_cb, 0); break;
192  case 1: covDev<Float, nDim, nColor, nParity, dagger, kernel_type>(arg, x_cb, 1); break;
193  }
194  }
195 } // namespace quda
KernelType kernel_type
QudaGaugeFieldOrder FieldOrder() const
Definition: gauge_field.h:257
#define errorQuda(...)
Definition: util_quda.h:121
__global__ void covDevGPU(Arg arg)
Definition: covDev.cuh:182
const F in
Definition: covDev.cuh:30
colorspinor_mapper< Float, nSpin, nColor, spin_project, spinor_direct_load >::type F
Definition: covDev.cuh:20
static constexpr bool spin_project
Definition: covDev.cuh:18
const G U
Definition: covDev.cuh:31
static constexpr bool spinor_direct_load
Definition: covDev.cuh:19
mapper< Float >::type real
Definition: covDev.cuh:27
__device__ __host__ void covDev(Arg &arg, int idx, int parity)
Definition: covDev.cuh:119
__device__ __host__ void applyCovDev(Vector &out, Arg &arg, int coord[nDim], int x_cb, int parity, int idx, int thread_dim, bool &active)
Definition: covDev.cuh:63
const int nColor
Definition: covdev_test.cpp:75
enum QudaGhostExchange_s QudaGhostExchange
Parameter structure for driving the covariatnt derivative operator.
Definition: covDev.cuh:16
Main header file for host and device accessors to GaugeFields.
const int nParity
Definition: spinor_noise.cu:25
CovDevArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, int mu, int parity, bool dagger, const int *comm_override)
Definition: covDev.cuh:34
enum QudaReconstructType_s QudaReconstructType
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
VectorXcd Vector
gauge_mapper< Float, reconstruct, 18, QUDA_STAGGERED_PHASE_NO, gauge_direct_load, ghost >::type G
Definition: covDev.cuh:25
static constexpr int nSpin
Definition: covDev.cuh:17
static constexpr QudaReconstructType reconstruct
Definition: covDev.cuh:22
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:130
static constexpr QudaGhostExchange ghost
Definition: covDev.cuh:24
bool isNative() const
QudaFieldOrder FieldOrder() const
static constexpr bool gauge_direct_load
Definition: covDev.cuh:23