QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
pgauge_heatbath.cu
Go to the documentation of this file.
1 #include <quda_internal.h>
2 #include <quda_matrix.h>
3 #include <tune_quda.h>
4 #include <gauge_field.h>
5 #include <gauge_field_order.h>
6 #include <launch_kernel.cuh>
7 #include <comm_quda.h>
8 #include <pgauge_monte.h>
9 #include <gauge_tools.h>
10 #include <random_quda.h>
11 #include <index_helper.cuh>
12 #include <atomic.cuh>
13 #include <cub_helper.cuh>
14 
15 
16 
17 #ifndef PI
18 #define PI 3.1415926535897932384626433832795 // pi
19 #endif
20 #ifndef PII
21 #define PII 6.2831853071795864769252867665590 // 2 * pi
22 #endif
23 
24 namespace quda {
25 
26 #ifdef GPU_GAUGE_ALG
27 
28 
34  template<int NCOLORS>
35  __host__ __device__ static inline int2 IndexBlock(int block){
36  int2 id;
37  int i1;
38  int found = 0;
39  int del_i = 0;
40  int index = -1;
41  while ( del_i < (NCOLORS - 1) && found == 0 ) {
42  del_i++;
43  for ( i1 = 0; i1 < (NCOLORS - del_i); i1++ ) {
44  index++;
45  if ( index == block ) {
46  found = 1;
47  break;
48  }
49  }
50  }
51  id.y = i1 + del_i;
52  id.x = i1;
53  return id;
54  }
61  template<int NCOLORS>
62  __host__ __device__ static inline void IndexBlock(int block, int &p, int &q){
63  if ( NCOLORS == 3 ) {
64  if ( block == 0 ) { p = 0; q = 1; }
65  else if ( block == 1 ) { p = 1; q = 2; }
66  else{ p = 0; q = 2; }
67  }
68  else if ( NCOLORS > 3 ) {
69  int i1;
70  int found = 0;
71  int del_i = 0;
72  int index = -1;
73  while ( del_i < (NCOLORS - 1) && found == 0 ) {
74  del_i++;
75  for ( i1 = 0; i1 < (NCOLORS - del_i); i1++ ) {
76  index++;
77  if ( index == block ) {
78  found = 1;
79  break;
80  }
81  }
82  }
83  q = i1 + del_i;
84  p = i1;
85  }
86  }
87 
94  template <class T>
95  __device__ static inline Matrix<T,2> generate_su2_matrix_milc(T al, cuRNGState& localState){
96  T xr1, xr2, xr3, xr4, d, r;
97  int k;
98  xr1 = Random<T>(localState);
99  xr1 = (log((xr1 + 1.e-10)));
100  xr2 = Random<T>(localState);
101  xr2 = (log((xr2 + 1.e-10)));
102  xr3 = Random<T>(localState);
103  xr4 = Random<T>(localState);
104  xr3 = cos(PII * xr3);
105  d = -(xr2 + xr1 * xr3 * xr3 ) / al;
106  //now beat each site into submission
107  int nacd = 0;
108  if ((1.00 - 0.5 * d) > xr4 * xr4 ) nacd = 1;
109  if ( nacd == 0 && al > 2.0 ) { //k-p algorithm
110  for ( k = 0; k < 20; k++ ) {
111  //get four random numbers (add a small increment to prevent taking log(0.)
112  xr1 = Random<T>(localState);
113  xr1 = (log((xr1 + 1.e-10)));
114  xr2 = Random<T>(localState);
115  xr2 = (log((xr2 + 1.e-10)));
116  xr3 = Random<T>(localState);
117  xr4 = Random<T>(localState);
118  xr3 = cos(PII * xr3);
119  d = -(xr2 + xr1 * xr3 * xr3) / al;
120  if ((1.00 - 0.5 * d) > xr4 * xr4 ) break;
121  }
122  } //endif nacd
123  Matrix<T,2> a;
124  if ( nacd == 0 && al <= 2.0 ) { //creutz algorithm
125  xr3 = exp(-2.0 * al);
126  xr4 = 1.0 - xr3;
127  for ( k = 0; k < 20; k++ ) {
128  //get two random numbers
129  xr1 = Random<T>(localState);
130  xr2 = Random<T>(localState);
131  r = xr3 + xr4 * xr1;
132  a(0,0) = 1.00 + log(r) / al;
133  if ((1.0 - a(0,0) * a(0,0)) > xr2 * xr2 ) break;
134  }
135  d = 1.0 - a(0,0);
136  } //endif nacd
137  //generate the four su(2) elements
138  //find a0 = 1 - d
139  a(0,0) = 1.0 - d;
140  //compute r
141  xr3 = 1.0 - a(0,0) * a(0,0);
142  xr3 = abs(xr3);
143  r = sqrt(xr3);
144  //compute a3
145  a(1,1) = (2.0 * Random<T>(localState) - 1.0) * r;
146  //compute a1 and a2
147  xr1 = xr3 - a(1,1) * a(1,1);
148  xr1 = abs(xr1);
149  xr1 = sqrt(xr1);
150  //xr2 is a random number between 0 and 2*pi
151  xr2 = PII * Random<T>(localState);
152  a(0,1) = xr1 * cos(xr2);
153  a(1,0) = xr1 * sin(xr2);
154  return a;
155  }
156 
157 
164  template < class T>
165  __host__ __device__ static inline Matrix<T,2> get_block_su2( Matrix<complex<T>,3> tmp1, int block ){
166  Matrix<T,2> r;
167  switch ( block ) {
168  case 0:
169  r(0,0) = tmp1(0,0).x + tmp1(1,1).x;
170  r(0,1) = tmp1(0,1).y + tmp1(1,0).y;
171  r(1,0) = tmp1(0,1).x - tmp1(1,0).x;
172  r(1,1) = tmp1(0,0).y - tmp1(1,1).y;
173  break;
174  case 1:
175  r(0,0) = tmp1(1,1).x + tmp1(2,2).x;
176  r(0,1) = tmp1(1,2).y + tmp1(2,1).y;
177  r(1,0) = tmp1(1,2).x - tmp1(2,1).x;
178  r(1,1) = tmp1(1,1).y - tmp1(2,2).y;
179  break;
180  case 2:
181  r(0,0) = tmp1(0,0).x + tmp1(2,2).x;
182  r(0,1) = tmp1(0,2).y + tmp1(2,0).y;
183  r(1,0) = tmp1(0,2).x - tmp1(2,0).x;
184  r(1,1) = tmp1(0,0).y - tmp1(2,2).y;
185  break;
186  }
187  return r;
188  }
189 
196  template <class T, int NCOLORS>
197  __host__ __device__ static inline Matrix<T,2> get_block_su2( Matrix<complex<T>,NCOLORS> tmp1, int2 id ){
198  Matrix<T,2> r;
199  r(0,0) = tmp1(id.x,id.x).x + tmp1(id.y,id.y).x;
200  r(0,1) = tmp1(id.x,id.y).y + tmp1(id.y,id.x).y;
201  r(1,0) = tmp1(id.x,id.y).x - tmp1(id.y,id.x).x;
202  r(1,1) = tmp1(id.x,id.x).y - tmp1(id.y,id.y).y;
203  return r;
204  }
205 
212  template <class T, int NCOLORS>
213  __host__ __device__ static inline Matrix<complex<T>,NCOLORS> block_su2_to_sun( Matrix<T,2> rr, int2 id ){
214  Matrix<complex<T>,NCOLORS> tmp1;
215  setIdentity(&tmp1);
216  tmp1(id.x,id.x) = complex<T>( rr(0,0), rr(1,1) );
217  tmp1(id.x,id.y) = complex<T>( rr(1,0), rr(0,1) );
218  tmp1(id.y,id.x) = complex<T>(-rr(1,0), rr(0,1) );
219  tmp1(id.y,id.y) = complex<T>( rr(0,0),-rr(1,1) );
220  return tmp1;
221  }
228  template <class T, int NCOLORS>
229  __host__ __device__ static inline void mul_block_sun( Matrix<T,2> u, Matrix<complex<T>,NCOLORS> &link, int2 id ){
230  for ( int j = 0; j < NCOLORS; j++ ) {
231  complex<T> tmp = complex<T>( u(0,0), u(1,1) ) * link(id.x, j) + complex<T>( u(1,0), u(0,1) ) * link(id.y, j);
232  link(id.y, j) = complex<T>(-u(1,0), u(0,1) ) * link(id.x, j) + complex<T>( u(0,0),-u(1,1) ) * link(id.y, j);
233  link(id.x, j) = tmp;
234  }
235  }
236 
246  template <class Cmplx>
247  __host__ __device__ static inline void block_su2_to_su3( Matrix<Cmplx,3> &U, Cmplx a00, Cmplx a01, Cmplx a10, Cmplx a11, int block ){
248  Cmplx tmp;
249  switch ( block ) {
250  case 0:
251  tmp = a00 * U(0,0) + a01 * U(1,0);
252  U(1,0) = a10 * U(0,0) + a11 * U(1,0);
253  U(0,0) = tmp;
254  tmp = a00 * U(0,1) + a01 * U(1,1);
255  U(1,1) = a10 * U(0,1) + a11 * U(1,1);
256  U(0,1) = tmp;
257  tmp = a00 * U(0,2) + a01 * U(1,2);
258  U(1,2) = a10 * U(0,2) + a11 * U(1,2);
259  U(0,2) = tmp;
260  break;
261  case 1:
262  tmp = a00 * U(1,0) + a01 * U(2,0);
263  U(2,0) = a10 * U(1,0) + a11 * U(2,0);
264  U(1,0) = tmp;
265  tmp = a00 * U(1,1) + a01 * U(2,1);
266  U(2,1) = a10 * U(1,1) + a11 * U(2,1);
267  U(1,1) = tmp;
268  tmp = a00 * U(1,2) + a01 * U(2,2);
269  U(2,2) = a10 * U(1,2) + a11 * U(2,2);
270  U(1,2) = tmp;
271  break;
272  case 2:
273  tmp = a00 * U(0,0) + a01 * U(2,0);
274  U(2,0) = a10 * U(0,0) + a11 * U(2,0);
275  U(0,0) = tmp;
276  tmp = a00 * U(0,1) + a01 * U(2,1);
277  U(2,1) = a10 * U(0,1) + a11 * U(2,1);
278  U(0,1) = tmp;
279  tmp = a00 * U(0,2) + a01 * U(2,2);
280  U(2,2) = a10 * U(0,2) + a11 * U(2,2);
281  U(0,2) = tmp;
282  break;
283  }
284  }
285 
286 
287 
288 // v * u^dagger
289  template <class Float>
290  __host__ __device__ static inline Matrix<Float,2> mulsu2UVDagger(Matrix<Float,2> v, Matrix<Float,2> u){
291  Matrix<Float,2> b;
292  b(0,0) = v(0,0) * u(0,0) + v(0,1) * u(0,1) + v(1,0) * u(1,0) + v(1,1) * u(1,1);
293  b(0,1) = v(0,1) * u(0,0) - v(0,0) * u(0,1) + v(1,0) * u(1,1) - v(1,1) * u(1,0);
294  b(1,0) = v(1,0) * u(0,0) - v(0,0) * u(1,0) + v(1,1) * u(0,1) - v(0,1) * u(1,1);
295  b(1,1) = v(1,1) * u(0,0) - v(0,0) * u(1,1) + v(0,1) * u(1,0) - v(1,0) * u(0,1);
296  return b;
297  }
298 
305  template <class Float, int NCOLORS>
306  __device__ inline void heatBathSUN( Matrix<complex<Float>,NCOLORS>& U, Matrix<complex<Float>,NCOLORS> F,
307  cuRNGState& localState, Float BetaOverNc ){
308 
309  if ( NCOLORS == 3 ) {
311  /*
312  for( int block = 0; block < NCOLORS; block++ ) {
313  Matrix<complex<T>,3> tmp1 = U * F;
314  Matrix<T,2> r = get_block_su2<T>(tmp1, block);
315  T k = sqrt(r(0,0)*r(0,0)+r(0,1)*r(0,1)+r(1,0)*r(1,0)+r(1,1)*r(1,1));
316  T ap = BetaOverNc * k;
317  k = (T)1.0 / k;
318  r *= k;
319  //Matrix<T,2> a = generate_su2_matrix<T4, T>(ap, localState);
320  Matrix<T,2> a = generate_su2_matrix_milc<T>(ap, localState);
321  r = mulsu2UVDagger_4<T>( a, r);
323  block_su2_to_su3<T>( U, complex( r(0,0), r(1,1) ), complex( r(1,0), r(0,1) ), complex(-r(1,0), r(0,1) ), complex( r(0,0),-r(1,1) ), block );
324  //FLOP_min = (198 + 4 + 15 + 28 + 28 + 84) * 3 = 1071
325  }*/
327 
328  for ( int block = 0; block < NCOLORS; block++ ) {
329  int p,q;
330  IndexBlock<NCOLORS>(block, p, q);
331  complex<Float> a0((Float)0.0, (Float)0.0);
332  complex<Float> a1 = a0;
333  complex<Float> a2 = a0;
334  complex<Float> a3 = a0;
335 
336  for ( int j = 0; j < NCOLORS; j++ ) {
337  a0 += U(p,j) * F(j,p);
338  a1 += U(p,j) * F(j,q);
339  a2 += U(q,j) * F(j,p);
340  a3 += U(q,j) * F(j,q);
341  }
342  Matrix<Float,2> r;
343  r(0,0) = a0.x + a3.x;
344  r(0,1) = a1.y + a2.y;
345  r(1,0) = a1.x - a2.x;
346  r(1,1) = a0.y - a3.y;
347  Float k = sqrt(r(0,0) * r(0,0) + r(0,1) * r(0,1) + r(1,0) * r(1,0) + r(1,1) * r(1,1));;
348  Float ap = BetaOverNc * k;
349  k = 1.0 / k;
350  r *= k;
351  Matrix<Float,2> a = generate_su2_matrix_milc<Float>(ap, localState);
352  r = mulsu2UVDagger<Float>( a, r);
354  a0 = complex<Float>( r(0,0), r(1,1) );
355  a1 = complex<Float>( r(1,0), r(0,1) );
356  a2 = complex<Float>(-r(1,0), r(0,1) );
357  a3 = complex<Float>( r(0,0),-r(1,1) );
358  complex<Float> tmp0;
359 
360  for ( int j = 0; j < NCOLORS; j++ ) {
361  tmp0 = a0 * U(p,j) + a1 * U(q,j);
362  U(q,j) = a2 * U(p,j) + a3 * U(q,j);
363  U(p,j) = tmp0;
364  }
365  //FLOP_min = (NCOLORS * 64 + 19 + 28 + 28) * 3 = NCOLORS * 192 + 225
366  }
368  }
369  else if ( NCOLORS > 3 ) {
371  //TESTED IN SU(4) SP THIS IS WORST
372  Matrix<complex<Float>,NCOLORS> M = U * F;
373  for ( int block = 0; block < NCOLORS * ( NCOLORS - 1) / 2; block++ ) {
374  int2 id = IndexBlock<NCOLORS>( block );
375  Matrix<Float,2> r = get_block_su2<Float>(M, id);
376  Float k = sqrt(r(0,0) * r(0,0) + r(0,1) * r(0,1) + r(1,0) * r(1,0) + r(1,1) * r(1,1));
377  Float ap = BetaOverNc * k;
378  k = 1.0 / k;
379  r *= k;
380  Matrix<Float,2> a = generate_su2_matrix_milc<Float>(ap, localState);
381  Matrix<Float,2> rr = mulsu2UVDagger<Float>( a, r);
383  mul_block_sun<Float, NCOLORS>( rr, U, id);
384  mul_block_sun<Float, NCOLORS>( rr, M, id);
386  }
387  /* / TESTED IN SU(4) SP THIS IS FASTER
388  for ( int block = 0; block < NCOLORS * ( NCOLORS - 1) / 2; block++ ) {
389  int2 id = IndexBlock<NCOLORS>( block );
390  complex a0 = complex::zero();
391  complex a1 = complex::zero();
392  complex a2 = complex::zero();
393  complex a3 = complex::zero();
394 
395  for ( int j = 0; j < NCOLORS; j++ ) {
396  a0 += U(id.x, j) * F.e[j][id.x];
397  a1 += U(id.x, j) * F.e[j][id.y];
398  a2 += U(id.y, j) * F.e[j][id.x];
399  a3 += U(id.y, j) * F.e[j][id.y];
400  }
401  Matrix<T,2> r;
402  r(0,0) = a0.x + a3.x;
403  r(0,1) = a1.y + a2.y;
404  r(1,0) = a1.x - a2.x;
405  r(1,1) = a0.y - a3.y;
406  T k = sqrt(r(0,0) * r(0,0) + r(0,1) * r(0,1) + r(1,0) * r(1,0) + r(1,1) * r(1,1));
407  T ap = BetaOverNc * k;
408  k = (T)1.0 / k;
409  r *= k;
410  //Matrix<T,2> a = generate_su2_matrix<T4, T>(ap, localState);
411  Matrix<T,2> a = generate_su2_matrix_milc<T>(ap, localState);
412  r = mulsu2UVDagger<T>( a, r);
413  mul_block_sun<T>( r, U, id); */
414  /*
415  a0 = complex( r(0,0), r(1,1) );
416  a1 = complex( r(1,0), r(0,1) );
417  a2 = complex(-r(1,0), r(0,1) );
418  a3 = complex( r(0,0),-r(1,1) );
419  complex tmp0;
420 
421  for ( int j = 0; j < NCOLORS; j++ ) {
422  tmp0 = a0 * U(id.x, j) + a1 * U(id.y, j);
423  U(id.y, j) = a2 * U(id.x, j) + a3 * U(id.y, j);
424  U(id.x, j) = tmp0;
425  } */
426  // }
427 
428  }
430  }
431 
433 
438  template <class Float, int NCOLORS>
439  __device__ inline void overrelaxationSUN( Matrix<complex<Float>,NCOLORS>& U, Matrix<complex<Float>,NCOLORS> F ){
440 
441  if ( NCOLORS == 3 ) {
443  /*
444  for( int block = 0; block < 3; block++ ) {
445  Matrix<complex<T>,3> tmp1 = U * F;
446  Matrix<T,2> r = get_block_su2<T>(tmp1, block);
447  //normalize and conjugate
448  Float norm = 1.0 / sqrt(r(0,0)*r(0,0)+r(0,1)*r(0,1)+r(1,0)*r(1,0)+r(1,1)*r(1,1));;
449  r(0,0) *= norm;
450  r(0,1) *= -norm;
451  r(1,0) *= -norm;
452  r(1,1) *= -norm;
454  complex a00 = complex( r(0,0), r(1,1) );
455  complex a01 = complex( r(1,0), r(0,1) );
456  complex a10 = complex(-r(1,0), r(0,1) );
457  complex a11 = complex( r(0,0),-r(1,1) );
458  block_su2_to_su3<T>( U, a00, a01, a10, a11, block );
459  block_su2_to_su3<T>( U, a00, a01, a10, a11, block );
460 
461  //FLOP = (198 + 17 + 84 * 2) * 3 = 1149
462  }*/
464  //This version does not need to multiply all matrix at each block: tmp1 = U * F;
466 
467  for ( int block = 0; block < 3; block++ ) {
468  int p,q;
469  IndexBlock<NCOLORS>(block, p, q);
470  complex<Float> a0((Float)0., (Float)0.);
471  complex<Float> a1 = a0;
472  complex<Float> a2 = a0;
473  complex<Float> a3 = a0;
474 
475  for ( int j = 0; j < NCOLORS; j++ ) {
476  a0 += U(p,j) * F(j,p);
477  a1 += U(p,j) * F(j,q);
478  a2 += U(q,j) * F(j,p);
479  a3 += U(q,j) * F(j,q);
480  }
481  Matrix<Float,2> r;
482  r(0,0) = a0.x + a3.x;
483  r(0,1) = a1.y + a2.y;
484  r(1,0) = a1.x - a2.x;
485  r(1,1) = a0.y - a3.y;
486  //normalize and conjugate
487  //r = r.conj_normalize();
488  Float norm = 1.0 / sqrt(r(0,0) * r(0,0) + r(0,1) * r(0,1) + r(1,0) * r(1,0) + r(1,1) * r(1,1));;
489  r(0,0) *= norm;
490  r(0,1) *= -norm;
491  r(1,0) *= -norm;
492  r(1,1) *= -norm;
493 
494 
496  a0 = complex<Float>( r(0,0), r(1,1) );
497  a1 = complex<Float>( r(1,0), r(0,1) );
498  a2 = complex<Float>(-r(1,0), r(0,1) );
499  a3 = complex<Float>( r(0,0),-r(1,1) );
500  complex<Float> tmp0, tmp1;
501 
502  for ( int j = 0; j < NCOLORS; j++ ) {
503  tmp0 = a0 * U(p,j) + a1 * U(q,j);
504  tmp1 = a2 * U(p,j) + a3 * U(q,j);
505  U(p,j) = a0 * tmp0 + a1 * tmp1;
506  U(q,j) = a2 * tmp0 + a3 * tmp1;
507  }
508  //FLOP = (NCOLORS * 88 + 17) * 3
509  }
511  }
512  else if ( NCOLORS > 3 ) {
514  Matrix<complex<Float>,NCOLORS> M = U * F;
515  for ( int block = 0; block < NCOLORS * ( NCOLORS - 1) / 2; block++ ) {
516  int2 id = IndexBlock<NCOLORS>( block );
517  Matrix<Float,2> r = get_block_su2<Float, NCOLORS>(M, id);
518  //normalize and conjugate
519  Float norm = 1.0 / sqrt(r(0,0) * r(0,0) + r(0,1) * r(0,1) + r(1,0) * r(1,0) + r(1,1) * r(1,1));;
520  r(0,0) *= norm;
521  r(0,1) *= -norm;
522  r(1,0) *= -norm;
523  r(1,1) *= -norm;
524  mul_block_sun<Float, NCOLORS>( r, U, id);
525  mul_block_sun<Float, NCOLORS>( r, U, id);
526  mul_block_sun<Float, NCOLORS>( r, M, id);
527  mul_block_sun<Float, NCOLORS>( r, M, id);
529  }
530  /* //TESTED IN SU(4) SP THIS IS WORST
531  for( int block = 0; block < NCOLORS * ( NCOLORS - 1) / 2; block++ ) {
532  int2 id = IndexBlock<NCOLORS>( block );
533  complex a0 = complex::zero();
534  complex a1 = complex::zero();
535  complex a2 = complex::zero();
536  complex a3 = complex::zero();
537 
538  for(int j = 0; j < NCOLORS; j++){
539  a0 += U(id.x, j) * F.e[j][id.x];
540  a1 += U(id.x, j) * F.e[j][id.y];
541  a2 += U(id.y, j) * F.e[j][id.x];
542  a3 += U(id.y, j) * F.e[j][id.y];
543  }
544  Matrix<T,2> r;
545  r(0,0) = a0.x + a3.x;
546  r(0,1) = a1.y + a2.y;
547  r(1,0) = a1.x - a2.x;
548  r(1,1) = a0.y - a3.y;
549  //normalize and conjugate
550  Float norm = 1.0 / sqrt(r(0,0)*r(0,0)+r(0,1)*r(0,1)+r(1,0)*r(1,0)+r(1,1)*r(1,1));;
551  r(0,0) *= norm;
552  r(0,1) *= -norm;
553  r(1,0) *= -norm;
554  r(1,1) *= -norm;
555  //mul_block_sun<T>( r, U, id);
556  //mul_block_sun<T>( r, U, id);
558  a0 = complex( r(0,0), r(1,1) );
559  a1 = complex( r(1,0), r(0,1) );
560  a2 = complex(-r(1,0), r(0,1) );
561  a3 = complex( r(0,0),-r(1,1) );
562  complex tmp0, tmp1;
563 
564  for(int j = 0; j < NCOLORS; j++){
565  tmp0 = a0 * U(id.x, j) + a1 * U(id.y, j);
566  tmp1 = a2 * U(id.x, j) + a3 * U(id.y, j);
567  U(id.x, j) = a0 * tmp0 + a1 * tmp1;
568  U(id.y, j) = a2 * tmp0 + a3 * tmp1;
569  }
570  }
571  */
572  }
573  }
574 
575 
576  template <typename Gauge, typename Float, int NCOLORS>
577  struct MonteArg {
578  int threads; // number of active threads required
579  int X[4]; // grid dimensions
580 #ifdef MULTI_GPU
581  int border[4];
582 #endif
583  Gauge dataOr;
584  cudaGaugeField &data;
585  Float BetaOverNc;
586  RNG rngstate;
587  MonteArg(const Gauge &dataOr, cudaGaugeField & data, Float Beta, RNG &rngstate)
588  : dataOr(dataOr), data(data), rngstate(rngstate) {
589  BetaOverNc = Beta / (Float)NCOLORS;
590 #ifdef MULTI_GPU
591  for ( int dir = 0; dir < 4; ++dir ) {
592  border[dir] = data.R()[dir];
593  X[dir] = data.X()[dir] - border[dir] * 2;
594  }
595 #else
596  for ( int dir = 0; dir < 4; ++dir ) X[dir] = data.X()[dir];
597 #endif
598  threads = X[0] * X[1] * X[2] * X[3] >> 1;
599  }
600  };
601 
602 
603  template<typename Float, typename Gauge, int NCOLORS, bool HeatbathOrRelax>
604  __global__ void compute_heatBath(MonteArg<Gauge, Float, NCOLORS> arg, int mu, int parity){
605  int idx = threadIdx.x + blockIdx.x * blockDim.x;
606  if ( idx >= arg.threads ) return;
607  int id = idx;
608  int X[4];
609  #pragma unroll
610  for ( int dr = 0; dr < 4; ++dr ) X[dr] = arg.X[dr];
611 
612  int x[4];
613  getCoords(x, idx, X, parity);
614 #ifdef MULTI_GPU
615  #pragma unroll
616  for ( int dr = 0; dr < 4; ++dr ) {
617  x[dr] += arg.border[dr];
618  X[dr] += 2 * arg.border[dr];
619  }
620  idx = linkIndex(x,X);
621 #endif
622 
623  Matrix<complex<Float>,NCOLORS> staple;
624  setZero(&staple);
625 
626  Matrix<complex<Float>,NCOLORS> U;
627  for ( int nu = 0; nu < 4; nu++ ) if ( mu != nu ) {
628  int dx[4] = { 0, 0, 0, 0 };
629  Matrix<complex<Float>,NCOLORS> link = arg.dataOr(nu, idx, parity);
630  dx[nu]++;
631  U = arg.dataOr(mu, linkIndexShift(x,dx,X), 1 - parity);
632  link *= U;
633  dx[nu]--;
634  dx[mu]++;
635  U = arg.dataOr(nu, linkIndexShift(x,dx,X), 1 - parity);
636  link *= conj(U);
637  staple += link;
638  dx[mu]--;
639  dx[nu]--;
640  link = arg.dataOr(nu, linkIndexShift(x,dx,X), 1 - parity);
641  U = arg.dataOr(mu, linkIndexShift(x,dx,X), 1 - parity);
642  link = conj(link) * U;
643  dx[mu]++;
644  U = arg.dataOr(nu, linkIndexShift(x,dx,X), parity);
645  link *= U;
646  staple += link;
647  }
648  U = arg.dataOr(mu, idx, parity);
649  if ( HeatbathOrRelax ) {
650  cuRNGState localState = arg.rngstate.State()[ id ];
651  heatBathSUN<Float, NCOLORS>( U, conj(staple), localState, arg.BetaOverNc );
652  arg.rngstate.State()[ id ] = localState;
653  }
654  else{
655  overrelaxationSUN<Float, NCOLORS>( U, conj(staple) );
656  }
657  arg.dataOr(mu, idx, parity) = U;
658  }
659 
660 
661  template<typename Float, typename Gauge, int NCOLORS, int NElems, bool HeatbathOrRelax>
662  class GaugeHB : Tunable {
663  MonteArg<Gauge, Float, NCOLORS> arg;
664  int mu;
665  int parity;
666  mutable char aux_string[128]; // used as a label in the autotuner
667  private:
668  unsigned int sharedBytesPerThread() const {
669  return 0;
670  }
671  unsigned int sharedBytesPerBlock(const TuneParam &param) const {
672  return 0;
673  }
674  //bool tuneSharedBytes() const { return false; } // Don't tune shared memory
675  bool tuneGridDim() const {
676  return false;
677  } // Don't tune the grid dimensions.
678  unsigned int minThreads() const {
679  return arg.threads;
680  }
681 
682  public:
683  GaugeHB(MonteArg<Gauge, Float, NCOLORS> &arg)
684  : arg(arg), mu(0), parity(0) {
685  }
686  ~GaugeHB () {
687  }
688  void SetParam(int _mu, int _parity){
689  mu = _mu;
690  parity = _parity;
691  }
692  void apply(const cudaStream_t &stream){
693  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
694  compute_heatBath<Float, Gauge, NCOLORS, HeatbathOrRelax > <<< tp.grid,tp.block, tp.shared_bytes, stream >>> (arg, mu, parity);
695  }
696 
697  TuneKey tuneKey() const {
698  std::stringstream vol;
699  vol << arg.X[0] << "x";
700  vol << arg.X[1] << "x";
701  vol << arg.X[2] << "x";
702  vol << arg.X[3];
703  sprintf(aux_string,"threads=%d,prec=%lu",arg.threads, sizeof(Float));
704  return TuneKey(vol.str().c_str(), typeid(*this).name(), aux_string);
705  }
706 
707  void preTune() {
708  arg.data.backup();
709  if(HeatbathOrRelax) arg.rngstate.backup();
710  }
711  void postTune() {
712  arg.data.restore();
713  if(HeatbathOrRelax) arg.rngstate.restore();
714  }
715  long long flops() const {
716 
717  //NEED TO CHECK THIS!!!!!!
718  if ( NCOLORS == 3 ) {
719  long long flop = 2268LL;
720  if ( HeatbathOrRelax ) {
721  flop += 801LL;
722  }
723  else{
724  flop += 843LL;
725  }
726  flop *= arg.threads;
727  return flop;
728  }
729  else{
730  long long flop = NCOLORS * NCOLORS * NCOLORS * 84LL;
731  if ( HeatbathOrRelax ) {
732  flop += NCOLORS * NCOLORS * NCOLORS + (NCOLORS * ( NCOLORS - 1) / 2) * (46LL + 48LL + 56LL * NCOLORS);
733  }
734  else{
735  flop += NCOLORS * NCOLORS * NCOLORS + (NCOLORS * ( NCOLORS - 1) / 2) * (17LL + 112LL * NCOLORS);
736  }
737  flop *= arg.threads;
738  return flop;
739  }
740  }
741  long long bytes() const {
742  //NEED TO CHECK THIS!!!!!!
743  if ( NCOLORS == 3 ) {
744  long long byte = 20LL * NElems * sizeof(Float);
745  if ( HeatbathOrRelax ) byte += 2LL * sizeof(cuRNGState);
746  byte *= arg.threads;
747  return byte;
748  }
749  else{
750  long long byte = 20LL * NCOLORS * NCOLORS * 2 * sizeof(Float);
751  if ( HeatbathOrRelax ) byte += 2LL * sizeof(cuRNGState);
752  byte *= arg.threads;
753  return byte;
754  }
755  }
756  };
757 
758 
759 
760 
761 
762 
763 
764 
765 
766  template<typename Float, int NElems, int NCOLORS, typename Gauge>
767  void Monte( Gauge dataOr, cudaGaugeField& data, RNG &rngstate, Float Beta, int nhb, int nover) {
768 
769  TimeProfile profileHBOVR("HeatBath_OR_Relax", false);
770  MonteArg<Gauge, Float, NCOLORS> montearg(dataOr, data, Beta, rngstate);
771  if ( getVerbosity() >= QUDA_SUMMARIZE ) profileHBOVR.TPSTART(QUDA_PROFILE_COMPUTE);
772  GaugeHB<Float, Gauge, NCOLORS, NElems, true> hb(montearg);
773  for ( int step = 0; step < nhb; ++step ) {
774  for ( int parity = 0; parity < 2; ++parity ) {
775  for ( int mu = 0; mu < 4; ++mu ) {
776  hb.SetParam(mu, parity);
777  hb.apply(0);
778  #ifdef MULTI_GPU
779  PGaugeExchange( data, mu, parity);
780  #endif
781  }
782  }
783  }
784  if ( getVerbosity() >= QUDA_SUMMARIZE ) {
786  profileHBOVR.TPSTOP(QUDA_PROFILE_COMPUTE);
787  double secs = profileHBOVR.Last(QUDA_PROFILE_COMPUTE);
788  double gflops = (hb.flops() * 8 * nhb * 1e-9) / (secs);
789  double gbytes = hb.bytes() * 8 * nhb / (secs * 1e9);
790  #ifdef MULTI_GPU
791  printfQuda("HB: Time = %6.6f s, Gflop/s = %6.1f, GB/s = %6.1f\n", secs, gflops * comm_size(), gbytes * comm_size());
792  #else
793  printfQuda("HB: Time = %6.6f s, Gflop/s = %6.1f, GB/s = %6.1f\n", secs, gflops, gbytes);
794  #endif
795  }
796 
797  if ( getVerbosity() >= QUDA_SUMMARIZE ) profileHBOVR.TPSTART(QUDA_PROFILE_COMPUTE);
798  GaugeHB<Float, Gauge, NCOLORS, NElems, false> relax(montearg);
799  for ( int step = 0; step < nover; ++step ) {
800  for ( int parity = 0; parity < 2; ++parity ) {
801  for ( int mu = 0; mu < 4; ++mu ) {
802  relax.SetParam(mu, parity);
803  relax.apply(0);
804  #ifdef MULTI_GPU
805  PGaugeExchange( data, mu, parity);
806  #endif
807  }
808  }
809  }
810  if ( getVerbosity() >= QUDA_SUMMARIZE ) {
812  profileHBOVR.TPSTOP(QUDA_PROFILE_COMPUTE);
813  double secs = profileHBOVR.Last(QUDA_PROFILE_COMPUTE);
814  double gflops = (relax.flops() * 8 * nover * 1e-9) / (secs);
815  double gbytes = relax.bytes() * 8 * nover / (secs * 1e9);
816  #ifdef MULTI_GPU
817  printfQuda("OVR: Time = %6.6f s, Gflop/s = %6.1f, GB/s = %6.1f\n", secs, gflops * comm_size(), gbytes * comm_size());
818  #else
819  printfQuda("OVR: Time = %6.6f s, Gflop/s = %6.1f, GB/s = %6.1f\n", secs, gflops, gbytes);
820  #endif
821  }
822  }
823 
824 
825 
826  template<typename Float>
827  void Monte( cudaGaugeField& data, RNG &rngstate, Float Beta, int nhb, int nover) {
828 
829  if ( data.isNative() ) {
830  if ( data.Reconstruct() == QUDA_RECONSTRUCT_NO ) {
831  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO>::type Gauge;
832  Monte<Float, 18, 3>(Gauge(data), data, rngstate, Beta, nhb, nover);
833  } else if ( data.Reconstruct() == QUDA_RECONSTRUCT_12 ) {
834  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_12>::type Gauge;
835  Monte<Float, 12, 3>(Gauge(data), data, rngstate, Beta, nhb, nover);
836  } else if ( data.Reconstruct() == QUDA_RECONSTRUCT_8 ) {
837  typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_8>::type Gauge;
838  Monte<Float, 8, 3>(Gauge(data), data, rngstate, Beta, nhb, nover);
839  } else {
840  errorQuda("Reconstruction type %d of gauge field not supported", data.Reconstruct());
841  }
842  } else {
843  errorQuda("Invalid Gauge Order\n");
844  }
845  }
846 #endif // GPU_GAUGE_ALG
847 
856  void Monte( cudaGaugeField& data, RNG &rngstate, double Beta, int nhb, int nover) {
857 #ifdef GPU_GAUGE_ALG
858  if ( data.Precision() == QUDA_SINGLE_PRECISION ) {
859  Monte<float> (data, rngstate, (float)Beta, nhb, nover);
860  } else if ( data.Precision() == QUDA_DOUBLE_PRECISION ) {
861  Monte<double>(data, rngstate, Beta, nhb, nover);
862  } else {
863  errorQuda("Precision %d not supported", data.Precision());
864  }
865 #else
866  errorQuda("Pure gauge code has not been built");
867 #endif // GPU_GAUGE_ALG
868  }
869 
870 
871 }
cudaColorSpinorField * tmp1
double mu
Definition: test_util.cpp:1648
struct curandStateMRG32k3a cuRNGState
Definition: random_quda.h:17
__device__ __host__ void setZero(Matrix< T, N > *m)
Definition: quda_matrix.h:702
static __device__ __host__ int linkIndexShift(const I x[], const J dx[], const K X[4])
static __device__ __host__ int linkIndex(const int x[], const I X[4])
__host__ __device__ ValueType norm(const complex< ValueType > &z)
Returns the magnitude of z squared.
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define PII
__host__ __device__ ValueType exp(ValueType x)
Definition: complex_quda.h:96
#define errorQuda(...)
Definition: util_quda.h:121
static __host__ __device__ void IndexBlock(int block, int &p, int &q)
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:120
cudaStream_t * stream
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:44
void PGaugeExchange(cudaGaugeField &data, const int dir, const int parity)
Perform heatbath and overrelaxation. Performs nhb heatbath steps followed by nover overrelaxation ste...
QudaGaugeParam param
Definition: pack_test.cpp:17
const int * R() const
double Last(QudaProfileType idx)
Definition: timer.h:251
int comm_size(void)
Definition: comm_mpi.cpp:88
#define qudaDeviceSynchronize()
__host__ __device__ ValueType sin(ValueType x)
Definition: complex_quda.h:51
Class declaration to initialize and hold CURAND RNG states.
Definition: random_quda.h:23
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
Main header file for host and device accessors to GaugeFields.
int X[4]
Definition: covdev_test.cpp:70
__device__ __host__ void setIdentity(Matrix< T, N > *m)
Definition: quda_matrix.h:653
__host__ __device__ ValueType log(ValueType x)
Definition: complex_quda.h:101
static int index(int ndim, const int *dims, const int *x)
Definition: comm_common.cpp:32
#define printfQuda(...)
Definition: util_quda.h:115
unsigned long long flops
Definition: blas_quda.cu:22
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
__host__ __device__ ValueType cos(ValueType x)
Definition: complex_quda.h:46
QudaReconstructType Reconstruct() const
Definition: gauge_field.h:250
__host__ __device__ ValueType abs(ValueType x)
Definition: complex_quda.h:125
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:130
void Monte(cudaGaugeField &data, RNG &rngstate, double Beta, int nhb, int nover)
Perform heatbath and overrelaxation. Performs nhb heatbath steps followed by nover overrelaxation ste...
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
QudaPrecision Precision() const
bool isNative() const
QudaParity parity
Definition: covdev_test.cpp:54
unsigned long long bytes
Definition: blas_quda.cu:23
__host__ __device__ int getCoords(int coord[], const Arg &arg, int &idx, int parity, int &dim)
Compute the space-time coordinates we are at.
const int * X() const