14 template <
typename Float, QudaReconstructType recon,
bool group_>
struct GaugeGaussArg {
17 static constexpr
bool group = group_;
29 for (
int dir = 0; dir < 4; ++dir) {
30 border[dir] = U.
R()[dir];
32 X[dir] = U.
X()[dir] - border[dir] * 2;
35 threads = X[0]*X[1]*X[2]*X[3]/2;
42 real rand1[4], rand2[4], phi[4], radius[4], temp1[4], temp2[4];
44 for (
int i = 0; i < 4; ++i) {
45 rand1[i] = Random<real>(localState);
46 rand2[i] = Random<real>(localState);
49 for (
int i = 0; i < 4; ++i) {
50 phi[i] = 2.0 * M_PI * rand1[i];
51 radius[i] =
sqrt(-
log(rand2[i]));
52 sincos(phi[i], &temp2[i], &temp1[i]);
53 temp1[i] *= radius[i];
54 temp2[i] *= radius[i];
58 ret(0, 0) = complex<real>(0.0, temp1[2] + rsqrt(3.0) * temp2[3]);
59 ret(1, 1) = complex<real>(0.0, -temp1[2] + rsqrt(3.0) * temp2[3]);
60 ret(2, 2) = complex<real>(0.0, -2.0 * rsqrt(3.0) * temp2[3]);
61 ret(0, 1) = complex<real>(temp1[0], temp1[1]);
62 ret(1, 0) = complex<real>(-temp1[0], temp1[1]);
63 ret(0, 2) = complex<real>(temp1[3], temp2[0]);
64 ret(2, 0) = complex<real>(-temp1[3], temp2[0]);
65 ret(1, 2) = complex<real>(temp2[1], temp2[2]);
66 ret(2, 1) = complex<real>(-temp2[1], temp2[2]);
75 int x_cb = threadIdx.x + blockIdx.x * blockDim.x;
76 int parity = threadIdx.y + blockIdx.y * blockDim.y;
77 if (x_cb >= arg.threads)
return;
81 for (
int dr = 0; dr < 4; ++dr) x[dr] += arg.border[dr];
83 if (arg.group && arg.sigma == 0.0) {
89 for (
int mu = 0;
mu < 4;
mu++) {
90 cuRNGState localState = arg.rngstate.State()[parity * arg.threads + x_cb];
93 Link u = gauss_su3<real, Link>(localState);
100 arg.rngstate.State()[parity * arg.threads + x_cb] = localState;
124 errorQuda(
"Randomize GaugeFields on CPU not supported yet\n");
130 long long flops()
const {
return 0; }
137 template <
typename Float, QudaReconstructType recon,
bool group>
149 printfQuda(
"Creating Gaussian distrbuted gauge field with sigma = %e\n", sigma);
156 default:
errorQuda(
"Reconstruction type %d of gauge field not supported", U.Reconstruct());
163 default:
errorQuda(
"Reconstruction type %d of gauge field not supported", U.Reconstruct());
176 default:
errorQuda(
"Precision %d not supported", U.Precision());
189 RNG *randstates =
new RNG(U, seed);
void Init()
Initialize CURAND RNG states.
struct curandStateMRG32k3a cuRNGState
static __device__ __host__ int linkIndex(const int x[], const I X[4])
const char * AuxString() const
QudaVerbosity getVerbosity()
GaugeGaussArg(const GaugeField &U, RNG &rngstate, double sigma)
GaugeGauss(Arg &arg, GaugeField &meta)
__host__ __device__ ValueType sqrt(ValueType x)
QudaLinkType LinkType() const
static constexpr bool group
__global__ void computeGenGauss(Arg arg)
const char * VolString() const
virtual void exchangeExtendedGhost(const int *R, bool no_comms_fill=false)=0
This does routine will populate the border / halo region of a gauge field that has been created using...
void genGauss(GaugeField &U, RNG &rngstate, double sigma)
void Release()
Release Device memory for CURAND RNG states.
Class declaration to initialize and hold CURAND RNG states.
unsigned int minThreads() const
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
__device__ __host__ Link gauss_su3(cuRNGState &localState)
Main header file for host and device accessors to GaugeFields.
__device__ __host__ void setIdentity(Matrix< T, N > *m)
QudaFieldLocation Location() const
__host__ __device__ ValueType log(ValueType x)
virtual void exchangeGhost(QudaLinkDirection=QUDA_LINK_BACKWARDS)=0
typename mapper< Float >::type real
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
QudaReconstructType Reconstruct() const
typename gauge_mapper< Float, recon >::type Gauge
QudaGaugeFieldOrder Order() const
QudaGhostExchange GhostExchange() const
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
QudaPrecision Precision() const
void apply(const cudaStream_t &stream)
void gaugeGauss(GaugeField &U, RNG &rngstate, double epsilon)
Generate Gaussian distributed su(N) or SU(N) fields. If U is a momentum field, then we generate rando...
__host__ __device__ int getCoords(int coord[], const Arg &arg, int &idx, int parity, int &dim)
Compute the space-time coordinates we are at.