QUDA  0.9.0
gauge_random.cu
Go to the documentation of this file.
1 #include <quda_internal.h>
2 #include <quda_matrix.h>
3 #include <tune_quda.h>
4 #include <gauge_field.h>
5 #include <gauge_field_order.h>
6 #include <launch_kernel.cuh>
7 #include <atomic.cuh>
8 #include <cub_helper.cuh>
9 #include <index_helper.cuh>
10 #include <random_quda.h>
11 
12 namespace quda {
13 
14 #ifdef GPU_GAUGE_TOOLS
15 
16  template <typename Gauge>
17  struct GaugeGaussArg {
18  int threads; // number of active threads required
19  int E[4]; // extended grid dimensions
20  int X[4]; // true grid dimensions
21  int border[4];
22  Gauge dataDs;
23  RNG rngstate;
24 
25  GaugeGaussArg(const Gauge &dataDs, const GaugeField &data, RNG &rngstate)
26  : dataDs(dataDs), rngstate(rngstate)
27  {
28  int R = 0;
29  for (int dir=0; dir<4; ++dir){
30  border[dir] = data.R()[dir];
31  E[dir] = data.X()[dir];
32  X[dir] = data.X()[dir] - border[dir]*2;
33  R += border[dir];
34  }
35  threads = X[0]*X[1]*X[2]*X[3]/2;
36  }
37  };
38 
39 
40  template<typename Float>
41  __device__ __host__ Matrix<complex<Float>,3> genGaussSU3(cuRNGState &localState){
43  //ret(i,j) = 0.0;
44  //ret(i,j) = complex<Float>( (Float)(Random<Float>(localState) - 0.5), (Float)(Random<Float>(localState) - 0.5) );
45 
46  Float rand1[4], rand2[4], phi[4], radius[4], temp1[4], temp2[4];
47 
48  for (int i=0; i<4; ++i)
49  {
50  rand1[i]= Random<Float>(localState);
51  rand2[i]= Random<Float>(localState);
52  }
53 
54  for (int i=0; i<4; ++i)
55  {
56  phi[i]=2.0*M_PI*rand1[i];
57  rand2[i] = rand2[i];
58  radius[i]=sqrt( -log(rand2[i]) );
59 
60  temp1[i] = radius[i]*cos(phi[i]);
61  temp2[i] = radius[i]*sin(phi[i]);
62  }
63 
64  ret(0,0) = complex<Float>( temp1[2] + 1./sqrt(3.0)*temp2[3], 0.0);
65  ret(0,1) = complex<Float>( temp1[0], -temp1[1]);
66  ret(0,2) = complex<Float>( temp1[3], -temp2[0]);
67  ret(1,0) = complex<Float>( temp1[0], temp1[1] );
68  ret(1,1) = complex<Float>( -temp1[2] + 1./sqrt(3.0) * temp2[3], 0.0 );
69  ret(1,2) = complex<Float>( temp2[1], -temp2[2] );
70  ret(2,0) = complex<Float>( temp1[3], temp2[0] );
71  ret(2,1) = complex<Float>( temp2[1], temp2[2] );
72  ret(2,2) = complex<Float>( - 2./sqrt(3.0) * temp2[3], 0.0 );
73 
74  return ret;
75  }
76 
77 
78  template<typename Float, typename Gauge>
79  __global__ void computeGenGauss(GaugeGaussArg<Gauge> arg){
80  typedef Matrix<complex<Float>,3> Link;
81  int idx = threadIdx.x + blockIdx.x*blockDim.x;
82  int parity = threadIdx.y + blockIdx.y*blockDim.y;
83 
84  typedef Matrix<complex<Float>,3> Link;
85 
86 
87  if(idx < arg.threads) {
88  int x[4];
89  getCoords(x, idx, arg.X, parity);
90  for (int dr=0; dr<4; ++dr) x[dr] += arg.border[dr]; // extended grid coordinates
91 
92  int dx[4] = {0, 0, 0, 0};
93  for(int mu = 0; mu < 4; mu++){
94  cuRNGState localState = arg.rngstate.State()[idx + parity*arg.threads];
95 
96  Link U = genGaussSU3<Float>(localState);
97 
98  arg.rngstate.State()[ idx + parity*arg.threads ] = localState;
99  arg.dataDs(mu, linkIndexShift(x,dx,arg.X), parity) = U;
100  }
101 
102  }
103  }
104 
105  template<typename Float, typename Gauge>
106  class GaugeGauss : TunableVectorY {
107  GaugeGaussArg<Gauge> arg;
108  GaugeField &gf;
109 
110  private:
111  unsigned int minThreads() const { return arg.threads; }
112  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
113 
114  public:
115  GaugeGauss(GaugeGaussArg<Gauge> &arg, GaugeField &gf)
116  : TunableVectorY(2), arg(arg), gf(gf){}
117  ~GaugeGauss () { }
118 
119  void apply(const cudaStream_t &stream){
121  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
122 
123  computeGenGauss<Float><<<tp.grid,tp.block,tp.shared_bytes>>>(arg);
125  } else {
126  errorQuda("Randomize GaugeFields on CPU not supported yet\n");
127  }
128  }
129 
130  TuneKey tuneKey() const {
131  std::stringstream vol, aux;
132  vol << arg.X[0] << "x" << arg.X[1] << "x" << arg.X[2] << "x" << arg.X[3];
133  aux << "threads=" << arg.threads << ",prec=" << sizeof(Float);
134  return TuneKey(vol.str().c_str(), typeid(*this).name(), aux.str().c_str());
135  }
136 
137  long long flops() const { return 0; }
138  long long bytes() const { return 0; }
139 
140 
141  void preTune(){
142  //gf.backup();
143  arg.rngstate.backup();
144  }
145  void postTune(){
146  //gf.restore();
147  arg.rngstate.restore();
148  }
149 
150  };
151 
152  template<typename Float, typename Gauge>
153  void genGauss(const Gauge dataDs, GaugeField& data, RNG &rngstate) {
154  GaugeGaussArg<Gauge> arg(dataDs, data, rngstate);
155  GaugeGauss<Float,Gauge> gaugeGauss(arg, data);
156  gaugeGauss.apply(0);
157 
158  }
159 
160 
161 
162  template<typename Float>
163  void gaugeGauss(GaugeField &dataDs, RNG &rngstate) {
164 
165  if(dataDs.Reconstruct() == QUDA_RECONSTRUCT_NO) {
166  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO>::type Gauge;
167  genGauss<Float>(Gauge(dataDs), dataDs, rngstate);
168  }else if(dataDs.Reconstruct() == QUDA_RECONSTRUCT_12){
169  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_12>::type Gauge;
170  genGauss<Float>(Gauge(dataDs), dataDs, rngstate);
171  }else if(dataDs.Reconstruct() == QUDA_RECONSTRUCT_8){
172  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_8>::type Gauge;
173  genGauss<Float>(Gauge(dataDs), dataDs, rngstate);
174  }else{
175  errorQuda("Reconstruction type %d of origin gauge field not supported", dataDs.Reconstruct());
176  }
177 
178  }
179 
180 #endif
181 
182  void gaugeGauss(GaugeField &dataDs, RNG &rngstate) {
183 
184 #ifdef GPU_GAUGE_TOOLS
185 
186  if(dataDs.Precision() == QUDA_HALF_PRECISION){
187  errorQuda("Half precision not supported\n");
188  }
189 
190  if (!dataDs.isNative())
191  errorQuda("Order %d with %d reconstruct not supported", dataDs.Order(), dataDs.Reconstruct());
192 
193  if (dataDs.Precision() == QUDA_SINGLE_PRECISION){
194  gaugeGauss<float>(dataDs, rngstate);
195  } else if(dataDs.Precision() == QUDA_DOUBLE_PRECISION) {
196  gaugeGauss<double>(dataDs, rngstate);
197  } else {
198  errorQuda("Precision %d not supported", dataDs.Precision());
199  }
200  return;
201 #else
202  errorQuda("Gauge tools are not build");
203 #endif
204  }
205 
206 }
dim3 dim3 blockDim
double mu
Definition: test_util.cpp:1643
struct curandStateMRG32k3a cuRNGState
Definition: random_quda.h:17
cudaStream_t stream
static __device__ __host__ int linkIndexShift(const I x[], const J dx[], const K X[4])
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:20
#define errorQuda(...)
Definition: util_quda.h:90
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:105
static int R[4]
int E[4]
Definition: test_util.cpp:36
virtual TuneKey tuneKey() const =0
void gaugeGauss(GaugeField &dataDs, RNG &rngstate)
const int * R() const
__host__ __device__ ValueType sin(ValueType x)
Definition: complex_quda.h:40
virtual long long bytes() const
Definition: tune_quda.h:64
__device__ __host__ void genGauss(InOrder &inOrder, cuRNGState &localState, int x, int s, int c)
Definition: spinor_gauss.cu:23
Class declaration to initialize and hold CURAND RNG states.
Definition: random_quda.h:23
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:603
Main header file for host and device accessors to GaugeFields.
cudaError_t qudaDeviceSynchronize()
Wrapper around cudaDeviceSynchronize or cuDeviceSynchronize.
QudaFieldLocation Location() const
__host__ __device__ ValueType log(ValueType x)
Definition: complex_quda.h:90
virtual void preTune()
Definition: tune_quda.h:204
virtual void postTune()
Definition: tune_quda.h:205
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
Definition: complex_quda.h:880
virtual unsigned int minThreads() const
Definition: tune_quda.h:73
__host__ __device__ ValueType cos(ValueType x)
Definition: complex_quda.h:35
QudaReconstructType Reconstruct() const
Definition: gauge_field.h:203
QudaGaugeFieldOrder Order() const
Definition: gauge_field.h:204
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:51
QudaPrecision Precision() const
bool isNative() const
virtual bool tuneGridDim() const
Definition: tune_quda.h:74
QudaParity parity
Definition: covdev_test.cpp:53
char aux[TuneKey::aux_n]
Definition: tune_quda.h:189
virtual long long flops() const =0
virtual void apply(const cudaStream_t &stream)=0
const int * X() const
static __device__ __host__ void getCoords(int x[], int cb_index, const I X[], int parity)