QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
dslash_wilson.cuh
Go to the documentation of this file.
1 #pragma once
2 
3 #include <dslash_helper.cuh>
5 #include <gauge_field_order.h>
6 #include <color_spinor.h>
7 #include <dslash_helper.cuh>
8 #include <index_helper.cuh>
9 
10 namespace quda
11 {
12 
16  template <typename Float, int nColor, QudaReconstructType reconstruct_> struct WilsonArg : DslashArg<Float> {
17  static constexpr int nSpin = 4;
18  static constexpr bool spin_project = true;
19  static constexpr bool spinor_direct_load = false; // false means texture load
21 
22  static constexpr QudaReconstructType reconstruct = reconstruct_;
23  static constexpr bool gauge_direct_load = false; // false means texture load
26 
27  typedef typename mapper<Float>::type real;
28 
29  F out;
30  const F in;
31  const F x;
32  const G U;
33  const real a;
35  WilsonArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a,
36  const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) :
37  DslashArg<Float>(in, U, parity, dagger, a != 0.0 ? true : false, 1, spin_project, comm_override),
38  out(out),
39  in(in),
40  U(U),
41  x(x),
42  a(a)
43  {
44  if (!out.isNative() || !x.isNative() || !in.isNative() || !U.isNative())
45  errorQuda("Unsupported field order colorspinor=%d gauge=%d combination\n", in.FieldOrder(), U.FieldOrder());
46  }
47  };
48 
61  template <typename Float, int nDim, int nColor, int nParity, bool dagger, KernelType kernel_type, typename Arg, typename Vector>
62  __device__ __host__ inline void applyWilson(
63  Vector &out, Arg &arg, int coord[nDim], int x_cb, int s, int parity, int idx, int thread_dim, bool &active)
64  {
65  typedef typename mapper<Float>::type real;
66  typedef ColorSpinor<real, nColor, 2> HalfVector;
67  typedef Matrix<complex<real>, nColor> Link;
68  const int their_spinor_parity = nParity == 2 ? 1 - parity : 0;
69 
70  // parity for gauge field - include residual parity from 5-d => 4-d checkerboarding
71  const int gauge_parity = (nDim == 5 ? (x_cb / arg.dc.volume_4d_cb + parity) % 2 : parity);
72 
73 #pragma unroll 4
74  for (int d = 0; d < 4; d++) { // loop over dimension
75  { // Forward gather - compute fwd offset for vector fetch
76  const int fwd_idx = getNeighborIndexCB<nDim>(coord, d, +1, arg.dc);
77  const int gauge_idx = (nDim == 5 ? x_cb % arg.dc.volume_4d_cb : x_cb);
78  constexpr int proj_dir = dagger ? +1 : -1;
79 
80  const bool ghost
81  = (coord[d] + arg.nFace >= arg.dim[d]) && isActive<kernel_type>(active, thread_dim, d, coord, arg);
82 
83  if (doHalo<kernel_type>(d) && ghost) {
84  // we need to compute the face index if we are updating a face that isn't ours
85  const int ghost_idx = (kernel_type == EXTERIOR_KERNEL_ALL && d != thread_dim) ?
86  ghostFaceIndex<1, nDim>(coord, arg.dim, d, arg.nFace) :
87  idx;
88 
89  Link U = arg.U(d, gauge_idx, gauge_parity);
90  HalfVector in = arg.in.Ghost(d, 1, ghost_idx + s * arg.dc.ghostFaceCB[d], their_spinor_parity);
91  if (d == 3) in *= arg.t_proj_scale; // put this in the Ghost accessor and merge with any rescaling?
92 
93  out += (U * in).reconstruct(d, proj_dir);
94  } else if (doBulk<kernel_type>() && !ghost) {
95 
96  Link U = arg.U(d, gauge_idx, gauge_parity);
97  Vector in = arg.in(fwd_idx + s * arg.dc.volume_4d_cb, their_spinor_parity);
98 
99  out += (U * in.project(d, proj_dir)).reconstruct(d, proj_dir);
100  }
101  }
102 
103  { // Backward gather - compute back offset for spinor and gauge fetch
104  const int back_idx = getNeighborIndexCB<nDim>(coord, d, -1, arg.dc);
105  const int gauge_idx = (nDim == 5 ? back_idx % arg.dc.volume_4d_cb : back_idx);
106  constexpr int proj_dir = dagger ? -1 : +1;
107 
108  const bool ghost = (coord[d] - arg.nFace < 0) && isActive<kernel_type>(active, thread_dim, d, coord, arg);
109 
110  if (doHalo<kernel_type>(d) && ghost) {
111  // we need to compute the face index if we are updating a face that isn't ours
112  const int ghost_idx = (kernel_type == EXTERIOR_KERNEL_ALL && d != thread_dim) ?
113  ghostFaceIndex<0, nDim>(coord, arg.dim, d, arg.nFace) :
114  idx;
115 
116  const int gauge_ghost_idx = (nDim == 5 ? ghost_idx % arg.dc.ghostFaceCB[d] : ghost_idx);
117  Link U = arg.U.Ghost(d, gauge_ghost_idx, 1 - gauge_parity);
118  HalfVector in = arg.in.Ghost(d, 0, ghost_idx + s * arg.dc.ghostFaceCB[d], their_spinor_parity);
119  if (d == 3) in *= arg.t_proj_scale;
120 
121  out += (conj(U) * in).reconstruct(d, proj_dir);
122  } else if (doBulk<kernel_type>() && !ghost) {
123 
124  Link U = arg.U(d, gauge_idx, 1 - gauge_parity);
125  Vector in = arg.in(back_idx + s * arg.dc.volume_4d_cb, their_spinor_parity);
126 
127  out += (conj(U) * in.project(d, proj_dir)).reconstruct(d, proj_dir);
128  }
129  }
130  } // nDim
131  }
132 
133  // out(x) = M*in = (-D + m) * in(x-mu)
134  template <typename Float, int nDim, int nColor, int nParity, bool dagger, bool xpay, KernelType kernel_type, typename Arg>
135  __device__ __host__ inline void wilson(Arg &arg, int idx, int s, int parity)
136  {
137  typedef typename mapper<Float>::type real;
139 
140  bool active
141  = kernel_type == EXTERIOR_KERNEL_ALL ? false : true; // is thread active (non-trival for fused kernel only)
142  int thread_dim; // which dimension is thread working on (fused kernel only)
143  int coord[nDim];
144  int x_cb = getCoords<nDim, QUDA_4D_PC, kernel_type>(coord, arg, idx, parity, thread_dim);
145 
146  const int my_spinor_parity = nParity == 2 ? parity : 0;
147  Vector out;
148  applyWilson<Float, nDim, nColor, nParity, dagger, kernel_type>(
149  out, arg, coord, x_cb, s, parity, idx, thread_dim, active);
150 
151  int xs = x_cb + s * arg.dc.volume_4d_cb;
152  if (xpay && kernel_type == INTERIOR_KERNEL) {
153  Vector x = arg.x(xs, my_spinor_parity);
154  out = x + arg.a * out;
155  } else if (kernel_type != INTERIOR_KERNEL && active) {
156  Vector x = arg.out(xs, my_spinor_parity);
157  out = x + (xpay ? arg.a * out : out);
158  }
159 
160  if (kernel_type != EXTERIOR_KERNEL_ALL || active) arg.out(xs, my_spinor_parity) = out;
161  }
162 
163  // CPU kernel for applying the Wilson operator to a vector
164  template <typename Float, int nDim, int nColor, int nParity, bool dagger, bool xpay, KernelType kernel_type, typename Arg>
166  {
167 
168  for (int parity = 0; parity < nParity; parity++) {
169  // for full fields then set parity from loop else use arg setting
170  parity = nParity == 2 ? parity : arg.parity;
171 
172  for (int x_cb = 0; x_cb < arg.threads; x_cb++) { // 4-d volume
173  wilson<Float, nDim, nColor, nParity, dagger, xpay, kernel_type>(arg, x_cb, 0, parity);
174  } // 4-d volumeCB
175  } // parity
176  }
177 
178  // GPU Kernel for applying the Wilson operator to a vector
179  template <typename Float, int nDim, int nColor, int nParity, bool dagger, bool xpay, KernelType kernel_type, typename Arg>
180  __global__ void wilsonGPU(Arg arg)
181  {
182  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
183  if (x_cb >= arg.threads) return;
184 
185  // for full fields set parity from z thread index else use arg setting
186  int parity = nParity == 2 ? blockDim.z * blockIdx.z + threadIdx.z : arg.parity;
187 
188  switch (parity) {
189  case 0: wilson<Float, nDim, nColor, nParity, dagger, xpay, kernel_type>(arg, x_cb, 0, 0); break;
190  case 1: wilson<Float, nDim, nColor, nParity, dagger, xpay, kernel_type>(arg, x_cb, 0, 1); break;
191  }
192  }
193 
194 } // namespace quda
KernelType kernel_type
QudaGaugeFieldOrder FieldOrder() const
Definition: gauge_field.h:257
static constexpr bool gauge_direct_load
#define errorQuda(...)
Definition: util_quda.h:121
static constexpr QudaGhostExchange ghost
static constexpr bool spinor_direct_load
static constexpr QudaReconstructType reconstruct
Parameter structure for driving the Wilson operator.
void wilsonCPU(Arg arg)
__device__ __host__ void wilson(Arg &arg, int idx, int s, int parity)
const int nColor
Definition: covdev_test.cpp:75
static constexpr bool spin_project
enum QudaGhostExchange_s QudaGhostExchange
mapper< Float >::type real
Main header file for host and device accessors to GaugeFields.
gauge_mapper< Float, reconstruct, 18, QUDA_STAGGERED_PHASE_NO, gauge_direct_load, ghost >::type G
enum QudaReconstructType_s QudaReconstructType
__shared__ float s[]
__global__ void wilsonGPU(Arg arg)
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
VectorXcd Vector
static constexpr int nSpin
colorspinor_mapper< Float, nSpin, nColor, spin_project, spinor_direct_load >::type F
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:130
bool isNative() const
QudaFieldOrder FieldOrder() const
__device__ __host__ void applyWilson(Vector &out, Arg &arg, int coord[nDim], int x_cb, int s, int parity, int idx, int thread_dim, bool &active)
Applies the off-diagonal part of the Wilson operator.
WilsonArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override)