13 extern cudaStream_t *
stream;
19 #if defined(MULTI_GPU) && (defined(GPU_FATLINK) || defined(GPU_GAUGE_FORCE)|| defined(GPU_FERMION_FORCE) || defined(GPU_HISQ_FORCE))
32 #define gaugeSiteSize 18
35 static void* fwd_nbr_staple_cpu[4];
36 static void* back_nbr_staple_cpu[4];
37 static void* fwd_nbr_staple_sendbuf_cpu[4];
38 static void* back_nbr_staple_sendbuf_cpu[4];
41 static void* fwd_nbr_staple_gpu[4];
42 static void* back_nbr_staple_gpu[4];
44 static void* fwd_nbr_staple[4];
45 static void* back_nbr_staple[4];
46 static void* fwd_nbr_staple_sendbuf[4];
47 static void* back_nbr_staple_sendbuf[4];
53 static int Vs[4],
Vsh[4];
60 } llfat_recv, llfat_send;
63 extern void setup_dims_in_gauge(
int *XX);
69 for (
int d=0; d< 4; d++) {
80 Vs[0] =
Vs_x = X[1]*X[2]*X[3];
81 Vs[1] =
Vs_y = X[0]*X[2]*X[3];
82 Vs[2] =
Vs_z = X[0]*X[1]*X[3];
83 Vs[3] =
Vs_t = X[0]*X[1]*X[2];
94 static bool initialized =
false;
96 if (initialized)
return;
99 for (
int i=0; i < 4; i++) {
114 fwd_nbr_staple_sendbuf_cpu[i] =
safe_malloc(packet_size);
115 back_nbr_staple_sendbuf_cpu[i] =
safe_malloc(packet_size);
122 template<
typename Float>
123 void exchange_sitelink_diag(
int* X,
Float** sitelink,
Float** ghost_sitelink_diag,
int optflag)
138 for(
int nu =
XUP; nu <=
TUP; nu++){
148 for(dir1=0; dir1 < 4; dir1 ++){
149 if(dir1 != nu && dir1 !=
mu){
153 for(dir2=0; dir2 < 4; dir2 ++){
154 if(dir2 != nu && dir2 !=
mu && dir2 != dir1){
159 if(dir1 == 4 || dir2 == 4){
190 template<
typename Float>
192 exchange_sitelink(
int*X,
Float** sitelink,
Float** ghost_sitelink,
Float** ghost_sitelink_diag,
193 Float** sitelink_fwd_sendbuf,
Float** sitelink_back_sendbuf,
int optflag)
201 Float* even_sitelink_back_src = sitelink[i];
205 if(dims[3] % 2 == 0){
206 memcpy(sitelink_back_dst, even_sitelink_back_src, len);
207 memcpy(sitelink_back_dst +
Vsh_t*gaugeSiteSize, odd_sitelink_back_src, len);
210 memcpy(sitelink_back_dst, odd_sitelink_back_src, len);
211 memcpy(sitelink_back_dst +
Vsh_t*gaugeSiteSize, even_sitelink_back_src, len);
216 Float* even_sitelink_fwd_src = sitelink[i] + (volumeCB -
Vsh_t)*gaugeSiteSize;
217 Float* odd_sitelink_fwd_src = sitelink[i] + volumeCB*gaugeSiteSize + (volumeCB -
Vsh_t)*gaugeSiteSize;
219 if(dims[3] % 2 == 0){
220 memcpy(sitelink_fwd_dst, even_sitelink_fwd_src, len);
221 memcpy(sitelink_fwd_dst +
Vsh_t*gaugeSiteSize, odd_sitelink_fwd_src, len);
224 memcpy(sitelink_fwd_dst, odd_sitelink_fwd_src, len);
225 memcpy(sitelink_fwd_dst +
Vsh_t*gaugeSiteSize, even_sitelink_fwd_src, len);
240 Float* ghost_sitelink_back = ghost_sitelink[
dir];
269 exchange_sitelink_diag(X, sitelink, ghost_sitelink_diag, optflag);
277 void** sitelink,
void** ghost_sitelink,
278 void** ghost_sitelink_diag,
282 static void* sitelink_fwd_sendbuf[4];
283 static void* sitelink_back_sendbuf[4];
284 static bool allocated =
false;
287 for (
int i=0; i<4; i++) {
288 int nbytes = 4*
Vs[i]*gaugeSiteSize*gPrecision;
291 memset(sitelink_fwd_sendbuf[i], 0, nbytes);
292 memset(sitelink_back_sendbuf[i], 0, nbytes);
298 exchange_sitelink(X, (
double**)sitelink, (
double**)(ghost_sitelink), (
double**)ghost_sitelink_diag,
299 (
double**)sitelink_fwd_sendbuf, (
double**)sitelink_back_sendbuf, optflag);
301 exchange_sitelink(X, (
float**)sitelink, (
float**)(ghost_sitelink), (
float**)ghost_sitelink_diag,
302 (
float**)sitelink_fwd_sendbuf, (
float**)sitelink_back_sendbuf, optflag);
306 for(
int i=0;i < 4;i++){
315 #define MEMCOPY_GAUGE_FIELDS_GRID_TO_BUF(ghost_buf, dst_idx, sitelink, src_idx, num, dir) \
320 dst_idx += R[dir]*slice_3d[dir]/2; \
322 if(cpu_order == QUDA_QDP_GAUGE_ORDER){ \
323 for(int linkdir=0; linkdir < 4; linkdir++){ \
324 char* src = (char*) sitelink[linkdir] + (src_idx)*gaugebytes; \
325 char* dst = ((char*)ghost_buf[dir])+ linkdir*R[dir]*slice_3d[dir]*gaugebytes + (dst_idx)*gaugebytes; \
326 memcpy(dst, src, gaugebytes*(num)); \
329 char* src = ((char*)sitelink)+ 4*(src_idx)*gaugebytes; \
330 char* dst = ((char*)ghost_buf[dir]) + 4*(dst_idx)*gaugebytes; \
331 memcpy(dst, src, 4*gaugebytes*(num)); \
334 #define MEMCOPY_GAUGE_FIELDS_BUF_TO_GRID(sitelink, dst_idx, ghost_buf, src_idx, num, dir) \
336 if(commDimPartitioned(dir)){ \
337 src_idx += R[dir]*slice_3d[dir]/2; \
343 if(cpu_order == QUDA_QDP_GAUGE_ORDER){ \
344 for(int linkdir=0; linkdir < 4; linkdir++){ \
346 if(commDimPartitioned(dir)){ \
347 src = ((char*)ghost_buf[dir])+ linkdir*R[dir]*slice_3d[dir]*gaugebytes + (src_idx)*gaugebytes; \
349 src = ((char*)sitelink[linkdir])+ (src_idx)*gaugebytes; \
351 char* dst = (char*) sitelink[linkdir] + (dst_idx)*gaugebytes; \
352 memcpy(dst, src, gaugebytes*(num)); \
356 if(commDimPartitioned(dir)){ \
357 src=((char*)ghost_buf[dir]) + 4*(src_idx)*gaugebytes; \
359 src = ((char*)sitelink)+ 4*(src_idx)*gaugebytes; \
361 char* dst = ((char*)sitelink) + 4*(dst_idx)*gaugebytes; \
362 memcpy(dst, src, 4*gaugebytes*(num)); \
365 #define MEMCOPY_GAUGE_FIELDS_BUF_TO_GRID_T(sitelink, ghost_buf, dst_face, src_face, dir) \
367 int even_dst_idx = (dst_face*E3E2E1)/2; \
369 if(commDimPartitioned(dir)){ \
372 even_src_idx = (src_face*E3E2E1)/2; \
375 int odd_dst_idx = even_dst_idx+Vh_ex; \
377 if(commDimPartitioned(dir)){ \
378 odd_src_idx = R[dir]*slice_3d[dir]/2; \
380 odd_src_idx = even_src_idx+Vh_ex; \
382 if(cpu_order == QUDA_QDP_GAUGE_ORDER){ \
383 for(int linkdir=0; linkdir < 4; linkdir ++){ \
384 char* dst = (char*)sitelink[linkdir]; \
386 if(commDimPartitioned(dir)){ \
387 src = ((char*)ghost_buf[dir]) + linkdir*R[dir]*slice_3d[dir]*gaugebytes; \
389 src = (char*)sitelink[linkdir]; \
391 memcpy(dst + even_dst_idx * gaugebytes, src + even_src_idx*gaugebytes, R[dir]*slice_3d[dir]*gaugebytes/2); \
392 memcpy(dst + odd_dst_idx * gaugebytes, src + odd_src_idx*gaugebytes, R[dir]*slice_3d[dir]*gaugebytes/2); \
395 char* dst = (char*)sitelink; \
397 if(commDimPartitioned(dir)){ \
398 src = (char*)ghost_buf[dir]; \
400 src = (char*)sitelink; \
402 memcpy(dst+4*even_dst_idx*gaugebytes, src+4*even_src_idx*gaugebytes, 4*R[dir]*slice_3d[dir]*gaugebytes/2); \
403 memcpy(dst+4*odd_dst_idx*gaugebytes, src+4*odd_src_idx*gaugebytes, 4*R[dir]*slice_3d[dir]*gaugebytes/2); \
415 E1 = X[0]+2*R[0]; E2 = X[1]+2*R[1]; E3 = X[2]+2*R[2]; E4 = X[3]+2*R[3];
423 int Vh_ex = E4*E3*E2*E1/2;
426 int starta[] = {R[3], R[3], R[3], 0};
427 int enda[] = {X[3]+R[3], X[3]+R[3], X[3]+R[3], X[2]+2*R[2]};
429 int startb[] = {R[2], R[2], 0, 0};
430 int endb[] = {X[2]+R[2], X[2]+R[2], X[1]+2*R[1], X[1]+2*R[1]};
432 int startc[] = {R[1], 0, 0, 0};
433 int endc[] = {X[1]+R[1], X[0]+2*R[0], X[0]+2*R[0], X[0]+2*R[0]};
443 {E3E2,
E2, 1, E4E3E2},
444 {E3E1,
E1, 1, E4E3E1},
449 int slice_3d[] = { E4E3E2, E4E3E1, E4E2E1, E3E2E1};
451 for(
int i=0;i < 4;i++){
452 len[i] = slice_3d[i] * R[i] * 4*gaugeSiteSize*gPrecision;
455 void* ghost_sitelink_fwd_sendbuf[4];
456 void* ghost_sitelink_back_sendbuf[4];
457 void* ghost_sitelink_fwd[4];
458 void* ghost_sitelink_back[4];
460 for(
int i=0; i<4; i++) {
462 ghost_sitelink_fwd_sendbuf[i] =
safe_malloc(len[i]);
463 ghost_sitelink_back_sendbuf[i] =
safe_malloc(len[i]);
468 int gaugebytes = gaugeSiteSize*gPrecision;
475 for(d=R[
dir]; d < 2*R[
dir]; d++)
476 for(a=starta[
dir];a < enda[
dir]; a++)
477 for(b=startb[
dir]; b < endb[
dir]; b++)
479 if(f_main[
dir][2] != 1 || f_bound[
dir][2] !=1){
480 for (c=startc[
dir]; c < endc[
dir]; c++){
481 int oddness = (a+b+c+d)%2;
482 int src_idx = ( a*f_main[
dir][0] + b*f_main[
dir][1]+ c*f_main[
dir][2] + d*f_main[
dir][3])>> 1;
483 int dst_idx = ( a*f_bound[
dir][0] + b*f_bound[
dir][1]+ c*f_bound[
dir][2] + (d-R[
dir])*f_bound[
dir][3])>> 1;
485 int src_oddness = oddness;
486 int dst_oddness = oddness;
487 if((X[dir] % 2 ==1) && (
commDim(dir) > 1)){
488 dst_oddness = 1-oddness;
491 MEMCOPY_GAUGE_FIELDS_GRID_TO_BUF(ghost_sitelink_back_sendbuf, dst_idx, sitelink, src_idx, 1, dir);
495 for(
int loop=0; loop < 2; loop++){
498 int oddness = (a+b+c+d)%2;
499 int src_idx = ( a*f_main[
dir][0] + b*f_main[
dir][1]+ c*f_main[
dir][2] + d*f_main[
dir][3])>> 1;
500 int dst_idx = ( a*f_bound[
dir][0] + b*f_bound[
dir][1]+ c*f_bound[
dir][2] + (d-R[
dir])*f_bound[dir][3])>> 1;
502 int src_oddness = oddness;
503 int dst_oddness = oddness;
504 if((X[dir] % 2 ==1) && (
commDim(dir) > 1)){
505 dst_oddness = 1-oddness;
507 MEMCOPY_GAUGE_FIELDS_GRID_TO_BUF(ghost_sitelink_back_sendbuf, dst_idx, sitelink, src_idx, (endc[dir]-c+1)/2, dir);
515 for(d=X[dir]; d < X[
dir]+R[
dir]; d++) {
516 for(a=starta[dir];a < enda[
dir]; a++) {
517 for(b=startb[dir]; b < endb[
dir]; b++) {
519 if(f_main[dir][2] != 1 || f_bound[dir][2] !=1){
520 for (c=startc[dir]; c < endc[
dir]; c++){
521 int oddness = (a+b+c+d)%2;
522 int src_idx = ( a*f_main[
dir][0] + b*f_main[
dir][1]+ c*f_main[
dir][2] + d*f_main[
dir][3])>> 1;
523 int dst_idx = ( a*f_bound[
dir][0] + b*f_bound[
dir][1]+ c*f_bound[
dir][2] + (d-X[
dir])*f_bound[dir][3])>> 1;
525 int src_oddness = oddness;
526 int dst_oddness = oddness;
527 if((X[dir] % 2 ==1) && (
commDim(dir) > 1)){
528 dst_oddness = 1-oddness;
531 MEMCOPY_GAUGE_FIELDS_GRID_TO_BUF(ghost_sitelink_fwd_sendbuf, dst_idx, sitelink, src_idx, 1,dir);
534 for(
int loop=0; loop < 2; loop++){
537 int oddness = (a+b+c+d)%2;
538 int src_idx = ( a*f_main[
dir][0] + b*f_main[
dir][1]+ c*f_main[
dir][2] + d*f_main[
dir][3])>> 1;
539 int dst_idx = ( a*f_bound[
dir][0] + b*f_bound[
dir][1]+ c*f_bound[
dir][2] + (d-X[
dir])*f_bound[dir][3])>> 1;
541 int src_oddness = oddness;
542 int dst_oddness = oddness;
543 if((X[dir] % 2 ==1) && (
commDim(dir) > 1)){
544 dst_oddness = 1-oddness;
546 MEMCOPY_GAUGE_FIELDS_GRID_TO_BUF(ghost_sitelink_fwd_sendbuf, dst_idx, sitelink, src_idx, (endc[dir]-c+1)/2,dir);
586 for(d=0; d < R[
dir]; d++) {
587 for(a=starta[dir];a < enda[
dir]; a++) {
588 for(b=startb[dir]; b < endb[
dir]; b++) {
590 if(f_main[dir][2] != 1 || f_bound[dir][2] !=1){
591 for (c=startc[dir]; c < endc[
dir]; c++){
592 int oddness = (a+b+c+d)%2;
593 int dst_idx = ( a*f_main[
dir][0] + b*f_main[
dir][1]+ c*f_main[
dir][2] + d*f_main[
dir][3])>> 1;
596 src_idx = ( a*f_bound[
dir][0] + b*f_bound[
dir][1]+ c*f_bound[
dir][2] + d*f_bound[
dir][3])>> 1;
598 src_idx = ( a*f_main[
dir][0] + b*f_main[
dir][1]+ c*f_main[
dir][2] + (d+X[
dir])*f_main[dir][3])>> 1;
601 MEMCOPY_GAUGE_FIELDS_BUF_TO_GRID(sitelink, dst_idx, ghost_sitelink_back, src_idx, 1, dir);
608 for(
int loop =0;loop <2;loop++){
609 int c=startc[
dir]+loop;
611 int oddness = (a+b+c+d)%2;
612 int dst_idx = ( a*f_main[
dir][0] + b*f_main[
dir][1]+ c*f_main[
dir][2] + d*f_main[
dir][3])>> 1;
615 src_idx = ( a*f_bound[
dir][0] + b*f_bound[
dir][1]+ c*f_bound[
dir][2] + d*f_bound[
dir][3])>> 1;
617 src_idx = ( a*f_main[
dir][0] + b*f_main[
dir][1]+ c*f_main[
dir][2] + (d+X[
dir])*f_main[dir][3])>> 1;
620 MEMCOPY_GAUGE_FIELDS_BUF_TO_GRID(sitelink, dst_idx, ghost_sitelink_back, src_idx, (endc[dir]-c+1)/2, dir);
633 MEMCOPY_GAUGE_FIELDS_BUF_TO_GRID_T(sitelink, ghost_sitelink_back, 0, X[3], dir)
639 for(d=X[dir]+R[dir]; d < X[
dir]+2*R[
dir]; d++) {
640 for(a=starta[dir];a < enda[
dir]; a++) {
641 for(b=startb[dir]; b < endb[
dir]; b++) {
643 if(f_main[dir][2] != 1 || f_bound[dir][2] != 1){
644 for (c=startc[dir]; c < endc[
dir]; c++){
645 int oddness = (a+b+c+d)%2;
646 int dst_idx = ( a*f_main[
dir][0] + b*f_main[
dir][1]+ c*f_main[
dir][2] + d*f_main[
dir][3])>> 1;
649 src_idx = ( a*f_bound[
dir][0] + b*f_bound[
dir][1]+ c*f_bound[
dir][2] + (d-X[
dir]-R[
dir])*f_bound[dir][3])>> 1;
651 src_idx = ( a*f_main[
dir][0] + b*f_main[
dir][1]+ c*f_main[
dir][2] + (d-X[
dir])*f_main[dir][3])>> 1;
654 MEMCOPY_GAUGE_FIELDS_BUF_TO_GRID(sitelink, dst_idx, ghost_sitelink_fwd, src_idx, 1, dir);
658 for(
int loop =0; loop < 2; loop++){
660 c=startc[
dir] + loop;
662 int oddness = (a+b+c+d)%2;
663 int dst_idx = ( a*f_main[
dir][0] + b*f_main[
dir][1]+ c*f_main[
dir][2] + d*f_main[
dir][3])>> 1;
666 src_idx = ( a*f_bound[
dir][0] + b*f_bound[
dir][1]+ c*f_bound[
dir][2] + (d-X[
dir]-R[
dir])*f_bound[dir][3])>> 1;
668 src_idx = ( a*f_main[
dir][0] + b*f_main[
dir][1]+ c*f_main[
dir][2] + (d-X[
dir])*f_main[dir][3])>> 1;
670 MEMCOPY_GAUGE_FIELDS_BUF_TO_GRID(sitelink, dst_idx, ghost_sitelink_fwd, src_idx, (endc[dir]-c+1)/2, dir);
683 MEMCOPY_GAUGE_FIELDS_BUF_TO_GRID_T(sitelink, ghost_sitelink_fwd, (X[3]+R[3]), 2, dir)
690 for(
int dir=0;dir < 4;dir++){
692 host_free(ghost_sitelink_fwd_sendbuf[dir]);
693 host_free(ghost_sitelink_back_sendbuf[dir]);
702 template<
typename Float>
704 do_exchange_cpu_staple(
Float* staple,
Float** ghost_staple,
Float** staple_fwd_sendbuf,
Float** staple_back_sendbuf,
int* X)
710 Float* even_staple_back_src = staple;
712 Float* staple_back_dst = staple_back_sendbuf[3];
714 if(dims[3] % 2 == 0){
715 memcpy(staple_back_dst, even_staple_back_src, len);
716 memcpy(staple_back_dst +
Vsh_t*gaugeSiteSize, odd_staple_back_src, len);
719 memcpy(staple_back_dst, odd_staple_back_src, len);
720 memcpy(staple_back_dst +
Vsh_t*gaugeSiteSize, even_staple_back_src, len);
724 Float* even_staple_fwd_src = staple + (volumeCB -
Vsh_t)*gaugeSiteSize;
725 Float* odd_staple_fwd_src = staple + volumeCB*gaugeSiteSize + (volumeCB -
Vsh_t)*gaugeSiteSize;
726 Float* staple_fwd_dst = staple_fwd_sendbuf[3];
727 if(dims[3] % 2 == 0){
728 memcpy(staple_fwd_dst, even_staple_fwd_src, len);
729 memcpy(staple_fwd_dst +
Vsh_t*gaugeSiteSize, odd_staple_fwd_src, len);
732 memcpy(staple_fwd_dst, odd_staple_fwd_src, len);
733 memcpy(staple_fwd_dst +
Vsh_t*gaugeSiteSize, even_staple_fwd_src, len);
750 for (
int dir=0;dir < 4; dir++) {
752 Float *ghost_staple_back = ghost_staple[
dir];
789 void *staple_fwd_sendbuf[4];
790 void *staple_back_sendbuf[4];
792 for(
int i=0;i < 4; i++){
793 staple_fwd_sendbuf[i] =
safe_malloc(Vs[i]*gaugeSiteSize*gPrecision);
794 staple_back_sendbuf[i] =
safe_malloc(Vs[i]*gaugeSiteSize*gPrecision);
798 do_exchange_cpu_staple((
double*)staple, (
double**)ghost_staple,
799 (
double**)staple_fwd_sendbuf, (
double**)staple_back_sendbuf, X);
801 do_exchange_cpu_staple((
float*)staple, (
float**)ghost_staple,
802 (
float**)staple_fwd_sendbuf, (
float**)staple_back_sendbuf, X);
805 for (
int i=0;i < 4;i++) {
821 void* even = cudaStaple->
Even_p();
822 void* odd = cudaStaple->
Odd_p();
823 int volume = cudaStaple->
VolumeCB();
828 dir, whichway, fwd_nbr_staple_gpu, back_nbr_staple_gpu,
829 fwd_nbr_staple_sendbuf, back_nbr_staple_sendbuf, stream);
838 cudaStreamSynchronize(*stream);
842 int len = Vs[dim]*gaugeSiteSize*
prec;
851 memcpy(fwd_nbr_staple_sendbuf_cpu[dim], fwd_nbr_staple_sendbuf[dim], len);
865 memcpy(back_nbr_staple_sendbuf_cpu[dim], back_nbr_staple_sendbuf[dim], len);
883 void* even = cudaStaple->
Even_p();
884 void* odd = cudaStaple->
Odd_p();
885 int volume = cudaStaple->
VolumeCB();
887 int stride = cudaStaple->
Stride();
892 int len = Vs[dim]*gaugeSiteSize*
prec;
907 memcpy(back_nbr_staple[dim], back_nbr_staple_cpu[dim], len);
922 dim,
QUDA_FORWARDS, fwd_nbr_staple, back_nbr_staple, stream);
924 memcpy(fwd_nbr_staple[dim], fwd_nbr_staple_cpu[dim], len);
926 dim,
QUDA_FORWARDS, fwd_nbr_staple, back_nbr_staple, stream);
935 for (
int i=0; i<4; i++) {
937 if(fwd_nbr_staple_gpu[i]){
938 device_free(fwd_nbr_staple_gpu[i]); fwd_nbr_staple_gpu[i] = NULL;
940 if(back_nbr_staple_gpu[i]){
941 device_free(back_nbr_staple_gpu[i]); back_nbr_staple_gpu[i] = NULL;
945 if(fwd_nbr_staple_cpu[i]){
946 host_free(fwd_nbr_staple_cpu[i]); fwd_nbr_staple_cpu[i] = NULL;
948 if(back_nbr_staple_cpu[i]){
949 host_free(back_nbr_staple_cpu[i]);back_nbr_staple_cpu[i] = NULL;
951 if(fwd_nbr_staple_sendbuf_cpu[i]){
952 host_free(fwd_nbr_staple_sendbuf_cpu[i]); fwd_nbr_staple_sendbuf_cpu[i] = NULL;
954 if(back_nbr_staple_sendbuf_cpu[i]){
955 host_free(back_nbr_staple_sendbuf_cpu[i]); back_nbr_staple_sendbuf_cpu[i] = NULL;
959 if(fwd_nbr_staple[i]){
960 host_free(fwd_nbr_staple[i]); fwd_nbr_staple[i] = NULL;
962 if(back_nbr_staple[i]){
963 host_free(back_nbr_staple[i]); back_nbr_staple[i] = NULL;
965 if(fwd_nbr_staple_sendbuf[i]){
966 host_free(fwd_nbr_staple_sendbuf[i]); fwd_nbr_staple_sendbuf[i] = NULL;
968 if(back_nbr_staple_sendbuf[i]){
969 host_free(back_nbr_staple_sendbuf[i]); back_nbr_staple_sendbuf[i] = NULL;