QUDA  v1.1.0
A library for QCD on GPUs
pgauge_heatbath.cu
Go to the documentation of this file.
1 #include <quda_internal.h>
2 #include <quda_matrix.h>
3 #include <tune_quda.h>
4 #include <gauge_field.h>
5 #include <gauge_field_order.h>
6 #include <launch_kernel.cuh>
7 #include <comm_quda.h>
8 #include <pgauge_monte.h>
9 #include <gauge_tools.h>
10 #include <random_quda.h>
11 #include <index_helper.cuh>
12 #include <atomic.cuh>
13 #include <instantiate.h>
14 
15 #ifndef PI
16 #define PI 3.1415926535897932384626433832795 // pi
17 #endif
18 #ifndef PII
19 #define PII 6.2831853071795864769252867665590 // 2 * pi
20 #endif
21 
22 namespace quda {
23 
24  /**
25  @brief Calculate the SU(2) index block in the SU(Nc) matrix
26  @param block number to calculate the index's, the total number of blocks is NCOLORS * ( NCOLORS - 1) / 2.
27  @return Returns two index's in int2 type, accessed by .x and .y.
28  */
29  template<int NCOLORS>
30  __host__ __device__ static inline int2 IndexBlock(int block){
31  int2 id;
32  int i1;
33  int found = 0;
34  int del_i = 0;
35  int index = -1;
36  while ( del_i < (NCOLORS - 1) && found == 0 ) {
37  del_i++;
38  for ( i1 = 0; i1 < (NCOLORS - del_i); i1++ ) {
39  index++;
40  if ( index == block ) {
41  found = 1;
42  break;
43  }
44  }
45  }
46  id.y = i1 + del_i;
47  id.x = i1;
48  return id;
49  }
50 
51  /**
52  @brief Calculate the SU(2) index block in the SU(Nc) matrix
53  @param block number to calculate de index's, the total number of blocks is NCOLORS * ( NCOLORS - 1) / 2.
54  @param p store the first index
55  @param q store the second index
56  */
57  template<int NCOLORS>
58  __host__ __device__ static inline void IndexBlock(int block, int &p, int &q){
59  if ( NCOLORS == 3 ) {
60  if ( block == 0 ) { p = 0; q = 1; }
61  else if ( block == 1 ) { p = 1; q = 2; }
62  else { p = 0; q = 2; }
63  } else if ( NCOLORS > 3 ) {
64  int i1;
65  int found = 0;
66  int del_i = 0;
67  int index = -1;
68  while ( del_i < (NCOLORS - 1) && found == 0 ) {
69  del_i++;
70  for ( i1 = 0; i1 < (NCOLORS - del_i); i1++ ) {
71  index++;
72  if ( index == block ) {
73  found = 1;
74  break;
75  }
76  }
77  }
78  q = i1 + del_i;
79  p = i1;
80  }
81  }
82 
83  /**
84  @brief Generate full SU(2) matrix (four real numbers instead of 2x2 complex matrix) and update link matrix.
85  Get from MILC code.
86  @param al weight
87  @param localstate CURAND rng state
88  */
89  template <class T>
90  __device__ static inline Matrix<T,2> generate_su2_matrix_milc(T al, cuRNGState& localState){
91  T xr1, xr2, xr3, xr4, d, r;
92  int k;
93  xr1 = Random<T>(localState);
94  xr1 = (log((xr1 + 1.e-10)));
95  xr2 = Random<T>(localState);
96  xr2 = (log((xr2 + 1.e-10)));
97  xr3 = Random<T>(localState);
98  xr4 = Random<T>(localState);
99  xr3 = cos(PII * xr3);
100  d = -(xr2 + xr1 * xr3 * xr3 ) / al;
101  //now beat each site into submission
102  int nacd = 0;
103  if ((1.00 - 0.5 * d) > xr4 * xr4 ) nacd = 1;
104  if ( nacd == 0 && al > 2.0 ) { //k-p algorithm
105  for ( k = 0; k < 20; k++ ) {
106  //get four random numbers (add a small increment to prevent taking log(0.)
107  xr1 = Random<T>(localState);
108  xr1 = (log((xr1 + 1.e-10)));
109  xr2 = Random<T>(localState);
110  xr2 = (log((xr2 + 1.e-10)));
111  xr3 = Random<T>(localState);
112  xr4 = Random<T>(localState);
113  xr3 = cos(PII * xr3);
114  d = -(xr2 + xr1 * xr3 * xr3) / al;
115  if ((1.00 - 0.5 * d) > xr4 * xr4 ) break;
116  }
117  } //endif nacd
118  Matrix<T,2> a;
119  if ( nacd == 0 && al <= 2.0 ) { //creutz algorithm
120  xr3 = exp(-2.0 * al);
121  xr4 = 1.0 - xr3;
122  for ( k = 0; k < 20; k++ ) {
123  //get two random numbers
124  xr1 = Random<T>(localState);
125  xr2 = Random<T>(localState);
126  r = xr3 + xr4 * xr1;
127  a(0,0) = 1.00 + log(r) / al;
128  if ((1.0 - a(0,0) * a(0,0)) > xr2 * xr2 ) break;
129  }
130  d = 1.0 - a(0,0);
131  } //endif nacd
132  //generate the four su(2) elements
133  //find a0 = 1 - d
134  a(0,0) = 1.0 - d;
135  //compute r
136  xr3 = 1.0 - a(0,0) * a(0,0);
137  xr3 = abs(xr3);
138  r = sqrt(xr3);
139  //compute a3
140  a(1,1) = (2.0 * Random<T>(localState) - 1.0) * r;
141  //compute a1 and a2
142  xr1 = xr3 - a(1,1) * a(1,1);
143  xr1 = abs(xr1);
144  xr1 = sqrt(xr1);
145  //xr2 is a random number between 0 and 2*pi
146  xr2 = PII * Random<T>(localState);
147  a(0,1) = xr1 * cos(xr2);
148  a(1,0) = xr1 * sin(xr2);
149  return a;
150  }
151 
152  /**
153  @brief Return SU(2) subgroup (4 real numbers) from SU(3) matrix
154  @param tmp1 input SU(3) matrix
155  @param block to retrieve from 0 to 2.
156  @return 4 real numbers
157  */
158  template < class T>
159  __host__ __device__ static inline Matrix<T,2> get_block_su2( Matrix<complex<T>,3> tmp1, int block ){
160  Matrix<T,2> r;
161  switch ( block ) {
162  case 0:
163  r(0,0) = tmp1(0,0).x + tmp1(1,1).x;
164  r(0,1) = tmp1(0,1).y + tmp1(1,0).y;
165  r(1,0) = tmp1(0,1).x - tmp1(1,0).x;
166  r(1,1) = tmp1(0,0).y - tmp1(1,1).y;
167  break;
168  case 1:
169  r(0,0) = tmp1(1,1).x + tmp1(2,2).x;
170  r(0,1) = tmp1(1,2).y + tmp1(2,1).y;
171  r(1,0) = tmp1(1,2).x - tmp1(2,1).x;
172  r(1,1) = tmp1(1,1).y - tmp1(2,2).y;
173  break;
174  case 2:
175  r(0,0) = tmp1(0,0).x + tmp1(2,2).x;
176  r(0,1) = tmp1(0,2).y + tmp1(2,0).y;
177  r(1,0) = tmp1(0,2).x - tmp1(2,0).x;
178  r(1,1) = tmp1(0,0).y - tmp1(2,2).y;
179  break;
180  }
181  return r;
182  }
183 
184  /**
185  @brief Return SU(2) subgroup (4 real numbers) from SU(Nc) matrix
186  @param tmp1 input SU(Nc) matrix
187  @param id the two indices to retrieve SU(2) block
188  @return 4 real numbers
189  */
190  template <class T, int NCOLORS>
191  __host__ __device__ static inline Matrix<T,2> get_block_su2( Matrix<complex<T>,NCOLORS> tmp1, int2 id ){
192  Matrix<T,2> r;
193  r(0,0) = tmp1(id.x,id.x).x + tmp1(id.y,id.y).x;
194  r(0,1) = tmp1(id.x,id.y).y + tmp1(id.y,id.x).y;
195  r(1,0) = tmp1(id.x,id.y).x - tmp1(id.y,id.x).x;
196  r(1,1) = tmp1(id.x,id.x).y - tmp1(id.y,id.y).y;
197  return r;
198  }
199 
200  /**
201  @brief Create a SU(Nc) identity matrix and fills with the SU(2) block
202  @param rr SU(2) matrix represented only by four real numbers
203  @param id the two indices to fill in the SU(3) matrix
204  @return SU(Nc) matrix
205  */
206  template <class T, int NCOLORS>
207  __host__ __device__ static inline Matrix<complex<T>,NCOLORS> block_su2_to_sun( Matrix<T,2> rr, int2 id ){
208  Matrix<complex<T>,NCOLORS> tmp1;
209  setIdentity(&tmp1);
210  tmp1(id.x,id.x) = complex<T>( rr(0,0), rr(1,1) );
211  tmp1(id.x,id.y) = complex<T>( rr(1,0), rr(0,1) );
212  tmp1(id.y,id.x) = complex<T>(-rr(1,0), rr(0,1) );
213  tmp1(id.y,id.y) = complex<T>( rr(0,0),-rr(1,1) );
214  return tmp1;
215  }
216 
217  /**
218  @brief Update the SU(Nc) link with the new SU(2) matrix, link <- u * link
219  @param u SU(2) matrix represented by four real numbers
220  @param link SU(Nc) matrix
221  @param id indices
222  */
223  template <class T, int NCOLORS>
224  __host__ __device__ static inline void mul_block_sun( Matrix<T,2> u, Matrix<complex<T>,NCOLORS> &link, int2 id ){
225  for ( int j = 0; j < NCOLORS; j++ ) {
226  complex<T> tmp = complex<T>( u(0,0), u(1,1) ) * link(id.x, j) + complex<T>( u(1,0), u(0,1) ) * link(id.y, j);
227  link(id.y, j) = complex<T>(-u(1,0), u(0,1) ) * link(id.x, j) + complex<T>( u(0,0),-u(1,1) ) * link(id.y, j);
228  link(id.x, j) = tmp;
229  }
230  }
231 
232  /**
233  @brief Update the SU(3) link with the new SU(2) matrix, link <- u * link
234  @param U SU(3) matrix
235  @param a00 element (0,0) of the SU(2) matrix
236  @param a01 element (0,1) of the SU(2) matrix
237  @param a10 element (1,0) of the SU(2) matrix
238  @param a11 element (1,1) of the SU(2) matrix
239  @param block of the SU(3) matrix, 0,1 or 2
240  */
241  template <class Cmplx>
242  __host__ __device__ static inline void block_su2_to_su3( Matrix<Cmplx,3> &U, Cmplx a00, Cmplx a01, Cmplx a10, Cmplx a11, int block ){
243  Cmplx tmp;
244  switch ( block ) {
245  case 0:
246  tmp = a00 * U(0,0) + a01 * U(1,0);
247  U(1,0) = a10 * U(0,0) + a11 * U(1,0);
248  U(0,0) = tmp;
249  tmp = a00 * U(0,1) + a01 * U(1,1);
250  U(1,1) = a10 * U(0,1) + a11 * U(1,1);
251  U(0,1) = tmp;
252  tmp = a00 * U(0,2) + a01 * U(1,2);
253  U(1,2) = a10 * U(0,2) + a11 * U(1,2);
254  U(0,2) = tmp;
255  break;
256  case 1:
257  tmp = a00 * U(1,0) + a01 * U(2,0);
258  U(2,0) = a10 * U(1,0) + a11 * U(2,0);
259  U(1,0) = tmp;
260  tmp = a00 * U(1,1) + a01 * U(2,1);
261  U(2,1) = a10 * U(1,1) + a11 * U(2,1);
262  U(1,1) = tmp;
263  tmp = a00 * U(1,2) + a01 * U(2,2);
264  U(2,2) = a10 * U(1,2) + a11 * U(2,2);
265  U(1,2) = tmp;
266  break;
267  case 2:
268  tmp = a00 * U(0,0) + a01 * U(2,0);
269  U(2,0) = a10 * U(0,0) + a11 * U(2,0);
270  U(0,0) = tmp;
271  tmp = a00 * U(0,1) + a01 * U(2,1);
272  U(2,1) = a10 * U(0,1) + a11 * U(2,1);
273  U(0,1) = tmp;
274  tmp = a00 * U(0,2) + a01 * U(2,2);
275  U(2,2) = a10 * U(0,2) + a11 * U(2,2);
276  U(0,2) = tmp;
277  break;
278  }
279  }
280 
281  // v * u^dagger
282  template <class Float>
283  __host__ __device__ static inline Matrix<Float,2> mulsu2UVDagger(Matrix<Float,2> v, Matrix<Float,2> u){
284  Matrix<Float,2> b;
285  b(0,0) = v(0,0) * u(0,0) + v(0,1) * u(0,1) + v(1,0) * u(1,0) + v(1,1) * u(1,1);
286  b(0,1) = v(0,1) * u(0,0) - v(0,0) * u(0,1) + v(1,0) * u(1,1) - v(1,1) * u(1,0);
287  b(1,0) = v(1,0) * u(0,0) - v(0,0) * u(1,0) + v(1,1) * u(0,1) - v(0,1) * u(1,1);
288  b(1,1) = v(1,1) * u(0,0) - v(0,0) * u(1,1) + v(0,1) * u(1,0) - v(1,0) * u(0,1);
289  return b;
290  }
291 
292  /**
293  @brief Link update by pseudo-heatbath
294  @param U link to be updated
295  @param F staple
296  @param localstate CURAND rng state
297  */
298  template <class Float, int NCOLORS>
299  __device__ inline void heatBathSUN( Matrix<complex<Float>,NCOLORS>& U, Matrix<complex<Float>,NCOLORS> F,
300  cuRNGState& localState, Float BetaOverNc ){
301 
302  if ( NCOLORS == 3 ) {
303  //////////////////////////////////////////////////////////////////
304  /*
305  for( int block = 0; block < NCOLORS; block++ ) {
306  Matrix<complex<T>,3> tmp1 = U * F;
307  Matrix<T,2> r = get_block_su2<T>(tmp1, block);
308  T k = sqrt(r(0,0)*r(0,0)+r(0,1)*r(0,1)+r(1,0)*r(1,0)+r(1,1)*r(1,1));
309  T ap = BetaOverNc * k;
310  k = (T)1.0 / k;
311  r *= k;
312  //Matrix<T,2> a = generate_su2_matrix<T4, T>(ap, localState);
313  Matrix<T,2> a = generate_su2_matrix_milc<T>(ap, localState);
314  r = mulsu2UVDagger_4<T>( a, r);
315  ///////////////////////////////////////
316  block_su2_to_su3<T>( U, complex( r(0,0), r(1,1) ), complex( r(1,0), r(0,1) ), complex(-r(1,0), r(0,1) ), complex( r(0,0),-r(1,1) ), block );
317  //FLOP_min = (198 + 4 + 15 + 28 + 28 + 84) * 3 = 1071
318  }*/
319  //////////////////////////////////////////////////////////////////
320 
321  for ( int block = 0; block < NCOLORS; block++ ) {
322  int p,q;
323  IndexBlock<NCOLORS>(block, p, q);
324  complex<Float> a0((Float)0.0, (Float)0.0);
325  complex<Float> a1 = a0;
326  complex<Float> a2 = a0;
327  complex<Float> a3 = a0;
328 
329  for ( int j = 0; j < NCOLORS; j++ ) {
330  a0 += U(p,j) * F(j,p);
331  a1 += U(p,j) * F(j,q);
332  a2 += U(q,j) * F(j,p);
333  a3 += U(q,j) * F(j,q);
334  }
335  Matrix<Float,2> r;
336  r(0,0) = a0.x + a3.x;
337  r(0,1) = a1.y + a2.y;
338  r(1,0) = a1.x - a2.x;
339  r(1,1) = a0.y - a3.y;
340  Float k = sqrt(r(0,0) * r(0,0) + r(0,1) * r(0,1) + r(1,0) * r(1,0) + r(1,1) * r(1,1));;
341  Float ap = BetaOverNc * k;
342  k = 1.0 / k;
343  r *= k;
344  Matrix<Float,2> a = generate_su2_matrix_milc<Float>(ap, localState);
345  r = mulsu2UVDagger<Float>( a, r);
346  ///////////////////////////////////////
347  a0 = complex<Float>( r(0,0), r(1,1) );
348  a1 = complex<Float>( r(1,0), r(0,1) );
349  a2 = complex<Float>(-r(1,0), r(0,1) );
350  a3 = complex<Float>( r(0,0),-r(1,1) );
351  complex<Float> tmp0;
352 
353  for ( int j = 0; j < NCOLORS; j++ ) {
354  tmp0 = a0 * U(p,j) + a1 * U(q,j);
355  U(q,j) = a2 * U(p,j) + a3 * U(q,j);
356  U(p,j) = tmp0;
357  }
358  //FLOP_min = (NCOLORS * 64 + 19 + 28 + 28) * 3 = NCOLORS * 192 + 225
359  }
360  //////////////////////////////////////////////////////////////////
361  } else if ( NCOLORS > 3 ) {
362  //////////////////////////////////////////////////////////////////
363  //TESTED IN SU(4) SP THIS IS WORST
364  Matrix<complex<Float>,NCOLORS> M = U * F;
365  for ( int block = 0; block < NCOLORS * ( NCOLORS - 1) / 2; block++ ) {
366  int2 id = IndexBlock<NCOLORS>( block );
367  Matrix<Float,2> r = get_block_su2<Float>(M, id);
368  Float k = sqrt(r(0,0) * r(0,0) + r(0,1) * r(0,1) + r(1,0) * r(1,0) + r(1,1) * r(1,1));
369  Float ap = BetaOverNc * k;
370  k = 1.0 / k;
371  r *= k;
372  Matrix<Float,2> a = generate_su2_matrix_milc<Float>(ap, localState);
373  Matrix<Float,2> rr = mulsu2UVDagger<Float>( a, r);
374  ///////////////////////////////////////
375  mul_block_sun<Float, NCOLORS>( rr, U, id);
376  mul_block_sun<Float, NCOLORS>( rr, M, id);
377  ///////////////////////////////////////
378  }
379  /* / TESTED IN SU(4) SP THIS IS FASTER
380  for ( int block = 0; block < NCOLORS * ( NCOLORS - 1) / 2; block++ ) {
381  int2 id = IndexBlock<NCOLORS>( block );
382  complex a0 = complex::zero();
383  complex a1 = complex::zero();
384  complex a2 = complex::zero();
385  complex a3 = complex::zero();
386 
387  for ( int j = 0; j < NCOLORS; j++ ) {
388  a0 += U(id.x, j) * F.e[j][id.x];
389  a1 += U(id.x, j) * F.e[j][id.y];
390  a2 += U(id.y, j) * F.e[j][id.x];
391  a3 += U(id.y, j) * F.e[j][id.y];
392  }
393  Matrix<T,2> r;
394  r(0,0) = a0.x + a3.x;
395  r(0,1) = a1.y + a2.y;
396  r(1,0) = a1.x - a2.x;
397  r(1,1) = a0.y - a3.y;
398  T k = sqrt(r(0,0) * r(0,0) + r(0,1) * r(0,1) + r(1,0) * r(1,0) + r(1,1) * r(1,1));
399  T ap = BetaOverNc * k;
400  k = (T)1.0 / k;
401  r *= k;
402  //Matrix<T,2> a = generate_su2_matrix<T4, T>(ap, localState);
403  Matrix<T,2> a = generate_su2_matrix_milc<T>(ap, localState);
404  r = mulsu2UVDagger<T>( a, r);
405  mul_block_sun<T>( r, U, id); */
406  /*
407  a0 = complex( r(0,0), r(1,1) );
408  a1 = complex( r(1,0), r(0,1) );
409  a2 = complex(-r(1,0), r(0,1) );
410  a3 = complex( r(0,0),-r(1,1) );
411  complex tmp0;
412 
413  for ( int j = 0; j < NCOLORS; j++ ) {
414  tmp0 = a0 * U(id.x, j) + a1 * U(id.y, j);
415  U(id.y, j) = a2 * U(id.x, j) + a3 * U(id.y, j);
416  U(id.x, j) = tmp0;
417  } */
418  // }
419 
420  }
421  //////////////////////////////////////////////////////////////////
422  }
423 
424  //////////////////////////////////////////////////////////////////////////
425  /**
426  @brief Link update by overrelaxation
427  @param U link to be updated
428  @param F staple
429  */
430  template <class Float, int NCOLORS>
431  __device__ inline void overrelaxationSUN( Matrix<complex<Float>,NCOLORS>& U, Matrix<complex<Float>,NCOLORS> F ){
432 
433  if ( NCOLORS == 3 ) {
434  //////////////////////////////////////////////////////////////////
435  /*
436  for( int block = 0; block < 3; block++ ) {
437  Matrix<complex<T>,3> tmp1 = U * F;
438  Matrix<T,2> r = get_block_su2<T>(tmp1, block);
439  //normalize and conjugate
440  Float norm = 1.0 / sqrt(r(0,0)*r(0,0)+r(0,1)*r(0,1)+r(1,0)*r(1,0)+r(1,1)*r(1,1));;
441  r(0,0) *= norm;
442  r(0,1) *= -norm;
443  r(1,0) *= -norm;
444  r(1,1) *= -norm;
445  ///////////////////////////////////////
446  complex a00 = complex( r(0,0), r(1,1) );
447  complex a01 = complex( r(1,0), r(0,1) );
448  complex a10 = complex(-r(1,0), r(0,1) );
449  complex a11 = complex( r(0,0),-r(1,1) );
450  block_su2_to_su3<T>( U, a00, a01, a10, a11, block );
451  block_su2_to_su3<T>( U, a00, a01, a10, a11, block );
452 
453  //FLOP = (198 + 17 + 84 * 2) * 3 = 1149
454  }*/
455  ///////////////////////////////////////////////////////////////////
456  //This version does not need to multiply all matrix at each block: tmp1 = U * F;
457  //////////////////////////////////////////////////////////////////
458 
459  for ( int block = 0; block < 3; block++ ) {
460  int p,q;
461  IndexBlock<NCOLORS>(block, p, q);
462  complex<Float> a0((Float)0., (Float)0.);
463  complex<Float> a1 = a0;
464  complex<Float> a2 = a0;
465  complex<Float> a3 = a0;
466 
467  for ( int j = 0; j < NCOLORS; j++ ) {
468  a0 += U(p,j) * F(j,p);
469  a1 += U(p,j) * F(j,q);
470  a2 += U(q,j) * F(j,p);
471  a3 += U(q,j) * F(j,q);
472  }
473  Matrix<Float,2> r;
474  r(0,0) = a0.x + a3.x;
475  r(0,1) = a1.y + a2.y;
476  r(1,0) = a1.x - a2.x;
477  r(1,1) = a0.y - a3.y;
478  //normalize and conjugate
479  //r = r.conj_normalize();
480  Float norm = 1.0 / sqrt(r(0,0) * r(0,0) + r(0,1) * r(0,1) + r(1,0) * r(1,0) + r(1,1) * r(1,1));;
481  r(0,0) *= norm;
482  r(0,1) *= -norm;
483  r(1,0) *= -norm;
484  r(1,1) *= -norm;
485 
486 
487  ///////////////////////////////////////
488  a0 = complex<Float>( r(0,0), r(1,1) );
489  a1 = complex<Float>( r(1,0), r(0,1) );
490  a2 = complex<Float>(-r(1,0), r(0,1) );
491  a3 = complex<Float>( r(0,0),-r(1,1) );
492  complex<Float> tmp0, tmp1;
493 
494  for ( int j = 0; j < NCOLORS; j++ ) {
495  tmp0 = a0 * U(p,j) + a1 * U(q,j);
496  tmp1 = a2 * U(p,j) + a3 * U(q,j);
497  U(p,j) = a0 * tmp0 + a1 * tmp1;
498  U(q,j) = a2 * tmp0 + a3 * tmp1;
499  }
500  //FLOP = (NCOLORS * 88 + 17) * 3
501  }
502  ///////////////////////////////////////////////////////////////////
503  }
504  else if ( NCOLORS > 3 ) {
505  ///////////////////////////////////////////////////////////////////
506  Matrix<complex<Float>,NCOLORS> M = U * F;
507  for ( int block = 0; block < NCOLORS * ( NCOLORS - 1) / 2; block++ ) {
508  int2 id = IndexBlock<NCOLORS>( block );
509  Matrix<Float,2> r = get_block_su2<Float, NCOLORS>(M, id);
510  //normalize and conjugate
511  Float norm = 1.0 / sqrt(r(0,0) * r(0,0) + r(0,1) * r(0,1) + r(1,0) * r(1,0) + r(1,1) * r(1,1));;
512  r(0,0) *= norm;
513  r(0,1) *= -norm;
514  r(1,0) *= -norm;
515  r(1,1) *= -norm;
516  mul_block_sun<Float, NCOLORS>( r, U, id);
517  mul_block_sun<Float, NCOLORS>( r, U, id);
518  mul_block_sun<Float, NCOLORS>( r, M, id);
519  mul_block_sun<Float, NCOLORS>( r, M, id);
520  ///////////////////////////////////////
521  }
522  /* //TESTED IN SU(4) SP THIS IS WORST
523  for( int block = 0; block < NCOLORS * ( NCOLORS - 1) / 2; block++ ) {
524  int2 id = IndexBlock<NCOLORS>( block );
525  complex a0 = complex::zero();
526  complex a1 = complex::zero();
527  complex a2 = complex::zero();
528  complex a3 = complex::zero();
529 
530  for(int j = 0; j < NCOLORS; j++){
531  a0 += U(id.x, j) * F.e[j][id.x];
532  a1 += U(id.x, j) * F.e[j][id.y];
533  a2 += U(id.y, j) * F.e[j][id.x];
534  a3 += U(id.y, j) * F.e[j][id.y];
535  }
536  Matrix<T,2> r;
537  r(0,0) = a0.x + a3.x;
538  r(0,1) = a1.y + a2.y;
539  r(1,0) = a1.x - a2.x;
540  r(1,1) = a0.y - a3.y;
541  //normalize and conjugate
542  Float norm = 1.0 / sqrt(r(0,0)*r(0,0)+r(0,1)*r(0,1)+r(1,0)*r(1,0)+r(1,1)*r(1,1));;
543  r(0,0) *= norm;
544  r(0,1) *= -norm;
545  r(1,0) *= -norm;
546  r(1,1) *= -norm;
547  //mul_block_sun<T>( r, U, id);
548  //mul_block_sun<T>( r, U, id);
549  ///////////////////////////////////////
550  a0 = complex( r(0,0), r(1,1) );
551  a1 = complex( r(1,0), r(0,1) );
552  a2 = complex(-r(1,0), r(0,1) );
553  a3 = complex( r(0,0),-r(1,1) );
554  complex tmp0, tmp1;
555 
556  for(int j = 0; j < NCOLORS; j++){
557  tmp0 = a0 * U(id.x, j) + a1 * U(id.y, j);
558  tmp1 = a2 * U(id.x, j) + a3 * U(id.y, j);
559  U(id.x, j) = a0 * tmp0 + a1 * tmp1;
560  U(id.y, j) = a2 * tmp0 + a3 * tmp1;
561  }
562  }
563  */
564  }
565  }
566 
567  template <typename Gauge, typename Float, int NCOLORS>
568  struct MonteArg {
569  int threads; // number of active threads required
570  int X[4]; // grid dimensions
571  int border[4];
572  Gauge dataOr;
573  GaugeField &data;
574  Float BetaOverNc;
575  RNG rngstate;
576  MonteArg(const Gauge &dataOr, GaugeField & data, Float Beta, RNG &rngstate)
577  : dataOr(dataOr), data(data), rngstate(rngstate) {
578  BetaOverNc = Beta / (Float)NCOLORS;
579  for ( int dir = 0; dir < 4; ++dir ) {
580  border[dir] = data.R()[dir];
581  X[dir] = data.X()[dir] - border[dir] * 2;
582  }
583  threads = X[0] * X[1] * X[2] * X[3] >> 1;
584  }
585  };
586 
587  template<typename Float, typename Gauge, int NCOLORS, bool HeatbathOrRelax>
588  __global__ void compute_heatBath(MonteArg<Gauge, Float, NCOLORS> arg, int mu, int parity){
589  int idx = threadIdx.x + blockIdx.x * blockDim.x;
590  if ( idx >= arg.threads ) return;
591  int id = idx;
592  int X[4];
593 #pragma unroll
594  for ( int dr = 0; dr < 4; ++dr ) X[dr] = arg.X[dr];
595 
596  int x[4];
597  getCoords(x, idx, X, parity);
598 #pragma unroll
599  for ( int dr = 0; dr < 4; ++dr ) {
600  x[dr] += arg.border[dr];
601  X[dr] += 2 * arg.border[dr];
602  }
603  idx = linkIndex(x,X);
604 
605  Matrix<complex<Float>,NCOLORS> staple;
606  setZero(&staple);
607 
608  Matrix<complex<Float>,NCOLORS> U;
609  for ( int nu = 0; nu < 4; nu++ ) if ( mu != nu ) {
610  int dx[4] = { 0, 0, 0, 0 };
611  Matrix<complex<Float>,NCOLORS> link = arg.dataOr(nu, idx, parity);
612  dx[nu]++;
613  U = arg.dataOr(mu, linkIndexShift(x,dx,X), 1 - parity);
614  link *= U;
615  dx[nu]--;
616  dx[mu]++;
617  U = arg.dataOr(nu, linkIndexShift(x,dx,X), 1 - parity);
618  link *= conj(U);
619  staple += link;
620  dx[mu]--;
621  dx[nu]--;
622  link = arg.dataOr(nu, linkIndexShift(x,dx,X), 1 - parity);
623  U = arg.dataOr(mu, linkIndexShift(x,dx,X), 1 - parity);
624  link = conj(link) * U;
625  dx[mu]++;
626  U = arg.dataOr(nu, linkIndexShift(x,dx,X), parity);
627  link *= U;
628  staple += link;
629  }
630  U = arg.dataOr(mu, idx, parity);
631  if ( HeatbathOrRelax ) {
632  cuRNGState localState = arg.rngstate.State()[ id ];
633  heatBathSUN<Float, NCOLORS>( U, conj(staple), localState, arg.BetaOverNc );
634  arg.rngstate.State()[ id ] = localState;
635  }
636  else{
637  overrelaxationSUN<Float, NCOLORS>( U, conj(staple) );
638  }
639  arg.dataOr(mu, idx, parity) = U;
640  }
641 
642 
643  template<typename Float, typename Gauge, int NCOLORS, int NElems, bool HeatbathOrRelax>
644  class GaugeHB : Tunable {
645  MonteArg<Gauge, Float, NCOLORS> arg;
646  int mu;
647  int parity;
648  mutable char aux_string[128]; // used as a label in the autotuner
649  unsigned int sharedBytesPerThread() const {
650  return 0;
651  }
652  unsigned int sharedBytesPerBlock(const TuneParam &param) const {
653  return 0;
654  }
655  //bool tuneSharedBytes() const { return false; } // Don't tune shared memory
656  bool tuneGridDim() const {
657  return false;
658  } // Don't tune the grid dimensions.
659  unsigned int minThreads() const {
660  return arg.threads;
661  }
662 
663  public:
664  GaugeHB(MonteArg<Gauge, Float, NCOLORS> &arg)
665  : arg(arg), mu(0), parity(0) {
666  }
667 
668  void SetParam(int _mu, int _parity)
669  {
670  mu = _mu;
671  parity = _parity;
672  }
673 
674  void apply(const qudaStream_t &stream)
675  {
676  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
677  qudaLaunchKernel(compute_heatBath<Float, Gauge, NCOLORS, HeatbathOrRelax>, tp, stream, arg, mu, parity);
678  }
679 
680  TuneKey tuneKey() const {
681  std::stringstream vol;
682  vol << arg.X[0] << "x";
683  vol << arg.X[1] << "x";
684  vol << arg.X[2] << "x";
685  vol << arg.X[3];
686  sprintf(aux_string,"threads=%d,prec=%lu",arg.threads, sizeof(Float));
687  return TuneKey(vol.str().c_str(), typeid(*this).name(), aux_string);
688  }
689 
690  void preTune() {
691  arg.data.backup();
692  if(HeatbathOrRelax) arg.rngstate.backup();
693  }
694  void postTune() {
695  arg.data.restore();
696  if(HeatbathOrRelax) arg.rngstate.restore();
697  }
698  long long flops() const
699  {
700  //NEED TO CHECK THIS!!!!!!
701  if ( NCOLORS == 3 ) {
702  long long flop = 2268LL;
703  if ( HeatbathOrRelax ) {
704  flop += 801LL;
705  }
706  else{
707  flop += 843LL;
708  }
709  flop *= arg.threads;
710  return flop;
711  } else {
712  long long flop = NCOLORS * NCOLORS * NCOLORS * 84LL;
713  if ( HeatbathOrRelax ) {
714  flop += NCOLORS * NCOLORS * NCOLORS + (NCOLORS * ( NCOLORS - 1) / 2) * (46LL + 48LL + 56LL * NCOLORS);
715  }
716  else{
717  flop += NCOLORS * NCOLORS * NCOLORS + (NCOLORS * ( NCOLORS - 1) / 2) * (17LL + 112LL * NCOLORS);
718  }
719  flop *= arg.threads;
720  return flop;
721  }
722  }
723 
724  long long bytes() const
725  {
726  //NEED TO CHECK THIS!!!!!!
727  if ( NCOLORS == 3 ) {
728  long long byte = 20LL * NElems * sizeof(Float);
729  if ( HeatbathOrRelax ) byte += 2LL * sizeof(cuRNGState);
730  byte *= arg.threads;
731  return byte;
732  } else {
733  long long byte = 20LL * NCOLORS * NCOLORS * 2 * sizeof(Float);
734  if ( HeatbathOrRelax ) byte += 2LL * sizeof(cuRNGState);
735  byte *= arg.threads;
736  return byte;
737  }
738  }
739  };
740 
741  template <typename Float, int nColor, QudaReconstructType recon>
742  struct MonteAlg {
743  MonteAlg(GaugeField& data, RNG &rngstate, Float Beta, int nhb, int nover)
744  {
745  TimeProfile profileHBOVR("HeatBath_OR_Relax", false);
746  using Gauge = typename gauge_mapper<Float, recon>::type;
747 
748  MonteArg<Gauge, Float, nColor> montearg(Gauge(data), data, Beta, rngstate);
749  if (getVerbosity() >= QUDA_SUMMARIZE) profileHBOVR.TPSTART(QUDA_PROFILE_COMPUTE);
750  GaugeHB<Float, Gauge, nColor, recon, true> hb(montearg);
751  for ( int step = 0; step < nhb; ++step ) {
752  for ( int parity = 0; parity < 2; ++parity ) {
753  for ( int mu = 0; mu < 4; ++mu ) {
754  hb.SetParam(mu, parity);
755  hb.apply(0);
756  PGaugeExchange(data, mu, parity);
757  }
758  }
759  }
760  if (getVerbosity() >= QUDA_VERBOSE) {
761  qudaDeviceSynchronize();
762  profileHBOVR.TPSTOP(QUDA_PROFILE_COMPUTE);
763  double secs = profileHBOVR.Last(QUDA_PROFILE_COMPUTE);
764  double gflops = (hb.flops() * 8 * nhb * 1e-9) / (secs);
765  double gbytes = hb.bytes() * 8 * nhb / (secs * 1e9);
766  printfQuda("HB: Time = %6.6f s, Gflop/s = %6.1f, GB/s = %6.1f\n", secs, gflops * comm_size(), gbytes * comm_size());
767  }
768 
769  if (getVerbosity() >= QUDA_VERBOSE) profileHBOVR.TPSTART(QUDA_PROFILE_COMPUTE);
770  GaugeHB<Float, Gauge, nColor, recon, false> relax(montearg);
771  for ( int step = 0; step < nover; ++step ) {
772  for ( int parity = 0; parity < 2; ++parity ) {
773  for ( int mu = 0; mu < 4; ++mu ) {
774  relax.SetParam(mu, parity);
775  relax.apply(0);
776  PGaugeExchange(data, mu, parity);
777  }
778  }
779  }
780  if (getVerbosity() >= QUDA_VERBOSE) {
781  qudaDeviceSynchronize();
782  profileHBOVR.TPSTOP(QUDA_PROFILE_COMPUTE);
783  double secs = profileHBOVR.Last(QUDA_PROFILE_COMPUTE);
784  double gflops = (relax.flops() * 8 * nover * 1e-9) / (secs);
785  double gbytes = relax.bytes() * 8 * nover / (secs * 1e9);
786  printfQuda("OVR: Time = %6.6f s, Gflop/s = %6.1f, GB/s = %6.1f\n", secs, gflops * comm_size(), gbytes * comm_size());
787  }
788  }
789  };
790 
791  /** @brief Perform heatbath and overrelaxation. Performs nhb heatbath steps followed by nover overrelaxation steps.
792  *
793  * @param[in,out] data Gauge field
794  * @param[in,out] rngstate state of the CURAND random number generator
795  * @param[in] Beta inverse of the gauge coupling, beta = 2 Nc / g_0^2
796  * @param[in] nhb number of heatbath steps
797  * @param[in] nover number of overrelaxation steps
798  */
799  void Monte(GaugeField& data, RNG &rngstate, double Beta, int nhb, int nover) {
800 #ifdef GPU_GAUGE_ALG
801  instantiate<MonteAlg>(data, rngstate, (float)Beta, nhb, nover);
802 #else
803  errorQuda("Pure gauge code has not been built");
804 #endif // GPU_GAUGE_ALG
805  }
806 
807 }