QUDA  v1.1.0
A library for QCD on GPUs
gauge_random.cu
Go to the documentation of this file.
1 #include <quda_internal.h>
2 #include <quda_matrix.h>
3 #include <tune_quda.h>
4 #include <gauge_field.h>
5 #include <gauge_field_order.h>
6 #include <launch_kernel.cuh>
7 #include <atomic.cuh>
8 #include <index_helper.cuh>
9 #include <random_quda.h>
10 #include <instantiate.h>
11 
12 namespace quda {
13 
14  template <typename Float_, int nColor_, QudaReconstructType recon_, bool group_> struct GaugeGaussArg {
15  using Float = Float_;
16  using real = typename mapper<Float>::type;
17  static constexpr int nColor = nColor_;
18  static constexpr QudaReconstructType recon = recon_;
19  static constexpr bool group = group_;
20 
21  using Gauge = typename gauge_mapper<Float, recon>::type;
22 
23  int threads; // number of active threads required
24  int E[4]; // extended grid dimensions
25  int X[4]; // true grid dimensions
26  int border[4];
27  Gauge U;
28  RNG rngstate;
29  real sigma; // where U = exp(sigma * H)
30 
31  GaugeGaussArg(const GaugeField &U, RNG &rngstate, double sigma) : U(U), rngstate(rngstate), sigma(sigma)
32  {
33  int R = 0;
34  for (int dir = 0; dir < 4; ++dir) {
35  border[dir] = U.R()[dir];
36  E[dir] = U.X()[dir];
37  X[dir] = U.X()[dir] - border[dir] * 2;
38  R += border[dir];
39  }
40  threads = X[0]*X[1]*X[2]*X[3]/2;
41  }
42  };
43 
44  template <typename real, typename Link> __device__ __host__ Link gauss_su3(cuRNGState &localState)
45  {
46  Link ret;
47  real rand1[4], rand2[4], phi[4], radius[4], temp1[4], temp2[4];
48 
49  for (int i = 0; i < 4; ++i) {
50  rand1[i] = Random<real>(localState);
51  rand2[i] = Random<real>(localState);
52  }
53 
54  for (int i = 0; i < 4; ++i) {
55  phi[i] = 2.0 * M_PI * rand1[i];
56  radius[i] = sqrt(-log(rand2[i]));
57  sincos(phi[i], &temp2[i], &temp1[i]);
58  temp1[i] *= radius[i];
59  temp2[i] *= radius[i];
60  }
61 
62  // construct Anti-Hermitian matrix
63  ret(0, 0) = complex<real>(0.0, temp1[2] + rsqrt(3.0) * temp2[3]);
64  ret(1, 1) = complex<real>(0.0, -temp1[2] + rsqrt(3.0) * temp2[3]);
65  ret(2, 2) = complex<real>(0.0, -2.0 * rsqrt(3.0) * temp2[3]);
66  ret(0, 1) = complex<real>(temp1[0], temp1[1]);
67  ret(1, 0) = complex<real>(-temp1[0], temp1[1]);
68  ret(0, 2) = complex<real>(temp1[3], temp2[0]);
69  ret(2, 0) = complex<real>(-temp1[3], temp2[0]);
70  ret(1, 2) = complex<real>(temp2[1], temp2[2]);
71  ret(2, 1) = complex<real>(-temp2[1], temp2[2]);
72 
73  return ret;
74  }
75 
76  template <typename Arg> __global__ void computeGenGauss(Arg arg)
77  {
78  using real = typename mapper<typename Arg::Float>::type;
79  using Link = Matrix<complex<real>, Arg::nColor>;
80  int x_cb = threadIdx.x + blockIdx.x * blockDim.x;
81  int parity = threadIdx.y + blockIdx.y * blockDim.y;
82  if (x_cb >= arg.threads) return;
83 
84  int x[4];
85  getCoords(x, x_cb, arg.X, parity);
86  for (int dr = 0; dr < 4; ++dr) x[dr] += arg.border[dr]; // extended grid coordinates
87 
88  if (arg.group && arg.sigma == 0.0) {
89  // if sigma = 0 then we just set the output matrix to the identity and finish
90  Link I;
91  setIdentity(&I);
92  for (int mu = 0; mu < 4; mu++) arg.U(mu, linkIndex(x, arg.E), parity) = I;
93  } else {
94  for (int mu = 0; mu < 4; mu++) {
95  cuRNGState localState = arg.rngstate.State()[parity * arg.threads + x_cb];
96 
97  // generate Gaussian distributed su(n) fiueld
98  Link u = gauss_su3<real, Link>(localState);
99  if (arg.group) {
100  u = arg.sigma * u;
101  expsu3<real>(u);
102  }
103  arg.U(mu, linkIndex(x, arg.E), parity) = u;
104 
105  arg.rngstate.State()[parity * arg.threads + x_cb] = localState;
106  }
107  }
108  }
109 
110  template <typename Arg> class GaugeGauss : TunableVectorY
111  {
112  Arg &arg;
113  const GaugeField &meta;
114 
115  unsigned int minThreads() const { return arg.threads; }
116  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
117 
118  public:
119  GaugeGauss(Arg &arg, GaugeField &meta) :
120  TunableVectorY(2),
121  arg(arg),
122  meta(meta) {}
123 
124  void apply(const qudaStream_t &stream)
125  {
126  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
127  qudaLaunchKernel(computeGenGauss<Arg>, tp, stream, arg);
128  }
129 
130  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
131 
132  long long flops() const { return 0; }
133  long long bytes() const { return meta.Bytes(); }
134 
135  void preTune() { arg.rngstate.backup(); }
136  void postTune() { arg.rngstate.restore(); }
137  };
138 
139  template <typename Float, int nColor, QudaReconstructType recon>
140  struct GenGaussGroup {
141  GenGaussGroup(GaugeField &U, RNG &rngstate, double sigma)
142  {
143  constexpr bool group = true;
144  GaugeGaussArg<Float, nColor, recon, group> arg(U, rngstate, sigma);
145  GaugeGauss<decltype(arg)> gaugeGauss(arg, U);
146  gaugeGauss.apply(0);
147  }
148  };
149 
150  template <typename Float, int nColor, QudaReconstructType recon>
151  struct GenGaussAlgebra {
152  GenGaussAlgebra(GaugeField &U, RNG &rngstate, double sigma)
153  {
154  constexpr bool group = false;
155  GaugeGaussArg<Float, nColor, recon, group> arg(U, rngstate, sigma);
156  GaugeGauss<decltype(arg)> gaugeGauss(arg, U);
157  gaugeGauss.apply(0);
158  }
159  };
160 
161  void gaugeGauss(GaugeField &U, RNG &rng, double sigma)
162  {
163  if (!U.isNative()) errorQuda("Order %d with %d reconstruct not supported", U.Order(), U.Reconstruct());
164 
165  if (U.LinkType() == QUDA_SU3_LINKS) {
166 
167  if (getVerbosity() >= QUDA_SUMMARIZE)
168  printfQuda("Creating Gaussian distrbuted gauge field with sigma = %e\n", sigma);
169  instantiate<GenGaussGroup, ReconstructFull>(U, rng, sigma);
170 
171  } else if (U.LinkType() == QUDA_MOMENTUM_LINKS) {
172 
173  if (getVerbosity() >= QUDA_SUMMARIZE)
174  printfQuda("Creating Gaussian distrbuted momentum field\n");
175  instantiate<GenGaussAlgebra, ReconstructMom>(U, rng, sigma);
176 
177  } else {
178  errorQuda("Unexpected linkt type %d", U.LinkType());
179  }
180 
181  // ensure multi-gpu consistency if required
182  if (U.GhostExchange() == QUDA_GHOST_EXCHANGE_EXTENDED) {
183  U.exchangeExtendedGhost(U.R());
184  } else if (U.GhostExchange() == QUDA_GHOST_EXCHANGE_PAD) {
185  U.exchangeGhost();
186  }
187  }
188 
189  void gaugeGauss(GaugeField &U, unsigned long long seed, double sigma)
190  {
191  RNG *randstates = new RNG(U, seed);
192  randstates->Init();
193  quda::gaugeGauss(U, *randstates, sigma);
194  randstates->Release();
195  delete randstates;
196  }
197 }