QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
dslash_twisted_mass_preconditioned.cuh
Go to the documentation of this file.
1 #pragma once
2 
4 
5 namespace quda
6 {
7 
8  template <typename Float, int nColor, QudaReconstructType reconstruct_>
9  struct TwistedMassArg : WilsonArg<Float, nColor, reconstruct_> {
10  typedef typename mapper<Float>::type real;
11  real a;
12  real b;
13  real c;
14  real a_inv;
15  real b_inv;
16  bool asymmetric;
18  TwistedMassArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double b,
19  bool xpay, const ColorSpinorField &x, int parity, bool dagger, bool asymmetric, const int *comm_override) :
20  WilsonArg<Float, nColor, reconstruct_>(out, in, U, xpay ? 1.0 : 0.0, x, parity, dagger, comm_override),
21  a(a),
22  b(dagger ? -b : b), // if dagger flip the twist
23  c(0.0),
24  a_inv(1.0 / (a * (1 + b * b))),
25  b_inv(dagger ? b : -b),
26  asymmetric(asymmetric)
27  {
28  // set parameters for twisting in the packing kernel
29  if (dagger && !asymmetric) {
32  }
33  }
34  };
35 
50  template <typename Float, int nDim, int nColor, int nParity, bool dagger, int twist, KernelType kernel_type,
51  typename Arg, typename Vector>
52  __device__ __host__ inline void applyWilsonTM(
53  Vector &out, Arg &arg, int coord[nDim], int x_cb, int s, int parity, int idx, int thread_dim, bool &active)
54  {
55  static_assert(twist == 1 || twist == 2, "twist template must equal 1 or 2"); // ensure singlet or doublet
56  typedef typename mapper<Float>::type real;
57  typedef ColorSpinor<real, nColor, 2> HalfVector;
58  typedef Matrix<complex<real>, nColor> Link;
59  const int their_spinor_parity = nParity == 2 ? 1 - parity : 0;
60 
61 #pragma unroll nDim
62  for (int d = 0; d < nDim; d++) { // loop over dimension
63  { // Forward gather - compute fwd offset for vector fetch
64  const int fwd_idx = getNeighborIndexCB(coord, d, +1, arg.dc);
65  constexpr int proj_dir = dagger ? +1 : -1;
66  const bool ghost
67  = (coord[d] + arg.nFace >= arg.dim[d]) && isActive<kernel_type>(active, thread_dim, d, coord, arg);
68 
69  if (doHalo<kernel_type>(d) && ghost) {
70  // we need to compute the face index if we are updating a face that isn't ours
71  const int ghost_idx = (kernel_type == EXTERIOR_KERNEL_ALL && d != thread_dim) ?
72  ghostFaceIndex<1, nDim>(coord, arg.dim, d, arg.nFace) :
73  idx;
74 
75  Link U = arg.U(d, x_cb, parity);
76  HalfVector in = arg.in.Ghost(d, 1, ghost_idx + s * arg.dc.ghostFaceCB[d], their_spinor_parity);
77  if (d == 3) in *= arg.t_proj_scale; // put this in the Ghost accessor and merge with any rescaling?
78 
79  out += (U * in).reconstruct(d, proj_dir);
80  } else if (doBulk<kernel_type>() && !ghost) {
81 
82  Link U = arg.U(d, x_cb, parity);
83  Vector in;
84  if (twist == 1) {
85  in = arg.in(fwd_idx + s * arg.dc.volume_4d_cb, their_spinor_parity);
86  in = arg.a * (in + arg.b * in.igamma(4)); // apply A^{-1} to in
87  } else { // twisted doublet
88  Vector in0 = arg.in(fwd_idx + 0 * arg.dc.volume_4d_cb, their_spinor_parity);
89  Vector in1 = arg.in(fwd_idx + 1 * arg.dc.volume_4d_cb, their_spinor_parity);
90  if (s == 0)
91  in = arg.a * (in0 + arg.b * in0.igamma(4) + arg.c * in1);
92  else
93  in = arg.a * (in1 - arg.b * in1.igamma(4) + arg.c * in0);
94  }
95 
96  out += (U * in.project(d, proj_dir)).reconstruct(d, proj_dir);
97  }
98  }
99 
100  { // Backward gather - compute back offset for spinor and gauge fetch
101  const int back_idx = getNeighborIndexCB(coord, d, -1, arg.dc);
102  const int gauge_idx = back_idx;
103  constexpr int proj_dir = dagger ? -1 : +1;
104  const bool ghost = (coord[d] - arg.nFace < 0) && isActive<kernel_type>(active, thread_dim, d, coord, arg);
105 
106  if (doHalo<kernel_type>(d) && ghost) {
107  // we need to compute the face index if we are updating a face that isn't ours
108  const int ghost_idx = (kernel_type == EXTERIOR_KERNEL_ALL && d != thread_dim) ?
109  ghostFaceIndex<0, nDim>(coord, arg.dim, d, arg.nFace) :
110  idx;
111 
112  Link U = arg.U.Ghost(d, ghost_idx, 1 - parity);
113  HalfVector in = arg.in.Ghost(d, 0, ghost_idx + s * arg.dc.ghostFaceCB[d], their_spinor_parity);
114  if (d == 3) in *= arg.t_proj_scale;
115 
116  out += (conj(U) * in).reconstruct(d, proj_dir);
117  } else if (doBulk<kernel_type>() && !ghost) {
118 
119  Link U = arg.U(d, gauge_idx, 1 - parity);
120  Vector in;
121  if (twist == 1) {
122  in = arg.in(back_idx + s * arg.dc.volume_4d_cb, their_spinor_parity);
123  in = arg.a * (in + arg.b * in.igamma(4)); // apply A^{-1} to in
124  } else { // twisted doublet
125  Vector in0 = arg.in(back_idx + 0 * arg.dc.volume_4d_cb, their_spinor_parity);
126  Vector in1 = arg.in(back_idx + 1 * arg.dc.volume_4d_cb, their_spinor_parity);
127  if (s == 0)
128  in = arg.a * (in0 + arg.b * in0.igamma(4) + arg.c * in1);
129  else
130  in = arg.a * (in1 - arg.b * in1.igamma(4) + arg.c * in0);
131  }
132 
133  out += (conj(U) * in.project(d, proj_dir)).reconstruct(d, proj_dir);
134  }
135  }
136  } // nDim
137  }
138 
144  template <typename Float, int nDim, int nColor, int nParity, bool dagger, bool asymmetric, bool xpay,
145  KernelType kernel_type, typename Arg>
146  __device__ __host__ inline void twistedMass(Arg &arg, int idx, int parity)
147  {
148  typedef typename mapper<Float>::type real;
150  typedef ColorSpinor<real, nColor, 2> HalfVector;
151 
152  bool active
153  = kernel_type == EXTERIOR_KERNEL_ALL ? false : true; // is thread active (non-trival for fused kernel only)
154  int thread_dim; // which dimension is thread working on (fused kernel only)
155  int coord[nDim];
156  int x_cb = getCoords<nDim, QUDA_4D_PC, kernel_type>(coord, arg, idx, parity, thread_dim);
157 
158  const int my_spinor_parity = nParity == 2 ? parity : 0;
159 
160  Vector out;
161 
162  if (!dagger || asymmetric) // defined in dslash_wilson.cuh
163  applyWilson<Float, nDim, nColor, nParity, dagger, kernel_type>(
164  out, arg, coord, x_cb, 0, parity, idx, thread_dim, active);
165  else // special dslash for symmetric dagger
166  applyWilsonTM<Float, nDim, nColor, nParity, dagger, 1, kernel_type>(
167  out, arg, coord, x_cb, 0, parity, idx, thread_dim, active);
168 
169  if (xpay && kernel_type == INTERIOR_KERNEL) {
170  Vector x = arg.x(x_cb, my_spinor_parity);
171  if (!dagger || asymmetric) {
172  out += arg.a_inv * (x + arg.b_inv * x.igamma(4)); // apply inverse twist which is undone below
173  } else {
174  out += x; // just directly add since twist already applied in the dslash
175  }
176  } else if (kernel_type != INTERIOR_KERNEL && active) {
177  // if we're not the interior kernel, then we must sum the partial
178  Vector x = arg.out(x_cb, my_spinor_parity);
179  out += x;
180  }
181 
182  if (isComplete<kernel_type>(arg, coord) && active) {
183  if (!dagger || asymmetric) out = arg.a * (out + arg.b * out.igamma(4)); // apply A^{-1} to D*in
184  }
185 
186  if (kernel_type != EXTERIOR_KERNEL_ALL || active) arg.out(x_cb, my_spinor_parity) = out;
187  }
188 
189  // CPU kernel for applying the preconditioned twisted-mass operator to a vector
190  template <typename Float, int nDim, int nColor, int nParity, bool dagger, bool xpay, KernelType kernel_type, typename Arg>
192  {
193 
194  if (arg.asymmetric) {
195  for (int parity = 0; parity < nParity; parity++) {
196  // for full fields then set parity from loop else use arg setting
197  parity = nParity == 2 ? parity : arg.parity;
198 
199  for (int x_cb = 0; x_cb < arg.threads; x_cb++) { // 4-d volume
200  twistedMass<Float, nDim, nColor, nParity, dagger, true, xpay, kernel_type>(arg, x_cb, parity);
201  } // 4-d volumeCB
202  } // parity
203  } else {
204  for (int parity = 0; parity < nParity; parity++) {
205  // for full fields then set parity from loop else use arg setting
206  parity = nParity == 2 ? parity : arg.parity;
207 
208  for (int x_cb = 0; x_cb < arg.threads; x_cb++) { // 4-d volume
209  twistedMass<Float, nDim, nColor, nParity, dagger, false, xpay, kernel_type>(arg, x_cb, parity);
210  } // 4-d volumeCB
211  } // parity
212  }
213  }
214 
215  // GPU Kernel for applying the preconditioned twisted-mass operator to a vector
216  template <typename Float, int nDim, int nColor, int nParity, bool dagger, bool xpay, KernelType kernel_type, typename Arg>
217  __global__ void twistedMassPreconditionedGPU(Arg arg)
218  {
219  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
220  if (x_cb >= arg.threads) return;
221 
222  // for full fields set parity from z thread index else use arg setting
223  int parity = nParity == 2 ? blockDim.z * blockIdx.z + threadIdx.z : arg.parity;
224 
225  if (arg.asymmetric) {
226  // constrain template instantiation for compilation (asymmetric implies dagger and !xpay)
227  switch (parity) {
228  case 0: twistedMass<Float, nDim, nColor, nParity, true, true, false, kernel_type>(arg, x_cb, 0); break;
229  case 1: twistedMass<Float, nDim, nColor, nParity, true, true, false, kernel_type>(arg, x_cb, 1); break;
230  }
231  } else {
232  switch (parity) {
233  case 0: twistedMass<Float, nDim, nColor, nParity, dagger, false, xpay, kernel_type>(arg, x_cb, 0); break;
234  case 1: twistedMass<Float, nDim, nColor, nParity, dagger, false, xpay, kernel_type>(arg, x_cb, 1); break;
235  }
236  }
237  }
238 
239 } // namespace quda
KernelType kernel_type
__device__ __host__ void twistedMass(Arg &arg, int idx, int parity)
Apply the twisted-mass dslash out(x) = M*in = a * D * in + (1 + i*b*gamma_5)*x Note this routine only...
void twistedMassPreconditionedCPU(Arg arg)
static constexpr QudaGhostExchange ghost
static constexpr QudaReconstructType reconstruct
Parameter structure for driving the Wilson operator.
const int nColor
Definition: covdev_test.cpp:75
__global__ void twistedMassPreconditionedGPU(Arg arg)
TwistedMassArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double b, bool xpay, const ColorSpinorField &x, int parity, bool dagger, bool asymmetric, const int *comm_override)
__shared__ float s[]
static __device__ __host__ int getNeighborIndexCB(const int x[], int mu, int dir, const Arg &arg)
Compute the checkerboard 1-d index for the nearest neighbor.
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
VectorXcd Vector
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:130
__device__ __host__ void applyWilsonTM(Vector &out, Arg &arg, int coord[nDim], int x_cb, int s, int parity, int idx, int thread_dim, bool &active)
Applies the off-diagonal part of the Wilson operator premultiplied by twist rotation - this is requir...