QUDA  v1.1.0
A library for QCD on GPUs
gauge_phase.cu
Go to the documentation of this file.
1 #include <gauge_field_order.h>
2 #include <comm_quda.h>
3 #include <complex_quda.h>
4 #include <index_helper.cuh>
5 #include <tune_quda.h>
6 #include <instantiate.h>
7 
8 /**
9  This code has not been checked. In particular, I suspect it is
10  erroneous in multi-GPU since it looks like the halo ghost region
11  isn't being treated here.
12  */
13 
14 namespace quda {
15 
16  template <typename Float_, int nColor_, QudaReconstructType recon_, QudaStaggeredPhase phase_>
17  struct GaugePhaseArg {
18  using Float = Float_;
19  static constexpr int nColor = nColor_;
20  static_assert(nColor == 3, "Only nColor=3 enabled at this time");
21  static constexpr QudaReconstructType recon = recon_;
22  static constexpr QudaStaggeredPhase phase = phase_;
23  typedef typename gauge_mapper<Float,recon>::type Gauge;
24 
25  Gauge u;
26  int X[4];
27  int threads;
28  Float tBoundary;
29  Float i_mu;
30  complex<Float> i_mu_phase;
31  GaugePhaseArg(GaugeField &u) :
32  u(u),
33  threads(u.VolumeCB()),
34  i_mu(u.iMu())
35  {
36  // if staggered phases are applied, then we are removing them
37  // else we are applying them
38  Float dir = u.StaggeredPhaseApplied() ? -1.0 : 1.0;
39 
40  i_mu_phase = complex<Float>( cos(M_PI * u.iMu() / (u.X()[3]*comm_dim(3)) ),
41  dir * sin(M_PI * u.iMu() / (u.X()[3]*comm_dim(3))) );
42 
43  for (int d=0; d<4; d++) X[d] = u.X()[d];
44 
45  // only set the boundary condition on the last time slice of nodes
46  bool last_node_in_t = (commCoords(3) == commDim(3)-1);
47  tBoundary = (Float)(last_node_in_t ? u.TBoundary() : QUDA_PERIODIC_T);
48  }
49  };
50 
51  // FIXME need to check this with odd local volumes
52  template <int dim, typename Arg> constexpr auto getPhase(int x, int y, int z, int t, Arg &arg) {
53  typename Arg::Float phase = 1.0;
54  if (Arg::phase == QUDA_STAGGERED_PHASE_MILC) {
55  if (dim==0) {
56  phase = (1.0 - 2.0 * (t % 2) );
57  } else if (dim == 1) {
58  phase = (1.0 - 2.0 * ((t + x) % 2) );
59  } else if (dim == 2) {
60  phase = (1.0 - 2.0 * ((t + x + y) % 2) );
61  } else if (dim == 3) { // also apply boundary condition
62  phase = (t == arg.X[3]-1) ? arg.tBoundary : 1.0;
63  }
64  } else if (Arg::phase == QUDA_STAGGERED_PHASE_TIFR) {
65  if (dim==0) {
66  phase = (1.0 - 2.0 * ((3 + t + z + y) % 2) );
67  } else if (dim == 1) {
68  phase = (1.0 - 2.0 * ((2 + t + z) % 2) );
69  } else if (dim == 2) {
70  phase = (1.0 - 2.0 * ((1 + t) % 2) );
71  } else if (dim == 3) { // also apply boundary condition
72  phase = (t == arg.X[3]-1) ? arg.tBoundary : 1.0;
73  }
74  } else if (Arg::phase == QUDA_STAGGERED_PHASE_CPS) {
75  if (dim==0) {
76  phase = 1.0;
77  } else if (dim == 1) {
78  phase = (1.0 - 2.0 * ((1 + x) % 2) );
79  } else if (dim == 2) {
80  phase = (1.0 - 2.0 * ((1 + x + y) % 2) );
81  } else if (dim == 3) { // also apply boundary condition
82  phase = ((t == arg.X[3]-1) ? arg.tBoundary : 1.0) *
83  (1.0 - 2 * ((1 + x + y + z) % 2) );
84  }
85  }
86  return phase;
87  }
88 
89  template <int dim, typename Arg>
90  __device__ __host__ void gaugePhase(int indexCB, int parity, Arg &arg) {
91  typedef typename mapper<typename Arg::Float>::type real;
92 
93  int x[4];
94  getCoords(x, indexCB, arg.X, parity);
95 
96  real phase = getPhase<dim>(x[0], x[1], x[2], x[3], arg);
97  Matrix<complex<real>,Arg::nColor> u = arg.u(dim, indexCB, parity);
98  u *= phase;
99 
100  // apply imaginary chemical potential if needed
101  if (dim==3 && arg.i_mu != 0.0) u *= arg.i_mu_phase;
102 
103  arg.u(dim, indexCB, parity) = u;
104  }
105 
106  /**
107  Generic GPU staggered phase application
108  */
109  template <typename Arg>
110  __global__ void gaugePhaseKernel(Arg arg) {
111  int indexCB = blockIdx.x * blockDim.x + threadIdx.x;
112  if (indexCB >= arg.threads) return;
113  int parity = blockIdx.y * blockDim.y + threadIdx.y;
114  gaugePhase<0>(indexCB, parity, arg);
115  gaugePhase<1>(indexCB, parity, arg);
116  gaugePhase<2>(indexCB, parity, arg);
117  gaugePhase<3>(indexCB, parity, arg);
118  }
119 
120  template <typename Arg>
121  class GaugePhase : TunableVectorY {
122  Arg &arg;
123  const GaugeField &meta; // used for meta data only
124 
125  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
126  unsigned int minThreads() const { return arg.threads; }
127 
128  public:
129  GaugePhase(Arg &arg, const GaugeField &meta)
130  : TunableVectorY(2), arg(arg), meta(meta) { }
131 
132  void apply(const qudaStream_t &stream) {
133  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
134  qudaLaunchKernel(gaugePhaseKernel<Arg>, tp, stream, arg);
135  }
136 
137  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
138 
139  void preTune() { arg.u.save(); }
140  void postTune() { arg.u.load(); }
141 
142  long long flops() const { return 0; }
143  long long bytes() const { return 2 * meta.Bytes(); } // 2 from i/o
144  };
145 
146 
147  template <typename Float, int nColor, QudaReconstructType recon>
148  struct GaugePhase_ {
149  GaugePhase_(GaugeField &u) {
150  if (u.StaggeredPhase() == QUDA_STAGGERED_PHASE_MILC) {
151  GaugePhaseArg<Float, nColor, recon, QUDA_STAGGERED_PHASE_MILC> arg(u);
152  GaugePhase<decltype(arg)> phase(arg, u);
153  phase.apply(0);
154  } else if (u.StaggeredPhase() == QUDA_STAGGERED_PHASE_CPS) {
155  GaugePhaseArg<Float, nColor, recon, QUDA_STAGGERED_PHASE_CPS> arg(u);
156  GaugePhase<decltype(arg)> phase(arg, u);
157  phase.apply(0);
158  } else if (u.StaggeredPhase() == QUDA_STAGGERED_PHASE_TIFR) {
159  GaugePhaseArg<Float, nColor, recon, QUDA_STAGGERED_PHASE_TIFR> arg(u);
160  GaugePhase<decltype(arg)> phase(arg, u);
161  phase.apply(0);
162  } else {
163  errorQuda("Undefined phase type");
164  }
165  }
166  };
167 
168  void applyGaugePhase(GaugeField &u) {
169 #ifdef GPU_GAUGE_TOOLS
170  instantiate<GaugePhase_, ReconstructNone>(u);
171  // ensure that ghosts are updated if needed
172  if (u.GhostExchange() == QUDA_GHOST_EXCHANGE_PAD) u.exchangeGhost();
173 #else
174  errorQuda("Gauge tools are not build");
175 #endif
176  }
177 
178 } // namespace quda