14 #ifdef GPU_GAUGE_TOOLS 16 template <
typename Gauge>
17 struct GaugeGaussArg {
25 GaugeGaussArg(
const Gauge &dataDs,
const GaugeField &data,
RNG &rngstate)
26 : dataDs(dataDs), rngstate(rngstate)
29 for (
int dir=0; dir<4; ++dir){
30 border[dir] = data.
R()[dir];
31 E[dir] = data.
X()[dir];
32 X[dir] = data.
X()[dir] - border[dir]*2;
35 threads =
X[0]*
X[1]*
X[2]*
X[3]/2;
40 template<
typename Float>
46 Float rand1[4], rand2[4], phi[4], radius[4], temp1[4], temp2[4];
48 for (
int i=0;
i<4; ++
i)
50 rand1[
i]= Random<Float>(localState);
51 rand2[
i]= Random<Float>(localState);
54 for (
int i=0;
i<4; ++
i)
56 phi[
i]=2.0*M_PI*rand1[
i];
60 temp1[
i] = radius[
i]*
cos(phi[
i]);
61 temp2[
i] = radius[
i]*
sin(phi[
i]);
64 ret(0,0) = complex<Float>( temp1[2] + 1./
sqrt(3.0)*temp2[3], 0.0);
65 ret(0,1) = complex<Float>( temp1[0], -temp1[1]);
66 ret(0,2) = complex<Float>( temp1[3], -temp2[0]);
67 ret(1,0) = complex<Float>( temp1[0], temp1[1] );
68 ret(1,1) = complex<Float>( -temp1[2] + 1./
sqrt(3.0) * temp2[3], 0.0 );
69 ret(1,2) = complex<Float>( temp2[1], -temp2[2] );
70 ret(2,0) = complex<Float>( temp1[3], temp2[0] );
71 ret(2,1) = complex<Float>( temp2[1], temp2[2] );
72 ret(2,2) = complex<Float>( - 2./
sqrt(3.0) * temp2[3], 0.0 );
78 template<
typename Float,
typename Gauge>
79 __global__
void computeGenGauss(GaugeGaussArg<Gauge>
arg){
90 for (
int dr=0; dr<4; ++dr)
x[dr] +=
arg.border[dr];
92 int dx[4] = {0, 0, 0, 0};
93 for(
int mu = 0;
mu < 4;
mu++){
96 Link U = genGaussSU3<Float>(localState);
105 template<
typename Float,
typename Gauge>
107 GaugeGaussArg<Gauge>
arg;
126 errorQuda(
"Randomize GaugeFields on CPU not supported yet\n");
131 std::stringstream vol,
aux;
132 vol <<
arg.X[0] <<
"x" <<
arg.X[1] <<
"x" <<
arg.X[2] <<
"x" <<
arg.X[3];
133 aux <<
"threads=" <<
arg.threads <<
",prec=" <<
sizeof(Float);
134 return TuneKey(vol.str().c_str(),
typeid(*this).name(), aux.str().c_str());
137 long long flops()
const {
return 0; }
138 long long bytes()
const {
return 0; }
143 arg.rngstate.backup();
147 arg.rngstate.restore();
152 template<
typename Float,
typename Gauge>
154 GaugeGaussArg<Gauge>
arg(dataDs, data, rngstate);
162 template<
typename Float>
167 genGauss<Float>(Gauge(dataDs), dataDs, rngstate);
170 genGauss<Float>(Gauge(dataDs), dataDs, rngstate);
173 genGauss<Float>(Gauge(dataDs), dataDs, rngstate);
184 #ifdef GPU_GAUGE_TOOLS 187 errorQuda(
"Half precision not supported\n");
194 gaugeGauss<float>(dataDs, rngstate);
196 gaugeGauss<double>(dataDs, rngstate);
struct curandStateMRG32k3a cuRNGState
static __device__ __host__ int linkIndexShift(const I x[], const J dx[], const K X[4])
QudaVerbosity getVerbosity()
__host__ __device__ ValueType sqrt(ValueType x)
virtual TuneKey tuneKey() const =0
void gaugeGauss(GaugeField &dataDs, RNG &rngstate)
__host__ __device__ ValueType sin(ValueType x)
virtual long long bytes() const
__device__ __host__ void genGauss(InOrder &inOrder, cuRNGState &localState, int x, int s, int c)
Class declaration to initialize and hold CURAND RNG states.
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Main header file for host and device accessors to GaugeFields.
cudaError_t qudaDeviceSynchronize()
Wrapper around cudaDeviceSynchronize or cuDeviceSynchronize.
QudaFieldLocation Location() const
__host__ __device__ ValueType log(ValueType x)
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
virtual unsigned int minThreads() const
__host__ __device__ ValueType cos(ValueType x)
QudaReconstructType Reconstruct() const
QudaGaugeFieldOrder Order() const
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
QudaPrecision Precision() const
virtual bool tuneGridDim() const
virtual long long flops() const =0
virtual void apply(const cudaStream_t &stream)=0
static __device__ __host__ void getCoords(int x[], int cb_index, const I X[], int parity)