QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
gauge_phase.cu
Go to the documentation of this file.
1 #include <gauge_field_order.h>
2 #include <comm_quda.h>
3 #include <complex_quda.h>
4 #include <index_helper.cuh>
5 #include <tune_quda.h>
6 
13 namespace quda {
14 
15 #ifdef GPU_GAUGE_TOOLS
16 
17  template <typename Float, int Nc, typename Order>
18  struct GaugePhaseArg {
19  static constexpr int nColor = Nc;
20  Order order;
21  int X[4];
22  int threads;
23  Float tBoundary;
24  Float i_mu;
25  complex<Float> i_mu_phase;
26  GaugePhaseArg(const Order &order, const GaugeField &u)
27  : order(order), threads(u.VolumeCB()), i_mu(u.iMu())
28  {
29  // if staggered phases are applied, then we are removing them
30  // else we are applying them
31  Float dir = u.StaggeredPhaseApplied() ? -1.0 : 1.0;
32 
33  i_mu_phase = complex<Float>( cos(M_PI * u.iMu() / (u.X()[3]*comm_dim(3)) ),
34  dir * sin(M_PI * u.iMu() / (u.X()[3]*comm_dim(3))) );
35 
36  for (int d=0; d<4; d++) X[d] = u.X()[d];
37 
38  // only set the boundary condition on the last time slice of nodes
39 #ifdef MULTI_GPU
40  bool last_node_in_t = (commCoords(3) == commDim(3)-1);
41 #else
42  bool last_node_in_t = true;
43 #endif
44  tBoundary = (Float)(last_node_in_t ? u.TBoundary() : QUDA_PERIODIC_T);
45  }
46  GaugePhaseArg(const GaugePhaseArg &arg)
47  : order(arg.order), tBoundary(arg.tBoundary), threads(arg.threads),
48  i_mu(arg.i_mu), i_mu_phase(arg.i_mu_phase) {
49  for (int d=0; d<4; d++) X[d] = arg.X[d];
50  }
51  };
52 
53 
54 
55  // FIXME need to check this with odd local volumes
56  template <int dim, typename Float, QudaStaggeredPhase phaseType, typename Arg>
57  __device__ __host__ Float getPhase(int x, int y, int z, int t, Arg &arg) {
58  Float phase = 1.0;
59  if (phaseType == QUDA_STAGGERED_PHASE_MILC) {
60  if (dim==0) {
61  phase = (1.0 - 2.0 * (t % 2) );
62  } else if (dim == 1) {
63  phase = (1.0 - 2.0 * ((t + x) % 2) );
64  } else if (dim == 2) {
65  phase = (1.0 - 2.0 * ((t + x + y) % 2) );
66  } else if (dim == 3) { // also apply boundary condition
67  phase = (t == arg.X[3]-1) ? arg.tBoundary : 1.0;
68  }
69  } else if (phaseType == QUDA_STAGGERED_PHASE_TIFR) {
70  if (dim==0) {
71  phase = (1.0 - 2.0 * ((3 + t + z + y) % 2) );
72  } else if (dim == 1) {
73  phase = (1.0 - 2.0 * ((2 + t + z) % 2) );
74  } else if (dim == 2) {
75  phase = (1.0 - 2.0 * ((1 + t) % 2) );
76  } else if (dim == 3) { // also apply boundary condition
77  phase = (t == arg.X[3]-1) ? arg.tBoundary : 1.0;
78  }
79  } else if (phaseType == QUDA_STAGGERED_PHASE_CPS) {
80  if (dim==0) {
81  phase = 1.0;
82  } else if (dim == 1) {
83  phase = (1.0 - 2.0 * ((1 + x) % 2) );
84  } else if (dim == 2) {
85  phase = (1.0 - 2.0 * ((1 + x + y) % 2) );
86  } else if (dim == 3) { // also apply boundary condition
87  phase = ((t == arg.X[3]-1) ? arg.tBoundary : 1.0) *
88  (1.0 - 2 * ((1 + x + y + z) % 2) );
89  }
90  }
91  return phase;
92  }
93 
94  template <typename Float, QudaStaggeredPhase phaseType, int dim, typename Arg>
95  __device__ __host__ void gaugePhase(int indexCB, int parity, Arg &arg) {
96  typedef typename mapper<Float>::type real;
97 
98  int x[4];
99  getCoords(x, indexCB, arg.X, parity);
100 
101  real phase = getPhase<dim,Float,phaseType>(x[0], x[1], x[2], x[3], arg);
102  Matrix<complex<real>,Arg::nColor> u = arg.order(dim, indexCB, parity);
103  u *= phase;
104 
105  // apply imaginary chemical potential if needed
106  if (dim==3 && arg.i_mu != 0.0) u *= arg.i_mu_phase;
107 
108  arg.order(dim, indexCB, parity) = u;
109  }
110 
114  template <typename Float, QudaStaggeredPhase phaseType, typename Arg>
115  void gaugePhase(Arg &arg) {
116  for (int parity=0; parity<2; parity++) {
117  for (int indexCB=0; indexCB < arg.threads; indexCB++) {
118  gaugePhase<Float,phaseType,0>(indexCB, parity, arg);
119  gaugePhase<Float,phaseType,1>(indexCB, parity, arg);
120  gaugePhase<Float,phaseType,2>(indexCB, parity, arg);
121  gaugePhase<Float,phaseType,3>(indexCB, parity, arg);
122  }
123  }
124  }
125 
129  template <typename Float, QudaStaggeredPhase phaseType, typename Arg>
130  __global__ void gaugePhaseKernel(Arg arg) {
131  int indexCB = blockIdx.x * blockDim.x + threadIdx.x;
132  if (indexCB >= arg.threads) return;
133  int parity = blockIdx.y * blockDim.y + threadIdx.y;
134  gaugePhase<Float,phaseType,0>(indexCB, parity, arg);
135  gaugePhase<Float,phaseType,1>(indexCB, parity, arg);
136  gaugePhase<Float,phaseType,2>(indexCB, parity, arg);
137  gaugePhase<Float,phaseType,3>(indexCB, parity, arg);
138  }
139 
140  template <typename Float, QudaStaggeredPhase phaseType, typename Arg>
141  class GaugePhase : TunableVectorY {
142  Arg &arg;
143  const GaugeField &meta; // used for meta data only
144 
145  private:
146  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
147  unsigned int minThreads() const { return arg.threads; }
148 
149  public:
150  GaugePhase(Arg &arg, const GaugeField &meta)
151  : TunableVectorY(2), arg(arg), meta(meta) {
152  writeAuxString("stride=%d,prec=%lu",arg.order.stride,sizeof(Float));
153  }
154  virtual ~GaugePhase() { ; }
155 
156  void apply(const cudaStream_t &stream) {
157  if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
158  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
159  gaugePhaseKernel<Float, phaseType, Arg>
160  <<<tp.grid, tp.block, tp.shared_bytes, stream>>>(arg);
161  } else {
162  gaugePhase<Float, phaseType, Arg>(arg);
163  }
164  }
165 
166  TuneKey tuneKey() const {
167  return TuneKey(meta.VolString(), typeid(*this).name(), aux);
168  }
169 
170  void preTune() { arg.order.save(); }
171  void postTune() { arg.order.load(); }
172 
173  long long flops() const { return 0; }
174  long long bytes() const { return 2 * arg.threads * 2 * arg.order.Bytes(); } // parity * e/o volume * i/o * vec size
175  };
176 
177 
178  template <typename Float, int Nc, typename Order>
179  void gaugePhase(Order order, const GaugeField &u) {
180  if (u.StaggeredPhase() == QUDA_STAGGERED_PHASE_MILC) {
181  GaugePhaseArg<Float,Nc,Order> arg(order, u);
182  GaugePhase<Float,QUDA_STAGGERED_PHASE_MILC,
183  GaugePhaseArg<Float,Nc,Order> > phase(arg, u);
184  phase.apply(0);
185  } else if (u.StaggeredPhase() == QUDA_STAGGERED_PHASE_CPS) {
186  GaugePhaseArg<Float,Nc,Order> arg(order, u);
187  GaugePhase<Float,QUDA_STAGGERED_PHASE_CPS,
188  GaugePhaseArg<Float,Nc,Order> > phase(arg, u);
189  phase.apply(0);
190  } else if (u.StaggeredPhase() == QUDA_STAGGERED_PHASE_TIFR) {
191  GaugePhaseArg<Float,Nc,Order> arg(order, u);
192  GaugePhase<Float,QUDA_STAGGERED_PHASE_TIFR,
193  GaugePhaseArg<Float,Nc,Order> > phase(arg, u);
194  phase.apply(0);
195  } else {
196  errorQuda("Undefined phase type");
197  }
198 
199  if (u.Location() == QUDA_CUDA_FIELD_LOCATION) checkCudaError();
200  }
201 
203  template <typename Float>
204  void gaugePhase(GaugeField &u) {
205  if (u.Ncolor() != 3) errorQuda("Unsupported number of colors %d", u.Ncolor());
206  constexpr int Nc = 3;
207 
208  if (u.isNative()) {
209  if (u.Reconstruct() == QUDA_RECONSTRUCT_NO) {
210  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO>::type G;
211  gaugePhase<Float,Nc>(G(u), u);
212  } else {
213  errorQuda("Unsupported reconstruction type");
214  }
215  } else {
216  errorQuda("Gauge field %d order not supported", u.Order());
217  }
218 
219  }
220 
221 #endif
222 
224 
225 #ifdef GPU_GAUGE_TOOLS
226  if (u.Precision() == QUDA_DOUBLE_PRECISION) {
227  gaugePhase<double>(u);
228  } else if (u.Precision() == QUDA_SINGLE_PRECISION) {
229  gaugePhase<float>(u);
230  } else {
231  errorQuda("Unknown precision type %d", u.Precision());
232  }
233 #else
234  errorQuda("Gauge tools are not build");
235 #endif
236 
238  // ensure that ghosts are updated if needed
239  u.exchangeGhost();
240  }
241 
242  }
243 
244 } // namespace quda
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define errorQuda(...)
Definition: util_quda.h:121
int comm_dim(int dim)
int commCoords(int)
cudaStream_t * stream
void applyGaugePhase(GaugeField &u)
Definition: gauge_phase.cu:223
bool last_node_in_t()
Definition: test_util.cpp:118
__host__ __device__ ValueType sin(ValueType x)
Definition: complex_quda.h:51
const int nColor
Definition: covdev_test.cpp:75
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
Main header file for host and device accessors to GaugeFields.
int X[4]
Definition: covdev_test.cpp:70
static int commDim[QUDA_MAX_DIM]
Definition: dslash_pack.cuh:9
unsigned long long flops
Definition: blas_quda.cu:22
virtual void exchangeGhost(QudaLinkDirection=QUDA_LINK_BACKWARDS)=0
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
__host__ __device__ ValueType cos(ValueType x)
Definition: complex_quda.h:46
#define checkCudaError()
Definition: util_quda.h:161
QudaGhostExchange GhostExchange() const
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
QudaPrecision Precision() const
QudaParity parity
Definition: covdev_test.cpp:54
unsigned long long bytes
Definition: blas_quda.cu:23
__host__ __device__ int getCoords(int coord[], const Arg &arg, int &idx, int parity, int &dim)
Compute the space-time coordinates we are at.