16 template <
typename Float,
int nColor, QudaReconstructType reconstruct_>
struct WilsonArg :
DslashArg<Float> {
37 DslashArg<Float>(in, U, parity, dagger, a != 0.0 ? true : false, 1, spin_project, comm_override),
61 template <
typename Float,
int nDim,
int nColor,
int nParity,
bool dagger, KernelType kernel_type,
typename Arg,
typename Vector>
63 Vector &
out,
Arg &
arg,
int coord[nDim],
int x_cb,
int s,
int parity,
int idx,
int thread_dim,
bool &active)
68 const int their_spinor_parity =
nParity == 2 ? 1 -
parity : 0;
71 const int gauge_parity = (nDim == 5 ? (x_cb / arg.dc.volume_4d_cb +
parity) % 2 : parity);
74 for (
int d = 0; d < 4; d++) {
76 const int fwd_idx = getNeighborIndexCB<nDim>(coord, d, +1, arg.dc);
77 const int gauge_idx = (nDim == 5 ? x_cb % arg.dc.volume_4d_cb : x_cb);
78 constexpr
int proj_dir =
dagger ? +1 : -1;
81 = (coord[d] + arg.nFace >= arg.dim[d]) && isActive<kernel_type>(active, thread_dim, d, coord, arg);
83 if (doHalo<kernel_type>(d) &&
ghost) {
86 ghostFaceIndex<1, nDim>(coord, arg.dim, d, arg.nFace) :
89 Link
U = arg.U(d, gauge_idx, gauge_parity);
90 HalfVector
in = arg.in.Ghost(d, 1, ghost_idx + s * arg.dc.ghostFaceCB[d], their_spinor_parity);
91 if (d == 3) in *= arg.t_proj_scale;
94 }
else if (doBulk<kernel_type>() && !
ghost) {
96 Link
U = arg.U(d, gauge_idx, gauge_parity);
97 Vector in = arg.in(fwd_idx + s * arg.dc.volume_4d_cb, their_spinor_parity);
99 out += (U * in.project(d, proj_dir)).
reconstruct(d, proj_dir);
104 const int back_idx = getNeighborIndexCB<nDim>(coord, d, -1, arg.dc);
105 const int gauge_idx = (nDim == 5 ? back_idx % arg.dc.volume_4d_cb : back_idx);
106 constexpr
int proj_dir =
dagger ? -1 : +1;
108 const bool ghost = (coord[d] - arg.nFace < 0) && isActive<kernel_type>(active, thread_dim, d, coord, arg);
110 if (doHalo<kernel_type>(d) &&
ghost) {
113 ghostFaceIndex<0, nDim>(coord, arg.dim, d, arg.nFace) :
116 const int gauge_ghost_idx = (nDim == 5 ? ghost_idx % arg.dc.ghostFaceCB[d] : ghost_idx);
117 Link
U = arg.U.Ghost(d, gauge_ghost_idx, 1 - gauge_parity);
118 HalfVector
in = arg.in.Ghost(d, 0, ghost_idx + s * arg.dc.ghostFaceCB[d], their_spinor_parity);
119 if (d == 3) in *= arg.t_proj_scale;
122 }
else if (doBulk<kernel_type>() && !
ghost) {
124 Link
U = arg.U(d, gauge_idx, 1 - gauge_parity);
125 Vector in = arg.in(back_idx + s * arg.dc.volume_4d_cb, their_spinor_parity);
134 template <
typename Float,
int nDim,
int nColor,
int nParity,
bool dagger,
bool xpay, KernelType kernel_type,
typename Arg>
144 int x_cb = getCoords<nDim, QUDA_4D_PC, kernel_type>(coord,
arg, idx,
parity, thread_dim);
148 applyWilson<Float, nDim, nColor, nParity, dagger, kernel_type>(
151 int xs = x_cb + s * arg.dc.volume_4d_cb;
153 Vector
x = arg.x(xs, my_spinor_parity);
154 out = x + arg.a *
out;
156 Vector
x = arg.out(xs, my_spinor_parity);
164 template <
typename Float,
int nDim,
int nColor,
int nParity,
bool dagger,
bool xpay, KernelType kernel_type,
typename Arg>
172 for (
int x_cb = 0; x_cb < arg.threads; x_cb++) {
173 wilson<Float, nDim, nColor, nParity, dagger, xpay, kernel_type>(
arg, x_cb, 0,
parity);
179 template <
typename Float,
int nDim,
int nColor,
int nParity,
bool dagger,
bool xpay, KernelType kernel_type,
typename Arg>
182 int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
183 if (x_cb >= arg.threads)
return;
186 int parity =
nParity == 2 ? blockDim.z * blockIdx.z + threadIdx.z : arg.parity;
189 case 0: wilson<Float, nDim, nColor, nParity, dagger, xpay, kernel_type>(
arg, x_cb, 0, 0);
break;
190 case 1: wilson<Float, nDim, nColor, nParity, dagger, xpay, kernel_type>(
arg, x_cb, 0, 1);
break;
QudaGaugeFieldOrder FieldOrder() const
static constexpr bool gauge_direct_load
static constexpr QudaGhostExchange ghost
static constexpr bool spinor_direct_load
static constexpr QudaReconstructType reconstruct
Parameter structure for driving the Wilson operator.
__device__ __host__ void wilson(Arg &arg, int idx, int s, int parity)
static constexpr bool spin_project
enum QudaGhostExchange_s QudaGhostExchange
mapper< Float >::type real
Main header file for host and device accessors to GaugeFields.
gauge_mapper< Float, reconstruct, 18, QUDA_STAGGERED_PHASE_NO, gauge_direct_load, ghost >::type G
enum QudaReconstructType_s QudaReconstructType
__global__ void wilsonGPU(Arg arg)
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
static constexpr int nSpin
colorspinor_mapper< Float, nSpin, nColor, spin_project, spinor_direct_load >::type F
__host__ __device__ ValueType conj(ValueType x)
QudaFieldOrder FieldOrder() const
__device__ __host__ void applyWilson(Vector &out, Arg &arg, int coord[nDim], int x_cb, int s, int parity, int idx, int thread_dim, bool &active)
Applies the off-diagonal part of the Wilson operator.
WilsonArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override)