QUDA  v1.1.0
A library for QCD on GPUs
pgauge_det_trace.cu
Go to the documentation of this file.
1 #include <quda_internal.h>
2 #include <quda_matrix.h>
3 #include <tune_quda.h>
4 #include <gauge_field.h>
5 #include <gauge_field_order.h>
6 #include <launch_kernel.cuh>
7 #include <comm_quda.h>
8 #include <reduce_helper.h>
9 #include <index_helper.cuh>
10 #include <instantiate.h>
11 
12 namespace quda {
13 
14  template <typename Float, int nColor_, QudaReconstructType recon_>
15  struct KernelArg : public ReduceArg<double2> {
16  static constexpr int nColor = nColor_;
17  static constexpr QudaReconstructType recon = recon_;
18  using real = typename mapper<Float>::type;
19  using Gauge = typename gauge_mapper<real, recon>::type;
20  int threads; // number of active threads required
21  int X[4]; // grid dimensions
22  int border[4];
23  Gauge dataOr;
24 
25  KernelArg(const GaugeField &data) :
26  ReduceArg<double2>(),
27  dataOr(data),
28  threads(data.LocalVolumeCB())
29  {
30  for (int dir=0; dir<4; ++dir) {
31  border[dir] = data.R()[dir];
32  X[dir] = data.X()[dir] - border[dir]*2;
33  }
34  }
35  };
36 
37  template <int blockSize, int type, typename Arg>
38  __global__ void compute(Arg arg)
39  {
40  int idx = threadIdx.x + blockIdx.x*blockDim.x;
41  int parity = threadIdx.y;
42 
43  complex<double> val(0.0, 0.0);
44  while (idx < arg.threads) {
45  int X[4];
46 #pragma unroll
47  for(int dr=0; dr<4; ++dr) X[dr] = arg.X[dr];
48 
49  int x[4];
50  getCoords(x, idx, X, parity);
51 #pragma unroll
52  for(int dr=0; dr<4; ++dr) {
53  x[dr] += arg.border[dr];
54  X[dr] += 2*arg.border[dr];
55  }
56  idx = linkIndex(x,X);
57 #pragma unroll
58  for (int mu = 0; mu < 4; mu++) {
59  Matrix<complex<typename Arg::real>, Arg::nColor> U = arg.dataOr(mu, idx, parity);
60  if (type == 0) val += getDeterminant(U);
61  else if (type == 1) val += getTrace(U);
62  }
63 
64  idx += blockDim.x*gridDim.x;
65  }
66 
67  arg.template reduce2d<blockSize,2>(val);
68  }
69 
70  template <typename Float, int nColor, QudaReconstructType recon, int type>
71  class CalcFunc : TunableLocalParityReduction {
72  double2 &result;
73  const GaugeField &u;
74 
75  public:
76  CalcFunc(double2 &result, const GaugeField &u) :
77  result(result),
78  u(u)
79  {
80  apply(0);
81  }
82 
83  void apply(const qudaStream_t &stream)
84  {
85  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
86  KernelArg<Float, nColor, recon> arg(u);
87  LAUNCH_KERNEL_LOCAL_PARITY(compute, (*this), tp, stream, arg, type, decltype(arg));
88  arg.complete(result, stream);
89  if (!activeTuning()) {
90  comm_allreduce_array((double*)&result, 2);
91  result.x /= (double)(4*u.LocalVolume()*comm_size());
92  result.y /= (double)(4*u.LocalVolume()*comm_size());
93  }
94  }
95 
96  TuneKey tuneKey() const { return TuneKey(u.VolString(), typeid(*this).name(), u.AuxString()); }
97 
98  long long flops() const {
99  if (u.Ncolor()==3 && type == 0) return 264LL*u.LocalVolume();
100  else if (type == 1) return 2*u.Geometry()*u.Ncolor()*u.LocalVolume();
101  else return 0;
102  }
103 
104  long long bytes() const { return u.Bytes(); }
105  };
106 
107  template <typename Float, int nColor, QudaReconstructType recon> struct computeDeterminant {
108  computeDeterminant(GaugeField &data, double2 &det)
109  {
110  CalcFunc<Float, nColor, recon, 0>(det, data);
111  }
112  };
113 
114  template <typename Float, int nColor, QudaReconstructType recon> struct computeTrace {
115  computeTrace(GaugeField &data, double2 &trace)
116  {
117  CalcFunc<Float, nColor, recon, 1>(trace, data);
118  }
119  };
120 
121  /**
122  * @brief Calculate the Determinant
123  *
124  * @param[in] data Gauge field
125  * @returns double2 complex Determinant value
126  */
127  double2 getLinkDeterminant(GaugeField& data)
128  {
129  double2 det = make_double2(0.0,0.0);
130 #ifdef GPU_GAUGE_ALG
131  instantiate<computeDeterminant>(data, det);
132 #else
133  errorQuda("Pure gauge code has not been built");
134 #endif // GPU_GAUGE_ALG
135  return det;
136  }
137 
138  /**
139  * @brief Calculate the Trace
140  *
141  * @param[in] data Gauge field
142  * @returns double2 complex trace value
143  */
144  double2 getLinkTrace(GaugeField& data)
145  {
146  double2 det = make_double2(0.0,0.0);
147 #ifdef GPU_GAUGE_ALG
148  instantiate<computeTrace>(data, det);
149 #else
150  errorQuda("Pure gauge code has not been built");
151 #endif // GPU_GAUGE_ALG
152  return det;
153  }
154 
155 } // namespace quda