1 #include <gauge_field_order.h>
3 #include <complex_quda.h>
4 #include <index_helper.cuh>
6 #include <instantiate.h>
9 This code has not been checked. In particular, I suspect it is
10 erroneous in multi-GPU since it looks like the halo ghost region
11 isn't being treated here.
16 template <typename Float_, int nColor_, QudaReconstructType recon_, QudaStaggeredPhase phase_>
17 struct GaugePhaseArg {
19 static constexpr int nColor = nColor_;
20 static_assert(nColor == 3, "Only nColor=3 enabled at this time");
21 static constexpr QudaReconstructType recon = recon_;
22 static constexpr QudaStaggeredPhase phase = phase_;
23 typedef typename gauge_mapper<Float,recon>::type Gauge;
30 complex<Float> i_mu_phase;
31 GaugePhaseArg(GaugeField &u) :
33 threads(u.VolumeCB()),
36 // if staggered phases are applied, then we are removing them
37 // else we are applying them
38 Float dir = u.StaggeredPhaseApplied() ? -1.0 : 1.0;
40 i_mu_phase = complex<Float>( cos(M_PI * u.iMu() / (u.X()[3]*comm_dim(3)) ),
41 dir * sin(M_PI * u.iMu() / (u.X()[3]*comm_dim(3))) );
43 for (int d=0; d<4; d++) X[d] = u.X()[d];
45 // only set the boundary condition on the last time slice of nodes
46 bool last_node_in_t = (commCoords(3) == commDim(3)-1);
47 tBoundary = (Float)(last_node_in_t ? u.TBoundary() : QUDA_PERIODIC_T);
51 // FIXME need to check this with odd local volumes
52 template <int dim, typename Arg> constexpr auto getPhase(int x, int y, int z, int t, Arg &arg) {
53 typename Arg::Float phase = 1.0;
54 if (Arg::phase == QUDA_STAGGERED_PHASE_MILC) {
56 phase = (1.0 - 2.0 * (t % 2) );
57 } else if (dim == 1) {
58 phase = (1.0 - 2.0 * ((t + x) % 2) );
59 } else if (dim == 2) {
60 phase = (1.0 - 2.0 * ((t + x + y) % 2) );
61 } else if (dim == 3) { // also apply boundary condition
62 phase = (t == arg.X[3]-1) ? arg.tBoundary : 1.0;
64 } else if (Arg::phase == QUDA_STAGGERED_PHASE_TIFR) {
66 phase = (1.0 - 2.0 * ((3 + t + z + y) % 2) );
67 } else if (dim == 1) {
68 phase = (1.0 - 2.0 * ((2 + t + z) % 2) );
69 } else if (dim == 2) {
70 phase = (1.0 - 2.0 * ((1 + t) % 2) );
71 } else if (dim == 3) { // also apply boundary condition
72 phase = (t == arg.X[3]-1) ? arg.tBoundary : 1.0;
74 } else if (Arg::phase == QUDA_STAGGERED_PHASE_CPS) {
77 } else if (dim == 1) {
78 phase = (1.0 - 2.0 * ((1 + x) % 2) );
79 } else if (dim == 2) {
80 phase = (1.0 - 2.0 * ((1 + x + y) % 2) );
81 } else if (dim == 3) { // also apply boundary condition
82 phase = ((t == arg.X[3]-1) ? arg.tBoundary : 1.0) *
83 (1.0 - 2 * ((1 + x + y + z) % 2) );
89 template <int dim, typename Arg>
90 __device__ __host__ void gaugePhase(int indexCB, int parity, Arg &arg) {
91 typedef typename mapper<typename Arg::Float>::type real;
94 getCoords(x, indexCB, arg.X, parity);
96 real phase = getPhase<dim>(x[0], x[1], x[2], x[3], arg);
97 Matrix<complex<real>,Arg::nColor> u = arg.u(dim, indexCB, parity);
100 // apply imaginary chemical potential if needed
101 if (dim==3 && arg.i_mu != 0.0) u *= arg.i_mu_phase;
103 arg.u(dim, indexCB, parity) = u;
107 Generic GPU staggered phase application
109 template <typename Arg>
110 __global__ void gaugePhaseKernel(Arg arg) {
111 int indexCB = blockIdx.x * blockDim.x + threadIdx.x;
112 if (indexCB >= arg.threads) return;
113 int parity = blockIdx.y * blockDim.y + threadIdx.y;
114 gaugePhase<0>(indexCB, parity, arg);
115 gaugePhase<1>(indexCB, parity, arg);
116 gaugePhase<2>(indexCB, parity, arg);
117 gaugePhase<3>(indexCB, parity, arg);
120 template <typename Arg>
121 class GaugePhase : TunableVectorY {
123 const GaugeField &meta; // used for meta data only
125 bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
126 unsigned int minThreads() const { return arg.threads; }
129 GaugePhase(Arg &arg, const GaugeField &meta)
130 : TunableVectorY(2), arg(arg), meta(meta) { }
132 void apply(const qudaStream_t &stream) {
133 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
134 qudaLaunchKernel(gaugePhaseKernel<Arg>, tp, stream, arg);
137 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
139 void preTune() { arg.u.save(); }
140 void postTune() { arg.u.load(); }
142 long long flops() const { return 0; }
143 long long bytes() const { return 2 * meta.Bytes(); } // 2 from i/o
147 template <typename Float, int nColor, QudaReconstructType recon>
149 GaugePhase_(GaugeField &u) {
150 if (u.StaggeredPhase() == QUDA_STAGGERED_PHASE_MILC) {
151 GaugePhaseArg<Float, nColor, recon, QUDA_STAGGERED_PHASE_MILC> arg(u);
152 GaugePhase<decltype(arg)> phase(arg, u);
154 } else if (u.StaggeredPhase() == QUDA_STAGGERED_PHASE_CPS) {
155 GaugePhaseArg<Float, nColor, recon, QUDA_STAGGERED_PHASE_CPS> arg(u);
156 GaugePhase<decltype(arg)> phase(arg, u);
158 } else if (u.StaggeredPhase() == QUDA_STAGGERED_PHASE_TIFR) {
159 GaugePhaseArg<Float, nColor, recon, QUDA_STAGGERED_PHASE_TIFR> arg(u);
160 GaugePhase<decltype(arg)> phase(arg, u);
163 errorQuda("Undefined phase type");
168 void applyGaugePhase(GaugeField &u) {
169 #ifdef GPU_GAUGE_TOOLS
170 instantiate<GaugePhase_, ReconstructNone>(u);
171 // ensure that ghosts are updated if needed
172 if (u.GhostExchange() == QUDA_GHOST_EXCHANGE_PAD) u.exchangeGhost();
174 errorQuda("Gauge tools are not build");