1 #include <quda_internal.h>
2 #include <quda_matrix.h>
4 #include <gauge_field.h>
5 #include <gauge_field_order.h>
6 #include <launch_kernel.cuh>
8 #include <pgauge_monte.h>
9 #include <gauge_tools.h>
10 #include <random_quda.h>
11 #include <index_helper.cuh>
13 #include <instantiate.h>
16 #define PI 3.1415926535897932384626433832795 // pi
19 #define PII 6.2831853071795864769252867665590 // 2 * pi
25 @brief Calculate the SU(2) index block in the SU(Nc) matrix
26 @param block number to calculate the index's, the total number of blocks is NCOLORS * ( NCOLORS - 1) / 2.
27 @return Returns two index's in int2 type, accessed by .x and .y.
30 __host__ __device__ static inline int2 IndexBlock(int block){
36 while ( del_i < (NCOLORS - 1) && found == 0 ) {
38 for ( i1 = 0; i1 < (NCOLORS - del_i); i1++ ) {
40 if ( index == block ) {
52 @brief Calculate the SU(2) index block in the SU(Nc) matrix
53 @param block number to calculate de index's, the total number of blocks is NCOLORS * ( NCOLORS - 1) / 2.
54 @param p store the first index
55 @param q store the second index
58 __host__ __device__ static inline void IndexBlock(int block, int &p, int &q){
60 if ( block == 0 ) { p = 0; q = 1; }
61 else if ( block == 1 ) { p = 1; q = 2; }
62 else { p = 0; q = 2; }
63 } else if ( NCOLORS > 3 ) {
68 while ( del_i < (NCOLORS - 1) && found == 0 ) {
70 for ( i1 = 0; i1 < (NCOLORS - del_i); i1++ ) {
72 if ( index == block ) {
84 @brief Generate full SU(2) matrix (four real numbers instead of 2x2 complex matrix) and update link matrix.
87 @param localstate CURAND rng state
90 __device__ static inline Matrix<T,2> generate_su2_matrix_milc(T al, cuRNGState& localState){
91 T xr1, xr2, xr3, xr4, d, r;
93 xr1 = Random<T>(localState);
94 xr1 = (log((xr1 + 1.e-10)));
95 xr2 = Random<T>(localState);
96 xr2 = (log((xr2 + 1.e-10)));
97 xr3 = Random<T>(localState);
98 xr4 = Random<T>(localState);
100 d = -(xr2 + xr1 * xr3 * xr3 ) / al;
101 //now beat each site into submission
103 if ((1.00 - 0.5 * d) > xr4 * xr4 ) nacd = 1;
104 if ( nacd == 0 && al > 2.0 ) { //k-p algorithm
105 for ( k = 0; k < 20; k++ ) {
106 //get four random numbers (add a small increment to prevent taking log(0.)
107 xr1 = Random<T>(localState);
108 xr1 = (log((xr1 + 1.e-10)));
109 xr2 = Random<T>(localState);
110 xr2 = (log((xr2 + 1.e-10)));
111 xr3 = Random<T>(localState);
112 xr4 = Random<T>(localState);
113 xr3 = cos(PII * xr3);
114 d = -(xr2 + xr1 * xr3 * xr3) / al;
115 if ((1.00 - 0.5 * d) > xr4 * xr4 ) break;
119 if ( nacd == 0 && al <= 2.0 ) { //creutz algorithm
120 xr3 = exp(-2.0 * al);
122 for ( k = 0; k < 20; k++ ) {
123 //get two random numbers
124 xr1 = Random<T>(localState);
125 xr2 = Random<T>(localState);
127 a(0,0) = 1.00 + log(r) / al;
128 if ((1.0 - a(0,0) * a(0,0)) > xr2 * xr2 ) break;
132 //generate the four su(2) elements
136 xr3 = 1.0 - a(0,0) * a(0,0);
140 a(1,1) = (2.0 * Random<T>(localState) - 1.0) * r;
142 xr1 = xr3 - a(1,1) * a(1,1);
145 //xr2 is a random number between 0 and 2*pi
146 xr2 = PII * Random<T>(localState);
147 a(0,1) = xr1 * cos(xr2);
148 a(1,0) = xr1 * sin(xr2);
153 @brief Return SU(2) subgroup (4 real numbers) from SU(3) matrix
154 @param tmp1 input SU(3) matrix
155 @param block to retrieve from 0 to 2.
156 @return 4 real numbers
159 __host__ __device__ static inline Matrix<T,2> get_block_su2( Matrix<complex<T>,3> tmp1, int block ){
163 r(0,0) = tmp1(0,0).x + tmp1(1,1).x;
164 r(0,1) = tmp1(0,1).y + tmp1(1,0).y;
165 r(1,0) = tmp1(0,1).x - tmp1(1,0).x;
166 r(1,1) = tmp1(0,0).y - tmp1(1,1).y;
169 r(0,0) = tmp1(1,1).x + tmp1(2,2).x;
170 r(0,1) = tmp1(1,2).y + tmp1(2,1).y;
171 r(1,0) = tmp1(1,2).x - tmp1(2,1).x;
172 r(1,1) = tmp1(1,1).y - tmp1(2,2).y;
175 r(0,0) = tmp1(0,0).x + tmp1(2,2).x;
176 r(0,1) = tmp1(0,2).y + tmp1(2,0).y;
177 r(1,0) = tmp1(0,2).x - tmp1(2,0).x;
178 r(1,1) = tmp1(0,0).y - tmp1(2,2).y;
185 @brief Return SU(2) subgroup (4 real numbers) from SU(Nc) matrix
186 @param tmp1 input SU(Nc) matrix
187 @param id the two indices to retrieve SU(2) block
188 @return 4 real numbers
190 template <class T, int NCOLORS>
191 __host__ __device__ static inline Matrix<T,2> get_block_su2( Matrix<complex<T>,NCOLORS> tmp1, int2 id ){
193 r(0,0) = tmp1(id.x,id.x).x + tmp1(id.y,id.y).x;
194 r(0,1) = tmp1(id.x,id.y).y + tmp1(id.y,id.x).y;
195 r(1,0) = tmp1(id.x,id.y).x - tmp1(id.y,id.x).x;
196 r(1,1) = tmp1(id.x,id.x).y - tmp1(id.y,id.y).y;
201 @brief Create a SU(Nc) identity matrix and fills with the SU(2) block
202 @param rr SU(2) matrix represented only by four real numbers
203 @param id the two indices to fill in the SU(3) matrix
204 @return SU(Nc) matrix
206 template <class T, int NCOLORS>
207 __host__ __device__ static inline Matrix<complex<T>,NCOLORS> block_su2_to_sun( Matrix<T,2> rr, int2 id ){
208 Matrix<complex<T>,NCOLORS> tmp1;
210 tmp1(id.x,id.x) = complex<T>( rr(0,0), rr(1,1) );
211 tmp1(id.x,id.y) = complex<T>( rr(1,0), rr(0,1) );
212 tmp1(id.y,id.x) = complex<T>(-rr(1,0), rr(0,1) );
213 tmp1(id.y,id.y) = complex<T>( rr(0,0),-rr(1,1) );
218 @brief Update the SU(Nc) link with the new SU(2) matrix, link <- u * link
219 @param u SU(2) matrix represented by four real numbers
220 @param link SU(Nc) matrix
223 template <class T, int NCOLORS>
224 __host__ __device__ static inline void mul_block_sun( Matrix<T,2> u, Matrix<complex<T>,NCOLORS> &link, int2 id ){
225 for ( int j = 0; j < NCOLORS; j++ ) {
226 complex<T> tmp = complex<T>( u(0,0), u(1,1) ) * link(id.x, j) + complex<T>( u(1,0), u(0,1) ) * link(id.y, j);
227 link(id.y, j) = complex<T>(-u(1,0), u(0,1) ) * link(id.x, j) + complex<T>( u(0,0),-u(1,1) ) * link(id.y, j);
233 @brief Update the SU(3) link with the new SU(2) matrix, link <- u * link
234 @param U SU(3) matrix
235 @param a00 element (0,0) of the SU(2) matrix
236 @param a01 element (0,1) of the SU(2) matrix
237 @param a10 element (1,0) of the SU(2) matrix
238 @param a11 element (1,1) of the SU(2) matrix
239 @param block of the SU(3) matrix, 0,1 or 2
241 template <class Cmplx>
242 __host__ __device__ static inline void block_su2_to_su3( Matrix<Cmplx,3> &U, Cmplx a00, Cmplx a01, Cmplx a10, Cmplx a11, int block ){
246 tmp = a00 * U(0,0) + a01 * U(1,0);
247 U(1,0) = a10 * U(0,0) + a11 * U(1,0);
249 tmp = a00 * U(0,1) + a01 * U(1,1);
250 U(1,1) = a10 * U(0,1) + a11 * U(1,1);
252 tmp = a00 * U(0,2) + a01 * U(1,2);
253 U(1,2) = a10 * U(0,2) + a11 * U(1,2);
257 tmp = a00 * U(1,0) + a01 * U(2,0);
258 U(2,0) = a10 * U(1,0) + a11 * U(2,0);
260 tmp = a00 * U(1,1) + a01 * U(2,1);
261 U(2,1) = a10 * U(1,1) + a11 * U(2,1);
263 tmp = a00 * U(1,2) + a01 * U(2,2);
264 U(2,2) = a10 * U(1,2) + a11 * U(2,2);
268 tmp = a00 * U(0,0) + a01 * U(2,0);
269 U(2,0) = a10 * U(0,0) + a11 * U(2,0);
271 tmp = a00 * U(0,1) + a01 * U(2,1);
272 U(2,1) = a10 * U(0,1) + a11 * U(2,1);
274 tmp = a00 * U(0,2) + a01 * U(2,2);
275 U(2,2) = a10 * U(0,2) + a11 * U(2,2);
282 template <class Float>
283 __host__ __device__ static inline Matrix<Float,2> mulsu2UVDagger(Matrix<Float,2> v, Matrix<Float,2> u){
285 b(0,0) = v(0,0) * u(0,0) + v(0,1) * u(0,1) + v(1,0) * u(1,0) + v(1,1) * u(1,1);
286 b(0,1) = v(0,1) * u(0,0) - v(0,0) * u(0,1) + v(1,0) * u(1,1) - v(1,1) * u(1,0);
287 b(1,0) = v(1,0) * u(0,0) - v(0,0) * u(1,0) + v(1,1) * u(0,1) - v(0,1) * u(1,1);
288 b(1,1) = v(1,1) * u(0,0) - v(0,0) * u(1,1) + v(0,1) * u(1,0) - v(1,0) * u(0,1);
293 @brief Link update by pseudo-heatbath
294 @param U link to be updated
296 @param localstate CURAND rng state
298 template <class Float, int NCOLORS>
299 __device__ inline void heatBathSUN( Matrix<complex<Float>,NCOLORS>& U, Matrix<complex<Float>,NCOLORS> F,
300 cuRNGState& localState, Float BetaOverNc ){
302 if ( NCOLORS == 3 ) {
303 //////////////////////////////////////////////////////////////////
305 for( int block = 0; block < NCOLORS; block++ ) {
306 Matrix<complex<T>,3> tmp1 = U * F;
307 Matrix<T,2> r = get_block_su2<T>(tmp1, block);
308 T k = sqrt(r(0,0)*r(0,0)+r(0,1)*r(0,1)+r(1,0)*r(1,0)+r(1,1)*r(1,1));
309 T ap = BetaOverNc * k;
312 //Matrix<T,2> a = generate_su2_matrix<T4, T>(ap, localState);
313 Matrix<T,2> a = generate_su2_matrix_milc<T>(ap, localState);
314 r = mulsu2UVDagger_4<T>( a, r);
315 ///////////////////////////////////////
316 block_su2_to_su3<T>( U, complex( r(0,0), r(1,1) ), complex( r(1,0), r(0,1) ), complex(-r(1,0), r(0,1) ), complex( r(0,0),-r(1,1) ), block );
317 //FLOP_min = (198 + 4 + 15 + 28 + 28 + 84) * 3 = 1071
319 //////////////////////////////////////////////////////////////////
321 for ( int block = 0; block < NCOLORS; block++ ) {
323 IndexBlock<NCOLORS>(block, p, q);
324 complex<Float> a0((Float)0.0, (Float)0.0);
325 complex<Float> a1 = a0;
326 complex<Float> a2 = a0;
327 complex<Float> a3 = a0;
329 for ( int j = 0; j < NCOLORS; j++ ) {
330 a0 += U(p,j) * F(j,p);
331 a1 += U(p,j) * F(j,q);
332 a2 += U(q,j) * F(j,p);
333 a3 += U(q,j) * F(j,q);
336 r(0,0) = a0.x + a3.x;
337 r(0,1) = a1.y + a2.y;
338 r(1,0) = a1.x - a2.x;
339 r(1,1) = a0.y - a3.y;
340 Float k = sqrt(r(0,0) * r(0,0) + r(0,1) * r(0,1) + r(1,0) * r(1,0) + r(1,1) * r(1,1));;
341 Float ap = BetaOverNc * k;
344 Matrix<Float,2> a = generate_su2_matrix_milc<Float>(ap, localState);
345 r = mulsu2UVDagger<Float>( a, r);
346 ///////////////////////////////////////
347 a0 = complex<Float>( r(0,0), r(1,1) );
348 a1 = complex<Float>( r(1,0), r(0,1) );
349 a2 = complex<Float>(-r(1,0), r(0,1) );
350 a3 = complex<Float>( r(0,0),-r(1,1) );
353 for ( int j = 0; j < NCOLORS; j++ ) {
354 tmp0 = a0 * U(p,j) + a1 * U(q,j);
355 U(q,j) = a2 * U(p,j) + a3 * U(q,j);
358 //FLOP_min = (NCOLORS * 64 + 19 + 28 + 28) * 3 = NCOLORS * 192 + 225
360 //////////////////////////////////////////////////////////////////
361 } else if ( NCOLORS > 3 ) {
362 //////////////////////////////////////////////////////////////////
363 //TESTED IN SU(4) SP THIS IS WORST
364 Matrix<complex<Float>,NCOLORS> M = U * F;
365 for ( int block = 0; block < NCOLORS * ( NCOLORS - 1) / 2; block++ ) {
366 int2 id = IndexBlock<NCOLORS>( block );
367 Matrix<Float,2> r = get_block_su2<Float>(M, id);
368 Float k = sqrt(r(0,0) * r(0,0) + r(0,1) * r(0,1) + r(1,0) * r(1,0) + r(1,1) * r(1,1));
369 Float ap = BetaOverNc * k;
372 Matrix<Float,2> a = generate_su2_matrix_milc<Float>(ap, localState);
373 Matrix<Float,2> rr = mulsu2UVDagger<Float>( a, r);
374 ///////////////////////////////////////
375 mul_block_sun<Float, NCOLORS>( rr, U, id);
376 mul_block_sun<Float, NCOLORS>( rr, M, id);
377 ///////////////////////////////////////
379 /* / TESTED IN SU(4) SP THIS IS FASTER
380 for ( int block = 0; block < NCOLORS * ( NCOLORS - 1) / 2; block++ ) {
381 int2 id = IndexBlock<NCOLORS>( block );
382 complex a0 = complex::zero();
383 complex a1 = complex::zero();
384 complex a2 = complex::zero();
385 complex a3 = complex::zero();
387 for ( int j = 0; j < NCOLORS; j++ ) {
388 a0 += U(id.x, j) * F.e[j][id.x];
389 a1 += U(id.x, j) * F.e[j][id.y];
390 a2 += U(id.y, j) * F.e[j][id.x];
391 a3 += U(id.y, j) * F.e[j][id.y];
394 r(0,0) = a0.x + a3.x;
395 r(0,1) = a1.y + a2.y;
396 r(1,0) = a1.x - a2.x;
397 r(1,1) = a0.y - a3.y;
398 T k = sqrt(r(0,0) * r(0,0) + r(0,1) * r(0,1) + r(1,0) * r(1,0) + r(1,1) * r(1,1));
399 T ap = BetaOverNc * k;
402 //Matrix<T,2> a = generate_su2_matrix<T4, T>(ap, localState);
403 Matrix<T,2> a = generate_su2_matrix_milc<T>(ap, localState);
404 r = mulsu2UVDagger<T>( a, r);
405 mul_block_sun<T>( r, U, id); */
407 a0 = complex( r(0,0), r(1,1) );
408 a1 = complex( r(1,0), r(0,1) );
409 a2 = complex(-r(1,0), r(0,1) );
410 a3 = complex( r(0,0),-r(1,1) );
413 for ( int j = 0; j < NCOLORS; j++ ) {
414 tmp0 = a0 * U(id.x, j) + a1 * U(id.y, j);
415 U(id.y, j) = a2 * U(id.x, j) + a3 * U(id.y, j);
421 //////////////////////////////////////////////////////////////////
424 //////////////////////////////////////////////////////////////////////////
426 @brief Link update by overrelaxation
427 @param U link to be updated
430 template <class Float, int NCOLORS>
431 __device__ inline void overrelaxationSUN( Matrix<complex<Float>,NCOLORS>& U, Matrix<complex<Float>,NCOLORS> F ){
433 if ( NCOLORS == 3 ) {
434 //////////////////////////////////////////////////////////////////
436 for( int block = 0; block < 3; block++ ) {
437 Matrix<complex<T>,3> tmp1 = U * F;
438 Matrix<T,2> r = get_block_su2<T>(tmp1, block);
439 //normalize and conjugate
440 Float norm = 1.0 / sqrt(r(0,0)*r(0,0)+r(0,1)*r(0,1)+r(1,0)*r(1,0)+r(1,1)*r(1,1));;
445 ///////////////////////////////////////
446 complex a00 = complex( r(0,0), r(1,1) );
447 complex a01 = complex( r(1,0), r(0,1) );
448 complex a10 = complex(-r(1,0), r(0,1) );
449 complex a11 = complex( r(0,0),-r(1,1) );
450 block_su2_to_su3<T>( U, a00, a01, a10, a11, block );
451 block_su2_to_su3<T>( U, a00, a01, a10, a11, block );
453 //FLOP = (198 + 17 + 84 * 2) * 3 = 1149
455 ///////////////////////////////////////////////////////////////////
456 //This version does not need to multiply all matrix at each block: tmp1 = U * F;
457 //////////////////////////////////////////////////////////////////
459 for ( int block = 0; block < 3; block++ ) {
461 IndexBlock<NCOLORS>(block, p, q);
462 complex<Float> a0((Float)0., (Float)0.);
463 complex<Float> a1 = a0;
464 complex<Float> a2 = a0;
465 complex<Float> a3 = a0;
467 for ( int j = 0; j < NCOLORS; j++ ) {
468 a0 += U(p,j) * F(j,p);
469 a1 += U(p,j) * F(j,q);
470 a2 += U(q,j) * F(j,p);
471 a3 += U(q,j) * F(j,q);
474 r(0,0) = a0.x + a3.x;
475 r(0,1) = a1.y + a2.y;
476 r(1,0) = a1.x - a2.x;
477 r(1,1) = a0.y - a3.y;
478 //normalize and conjugate
479 //r = r.conj_normalize();
480 Float norm = 1.0 / sqrt(r(0,0) * r(0,0) + r(0,1) * r(0,1) + r(1,0) * r(1,0) + r(1,1) * r(1,1));;
487 ///////////////////////////////////////
488 a0 = complex<Float>( r(0,0), r(1,1) );
489 a1 = complex<Float>( r(1,0), r(0,1) );
490 a2 = complex<Float>(-r(1,0), r(0,1) );
491 a3 = complex<Float>( r(0,0),-r(1,1) );
492 complex<Float> tmp0, tmp1;
494 for ( int j = 0; j < NCOLORS; j++ ) {
495 tmp0 = a0 * U(p,j) + a1 * U(q,j);
496 tmp1 = a2 * U(p,j) + a3 * U(q,j);
497 U(p,j) = a0 * tmp0 + a1 * tmp1;
498 U(q,j) = a2 * tmp0 + a3 * tmp1;
500 //FLOP = (NCOLORS * 88 + 17) * 3
502 ///////////////////////////////////////////////////////////////////
504 else if ( NCOLORS > 3 ) {
505 ///////////////////////////////////////////////////////////////////
506 Matrix<complex<Float>,NCOLORS> M = U * F;
507 for ( int block = 0; block < NCOLORS * ( NCOLORS - 1) / 2; block++ ) {
508 int2 id = IndexBlock<NCOLORS>( block );
509 Matrix<Float,2> r = get_block_su2<Float, NCOLORS>(M, id);
510 //normalize and conjugate
511 Float norm = 1.0 / sqrt(r(0,0) * r(0,0) + r(0,1) * r(0,1) + r(1,0) * r(1,0) + r(1,1) * r(1,1));;
516 mul_block_sun<Float, NCOLORS>( r, U, id);
517 mul_block_sun<Float, NCOLORS>( r, U, id);
518 mul_block_sun<Float, NCOLORS>( r, M, id);
519 mul_block_sun<Float, NCOLORS>( r, M, id);
520 ///////////////////////////////////////
522 /* //TESTED IN SU(4) SP THIS IS WORST
523 for( int block = 0; block < NCOLORS * ( NCOLORS - 1) / 2; block++ ) {
524 int2 id = IndexBlock<NCOLORS>( block );
525 complex a0 = complex::zero();
526 complex a1 = complex::zero();
527 complex a2 = complex::zero();
528 complex a3 = complex::zero();
530 for(int j = 0; j < NCOLORS; j++){
531 a0 += U(id.x, j) * F.e[j][id.x];
532 a1 += U(id.x, j) * F.e[j][id.y];
533 a2 += U(id.y, j) * F.e[j][id.x];
534 a3 += U(id.y, j) * F.e[j][id.y];
537 r(0,0) = a0.x + a3.x;
538 r(0,1) = a1.y + a2.y;
539 r(1,0) = a1.x - a2.x;
540 r(1,1) = a0.y - a3.y;
541 //normalize and conjugate
542 Float norm = 1.0 / sqrt(r(0,0)*r(0,0)+r(0,1)*r(0,1)+r(1,0)*r(1,0)+r(1,1)*r(1,1));;
547 //mul_block_sun<T>( r, U, id);
548 //mul_block_sun<T>( r, U, id);
549 ///////////////////////////////////////
550 a0 = complex( r(0,0), r(1,1) );
551 a1 = complex( r(1,0), r(0,1) );
552 a2 = complex(-r(1,0), r(0,1) );
553 a3 = complex( r(0,0),-r(1,1) );
556 for(int j = 0; j < NCOLORS; j++){
557 tmp0 = a0 * U(id.x, j) + a1 * U(id.y, j);
558 tmp1 = a2 * U(id.x, j) + a3 * U(id.y, j);
559 U(id.x, j) = a0 * tmp0 + a1 * tmp1;
560 U(id.y, j) = a2 * tmp0 + a3 * tmp1;
567 template <typename Gauge, typename Float, int NCOLORS>
569 int threads; // number of active threads required
570 int X[4]; // grid dimensions
576 MonteArg(const Gauge &dataOr, GaugeField & data, Float Beta, RNG &rngstate)
577 : dataOr(dataOr), data(data), rngstate(rngstate) {
578 BetaOverNc = Beta / (Float)NCOLORS;
579 for ( int dir = 0; dir < 4; ++dir ) {
580 border[dir] = data.R()[dir];
581 X[dir] = data.X()[dir] - border[dir] * 2;
583 threads = X[0] * X[1] * X[2] * X[3] >> 1;
587 template<typename Float, typename Gauge, int NCOLORS, bool HeatbathOrRelax>
588 __global__ void compute_heatBath(MonteArg<Gauge, Float, NCOLORS> arg, int mu, int parity){
589 int idx = threadIdx.x + blockIdx.x * blockDim.x;
590 if ( idx >= arg.threads ) return;
594 for ( int dr = 0; dr < 4; ++dr ) X[dr] = arg.X[dr];
597 getCoords(x, idx, X, parity);
599 for ( int dr = 0; dr < 4; ++dr ) {
600 x[dr] += arg.border[dr];
601 X[dr] += 2 * arg.border[dr];
603 idx = linkIndex(x,X);
605 Matrix<complex<Float>,NCOLORS> staple;
608 Matrix<complex<Float>,NCOLORS> U;
609 for ( int nu = 0; nu < 4; nu++ ) if ( mu != nu ) {
610 int dx[4] = { 0, 0, 0, 0 };
611 Matrix<complex<Float>,NCOLORS> link = arg.dataOr(nu, idx, parity);
613 U = arg.dataOr(mu, linkIndexShift(x,dx,X), 1 - parity);
617 U = arg.dataOr(nu, linkIndexShift(x,dx,X), 1 - parity);
622 link = arg.dataOr(nu, linkIndexShift(x,dx,X), 1 - parity);
623 U = arg.dataOr(mu, linkIndexShift(x,dx,X), 1 - parity);
624 link = conj(link) * U;
626 U = arg.dataOr(nu, linkIndexShift(x,dx,X), parity);
630 U = arg.dataOr(mu, idx, parity);
631 if ( HeatbathOrRelax ) {
632 cuRNGState localState = arg.rngstate.State()[ id ];
633 heatBathSUN<Float, NCOLORS>( U, conj(staple), localState, arg.BetaOverNc );
634 arg.rngstate.State()[ id ] = localState;
637 overrelaxationSUN<Float, NCOLORS>( U, conj(staple) );
639 arg.dataOr(mu, idx, parity) = U;
643 template<typename Float, typename Gauge, int NCOLORS, int NElems, bool HeatbathOrRelax>
644 class GaugeHB : Tunable {
645 MonteArg<Gauge, Float, NCOLORS> arg;
648 mutable char aux_string[128]; // used as a label in the autotuner
649 unsigned int sharedBytesPerThread() const {
652 unsigned int sharedBytesPerBlock(const TuneParam ¶m) const {
655 //bool tuneSharedBytes() const { return false; } // Don't tune shared memory
656 bool tuneGridDim() const {
658 } // Don't tune the grid dimensions.
659 unsigned int minThreads() const {
664 GaugeHB(MonteArg<Gauge, Float, NCOLORS> &arg)
665 : arg(arg), mu(0), parity(0) {
668 void SetParam(int _mu, int _parity)
674 void apply(const qudaStream_t &stream)
676 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
677 qudaLaunchKernel(compute_heatBath<Float, Gauge, NCOLORS, HeatbathOrRelax>, tp, stream, arg, mu, parity);
680 TuneKey tuneKey() const {
681 std::stringstream vol;
682 vol << arg.X[0] << "x";
683 vol << arg.X[1] << "x";
684 vol << arg.X[2] << "x";
686 sprintf(aux_string,"threads=%d,prec=%lu",arg.threads, sizeof(Float));
687 return TuneKey(vol.str().c_str(), typeid(*this).name(), aux_string);
692 if(HeatbathOrRelax) arg.rngstate.backup();
696 if(HeatbathOrRelax) arg.rngstate.restore();
698 long long flops() const
700 //NEED TO CHECK THIS!!!!!!
701 if ( NCOLORS == 3 ) {
702 long long flop = 2268LL;
703 if ( HeatbathOrRelax ) {
712 long long flop = NCOLORS * NCOLORS * NCOLORS * 84LL;
713 if ( HeatbathOrRelax ) {
714 flop += NCOLORS * NCOLORS * NCOLORS + (NCOLORS * ( NCOLORS - 1) / 2) * (46LL + 48LL + 56LL * NCOLORS);
717 flop += NCOLORS * NCOLORS * NCOLORS + (NCOLORS * ( NCOLORS - 1) / 2) * (17LL + 112LL * NCOLORS);
724 long long bytes() const
726 //NEED TO CHECK THIS!!!!!!
727 if ( NCOLORS == 3 ) {
728 long long byte = 20LL * NElems * sizeof(Float);
729 if ( HeatbathOrRelax ) byte += 2LL * sizeof(cuRNGState);
733 long long byte = 20LL * NCOLORS * NCOLORS * 2 * sizeof(Float);
734 if ( HeatbathOrRelax ) byte += 2LL * sizeof(cuRNGState);
741 template <typename Float, int nColor, QudaReconstructType recon>
743 MonteAlg(GaugeField& data, RNG &rngstate, Float Beta, int nhb, int nover)
745 TimeProfile profileHBOVR("HeatBath_OR_Relax", false);
746 using Gauge = typename gauge_mapper<Float, recon>::type;
748 MonteArg<Gauge, Float, nColor> montearg(Gauge(data), data, Beta, rngstate);
749 if (getVerbosity() >= QUDA_SUMMARIZE) profileHBOVR.TPSTART(QUDA_PROFILE_COMPUTE);
750 GaugeHB<Float, Gauge, nColor, recon, true> hb(montearg);
751 for ( int step = 0; step < nhb; ++step ) {
752 for ( int parity = 0; parity < 2; ++parity ) {
753 for ( int mu = 0; mu < 4; ++mu ) {
754 hb.SetParam(mu, parity);
756 PGaugeExchange(data, mu, parity);
760 if (getVerbosity() >= QUDA_VERBOSE) {
761 qudaDeviceSynchronize();
762 profileHBOVR.TPSTOP(QUDA_PROFILE_COMPUTE);
763 double secs = profileHBOVR.Last(QUDA_PROFILE_COMPUTE);
764 double gflops = (hb.flops() * 8 * nhb * 1e-9) / (secs);
765 double gbytes = hb.bytes() * 8 * nhb / (secs * 1e9);
766 printfQuda("HB: Time = %6.6f s, Gflop/s = %6.1f, GB/s = %6.1f\n", secs, gflops * comm_size(), gbytes * comm_size());
769 if (getVerbosity() >= QUDA_VERBOSE) profileHBOVR.TPSTART(QUDA_PROFILE_COMPUTE);
770 GaugeHB<Float, Gauge, nColor, recon, false> relax(montearg);
771 for ( int step = 0; step < nover; ++step ) {
772 for ( int parity = 0; parity < 2; ++parity ) {
773 for ( int mu = 0; mu < 4; ++mu ) {
774 relax.SetParam(mu, parity);
776 PGaugeExchange(data, mu, parity);
780 if (getVerbosity() >= QUDA_VERBOSE) {
781 qudaDeviceSynchronize();
782 profileHBOVR.TPSTOP(QUDA_PROFILE_COMPUTE);
783 double secs = profileHBOVR.Last(QUDA_PROFILE_COMPUTE);
784 double gflops = (relax.flops() * 8 * nover * 1e-9) / (secs);
785 double gbytes = relax.bytes() * 8 * nover / (secs * 1e9);
786 printfQuda("OVR: Time = %6.6f s, Gflop/s = %6.1f, GB/s = %6.1f\n", secs, gflops * comm_size(), gbytes * comm_size());
791 /** @brief Perform heatbath and overrelaxation. Performs nhb heatbath steps followed by nover overrelaxation steps.
793 * @param[in,out] data Gauge field
794 * @param[in,out] rngstate state of the CURAND random number generator
795 * @param[in] Beta inverse of the gauge coupling, beta = 2 Nc / g_0^2
796 * @param[in] nhb number of heatbath steps
797 * @param[in] nover number of overrelaxation steps
799 void Monte(GaugeField& data, RNG &rngstate, double Beta, int nhb, int nover) {
801 instantiate<MonteAlg>(data, rngstate, (float)Beta, nhb, nover);
803 errorQuda("Pure gauge code has not been built");
804 #endif // GPU_GAUGE_ALG