QUDA  v0.7.0
A library for QCD on GPUs
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
clover_trace_quda.cu
Go to the documentation of this file.
1 #include <quda_internal.h>
2 #include <quda_matrix.h>
3 #include <tune_quda.h>
4 #include <clover_field.h>
5 #include <gauge_field.h>
6 #include <gauge_field_order.h>
7 
8 namespace CloverOrder {
9  using namespace quda;
10 #include <clover_field_order.h>
11 } // CloverOrder
12 
13 
14 namespace quda {
15 
16 #ifdef GPU_CLOVER_DIRAC
17 
18  template<typename Clover1, typename Clover2, typename Gauge>
19  struct CloverTraceArg {
20  Clover1 clover1;
21  Clover2 clover2;
22  Gauge gauge;
23  int dir1;
24  int dir2;
25 
26  CloverTraceArg(Clover1 &clover1, Clover2 &clover2, Gauge &gauge, int dir1, int dir2)
27  : clover1(clover1), clover2(clover2), gauge(gauge), dir1(dir1), dir2(dir2) {}
28  };
29 
30 
31  template <typename Float, typename Clover1, typename Clover2, typename Gauge>
32  __device__ __host__ void cloverSigmaTraceCompute(CloverTraceArg<Clover1, Clover2, Gauge>& arg, int x, int parity)
33  {
34 
35  Float A[72];
36  typedef typename ComplexTypeId<Float>::Type Complex;
37 
39  setZero(&mat);
40 
41  // load the clover term into memory
42 
43  int dir1 = arg.dir1;
44  int dir2 = arg.dir2;
45 
46 
47  Float sign = 1;
48  if(dir2 < dir1){
49  int tmp = dir2;
50  dir2 = dir1;
51  dir1 = tmp;
52  sign = -1;
53  }
54 
55 
56  Float diag[2][6];
57  complex<Float> tri[2][15];
58  const int idtab[15]={0,1,3,6,10,2,4,7,11,5,8,12,9,13,14};
59  complex<Float> ctmp;
60 
61  if(parity==0){
62  arg.clover1.load(A,x,parity);
63  }else{
64  arg.clover2.load(A,x,parity);
65  }
66 
67  for(int ch=0; ch<2; ++ch){
68  // factor of two is inherent to QUDA clover storage
69  for (int i=0; i<6; i++) diag[ch][i] = 2.0*A[ch*36+i];
70  for (int i=0; i<15; i++) tri[ch][idtab[i]] = complex<Float>(2.0*A[ch*36+6+2*i], 2.0*A[ch*36+6+2*i+1]);
71  }
72 
73 
74  // X, Y
75  if(dir1 == 0){
76  if(dir2 == 1){
77  for(int j=0; j<3; ++j){
78  mat(j,j).y = diag[0][j+3] + diag[1][j+3] - diag[0][j] - diag[1][j];
79  }
80 
81  // triangular part
82  int jk=0;
83  for(int j=1; j<3; ++j){
84  int jk2 = (j+3)*(j+2)/2 + 3;
85  for(int k=0; k<j; ++k){
86  ctmp = tri[0][jk2] + tri[1][jk2] - tri[0][jk] - tri[1][jk];
87 
88  mat(j,k).x = -ctmp.imag();
89  mat(j,k).y = ctmp.real();
90 
91  mat(k,j).x = ctmp.imag();
92  mat(k,j).y = ctmp.real();
93 
94  jk++; jk2++;
95  }
96  } // X Y
97 
98 
99  }else if(dir2 == 2){
100 
101  for(int j=0; j<3; ++j){
102  int jk = (j+3)*(j+2)/2;
103  for(int k=0; k<3; ++k){
104  int kj = (k+3)*(k+2)/2 + j;
105  ctmp = conj(tri[0][kj]) - tri[0][jk] + conj(tri[1][kj]) - tri[1][jk];
106  mat(j,k).x = ctmp.real();
107  mat(j,k).y = ctmp.imag();
108  jk++;
109  }
110  } // X Z
111 
112  }else if(dir2 == 3){
113  for(int j=0; j<3; ++j){
114  int jk = (j+3)*(j+2)/2;
115  for(int k=0; k<3; ++k){
116  int kj = (k+3)*(k+2)/2 + j;
117  ctmp = conj(tri[0][kj]) + tri[0][jk] - conj(tri[1][kj]) - tri[1][jk];
118  mat(j,k).x = -ctmp.imag();
119  mat(j,k).y = ctmp.real();
120  jk++;
121  }
122  }
123 
124  } // dir2 == 3 // X T
125 
126  }else if(dir1 == 1){
127  if(dir2 == 2){ // Y Z
128  for(int j=0; j<3; ++j){
129  int jk = (j+3)*(j+2)/2;
130  for(int k=0; k<3; ++k){
131  int kj = (k+3)*(k+2)/2 + j;
132  ctmp = conj(tri[0][kj]) + tri[0][jk] + conj(tri[1][kj]) + tri[1][jk];
133  mat(j,k).x = ctmp.imag();
134  mat(j,k).y = -ctmp.real();
135  jk++;
136  }
137  }
138  }else if(dir2 == 3){ // Y T
139  for(int j=0; j<3; ++j){
140  int jk = (j+3)*(j+2)/2;
141  for(int k=0; k<3; ++k){
142  int kj = (k+3)*(k+2)/2 + j;
143  ctmp = conj(tri[0][kj]) - tri[0][jk] - conj(tri[1][kj]) + tri[1][jk];
144  mat(j,k).x = ctmp.real();
145  mat(j,k).y = ctmp.imag();
146  jk++;
147  }
148  }
149  } // dir2 == 3
150  } // dir1 == 1
151  else if(dir1 == 2){
152  if(dir2 == 3){
153  for(int j=0; j<3; ++j){
154  mat(j,j).y = diag[0][j] - diag[0][j+3] - diag[1][j] + diag[1][j+3];
155  }
156  int jk=0;
157  for(int j=1; j<3; ++j){
158  int jk2 = (j+3)*(j+2)/2 + 3;
159  for(int k=0; k<j; ++k){
160  ctmp = tri[0][jk] - tri[0][jk2] - tri[1][jk] + tri[1][jk2];
161  mat(j,k).x = -ctmp.imag();
162  mat(j,k).y = ctmp.real();
163 
164  mat(k,j).x = ctmp.imag();
165  mat(k,j).y = ctmp.real();
166  jk++; jk2++;
167  }
168  }
169  }
170  }
171  // if we dir1 and dir2 were swapped, multiply by -1
172  mat *= sign;
173 
174  arg.gauge.save((Float*)(mat.data), x, 0, parity);
175 
176  return;
177  }
178 
179  template<typename Float, typename Clover1, typename Clover2, typename Gauge>
180  void cloverSigmaTrace(CloverTraceArg<Clover1,Clover2,Gauge> arg)
181  {
182  for(int x=0; x<arg.clover1.volumeCB; x++){
183  cloverSigmaTraceCompute<Float,Clover1,Clover2,Gauge>(arg, x, 1);
184  }
185  return;
186  }
187 
188 
189  template<typename Float, typename Clover1, typename Clover2, typename Gauge>
190  __global__ void cloverSigmaTraceKernel(CloverTraceArg<Clover1,Clover2,Gauge> arg)
191  {
192  int idx = blockIdx.x*blockDim.x + threadIdx.x;
193  if(idx >= arg.clover1.volumeCB) return;
194  // odd parity
195  cloverSigmaTraceCompute<Float,Clover1,Clover2,Gauge>(arg, idx, 1);
196  }
197 
198  template<typename Float, typename Clover1, typename Clover2, typename Gauge>
199  class CloverSigmaTrace : Tunable {
200  CloverTraceArg<Clover1,Clover2,Gauge> arg;
201  const GaugeField &meta;
203 
204  private:
205  unsigned int sharedBytesPerThread() const { return 0; }
206  unsigned int sharedBytesPerBlock(const TuneParam &param) const { return 0; }
207 
208  bool tuneSharedBytes() const { return false; } // Don't tune the shared memory
209  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
210  unsigned int minThreads() const { return arg.clover1.volumeCB; }
211 
212  public:
213  CloverSigmaTrace(CloverTraceArg<Clover1,Clover2,Gauge> &arg, const GaugeField &meta, QudaFieldLocation location)
214  : arg(arg), meta(meta), location(location) {
215  writeAuxString("stride=%d", arg.clover1.stride);
216  }
217  virtual ~CloverSigmaTrace() {;}
218 
219  void apply(const cudaStream_t &stream){
221 #if (__COMPUTE_CAPABILITY__ >= 200)
222  dim3 blockDim(128, 1, 1);
223  dim3 gridDim((arg.clover1.volumeCB + blockDim.x - 1)/blockDim.x, 1, 1);
224  cloverSigmaTraceKernel<Float,Clover1,Clover2,Gauge><<<gridDim,blockDim,0>>>(arg);
225 #else
226  errorQuda("cloverSigmaTrace not supported on pre-Fermi architecture");
227 #endif
228  }else{
229  cloverSigmaTrace<Float,Clover1,Clover2,Gauge>(arg);
230  }
231  }
232 
233  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
234 
235  std::string paramString(const TuneParam &param) const { // Don't print the grid dim.
236  std::stringstream ps;
237  ps << "block=(" << param.block.x << "," << param.block.y << "," << param.block.z << "), ";
238  ps << "shared=" << param.shared_bytes;
239  return ps.str();
240  }
241 
242  long long flops() const { return 0; } // Fix this
243  long long bytes() const { return 0; } // Fix this
244 
245  }; // CloverSigmaTrace
246 
247 
248  template<typename Float, typename Clover1, typename Clover2, typename Gauge>
249  void computeCloverSigmaTrace(Clover1 clover1, Clover2 clover2, Gauge gauge, int dir1, int dir2,
251  {
252  CloverTraceArg<Clover1, Clover2, Gauge> arg(clover1, clover2, gauge, dir1, dir2);
253  CloverSigmaTrace<Float,Clover1,Clover2,Gauge> traceCompute(arg, meta, location);
254  traceCompute.apply(0);
255  cudaDeviceSynchronize();
256  return;
257  }
258 
259 
260 
261  template<typename Float>
262  void computeCloverSigmaTrace(GaugeField& gauge, const CloverField& clover, int dir1, int dir2,
263  QudaFieldLocation location){
264 
265  if(clover.Order() == QUDA_FLOAT2_CLOVER_ORDER){
266  if(gauge.Order() == QUDA_FLOAT2_GAUGE_ORDER){
267  if(gauge.Reconstruct() == QUDA_RECONSTRUCT_NO){
268  computeCloverSigmaTrace<Float>(CloverOrder::quda::FloatNOrder<Float,72,2>(clover,0),
270  FloatNOrder<Float, 18, 2, 18>(gauge), dir1, dir2, gauge, location);
271  }else if(gauge.Reconstruct() == QUDA_RECONSTRUCT_12){
272  computeCloverSigmaTrace<Float>(CloverOrder::quda::FloatNOrder<Float,72,2>(clover,0),
274  FloatNOrder<Float, 18, 2, 12>(gauge), dir1, dir2, gauge, location);
275 
276  }else{
277  errorQuda("Reconstruction type %d not supported",gauge.Reconstruct());
278  }
279 
280  }else if(gauge.Order() == QUDA_FLOAT4_GAUGE_ORDER){
281  if(gauge.Reconstruct() == QUDA_RECONSTRUCT_12){
282  computeCloverSigmaTrace<Float>(CloverOrder::quda::FloatNOrder<Float,72,2>(clover,0),
284  FloatNOrder<Float,18,4,12>(gauge), dir1, dir2, gauge, location);
285  }else{
286  errorQuda("Reconstruction type %d not supported",gauge.Reconstruct());
287  }
288  }
289  }else if(clover.Order() == QUDA_FLOAT4_CLOVER_ORDER){
290  if(gauge.Order() == QUDA_FLOAT2_GAUGE_ORDER){
291  if(gauge.Reconstruct() == QUDA_RECONSTRUCT_NO){
292  computeCloverSigmaTrace<Float>(CloverOrder::quda::FloatNOrder<Float,72,4>(clover,0),
294  FloatNOrder<Float,18,2,18>(gauge), dir1, dir2, gauge, location);
295  }else if(gauge.Reconstruct() == QUDA_RECONSTRUCT_12){
296  computeCloverSigmaTrace<Float>(CloverOrder::quda::FloatNOrder<Float,72,4>(clover,0),
298  FloatNOrder<Float,18,2,12>(gauge), dir1, dir2, gauge, location);
299  }else{
300  errorQuda("Reconstruction type %d not supported",gauge.Reconstruct());
301  }
302  }else if(gauge.Order() == QUDA_FLOAT4_GAUGE_ORDER){
303  errorQuda("Reconstruction type %d not supported",gauge.Reconstruct());
304  }
305  } // clover order
306  }
307 
308 #endif
309 
310  void computeCloverSigmaTrace(GaugeField& gauge, const CloverField& clover, int dir1, int dir2,
311  QudaFieldLocation location){
312 
313 #ifdef GPU_CLOVER_DIRAC
314  if(clover.Precision() == QUDA_HALF_PRECISION){
315  errorQuda("Half precision not supported\n");
316  }
317 
318  if(clover.Precision() == QUDA_SINGLE_PRECISION){
319  computeCloverSigmaTrace<float>(gauge, clover, dir1, dir2, location);
320  }else if(clover.Precision() == QUDA_DOUBLE_PRECISION){
321  computeCloverSigmaTrace<double>(gauge, clover, dir1, dir2, location);
322  }else{
323  errorQuda("Precision %d not supported", clover.Precision());
324  }
325 #else
326  errorQuda("Clover has not been built");
327 #endif
328 
329  }
330 
331 
332 } // namespace quda
__device__ __host__ void setZero(Matrix< T, N > *m)
Definition: quda_matrix.h:640
Matrix< N, std::complex< T > > conj(const Matrix< N, std::complex< T > > &mat)
#define errorQuda(...)
Definition: util_quda.h:73
std::complex< double > Complex
Definition: eig_variables.h:13
QudaGaugeFieldOrder Order() const
Definition: gauge_field.h:169
cudaStream_t * stream
::std::string string
Definition: gtest.h:1979
void mat(void *out, void **fatlink, void **longlink, void *in, double kappa, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision)
__host__ __device__ ValueType imag() const volatile
QudaGaugeParam param
Definition: pack_test.cpp:17
QudaPrecision Precision() const
cudaColorSpinorField * tmp
const QudaFieldLocation location
Definition: pack_test.cpp:46
FloatingPoint< float > Float
Definition: gtest.h:7350
QudaCloverFieldOrder Order() const
Definition: clover_field.h:66
T data[N *N]
Definition: quda_matrix.h:351
QudaReconstructType Reconstruct() const
Definition: gauge_field.h:168
__host__ __device__ ValueType real() const volatile
void computeCloverSigmaTrace(GaugeField &gauge, const CloverField &clover, int dir1, int dir2, QudaFieldLocation location)
int x[4]
enum QudaFieldLocation_s QudaFieldLocation
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
Definition: complex_quda.h:843
const QudaParity parity
Definition: dslash_test.cpp:29
void * gauge[4]
Definition: su3_test.cpp:15