14 #ifdef GPU_GAUGE_TOOLS 16 template <
typename Float,
typename Gauge,
typename Mom>
17 struct UpdateGaugeArg {
23 UpdateGaugeArg(
const Gauge &
out,
const Gauge &
in,
24 const Mom &momentum, Float dt,
int nDim)
25 :
out(
out),
in(
in), momentum(momentum), dt(dt), nDim(nDim) { }
31 template <
typename Float>
32 __device__ __host__
void expsu3(
Matrix<complex<Float>,3> &q,
int x) {
35 Complex a2 = (q(3)*q(1)+q(7)*q(5)+q(6)*q(2) -
36 (q(0)*q(4)+(q(0)+q(4))*q(8))) / (Float)3.0 ;
37 Complex a3 = q(0)*q(4)*q(8) + q(1)*q(5)*q(6) + q(2)*q(3)*q(7) -
38 q(6)*q(4)*q(2) - q(3)*q(1)*q(8) - q(0)*q(7)*q(5);
56 Complex wr21 = (z1+al*q(7)) / (z2+al*q(6));
57 Complex wr31 = (al-q(0)-wr21*q(3))/q(6);
60 Complex wr22 = (z1+al*q(7))/(z2+al*q(6));
61 Complex wr32 = (al-q(0)-wr22*q(3))/q(6);
64 Complex wr23 = (z1+al*q(7))/(z2+al*q(6));
65 Complex wr33 = (al-q(0)-wr23*q(3))/q(6);
67 z1=q(3)*q(2) - q(0)*q(5);
68 z2=q(1)*q(5) - q(4)*q(2);
98 q(0) = y11 + y12 + y13;
99 q(1) = y21 + y22 + y23;
100 q(2) = y31 + y32 + y33;
101 q(3) = y11*
conj(wl21) + y12*
conj(wl22) + y13*
conj(wl23);
102 q(4) = y21*
conj(wl21) + y22*
conj(wl22) + y23*
conj(wl23);
103 q(5) = y31*
conj(wl21) + y32*
conj(wl22) + y33*
conj(wl23);
104 q(6) = y11*
conj(wl31) + y12*
conj(wl32) + y13*
conj(wl33);
105 q(7) = y21*
conj(wl31) + y22*
conj(wl32) + y23*
conj(wl33);
106 q(8) = y31*
conj(wl31) + y32*
conj(wl32) + y33*
conj(wl33);
109 template<
typename Float,
typename Gauge,
typename Mom,
int N,
110 bool conj_mom,
bool exact>
111 __device__ __host__
void updateGaugeFieldCompute
112 (UpdateGaugeArg<Float,Gauge,Mom> &
arg,
int x,
int parity) {
113 typedef complex<Float>
Complex;
116 for(
int dir=0; dir<
arg.nDim; ++dir){
121 mom(0,0) -= trace/
static_cast<Float
>(3.0);
122 mom(1,1) -= trace/
static_cast<Float
>(3.0);
123 mom(2,2) -= trace/
static_cast<Float
>(3.0);
130 for(
int r=N; r>0; r--)
131 result = (
arg.dt/r)*mom*result + link;
133 for(
int r=N; r>0; r--)
134 result = (
arg.dt/r)*
conj(mom)*result + link;
138 expsu3<Float>(mom,
x+dir+
parity);
143 link =
conj(mom) * link;
154 template<
typename Float,
typename Gauge,
typename Mom,
int N,
155 bool conj_mom,
bool exact>
159 for (
int x=0;
x<
arg.out.volumeCB;
x++) {
160 updateGaugeFieldCompute<Float,Gauge,Mom,N,conj_mom,exact>
166 template<
typename Float,
typename Gauge,
typename Mom,
int N,
167 bool conj_mom,
bool exact>
168 __global__
void updateGaugeFieldKernel(UpdateGaugeArg<Float,Gauge,Mom>
arg) {
170 if (
idx >= 2*
arg.out.volumeCB)
return;
174 updateGaugeFieldCompute<Float,Gauge,Mom,N,conj_mom,exact>(
arg,
idx,
parity);
177 template <
typename Float,
typename Gauge,
typename Mom,
int N,
178 bool conj_mom,
bool exact>
179 class UpdateGaugeField :
public Tunable {
181 UpdateGaugeArg<Float,Gauge,Mom>
arg;
182 const GaugeField &meta;
185 unsigned int sharedBytesPerThread()
const {
return 0; }
186 unsigned int sharedBytesPerBlock(
const TuneParam &)
const {
return 0; }
188 unsigned int minThreads()
const {
return 2*
arg.in.volumeCB; }
189 bool tuneGridDim()
const {
return false; }
192 UpdateGaugeField(
const UpdateGaugeArg<Float,Gauge,Mom> &
arg,
194 :
arg(
arg), meta(meta), location(location) {
195 writeAuxString(
"threads=%d,prec=%lu,stride=%d",
196 2*
arg.in.volumeCB,
sizeof(Float),
arg.in.stride);
198 virtual ~UpdateGaugeField() { }
200 void apply(
const cudaStream_t &
stream){
203 updateGaugeFieldKernel<Float,Gauge,Mom,N,conj_mom,exact>
204 <<<tp.grid,tp.block,tp.shared_bytes>>>(
arg);
206 updateGaugeField<Float,Gauge,Mom,N,conj_mom,exact>(
arg);
210 long long flops()
const {
212 return arg.nDim*2*
arg.in.volumeCB*N*(Nc*Nc*2 +
213 (8*Nc*Nc*Nc - 2*Nc*Nc) +
216 long long bytes()
const {
return arg.nDim*2*
arg.in.volumeCB*
217 (
arg.in.Bytes() +
arg.out.Bytes() +
arg.momentum.Bytes()); }
219 TuneKey tuneKey()
const {
return TuneKey(meta.VolString(),
typeid(*this).name(), aux); }
222 template <
typename Float,
typename Gauge,
typename Mom>
224 double dt,
const GaugeField &meta,
bool conj_mom,
bool exact,
231 UpdateGaugeArg<Float, Gauge, Mom>
arg(
out,
in, mom, dt, 4);
232 UpdateGaugeField<Float,Gauge,Mom,N,true,true> updateGauge(
arg, meta, location);
233 updateGauge.apply(0);
235 UpdateGaugeArg<Float, Gauge, Mom>
arg(
out,
in, mom, dt, 4);
236 UpdateGaugeField<Float,Gauge,Mom,N,true,false> updateGauge(
arg, meta, location);
237 updateGauge.apply(0);
241 UpdateGaugeArg<Float, Gauge, Mom>
arg(
out,
in, mom, dt, 4);
242 UpdateGaugeField<Float,Gauge,Mom,N,false,true> updateGauge(
arg, meta, location);
243 updateGauge.apply(0);
245 UpdateGaugeArg<Float, Gauge, Mom>
arg(
out,
in, mom, dt, 4);
246 UpdateGaugeField<Float,Gauge,Mom,N,false,false> updateGauge(
arg, meta, location);
247 updateGauge.apply(0);
255 template <
typename Float,
typename Gauge>
257 double dt,
bool conj_mom,
bool exact,
262 updateGaugeField<Float>(
out,
in, gauge::FloatNOrder<Float,18,2,11>(mom), dt, mom, conj_mom, exact, location);
264 errorQuda(
"Reconstruction type not supported");
267 updateGaugeField<Float>(
out,
in, gauge::MILCOrder<Float,10>(mom), dt, mom, conj_mom, exact, location);
269 errorQuda(
"Gauge Field order %d not supported", mom.Order());
274 template <
typename Float>
276 double dt,
bool conj_mom,
bool exact,
283 if (
out.Order() !=
in.Order() ||
out.Reconstruct() !=
in.Reconstruct()) {
284 errorQuda(
"Input and output gauge field ordering and reconstruction must match");
289 typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO>::type G;
290 updateGaugeField<Float>(G(
out),G(
in), mom, dt, conj_mom, exact, location);
292 typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_12>::type G;
293 updateGaugeField<Float>(G(
out), G(
in), mom, dt, conj_mom, exact, location);
295 errorQuda(
"Reconstruction type not supported");
298 updateGaugeField<Float>(gauge::MILCOrder<Float, Nc*Nc*2>(
out),
299 gauge::MILCOrder<Float, Nc*Nc*2>(
in),
300 mom, dt, conj_mom, exact, location);
302 errorQuda(
"Gauge Field order %d not supported",
out.Order());
309 const GaugeField& mom,
bool conj_mom,
bool exact)
311 #ifdef GPU_GAUGE_TOOLS 313 errorQuda(
"Gauge and momentum fields must have matching precision");
316 errorQuda(
"Gauge and momentum fields must have matching location");
319 updateGaugeField<double>(
out,
in, mom, dt, conj_mom, exact,
out.
Location());
QudaVerbosity getVerbosity()
__host__ __device__ ValueType exp(ValueType x)
__host__ __device__ ValueType sqrt(ValueType x)
std::complex< double > Complex
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Main header file for host and device accessors to GaugeFields.
QudaFieldLocation Location() const
__host__ __device__ ValueType log(ValueType x)
__device__ __host__ T getTrace(const Matrix< T, 3 > &a)
enum QudaFieldLocation_s QudaFieldLocation
cpuColorSpinorField * out
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
void updateGaugeField(GaugeField &out, double dt, const GaugeField &in, const GaugeField &mom, bool conj_mom, bool exact)
__host__ __device__ ValueType conj(ValueType x)
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
QudaPrecision Precision() const