3 #include <quda_internal.h>
5 #include <gauge_field.h>
6 #include <gauge_field_order.h>
7 #include <quda_matrix.h>
8 #include <float_vector.h>
9 #include <complex_quda.h>
10 #include <instantiate.h>
14 template <typename Float_, int nColor_, QudaReconstructType recon_u, QudaReconstructType recon_m, int N_>
15 struct UpdateGaugeArg {
17 static constexpr int nColor = nColor_;
18 static constexpr int N = N_;
19 static_assert(nColor == 3, "Only nColor=3 enabled at this time");
20 typedef typename gauge_mapper<Float,recon_u>::type Gauge;
21 typedef typename gauge_mapper<Float,recon_m>::type Mom;
27 UpdateGaugeArg(GaugeField &out, const GaugeField &in, const GaugeField &mom, Float dt, int nDim)
28 : out(out), in(in), mom(mom), dt(dt), nDim(nDim) { }
31 template <bool conj_mom, bool exact, typename Arg>
32 __device__ __host__ void compute(Arg &arg, int x, int parity)
34 using Float = typename Arg::Float;
35 typedef complex<Float> Complex;
36 Matrix<Complex, Arg::nColor> link, result, mom;
38 for (int dir=0; dir<arg.nDim; ++dir) {
39 link = arg.in(dir, x, parity);
40 mom = arg.mom(dir, x, parity);
42 Complex trace = getTrace(mom);
43 for (int c=0; c<Arg::nColor; c++) mom(c,c) -= trace/static_cast<Float>(Arg::nColor);
48 // Nth order expansion of exponential
50 for (int r= Arg::N; r>0; r--)
51 result = (arg.dt/r)*mom*result + link;
53 for (int r= Arg::N; r>0; r--)
54 result = (arg.dt/r)*conj(mom)*result + link;
63 link = conj(mom) * link;
69 arg.out(dir, x, parity) = result;
73 template <bool conj_mom, bool exact, typename Arg>
74 __global__ void updateGaugeFieldKernel(Arg arg)
76 int x_cb = blockIdx.x*blockDim.x + threadIdx.x;
77 if (x_cb >= arg.out.volumeCB) return;
78 int parity = blockIdx.y*blockDim.y + threadIdx.y;
79 compute<conj_mom,exact>(arg, x_cb, parity);
82 template <typename Arg, bool conj_mom, bool exact>
83 class UpdateGaugeField : public TunableVectorY {
85 const GaugeField &meta; // meta data
87 bool tuneGridDim() const { return false; }
88 unsigned int minThreads() const { return arg.in.volumeCB; }
91 UpdateGaugeField(Arg &arg, const GaugeField &meta) :
96 void apply(const qudaStream_t &stream){
97 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
98 qudaLaunchKernel(updateGaugeFieldKernel<conj_mom,exact,Arg>, tp, stream, arg);
101 long long flops() const {
102 const int Nc = Arg::nColor;
103 return arg.nDim*2*arg.in.volumeCB*Arg::N*(Nc*Nc*2 + // scalar-matrix multiply
104 (8*Nc*Nc*Nc - 2*Nc*Nc) + // matrix-matrix multiply
105 Nc*Nc*2); // matrix-matrix addition
108 long long bytes() const { return arg.nDim*2*arg.in.volumeCB*(arg.in.Bytes() + arg.out.Bytes() + arg.mom.Bytes()); }
110 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
113 template <typename Float, int nColor, QudaReconstructType recon_u> struct UpdateGauge
115 UpdateGauge(GaugeField &out, const GaugeField &in, const GaugeField &mom, double dt, bool conj_mom, bool exact)
117 if (mom.Reconstruct() != QUDA_RECONSTRUCT_10) errorQuda("Reconstruction type %d not supported", mom.Reconstruct());
118 constexpr QudaReconstructType recon_m = QUDA_RECONSTRUCT_10;
119 constexpr int N = 8; // degree of exponential expansion
120 UpdateGaugeArg<Float, nColor, recon_u, recon_m, N> arg(out, in, mom, dt, 4);
123 UpdateGaugeField<decltype(arg),true,true> updateGauge(arg, in);
124 updateGauge.apply(0);
126 UpdateGaugeField<decltype(arg),true,false> updateGauge(arg, in);
127 updateGauge.apply(0);
131 UpdateGaugeField<decltype(arg),false,true> updateGauge(arg, in);
132 updateGauge.apply(0);
134 UpdateGaugeField<decltype(arg),false,false> updateGauge(arg, in);
135 updateGauge.apply(0);
141 void updateGaugeField(GaugeField &out, double dt, const GaugeField& in, const GaugeField& mom, bool conj_mom, bool exact)
143 #ifdef GPU_GAUGE_TOOLS
144 checkPrecision(out, in, mom);
145 checkLocation(out, in, mom);
146 checkReconstruct(out, in);
147 instantiate<UpdateGauge,ReconstructNo12>(out, in, mom, dt, conj_mom, exact);
149 errorQuda("Gauge tools are not build");