QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
clover_sigma_outer_product.cuh
Go to the documentation of this file.
2 #include <gauge_field_order.h>
3 #include <quda_matrix.h>
4 #include <color_spinor.h>
5 
6 namespace quda
7 {
8 
9 #include <texture.h> // we need to convert this kernel to using colorspinor accessors
10 
11  // This is the maximum number of color spinors we can process in a single kernel
12 #if (CUDA_VERSION < 8000)
13 #define MAX_NVECTOR 1 // multi-vector code doesn't seem to work well with CUDA 7.x
14 #else
15 #define MAX_NVECTOR 9
16 #endif
17 
18  template <typename Float, typename Output, typename InputA, typename InputB> struct CloverSigmaOprodArg {
19  Output oprod;
20  InputA inA[MAX_NVECTOR];
21  InputB inB[MAX_NVECTOR];
22  Float coeff[MAX_NVECTOR][2];
23  unsigned int length;
24  int nvector;
25 
26  CloverSigmaOprodArg(Output &oprod, InputA *inA_, InputB *inB_, const std::vector<std::vector<double>> &coeff_,
27  const GaugeField &meta, int nvector) :
28  oprod(oprod),
29  length(meta.VolumeCB()),
30  nvector(nvector)
31  {
32  for (int i = 0; i < nvector; i++) {
33  inA[i] = inA_[i];
34  inB[i] = inB_[i];
35  coeff[i][0] = coeff_[i][0];
36  coeff[i][1] = coeff_[i][1];
37  }
38  }
39  };
40 
41  template <typename real, int nvector, int mu, int nu, int parity, typename Arg>
42  inline __device__ void sigmaOprod(Arg &arg, int idx)
43  {
44  typedef complex<real> Complex;
45  Matrix<Complex, 3> result;
46 
47 #pragma unroll
48  for (int i = 0; i < nvector; i++) {
50 
51  arg.inA[i].load(static_cast<Complex *>(A.data), idx, parity);
52  arg.inB[i].load(static_cast<Complex *>(B.data), idx, parity);
53 
54  // multiply by sigma_mu_nu
55  ColorSpinor<real, 3, 4> C = A.sigma(nu, mu);
56  result += arg.coeff[i][parity] * outerProdSpinTrace(C, B);
57  }
58 
59  result -= conj(result);
60 
61  Matrix<Complex, 3> temp = arg.oprod((mu - 1) * mu / 2 + nu, idx, parity);
62  temp = result + temp;
63  arg.oprod((mu - 1) * mu / 2 + nu, idx, parity) = temp;
64  }
65 
66  template <int nvector, typename real, typename Arg> __global__ void sigmaOprodKernel(Arg arg)
67  {
68  typedef complex<real> Complex;
69  int idx = blockIdx.x * blockDim.x + threadIdx.x;
70  int parity = blockIdx.y * blockDim.y + threadIdx.y;
71  int mu_nu = blockIdx.z * blockDim.z + threadIdx.z;
72 
73  if (idx >= arg.length) return;
74  if (mu_nu >= 6) return;
75 
76  switch (parity) {
77  case 0:
78  switch (mu_nu) {
79  case 0: sigmaOprod<real, nvector, 1, 0, 0>(arg, idx); break;
80  case 1: sigmaOprod<real, nvector, 2, 0, 0>(arg, idx); break;
81  case 2: sigmaOprod<real, nvector, 2, 1, 0>(arg, idx); break;
82  case 3: sigmaOprod<real, nvector, 3, 0, 0>(arg, idx); break;
83  case 4: sigmaOprod<real, nvector, 3, 1, 0>(arg, idx); break;
84  case 5: sigmaOprod<real, nvector, 3, 2, 0>(arg, idx); break;
85  }
86  break;
87  case 1:
88  switch (mu_nu) {
89  case 0: sigmaOprod<real, nvector, 1, 0, 1>(arg, idx); break;
90  case 1: sigmaOprod<real, nvector, 2, 0, 1>(arg, idx); break;
91  case 2: sigmaOprod<real, nvector, 2, 1, 1>(arg, idx); break;
92  case 3: sigmaOprod<real, nvector, 3, 0, 1>(arg, idx); break;
93  case 4: sigmaOprod<real, nvector, 3, 1, 1>(arg, idx); break;
94  case 5: sigmaOprod<real, nvector, 3, 2, 1>(arg, idx); break;
95  }
96  break;
97  }
98 
99  } // sigmaOprodKernel
100 
101 } // namespace quda
double mu
Definition: test_util.cpp:1648
complex< Float > data[size]
Definition: color_spinor.h:27
#define MAX_NVECTOR
__device__ __host__ Matrix< complex< Float >, Nc > outerProdSpinTrace(const ColorSpinor< Float, Nc, Ns > &a, const ColorSpinor< Float, Nc, Ns > &b)
Definition: color_spinor.h:985
__device__ void sigmaOprod(Arg &arg, int idx)
__global__ void sigmaOprodKernel(Arg arg)
CloverSigmaOprodArg(Output &oprod, InputA *inA_, InputB *inB_, const std::vector< std::vector< double >> &coeff_, const GaugeField &meta, int nvector)
Main header file for host and device accessors to GaugeFields.
std::complex< double > Complex
Definition: quda_internal.h:46
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:130
QudaParity parity
Definition: covdev_test.cpp:54