QUDA  0.9.0
gauge_phase.cu
Go to the documentation of this file.
1 #include <gauge_field_order.h>
2 #include <comm_quda.h>
3 #include <complex_quda.h>
4 #include <index_helper.cuh>
5 
12 namespace quda {
13 
14 #ifdef GPU_GAUGE_TOOLS
15 
16  template <typename Float, typename Order>
17  struct GaugePhaseArg {
18  Order order;
19  int X[4];
20  int threads;
21  Float tBoundary;
22  Float i_mu;
23  complex<Float> i_mu_phase;
24  GaugePhaseArg(const Order &order, const GaugeField &u)
25  : order(order), threads(u.VolumeCB()), i_mu(u.iMu())
26  {
27  // if staggered phases are applied, then we are removing them
28  // else we are applying them
29  Float dir = u.StaggeredPhaseApplied() ? -1.0 : 1.0;
30 
31  i_mu_phase = complex<Float>( cos(M_PI * u.iMu() / (u.X()[3]*comm_dim(3)) ),
32  dir * sin(M_PI * u.iMu() / (u.X()[3]*comm_dim(3))) );
33 
34  for (int d=0; d<4; d++) X[d] = u.X()[d];
35 
36  // only set the boundary condition on the last time slice of nodes
37 #ifdef MULTI_GPU
38  bool last_node_in_t = (commCoords(3) == commDim(3)-1);
39 #else
40  bool last_node_in_t = true;
41 #endif
42  tBoundary = (Float)(last_node_in_t ? u.TBoundary() : QUDA_PERIODIC_T);
43  }
44  GaugePhaseArg(const GaugePhaseArg &arg)
45  : order(arg.order), tBoundary(arg.tBoundary), threads(arg.threads),
46  i_mu(arg.i_mu), i_mu_phase(arg.i_mu_phase) {
47  for (int d=0; d<4; d++) X[d] = arg.X[d];
48  }
49  };
50 
51 
52 
53  // FIXME need to check this with odd local volumes
54  template <int dim, typename Float, QudaStaggeredPhase phaseType, typename Arg>
55  __device__ __host__ Float getPhase(int x, int y, int z, int t, Arg &arg) {
56  Float phase = 1.0;
57  if (phaseType == QUDA_STAGGERED_PHASE_MILC) {
58  if (dim==0) {
59  phase = (1.0 - 2.0 * (t % 2) );
60  } else if (dim == 1) {
61  phase = (1.0 - 2.0 * ((t + x) % 2) );
62  } else if (dim == 2) {
63  phase = (1.0 - 2.0 * ((t + x + y) % 2) );
64  } else if (dim == 3) { // also apply boundary condition
65  phase = (t == arg.X[3]-1) ? arg.tBoundary : 1.0;
66  }
67  } if (phaseType == QUDA_STAGGERED_PHASE_TIFR) {
68  if (dim==0) {
69  phase = (1.0 - 2.0 * ((3 + t + z + y) % 2) );
70  } else if (dim == 1) {
71  phase = (1.0 - 2.0 * ((2 + t + z) % 2) );
72  } else if (dim == 2) {
73  phase = (1.0 - 2.0 * ((1 + t) % 2) );
74  } else if (dim == 3) { // also apply boundary condition
75  phase = (t == arg.X[3]-1) ? arg.tBoundary : 1.0;
76  }
77  } else if (phaseType == QUDA_STAGGERED_PHASE_CPS) {
78  if (dim==0) {
79  phase = 1.0;
80  } else if (dim == 1) {
81  phase = (1.0 - 2.0 * ((1 + x) % 2) );
82  } else if (dim == 2) {
83  phase = (1.0 - 2.0 * ((1 + x + y) % 2) );
84  } else if (dim == 3) { // also apply boundary condition
85  phase = ((t == arg.X[3]-1) ? arg.tBoundary : 1.0) *
86  (1.0 - 2 * ((1 + x + y + z) % 2) );
87  }
88  }
89  return phase;
90  }
91 
92  template <typename Float, int length, QudaStaggeredPhase phaseType, int dim, typename Arg>
93  __device__ __host__ void gaugePhase(int indexCB, int parity, Arg &arg) {
94  typedef typename mapper<Float>::type RegType;
95 
96  int x[4];
97  getCoords(x, indexCB, arg.X, parity);
98 
99  RegType phase = getPhase<dim,Float,phaseType>(x[0], x[1], x[2], x[3], arg);
100  RegType u[length];
101  arg.order.load(u, indexCB, dim, parity);
102  for (int i=0; i<length; i++) u[i] *= phase;
103 
104  // apply imaginary chemical potential if needed
105  if (dim==3 && arg.i_mu != 0.0) {
106  complex<RegType>* v = reinterpret_cast<complex<RegType>*>(u);
107  for (int i=0; i<length/2; i++) v[i] *= arg.i_mu_phase;
108  }
109 
110  arg.order.save(u, indexCB, dim, parity);
111  }
112 
116  template <typename Float, int length, QudaStaggeredPhase phaseType, typename Arg>
117  void gaugePhase(Arg &arg) {
118  for (int parity=0; parity<2; parity++) {
119  for (int indexCB=0; indexCB < arg.threads; indexCB++) {
120  gaugePhase<Float,length,phaseType,0>(indexCB, parity, arg);
121  gaugePhase<Float,length,phaseType,1>(indexCB, parity, arg);
122  gaugePhase<Float,length,phaseType,2>(indexCB, parity, arg);
123  gaugePhase<Float,length,phaseType,3>(indexCB, parity, arg);
124  }
125  }
126  }
127 
131  template <typename Float, int length, QudaStaggeredPhase phaseType, typename Arg>
132  __global__ void gaugePhaseKernel(Arg arg) {
133  int indexCB = blockIdx.x * blockDim.x + threadIdx.x;
134  if (indexCB >= arg.threads) return;
135  int parity = blockIdx.y;
136 
137  gaugePhase<Float,length,phaseType,0>(indexCB, parity, arg);
138  gaugePhase<Float,length,phaseType,1>(indexCB, parity, arg);
139  gaugePhase<Float,length,phaseType,2>(indexCB, parity, arg);
140  gaugePhase<Float,length,phaseType,3>(indexCB, parity, arg);
141  }
142 
143  template <typename Float, int length, QudaStaggeredPhase phaseType, typename Arg>
144  class GaugePhase : Tunable {
145  Arg &arg;
146  const GaugeField &meta; // used for meta data only
147  QudaFieldLocation location;
148 
149  private:
150  unsigned int sharedBytesPerThread() const { return 0; }
151  unsigned int sharedBytesPerBlock(const TuneParam &param) const { return 0 ;}
152 
153  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
154  unsigned int minThreads() const { return arg.threads; }
155 
156  public:
157  GaugePhase(Arg &arg, const GaugeField &meta, QudaFieldLocation location)
158  : arg(arg), meta(meta), location(location) {
159  writeAuxString("stride=%d,prec=%lu",arg.order.stride,sizeof(Float));
160  }
161  virtual ~GaugePhase() { ; }
162 
163  bool advanceBlockDim(TuneParam &param) const {
164  bool rtn = Tunable::advanceBlockDim(param);
165  param.grid.y = 2;
166  return rtn;
167  }
168 
169  void initTuneParam(TuneParam &param) const {
171  param.grid.y = 2;
172  }
173 
174  void apply(const cudaStream_t &stream) {
175  if (location == QUDA_CUDA_FIELD_LOCATION) {
176  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
177  tp.grid.y = 2; // parity is the y grid dimension
178  gaugePhaseKernel<Float, length, phaseType, Arg>
179  <<<tp.grid, tp.block, tp.shared_bytes, stream>>>(arg);
180  } else {
181  gaugePhase<Float, length, phaseType, Arg>(arg);
182  }
183  }
184 
185  TuneKey tuneKey() const {
186  return TuneKey(meta.VolString(), typeid(*this).name(), aux);
187  }
188 
189  void preTune() { arg.order.save(); }
190  void postTune() { arg.order.load(); }
191 
192  long long flops() const { return 0; }
193  long long bytes() const { return 2 * arg.threads * 2 * arg.order.Bytes(); } // parity * e/o volume * i/o * vec size
194  };
195 
196 
197  template <typename Float, int length, typename Order>
198  void gaugePhase(Order order, const GaugeField &u, QudaFieldLocation location) {
199  if (u.StaggeredPhase() == QUDA_STAGGERED_PHASE_MILC) {
200  GaugePhaseArg<Float,Order> arg(order, u);
201  GaugePhase<Float,length,QUDA_STAGGERED_PHASE_MILC,
202  GaugePhaseArg<Float,Order> > phase(arg, u, location);
203  phase.apply(0);
204  } else if (u.StaggeredPhase() == QUDA_STAGGERED_PHASE_CPS) {
205  GaugePhaseArg<Float,Order> arg(order, u);
206  GaugePhase<Float,length,QUDA_STAGGERED_PHASE_CPS,
207  GaugePhaseArg<Float,Order> > phase(arg, u, location);
208  phase.apply(0);
209  } else if (u.StaggeredPhase() == QUDA_STAGGERED_PHASE_TIFR) {
210  GaugePhaseArg<Float,Order> arg(order, u);
211  GaugePhase<Float,length,QUDA_STAGGERED_PHASE_TIFR,
212  GaugePhaseArg<Float,Order> > phase(arg, u, location);
213  phase.apply(0);
214  } else {
215  errorQuda("Undefined phase type");
216  }
217 
218  if (location == QUDA_CUDA_FIELD_LOCATION) checkCudaError();
219  }
220 
222  template <typename Float>
223  void gaugePhase(GaugeField &u) {
224  const int length = 18;
225 
226  QudaFieldLocation location =
227  (typeid(u)==typeid(cudaGaugeField)) ? QUDA_CUDA_FIELD_LOCATION : QUDA_CPU_FIELD_LOCATION;
228 
229  if (u.isNative()) {
230  if (u.Reconstruct() == QUDA_RECONSTRUCT_NO) {
231  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO>::type G;
232  gaugePhase<Float,length>(G(u), u, location);
233  } else {
234  errorQuda("Unsupported reconstruction type");
235  }
236  } else {
237  errorQuda("Gauge field %d order not supported", u.Order());
238  }
239 
240  }
241 
242 #endif
243 
245 
246 #ifdef GPU_GAUGE_TOOLS
247  if (u.Precision() == QUDA_DOUBLE_PRECISION) {
248  gaugePhase<double>(u);
249  } else if (u.Precision() == QUDA_SINGLE_PRECISION) {
250  gaugePhase<float>(u);
251  } else {
252  errorQuda("Unknown precision type %d", u.Precision());
253  }
254 #else
255  errorQuda("Gauge tools are not build");
256 #endif
257 
258  }
259 
260 } // namespace quda
dim3 dim3 blockDim
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:20
#define errorQuda(...)
Definition: util_quda.h:90
int comm_dim(int dim)
int commCoords(int)
cudaStream_t * stream
static __inline__ dim3 dim3 void size_t cudaStream_t int dim
void applyGaugePhase(GaugeField &u)
Definition: gauge_phase.cu:244
QudaGaugeParam param
Definition: pack_test.cpp:17
__host__ __device__ ValueType sin(ValueType x)
Definition: complex_quda.h:40
int commDim(int)
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:603
Main header file for host and device accessors to GaugeFields.
enum QudaFieldLocation_s QudaFieldLocation
unsigned long long flops
Definition: blas_quda.cu:42
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
Definition: complex_quda.h:880
void size_t length
__host__ __device__ ValueType cos(ValueType x)
Definition: complex_quda.h:35
virtual void initTuneParam(TuneParam &param) const
Definition: tune_quda.h:230
#define checkCudaError()
Definition: util_quda.h:129
virtual bool advanceBlockDim(TuneParam &param) const
Definition: tune_quda.h:102
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:51
static __inline__ size_t size_t d
QudaPrecision Precision() const
QudaParity parity
Definition: covdev_test.cpp:53
unsigned long long bytes
Definition: blas_quda.cu:43
static __device__ __host__ void getCoords(int x[], int cb_index, const I X[], int parity)