14 #ifdef GPU_GAUGE_TOOLS 16 template <
typename Float,
typename Gauge,
typename Mom>
17 struct UpdateGaugeArg {
23 UpdateGaugeArg(
const Gauge &out,
const Gauge &in,
24 const Mom &momentum, Float dt,
int nDim)
25 : out(out), in(in), momentum(momentum), dt(dt), nDim(nDim) { }
28 template<
typename Float,
typename Gauge,
typename Mom,
int N,
29 bool conj_mom,
bool exact>
30 __device__ __host__
void updateGaugeFieldCompute
31 (UpdateGaugeArg<Float,Gauge,Mom> &
arg,
int x,
int parity) {
35 for(
int dir=0; dir<arg.nDim; ++dir){
36 link = arg.in(dir, x, parity);
37 mom = arg.momentum(dir, x, parity);
40 mom(0,0) -= trace/
static_cast<Float
>(3.0);
41 mom(1,1) -= trace/
static_cast<Float
>(3.0);
42 mom(2,2) -= trace/
static_cast<Float
>(3.0);
49 for(
int r=N; r>0; r--)
50 result = (arg.dt/r)*mom*result + link;
52 for(
int r=N; r>0; r--)
53 result = (arg.dt/r)*
conj(mom)*result + link;
62 link =
conj(mom) * link;
68 arg.out(dir, x, parity) = result;
73 template<
typename Float,
typename Gauge,
typename Mom,
int N,
74 bool conj_mom,
bool exact>
77 for (
unsigned int parity=0; parity<2; parity++) {
78 for (
int x=0; x<arg.out.volumeCB; x++) {
79 updateGaugeFieldCompute<Float,Gauge,Mom,N,conj_mom,exact>
85 template<
typename Float,
typename Gauge,
typename Mom,
int N,
86 bool conj_mom,
bool exact>
87 __global__
void updateGaugeFieldKernel(UpdateGaugeArg<Float,Gauge,Mom> arg) {
88 int idx = blockIdx.x*blockDim.x + threadIdx.x;
89 if (idx >= 2*arg.out.volumeCB)
return;
90 int parity = (idx >= arg.out.volumeCB) ? 1 : 0;
91 idx -= parity*arg.out.volumeCB;
93 updateGaugeFieldCompute<Float,Gauge,Mom,N,conj_mom,exact>(
arg, idx,
parity);
96 template <
typename Float,
typename Gauge,
typename Mom,
int N,
97 bool conj_mom,
bool exact>
98 class UpdateGaugeField :
public Tunable {
100 UpdateGaugeArg<Float,Gauge,Mom>
arg;
101 const GaugeField &meta;
104 unsigned int sharedBytesPerThread()
const {
return 0; }
105 unsigned int sharedBytesPerBlock(
const TuneParam &)
const {
return 0; }
107 unsigned int minThreads()
const {
return 2*arg.in.volumeCB; }
108 bool tuneGridDim()
const {
return false; }
111 UpdateGaugeField(
const UpdateGaugeArg<Float,Gauge,Mom> &arg,
113 : arg(arg), meta(meta), location(location) {
114 writeAuxString(
"threads=%d,prec=%lu,stride=%d",
115 2*arg.in.volumeCB,
sizeof(Float), arg.in.stride);
117 virtual ~UpdateGaugeField() { }
119 void apply(
const cudaStream_t &
stream){
122 updateGaugeFieldKernel<Float,Gauge,Mom,N,conj_mom,exact>
123 <<<tp.grid,tp.block,tp.shared_bytes>>>(
arg);
125 updateGaugeField<Float,Gauge,Mom,N,conj_mom,exact>(
arg);
129 long long flops()
const {
131 return arg.nDim*2*arg.in.volumeCB*N*(Nc*Nc*2 +
132 (8*Nc*Nc*Nc - 2*Nc*Nc) +
135 long long bytes()
const {
return arg.nDim*2*arg.in.volumeCB*
136 (arg.in.Bytes() + arg.out.Bytes() + arg.momentum.Bytes()); }
138 TuneKey tuneKey()
const {
return TuneKey(meta.VolString(),
typeid(*this).name(), aux); }
141 template <
typename Float,
typename Gauge,
typename Mom>
143 double dt,
const GaugeField &meta,
bool conj_mom,
bool exact,
150 UpdateGaugeArg<Float, Gauge, Mom>
arg(out, in, mom, dt, 4);
151 UpdateGaugeField<Float,Gauge,Mom,N,true,true> updateGauge(arg, meta, location);
152 updateGauge.apply(0);
154 UpdateGaugeArg<Float, Gauge, Mom>
arg(out, in, mom, dt, 4);
155 UpdateGaugeField<Float,Gauge,Mom,N,true,false> updateGauge(arg, meta, location);
156 updateGauge.apply(0);
160 UpdateGaugeArg<Float, Gauge, Mom>
arg(out, in, mom, dt, 4);
161 UpdateGaugeField<Float,Gauge,Mom,N,false,true> updateGauge(arg, meta, location);
162 updateGauge.apply(0);
164 UpdateGaugeArg<Float, Gauge, Mom>
arg(out, in, mom, dt, 4);
165 UpdateGaugeField<Float,Gauge,Mom,N,false,false> updateGauge(arg, meta, location);
166 updateGauge.apply(0);
174 template <
typename Float,
typename Gauge>
176 double dt,
bool conj_mom,
bool exact,
181 updateGaugeField<Float>(
out,
in, gauge::FloatNOrder<Float,18,2,11>(mom), dt, mom, conj_mom, exact, location);
183 errorQuda(
"Reconstruction type not supported");
186 updateGaugeField<Float>(
out,
in, gauge::MILCOrder<Float,10>(mom), dt, mom, conj_mom, exact, location);
188 errorQuda(
"Gauge Field order %d not supported", mom.Order());
193 template <
typename Float>
194 void updateGaugeField(GaugeField &out,
const GaugeField &in,
const GaugeField &mom,
195 double dt,
bool conj_mom,
bool exact,
199 if (out.Ncolor() != Nc)
200 errorQuda(
"Ncolor=%d not supported at this time", out.Ncolor());
202 if (out.Order() != in.Order() || out.Reconstruct() != in.Reconstruct()) {
203 errorQuda(
"Input and output gauge field ordering and reconstruction must match");
206 if (out.isNative()) {
208 typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO>::type G;
209 updateGaugeField<Float>(G(out),G(in), mom, dt, conj_mom, exact, location);
211 typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_12>::type G;
212 updateGaugeField<Float>(G(out), G(in), mom, dt, conj_mom, exact, location);
214 errorQuda(
"Reconstruction type not supported");
217 updateGaugeField<Float>(gauge::MILCOrder<Float, Nc*Nc*2>(
out),
218 gauge::MILCOrder<Float, Nc*Nc*2>(in),
219 mom, dt, conj_mom, exact, location);
221 errorQuda(
"Gauge Field order %d not supported", out.Order());
228 const GaugeField& mom,
bool conj_mom,
bool exact)
230 #ifdef GPU_GAUGE_TOOLS 232 errorQuda(
"Gauge and momentum fields must have matching precision");
235 errorQuda(
"Gauge and momentum fields must have matching location");
238 updateGaugeField<double>(
out,
in, mom, dt, conj_mom, exact, out.
Location());
240 updateGaugeField<float>(
out,
in, mom, dt, conj_mom, exact, out.
Location());
QudaVerbosity getVerbosity()
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Main header file for host and device accessors to GaugeFields.
std::complex< double > Complex
QudaFieldLocation Location() const
__device__ __host__ T getTrace(const Matrix< T, 3 > &a)
enum QudaFieldLocation_s QudaFieldLocation
cpuColorSpinorField * out
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
void updateGaugeField(GaugeField &out, double dt, const GaugeField &in, const GaugeField &mom, bool conj_mom, bool exact)
__host__ __device__ ValueType conj(ValueType x)
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
QudaPrecision Precision() const