16 template <
typename Float,
int nColor, QudaReconstructType reconstruct_>
struct CovDevArg :
DslashArg<Float> {
35 const int *comm_override) :
37 DslashArg<Float>(in, U, parity, dagger, false, 1, spin_project, comm_override),
64 int thread_dim,
bool &active)
69 const int their_spinor_parity = (arg.
nParity == 2) ? 1 - parity : 0;
75 const int fwd_idx = getNeighborIndexCB<nDim>(coord, d, +1, arg.dc);
76 const bool ghost = (coord[d] + 1 >= arg.dim[d]) && isActive<kernel_type>(active, thread_dim, d, coord, arg);
78 const Link
U = arg.U(d, x_cb, parity);
80 if (doHalo<kernel_type>(d) && ghost) {
82 const int ghost_idx = ghostFaceIndex<1>(coord, arg.dim, d, arg.nFace);
83 const Vector in = arg.in.Ghost(d, 1, ghost_idx, their_spinor_parity);
87 }
else if (doBulk<kernel_type>() && !ghost) {
89 const Vector in = arg.in(fwd_idx, their_spinor_parity);
95 const int back_idx = getNeighborIndexCB<nDim>(coord, d, -1, arg.dc);
96 const int gauge_idx = back_idx;
98 const bool ghost = (coord[d] - 1 < 0) && isActive<kernel_type>(active, thread_dim, d, coord, arg);
100 if (doHalo<kernel_type>(d) &&
ghost) {
102 const int ghost_idx = ghostFaceIndex<0>(coord, arg.dim, d, arg.nFace);
103 const Link
U = arg.U.Ghost(d, ghost_idx, 1 - parity);
104 const Vector in = arg.in.Ghost(d, 0, ghost_idx, their_spinor_parity);
107 }
else if (doBulk<kernel_type>() && !ghost) {
109 const Link
U = arg.U(d, gauge_idx, 1 - parity);
110 const Vector in = arg.in(back_idx, their_spinor_parity);
118 template <
typename Float,
int nDim,
int nColor,
int nParity,
bool dagger, KernelType kernel_type,
typename Arg>
132 int x_cb = getCoords<nDim, QUDA_4D_PC, kernel_type, Arg>(coord,
arg, idx,
parity, thread_dim);
134 const int my_spinor_parity = nParity == 2 ?
parity : 0;
139 applyCovDev<Float, nDim, nColor, nParity, dagger, kernel_type, 0>(
out,
arg, coord, x_cb,
parity, idx, thread_dim,
143 applyCovDev<Float, nDim, nColor, nParity, dagger, kernel_type, 1>(
out,
arg, coord, x_cb,
parity, idx, thread_dim,
147 applyCovDev<Float, nDim, nColor, nParity, dagger, kernel_type, 2>(
out,
arg, coord, x_cb,
parity, idx, thread_dim,
151 applyCovDev<Float, nDim, nColor, nParity, dagger, kernel_type, 3>(
out,
arg, coord, x_cb,
parity, idx, thread_dim,
155 applyCovDev<Float, nDim, nColor, nParity, dagger, kernel_type, 4>(
out,
arg, coord, x_cb,
parity, idx, thread_dim,
159 applyCovDev<Float, nDim, nColor, nParity, dagger, kernel_type, 5>(
out,
arg, coord, x_cb,
parity, idx, thread_dim,
163 applyCovDev<Float, nDim, nColor, nParity, dagger, kernel_type, 6>(
out,
arg, coord, x_cb,
parity, idx, thread_dim,
167 applyCovDev<Float, nDim, nColor, nParity, dagger, kernel_type, 7>(
out,
arg, coord, x_cb,
parity, idx, thread_dim,
173 Vector x = arg.out(x_cb, my_spinor_parity);
181 template <
typename Float,
int nDim,
int nColor,
int nParity,
bool dagger,
bool xpay, KernelType kernel_type,
typename Arg>
184 int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
185 if (x_cb >= arg.threads)
return;
188 int parity = nParity == 2 ? blockDim.z * blockIdx.z + threadIdx.z : arg.parity;
191 case 0: covDev<Float, nDim, nColor, nParity, dagger, kernel_type>(
arg, x_cb, 0);
break;
192 case 1: covDev<Float, nDim, nColor, nParity, dagger, kernel_type>(
arg, x_cb, 1);
break;
QudaGaugeFieldOrder FieldOrder() const
__global__ void covDevGPU(Arg arg)
colorspinor_mapper< Float, nSpin, nColor, spin_project, spinor_direct_load >::type F
static constexpr bool spin_project
static constexpr bool spinor_direct_load
mapper< Float >::type real
__device__ __host__ void covDev(Arg &arg, int idx, int parity)
__device__ __host__ void applyCovDev(Vector &out, Arg &arg, int coord[nDim], int x_cb, int parity, int idx, int thread_dim, bool &active)
enum QudaGhostExchange_s QudaGhostExchange
Parameter structure for driving the covariatnt derivative operator.
Main header file for host and device accessors to GaugeFields.
CovDevArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, int mu, int parity, bool dagger, const int *comm_override)
enum QudaReconstructType_s QudaReconstructType
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
gauge_mapper< Float, reconstruct, 18, QUDA_STAGGERED_PHASE_NO, gauge_direct_load, ghost >::type G
static constexpr int nSpin
static constexpr QudaReconstructType reconstruct
__host__ __device__ ValueType conj(ValueType x)
static constexpr QudaGhostExchange ghost
QudaFieldOrder FieldOrder() const
static constexpr bool gauge_direct_load