QUDA  0.9.0
pgauge_heatbath.cu
Go to the documentation of this file.
1 #include <quda_internal.h>
2 #include <quda_matrix.h>
3 #include <tune_quda.h>
4 #include <gauge_field.h>
5 #include <gauge_field_order.h>
6 #include <launch_kernel.cuh>
7 #include <comm_quda.h>
8 #include <pgauge_monte.h>
9 #include <gauge_tools.h>
10 #include <random_quda.h>
11 #include <index_helper.cuh>
12 #include <atomic.cuh>
13 #include <cub/cub.cuh>
14 
15 
16 
17 #ifndef PI
18 #define PI 3.1415926535897932384626433832795 // pi
19 #endif
20 #ifndef PII
21 #define PII 6.2831853071795864769252867665590 // 2 * pi
22 #endif
23 
24 namespace quda {
25 
26 #ifdef GPU_GAUGE_ALG
27 
28 
34  template<int NCOLORS>
35  __host__ __device__ static inline int2 IndexBlock(int block){
36  int2 id;
37  int i1;
38  int found = 0;
39  int del_i = 0;
40  int index = -1;
41  while ( del_i < (NCOLORS - 1) && found == 0 ) {
42  del_i++;
43  for ( i1 = 0; i1 < (NCOLORS - del_i); i1++ ) {
44  index++;
45  if ( index == block ) {
46  found = 1;
47  break;
48  }
49  }
50  }
51  id.y = i1 + del_i;
52  id.x = i1;
53  return id;
54  }
61  template<int NCOLORS>
62  __host__ __device__ static inline void IndexBlock(int block, int &p, int &q){
63  if ( NCOLORS == 3 ) {
64  if ( block == 0 ) { p = 0; q = 1; }
65  else if ( block == 1 ) { p = 1; q = 2; }
66  else{ p = 0; q = 2; }
67  }
68  else if ( NCOLORS > 3 ) {
69  int i1;
70  int found = 0;
71  int del_i = 0;
72  int index = -1;
73  while ( del_i < (NCOLORS - 1) && found == 0 ) {
74  del_i++;
75  for ( i1 = 0; i1 < (NCOLORS - del_i); i1++ ) {
76  index++;
77  if ( index == block ) {
78  found = 1;
79  break;
80  }
81  }
82  }
83  q = i1 + del_i;
84  p = i1;
85  }
86  }
87 
94  template <class T>
95  __device__ static inline Matrix<T,2> generate_su2_matrix_milc(T al, cuRNGState& localState){
96  T xr1, xr2, xr3, xr4, d, r;
97  int k;
98  xr1 = Random<T>(localState);
99  xr1 = (log((xr1 + 1.e-10)));
100  xr2 = Random<T>(localState);
101  xr2 = (log((xr2 + 1.e-10)));
102  xr3 = Random<T>(localState);
103  xr4 = Random<T>(localState);
104  xr3 = cos(PII * xr3);
105  d = -(xr2 + xr1 * xr3 * xr3 ) / al;
106  //now beat each site into submission
107  int nacd = 0;
108  if ((1.00 - 0.5 * d) > xr4 * xr4 ) nacd = 1;
109  if ( nacd == 0 && al > 2.0 ) { //k-p algorithm
110  for ( k = 0; k < 20; k++ ) {
111  //get four random numbers (add a small increment to prevent taking log(0.)
112  xr1 = Random<T>(localState);
113  xr1 = (log((xr1 + 1.e-10)));
114  xr2 = Random<T>(localState);
115  xr2 = (log((xr2 + 1.e-10)));
116  xr3 = Random<T>(localState);
117  xr4 = Random<T>(localState);
118  xr3 = cos(PII * xr3);
119  d = -(xr2 + xr1 * xr3 * xr3) / al;
120  if ((1.00 - 0.5 * d) > xr4 * xr4 ) break;
121  }
122  } //endif nacd
123  Matrix<T,2> a;
124  if ( nacd == 0 && al <= 2.0 ) { //creutz algorithm
125  xr3 = exp(-2.0 * al);
126  xr4 = 1.0 - xr3;
127  for ( k = 0; k < 20; k++ ) {
128  //get two random numbers
129  xr1 = Random<T>(localState);
130  xr2 = Random<T>(localState);
131  r = xr3 + xr4 * xr1;
132  a(0,0) = 1.00 + log(r) / al;
133  if ((1.0 - a(0,0) * a(0,0)) > xr2 * xr2 ) break;
134  }
135  d = 1.0 - a(0,0);
136  } //endif nacd
137  //generate the four su(2) elements
138  //find a0 = 1 - d
139  a(0,0) = 1.0 - d;
140  //compute r
141  xr3 = 1.0 - a(0,0) * a(0,0);
142  xr3 = abs(xr3);
143  r = sqrt(xr3);
144  //compute a3
145  a(1,1) = (2.0 * Random<T>(localState) - 1.0) * r;
146  //compute a1 and a2
147  xr1 = xr3 - a(1,1) * a(1,1);
148  xr1 = abs(xr1);
149  xr1 = sqrt(xr1);
150  //xr2 is a random number between 0 and 2*pi
151  xr2 = PII * Random<T>(localState);
152  a(0,1) = xr1 * cos(xr2);
153  a(1,0) = xr1 * sin(xr2);
154  return a;
155  }
156 
157 
164  template < class T>
165  __host__ __device__ static inline Matrix<T,2> get_block_su2( Matrix<complex<T>,3> tmp1, int block ){
166  Matrix<T,2> r;
167  switch ( block ) {
168  case 0:
169  r(0,0) = tmp1(0,0).x + tmp1(1,1).x;
170  r(0,1) = tmp1(0,1).y + tmp1(1,0).y;
171  r(1,0) = tmp1(0,1).x - tmp1(1,0).x;
172  r(1,1) = tmp1(0,0).y - tmp1(1,1).y;
173  break;
174  case 1:
175  r(0,0) = tmp1(1,1).x + tmp1(2,2).x;
176  r(0,1) = tmp1(1,2).y + tmp1(2,1).y;
177  r(1,0) = tmp1(1,2).x - tmp1(2,1).x;
178  r(1,1) = tmp1(1,1).y - tmp1(2,2).y;
179  break;
180  case 2:
181  r(0,0) = tmp1(0,0).x + tmp1(2,2).x;
182  r(0,1) = tmp1(0,2).y + tmp1(2,0).y;
183  r(1,0) = tmp1(0,2).x - tmp1(2,0).x;
184  r(1,1) = tmp1(0,0).y - tmp1(2,2).y;
185  break;
186  }
187  return r;
188  }
189 
196  template <class T, int NCOLORS>
197  __host__ __device__ static inline Matrix<T,2> get_block_su2( Matrix<complex<T>,NCOLORS> tmp1, int2 id ){
198  Matrix<T,2> r;
199  r(0,0) = tmp1(id.x,id.x).x + tmp1(id.y,id.y).x;
200  r(0,1) = tmp1(id.x,id.y).y + tmp1(id.y,id.x).y;
201  r(1,0) = tmp1(id.x,id.y).x - tmp1(id.y,id.x).x;
202  r(1,1) = tmp1(id.x,id.x).y - tmp1(id.y,id.y).y;
203  return r;
204  }
205 
212  template <class T, int NCOLORS>
213  __host__ __device__ static inline Matrix<complex<T>,NCOLORS> block_su2_to_sun( Matrix<T,2> rr, int2 id ){
214  Matrix<complex<T>,NCOLORS> tmp1;
215  setIdentity(&tmp1);
216  tmp1(id.x,id.x) = complex<T>( rr(0,0), rr(1,1) );
217  tmp1(id.x,id.y) = complex<T>( rr(1,0), rr(0,1) );
218  tmp1(id.y,id.x) = complex<T>(-rr(1,0), rr(0,1) );
219  tmp1(id.y,id.y) = complex<T>( rr(0,0),-rr(1,1) );
220  return tmp1;
221  }
228  template <class T, int NCOLORS>
229  __host__ __device__ static inline void mul_block_sun( Matrix<T,2> u, Matrix<complex<T>,NCOLORS> &link, int2 id ){
230  for ( int j = 0; j < NCOLORS; j++ ) {
231  complex<T> tmp = complex<T>( u(0,0), u(1,1) ) * link(id.x, j) + complex<T>( u(1,0), u(0,1) ) * link(id.y, j);
232  link(id.y, j) = complex<T>(-u(1,0), u(0,1) ) * link(id.x, j) + complex<T>( u(0,0),-u(1,1) ) * link(id.y, j);
233  link(id.x, j) = tmp;
234  }
235  }
236 
246  template <class Cmplx>
247  __host__ __device__ static inline void block_su2_to_su3( Matrix<Cmplx,3> &U, Cmplx a00, Cmplx a01, Cmplx a10, Cmplx a11, int block ){
248  Cmplx tmp;
249  switch ( block ) {
250  case 0:
251  tmp = a00 * U(0,0) + a01 * U(1,0);
252  U(1,0) = a10 * U(0,0) + a11 * U(1,0);
253  U(0,0) = tmp;
254  tmp = a00 * U(0,1) + a01 * U(1,1);
255  U(1,1) = a10 * U(0,1) + a11 * U(1,1);
256  U(0,1) = tmp;
257  tmp = a00 * U(0,2) + a01 * U(1,2);
258  U(1,2) = a10 * U(0,2) + a11 * U(1,2);
259  U(0,2) = tmp;
260  break;
261  case 1:
262  tmp = a00 * U(1,0) + a01 * U(2,0);
263  U(2,0) = a10 * U(1,0) + a11 * U(2,0);
264  U(1,0) = tmp;
265  tmp = a00 * U(1,1) + a01 * U(2,1);
266  U(2,1) = a10 * U(1,1) + a11 * U(2,1);
267  U(1,1) = tmp;
268  tmp = a00 * U(1,2) + a01 * U(2,2);
269  U(2,2) = a10 * U(1,2) + a11 * U(2,2);
270  U(1,2) = tmp;
271  break;
272  case 2:
273  tmp = a00 * U(0,0) + a01 * U(2,0);
274  U(2,0) = a10 * U(0,0) + a11 * U(2,0);
275  U(0,0) = tmp;
276  tmp = a00 * U(0,1) + a01 * U(2,1);
277  U(2,1) = a10 * U(0,1) + a11 * U(2,1);
278  U(0,1) = tmp;
279  tmp = a00 * U(0,2) + a01 * U(2,2);
280  U(2,2) = a10 * U(0,2) + a11 * U(2,2);
281  U(0,2) = tmp;
282  break;
283  }
284  }
285 
286 
287 
288 // v * u^dagger
289  template <class Float>
290  __host__ __device__ static inline Matrix<Float,2> mulsu2UVDagger(Matrix<Float,2> v, Matrix<Float,2> u){
292  b(0,0) = v(0,0) * u(0,0) + v(0,1) * u(0,1) + v(1,0) * u(1,0) + v(1,1) * u(1,1);
293  b(0,1) = v(0,1) * u(0,0) - v(0,0) * u(0,1) + v(1,0) * u(1,1) - v(1,1) * u(1,0);
294  b(1,0) = v(1,0) * u(0,0) - v(0,0) * u(1,0) + v(1,1) * u(0,1) - v(0,1) * u(1,1);
295  b(1,1) = v(1,1) * u(0,0) - v(0,0) * u(1,1) + v(0,1) * u(1,0) - v(1,0) * u(0,1);
296  return b;
297  }
298 
305  template <class Float, int NCOLORS>
306  __device__ inline void heatBathSUN( Matrix<complex<Float>,NCOLORS>& U, Matrix<complex<Float>,NCOLORS> F,
307  cuRNGState& localState, Float BetaOverNc ){
308 
309  if ( NCOLORS == 3 ) {
311  /*
312  for( int block = 0; block < NCOLORS; block++ ) {
313  Matrix<complex<T>,3> tmp1 = U * F;
314  Matrix<T,2> r = get_block_su2<T>(tmp1, block);
315  T k = sqrt(r(0,0)*r(0,0)+r(0,1)*r(0,1)+r(1,0)*r(1,0)+r(1,1)*r(1,1));
316  T ap = BetaOverNc * k;
317  k = (T)1.0 / k;
318  r *= k;
319  //Matrix<T,2> a = generate_su2_matrix<T4, T>(ap, localState);
320  Matrix<T,2> a = generate_su2_matrix_milc<T>(ap, localState);
321  r = mulsu2UVDagger_4<T>( a, r);
323  block_su2_to_su3<T>( U, complex( r(0,0), r(1,1) ), complex( r(1,0), r(0,1) ), complex(-r(1,0), r(0,1) ), complex( r(0,0),-r(1,1) ), block );
324  //FLOP_min = (198 + 4 + 15 + 28 + 28 + 84) * 3 = 1071
325  }*/
327 
328  for ( int block = 0; block < NCOLORS; block++ ) {
329  int p,q;
330  IndexBlock<NCOLORS>(block, p, q);
331  complex<Float> a0((Float)0.0, (Float)0.0);
332  complex<Float> a1 = a0;
333  complex<Float> a2 = a0;
334  complex<Float> a3 = a0;
335 
336  for ( int j = 0; j < NCOLORS; j++ ) {
337  a0 += U(p,j) * F(j,p);
338  a1 += U(p,j) * F(j,q);
339  a2 += U(q,j) * F(j,p);
340  a3 += U(q,j) * F(j,q);
341  }
342  Matrix<Float,2> r;
343  r(0,0) = a0.x + a3.x;
344  r(0,1) = a1.y + a2.y;
345  r(1,0) = a1.x - a2.x;
346  r(1,1) = a0.y - a3.y;
347  Float k = sqrt(r(0,0) * r(0,0) + r(0,1) * r(0,1) + r(1,0) * r(1,0) + r(1,1) * r(1,1));;
348  Float ap = BetaOverNc * k;
349  k = 1.0 / k;
350  r *= k;
351  Matrix<Float,2> a = generate_su2_matrix_milc<Float>(ap, localState);
352  r = mulsu2UVDagger<Float>( a, r);
354  a0 = complex<Float>( r(0,0), r(1,1) );
355  a1 = complex<Float>( r(1,0), r(0,1) );
356  a2 = complex<Float>(-r(1,0), r(0,1) );
357  a3 = complex<Float>( r(0,0),-r(1,1) );
358  complex<Float> tmp0;
359 
360  for ( int j = 0; j < NCOLORS; j++ ) {
361  tmp0 = a0 * U(p,j) + a1 * U(q,j);
362  U(q,j) = a2 * U(p,j) + a3 * U(q,j);
363  U(p,j) = tmp0;
364  }
365  //FLOP_min = (NCOLORS * 64 + 19 + 28 + 28) * 3 = NCOLORS * 192 + 225
366  }
368  }
369  else if ( NCOLORS > 3 ) {
371  //TESTED IN SU(4) SP THIS IS WORST
372  Matrix<complex<Float>,NCOLORS> M = U * F;
373  for ( int block = 0; block < NCOLORS * ( NCOLORS - 1) / 2; block++ ) {
374  int2 id = IndexBlock<NCOLORS>( block );
375  Matrix<Float,2> r = get_block_su2<Float>(M, id);
376  Float k = sqrt(r(0,0) * r(0,0) + r(0,1) * r(0,1) + r(1,0) * r(1,0) + r(1,1) * r(1,1));
377  Float ap = BetaOverNc * k;
378  k = 1.0 / k;
379  r *= k;
380  Matrix<Float,2> a = generate_su2_matrix_milc<Float>(ap, localState);
381  Matrix<Float,2> rr = mulsu2UVDagger<Float>( a, r);
383  mul_block_sun<Float, NCOLORS>( rr, U, id);
384  mul_block_sun<Float, NCOLORS>( rr, M, id);
386  }
387  /* / TESTED IN SU(4) SP THIS IS FASTER
388  for ( int block = 0; block < NCOLORS * ( NCOLORS - 1) / 2; block++ ) {
389  int2 id = IndexBlock<NCOLORS>( block );
390  complex a0 = complex::zero();
391  complex a1 = complex::zero();
392  complex a2 = complex::zero();
393  complex a3 = complex::zero();
394 
395  for ( int j = 0; j < NCOLORS; j++ ) {
396  a0 += U(id.x, j) * F.e[j][id.x];
397  a1 += U(id.x, j) * F.e[j][id.y];
398  a2 += U(id.y, j) * F.e[j][id.x];
399  a3 += U(id.y, j) * F.e[j][id.y];
400  }
401  Matrix<T,2> r;
402  r(0,0) = a0.x + a3.x;
403  r(0,1) = a1.y + a2.y;
404  r(1,0) = a1.x - a2.x;
405  r(1,1) = a0.y - a3.y;
406  T k = sqrt(r(0,0) * r(0,0) + r(0,1) * r(0,1) + r(1,0) * r(1,0) + r(1,1) * r(1,1));
407  T ap = BetaOverNc * k;
408  k = (T)1.0 / k;
409  r *= k;
410  //Matrix<T,2> a = generate_su2_matrix<T4, T>(ap, localState);
411  Matrix<T,2> a = generate_su2_matrix_milc<T>(ap, localState);
412  r = mulsu2UVDagger<T>( a, r);
413  mul_block_sun<T>( r, U, id); */
414  /*
415  a0 = complex( r(0,0), r(1,1) );
416  a1 = complex( r(1,0), r(0,1) );
417  a2 = complex(-r(1,0), r(0,1) );
418  a3 = complex( r(0,0),-r(1,1) );
419  complex tmp0;
420 
421  for ( int j = 0; j < NCOLORS; j++ ) {
422  tmp0 = a0 * U(id.x, j) + a1 * U(id.y, j);
423  U(id.y, j) = a2 * U(id.x, j) + a3 * U(id.y, j);
424  U(id.x, j) = tmp0;
425  } */
426  // }
427 
428  }
430  }
431 
433 
438  template <class Float, int NCOLORS>
439  __device__ inline void overrelaxationSUN( Matrix<complex<Float>,NCOLORS>& U, Matrix<complex<Float>,NCOLORS> F ){
440 
441  if ( NCOLORS == 3 ) {
443  /*
444  for( int block = 0; block < 3; block++ ) {
445  Matrix<complex<T>,3> tmp1 = U * F;
446  Matrix<T,2> r = get_block_su2<T>(tmp1, block);
447  //normalize and conjugate
448  Float norm = 1.0 / sqrt(r(0,0)*r(0,0)+r(0,1)*r(0,1)+r(1,0)*r(1,0)+r(1,1)*r(1,1));;
449  r(0,0) *= norm;
450  r(0,1) *= -norm;
451  r(1,0) *= -norm;
452  r(1,1) *= -norm;
454  complex a00 = complex( r(0,0), r(1,1) );
455  complex a01 = complex( r(1,0), r(0,1) );
456  complex a10 = complex(-r(1,0), r(0,1) );
457  complex a11 = complex( r(0,0),-r(1,1) );
458  block_su2_to_su3<T>( U, a00, a01, a10, a11, block );
459  block_su2_to_su3<T>( U, a00, a01, a10, a11, block );
460 
461  //FLOP = (198 + 17 + 84 * 2) * 3 = 1149
462  }*/
464  //This version does not need to multiply all matrix at each block: tmp1 = U * F;
466 
467  for ( int block = 0; block < 3; block++ ) {
468  int p,q;
469  IndexBlock<NCOLORS>(block, p, q);
470  complex<Float> a0((Float)0., (Float)0.);
471  complex<Float> a1 = a0;
472  complex<Float> a2 = a0;
473  complex<Float> a3 = a0;
474 
475  for ( int j = 0; j < NCOLORS; j++ ) {
476  a0 += U(p,j) * F(j,p);
477  a1 += U(p,j) * F(j,q);
478  a2 += U(q,j) * F(j,p);
479  a3 += U(q,j) * F(j,q);
480  }
481  Matrix<Float,2> r;
482  r(0,0) = a0.x + a3.x;
483  r(0,1) = a1.y + a2.y;
484  r(1,0) = a1.x - a2.x;
485  r(1,1) = a0.y - a3.y;
486  //normalize and conjugate
487  //r = r.conj_normalize();
488  Float norm = 1.0 / sqrt(r(0,0) * r(0,0) + r(0,1) * r(0,1) + r(1,0) * r(1,0) + r(1,1) * r(1,1));;
489  r(0,0) *= norm;
490  r(0,1) *= -norm;
491  r(1,0) *= -norm;
492  r(1,1) *= -norm;
493 
494 
496  a0 = complex<Float>( r(0,0), r(1,1) );
497  a1 = complex<Float>( r(1,0), r(0,1) );
498  a2 = complex<Float>(-r(1,0), r(0,1) );
499  a3 = complex<Float>( r(0,0),-r(1,1) );
500  complex<Float> tmp0, tmp1;
501 
502  for ( int j = 0; j < NCOLORS; j++ ) {
503  tmp0 = a0 * U(p,j) + a1 * U(q,j);
504  tmp1 = a2 * U(p,j) + a3 * U(q,j);
505  U(p,j) = a0 * tmp0 + a1 * tmp1;
506  U(q,j) = a2 * tmp0 + a3 * tmp1;
507  }
508  //FLOP = (NCOLORS * 88 + 17) * 3
509  }
511  }
512  else if ( NCOLORS > 3 ) {
514  Matrix<complex<Float>,NCOLORS> M = U * F;
515  for ( int block = 0; block < NCOLORS * ( NCOLORS - 1) / 2; block++ ) {
516  int2 id = IndexBlock<NCOLORS>( block );
517  Matrix<Float,2> r = get_block_su2<Float, NCOLORS>(M, id);
518  //normalize and conjugate
519  Float norm = 1.0 / sqrt(r(0,0) * r(0,0) + r(0,1) * r(0,1) + r(1,0) * r(1,0) + r(1,1) * r(1,1));;
520  r(0,0) *= norm;
521  r(0,1) *= -norm;
522  r(1,0) *= -norm;
523  r(1,1) *= -norm;
524  mul_block_sun<Float, NCOLORS>( r, U, id);
525  mul_block_sun<Float, NCOLORS>( r, U, id);
526  mul_block_sun<Float, NCOLORS>( r, M, id);
527  mul_block_sun<Float, NCOLORS>( r, M, id);
529  }
530  /* //TESTED IN SU(4) SP THIS IS WORST
531  for( int block = 0; block < NCOLORS * ( NCOLORS - 1) / 2; block++ ) {
532  int2 id = IndexBlock<NCOLORS>( block );
533  complex a0 = complex::zero();
534  complex a1 = complex::zero();
535  complex a2 = complex::zero();
536  complex a3 = complex::zero();
537 
538  for(int j = 0; j < NCOLORS; j++){
539  a0 += U(id.x, j) * F.e[j][id.x];
540  a1 += U(id.x, j) * F.e[j][id.y];
541  a2 += U(id.y, j) * F.e[j][id.x];
542  a3 += U(id.y, j) * F.e[j][id.y];
543  }
544  Matrix<T,2> r;
545  r(0,0) = a0.x + a3.x;
546  r(0,1) = a1.y + a2.y;
547  r(1,0) = a1.x - a2.x;
548  r(1,1) = a0.y - a3.y;
549  //normalize and conjugate
550  Float norm = 1.0 / sqrt(r(0,0)*r(0,0)+r(0,1)*r(0,1)+r(1,0)*r(1,0)+r(1,1)*r(1,1));;
551  r(0,0) *= norm;
552  r(0,1) *= -norm;
553  r(1,0) *= -norm;
554  r(1,1) *= -norm;
555  //mul_block_sun<T>( r, U, id);
556  //mul_block_sun<T>( r, U, id);
558  a0 = complex( r(0,0), r(1,1) );
559  a1 = complex( r(1,0), r(0,1) );
560  a2 = complex(-r(1,0), r(0,1) );
561  a3 = complex( r(0,0),-r(1,1) );
562  complex tmp0, tmp1;
563 
564  for(int j = 0; j < NCOLORS; j++){
565  tmp0 = a0 * U(id.x, j) + a1 * U(id.y, j);
566  tmp1 = a2 * U(id.x, j) + a3 * U(id.y, j);
567  U(id.x, j) = a0 * tmp0 + a1 * tmp1;
568  U(id.y, j) = a2 * tmp0 + a3 * tmp1;
569  }
570  }
571  */
572  }
573  }
574 
575 
576  template <typename Gauge, typename Float, int NCOLORS>
577  struct MonteArg {
578  int threads; // number of active threads required
579  int X[4]; // grid dimensions
580 #ifdef MULTI_GPU
581  int border[4];
582 #endif
583  Gauge dataOr;
584  cudaGaugeField &data;
585  Float BetaOverNc;
586  RNG rngstate;
587  MonteArg(const Gauge &dataOr, cudaGaugeField & data, Float Beta, RNG &rngstate)
588  : dataOr(dataOr), data(data), rngstate(rngstate) {
589  BetaOverNc = Beta / (Float)NCOLORS;
590 #ifdef MULTI_GPU
591  for ( int dir = 0; dir < 4; ++dir ) {
592  border[dir] = data.R()[dir];
593  X[dir] = data.X()[dir] - border[dir] * 2;
594  }
595 #else
596  for ( int dir = 0; dir < 4; ++dir ) X[dir] = data.X()[dir];
597 #endif
598  threads = X[0] * X[1] * X[2] * X[3] >> 1;
599  }
600  };
601 
602 
603  template<typename Float, typename Gauge, int NCOLORS, bool HeatbathOrRelax>
604  __global__ void compute_heatBath(MonteArg<Gauge, Float, NCOLORS> arg, int mu, int parity){
605  int idx = threadIdx.x + blockIdx.x * blockDim.x;
606  if ( idx >= arg.threads ) return;
607  int id = idx;
608  int X[4];
609  #pragma unroll
610  for ( int dr = 0; dr < 4; ++dr ) X[dr] = arg.X[dr];
611 
612  int x[4];
613  getCoords(x, idx, X, parity);
614 #ifdef MULTI_GPU
615  #pragma unroll
616  for ( int dr = 0; dr < 4; ++dr ) {
617  x[dr] += arg.border[dr];
618  X[dr] += 2 * arg.border[dr];
619  }
620  idx = linkIndex(x,X);
621 #endif
622 
623  Matrix<complex<Float>,NCOLORS> staple;
624  setZero(&staple);
625 
626  Matrix<complex<Float>,NCOLORS> U;
627  for ( int nu = 0; nu < 4; nu++ ) if ( mu != nu ) {
628  int dx[4] = { 0, 0, 0, 0 };
629  Matrix<complex<Float>,NCOLORS> link;
630  arg.dataOr.load((Float*)(link.data), idx, nu, parity);
631  dx[nu]++;
632  arg.dataOr.load((Float*)(U.data), linkIndexShift(x,dx,X), mu, 1 - parity);
633  link *= U;
634  dx[nu]--;
635  dx[mu]++;
636  arg.dataOr.load((Float*)(U.data), linkIndexShift(x,dx,X), nu, 1 - parity);
637  link *= conj(U);
638  staple += link;
639  dx[mu]--;
640  dx[nu]--;
641  arg.dataOr.load((Float*)(link.data), linkIndexShift(x,dx,X), nu, 1 - parity);
642  arg.dataOr.load((Float*)(U.data), linkIndexShift(x,dx,X), mu, 1 - parity);
643  link = conj(link) * U;
644  dx[mu]++;
645  arg.dataOr.load((Float*)(U.data), linkIndexShift(x,dx,X), nu, parity);
646  link *= U;
647  staple += link;
648  }
649  arg.dataOr.load((Float*)(U.data), idx, mu, parity);
650  if ( HeatbathOrRelax ) {
651  cuRNGState localState = arg.rngstate.State()[ id ];
652  heatBathSUN<Float, NCOLORS>( U, conj(staple), localState, arg.BetaOverNc );
653  arg.rngstate.State()[ id ] = localState;
654  }
655  else{
656  overrelaxationSUN<Float, NCOLORS>( U, conj(staple) );
657  }
658  arg.dataOr.save((Float*)(U.data), idx, mu, parity);
659  }
660 
661 
662  template<typename Float, typename Gauge, int NCOLORS, int NElems, bool HeatbathOrRelax>
663  class GaugeHB : Tunable {
664  MonteArg<Gauge, Float, NCOLORS> arg;
665  int mu;
666  int parity;
667  mutable char aux_string[128]; // used as a label in the autotuner
668  private:
669  unsigned int sharedBytesPerThread() const {
670  return 0;
671  }
672  unsigned int sharedBytesPerBlock(const TuneParam &param) const {
673  return 0;
674  }
675  //bool tuneSharedBytes() const { return false; } // Don't tune shared memory
676  bool tuneGridDim() const {
677  return false;
678  } // Don't tune the grid dimensions.
679  unsigned int minThreads() const {
680  return arg.threads;
681  }
682 
683  public:
684  GaugeHB(MonteArg<Gauge, Float, NCOLORS> &arg)
685  : arg(arg), mu(0), parity(0) {
686  }
687  ~GaugeHB () {
688  }
689  void SetParam(int _mu, int _parity){
690  mu = _mu;
691  parity = _parity;
692  }
693  void apply(const cudaStream_t &stream){
694  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
695  compute_heatBath<Float, Gauge, NCOLORS, HeatbathOrRelax ><< < tp.grid,tp.block, tp.shared_bytes, stream >> > (arg, mu, parity);
696  }
697 
698  TuneKey tuneKey() const {
699  std::stringstream vol;
700  vol << arg.X[0] << "x";
701  vol << arg.X[1] << "x";
702  vol << arg.X[2] << "x";
703  vol << arg.X[3];
704  sprintf(aux_string,"threads=%d,prec=%lu",arg.threads, sizeof(Float));
705  return TuneKey(vol.str().c_str(), typeid(*this).name(), aux_string);
706  }
707 
708  void preTune() {
709  arg.data.backup();
710  if(HeatbathOrRelax) arg.rngstate.backup();
711  }
712  void postTune() {
713  arg.data.restore();
714  if(HeatbathOrRelax) arg.rngstate.restore();
715  }
716  long long flops() const {
717 
718  //NEED TO CHECK THIS!!!!!!
719  if ( NCOLORS == 3 ) {
720  long long flop = 2268LL;
721  if ( HeatbathOrRelax ) {
722  flop += 801LL;
723  }
724  else{
725  flop += 843LL;
726  }
727  flop *= arg.threads;
728  return flop;
729  }
730  else{
731  long long flop = NCOLORS * NCOLORS * NCOLORS * 84LL;
732  if ( HeatbathOrRelax ) {
733  flop += NCOLORS * NCOLORS * NCOLORS + (NCOLORS * ( NCOLORS - 1) / 2) * (46LL + 48LL + 56LL * NCOLORS);
734  }
735  else{
736  flop += NCOLORS * NCOLORS * NCOLORS + (NCOLORS * ( NCOLORS - 1) / 2) * (17LL + 112LL * NCOLORS);
737  }
738  flop *= arg.threads;
739  return flop;
740  }
741  }
742  long long bytes() const {
743  //NEED TO CHECK THIS!!!!!!
744  if ( NCOLORS == 3 ) {
745  long long byte = 20LL * NElems * sizeof(Float);
746  if ( HeatbathOrRelax ) byte += 2LL * sizeof(cuRNGState);
747  byte *= arg.threads;
748  return byte;
749  }
750  else{
751  long long byte = 20LL * NCOLORS * NCOLORS * 2 * sizeof(Float);
752  if ( HeatbathOrRelax ) byte += 2LL * sizeof(cuRNGState);
753  byte *= arg.threads;
754  return byte;
755  }
756  }
757  };
758 
759 
760 
761 
762 
763 
764 
765 
766 
767  template<typename Float, int NElems, int NCOLORS, typename Gauge>
768  void Monte( Gauge dataOr, cudaGaugeField& data, RNG &rngstate, Float Beta, int nhb, int nover) {
769 
770  TimeProfile profileHBOVR("HeatBath_OR_Relax", false);
771  MonteArg<Gauge, Float, NCOLORS> montearg(dataOr, data, Beta, rngstate);
772  if ( getVerbosity() >= QUDA_SUMMARIZE ) profileHBOVR.TPSTART(QUDA_PROFILE_COMPUTE);
773  GaugeHB<Float, Gauge, NCOLORS, NElems, true> hb(montearg);
774  for ( int step = 0; step < nhb; ++step ) {
775  for ( int parity = 0; parity < 2; ++parity ) {
776  for ( int mu = 0; mu < 4; ++mu ) {
777  hb.SetParam(mu, parity);
778  hb.apply(0);
779  #ifdef MULTI_GPU
780  PGaugeExchange( data, mu, parity);
781  #endif
782  }
783  }
784  }
785  if ( getVerbosity() >= QUDA_SUMMARIZE ) {
787  profileHBOVR.TPSTOP(QUDA_PROFILE_COMPUTE);
788  double secs = profileHBOVR.Last(QUDA_PROFILE_COMPUTE);
789  double gflops = (hb.flops() * 8 * nhb * 1e-9) / (secs);
790  double gbytes = hb.bytes() * 8 * nhb / (secs * 1e9);
791  #ifdef MULTI_GPU
792  printfQuda("HB: Time = %6.6f s, Gflop/s = %6.1f, GB/s = %6.1f\n", secs, gflops * comm_size(), gbytes * comm_size());
793  #else
794  printfQuda("HB: Time = %6.6f s, Gflop/s = %6.1f, GB/s = %6.1f\n", secs, gflops, gbytes);
795  #endif
796  }
797 
798  if ( getVerbosity() >= QUDA_SUMMARIZE ) profileHBOVR.TPSTART(QUDA_PROFILE_COMPUTE);
799  GaugeHB<Float, Gauge, NCOLORS, NElems, false> relax(montearg);
800  for ( int step = 0; step < nover; ++step ) {
801  for ( int parity = 0; parity < 2; ++parity ) {
802  for ( int mu = 0; mu < 4; ++mu ) {
803  relax.SetParam(mu, parity);
804  relax.apply(0);
805  #ifdef MULTI_GPU
806  PGaugeExchange( data, mu, parity);
807  #endif
808  }
809  }
810  }
811  if ( getVerbosity() >= QUDA_SUMMARIZE ) {
813  profileHBOVR.TPSTOP(QUDA_PROFILE_COMPUTE);
814  double secs = profileHBOVR.Last(QUDA_PROFILE_COMPUTE);
815  double gflops = (relax.flops() * 8 * nover * 1e-9) / (secs);
816  double gbytes = relax.bytes() * 8 * nover / (secs * 1e9);
817  #ifdef MULTI_GPU
818  printfQuda("OVR: Time = %6.6f s, Gflop/s = %6.1f, GB/s = %6.1f\n", secs, gflops * comm_size(), gbytes * comm_size());
819  #else
820  printfQuda("OVR: Time = %6.6f s, Gflop/s = %6.1f, GB/s = %6.1f\n", secs, gflops, gbytes);
821  #endif
822  }
823  }
824 
825 
826 
827  template<typename Float>
828  void Monte( cudaGaugeField& data, RNG &rngstate, Float Beta, int nhb, int nover) {
829 
830  if ( data.isNative() ) {
831  if ( data.Reconstruct() == QUDA_RECONSTRUCT_NO ) {
832  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO>::type Gauge;
833  Monte<Float, 18, 3>(Gauge(data), data, rngstate, Beta, nhb, nover);
834  } else if ( data.Reconstruct() == QUDA_RECONSTRUCT_12 ) {
835  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_12>::type Gauge;
836  Monte<Float, 12, 3>(Gauge(data), data, rngstate, Beta, nhb, nover);
837  } else if ( data.Reconstruct() == QUDA_RECONSTRUCT_8 ) {
838  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_8>::type Gauge;
839  Monte<Float, 8, 3>(Gauge(data), data, rngstate, Beta, nhb, nover);
840  } else {
841  errorQuda("Reconstruction type %d of gauge field not supported", data.Reconstruct());
842  }
843  } else {
844  errorQuda("Invalid Gauge Order\n");
845  }
846  }
847 #endif // GPU_GAUGE_ALG
848 
857  void Monte( cudaGaugeField& data, RNG &rngstate, double Beta, int nhb, int nover) {
858 #ifdef GPU_GAUGE_ALG
859  if ( data.Precision() == QUDA_SINGLE_PRECISION ) {
860  Monte<float> (data, rngstate, (float)Beta, nhb, nover);
861  } else if ( data.Precision() == QUDA_DOUBLE_PRECISION ) {
862  Monte<double>(data, rngstate, Beta, nhb, nover);
863  } else {
864  errorQuda("Precision %d not supported", data.Precision());
865  }
866 #else
867  errorQuda("Pure gauge code has not been built");
868 #endif // GPU_GAUGE_ALG
869  }
870 
871 
872 }
dim3 dim3 blockDim
double mu
Definition: test_util.cpp:1643
struct curandStateMRG32k3a cuRNGState
Definition: random_quda.h:17
__device__ __host__ void setZero(Matrix< T, N > *m)
Definition: quda_matrix.h:592
static __device__ __host__ int linkIndexShift(const I x[], const J dx[], const K X[4])
static __device__ __host__ int linkIndex(const int x[], const I X[4])
__host__ __device__ ValueType norm(const complex< ValueType > &z)
Returns the magnitude of z squared.
Definition: complex_quda.h:896
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:20
#define PII
__host__ __device__ ValueType exp(ValueType x)
Definition: complex_quda.h:85
#define errorQuda(...)
Definition: util_quda.h:90
static __host__ __device__ void IndexBlock(int block, int &p, int &q)
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:105
cudaStream_t * stream
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:44
void PGaugeExchange(cudaGaugeField &data, const int dir, const int parity)
Perform heatbath and overrelaxation. Performs nhb heatbath steps followed by nover overrelaxation ste...
char * index(const char *, int)
QudaGaugeParam param
Definition: pack_test.cpp:17
#define b
int comm_size(void)
Definition: comm_mpi.cpp:126
__host__ __device__ ValueType sin(ValueType x)
Definition: complex_quda.h:40
def id
projector matrices ######################################################################## ...
Class declaration to initialize and hold CURAND RNG states.
Definition: random_quda.h:23
static __inline__ size_t p
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:603
Main header file for host and device accessors to GaugeFields.
#define tmp1
Definition: tmc_core.h:15
cudaError_t qudaDeviceSynchronize()
Wrapper around cudaDeviceSynchronize or cuDeviceSynchronize.
__device__ __host__ void setIdentity(Matrix< T, N > *m)
Definition: quda_matrix.h:543
__host__ __device__ ValueType log(ValueType x)
Definition: complex_quda.h:90
int sprintf(char *, const char *,...) __attribute__((__format__(__printf__
#define printfQuda(...)
Definition: util_quda.h:84
unsigned long long flops
Definition: blas_quda.cu:42
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
Definition: complex_quda.h:880
__host__ __device__ ValueType cos(ValueType x)
Definition: complex_quda.h:35
__host__ __device__ ValueType abs(ValueType x)
Definition: complex_quda.h:110
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:115
void Monte(cudaGaugeField &data, RNG &rngstate, double Beta, int nhb, int nover)
Perform heatbath and overrelaxation. Performs nhb heatbath steps followed by nover overrelaxation ste...
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:51
static __inline__ size_t size_t d
QudaPrecision Precision() const
QudaParity parity
Definition: covdev_test.cpp:53
#define a
#define tmp0
Definition: tmc_core.h:14
unsigned long long bytes
Definition: blas_quda.cu:43
static __device__ __host__ void getCoords(int x[], int cb_index, const I X[], int parity)