QUDA  v0.7.0
A library for QCD on GPUs
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
gauge_phase.cu
Go to the documentation of this file.
1 #include <gauge_field_order.h>
2 #include <comm_quda.h>
3 
10 namespace quda {
11 
12 #ifdef GPU_GAUGE_TOOLS
13 
14  template <typename Float, typename Order>
15  struct GaugePhaseArg {
16  Order order;
17  int X[4];
18  int volume;
19  Float tBoundary;
20  GaugePhaseArg(const Order &order, const int *X_, QudaTboundary tBoundary_)
21  : order(order) {
22  volume = 1;
23  for (int d=0; d<4; d++) {
24  X[d] = X_[d];
25  volume *= X[d];
26  }
27 
28  // only set the boundary condition on the last time slice of nodes
29 #ifdef MULTI_GPU
30  bool last_node_in_t = (commCoords(3) == commDim(3)-1);
31 #else
32  bool last_node_in_t = true;
33 #endif
34  tBoundary = (Float)(last_node_in_t ? tBoundary_ : QUDA_PERIODIC_T);
35  printf("node=%d Tboundary = %e\n", comm_rank(), tBoundary);
36  }
37  GaugePhaseArg(const GaugePhaseArg &arg)
38  : order(arg.order), tBoundary(arg.tBoundary), volume(arg.volume) {
39  for (int d=0; d<4; d++) X[d] = arg.X[d];
40  }
41  };
42 
43 
44 
45  // FIXME need to check this with odd local volumes
46  template <int dim, typename Float, QudaStaggeredPhase phaseType, typename Arg>
47  __device__ __host__ Float getPhase(int x, int y, int z, int t, Arg &arg) {
48  Float phase = 1.0;
49  if (phaseType == QUDA_MILC_STAGGERED_PHASE) {
50  if (dim==0) {
51  phase = (1.0 - 2.0 * (t % 2) );
52  } else if (dim == 1) {
53  phase = (1.0 - 2.0 * ((t + x) % 2) );
54  } else if (dim == 2) {
55  phase = (1.0 - 2.0 * ((t + x + y) % 2) );
56  } else if (dim == 3) { // also apply boundary condition
57  phase = (t == arg.X[3]-1) ? arg.tBoundary : 1.0;
58  }
59  } if (phaseType == QUDA_TIFR_STAGGERED_PHASE) {
60  if (dim==0) {
61  phase = (1.0 - 2.0 * ((3 + t + z + y) % 2) );
62  } else if (dim == 1) {
63  phase = (1.0 - 2.0 * ((2 + t + z) % 2) );
64  } else if (dim == 2) {
65  phase = (1.0 - 2.0 * ((1 + t) % 2) );
66  } else if (dim == 3) { // also apply boundary condition
67  phase = (t == arg.X[3]-1) ? arg.tBoundary : 1.0;
68  }
69  } else if (phaseType == QUDA_CPS_STAGGERED_PHASE) {
70  if (dim==0) {
71  phase = 1.0;
72  } else if (dim == 1) {
73  phase = (1.0 - 2.0 * ((1 + x) % 2) );
74  } else if (dim == 2) {
75  phase = (1.0 - 2.0 * ((1 + x + y) % 2) );
76  } else if (dim == 3) { // also apply boundary condition
77  phase = ((t == arg.X[3]-1) ? arg.tBoundary : 1.0) *
78  (1.0 - 2 * ((1 + x + y + z) % 2) );
79  }
80  }
81  return phase;
82  }
83 
84  template <typename Float, int length, QudaStaggeredPhase phaseType, int dim, typename Arg>
85  __device__ __host__ void gaugePhase(int xh, int y, int z, int t, int parity, Arg &arg) {
86  typedef typename mapper<Float>::type RegType;
87  int indexCB = ((t*arg.X[2] + z)*arg.X[1] + y)*(arg.X[0]>>1) + xh;
88  int x = 2*xh + parity;
89  Float phase = getPhase<dim,Float,phaseType>(x, y, z, t, arg);
90  //printf("dim=%d xh=%d y=%d z=%d t=%d parity = %d phase = %e\n",
91  // dim, xh, y, z, t, parity, phase);
92  RegType u[length];
93  arg.order.load(u, indexCB, dim, parity);
94  for (int i=0; i<length; i++) u[i] *= phase;
95  arg.order.save(u, indexCB, dim, parity);
96  }
97 
102  template <typename Float, int length, QudaStaggeredPhase phaseType, typename Arg>
103  void gaugePhase(Arg &arg) {
104  for (int parity=0; parity<2; parity++) {
105  for (int t=0; t<arg.X[3]; t++) {
106  for (int z=0; z<arg.X[2]; z++) {
107  for (int y=0; y<arg.X[1]; y++) {
108  for (int xh=0; xh<arg.X[0]>>1; xh++) {
109  gaugePhase<Float,length,phaseType,0>(xh, y, z, t, parity, arg);
110  gaugePhase<Float,length,phaseType,1>(xh, y, z, t, parity, arg);
111  gaugePhase<Float,length,phaseType,2>(xh, y, z, t, parity, arg);
112  gaugePhase<Float,length,phaseType,3>(xh, y, z, t, parity, arg);
113  }
114  }
115  }
116  }
117  } // parity
118  }
119 
124  template <typename Float, int length, QudaStaggeredPhase phaseType, typename Arg>
125  __global__ void gaugePhaseKernel(Arg arg) {
126  int X = blockIdx.x * blockDim.x + threadIdx.x;
127  if (X >= (arg.volume>>1)) return;
128  int parity = blockIdx.y;
129 
130  int tzy = X / (arg.X[0]>>1);
131  int xh = X - tzy*(arg.X[0]>>1);
132  int tz = tzy / arg.X[1];
133  int y = tzy - tz*arg.X[1];
134  int t = tz / arg.X[2];
135  int z = tz - t * arg.X[2];
136  gaugePhase<Float,length,phaseType,0>(xh, y, z, t, parity, arg);
137  gaugePhase<Float,length,phaseType,1>(xh, y, z, t, parity, arg);
138  gaugePhase<Float,length,phaseType,2>(xh, y, z, t, parity, arg);
139  gaugePhase<Float,length,phaseType,3>(xh, y, z, t, parity, arg);
140  }
141 
142  template <typename Float, int length, QudaStaggeredPhase phaseType, typename Arg>
143  class GaugePhase : Tunable {
144  Arg &arg;
145  const GaugeField &meta; // used for meta data only
147 
148  private:
149  unsigned int sharedBytesPerThread() const { return 0; }
150  unsigned int sharedBytesPerBlock(const TuneParam &param) const { return 0 ;}
151 
152  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
153  unsigned int minThreads() const { return arg.volume>>1; }
154 
155  public:
156  GaugePhase(Arg &arg, const GaugeField &meta, QudaFieldLocation location)
157  : arg(arg), meta(meta), location(location) {
158  writeAuxString("stride=%d,prec=%lu",arg.order.stride,sizeof(Float));
159  }
160  virtual ~GaugePhase() { ; }
161 
162  void apply(const cudaStream_t &stream) {
164  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
165  tp.grid.y = 2; // parity is the y grid dimension
166  gaugePhaseKernel<Float, length, phaseType, Arg>
167  <<<tp.grid, tp.block, tp.shared_bytes, stream>>>(arg);
168  } else {
169  gaugePhase<Float, length, phaseType, Arg>(arg);
170  }
171  }
172 
173  TuneKey tuneKey() const {
174  return TuneKey(meta.VolString(), typeid(*this).name(), aux);
175  }
176 
177  std::string paramString(const TuneParam &param) const { // Don't bother printing the grid dim.
178  std::stringstream ps;
179  ps << "block=(" << param.block.x << "," << param.block.y << "," << param.block.z << "), ";
180  ps << "shared=" << param.shared_bytes;
181  return ps.str();
182  }
183 
184  long long flops() const { return 0; }
185  long long bytes() const { return arg.volume * 2 * arg.order.Bytes(); } // volume * i/o * vec size
186  };
187 
188 
189  template <typename Float, int length, typename Order>
190  void gaugePhase(Order order, const GaugeField &u, QudaFieldLocation location) {
191  if (u.StaggeredPhase() == QUDA_MILC_STAGGERED_PHASE) {
192  GaugePhaseArg<Float,Order> arg(order, u.X(), u.TBoundary());
194  GaugePhaseArg<Float,Order> > phase(arg, u, location);
195  phase.apply(0);
196  } else if (u.StaggeredPhase() == QUDA_CPS_STAGGERED_PHASE) {
197  GaugePhaseArg<Float,Order> arg(order, u.X(), u.TBoundary());
199  GaugePhaseArg<Float,Order> > phase(arg, u, location);
200  phase.apply(0);
201  } else if (u.StaggeredPhase() == QUDA_TIFR_STAGGERED_PHASE) {
202  GaugePhaseArg<Float,Order> arg(order, u.X(), u.TBoundary());
204  GaugePhaseArg<Float,Order> > phase(arg, u, location);
205  phase.apply(0);
206  } else {
207  errorQuda("Undefined phase type");
208  }
209 
210  if (location == QUDA_CUDA_FIELD_LOCATION) checkCudaError();
211  }
212 
214  template <typename Float>
215  void gaugePhase(GaugeField &u) {
216  const int length = 18;
217 
218  QudaFieldLocation location =
219  (typeid(u)==typeid(cudaGaugeField)) ? QUDA_CUDA_FIELD_LOCATION : QUDA_CPU_FIELD_LOCATION;
220 
221  if (u.Order() == QUDA_FLOAT2_GAUGE_ORDER) {
222  if (u.Reconstruct() == QUDA_RECONSTRUCT_NO) {
223  if (typeid(Float)==typeid(short) && u.LinkType() == QUDA_ASQTAD_FAT_LINKS) {
224  gaugePhase<Float,length>(FloatNOrder<Float,length,2,19>(u), u, location);
225  } else {
226  gaugePhase<Float,length>(FloatNOrder<Float,length,2,18>(u), u, location);
227  }
228  } else if (u.Reconstruct() == QUDA_RECONSTRUCT_12) {
229  gaugePhase<Float,length>(FloatNOrder<Float,length,2,12>(u), u, location);
230  } else {
231  errorQuda("Unsupported recsontruction type");
232  }
233  } else if (u.Order() == QUDA_FLOAT4_GAUGE_ORDER) {
234  if (u.Reconstruct() == QUDA_RECONSTRUCT_NO) {
235  if (typeid(Float)==typeid(short) && u.LinkType() == QUDA_ASQTAD_FAT_LINKS) {
236  gaugePhase<Float,length>(FloatNOrder<Float,length,1,19>(u), u, location);
237  } else {
238  gaugePhase<Float,length>(FloatNOrder<Float,length,1,18>(u),u, location);
239  }
240  } else if (u.Reconstruct() == QUDA_RECONSTRUCT_12) {
241  gaugePhase<Float,length>(FloatNOrder<Float,length,4,12>(u), u, location);
242  } else {
243  errorQuda("Unsupported recsontruction type");
244  }
245  } else if (u.Order() == QUDA_TIFR_GAUGE_ORDER) {
246 
247 #ifdef BUILD_TIFR_INTERFACE
248  gaugePhase<Float,length>(TIFROrder<Float,length>(u), u, location);
249 #else
250  errorQuda("TIFR interface has not been built\n");
251 #endif
252 
253  } else {
254  errorQuda("Gauge field %d order not supported", u.Order());
255  }
256 
257  }
258 
259 #endif
260 
262 
263 #ifdef GPU_GAUGE_TOOLS
264  if (u.Precision() == QUDA_DOUBLE_PRECISION) {
265  gaugePhase<double>(u);
266  } else if (u.Precision() == QUDA_SINGLE_PRECISION) {
267  gaugePhase<float>(u);
268  } else {
269  errorQuda("Unknown precision type %d", u.Precision());
270  }
271 #else
272  errorQuda("Gauge tools are not build");
273 #endif
274 
275  }
276 
277 } // namespace quda
int commDim(int)
int comm_rank(void)
Definition: comm_mpi.cpp:80
int y[4]
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:20
#define errorQuda(...)
Definition: util_quda.h:73
cudaStream_t * stream
::std::string string
Definition: gtest.h:1979
void applyGaugePhase(GaugeField &u)
Definition: gauge_phase.cu:261
int length[]
enum QudaTboundary_s QudaTboundary
QudaGaugeParam param
Definition: pack_test.cpp:17
QudaPrecision Precision() const
const QudaFieldLocation location
Definition: pack_test.cpp:46
FloatingPoint< float > Float
Definition: gtest.h:7350
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:271
int commCoords(int)
int x[4]
enum QudaFieldLocation_s QudaFieldLocation
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
Definition: complex_quda.h:843
#define checkCudaError()
Definition: util_quda.h:110
QudaTune getTuning()
Definition: util_quda.cpp:32
const QudaParity parity
Definition: dslash_test.cpp:29