49 DslashArg<Float>(in, U, parity, dagger, a == 0.0 ? false : true, improved_ ? 3 : 1, spin_project, comm_override),
51 in(in, improved_ ? 3 : 1),
56 tboundary(U.TBoundary()),
57 is_first_time_slice(
comm_coord(3) == 0 ? true : false),
75 template <
typename Float,
int nDim,
int nColor,
int nParity,
bool dagger, KernelType kernel_type,
typename Arg,
typename Vector>
81 const int their_spinor_parity = (arg.
nParity == 2) ? 1 - parity : 0;
84 for (
int d = 0; d < 4; d++) {
88 const bool ghost = (coord[d] + 1 >= arg.dim[d]) && isActive<kernel_type>(active, thread_dim, d, coord, arg);
89 if (doHalo<kernel_type>(d) &&
ghost) {
90 const int ghost_idx = ghostFaceIndexStaggered<1>(coord, arg.dim, d, 1);
91 const Link
U = arg.improved ? arg.U(d, x_cb, parity) : arg.U(d, x_cb, parity,
StaggeredPhase(coord, d, +1, arg));
92 Vector in = arg.in.Ghost(d, 1, ghost_idx, their_spinor_parity);
95 if (x_cb == 0 && parity == 0 && d == 0)
printLink(U);
96 }
else if (doBulk<kernel_type>() && !ghost) {
98 const Link
U = arg.improved ? arg.U(d, x_cb, parity) : arg.U(d, x_cb, parity,
StaggeredPhase(coord, d, +1, arg));
99 Vector in = arg.in(fwd_idx, their_spinor_parity);
106 const bool ghost = (coord[d] + 3 >= arg.dim[d]) && isActive<kernel_type>(active, thread_dim, d, coord, arg);
107 if (doHalo<kernel_type>(d) &&
ghost) {
108 const int ghost_idx = ghostFaceIndexStaggered<1>(coord, arg.dim, d, arg.nFace);
109 const Link
L = arg.L(d, x_cb, parity);
110 const Vector in = arg.in.Ghost(d, 1, ghost_idx, their_spinor_parity);
112 }
else if (doBulk<kernel_type>() && !ghost) {
113 const int fwd3_idx =
linkIndexP3(coord, arg.dim, d);
114 const Link
L = arg.L(d, x_cb, parity);
115 const Vector in = arg.in(fwd3_idx, their_spinor_parity);
122 const bool ghost = (coord[d] - 1 < 0) && isActive<kernel_type>(active, thread_dim, d, coord, arg);
124 if (doHalo<kernel_type>(d) &&
ghost) {
125 const int ghost_idx2 = ghostFaceIndexStaggered<0>(coord, arg.dim, d, 1);
126 const int ghost_idx = arg.improved ? ghostFaceIndexStaggered<0>(coord, arg.dim, d, 3) : ghost_idx2;
127 const int back_idx =
linkIndexM1(coord, arg.dim, d);
128 const Link
U = arg.improved ? arg.U.Ghost(d, ghost_idx2, 1 - parity) :
129 arg.U.Ghost(d, ghost_idx2, 1 - parity,
StaggeredPhase(coord, d, -1, arg));
130 Vector in = arg.in.Ghost(d, 0, ghost_idx, their_spinor_parity);
132 }
else if (doBulk<kernel_type>() && !
ghost) {
133 const int back_idx =
linkIndexM1(coord, arg.dim, d);
134 const int gauge_idx = back_idx;
135 const Link
U = arg.improved ? arg.U(d, gauge_idx, 1 - parity) :
136 arg.U(d, gauge_idx, 1 - parity,
StaggeredPhase(coord, d, -1, arg));
137 Vector in = arg.in(back_idx, their_spinor_parity);
144 const bool ghost = (coord[d] - 3 < 0) && isActive<kernel_type>(active, thread_dim, d, coord, arg);
145 if (doHalo<kernel_type>(d) &&
ghost) {
147 const int ghost_idx = ghostFaceIndexStaggered<0>(coord, arg.dim, d, 1);
148 const Link
L = arg.L.Ghost(d, ghost_idx, 1 - parity);
149 const Vector in = arg.in.Ghost(d, 0, ghost_idx, their_spinor_parity);
151 }
else if (doBulk<kernel_type>() && !ghost) {
152 const int back3_idx =
linkIndexM3(coord, arg.dim, d);
153 const int gauge_idx = back3_idx;
154 const Link
L = arg.L(d, gauge_idx, 1 - parity);
155 const Vector in = arg.in(back3_idx, their_spinor_parity);
163 template <
typename Float,
int nDim,
int nColor,
int nParity,
bool dagger,
bool xpay, KernelType kernel_type,
typename Arg>
173 int x_cb = arg.improved ? getCoords<nDim, QUDA_4D_PC, kernel_type, Arg, 3>(coord,
arg, idx,
parity, thread_dim) :
174 getCoords<nDim, QUDA_4D_PC, kernel_type, Arg, 1>(coord, arg, idx, parity, thread_dim);
180 applyStaggered<Float, nDim, nColor, nParity, dagger, kernel_type>(
181 out,
arg, coord, x_cb,
parity, idx, thread_dim, active);
186 Vector x = arg.x(x_cb, my_spinor_parity);
187 out = arg.a * x -
out;
189 Vector x = arg.out(x_cb, my_spinor_parity);
196 template <
typename Float,
int nDim,
int nColor,
int nParity,
bool dagger,
bool xpay, KernelType kernel_type,
typename Arg>
199 int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
200 if (x_cb >= arg.threads)
return;
203 int parity =
nParity == 2 ? blockDim.z * blockIdx.z + threadIdx.z : arg.parity;
206 case 0: staggered<Float, nDim, nColor, nParity, dagger, xpay, kernel_type>(
arg, x_cb, 0);
break;
207 case 1: staggered<Float, nDim, nColor, nParity, dagger, xpay, kernel_type>(
arg, x_cb, 1);
break;
mapper< Float >::type real
static constexpr bool spin_project
QudaGaugeFieldOrder FieldOrder() const
__device__ __host__ void applyStaggered(Vector &out, Arg &arg, int coord[nDim], int x_cb, int parity, int idx, int thread_dim, bool &active)
Applies the off-diagonal part of the Staggered / Asqtad operator.
static constexpr QudaReconstructType reconstruct_u
__global__ void staggeredGPU(Arg arg)
typename gauge_mapper< Float, reconstruct_u, 18, phase, gauge_direct_load, ghost, use_inphase >::type GU
typename gauge_mapper< Float, reconstruct_l, 18, QUDA_STAGGERED_PHASE_NO, gauge_direct_load, ghost, use_inphase >::type GL
__host__ __device__ void printLink(const Matrix< Cmplx, 3 > &link)
static constexpr QudaStaggeredPhase phase
static constexpr bool gauge_direct_load
static constexpr bool spinor_direct_load
__device__ __host__ void staggered(Arg &arg, int idx, int parity)
static __device__ __host__ int linkIndexP3(const int x[], const I X[4], const int mu)
static __device__ __host__ int linkIndexM1(const int x[], const I X[4], const int mu)
static __device__ __host__ int linkIndexM3(const int x[], const I X[4], const int mu)
enum QudaStaggeredPhase_s QudaStaggeredPhase
static constexpr QudaReconstructType reconstruct_l
enum QudaGhostExchange_s QudaGhostExchange
Main header file for host and device accessors to GaugeFields.
typename colorspinor_mapper< Float, nSpin, nColor, spin_project, spinor_direct_load >::type F
Parameter structure for driving the Staggered Dslash operator.
static constexpr QudaGhostExchange ghost
static constexpr int nSpin
enum QudaReconstructType_s QudaReconstructType
const bool is_first_time_slice
const bool is_last_time_slice
static constexpr bool use_inphase
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
static constexpr bool improved
__device__ __host__ auto StaggeredPhase(const int coords[], int dim, int dir, const Arg &arg) -> typename Arg::real
Compute the staggered phase factor at unit shift from the current lattice coordinates. The routine below optimizes out the shift where possible, hence is only visible where we need to consider the boundary condition.
__host__ __device__ ValueType conj(ValueType x)
static __device__ __host__ int linkIndexP1(const int x[], const I X[4], const int mu)
StaggeredArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const GaugeField &L, double a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override)
QudaFieldOrder FieldOrder() const