1 #include <quda_internal.h>
2 #include <quda_matrix.h>
4 #include <gauge_field.h>
5 #include <gauge_field_order.h>
6 #include <launch_kernel.cuh>
8 #include <index_helper.cuh>
9 #include <random_quda.h>
10 #include <instantiate.h>
14 template <typename Float_, int nColor_, QudaReconstructType recon_, bool group_> struct GaugeGaussArg {
16 using real = typename mapper<Float>::type;
17 static constexpr int nColor = nColor_;
18 static constexpr QudaReconstructType recon = recon_;
19 static constexpr bool group = group_;
21 using Gauge = typename gauge_mapper<Float, recon>::type;
23 int threads; // number of active threads required
24 int E[4]; // extended grid dimensions
25 int X[4]; // true grid dimensions
29 real sigma; // where U = exp(sigma * H)
31 GaugeGaussArg(const GaugeField &U, RNG &rngstate, double sigma) : U(U), rngstate(rngstate), sigma(sigma)
34 for (int dir = 0; dir < 4; ++dir) {
35 border[dir] = U.R()[dir];
37 X[dir] = U.X()[dir] - border[dir] * 2;
40 threads = X[0]*X[1]*X[2]*X[3]/2;
44 template <typename real, typename Link> __device__ __host__ Link gauss_su3(cuRNGState &localState)
47 real rand1[4], rand2[4], phi[4], radius[4], temp1[4], temp2[4];
49 for (int i = 0; i < 4; ++i) {
50 rand1[i] = Random<real>(localState);
51 rand2[i] = Random<real>(localState);
54 for (int i = 0; i < 4; ++i) {
55 phi[i] = 2.0 * M_PI * rand1[i];
56 radius[i] = sqrt(-log(rand2[i]));
57 sincos(phi[i], &temp2[i], &temp1[i]);
58 temp1[i] *= radius[i];
59 temp2[i] *= radius[i];
62 // construct Anti-Hermitian matrix
63 ret(0, 0) = complex<real>(0.0, temp1[2] + rsqrt(3.0) * temp2[3]);
64 ret(1, 1) = complex<real>(0.0, -temp1[2] + rsqrt(3.0) * temp2[3]);
65 ret(2, 2) = complex<real>(0.0, -2.0 * rsqrt(3.0) * temp2[3]);
66 ret(0, 1) = complex<real>(temp1[0], temp1[1]);
67 ret(1, 0) = complex<real>(-temp1[0], temp1[1]);
68 ret(0, 2) = complex<real>(temp1[3], temp2[0]);
69 ret(2, 0) = complex<real>(-temp1[3], temp2[0]);
70 ret(1, 2) = complex<real>(temp2[1], temp2[2]);
71 ret(2, 1) = complex<real>(-temp2[1], temp2[2]);
76 template <typename Arg> __global__ void computeGenGauss(Arg arg)
78 using real = typename mapper<typename Arg::Float>::type;
79 using Link = Matrix<complex<real>, Arg::nColor>;
80 int x_cb = threadIdx.x + blockIdx.x * blockDim.x;
81 int parity = threadIdx.y + blockIdx.y * blockDim.y;
82 if (x_cb >= arg.threads) return;
85 getCoords(x, x_cb, arg.X, parity);
86 for (int dr = 0; dr < 4; ++dr) x[dr] += arg.border[dr]; // extended grid coordinates
88 if (arg.group && arg.sigma == 0.0) {
89 // if sigma = 0 then we just set the output matrix to the identity and finish
92 for (int mu = 0; mu < 4; mu++) arg.U(mu, linkIndex(x, arg.E), parity) = I;
94 for (int mu = 0; mu < 4; mu++) {
95 cuRNGState localState = arg.rngstate.State()[parity * arg.threads + x_cb];
97 // generate Gaussian distributed su(n) fiueld
98 Link u = gauss_su3<real, Link>(localState);
103 arg.U(mu, linkIndex(x, arg.E), parity) = u;
105 arg.rngstate.State()[parity * arg.threads + x_cb] = localState;
110 template <typename Arg> class GaugeGauss : TunableVectorY
113 const GaugeField &meta;
115 unsigned int minThreads() const { return arg.threads; }
116 bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
119 GaugeGauss(Arg &arg, GaugeField &meta) :
124 void apply(const qudaStream_t &stream)
126 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
127 qudaLaunchKernel(computeGenGauss<Arg>, tp, stream, arg);
130 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), meta.AuxString()); }
132 long long flops() const { return 0; }
133 long long bytes() const { return meta.Bytes(); }
135 void preTune() { arg.rngstate.backup(); }
136 void postTune() { arg.rngstate.restore(); }
139 template <typename Float, int nColor, QudaReconstructType recon>
140 struct GenGaussGroup {
141 GenGaussGroup(GaugeField &U, RNG &rngstate, double sigma)
143 constexpr bool group = true;
144 GaugeGaussArg<Float, nColor, recon, group> arg(U, rngstate, sigma);
145 GaugeGauss<decltype(arg)> gaugeGauss(arg, U);
150 template <typename Float, int nColor, QudaReconstructType recon>
151 struct GenGaussAlgebra {
152 GenGaussAlgebra(GaugeField &U, RNG &rngstate, double sigma)
154 constexpr bool group = false;
155 GaugeGaussArg<Float, nColor, recon, group> arg(U, rngstate, sigma);
156 GaugeGauss<decltype(arg)> gaugeGauss(arg, U);
161 void gaugeGauss(GaugeField &U, RNG &rng, double sigma)
163 if (!U.isNative()) errorQuda("Order %d with %d reconstruct not supported", U.Order(), U.Reconstruct());
165 if (U.LinkType() == QUDA_SU3_LINKS) {
167 if (getVerbosity() >= QUDA_SUMMARIZE)
168 printfQuda("Creating Gaussian distrbuted gauge field with sigma = %e\n", sigma);
169 instantiate<GenGaussGroup, ReconstructFull>(U, rng, sigma);
171 } else if (U.LinkType() == QUDA_MOMENTUM_LINKS) {
173 if (getVerbosity() >= QUDA_SUMMARIZE)
174 printfQuda("Creating Gaussian distrbuted momentum field\n");
175 instantiate<GenGaussAlgebra, ReconstructMom>(U, rng, sigma);
178 errorQuda("Unexpected linkt type %d", U.LinkType());
181 // ensure multi-gpu consistency if required
182 if (U.GhostExchange() == QUDA_GHOST_EXCHANGE_EXTENDED) {
183 U.exchangeExtendedGhost(U.R());
184 } else if (U.GhostExchange() == QUDA_GHOST_EXCHANGE_PAD) {
189 void gaugeGauss(GaugeField &U, unsigned long long seed, double sigma)
191 RNG *randstates = new RNG(U, seed);
193 quda::gaugeGauss(U, *randstates, sigma);
194 randstates->Release();