12 #ifdef GPU_GAUGE_TOOLS
14 template <
typename Float,
typename Order>
15 struct GaugePhaseArg {
20 GaugePhaseArg(
const Order &order,
const int *X_,
QudaTboundary tBoundary_)
23 for (
int d=0; d<4; d++) {
32 bool last_node_in_t =
true;
35 printf(
"node=%d Tboundary = %e\n",
comm_rank(), tBoundary);
37 GaugePhaseArg(
const GaugePhaseArg &
arg)
38 : order(arg.order), tBoundary(arg.tBoundary), volume(arg.volume) {
39 for (
int d=0; d<4; d++)
X[d] = arg.X[d];
46 template <
int dim,
typename Float, QudaStaggeredPhase phaseType,
typename Arg>
47 __device__ __host__
Float getPhase(
int x,
int y,
int z,
int t, Arg &
arg) {
51 phase = (1.0 - 2.0 * (t % 2) );
52 }
else if (
dim == 1) {
53 phase = (1.0 - 2.0 * ((t +
x) % 2) );
54 }
else if (
dim == 2) {
55 phase = (1.0 - 2.0 * ((t + x +
y) % 2) );
56 }
else if (
dim == 3) {
57 phase = (t == arg.X[3]-1) ? arg.tBoundary : 1.0;
61 phase = (1.0 - 2.0 * ((3 + t + z +
y) % 2) );
62 }
else if (
dim == 1) {
63 phase = (1.0 - 2.0 * ((2 + t + z) % 2) );
64 }
else if (
dim == 2) {
65 phase = (1.0 - 2.0 * ((1 + t) % 2) );
66 }
else if (
dim == 3) {
67 phase = (t == arg.X[3]-1) ? arg.tBoundary : 1.0;
72 }
else if (
dim == 1) {
73 phase = (1.0 - 2.0 * ((1 +
x) % 2) );
74 }
else if (
dim == 2) {
75 phase = (1.0 - 2.0 * ((1 + x +
y) % 2) );
76 }
else if (
dim == 3) {
77 phase = ((t == arg.X[3]-1) ? arg.tBoundary : 1.0) *
78 (1.0 - 2 * ((1 + x + y + z) % 2) );
84 template <
typename Float,
int length, QudaStaggeredPhase phaseType,
int dim,
typename Arg>
85 __device__ __host__
void gaugePhase(
int xh,
int y,
int z,
int t,
int parity, Arg &arg) {
86 typedef typename mapper<Float>::type RegType;
87 int indexCB = ((t*arg.X[2] + z)*arg.X[1] + y)*(arg.X[0]>>1) + xh;
89 Float phase = getPhase<dim,Float,phaseType>(
x,
y, z, t,
arg);
93 arg.order.load(u, indexCB,
dim, parity);
94 for (
int i=0; i<
length; i++) u[i] *= phase;
95 arg.order.save(u, indexCB,
dim, parity);
102 template <
typename Float,
int length, QudaStaggeredPhase phaseType,
typename Arg>
103 void gaugePhase(Arg &arg) {
104 for (
int parity=0; parity<2; parity++) {
105 for (
int t=0; t<arg.X[3]; t++) {
106 for (
int z=0; z<arg.X[2]; z++) {
107 for (
int y=0; y<arg.X[1]; y++) {
108 for (
int xh=0; xh<arg.X[0]>>1; xh++) {
109 gaugePhase<Float,length,phaseType,0>(xh,
y, z, t,
parity,
arg);
110 gaugePhase<Float,length,phaseType,1>(xh,
y, z, t,
parity,
arg);
111 gaugePhase<Float,length,phaseType,2>(xh,
y, z, t,
parity,
arg);
112 gaugePhase<Float,length,phaseType,3>(xh,
y, z, t,
parity,
arg);
124 template <
typename Float,
int length, QudaStaggeredPhase phaseType,
typename Arg>
125 __global__
void gaugePhaseKernel(Arg arg) {
126 int X = blockIdx.x * blockDim.x + threadIdx.x;
127 if (X >= (arg.volume>>1))
return;
128 int parity = blockIdx.y;
130 int tzy = X / (arg.X[0]>>1);
131 int xh = X - tzy*(arg.X[0]>>1);
132 int tz = tzy / arg.X[1];
133 int y = tzy - tz*arg.X[1];
134 int t = tz / arg.X[2];
135 int z = tz - t * arg.X[2];
136 gaugePhase<Float,length,phaseType,0>(xh,
y, z, t,
parity,
arg);
137 gaugePhase<Float,length,phaseType,1>(xh,
y, z, t,
parity,
arg);
138 gaugePhase<Float,length,phaseType,2>(xh,
y, z, t,
parity,
arg);
139 gaugePhase<Float,length,phaseType,3>(xh,
y, z, t,
parity,
arg);
142 template <
typename Float,
int length, QudaStaggeredPhase phaseType,
typename Arg>
143 class GaugePhase : Tunable {
145 const GaugeField &meta;
149 unsigned int sharedBytesPerThread()
const {
return 0; }
150 unsigned int sharedBytesPerBlock(
const TuneParam &
param)
const {
return 0 ;}
152 bool tuneGridDim()
const {
return false; }
153 unsigned int minThreads()
const {
return arg.volume>>1; }
157 : arg(arg), meta(meta), location(location) {
158 writeAuxString(
"stride=%d,prec=%lu",arg.order.stride,
sizeof(
Float));
160 virtual ~GaugePhase() { ; }
162 void apply(
const cudaStream_t &
stream) {
166 gaugePhaseKernel<Float, length, phaseType, Arg>
167 <<<tp.grid, tp.block, tp.shared_bytes, stream>>>(
arg);
169 gaugePhase<Float, length, phaseType, Arg>(
arg);
173 TuneKey tuneKey()
const {
174 return TuneKey(meta.VolString(),
typeid(*this).name(), aux);
178 std::stringstream ps;
179 ps <<
"block=(" << param.block.x <<
"," << param.block.y <<
"," << param.block.z <<
"), ";
180 ps <<
"shared=" << param.shared_bytes;
184 long long flops()
const {
return 0; }
185 long long bytes()
const {
return arg.volume * 2 * arg.order.Bytes(); }
189 template <
typename Float,
int length,
typename Order>
192 GaugePhaseArg<Float,Order>
arg(order, u.X(), u.TBoundary());
194 GaugePhaseArg<Float,Order> > phase(arg, u, location);
197 GaugePhaseArg<Float,Order>
arg(order, u.X(), u.TBoundary());
199 GaugePhaseArg<Float,Order> > phase(arg, u, location);
202 GaugePhaseArg<Float,Order>
arg(order, u.X(), u.TBoundary());
204 GaugePhaseArg<Float,Order> > phase(arg, u, location);
214 template <
typename Float>
215 void gaugePhase(GaugeField &u) {
216 const int length = 18;
224 gaugePhase<Float,length>(FloatNOrder<Float,length,2,19>(u), u, location);
226 gaugePhase<Float,length>(FloatNOrder<Float,length,2,18>(u), u, location);
229 gaugePhase<Float,length>(FloatNOrder<Float,length,2,12>(u), u, location);
231 errorQuda(
"Unsupported recsontruction type");
236 gaugePhase<Float,length>(FloatNOrder<Float,length,1,19>(u), u, location);
238 gaugePhase<Float,length>(FloatNOrder<Float,length,1,18>(u),u, location);
241 gaugePhase<Float,length>(FloatNOrder<Float,length,4,12>(u), u, location);
243 errorQuda(
"Unsupported recsontruction type");
247 #ifdef BUILD_TIFR_INTERFACE
248 gaugePhase<Float,length>(TIFROrder<Float,length>(u), u, location);
250 errorQuda(
"TIFR interface has not been built\n");
254 errorQuda(
"Gauge field %d order not supported", u.Order());
263 #ifdef GPU_GAUGE_TOOLS
265 gaugePhase<double>(u);
267 gaugePhase<float>(u);
QudaVerbosity getVerbosity()
void applyGaugePhase(GaugeField &u)
enum QudaTboundary_s QudaTboundary
QudaPrecision Precision() const
const QudaFieldLocation location
FloatingPoint< float > Float
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
enum QudaFieldLocation_s QudaFieldLocation
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.