14 #ifdef GPU_GAUGE_TOOLS
16 template <
typename Complex,
typename Gauge,
typename Mom>
17 struct UpdateGaugeArg {
18 typedef typename RealTypeId<Complex>::Type real;
24 UpdateGaugeArg(
const Gauge &
out,
const Gauge &
in,
25 const Mom &momentum, real dt,
int nDim)
26 : out(out), in(in), momentum(momentum), dt(dt), nDim(nDim) { }
32 template <
typename complex,
typename Cmplx>
34 typedef typename RealTypeId<Cmplx>::Type real;
36 complex a2 = (q(3)*q(1)+q(7)*q(5)+q(6)*q(2) -
37 (q(0)*q(4)+(q(0)+q(4))*q(8))) / (real)3.0 ;
38 complex a3 = q(0)*q(4)*q(8) + q(1)*q(5)*q(6) + q(2)*q(3)*q(7) -
39 q(6)*q(4)*q(2) - q(3)*q(1)*q(8) - q(0)*q(7)*q(5);
57 complex wr21 = (z1+al*q(7)) / (z2+al*q(6));
58 complex wr31 = (al-q(0)-wr21*q(3))/q(6);
61 complex wr22 = (z1+al*q(7))/(z2+al*q(6));
62 complex wr32 = (al-q(0)-wr22*q(3))/q(6);
65 complex wr23 = (z1+al*q(7))/(z2+al*q(6));
66 complex wr33 = (al-q(0)-wr23*q(3))/q(6);
68 z1=q(3)*q(2) - q(0)*q(5);
69 z2=q(1)*q(5) - q(4)*q(2);
99 q(0) = y11 + y12 + y13;
100 q(1) = y21 + y22 + y23;
101 q(2) = y31 + y32 + y33;
102 q(3) = y11*
conj(wl21) + y12*
conj(wl22) + y13*
conj(wl23);
103 q(4) = y21*
conj(wl21) + y22*
conj(wl22) + y23*
conj(wl23);
104 q(5) = y31*
conj(wl21) + y32*
conj(wl22) + y33*
conj(wl23);
105 q(6) = y11*
conj(wl31) + y12*
conj(wl32) + y13*
conj(wl33);
106 q(7) = y21*
conj(wl31) + y22*
conj(wl32) + y23*
conj(wl33);
107 q(8) = y31*
conj(wl31) + y32*
conj(wl32) + y33*
conj(wl33);
114 template <
typename complex,
typename Cmplx>
117 c[0] =
conj(a(0))*b(0)+
conj(a(1))*b(1) +
conj(a(2))*b(2);
118 c[3] =
conj(a(0))*b(3)+
conj(a(1))*b(4) +
conj(a(2))*b(5);
119 c[6] =
conj(a(0))*b(6)+
conj(a(1))*b(7) +
conj(a(2))*b(8);
120 c[1] =
conj(a(3))*b(0)+
conj(a(4))*b(1) +
conj(a(5))*b(2);
121 c[4] =
conj(a(3))*b(3)+
conj(a(4))*b(4) +
conj(a(5))*b(5);
122 c[7] =
conj(a(3))*b(6)+
conj(a(4))*b(7) +
conj(a(5))*b(8);
123 c[2] =
conj(a(6))*b(0)+
conj(a(7))*b(1) +
conj(a(8))*b(2);
124 c[5] =
conj(a(6))*b(3)+
conj(a(7))*b(4) +
conj(a(8))*b(5);
125 c[8] =
conj(a(6))*b(6)+
conj(a(7))*b(7) +
conj(a(8))*b(8);
126 for (
int i=0; i<9; i++) b(i) = c[i];
129 template<
typename Cmplx,
typename Gauge,
typename Mom,
int N,
130 bool conj_mom,
bool exact>
131 __device__ __host__
void updateGaugeFieldCompute
132 (UpdateGaugeArg<Cmplx,Gauge,Mom> &
arg,
int x,
int parity) {
134 typedef typename RealTypeId<Cmplx>::Type real;
136 for(
int dir=0; dir<arg.nDim; ++dir){
137 arg.in.load((real*)(link.data), x, dir, parity);
138 arg.momentum.load((real*)(mom.data), x, dir, parity);
141 mom(0,0) -= trace/3.0;
142 mom(1,1) -= trace/3.0;
143 mom(2,2) -= trace/3.0;
150 for(
int r=N; r>0; r--)
151 result = (arg.dt/r)*mom*result + link;
153 for(
int r=N; r>0; r--)
154 result = (arg.dt/r)*
conj(mom)*result + link;
158 expsu3<complex<real> >(mom, x+dir+
parity);
163 link =
conj(mom) * link;
170 arg.out.save((real*)(result.data), x, dir, parity);
175 template<
typename Cmplx,
typename Gauge,
typename Mom,
int N,
176 bool conj_mom,
bool exact>
179 for (
unsigned int parity=0; parity<2; parity++) {
180 for (
unsigned int x=0; x<arg.out.volumeCB; x++) {
181 updateGaugeFieldCompute<Cmplx,Gauge,Mom,N,conj_mom,exact>
187 template<
typename Cmplx,
typename Gauge,
typename Mom,
int N,
188 bool conj_mom,
bool exact>
189 __global__
void updateGaugeFieldKernel(UpdateGaugeArg<Cmplx,Gauge,Mom> arg) {
190 int idx = blockIdx.x*blockDim.x + threadIdx.x;
191 if (idx >= 2*arg.out.volumeCB)
return;
192 int parity = (idx >= arg.out.volumeCB) ? 1 : 0;
193 idx -= parity*arg.out.volumeCB;
195 updateGaugeFieldCompute<Cmplx,Gauge,Mom,N,conj_mom,exact>
199 template <
typename Complex,
typename Gauge,
typename Mom,
int N,
200 bool conj_mom,
bool exact>
201 class UpdateGaugeField :
public Tunable {
203 UpdateGaugeArg<Complex,Gauge,Mom>
arg;
204 const GaugeField &meta;
207 unsigned int sharedBytesPerThread()
const {
return 0; }
208 unsigned int sharedBytesPerBlock(
const TuneParam &)
const {
return 0; }
210 unsigned int minThreads()
const {
return 2*arg.in.volumeCB; }
211 bool tuneGridDim()
const {
return false; }
214 UpdateGaugeField(
const UpdateGaugeArg<Complex,Gauge,Mom> &arg,
216 : arg(arg), meta(meta), location(location) {
217 writeAuxString(
"threads=%d,prec=%lu,stride=%d",
218 2*arg.in.volumeCB,
sizeof(
Complex)/2, arg.in.stride);
220 virtual ~UpdateGaugeField() { }
222 void apply(
const cudaStream_t &
stream){
224 #if __COMPUTE_CAPABILITY__ >= 200
226 updateGaugeFieldKernel<Complex,Gauge,Mom,N,conj_mom,exact>
227 <<<tp.grid,tp.block,tp.shared_bytes>>>(
arg);
229 errorQuda(
"Not supported on pre-Fermi architecture");
232 updateGaugeField<Complex,Gauge,Mom,N,conj_mom,exact>(
arg);
239 long long flops()
const {
241 return arg.nDim*2*arg.in.volumeCB*N*(Nc*Nc*2 +
242 (8*Nc*Nc*Nc - 2*Nc*Nc) +
245 long long bytes()
const {
return arg.nDim*2*arg.in.volumeCB*
246 (arg.in.Bytes() + arg.out.Bytes() + arg.momentum.Bytes()); }
248 TuneKey tuneKey()
const {
return TuneKey(meta.VolString(),
typeid(*this).name(), aux); }
251 template <
typename Float,
typename Gauge,
typename Mom>
253 double dt,
const GaugeField &meta,
bool conj_mom,
bool exact,
258 typedef typename ComplexTypeId<Float>::Type
Complex;
261 UpdateGaugeArg<Complex, Gauge, Mom>
arg(out, in, mom, dt, 4);
262 UpdateGaugeField<Complex,Gauge,Mom,N,true,true> updateGauge(arg, meta, location);
263 updateGauge.apply(0);
265 UpdateGaugeArg<Complex, Gauge, Mom>
arg(out, in, mom, dt, 4);
266 UpdateGaugeField<Complex,Gauge,Mom,N,true,false> updateGauge(arg, meta, location);
267 updateGauge.apply(0);
271 UpdateGaugeArg<Complex, Gauge, Mom>
arg(out, in, mom, dt, 4);
272 UpdateGaugeField<Complex,Gauge,Mom,N,false,true> updateGauge(arg, meta, location);
273 updateGauge.apply(0);
275 UpdateGaugeArg<Complex, Gauge, Mom>
arg(out, in, mom, dt, 4);
276 UpdateGaugeField<Complex,Gauge,Mom,N,false,false> updateGauge(arg, meta, location);
277 updateGauge.apply(0);
285 template <
typename Float,
typename Gauge>
287 double dt,
bool conj_mom,
bool exact,
292 updateGaugeField<Float>(
out,
in, FloatNOrder<Float,18,2,11>(mom), dt, mom, conj_mom, exact, location);
294 errorQuda(
"Reconstruction type not supported");
297 updateGaugeField<Float>(
out,
in, MILCOrder<Float,10>(mom), dt, mom, conj_mom, exact, location);
299 errorQuda(
"Gauge Field order %d not supported", mom.Order());
304 template <
typename Float>
305 void updateGaugeField(GaugeField &out,
const GaugeField &in,
const GaugeField &mom,
306 double dt,
bool conj_mom,
bool exact,
310 if (out.Ncolor() != Nc)
311 errorQuda(
"Ncolor=%d not supported at this time", out.Ncolor());
313 if (out.Order() != in.Order() || out.Reconstruct() != in.Reconstruct()) {
314 errorQuda(
"Input and output gauge field ordering and reconstruction must match");
319 updateGaugeField<Float>(FloatNOrder<Float, Nc*Nc*2, 2, 18>(
out),
320 FloatNOrder<Float, Nc*Nc*2, 2, 18>(in),
321 mom, dt, conj_mom, exact,
location);
323 updateGaugeField<Float>(FloatNOrder<Float, Nc*Nc*2, 2, 12>(
out),
324 FloatNOrder<Float, Nc*Nc*2, 2, 12>(in),
325 mom, dt, conj_mom, exact,
location);
327 errorQuda(
"Reconstruction type not supported");
331 updateGaugeField<Float>(FloatNOrder<Float, Nc*Nc*2, 4, 12>(
out),
332 FloatNOrder<Float, Nc*Nc*2, 4, 12>(in),
333 mom, dt, conj_mom, exact,
location);
335 errorQuda(
"Reconstruction type %d not supported", out.Order());
338 updateGaugeField<Float>(MILCOrder<Float, Nc*Nc*2>(
out),
339 MILCOrder<Float, Nc*Nc*2>(in),
340 mom, dt, conj_mom, exact,
location);
342 errorQuda(
"Gauge Field order %d not supported", out.Order());
349 const GaugeField& mom,
bool conj_mom,
bool exact)
351 #ifdef GPU_GAUGE_TOOLS
353 errorQuda(
"Gauge and momentum fields must have matching precision");
356 errorQuda(
"Gauge and momentum fields must have matching location");
359 updateGaugeField<double>(
out,
in, mom, dt, conj_mom, exact, out.
Location());
361 updateGaugeField<float>(
out,
in, mom, dt, conj_mom, exact, out.
Location());
QudaVerbosity getVerbosity()
__host__ __device__ ValueType exp(ValueType x)
QudaFieldLocation Location() const
__host__ __device__ ValueType sqrt(ValueType x)
std::complex< double > Complex
QudaPrecision Precision() const
const QudaFieldLocation location
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
__host__ __device__ ValueType log(ValueType x)
__device__ __host__ T getTrace(const Matrix< T, 3 > &a)
enum QudaFieldLocation_s QudaFieldLocation
cpuColorSpinorField * out
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
void updateGaugeField(GaugeField &out, double dt, const GaugeField &in, const GaugeField &mom, bool conj_mom, bool exact)
__host__ __device__ ValueType conj(ValueType x)