1 #include <quda_internal.h>
2 #include <quda_matrix.h>
4 #include <gauge_field.h>
5 #include <gauge_field_order.h>
6 #include <launch_kernel.cuh>
8 #include <reduce_helper.h>
9 #include <index_helper.cuh>
10 #include <instantiate.h>
14 template <typename Float, int nColor_, QudaReconstructType recon_>
15 struct KernelArg : public ReduceArg<double2> {
16 static constexpr int nColor = nColor_;
17 static constexpr QudaReconstructType recon = recon_;
18 using real = typename mapper<Float>::type;
19 using Gauge = typename gauge_mapper<real, recon>::type;
20 int threads; // number of active threads required
21 int X[4]; // grid dimensions
25 KernelArg(const GaugeField &data) :
28 threads(data.LocalVolumeCB())
30 for (int dir=0; dir<4; ++dir) {
31 border[dir] = data.R()[dir];
32 X[dir] = data.X()[dir] - border[dir]*2;
37 template <int blockSize, int type, typename Arg>
38 __global__ void compute(Arg arg)
40 int idx = threadIdx.x + blockIdx.x*blockDim.x;
41 int parity = threadIdx.y;
43 complex<double> val(0.0, 0.0);
44 while (idx < arg.threads) {
47 for(int dr=0; dr<4; ++dr) X[dr] = arg.X[dr];
50 getCoords(x, idx, X, parity);
52 for(int dr=0; dr<4; ++dr) {
53 x[dr] += arg.border[dr];
54 X[dr] += 2*arg.border[dr];
58 for (int mu = 0; mu < 4; mu++) {
59 Matrix<complex<typename Arg::real>, Arg::nColor> U = arg.dataOr(mu, idx, parity);
60 if (type == 0) val += getDeterminant(U);
61 else if (type == 1) val += getTrace(U);
64 idx += blockDim.x*gridDim.x;
67 arg.template reduce2d<blockSize,2>(val);
70 template <typename Float, int nColor, QudaReconstructType recon, int type>
71 class CalcFunc : TunableLocalParityReduction {
76 CalcFunc(double2 &result, const GaugeField &u) :
83 void apply(const qudaStream_t &stream)
85 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
86 KernelArg<Float, nColor, recon> arg(u);
87 LAUNCH_KERNEL_LOCAL_PARITY(compute, (*this), tp, stream, arg, type, decltype(arg));
88 arg.complete(result, stream);
89 if (!activeTuning()) {
90 comm_allreduce_array((double*)&result, 2);
91 result.x /= (double)(4*u.LocalVolume()*comm_size());
92 result.y /= (double)(4*u.LocalVolume()*comm_size());
96 TuneKey tuneKey() const { return TuneKey(u.VolString(), typeid(*this).name(), u.AuxString()); }
98 long long flops() const {
99 if (u.Ncolor()==3 && type == 0) return 264LL*u.LocalVolume();
100 else if (type == 1) return 2*u.Geometry()*u.Ncolor()*u.LocalVolume();
104 long long bytes() const { return u.Bytes(); }
107 template <typename Float, int nColor, QudaReconstructType recon> struct computeDeterminant {
108 computeDeterminant(GaugeField &data, double2 &det)
110 CalcFunc<Float, nColor, recon, 0>(det, data);
114 template <typename Float, int nColor, QudaReconstructType recon> struct computeTrace {
115 computeTrace(GaugeField &data, double2 &trace)
117 CalcFunc<Float, nColor, recon, 1>(trace, data);
122 * @brief Calculate the Determinant
124 * @param[in] data Gauge field
125 * @returns double2 complex Determinant value
127 double2 getLinkDeterminant(GaugeField& data)
129 double2 det = make_double2(0.0,0.0);
131 instantiate<computeDeterminant>(data, det);
133 errorQuda("Pure gauge code has not been built");
134 #endif // GPU_GAUGE_ALG
139 * @brief Calculate the Trace
141 * @param[in] data Gauge field
142 * @returns double2 complex trace value
144 double2 getLinkTrace(GaugeField& data)
146 double2 det = make_double2(0.0,0.0);
148 instantiate<computeTrace>(data, det);
150 errorQuda("Pure gauge code has not been built");
151 #endif // GPU_GAUGE_ALG