QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
dslash_staggered.cuh
Go to the documentation of this file.
1 #pragma once
2 
3 #include <dslash_helper.cuh>
5 #include <gauge_field_order.h>
6 #include <color_spinor.h>
7 #include <dslash_helper.cuh>
8 #include <index_helper.cuh>
9 
10 namespace quda
11 {
12 
16  template <typename Float, int nColor, QudaReconstructType reconstruct_u_, QudaReconstructType reconstruct_l_,
17  bool improved_, QudaStaggeredPhase phase_ = QUDA_STAGGERED_PHASE_MILC>
18  struct StaggeredArg : DslashArg<Float> {
19  typedef typename mapper<Float>::type real;
20  static constexpr int nSpin = 1;
21  static constexpr bool spin_project = false;
22  static constexpr bool spinor_direct_load = false; // false means texture load
24 
25  static constexpr QudaReconstructType reconstruct_u = reconstruct_u_;
26  static constexpr QudaReconstructType reconstruct_l = reconstruct_l_;
27  static constexpr bool gauge_direct_load = false; // false means texture load
29  static constexpr bool use_inphase = improved_ ? false : true;
30  static constexpr QudaStaggeredPhase phase = phase_;
32  using GL =
34 
35  F out;
36  const F in;
37  const F x;
38  const GU U;
39  const GL L;
41  const real a;
42  const real tboundary;
43  const bool is_first_time_slice;
44  const bool is_last_time_slice;
45  static constexpr bool improved = improved_;
46 
47  StaggeredArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const GaugeField &L, double a,
48  const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) :
49  DslashArg<Float>(in, U, parity, dagger, a == 0.0 ? false : true, improved_ ? 3 : 1, spin_project, comm_override),
50  out(out),
51  in(in, improved_ ? 3 : 1),
52  U(U),
53  L(L),
54  x(x),
55  a(a),
56  tboundary(U.TBoundary()),
57  is_first_time_slice(comm_coord(3) == 0 ? true : false),
58  is_last_time_slice(comm_coord(3) == comm_dim(3) - 1 ? true : false)
59  {
60  if (!out.isNative() || !x.isNative() || !in.isNative() || !U.isNative())
61  errorQuda("Unsupported field order colorspinor=%d gauge=%d combination\n", in.FieldOrder(), U.FieldOrder());
62  }
63  };
64 
75  template <typename Float, int nDim, int nColor, int nParity, bool dagger, KernelType kernel_type, typename Arg, typename Vector>
76  __device__ __host__ inline void applyStaggered(
77  Vector &out, Arg &arg, int coord[nDim], int x_cb, int parity, int idx, int thread_dim, bool &active)
78  {
79  typedef typename mapper<Float>::type real;
80  typedef Matrix<complex<real>, nColor> Link;
81  const int their_spinor_parity = (arg.nParity == 2) ? 1 - parity : 0;
82 
83 #pragma unroll 4
84  for (int d = 0; d < 4; d++) { // loop over dimension
85 
86  // standard - forward direction
87  {
88  const bool ghost = (coord[d] + 1 >= arg.dim[d]) && isActive<kernel_type>(active, thread_dim, d, coord, arg);
89  if (doHalo<kernel_type>(d) && ghost) {
90  const int ghost_idx = ghostFaceIndexStaggered<1>(coord, arg.dim, d, 1);
91  const Link U = arg.improved ? arg.U(d, x_cb, parity) : arg.U(d, x_cb, parity, StaggeredPhase(coord, d, +1, arg));
92  Vector in = arg.in.Ghost(d, 1, ghost_idx, their_spinor_parity);
93  out += (U * in);
94 
95  if (x_cb == 0 && parity == 0 && d == 0) printLink(U);
96  } else if (doBulk<kernel_type>() && !ghost) {
97  const int fwd_idx = linkIndexP1(coord, arg.dim, d);
98  const Link U = arg.improved ? arg.U(d, x_cb, parity) : arg.U(d, x_cb, parity, StaggeredPhase(coord, d, +1, arg));
99  Vector in = arg.in(fwd_idx, their_spinor_parity);
100  out += (U * in);
101  }
102  }
103 
104  // improved - forward direction
105  if (arg.improved) {
106  const bool ghost = (coord[d] + 3 >= arg.dim[d]) && isActive<kernel_type>(active, thread_dim, d, coord, arg);
107  if (doHalo<kernel_type>(d) && ghost) {
108  const int ghost_idx = ghostFaceIndexStaggered<1>(coord, arg.dim, d, arg.nFace);
109  const Link L = arg.L(d, x_cb, parity);
110  const Vector in = arg.in.Ghost(d, 1, ghost_idx, their_spinor_parity);
111  out += L * in;
112  } else if (doBulk<kernel_type>() && !ghost) {
113  const int fwd3_idx = linkIndexP3(coord, arg.dim, d);
114  const Link L = arg.L(d, x_cb, parity);
115  const Vector in = arg.in(fwd3_idx, their_spinor_parity);
116  out += L * in;
117  }
118  }
119 
120  {
121  // Backward gather - compute back offset for spinor and gauge fetch
122  const bool ghost = (coord[d] - 1 < 0) && isActive<kernel_type>(active, thread_dim, d, coord, arg);
123 
124  if (doHalo<kernel_type>(d) && ghost) {
125  const int ghost_idx2 = ghostFaceIndexStaggered<0>(coord, arg.dim, d, 1);
126  const int ghost_idx = arg.improved ? ghostFaceIndexStaggered<0>(coord, arg.dim, d, 3) : ghost_idx2;
127  const int back_idx = linkIndexM1(coord, arg.dim, d);
128  const Link U = arg.improved ? arg.U.Ghost(d, ghost_idx2, 1 - parity) :
129  arg.U.Ghost(d, ghost_idx2, 1 - parity, StaggeredPhase(coord, d, -1, arg));
130  Vector in = arg.in.Ghost(d, 0, ghost_idx, their_spinor_parity);
131  out -= (conj(U) * in);
132  } else if (doBulk<kernel_type>() && !ghost) {
133  const int back_idx = linkIndexM1(coord, arg.dim, d);
134  const int gauge_idx = back_idx;
135  const Link U = arg.improved ? arg.U(d, gauge_idx, 1 - parity) :
136  arg.U(d, gauge_idx, 1 - parity, StaggeredPhase(coord, d, -1, arg));
137  Vector in = arg.in(back_idx, their_spinor_parity);
138  out -= (conj(U) * in);
139  }
140  }
141 
142  // improved - backward direction
143  if (arg.improved) {
144  const bool ghost = (coord[d] - 3 < 0) && isActive<kernel_type>(active, thread_dim, d, coord, arg);
145  if (doHalo<kernel_type>(d) && ghost) {
146  // when updating replace arg.nFace with 1 here
147  const int ghost_idx = ghostFaceIndexStaggered<0>(coord, arg.dim, d, 1);
148  const Link L = arg.L.Ghost(d, ghost_idx, 1 - parity);
149  const Vector in = arg.in.Ghost(d, 0, ghost_idx, their_spinor_parity);
150  out -= conj(L) * in;
151  } else if (doBulk<kernel_type>() && !ghost) {
152  const int back3_idx = linkIndexM3(coord, arg.dim, d);
153  const int gauge_idx = back3_idx;
154  const Link L = arg.L(d, gauge_idx, 1 - parity);
155  const Vector in = arg.in(back3_idx, their_spinor_parity);
156  out -= conj(L) * in;
157  }
158  }
159  } // nDim
160  }
161 
162  // out(x) = M*in = (-D + m) * in(x-mu)
163  template <typename Float, int nDim, int nColor, int nParity, bool dagger, bool xpay, KernelType kernel_type, typename Arg>
164  __device__ __host__ inline void staggered(Arg &arg, int idx, int parity)
165  {
166  using real = typename mapper<Float>::type;
168 
169  bool active
170  = kernel_type == EXTERIOR_KERNEL_ALL ? false : true; // is thread active (non-trival for fused kernel only)
171  int thread_dim; // which dimension is thread working on (fused kernel only)
172  int coord[nDim];
173  int x_cb = arg.improved ? getCoords<nDim, QUDA_4D_PC, kernel_type, Arg, 3>(coord, arg, idx, parity, thread_dim) :
174  getCoords<nDim, QUDA_4D_PC, kernel_type, Arg, 1>(coord, arg, idx, parity, thread_dim);
175 
176  const int my_spinor_parity = nParity == 2 ? parity : 0;
177 
178  Vector out;
179 
180  applyStaggered<Float, nDim, nColor, nParity, dagger, kernel_type>(
181  out, arg, coord, x_cb, parity, idx, thread_dim, active);
182 
183  if (dagger) { out = -out; }
184 
185  if (xpay && kernel_type == INTERIOR_KERNEL) {
186  Vector x = arg.x(x_cb, my_spinor_parity);
187  out = arg.a * x - out;
188  } else if (kernel_type != INTERIOR_KERNEL) {
189  Vector x = arg.out(x_cb, my_spinor_parity);
190  out = x + (xpay ? -out : out);
191  }
192  if (kernel_type != EXTERIOR_KERNEL_ALL || active) arg.out(x_cb, my_spinor_parity) = out;
193  }
194 
195  // GPU Kernel for applying the staggered operator to a vector
196  template <typename Float, int nDim, int nColor, int nParity, bool dagger, bool xpay, KernelType kernel_type, typename Arg>
197  __global__ void staggeredGPU(Arg arg)
198  {
199  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
200  if (x_cb >= arg.threads) return;
201 
202  // for full fields set parity from z thread index else use arg setting
203  int parity = nParity == 2 ? blockDim.z * blockIdx.z + threadIdx.z : arg.parity;
204 
205  switch (parity) {
206  case 0: staggered<Float, nDim, nColor, nParity, dagger, xpay, kernel_type>(arg, x_cb, 0); break;
207  case 1: staggered<Float, nDim, nColor, nParity, dagger, xpay, kernel_type>(arg, x_cb, 1); break;
208  }
209  }
210 } // namespace quda
mapper< Float >::type real
KernelType kernel_type
static constexpr bool spin_project
QudaGaugeFieldOrder FieldOrder() const
Definition: gauge_field.h:257
__device__ __host__ void applyStaggered(Vector &out, Arg &arg, int coord[nDim], int x_cb, int parity, int idx, int thread_dim, bool &active)
Applies the off-diagonal part of the Staggered / Asqtad operator.
#define errorQuda(...)
Definition: util_quda.h:121
static constexpr QudaReconstructType reconstruct_u
int comm_dim(int dim)
__global__ void staggeredGPU(Arg arg)
int comm_coord(int dim)
typename gauge_mapper< Float, reconstruct_u, 18, phase, gauge_direct_load, ghost, use_inphase >::type GU
typename gauge_mapper< Float, reconstruct_l, 18, QUDA_STAGGERED_PHASE_NO, gauge_direct_load, ghost, use_inphase >::type GL
__host__ __device__ void printLink(const Matrix< Cmplx, 3 > &link)
Definition: quda_matrix.h:1149
static constexpr QudaStaggeredPhase phase
static constexpr bool gauge_direct_load
static constexpr bool spinor_direct_load
__device__ __host__ void staggered(Arg &arg, int idx, int parity)
static __device__ __host__ int linkIndexP3(const int x[], const I X[4], const int mu)
static __device__ __host__ int linkIndexM1(const int x[], const I X[4], const int mu)
static __device__ __host__ int linkIndexM3(const int x[], const I X[4], const int mu)
enum QudaStaggeredPhase_s QudaStaggeredPhase
const int nColor
Definition: covdev_test.cpp:75
static constexpr QudaReconstructType reconstruct_l
enum QudaGhostExchange_s QudaGhostExchange
Main header file for host and device accessors to GaugeFields.
typename colorspinor_mapper< Float, nSpin, nColor, spin_project, spinor_direct_load >::type F
Parameter structure for driving the Staggered Dslash operator.
static constexpr QudaGhostExchange ghost
static constexpr int nSpin
const int nParity
Definition: spinor_noise.cu:25
enum QudaReconstructType_s QudaReconstructType
const bool is_first_time_slice
const bool is_last_time_slice
static constexpr bool use_inphase
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
VectorXcd Vector
static constexpr bool improved
__device__ __host__ auto StaggeredPhase(const int coords[], int dim, int dir, const Arg &arg) -> typename Arg::real
Compute the staggered phase factor at unit shift from the current lattice coordinates. The routine below optimizes out the shift where possible, hence is only visible where we need to consider the boundary condition.
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:130
static __device__ __host__ int linkIndexP1(const int x[], const I X[4], const int mu)
bool isNative() const
StaggeredArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const GaugeField &L, double a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override)
QudaFieldOrder FieldOrder() const