QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
dslash_wilson_clover.cuh
Go to the documentation of this file.
1 #pragma once
2 
4 #include <clover_field_order.h>
5 #include <linalg.cuh>
6 
7 namespace quda
8 {
9 
10  template <typename Float, int nColor, QudaReconstructType reconstruct_, bool twist_ = false>
11  struct WilsonCloverArg : WilsonArg<Float, nColor, reconstruct_> {
13  static constexpr int length = (nSpin / (nSpin / 2)) * 2 * nColor * nColor * (nSpin / 2) * (nSpin / 2) / 2;
14  static constexpr bool twist = twist_;
15 
17  typedef typename mapper<Float>::type real;
18  const C A;
19  const real a;
20  const real b;
23  double a, double b, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) :
24  WilsonArg<Float, nColor, reconstruct_>(out, in, U, a, x, parity, dagger, comm_override),
25  A(A, false),
26  a(a),
27  b(dagger ? -0.5 * b : 0.5 * b) // factor of 1/2 comes from clover normalization we need to correct for
28  {
29  }
30  };
31 
37  template <typename Float, int nDim, int nColor, int nParity, bool dagger, KernelType kernel_type, typename Arg>
38  __device__ __host__ inline void wilsonClover(Arg &arg, int idx, int parity)
39  {
40  typedef typename mapper<Float>::type real;
42  typedef ColorSpinor<real, nColor, 2> HalfVector;
43 
44  bool active
45  = kernel_type == EXTERIOR_KERNEL_ALL ? false : true; // is thread active (non-trival for fused kernel only)
46  int thread_dim; // which dimension is thread working on (fused kernel only)
47  int coord[nDim];
48  int x_cb = getCoords<nDim, QUDA_4D_PC, kernel_type>(coord, arg, idx, parity, thread_dim);
49 
50  const int my_spinor_parity = nParity == 2 ? parity : 0;
51  Vector out;
52 
53  // defined in dslash_wilson.cuh
54  applyWilson<Float, nDim, nColor, nParity, dagger, kernel_type>(
55  out, arg, coord, x_cb, 0, parity, idx, thread_dim, active);
56 
58  Vector x = arg.x(x_cb, my_spinor_parity);
59  x.toRel(); // switch to chiral basis
60 
61  Vector tmp;
62 
63 #pragma unroll
64  for (int chirality = 0; chirality < 2; chirality++) {
65  constexpr int n = nColor * Arg::nSpin / 2;
66  HMatrix<real, n> A = arg.A(x_cb, parity, chirality);
67  HalfVector x_chi = x.chiral_project(chirality);
68  HalfVector Ax_chi = A * x_chi;
69  if (arg.twist) {
70  const complex<real> b(0.0, chirality == 0 ? static_cast<real>(arg.b) : -static_cast<real>(arg.b));
71  Ax_chi += b * x_chi;
72  }
73  tmp += Ax_chi.chiral_reconstruct(chirality);
74  }
75 
76  tmp.toNonRel(); // switch back to non-chiral basis
77 
78  out = tmp + arg.a * out;
79  } else if (active) {
80  Vector x = arg.out(x_cb, my_spinor_parity);
81  out = x + arg.a * out;
82  }
83 
84  if (kernel_type != EXTERIOR_KERNEL_ALL || active) arg.out(x_cb, my_spinor_parity) = out;
85  }
86 
87  // CPU kernel for applying the Wilson operator to a vector
88  template <typename Float, int nDim, int nColor, int nParity, bool dagger, bool xpay, KernelType kernel_type, typename Arg>
90  {
91  for (int parity = 0; parity < nParity; parity++) {
92  // for full fields then set parity from loop else use arg setting
93  parity = nParity == 2 ? parity : arg.parity;
94 
95  for (int x_cb = 0; x_cb < arg.threads; x_cb++) { // 4-d volume
96  wilsonClover<Float, nDim, nColor, nParity, dagger, kernel_type>(arg, x_cb, parity);
97  } // 4-d volumeCB
98  } // parity
99  }
100 
101  // GPU Kernel for applying the Wilson operator to a vector
102  template <typename Float, int nDim, int nColor, int nParity, bool dagger, bool xpay, KernelType kernel_type, typename Arg>
103  __global__ void wilsonCloverGPU(Arg arg)
104  {
105  int x_cb = blockIdx.x * blockDim.x + threadIdx.x;
106  if (x_cb >= arg.threads) return;
107 
108  // for full fields set parity from z thread index else use arg setting
109  int parity = nParity == 2 ? blockDim.z * blockIdx.z + threadIdx.z : arg.parity;
110 
111  switch (parity) {
112  case 0: wilsonClover<Float, nDim, nColor, nParity, dagger, kernel_type>(arg, x_cb, 0); break;
113  case 1: wilsonClover<Float, nDim, nColor, nParity, dagger, kernel_type>(arg, x_cb, 1); break;
114  }
115  }
116 
117 } // namespace quda
WilsonCloverArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const CloverField &A, double a, double b, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override)
KernelType kernel_type
clover_mapper< Float, length, true >::type C
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:44
void wilsonCloverCPU(Arg arg)
Main header file for host and device accessors to CloverFields.
__global__ void wilsonCloverGPU(Arg arg)
Parameter structure for driving the Wilson operator.
const int nColor
Definition: covdev_test.cpp:75
__device__ __host__ void wilsonClover(Arg &arg, int idx, int parity)
Apply the Wilson-clover dslash out(x) = M*in = A(x)*x(x) + D * in(x-mu) Note this routine only exists...
Specialized container for Hermitian matrices (e.g., used for wrapping clover matrices) ...
Definition: quda_matrix.h:61
static constexpr int length
mapper< Float >::type real
static constexpr bool twist
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
VectorXcd Vector
static constexpr int nSpin