QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
gauge_fix_ovr_hit_devf.cuh
Go to the documentation of this file.
1 #ifndef _GAUGE_FIX_OVR_HIT_DEVF_H
2 #define _GAUGE_FIX_OVR_HIT_DEVF_H
3 
4 
5 #include <quda_internal.h>
6 #include <quda_matrix.h>
7 #include <atomic.cuh>
8 
9 namespace quda {
10 
11 
12 
13  template<class T>
14  struct SharedMemory
15  {
16  __device__ inline operator T*()
17  {
18  extern __shared__ int __smem[];
19  return (T*)__smem;
20  }
21 
22  __device__ inline operator const T*() const
23  {
24  extern __shared__ int __smem[];
25  return (T*)__smem;
26  }
27  };
28 
35  template<int NCOLORS>
36  static __host__ __device__ inline void IndexBlock(int block, int &p, int &q){
37  if ( NCOLORS == 3 ) {
38  if ( block == 0 ) { p = 0; q = 1; }
39  else if ( block == 1 ) { p = 1; q = 2; }
40  else{ p = 0; q = 2; }
41  }
42  else{
43  int i1;
44  int found = 0;
45  int del_i = 0;
46  int index = -1;
47  while ( del_i < (NCOLORS - 1) && found == 0 ) {
48  del_i++;
49  for ( i1 = 0; i1 < (NCOLORS - del_i); i1++ ) {
50  index++;
51  if ( index == block ) {
52  found = 1;
53  break;
54  }
55  }
56  }
57  q = i1 + del_i;
58  p = i1;
59  }
60  }
61 
62 
68  template<int blockSize, typename Float, int gauge_dir, int NCOLORS>
69  __forceinline__ __device__ void GaugeFixHit_AtomicAdd(Matrix<complex<Float>,NCOLORS> &link, const Float relax_boost, const int tid){
70 
71  //Container for the four real parameters of SU(2) subgroup in shared memory
72  //__shared__ Float elems[blockSize * 4];
73  Float *elems = SharedMemory<Float>();
74  //initialize shared memory
75  if ( threadIdx.x < blockSize * 4 ) elems[threadIdx.x] = 0.0;
76  __syncthreads();
77 
78 
79  //Loop over all SU(2) subroups of SU(N)
80  //#pragma unroll
81  for ( int block = 0; block < (NCOLORS * (NCOLORS - 1) / 2); block++ ) {
82  int p, q;
83  //Get the two indices for the SU(N) matrix
84  IndexBlock<NCOLORS>(block, p, q);
85  Float asq = 1.0;
86  if ( threadIdx.x < blockSize * 4 ) asq = -1.0;
87  //FOR COULOMB AND LANDAU!!!!!!!!
88  //if(nu0<gauge_dir){
89  //In terms of thread index
90  if ( threadIdx.x < blockSize * gauge_dir || (threadIdx.x >= blockSize * 4 && threadIdx.x < blockSize * (gauge_dir + 4))) {
91  //Retrieve the four SU(2) parameters...
92  // a0
93  atomicAdd(elems + tid, (link(p,p)).x + (link(q,q)).x); //a0
94  // a1
95  atomicAdd(elems + tid + blockSize, (link(p,q).y + link(q,p).y) * asq); //a1
96  // a2
97  atomicAdd(elems + tid + blockSize * 2, (link(p,q).x - link(q,p).x) * asq); //a2
98  // a3
99  atomicAdd(elems + tid + blockSize * 3, (link(p,p).y - link(q,q).y) * asq); //a3
100  } //FLOP per lattice site = gauge_dir * 2 * (4 + 7) = gauge_dir * 22
101  __syncthreads();
102  if ( threadIdx.x < blockSize ) {
103  //Over-relaxation boost
104  asq = elems[threadIdx.x + blockSize] * elems[threadIdx.x + blockSize];
105  asq += elems[threadIdx.x + blockSize * 2] * elems[threadIdx.x + blockSize * 2];
106  asq += elems[threadIdx.x + blockSize * 3] * elems[threadIdx.x + blockSize * 3];
107  Float a0sq = elems[threadIdx.x] * elems[threadIdx.x];
108  Float x = (relax_boost * a0sq + asq) / (a0sq + asq);
109  Float r = rsqrt((a0sq + x * x * asq));
110  elems[threadIdx.x] *= r;
111  elems[threadIdx.x + blockSize] *= x * r;
112  elems[threadIdx.x + blockSize * 2] *= x * r;
113  elems[threadIdx.x + blockSize * 3] *= x * r;
114  } //FLOP per lattice site = 22CUB: "Collective" Software Primitives for CUDA Kernel Development
115  __syncthreads();
116  //_____________
117  if ( threadIdx.x < blockSize * 4 ) {
118  complex<Float> m0;
119  //Do SU(2) hit on all upward links
120  //left multiply an su3_matrix by an su2 matrix
121  //link <- u * link
122  //#pragma unroll
123  for ( int j = 0; j < NCOLORS; j++ ) {
124  m0 = link(p,j);
125  link(p,j) = complex<Float>( elems[tid], elems[tid + blockSize * 3] ) * m0 + complex<Float>( elems[tid + blockSize * 2], elems[tid + blockSize] ) * link(q,j);
126  link(q,j) = complex<Float>(-elems[tid + blockSize * 2], elems[tid + blockSize]) * m0 + complex<Float>( elems[tid],-elems[tid + blockSize * 3] ) * link(q,j);
127  }
128  }
129  else{
130  complex<Float> m0;
131  //Do SU(2) hit on all downward links
132  //right multiply an su3_matrix by an su2 matrix
133  //link <- link * u_adj
134  //#pragma unroll
135  for ( int j = 0; j < NCOLORS; j++ ) {
136  m0 = link(j,p);
137  link(j,p) = complex<Float>( elems[tid], -elems[tid + blockSize * 3] ) * m0 + complex<Float>( elems[tid + blockSize * 2], -elems[tid + blockSize] ) * link(j,q);
138  link(j,q) = complex<Float>(-elems[tid + blockSize * 2], -elems[tid + blockSize]) * m0 + complex<Float>( elems[tid],elems[tid + blockSize * 3] ) * link(j,q);
139  }
140  }
141  //_____________ //FLOP per lattice site = 8 * NCOLORS * 2 * (2*6+2) = NCOLORS * 224
142  if ( block < (NCOLORS * (NCOLORS - 1) / 2) - 1 ) {
143  __syncthreads();
144  //reset shared memory SU(2) elements
145  if ( threadIdx.x < blockSize * 4 ) elems[threadIdx.x] = 0.0;
146  __syncthreads();
147  }
148  } //FLOP per lattice site = (block < NCOLORS * ( NCOLORS - 1) / 2) * (22 + 28 gauge_dir + 224 NCOLORS)
149  //write updated link to global memory
150  }
151 
152 
153 
158  template<int blockSize, typename Float, int gauge_dir, int NCOLORS>
159  __forceinline__ __device__ void GaugeFixHit_NoAtomicAdd(Matrix<complex<Float>,NCOLORS> &link, const Float relax_boost, const int tid){
160 
161  //Container for the four real parameters of SU(2) subgroup in shared memory
162  //__shared__ Float elems[blockSize * 4 * 8];
163  Float *elems = SharedMemory<Float>();
164 
165 
166  //Loop over all SU(2) subroups of SU(N)
167  //#pragma unroll
168  for ( int block = 0; block < (NCOLORS * (NCOLORS - 1) / 2); block++ ) {
169  int p, q;
170  //Get the two indices for the SU(N) matrix
171  IndexBlock<NCOLORS>(block, p, q);
172  /*Float asq = 1.0;
173  if(threadIdx.x < blockSize * 4) asq = -1.0;
174  if(threadIdx.x < blockSize * gauge_dir || (threadIdx.x >= blockSize * 4 && threadIdx.x < blockSize * (gauge_dir + 4))){
175  elems[threadIdx.x] = link(p,p).x + link(q,q).x;
176  elems[threadIdx.x + blockSize * 8] = (link(p,q).y + link(q,p).y) * asq;
177  elems[threadIdx.x + blockSize * 8 * 2] = (link(p,q).x - link(q,p).x) * asq;
178  elems[threadIdx.x + blockSize * 8 * 3] = (link(p,p).y - link(q,q).y) * asq;
179  }*/ //FLOP per lattice site = gauge_dir * 2 * 7 = gauge_dir * 14
180  if ( threadIdx.x < blockSize * gauge_dir ) {
181  elems[threadIdx.x] = link(p,p).x + link(q,q).x;
182  elems[threadIdx.x + blockSize * 8] = -(link(p,q).y + link(q,p).y);
183  elems[threadIdx.x + blockSize * 8 * 2] = -(link(p,q).x - link(q,p).x);
184  elems[threadIdx.x + blockSize * 8 * 3] = -(link(p,p).y - link(q,q).y);
185  }
186  if ((threadIdx.x >= blockSize * 4 && threadIdx.x < blockSize * (gauge_dir + 4))) {
187  elems[threadIdx.x] = link(p,p).x + link(q,q).x;
188  elems[threadIdx.x + blockSize * 8] = (link(p,q).y + link(q,p).y);
189  elems[threadIdx.x + blockSize * 8 * 2] = (link(p,q).x - link(q,p).x);
190  elems[threadIdx.x + blockSize * 8 * 3] = (link(p,p).y - link(q,q).y);
191  }
192  //FLOP per lattice site = gauge_dir * 2 * 7 = gauge_dir * 14
193  __syncthreads();
194  if ( threadIdx.x < blockSize ) {
195  Float a0, a1, a2, a3;
196  a0 = 0.0; a1 = 0.0; a2 = 0.0; a3 = 0.0;
197  #pragma unroll
198  for ( int i = 0; i < gauge_dir; i++ ) {
199  a0 += elems[tid + i * blockSize] + elems[tid + (i + 4) * blockSize];
200  a1 += elems[tid + i * blockSize + blockSize * 8] + elems[tid + (i + 4) * blockSize + blockSize * 8];
201  a2 += elems[tid + i * blockSize + blockSize * 8 * 2] + elems[tid + (i + 4) * blockSize + blockSize * 8 * 2];
202  a3 += elems[tid + i * blockSize + blockSize * 8 * 3] + elems[tid + (i + 4) * blockSize + blockSize * 8 * 3];
203  }
204  //Over-relaxation boost
205  Float asq = a1 * a1 + a2 * a2 + a3 * a3;
206  Float a0sq = a0 * a0;
207  Float x = (relax_boost * a0sq + asq) / (a0sq + asq);
208  Float r = rsqrt((a0sq + x * x * asq));
209  elems[threadIdx.x] = a0 * r;
210  elems[threadIdx.x + blockSize] = a1 * x * r;
211  elems[threadIdx.x + blockSize * 2] = a2 * x * r;
212  elems[threadIdx.x + blockSize * 3] = a3 * x * r;
213  } //FLOP per lattice site = 22 + 8 * 4
214  __syncthreads();
215  //_____________
216  if ( threadIdx.x < blockSize * 4 ) {
217  complex<Float> m0;
218  //Do SU(2) hit on all upward links
219  //left multiply an su3_matrix by an su2 matrix
220  //link <- u * link
221  //#pragma unroll
222  for ( int j = 0; j < NCOLORS; j++ ) {
223  m0 = link(p,j);
224  link(p,j) = complex<Float>( elems[tid], elems[tid + blockSize * 3] ) * m0 + complex<Float>( elems[tid + blockSize * 2], elems[tid + blockSize] ) * link(q,j);
225  link(q,j) = complex<Float>(-elems[tid + blockSize * 2], elems[tid + blockSize]) * m0 + complex<Float>( elems[tid],-elems[tid + blockSize * 3] ) * link(q,j);
226  }
227  }
228  else{
229  complex<Float> m0;
230  //Do SU(2) hit on all downward links
231  //right multiply an su3_matrix by an su2 matrix
232  //link <- link * u_adj
233  //#pragma unroll
234  for ( int j = 0; j < NCOLORS; j++ ) {
235  m0 = link(j,p);
236  link(j,p) = complex<Float>( elems[tid], -elems[tid + blockSize * 3] ) * m0 + complex<Float>( elems[tid + blockSize * 2], -elems[tid + blockSize] ) * link(j,q);
237  link(j,q) = complex<Float>(-elems[tid + blockSize * 2], -elems[tid + blockSize]) * m0 + complex<Float>(elems[tid],elems[tid + blockSize * 3] ) * link(j,q);
238  }
239  }
240  //_____________ //FLOP per lattice site = 8 * NCOLORS * 2 * (2*6+2) = NCOLORS * 224
241  if ( block < (NCOLORS * (NCOLORS - 1) / 2) - 1 ) { __syncthreads(); }
242  } //FLOP per lattice site = (NCOLORS * ( NCOLORS - 1) / 2) * (22 + 28 gauge_dir + 224 NCOLORS)
243  //write updated link to global memory
244  }
245 
246 
247 
253  template<int blockSize, typename Float, int gauge_dir, int NCOLORS>
254  __forceinline__ __device__ void GaugeFixHit_NoAtomicAdd_LessSM(Matrix<complex<Float>,NCOLORS> &link, const Float relax_boost, const int tid){
255 
256  //Container for the four real parameters of SU(2) subgroup in shared memory
257  //__shared__ Float elems[blockSize * 4 * 8];
258  Float *elems = SharedMemory<Float>();
259 
260  //Loop over all SU(2) subroups of SU(N)
261  //#pragma unroll
262  for ( int block = 0; block < (NCOLORS * (NCOLORS - 1) / 2); block++ ) {
263  int p, q;
264  //Get the two indices for the SU(N) matrix
265  IndexBlock<NCOLORS>(block, p, q);
266 
267  if ( threadIdx.x < blockSize ) {
268  elems[tid] = link(p,p).x + link(q,q).x;
269  elems[tid + blockSize] = -(link(p,q).y + link(q,p).y);
270  elems[tid + blockSize * 2] = -(link(p,q).x - link(q,p).x);
271  elems[tid + blockSize * 3] = -(link(p,p).y - link(q,q).y);
272  }
273  __syncthreads();
274  if ( threadIdx.x < blockSize * 2 && threadIdx.x >= blockSize ) {
275  elems[tid] += link(p,p).x + link(q,q).x;
276  elems[tid + blockSize] -= (link(p,q).y + link(q,p).y);
277  elems[tid + blockSize * 2] -= (link(p,q).x - link(q,p).x);
278  elems[tid + blockSize * 3] -= (link(p,p).y - link(q,q).y);
279  }
280  __syncthreads();
281  if ( threadIdx.x < blockSize * 3 && threadIdx.x >= blockSize * 2 ) {
282  elems[tid] += link(p,p).x + link(q,q).x;
283  elems[tid + blockSize] -= (link(p,q).y + link(q,p).y);
284  elems[tid + blockSize * 2] -= (link(p,q).x - link(q,p).x);
285  elems[tid + blockSize * 3] -= (link(p,p).y - link(q,q).y);
286  }
287  if ( gauge_dir == 4 ) {
288  __syncthreads();
289  if ( threadIdx.x < blockSize * 4 && threadIdx.x >= blockSize * 3 ) {
290  elems[tid] += link(p,p).x + link(q,q).x;
291  elems[tid + blockSize] -= (link(p,q).y + link(q,p).y);
292  elems[tid + blockSize * 2] -= (link(p,q).x - link(q,p).x);
293  elems[tid + blockSize * 3] -= (link(p,p).y - link(q,q).y);
294  }
295  }
296  __syncthreads();
297  if ( threadIdx.x < blockSize * 5 && threadIdx.x >= blockSize * 4 ) {
298  elems[tid] += link(p,p).x + link(q,q).x;
299  elems[tid + blockSize] += (link(p,q).y + link(q,p).y);
300  elems[tid + blockSize * 2] += (link(p,q).x - link(q,p).x);
301  elems[tid + blockSize * 3] += (link(p,p).y - link(q,q).y);
302  }
303  __syncthreads();
304  if ( threadIdx.x < blockSize * 6 && threadIdx.x >= blockSize * 5 ) {
305  elems[tid] += link(p,p).x + link(q,q).x;
306  elems[tid + blockSize] += (link(p,q).y + link(q,p).y);
307  elems[tid + blockSize * 2] += (link(p,q).x - link(q,p).x);
308  elems[tid + blockSize * 3] += (link(p,p).y - link(q,q).y);
309  }
310  __syncthreads();
311  if ( threadIdx.x < blockSize * 7 && threadIdx.x >= blockSize * 6 ) {
312  elems[tid] += link(p,p).x + link(q,q).x;
313  elems[tid + blockSize] += (link(p,q).y + link(q,p).y);
314  elems[tid + blockSize * 2] += (link(p,q).x - link(q,p).x);
315  elems[tid + blockSize * 3] += (link(p,p).y - link(q,q).y);
316  }
317  if ( gauge_dir == 4 ) {
318  __syncthreads();
319  if ( threadIdx.x < blockSize * 8 && threadIdx.x >= blockSize * 7 ) {
320  elems[tid] += link(p,p).x + link(q,q).x;
321  elems[tid + blockSize] += (link(p,q).y + link(q,p).y);
322  elems[tid + blockSize * 2] += (link(p,q).x - link(q,p).x);
323  elems[tid + blockSize * 3] += (link(p,p).y - link(q,q).y);
324  }
325  }
326  //FLOP per lattice site = gauge_dir * 2 * 7 = gauge_dir * 14
327  __syncthreads();
328  if ( threadIdx.x < blockSize ) {
329  Float asq = elems[tid + blockSize] * elems[tid + blockSize];
330  asq += elems[tid + blockSize * 2] * elems[tid + blockSize * 2];
331  asq += elems[tid + blockSize * 3] * elems[tid + blockSize * 3];
332  Float a0sq = elems[tid] * elems[tid];
333  Float x = (relax_boost * a0sq + asq) / (a0sq + asq);
334  Float r = rsqrt((a0sq + x * x * asq));
335  elems[tid] *= r;
336  elems[tid + blockSize] *= x * r;
337  elems[tid + blockSize * 2] *= x * r;
338  elems[tid + blockSize * 3] *= x * r;
339  } //FLOP per lattice site = 22 + 8 * 4
340  __syncthreads();
341  //_____________
342  if ( threadIdx.x < blockSize * 4 ) {
343  complex<Float> m0;
344  //Do SU(2) hit on all upward links
345  //left multiply an su3_matrix by an su2 matrix
346  //link <- u * link
347  //#pragma unroll
348  for ( int j = 0; j < NCOLORS; j++ ) {
349  m0 = link(p,j);
350  link(p,j) = complex<Float>( elems[tid], elems[tid + blockSize * 3] ) * m0 + complex<Float>( elems[tid + blockSize * 2], elems[tid + blockSize] ) * link(q,j);
351  link(q,j) = complex<Float>(-elems[tid + blockSize * 2], elems[tid + blockSize]) * m0 + complex<Float>( elems[tid],-elems[tid + blockSize * 3] ) * link(q,j);
352  }
353  }
354  else{
355  complex<Float> m0;
356  //Do SU(2) hit on all downward links
357  //right multiply an su3_matrix by an su2 matrix
358  //link <- link * u_adj
359  //#pragma unroll
360  for ( int j = 0; j < NCOLORS; j++ ) {
361  m0 = link(j,p);
362  link(j,p) = complex<Float>( elems[tid], -elems[tid + blockSize * 3] ) * m0 + complex<Float>( elems[tid + blockSize * 2], -elems[tid + blockSize] ) * link(j,q);
363  link(j,q) = complex<Float>(-elems[tid + blockSize * 2], -elems[tid + blockSize]) * m0 + complex<Float>( elems[tid],elems[tid + blockSize * 3] ) * link(j,q);
364  }
365  }
366  //_____________ //FLOP per lattice site = 8 * NCOLORS * 2 * (2*6+2) = NCOLORS * 224
367  if ( block < (NCOLORS * (NCOLORS - 1) / 2) - 1 ) { __syncthreads(); }
368  } //FLOP per lattice site = (NCOLORS * ( NCOLORS - 1) / 2) * (22 + 28 gauge_dir + 224 NCOLORS)
369  //write updated link to global memory
370  }
371 
372 
373 
374 
375 
376 
377 
378 
379 
380 
381 
382 
383 
384 
385 
391  template<int blockSize, typename Float, int gauge_dir, int NCOLORS>
392  __forceinline__ __device__ void GaugeFixHit_AtomicAdd(Matrix<complex<Float>,NCOLORS> &link, Matrix<complex<Float>,NCOLORS> &link1,
393  const Float relax_boost, const int tid){
394 
395  //Container for the four real parameters of SU(2) subgroup in shared memory
396  //__shared__ Float elems[blockSize * 4];
397  Float *elems = SharedMemory<Float>();
398  //initialize shared memory
399  if ( threadIdx.x < blockSize * 4 ) elems[threadIdx.x] = 0.0;
400  __syncthreads();
401 
402 
403  //Loop over all SU(2) subroups of SU(N)
404  //#pragma unroll
405  for ( int block = 0; block < (NCOLORS * (NCOLORS - 1) / 2); block++ ) {
406  int p, q;
407  //Get the two indices for the SU(N) matrix
408  IndexBlock<NCOLORS>(block, p, q);
409  if ( threadIdx.x < blockSize * gauge_dir ) {
410  //Retrieve the four SU(2) parameters...
411  // a0
412  atomicAdd(elems + tid, (link1(p,p)).x + (link1(q,q)).x + (link(p,p)).x + (link(q,q)).x); //a0
413  // a1
414  atomicAdd(elems + tid + blockSize, (link1(p,q).y + link1(q,p).y) - (link(p,q).y + link(q,p).y)); //a1
415  // a2
416  atomicAdd(elems + tid + blockSize * 2, (link1(p,q).x - link1(q,p).x) - (link(p,q).x - link(q,p).x)); //a2
417  // a3
418  atomicAdd(elems + tid + blockSize * 3, (link1(p,p).y - link1(q,q).y) - (link(p,p).y - link(q,q).y)); //a3
419  }
420  __syncthreads();
421  if ( threadIdx.x < blockSize ) {
422  //Over-relaxation boost
423  Float asq = elems[threadIdx.x + blockSize] * elems[threadIdx.x + blockSize];
424  asq += elems[threadIdx.x + blockSize * 2] * elems[threadIdx.x + blockSize * 2];
425  asq += elems[threadIdx.x + blockSize * 3] * elems[threadIdx.x + blockSize * 3];
426  Float a0sq = elems[threadIdx.x] * elems[threadIdx.x];
427  Float x = (relax_boost * a0sq + asq) / (a0sq + asq);
428  Float r = rsqrt((a0sq + x * x * asq));
429  elems[threadIdx.x] *= r;
430  elems[threadIdx.x + blockSize] *= x * r;
431  elems[threadIdx.x + blockSize * 2] *= x * r;
432  elems[threadIdx.x + blockSize * 3] *= x * r;
433  } //FLOP per lattice site = 22CUB: "Collective" Software Primitives for CUDA Kernel Development
434  __syncthreads();
435  complex<Float> m0;
436  //Do SU(2) hit on all upward links
437  //left multiply an su3_matrix by an su2 matrix
438  //link <- u * link
439  //#pragma unroll
440  for ( int j = 0; j < NCOLORS; j++ ) {
441  m0 = link(p,j);
442  link(p,j) = complex<Float>( elems[tid], elems[tid + blockSize * 3] ) * m0 +
443  complex<Float>( elems[tid + blockSize * 2], elems[tid + blockSize] ) * link(q,j);
444  link(q,j) = complex<Float>(-elems[tid + blockSize * 2], elems[tid + blockSize]) * m0 +
445  complex<Float>( elems[tid],-elems[tid + blockSize * 3] ) * link(q,j);
446  }
447  //Do SU(2) hit on all downward links
448  //right multiply an su3_matrix by an su2 matrix
449  //link <- link * u_adj
450  //#pragma unroll
451  for ( int j = 0; j < NCOLORS; j++ ) {
452  m0 = link1(j,p);
453  link1(j,p) = complex<Float>( elems[tid], -elems[tid + blockSize * 3] ) * m0 +
454  complex<Float>( elems[tid + blockSize * 2], -elems[tid + blockSize] ) * link1(j,q);
455  link1(j,q) = complex<Float>(-elems[tid + blockSize * 2], -elems[tid + blockSize]) * m0 +
456  complex<Float>( elems[tid],elems[tid + blockSize * 3] ) * link1(j,q);
457  }
458  if ( block < (NCOLORS * (NCOLORS - 1) / 2) - 1 ) {
459  __syncthreads();
460  //reset shared memory SU(2) elements
461  if ( threadIdx.x < blockSize * 4 ) elems[threadIdx.x] = 0.0;
462  __syncthreads();
463  }
464  }
465  }
466 
467 
468 
469 
470 
471 
472 
473 
474 
475 
476 
477 
478 
479 
480 
485  template<int blockSize, typename Float, int gauge_dir, int NCOLORS>
486  __forceinline__ __device__ void GaugeFixHit_NoAtomicAdd(Matrix<complex<Float>,NCOLORS> &link, Matrix<complex<Float>,NCOLORS> &link1,
487  const Float relax_boost, const int tid){
488 
489  //Container for the four real parameters of SU(2) subgroup in shared memory
490  //__shared__ Float elems[blockSize * 4 * 8];
491  Float *elems = SharedMemory<Float>();
492  //Loop over all SU(2) subroups of SU(N)
493  //#pragma unroll
494  for ( int block = 0; block < (NCOLORS * (NCOLORS - 1) / 2); block++ ) {
495  int p, q;
496  //Get the two indices for the SU(N) matrix
497  IndexBlock<NCOLORS>(block, p, q);
498  if ( threadIdx.x < blockSize * gauge_dir ) {
499  elems[threadIdx.x] = link1(p,p).x + link1(q,q).x + link(p,p).x + link(q,q).x;
500  elems[threadIdx.x + blockSize * 4] = (link1(p,q).y + link1(q,p).y) - (link(p,q).y + link(q,p).y);
501  elems[threadIdx.x + blockSize * 4 * 2] = (link1(p,q).x - link1(q,p).x) - (link(p,q).x - link(q,p).x);
502  elems[threadIdx.x + blockSize * 4 * 3] = (link1(p,p).y - link1(q,q).y) - (link(p,p).y - link(q,q).y);
503  }
504  __syncthreads();
505  if ( threadIdx.x < blockSize ) {
506  Float a0, a1, a2, a3;
507  a0 = 0.0; a1 = 0.0; a2 = 0.0; a3 = 0.0;
508  #pragma unroll
509  for ( int i = 0; i < gauge_dir; i++ ) {
510  a0 += elems[tid + i * blockSize];
511  a1 += elems[tid + i * blockSize + blockSize * 4];
512  a2 += elems[tid + i * blockSize + blockSize * 4 * 2];
513  a3 += elems[tid + i * blockSize + blockSize * 4 * 3];
514  }
515  //Over-relaxation boost
516  Float asq = a1 * a1 + a2 * a2 + a3 * a3;
517  Float a0sq = a0 * a0;
518  Float x = (relax_boost * a0sq + asq) / (a0sq + asq);
519  Float r = rsqrt((a0sq + x * x * asq));
520  elems[threadIdx.x] = a0 * r;
521  elems[threadIdx.x + blockSize] = a1 * x * r;
522  elems[threadIdx.x + blockSize * 2] = a2 * x * r;
523  elems[threadIdx.x + blockSize * 3] = a3 * x * r;
524  } //FLOP per lattice site = 22 + 8 * 4
525  __syncthreads();
526  complex<Float> m0;
527  //Do SU(2) hit on all upward links
528  //left multiply an su3_matrix by an su2 matrix
529  //link <- u * link
530  //#pragma unroll
531  for ( int j = 0; j < NCOLORS; j++ ) {
532  m0 = link(p,j);
533  link(p,j) = complex<Float>( elems[tid], elems[tid + blockSize * 3] ) * m0 +
534  complex<Float>( elems[tid + blockSize * 2], elems[tid + blockSize] ) * link(q,j);
535  link(q,j) = complex<Float>(-elems[tid + blockSize * 2], elems[tid + blockSize]) * m0 +
536  complex<Float>( elems[tid],-elems[tid + blockSize * 3] ) * link(q,j);
537  }
538  //Do SU(2) hit on all downward links
539  //right multiply an su3_matrix by an su2 matrix
540  //link <- link * u_adj
541  //#pragma unroll
542  for ( int j = 0; j < NCOLORS; j++ ) {
543  m0 = link1(j,p);
544  link1(j,p) = complex<Float>( elems[tid], -elems[tid + blockSize * 3] ) * m0 +
545  complex<Float>( elems[tid + blockSize * 2], -elems[tid + blockSize] ) * link1(j,q);
546  link1(j,q) = complex<Float>(-elems[tid + blockSize * 2], -elems[tid + blockSize]) * m0 +
547  complex<Float>( elems[tid],elems[tid + blockSize * 3] ) * link1(j,q);
548  }
549  if ( block < (NCOLORS * (NCOLORS - 1) / 2) - 1 ) { __syncthreads(); }
550  }
551  }
552 
553 
554 
555 
556 
562  template<int blockSize, typename Float, int gauge_dir, int NCOLORS>
563  __forceinline__ __device__ void GaugeFixHit_NoAtomicAdd_LessSM(Matrix<complex<Float>,NCOLORS> &link, Matrix<complex<Float>,NCOLORS> &link1, const Float relax_boost, const int tid){
564 
565  //Container for the four real parameters of SU(2) subgroup in shared memory
566  //__shared__ Float elems[blockSize * 4 * 8];
567  Float *elems = SharedMemory<Float>();
568 
569  //Loop over all SU(2) subroups of SU(N)
570  //#pragma unroll
571  for ( int block = 0; block < (NCOLORS * (NCOLORS - 1) / 2); block++ ) {
572  int p, q;
573  //Get the two indices for the SU(N) matrix
574  IndexBlock<NCOLORS>(block, p, q);
575  if ( threadIdx.x < blockSize ) {
576  elems[tid] = link1(p,p).x + link1(q,q).x + link(p,p).x + link(q,q).x;
577  elems[tid + blockSize] = (link1(p,q).y + link1(q,p).y) - (link(p,q).y + link(q,p).y);
578  elems[tid + blockSize * 2] = (link1(p,q).x - link1(q,p).x) - (link(p,q).x - link(q,p).x);
579  elems[tid + blockSize * 3] = (link1(p,p).y - link1(q,q).y) - (link(p,p).y - link(q,q).y);
580  }
581  __syncthreads();
582  if ( threadIdx.x < blockSize * 2 && threadIdx.x >= blockSize ) {
583  elems[tid] += link1(p,p).x + link1(q,q).x + link(p,p).x + link(q,q).x;
584  elems[tid + blockSize] += (link1(p,q).y + link1(q,p).y) - (link(p,q).y + link(q,p).y);
585  elems[tid + blockSize * 2] += (link1(p,q).x - link1(q,p).x) - (link(p,q).x - link(q,p).x);
586  elems[tid + blockSize * 3] += (link1(p,p).y - link1(q,q).y) - (link(p,p).y - link(q,q).y);
587  }
588  __syncthreads();
589  if ( threadIdx.x < blockSize * 3 && threadIdx.x >= blockSize * 2 ) {
590  elems[tid] += link1(p,p).x + link1(q,q).x + link(p,p).x + link(q,q).x;
591  elems[tid + blockSize] += (link1(p,q).y + link1(q,p).y) - (link(p,q).y + link(q,p).y);
592  elems[tid + blockSize * 2] += (link1(p,q).x - link1(q,p).x) - (link(p,q).x - link(q,p).x);
593  elems[tid + blockSize * 3] += (link1(p,p).y - link1(q,q).y) - (link(p,p).y - link(q,q).y);
594  }
595  if ( gauge_dir == 4 ) {
596  __syncthreads();
597  if ( threadIdx.x < blockSize * 4 && threadIdx.x >= blockSize * 3 ) {
598  elems[tid] += link1(p,p).x + link1(q,q).x + link(p,p).x + link(q,q).x;
599  elems[tid + blockSize] += (link1(p,q).y + link1(q,p).y) - (link(p,q).y + link(q,p).y);
600  elems[tid + blockSize * 2] += (link1(p,q).x - link1(q,p).x) - (link(p,q).x - link(q,p).x);
601  elems[tid + blockSize * 3] += (link1(p,p).y - link1(q,q).y) - (link(p,p).y - link(q,q).y);
602  }
603  }
604  __syncthreads();
605  if ( threadIdx.x < blockSize ) {
606  Float asq = elems[tid + blockSize] * elems[tid + blockSize];
607  asq += elems[tid + blockSize * 2] * elems[tid + blockSize * 2];
608  asq += elems[tid + blockSize * 3] * elems[tid + blockSize * 3];
609  Float a0sq = elems[tid] * elems[tid];
610  Float x = (relax_boost * a0sq + asq) / (a0sq + asq);
611  Float r = rsqrt((a0sq + x * x * asq));
612  elems[tid] *= r;
613  elems[tid + blockSize] *= x * r;
614  elems[tid + blockSize * 2] *= x * r;
615  elems[tid + blockSize * 3] *= x * r;
616  }
617  __syncthreads();
618  complex<Float> m0;
619  //Do SU(2) hit on all upward links
620  //left multiply an su3_matrix by an su2 matrix
621  //link <- u * link
622  //#pragma unroll
623  for ( int j = 0; j < NCOLORS; j++ ) {
624  m0 = link(p,j);
625  link(p,j) = complex<Float>( elems[tid], elems[tid + blockSize * 3] ) * m0 +
626  complex<Float>( elems[tid + blockSize * 2], elems[tid + blockSize] ) * link(q,j);
627  link(q,j) = complex<Float>(-elems[tid + blockSize * 2], elems[tid + blockSize]) * m0 +
628  complex<Float>( elems[tid],-elems[tid + blockSize * 3] ) * link(q,j);
629  }
630  //Do SU(2) hit on all downward links
631  //right multiply an su3_matrix by an su2 matrix
632  //link <- link * u_adj
633  //#pragma unroll
634  for ( int j = 0; j < NCOLORS; j++ ) {
635  m0 = link1(j,p);
636  link1(j,p) = complex<Float>( elems[tid], -elems[tid + blockSize * 3] ) * m0 +
637  complex<Float>( elems[tid + blockSize * 2], -elems[tid + blockSize] ) * link1(j,q);
638  link1(j,q) = complex<Float>(-elems[tid + blockSize * 2], -elems[tid + blockSize]) * m0 +
639  complex<Float>( elems[tid],elems[tid + blockSize * 3] ) * link1(j,q);
640  }
641  if ( block < (NCOLORS * (NCOLORS - 1) / 2) - 1 ) { __syncthreads(); }
642  }
643  }
644 
645 }
646 #endif
static __host__ __device__ void IndexBlock(int block, int &p, int &q)
static __device__ double2 atomicAdd(double2 *addr, double2 val)
Implementation of double2 atomic addition using two double-precision additions.
Definition: atomic.cuh:51
__forceinline__ __device__ void GaugeFixHit_NoAtomicAdd(Matrix< complex< Float >, NCOLORS > &link, const Float relax_boost, const int tid)
__forceinline__ __device__ void GaugeFixHit_NoAtomicAdd_LessSM(Matrix< complex< Float >, NCOLORS > &link, const Float relax_boost, const int tid)
__forceinline__ __device__ void GaugeFixHit_AtomicAdd(Matrix< complex< Float >, NCOLORS > &link, const Float relax_boost, const int tid)
static int index(int ndim, const int *dims, const int *x)
Definition: comm_common.cpp:32