QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
gauge_random.cu
Go to the documentation of this file.
1 #include <quda_internal.h>
2 #include <quda_matrix.h>
3 #include <tune_quda.h>
4 #include <gauge_field.h>
5 #include <gauge_field_order.h>
6 #include <launch_kernel.cuh>
7 #include <atomic.cuh>
8 #include <cub_helper.cuh>
9 #include <index_helper.cuh>
10 #include <random_quda.h>
11 
12 namespace quda {
13 
14  template <typename Float, QudaReconstructType recon, bool group_> struct GaugeGaussArg {
16  using real = typename mapper<Float>::type;
17  static constexpr bool group = group_;
18  int threads; // number of active threads required
19  int E[4]; // extended grid dimensions
20  int X[4]; // true grid dimensions
21  int border[4];
24  real sigma; // where U = exp(sigma * H)
25 
26  GaugeGaussArg(const GaugeField &U, RNG &rngstate, double sigma) : U(U), rngstate(rngstate), sigma(sigma)
27  {
28  int R = 0;
29  for (int dir = 0; dir < 4; ++dir) {
30  border[dir] = U.R()[dir];
31  E[dir] = U.X()[dir];
32  X[dir] = U.X()[dir] - border[dir] * 2;
33  R += border[dir];
34  }
35  threads = X[0]*X[1]*X[2]*X[3]/2;
36  }
37  };
38 
39  template <typename real, typename Link> __device__ __host__ Link gauss_su3(cuRNGState &localState)
40  {
41  Link ret;
42  real rand1[4], rand2[4], phi[4], radius[4], temp1[4], temp2[4];
43 
44  for (int i = 0; i < 4; ++i) {
45  rand1[i] = Random<real>(localState);
46  rand2[i] = Random<real>(localState);
47  }
48 
49  for (int i = 0; i < 4; ++i) {
50  phi[i] = 2.0 * M_PI * rand1[i];
51  radius[i] = sqrt(-log(rand2[i]));
52  sincos(phi[i], &temp2[i], &temp1[i]);
53  temp1[i] *= radius[i];
54  temp2[i] *= radius[i];
55  }
56 
57  // construct Anti-Hermitian matrix
58  ret(0, 0) = complex<real>(0.0, temp1[2] + rsqrt(3.0) * temp2[3]);
59  ret(1, 1) = complex<real>(0.0, -temp1[2] + rsqrt(3.0) * temp2[3]);
60  ret(2, 2) = complex<real>(0.0, -2.0 * rsqrt(3.0) * temp2[3]);
61  ret(0, 1) = complex<real>(temp1[0], temp1[1]);
62  ret(1, 0) = complex<real>(-temp1[0], temp1[1]);
63  ret(0, 2) = complex<real>(temp1[3], temp2[0]);
64  ret(2, 0) = complex<real>(-temp1[3], temp2[0]);
65  ret(1, 2) = complex<real>(temp2[1], temp2[2]);
66  ret(2, 1) = complex<real>(-temp2[1], temp2[2]);
67 
68  return ret;
69  }
70 
71  template <typename Float, typename Arg> __global__ void computeGenGauss(Arg arg)
72  {
73  using real = typename mapper<Float>::type;
74  using Link = Matrix<complex<real>, 3>;
75  int x_cb = threadIdx.x + blockIdx.x * blockDim.x;
76  int parity = threadIdx.y + blockIdx.y * blockDim.y;
77  if (x_cb >= arg.threads) return;
78 
79  int x[4];
80  getCoords(x, x_cb, arg.X, parity);
81  for (int dr = 0; dr < 4; ++dr) x[dr] += arg.border[dr]; // extended grid coordinates
82 
83  if (arg.group && arg.sigma == 0.0) {
84  // if sigma = 0 then we just set the output matrix to the identity and finish
85  Link I;
86  setIdentity(&I);
87  for (int mu = 0; mu < 4; mu++) arg.U(mu, linkIndex(x, arg.E), parity) = I;
88  } else {
89  for (int mu = 0; mu < 4; mu++) {
90  cuRNGState localState = arg.rngstate.State()[parity * arg.threads + x_cb];
91 
92  // generate Gaussian distributed su(n) fiueld
93  Link u = gauss_su3<real, Link>(localState);
94  if (arg.group) {
95  u = arg.sigma * u;
96  expsu3<real>(u);
97  }
98  arg.U(mu, linkIndex(x, arg.E), parity) = u;
99 
100  arg.rngstate.State()[parity * arg.threads + x_cb] = localState;
101  }
102  }
103  }
104 
105  template <typename Float, typename Arg> class GaugeGauss : TunableVectorY
106  {
108  const GaugeField &meta;
109 
110 private:
111  unsigned int minThreads() const { return arg.threads; }
112  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
113 
114 public:
115  GaugeGauss(Arg &arg, GaugeField &meta) : TunableVectorY(2), arg(arg), meta(meta) {}
117 
118  void apply(const cudaStream_t &stream)
119  {
120  if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
121  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
122  computeGenGauss<Float><<<tp.grid, tp.block, tp.shared_bytes>>>(arg);
123  } else {
124  errorQuda("Randomize GaugeFields on CPU not supported yet\n");
125  }
126  }
127 
128  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
129 
130  long long flops() const { return 0; }
131  long long bytes() const { return meta.Bytes(); }
132 
133  void preTune() { arg.rngstate.backup(); }
134  void postTune() { arg.rngstate.restore(); }
135  };
136 
137  template <typename Float, QudaReconstructType recon, bool group>
139  {
140  GaugeGaussArg<Float, recon, group> arg(U, rngstate, sigma);
142  gaugeGauss.apply(0);
143  }
144 
145  template <typename Float> void gaugeGauss(GaugeField &U, RNG &rngstate, double sigma)
146  {
147  if (U.LinkType() == QUDA_SU3_LINKS) { // generate Gaussian distributed SU(3) field
148  if (getVerbosity() >= QUDA_SUMMARIZE)
149  printfQuda("Creating Gaussian distrbuted gauge field with sigma = %e\n", sigma);
150  switch (U.Reconstruct()) {
151  case QUDA_RECONSTRUCT_NO: genGauss<Float, QUDA_RECONSTRUCT_NO, true>(U, rngstate, sigma); break;
152  case QUDA_RECONSTRUCT_13: genGauss<Float, QUDA_RECONSTRUCT_13, true>(U, rngstate, sigma); break;
153  case QUDA_RECONSTRUCT_12: genGauss<Float, QUDA_RECONSTRUCT_12, true>(U, rngstate, sigma); break;
154  case QUDA_RECONSTRUCT_9: genGauss<Float, QUDA_RECONSTRUCT_9, true>(U, rngstate, sigma); break;
155  case QUDA_RECONSTRUCT_8: genGauss<Float, QUDA_RECONSTRUCT_8, true>(U, rngstate, sigma); break;
156  default: errorQuda("Reconstruction type %d of gauge field not supported", U.Reconstruct());
157  }
158  } else if (U.LinkType() == QUDA_MOMENTUM_LINKS) { // generate Gaussian distributed su(3) field
159  if (getVerbosity() >= QUDA_SUMMARIZE) printfQuda("Creating Gaussian distrbuted momentum field\n");
160  switch (U.Reconstruct()) {
161  case QUDA_RECONSTRUCT_NO: genGauss<Float, QUDA_RECONSTRUCT_NO, false>(U, rngstate, sigma); break;
162  case QUDA_RECONSTRUCT_10: genGauss<Float, QUDA_RECONSTRUCT_10, false>(U, rngstate, sigma); break;
163  default: errorQuda("Reconstruction type %d of gauge field not supported", U.Reconstruct());
164  }
165  }
166  }
167 
168  void gaugeGauss(GaugeField &U, RNG &rngstate, double sigma)
169  {
170  if (!U.isNative()) errorQuda("Order %d with %d reconstruct not supported", U.Order(), U.Reconstruct());
171  if (U.Ncolor() != 3) errorQuda("Nc = %d not supported", U.Ncolor());
172 
173  switch (U.Precision()) {
174  case QUDA_DOUBLE_PRECISION: gaugeGauss<double>(U, rngstate, sigma); break;
175  case QUDA_SINGLE_PRECISION: gaugeGauss<float>(U, rngstate, sigma); break;
176  default: errorQuda("Precision %d not supported", U.Precision());
177  }
178 
179  // ensure multi-gpu consistency if required
181  U.exchangeExtendedGhost(U.R());
182  } else if (U.GhostExchange() == QUDA_GHOST_EXCHANGE_PAD) {
183  U.exchangeGhost();
184  }
185  }
186 
187  void gaugeGauss(GaugeField &U, unsigned long long seed, double sigma)
188  {
189  RNG *randstates = new RNG(U, seed);
190  randstates->Init();
191  quda::gaugeGauss(U, *randstates, sigma);
192  randstates->Release();
193  delete randstates;
194  }
195 }
void Init()
Initialize CURAND RNG states.
Definition: random.cu:122
double mu
Definition: test_util.cpp:1648
struct curandStateMRG32k3a cuRNGState
Definition: random_quda.h:17
static __device__ __host__ int linkIndex(const int x[], const I X[4])
const char * AuxString() const
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
bool tuneGridDim() const
GaugeGaussArg(const GaugeField &U, RNG &rngstate, double sigma)
Definition: gauge_random.cu:26
#define errorQuda(...)
Definition: util_quda.h:121
GaugeGauss(Arg &arg, GaugeField &meta)
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:120
const GaugeField & meta
cudaStream_t * stream
QudaLinkType LinkType() const
Definition: gauge_field.h:255
static constexpr bool group
Definition: gauge_random.cu:17
__global__ void computeGenGauss(Arg arg)
Definition: gauge_random.cu:71
long long bytes() const
const char * VolString() const
static int R[4]
virtual void exchangeExtendedGhost(const int *R, bool no_comms_fill=false)=0
This does routine will populate the border / halo region of a gauge field that has been created using...
int Ncolor() const
Definition: gauge_field.h:249
const int * R() const
void genGauss(GaugeField &U, RNG &rngstate, double sigma)
size_t Bytes() const
Definition: gauge_field.h:311
void Release()
Release Device memory for CURAND RNG states.
Definition: random.cu:145
Class declaration to initialize and hold CURAND RNG states.
Definition: random_quda.h:23
unsigned int minThreads() const
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
__device__ __host__ Link gauss_su3(cuRNGState &localState)
Definition: gauge_random.cu:39
Main header file for host and device accessors to GaugeFields.
__device__ __host__ void setIdentity(Matrix< T, N > *m)
Definition: quda_matrix.h:653
QudaFieldLocation Location() const
__host__ __device__ ValueType log(ValueType x)
Definition: complex_quda.h:101
#define printfQuda(...)
Definition: util_quda.h:115
virtual void exchangeGhost(QudaLinkDirection=QUDA_LINK_BACKWARDS)=0
typename mapper< Float >::type real
Definition: gauge_random.cu:16
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
QudaReconstructType Reconstruct() const
Definition: gauge_field.h:250
typename gauge_mapper< Float, recon >::type Gauge
Definition: gauge_random.cu:15
QudaGaugeFieldOrder Order() const
Definition: gauge_field.h:251
long long flops() const
TuneKey tuneKey() const
QudaGhostExchange GhostExchange() const
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
QudaPrecision Precision() const
bool isNative() const
void apply(const cudaStream_t &stream)
QudaParity parity
Definition: covdev_test.cpp:54
void gaugeGauss(GaugeField &U, RNG &rngstate, double epsilon)
Generate Gaussian distributed su(N) or SU(N) fields. If U is a momentum field, then we generate rando...
__host__ __device__ int getCoords(int coord[], const Arg &arg, int &idx, int parity, int &dim)
Compute the space-time coordinates we are at.
const int * X() const