QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
laplace.cuh
Go to the documentation of this file.
1 #pragma once
2 
3 #include <dslash_helper.cuh>
5 #include <gauge_field_order.h>
6 #include <color_spinor.h>
7 #include <dslash_helper.cuh>
8 #include <index_helper.cuh>
9 
10 namespace quda
11 {
12 
16  template <typename Float, int nColor, QudaReconstructType reconstruct_> struct LaplaceArg : DslashArg<Float> {
17  static constexpr int nSpin = 1;
18  static constexpr bool spin_project = false;
19  static constexpr bool spinor_direct_load = false; // false means texture load
21 
22  static constexpr QudaReconstructType reconstruct = reconstruct_;
23  static constexpr bool gauge_direct_load = false; // false means texture load
26 
27  typedef typename mapper<Float>::type real;
28 
29  F out;
30  const F in;
31  const F x;
32  const G U;
33  const real a;
34  int dir;
36  LaplaceArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, int dir, double a,
37  const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) :
38 
39  DslashArg<Float>(in, U, parity, dagger, a != 0.0 ? true : false, 1, false, comm_override),
40  out(out),
41  in(in),
42  U(U),
43  dir(dir),
44  x(x),
45  a(a)
46  {
47  if (!out.isNative() || !x.isNative() || !in.isNative() || !U.isNative())
48  errorQuda("Unsupported field order colorspinor(in)=%d gauge=%d combination\n", in.FieldOrder(), U.FieldOrder());
49  if (dir < 3 || dir > 4) errorQuda("Unsupported laplace direction %d (must be 3 or 4)", dir);
50  }
51  };
52 
67  template <typename Float, int nDim, int nColor, int nParity, bool dagger, KernelType kernel_type, int dir,
68  typename Arg, typename Vector>
69  __device__ __host__ inline void applyLaplace(Vector &out, Arg &arg, int coord[nDim], int x_cb, int parity, int idx,
70  int thread_dim, bool &active)
71  {
72 
73  typedef typename mapper<Float>::type real;
74  typedef Matrix<complex<real>, nColor> Link;
75  const int their_spinor_parity = (arg.nParity == 2) ? 1 - parity : 0;
76 
77 #pragma unroll
78  for (int d = 0; d < nDim; d++) { // loop over dimension
79  if (d != dir) {
80  {
81  // Forward gather - compute fwd offset for vector fetch
82  const bool ghost = (coord[d] + 1 >= arg.dim[d]) && isActive<kernel_type>(active, thread_dim, d, coord, arg);
83 
84  if (doHalo<kernel_type>(d) && ghost) {
85 
86  // const int ghost_idx = ghostFaceIndexStaggered<1>(coord, arg.dim, d, 1);
87  const int ghost_idx = ghostFaceIndex<1>(coord, arg.dim, d, arg.nFace);
88  const Link U = arg.U(d, x_cb, parity);
89  const Vector in = arg.in.Ghost(d, 1, ghost_idx, their_spinor_parity);
90 
91  out += U * in;
92  } else if (doBulk<kernel_type>() && !ghost) {
93 
94  const int fwd_idx = linkIndexP1(coord, arg.dim, d);
95  const Link U = arg.U(d, x_cb, parity);
96  const Vector in = arg.in(fwd_idx, their_spinor_parity);
97 
98  out += U * in;
99  }
100  }
101  {
102  // Backward gather - compute back offset for spinor and gauge fetch
103 
104  const int back_idx = linkIndexM1(coord, arg.dim, d);
105  const int gauge_idx = back_idx;
106 
107  const bool ghost = (coord[d] - 1 < 0) && isActive<kernel_type>(active, thread_dim, d, coord, arg);
108 
109  if (doHalo<kernel_type>(d) && ghost) {
110 
111  // const int ghost_idx = ghostFaceIndexStaggered<0>(coord, arg.dim, d, 1);
112  const int ghost_idx = ghostFaceIndex<0>(coord, arg.dim, d, arg.nFace);
113 
114  const Link U = arg.U.Ghost(d, ghost_idx, 1 - parity);
115  const Vector in = arg.in.Ghost(d, 0, ghost_idx, their_spinor_parity);
116 
117  out += conj(U) * in;
118  } else if (doBulk<kernel_type>() && !ghost) {
119 
120  const Link U = arg.U(d, gauge_idx, 1 - parity);
121  const Vector in = arg.in(back_idx, their_spinor_parity);
122 
123  out += conj(U) * in;
124  }
125  }
126  }
127  }
128  }
129 
130  // out(x) = M*in
131  template <typename Float, int nDim, int nColor, int nParity, bool dagger, bool xpay, KernelType kernel_type, typename Arg>
132  __device__ __host__ inline void laplace(Arg &arg, int idx, int parity)
133  {
134 
135  using real = typename mapper<Float>::type;
137 
138  // is thread active (non-trival for fused kernel only)
139  bool active = kernel_type == EXTERIOR_KERNEL_ALL ? false : true;
140 
141  // which dimension is thread working on (fused kernel only)
142  int thread_dim;
143 
144  int coord[nDim];
145  int x_cb = getCoords<nDim, QUDA_4D_PC, kernel_type, Arg>(coord, arg, idx, parity, thread_dim);
146 
147  const int my_spinor_parity = nParity == 2 ? parity : 0;
148  Vector out;
149 
150  //We instantiate two kernel types:
151  //case 4 is an operator in all x,y,z,t dimensions
152  //case 3 is a spatial operator only, the t dimension is omitted.
153  switch (arg.dir) {
154  case 3:
155  applyLaplace<Float, nDim, nColor, nParity, dagger, kernel_type, 3>(out, arg, coord, x_cb, parity, idx, thread_dim,
156  active);
157  break;
158  case 4:
159  default:
160  applyLaplace<Float, nDim, nColor, nParity, dagger, kernel_type, -1>(out, arg, coord, x_cb, parity, idx,
161  thread_dim, active);
162  break;
163  }
164 
165  if (xpay && kernel_type == INTERIOR_KERNEL) {
166  Vector x = arg.x(x_cb, my_spinor_parity);
167  out = x + arg.a * out;
168  } else if (kernel_type != INTERIOR_KERNEL) {
169  Vector x = arg.out(x_cb, my_spinor_parity);
170  out = x + (xpay ? arg.a * out : out);
171  }
172 
173  if (kernel_type != EXTERIOR_KERNEL_ALL || active) arg.out(x_cb, my_spinor_parity) = out;
174  }
175 
176  // GPU Kernel for applying the covariant derivative operator to a vector
177  template <typename Float, int nDim, int nColor, int nParity, bool dagger, bool xpay, KernelType kernel_type, typename Arg>
178  __global__ void laplaceGPU(Arg arg)
179  {
180 
181  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
182  if (x_cb >= arg.threads) return;
183 
184  // for full fields set parity from z thread index else use arg setting
185  int parity = nParity == 2 ? blockDim.z * blockIdx.z + threadIdx.z : arg.parity;
186 
187  switch (parity) {
188  case 0: laplace<Float, nDim, nColor, nParity, dagger, xpay, kernel_type>(arg, x_cb, 0); break;
189  case 1: laplace<Float, nDim, nColor, nParity, dagger, xpay, kernel_type>(arg, x_cb, 1); break;
190  }
191  }
192 } // namespace quda
LaplaceArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, int dir, double a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override)
Definition: laplace.cuh:36
KernelType kernel_type
QudaGaugeFieldOrder FieldOrder() const
Definition: gauge_field.h:257
#define errorQuda(...)
Definition: util_quda.h:121
gauge_mapper< Float, reconstruct, 18, QUDA_STAGGERED_PHASE_NO, gauge_direct_load, ghost >::type G
Definition: laplace.cuh:25
mapper< Float >::type real
Definition: laplace.cuh:27
Parameter structure for driving the covariatnt derivative operator.
Definition: laplace.cuh:16
static __device__ __host__ int linkIndexM1(const int x[], const I X[4], const int mu)
colorspinor_mapper< Float, nSpin, nColor, spin_project, spinor_direct_load >::type F
Definition: laplace.cuh:20
const int nColor
Definition: covdev_test.cpp:75
enum QudaGhostExchange_s QudaGhostExchange
Main header file for host and device accessors to GaugeFields.
static constexpr int nSpin
Definition: laplace.cuh:17
static constexpr QudaGhostExchange ghost
Definition: laplace.cuh:24
static constexpr bool spin_project
Definition: laplace.cuh:18
const int nParity
Definition: spinor_noise.cu:25
enum QudaReconstructType_s QudaReconstructType
__device__ __host__ void applyLaplace(Vector &out, Arg &arg, int coord[nDim], int x_cb, int parity, int idx, int thread_dim, bool &active)
Definition: laplace.cuh:69
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
VectorXcd Vector
static constexpr bool spinor_direct_load
Definition: laplace.cuh:19
__device__ __host__ void laplace(Arg &arg, int idx, int parity)
Definition: laplace.cuh:132
__global__ void laplaceGPU(Arg arg)
Definition: laplace.cuh:178
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:130
static __device__ __host__ int linkIndexP1(const int x[], const I X[4], const int mu)
bool isNative() const
QudaFieldOrder FieldOrder() const
const real a
Definition: laplace.cuh:33
static constexpr QudaReconstructType reconstruct
Definition: laplace.cuh:22
static constexpr bool gauge_direct_load
Definition: laplace.cuh:23