QUDA  v0.7.0
A library for QCD on GPUs
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
clover_invert.cu
Go to the documentation of this file.
1 #include <tune_quda.h>
2 #include <clover_field_order.h>
3 #include <complex_quda.h>
4 #include <cub/cub.cuh>
5 #include <launch_kernel.cuh>
6 #include <face_quda.h>
7 
8 namespace quda {
9 
10 #ifdef GPU_CLOVER_DIRAC
11 
12  template <typename Clover>
13  struct CloverInvertArg {
14  const Clover clover;
15  Clover inverse;
16  bool computeTraceLog;
17  double * const trlogA_h;
18  double *trlogA_d;
19 //extra attributes for twisted mass clover
20  bool twist;
21  double mu2;
22  CloverInvertArg(Clover &inverse, const Clover &clover, bool computeTraceLog=0, double* const trlogA=0) :
23  inverse(inverse), clover(clover), computeTraceLog(computeTraceLog), trlogA_h(trlogA), twist(clover.Twisted()), mu2(clover.Mu2()){
24  cudaHostGetDevicePointer(&trlogA_d, trlogA_h, 0); // set the matching device pointer
25  }
26  };
27 
28  static __inline__ __device__ double atomicAdd(double *addr, double val)
29  {
30  double old=*addr, assumed;
31 
32  do {
33  assumed = old;
34  old = __longlong_as_double( atomicCAS((unsigned long long int*)addr,
35  __double_as_longlong(assumed),
36  __double_as_longlong(val+assumed)));
37  } while( __double_as_longlong(assumed)!=__double_as_longlong(old) );
38 
39  return old;
40  }
41 
46  template <int blockSize, typename Float, typename Clover>
47  __device__ __host__ double cloverInvertCompute(CloverInvertArg<Clover> arg, int x, int parity) {
48 
49  Float A[72];
50  double trlogA = 0.0;
51 
52  // load the clover term into memory
53  arg.clover.load(A, x, parity);
54 
55  for (int ch=0; ch<2; ch++) {
56 
57  Float diag[6];
58  Float tmp[6]; // temporary storage
59  complex<Float> tri[15];
60 
61  // hack into the right order as MILC just to copy algorithm directly
62  // FIXME use native ordering in the Cholseky
63  // factor of two is inherent to QUDA clover storage
64  for (int i=0; i<6; i++) diag[i] = 2.0*A[ch*36+i];
65 
66  const int idtab[15]={0,1,3,6,10,2,4,7,11,5,8,12,9,13,14};
67  for (int i=0; i<15; i++) tri[idtab[i]] = complex<Float>(2.0*A[ch*36+6+2*i], 2.0*A[ch*36+6+2*i+1]);
68 
69 //Compute (T^2 + mu2) first, then invert (not optimized!):
70  if(arg.twist)
71  {
72  complex<Float> aux[15];//hmmm, better to reuse A-regs...
73  //another solution just to define (but compiler may not be happy with this, swapping everything in
74  //the global buffer):
75  //complex<Float>* aux = (complex<Float>*)&A[ch*36];
76  //compute off-diagonal terms:
77 //
78  aux[ 0] = tri[0]*diag[0]+diag[1]*tri[0]+conj(tri[2])*tri[1]+conj(tri[4])*tri[3]+conj(tri[7])*tri[6]+conj(tri[11])*tri[10];
79 //
80  aux[ 1] = tri[1]*diag[0]+diag[2]*tri[1]+tri[2]*tri[0]+conj(tri[5])*tri[3]+conj(tri[8])*tri[6]+conj(tri[12])*tri[10];
81 
82  aux[ 2] = tri[2]*diag[1]+diag[2]*tri[2]+tri[1]*conj(tri[0])+conj(tri[5])*tri[4]+conj(tri[8])*tri[7]+conj(tri[12])*tri[11];
83 //
84  aux[ 3] = tri[3]*diag[0]+diag[3]*tri[3]+tri[4]*tri[0]+tri[5]*tri[1]+conj(tri[9])*tri[6]+conj(tri[13])*tri[10];
85 
86  aux[ 4] = tri[4]*diag[1]+diag[3]*tri[4]+tri[3]*conj(tri[0])+tri[5]*tri[2]+conj(tri[9])*tri[7]+conj(tri[13])*tri[11];
87 
88  aux[ 5] = tri[5]*diag[2]+diag[3]*tri[5]+tri[3]*conj(tri[1])+tri[4]*conj(tri[2])+conj(tri[9])*tri[8]+conj(tri[13])*tri[12];
89 //
90  aux[ 6] = tri[6]*diag[0]+diag[4]*tri[6]+tri[7]*tri[0]+tri[8]*tri[1]+tri[9]*tri[3]+conj(tri[14])*tri[10];
91 
92  aux[ 7] = tri[7]*diag[1]+diag[4]*tri[7]+tri[6]*conj(tri[0])+tri[8]*tri[2]+tri[9]*tri[4]+conj(tri[14])*tri[11];
93 
94  aux[ 8] = tri[8]*diag[2]+diag[4]*tri[8]+tri[6]*conj(tri[1])+tri[7]*conj(tri[2])+tri[9]*tri[5]+conj(tri[14])*tri[12];
95 
96  aux[ 9] = tri[9]*diag[3]+diag[4]*tri[9]+tri[6]*conj(tri[3])+tri[7]*conj(tri[4])+tri[8]*conj(tri[5])+conj(tri[14])*tri[13];
97 //
98  aux[10] = tri[10]*diag[0]+diag[5]*tri[10]+tri[11]*tri[0]+tri[12]*tri[1]+tri[13]*tri[3]+tri[14]*tri[6];
99 
100  aux[11] = tri[11]*diag[1]+diag[5]*tri[11]+tri[10]*conj(tri[0])+tri[12]*tri[2]+tri[13]*tri[4]+tri[14]*tri[7];
101 
102  aux[12] = tri[12]*diag[2]+diag[5]*tri[12]+tri[10]*conj(tri[1])+tri[11]*conj(tri[2])+tri[13]*tri[5]+tri[14]*tri[8];
103 
104  aux[13] = tri[13]*diag[3]+diag[5]*tri[13]+tri[10]*conj(tri[3])+tri[11]*conj(tri[4])+tri[12]*conj(tri[5])+tri[14]*tri[9];
105 
106  aux[14] = tri[14]*diag[4]+diag[5]*tri[14]+tri[10]*conj(tri[6])+tri[11]*conj(tri[7])+tri[12]*conj(tri[8])+tri[13]*conj(tri[9]);
107 
108 
109  //update diagonal elements:
110  diag[0] = (Float)arg.mu2+diag[0]*diag[0]+norm(tri[ 0])+norm(tri[ 1])+norm(tri[ 3])+norm(tri[ 6])+norm(tri[10]);
111  diag[1] = (Float)arg.mu2+diag[1]*diag[1]+norm(tri[ 0])+norm(tri[ 2])+norm(tri[ 4])+norm(tri[ 7])+norm(tri[11]);
112  diag[2] = (Float)arg.mu2+diag[2]*diag[2]+norm(tri[ 1])+norm(tri[ 2])+norm(tri[ 5])+norm(tri[ 8])+norm(tri[12]);
113  diag[3] = (Float)arg.mu2+diag[3]*diag[3]+norm(tri[ 3])+norm(tri[ 4])+norm(tri[ 5])+norm(tri[ 9])+norm(tri[13]);
114  diag[4] = (Float)arg.mu2+diag[4]*diag[4]+norm(tri[ 6])+norm(tri[ 7])+norm(tri[ 8])+norm(tri[ 9])+norm(tri[14]);
115  diag[5] = (Float)arg.mu2+diag[5]*diag[5]+norm(tri[10])+norm(tri[11])+norm(tri[12])+norm(tri[13])+norm(tri[14]);
116 
117  //update off-diagonal elements:
118  for(int i = 0; i < 15; i++) tri[i] = aux[i];
119  }
120 //
121  for (int j=0; j<6; j++) {
122  diag[j] = sqrt(diag[j]);
123  tmp[j] = 1.0 / diag[j];
124 
125  for (int k=j+1; k<6; k++) {
126  int kj = k*(k-1)/2+j;
127  tri[kj] *= tmp[j];
128  }
129 
130  for(int k=j+1;k<6;k++){
131  int kj=k*(k-1)/2+j;
132  diag[k] -= (tri[kj] * conj(tri[kj])).real();
133  for(int l=k+1;l<6;l++){
134  int lj=l*(l-1)/2+j;
135  int lk=l*(l-1)/2+k;
136  tri[lk] -= tri[lj] * conj(tri[kj]);
137  }
138  }
139  }
140 
141  /* Accumulate trlogA */
142  for (int j=0;j<6;j++) trlogA += (double)2.0*log((double)(diag[j]));
143 
144  /* Now use forward and backward substitution to construct inverse */
145  complex<Float> v1[6];
146  for (int k=0;k<6;k++) {
147  for(int l=0;l<k;l++) v1[l] = complex<Float>(0.0, 0.0);
148 
149  /* Forward substitute */
150  v1[k] = complex<Float>(tmp[k], 0.0);
151  for(int l=k+1;l<6;l++){
152  complex<Float> sum = complex<Float>(0.0, 0.0);
153  for(int j=k;j<l;j++){
154  int lj=l*(l-1)/2+j;
155  sum -= tri[lj] * v1[j];
156  }
157  v1[l] = sum * tmp[l];
158  }
159 
160  /* Backward substitute */
161  v1[5] = v1[5] * tmp[5];
162  for(int l=4;l>=k;l--){
163  complex<Float> sum = v1[l];
164  for(int j=l+1;j<6;j++){
165  int jl=j*(j-1)/2+l;
166  sum -= conj(tri[jl]) * v1[j];
167  }
168  v1[l] = sum * tmp[l];
169  }
170 
171  /* Overwrite column k */
172  diag[k] = v1[k].real();
173  for(int l=k+1;l<6;l++){
174  int lk=l*(l-1)/2+k;
175  tri[lk] = v1[l];
176  }
177  }
178 
179  for (int i=0; i<6; i++) A[ch*36+i] = 0.5 * diag[i];
180  for (int i=0; i<15; i++) {
181  A[ch*36+6+2*i] = 0.5*tri[idtab[i]].real(); A[ch*36+6+2*i+1] = 0.5*tri[idtab[i]].imag();
182  }
183  }
184 
185  // save the inverted matrix
186  arg.inverse.save(A, x, parity);
187 
188  return trlogA;
189  }
190 
191  template <int blockSize, typename Float, typename Clover>
192  void cloverInvert(CloverInvertArg<Clover> arg) {
193  for (int parity=0; parity<2; parity++) {
194  for (int x=0; x<arg.clover.volumeCB; x++) {
195  // should make this thread safe if we ever apply threads to cpu code
196  double trlogA = cloverInvertCompute<blockSize, Float>(arg, x, parity);
197  if (arg.computeTraceLog) arg.trlogA_h[parity] += trlogA;
198  }
199  }
200  }
201 
202  template <int blockSize, typename Float, typename Clover>
203  __global__ void cloverInvertKernel(CloverInvertArg<Clover> arg) {
204  int idx = blockIdx.x*blockDim.x + threadIdx.x;
205  //if (idx >= arg.clover.volumeCB) return;
206  int parity = blockIdx.y;
207  double trlogA = 0.0;
208  if (idx < arg.clover.volumeCB) trlogA = cloverInvertCompute<blockSize, Float>(arg, idx, parity);
209 
210  if (arg.computeTraceLog) {
211  typedef cub::BlockReduce<double, blockSize> BlockReduce;
212  __shared__ typename BlockReduce::TempStorage temp_storage;
213  double aggregate = BlockReduce(temp_storage).Sum(trlogA);
214  if (threadIdx.x == 0) atomicAdd(arg.trlogA_d+parity, aggregate);
215  }
216 
217  }
218 
219  template <typename Float, typename Clover>
220  class CloverInvert : Tunable {
221  CloverInvertArg<Clover> arg;
222  const CloverField &meta; // used for meta data only
224 
225  private:
226  unsigned int sharedBytesPerThread() const { return 0; }
227  unsigned int sharedBytesPerBlock(const TuneParam &param) const { return 0 ;}
228 
229  bool tuneSharedBytes() const { return false; } // Don't tune the shared memory
230  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
231  unsigned int minThreads() const { return arg.clover.volumeCB; }
232 
233  public:
234  CloverInvert(CloverInvertArg<Clover> &arg, const CloverField &meta, QudaFieldLocation location)
235  : arg(arg), meta(meta), location(location) {
236  writeAuxString("stride=%d,prec=%lu",arg.clover.stride,sizeof(Float));
237  }
238  virtual ~CloverInvert() { ; }
239 
240  void apply(const cudaStream_t &stream) {
241  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
242  arg.trlogA_h[0] = 0.0; arg.trlogA_h[1] = 0.0;
244  tp.grid.y = 2; // for parity
245  LAUNCH_KERNEL(cloverInvertKernel, tp, stream, arg, Float, Clover);
246  } else {
247  cloverInvert<1, Float, Clover>(arg);
248  }
249  if (arg.computeTraceLog) {
250  cudaDeviceSynchronize();
251  reduceDoubleArray(arg.trlogA_h, 2);
252  }
253  }
254 
255  TuneKey tuneKey() const {
256  return TuneKey(meta.VolString(), typeid(*this).name(), aux);
257  }
258 
259  std::string paramString(const TuneParam &param) const { // Don't bother printing the grid dim.
260  std::stringstream ps;
261  ps << "block=(" << param.block.x << "," << param.block.y << "," << param.block.z << "), ";
262  ps << "shared=" << param.shared_bytes;
263  return ps.str();
264  }
265 
266  long long flops() const { return 0; }
267  long long bytes() const { return 2*arg.clover.volumeCB*(arg.inverse.Bytes() + arg.clover.Bytes()); }
268  };
269 
270  template <typename Float, typename Clover>
271  void cloverInvert(Clover inverse, const Clover clover, bool computeTraceLog,
272  double* const trlog, const CloverField &meta, QudaFieldLocation location) {
273  CloverInvertArg<Clover> arg(inverse, clover, computeTraceLog, trlog);
274  CloverInvert<Float,Clover> invert(arg, meta, location);
275  invert.apply(0);
276  cudaDeviceSynchronize();
277  }
278 
279  template <typename Float>
280  void cloverInvert(const CloverField &clover, bool computeTraceLog, QudaFieldLocation location) {
281  if (clover.Order() == QUDA_FLOAT2_CLOVER_ORDER) {
282  cloverInvert<Float>(FloatNOrder<Float,72,2>(clover, 1),
283  FloatNOrder<Float,72,2>(clover, 0),
284  computeTraceLog, clover.TrLog(), clover, location);
285  } else if (clover.Order() == QUDA_FLOAT4_CLOVER_ORDER) {
286  cloverInvert<Float>(FloatNOrder<Float,72,4>(clover, 1),
287  FloatNOrder<Float,72,4>(clover, 0),
288  computeTraceLog, clover.TrLog(), clover, location);
289  } else {
290  errorQuda("Clover field %d order not supported", clover.Order());
291  }
292 
293  }
294 
295 #endif
296 
297  // this is the function that is actually called, from here on down we instantiate all required templates
298  void cloverInvert(CloverField &clover, bool computeTraceLog, QudaFieldLocation location) {
299 
300 #ifdef GPU_CLOVER_DIRAC
301  if (clover.Precision() == QUDA_HALF_PRECISION && clover.Order() > 4)
302  errorQuda("Half precision not supported for order %d", clover.Order());
303 
304  if (clover.Precision() == QUDA_DOUBLE_PRECISION) {
305  cloverInvert<double>(clover, computeTraceLog, location);
306  } else if (clover.Precision() == QUDA_SINGLE_PRECISION) {
307  cloverInvert<float>(clover, computeTraceLog, location);
308  } else {
309  errorQuda("Precision %d not supported", clover.Precision());
310  }
311 #else
312  errorQuda("Clover has not been built");
313 #endif
314  }
315 
316 } // namespace quda
__host__ __device__ ValueType norm(const complex< ValueType > &z)
Returns the magnitude of z squared.
Definition: complex_quda.h:859
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:20
#define errorQuda(...)
Definition: util_quda.h:73
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:105
cudaStream_t * stream
::std::string string
Definition: gtest.h:1979
QudaGaugeParam param
Definition: pack_test.cpp:17
QudaPrecision Precision() const
void cloverInvert(CloverField &clover, bool computeTraceLog, QudaFieldLocation location)
cudaColorSpinorField * tmp
const QudaFieldLocation location
Definition: pack_test.cpp:46
FloatingPoint< float > Float
Definition: gtest.h:7350
QudaCloverFieldOrder Order() const
Definition: clover_field.h:66
void reduceDoubleArray(double *, const int len)
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:271
int x[4]
__host__ __device__ ValueType log(ValueType x)
Definition: complex_quda.h:90
enum QudaFieldLocation_s QudaFieldLocation
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
Definition: complex_quda.h:843
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:115
QudaTune getTuning()
Definition: util_quda.cpp:32
const QudaParity parity
Definition: dslash_test.cpp:29