QUDA  v0.5.0
A library for QCD on GPUs
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
face_gauge.cpp
Go to the documentation of this file.
1 #include <cstdio>
2 #include <cstdlib>
3 #include <string.h>
4 #include <sys/time.h>
5 
6 #include <quda_internal.h>
7 #include <comm_quda.h>
8 #include <fat_force_quda.h>
9 #include <face_quda.h>
10 
11 using namespace quda;
12 
13 extern cudaStream_t *stream;
14 
15 /**************************************************************
16  * Staple exchange routine
17  * used in fat link computation
18  ***************************************************************/
19 #if defined(MULTI_GPU) && (defined(GPU_FATLINK) || defined(GPU_GAUGE_FORCE)|| defined(GPU_FERMION_FORCE) || defined(GPU_HISQ_FORCE))
20 
21 enum {
22  XUP = 0,
23  YUP = 1,
24  ZUP = 2,
25  TUP = 3,
26  TDOWN = 4,
27  ZDOWN = 5,
28  YDOWN = 6,
29  XDOWN = 7
30 };
31 
32 #define gaugeSiteSize 18
33 
34 #ifndef GPU_DIRECT
35 static void* fwd_nbr_staple_cpu[4];
36 static void* back_nbr_staple_cpu[4];
37 static void* fwd_nbr_staple_sendbuf_cpu[4];
38 static void* back_nbr_staple_sendbuf_cpu[4];
39 #endif
40 
41 static void* fwd_nbr_staple_gpu[4];
42 static void* back_nbr_staple_gpu[4];
43 
44 static void* fwd_nbr_staple[4];
45 static void* back_nbr_staple[4];
46 static void* fwd_nbr_staple_sendbuf[4];
47 static void* back_nbr_staple_sendbuf[4];
48 
49 static int dims[4];
50 static int X1,X2,X3,X4;
51 static int V;
52 static int volumeCB;
53 static int Vs[4], Vsh[4];
54 static int Vs_x, Vs_y, Vs_z, Vs_t;
55 static int Vsh_x, Vsh_y, Vsh_z, Vsh_t;
56 
57 static struct {
58  MsgHandle *fwd[4];
59  MsgHandle *back[4];
60 } llfat_recv, llfat_send;
61 
62 #include "gauge_field.h"
63 extern void setup_dims_in_gauge(int *XX);
64 
65 static void
66 setup_dims(int* X)
67 {
68  V = 1;
69  for (int d=0; d< 4; d++) {
70  V *= X[d];
71  dims[d] = X[d];
72  }
73  volumeCB = V/2;
74 
75  X1=X[0];
76  X2=X[1];
77  X3=X[2];
78  X4=X[3];
79 
80  Vs[0] = Vs_x = X[1]*X[2]*X[3];
81  Vs[1] = Vs_y = X[0]*X[2]*X[3];
82  Vs[2] = Vs_z = X[0]*X[1]*X[3];
83  Vs[3] = Vs_t = X[0]*X[1]*X[2];
84 
85  Vsh[0] = Vsh_x = Vs_x/2;
86  Vsh[1] = Vsh_y = Vs_y/2;
87  Vsh[2] = Vsh_z = Vs_z/2;
88  Vsh[3] = Vsh_t = Vs_t/2;
89 }
90 
91 
93 {
94  static bool initialized = false;
95 
96  if (initialized) return;
97  initialized = true;
98 
99  for (int i=0; i < 4; i++) {
100 
101  size_t packet_size = Vs[i]*gaugeSiteSize*prec;
102 
103  fwd_nbr_staple_gpu[i] = device_malloc(packet_size);
104  back_nbr_staple_gpu[i] = device_malloc(packet_size);
105 
106  fwd_nbr_staple[i] = pinned_malloc(packet_size);
107  back_nbr_staple[i] = pinned_malloc(packet_size);
108  fwd_nbr_staple_sendbuf[i] = pinned_malloc(packet_size);
109  back_nbr_staple_sendbuf[i] = pinned_malloc(packet_size);
110 
111 #ifndef GPU_DIRECT
112  fwd_nbr_staple_cpu[i] = safe_malloc(packet_size);
113  back_nbr_staple_cpu[i] = safe_malloc(packet_size);
114  fwd_nbr_staple_sendbuf_cpu[i] = safe_malloc(packet_size);
115  back_nbr_staple_sendbuf_cpu[i] = safe_malloc(packet_size);
116 #endif
117 
118  }
119 }
120 
121 
122 template<typename Float>
123 void exchange_sitelink_diag(int* X, Float** sitelink, Float** ghost_sitelink_diag, int optflag)
124 {
125  /*
126  nu | |
127  |__________|
128  mu
129 
130  * There are total 12 different combinations for (nu,mu)
131  * since nu/mu = X,Y,Z,T and nu != mu
132  * For each combination, we need to communicate with the corresponding
133  * neighbor and get the diag ghost data
134  * The neighbor we need to get data from is dx[nu]=-1, dx[mu]= +1
135  * and we need to send our data to neighbor with dx[nu]=+1, dx[mu]=-1
136  */
137 
138  for(int nu = XUP; nu <=TUP; nu++){
139  for(int mu = XUP; mu <= TUP; mu++){
140  if(nu == mu){
141  continue;
142  }
143  if(optflag && (!commDimPartitioned(mu) || !commDimPartitioned(nu))){
144  continue;
145  }
146 
147  int dir1, dir2; //other two dimensions
148  for(dir1=0; dir1 < 4; dir1 ++){
149  if(dir1 != nu && dir1 != mu){
150  break;
151  }
152  }
153  for(dir2=0; dir2 < 4; dir2 ++){
154  if(dir2 != nu && dir2 != mu && dir2 != dir1){
155  break;
156  }
157  }
158 
159  if(dir1 == 4 || dir2 == 4){
160  errorQuda("Invalid dir1/dir2");
161  }
162  int len = X[dir1]*X[dir2]*gaugeSiteSize*sizeof(Float);
163  void *sendbuf = safe_malloc(len);
164 
165  pack_gauge_diag(sendbuf, X, (void**)sitelink, nu, mu, dir1, dir2, (QudaPrecision)sizeof(Float));
166 
167  int dx[4] = {0};
168  dx[nu] = -1;
169  dx[mu] = +1;
170  MsgHandle *mh_recv = comm_declare_receive_displaced(ghost_sitelink_diag[nu*4+mu], dx, len);
171  comm_start(mh_recv);
172 
173  dx[nu] = +1;
174  dx[mu] = -1;
175  MsgHandle *mh_send = comm_declare_send_displaced(sendbuf, dx, len);
176  comm_start(mh_send);
177 
178  comm_wait(mh_send);
179  comm_wait(mh_recv);
180 
181  comm_free(mh_send);
182  comm_free(mh_recv);
183 
184  host_free(sendbuf);
185  }
186  }
187 }
188 
189 
190 template<typename Float>
191 void
192 exchange_sitelink(int*X, Float** sitelink, Float** ghost_sitelink, Float** ghost_sitelink_diag,
193  Float** sitelink_fwd_sendbuf, Float** sitelink_back_sendbuf, int optflag)
194 {
195 
196 
197 #if 0
198  int i;
199  int len = Vsh_t*gaugeSiteSize*sizeof(Float);
200  for(i=0;i < 4;i++){
201  Float* even_sitelink_back_src = sitelink[i];
202  Float* odd_sitelink_back_src = sitelink[i] + volumeCB*gaugeSiteSize;
203  Float* sitelink_back_dst = sitelink_back_sendbuf[3] + 2*i*Vsh_t*gaugeSiteSize;
204 
205  if(dims[3] % 2 == 0){
206  memcpy(sitelink_back_dst, even_sitelink_back_src, len);
207  memcpy(sitelink_back_dst + Vsh_t*gaugeSiteSize, odd_sitelink_back_src, len);
208  }else{
209  //switching odd and even ghost sitelink
210  memcpy(sitelink_back_dst, odd_sitelink_back_src, len);
211  memcpy(sitelink_back_dst + Vsh_t*gaugeSiteSize, even_sitelink_back_src, len);
212  }
213  }
214 
215  for(i=0;i < 4;i++){
216  Float* even_sitelink_fwd_src = sitelink[i] + (volumeCB - Vsh_t)*gaugeSiteSize;
217  Float* odd_sitelink_fwd_src = sitelink[i] + volumeCB*gaugeSiteSize + (volumeCB - Vsh_t)*gaugeSiteSize;
218  Float* sitelink_fwd_dst = sitelink_fwd_sendbuf[3] + 2*i*Vsh_t*gaugeSiteSize;
219  if(dims[3] % 2 == 0){
220  memcpy(sitelink_fwd_dst, even_sitelink_fwd_src, len);
221  memcpy(sitelink_fwd_dst + Vsh_t*gaugeSiteSize, odd_sitelink_fwd_src, len);
222  }else{
223  //switching odd and even ghost sitelink
224  memcpy(sitelink_fwd_dst, odd_sitelink_fwd_src, len);
225  memcpy(sitelink_fwd_dst + Vsh_t*gaugeSiteSize, even_sitelink_fwd_src, len);
226  }
227 
228  }
229 #else
230  int nFace =1;
231  for(int dir=0; dir < 4; dir++){
232  if(optflag && !commDimPartitioned(dir)) continue;
233  pack_ghost_all_links((void**)sitelink, (void**)sitelink_back_sendbuf, (void**)sitelink_fwd_sendbuf, dir, nFace, (QudaPrecision)(sizeof(Float)), X);
234  }
235 #endif
236 
237  for (int dir = 0; dir < 4; dir++) {
238  if(optflag && !commDimPartitioned(dir)) continue;
239  int len = Vsh[dir]*gaugeSiteSize*sizeof(Float);
240  Float* ghost_sitelink_back = ghost_sitelink[dir];
241  Float* ghost_sitelink_fwd = ghost_sitelink[dir] + 8*Vsh[dir]*gaugeSiteSize;
242 
243  MsgHandle *mh_recv_back;
244  MsgHandle *mh_recv_fwd;
245  MsgHandle *mh_send_fwd;
246  MsgHandle *mh_send_back;
247 
248  mh_recv_back = comm_declare_receive_relative(ghost_sitelink_back, dir, -1, 8*len);
249  mh_recv_fwd = comm_declare_receive_relative(ghost_sitelink_fwd, dir, +1, 8*len);
250  mh_send_fwd = comm_declare_send_relative(sitelink_fwd_sendbuf[dir], dir, +1, 8*len);
251  mh_send_back = comm_declare_send_relative(sitelink_back_sendbuf[dir], dir, -1, 8*len);
252 
253  comm_start(mh_recv_back);
254  comm_start(mh_recv_fwd);
255  comm_start(mh_send_fwd);
256  comm_start(mh_send_back);
257 
258  comm_wait(mh_send_fwd);
259  comm_wait(mh_send_back);
260  comm_wait(mh_recv_back);
261  comm_wait(mh_recv_fwd);
262 
263  comm_free(mh_send_fwd);
264  comm_free(mh_send_back);
265  comm_free(mh_recv_back);
266  comm_free(mh_recv_fwd);
267  }
268 
269  exchange_sitelink_diag(X, sitelink, ghost_sitelink_diag, optflag);
270 }
271 
272 
273 //this function is used for link fattening computation
274 //@optflag: if this flag is set, we only communicate in directions that are partitioned
275 // if not set, then we communicate in all directions regradless of partitions
276 void exchange_cpu_sitelink(int* X,
277  void** sitelink, void** ghost_sitelink,
278  void** ghost_sitelink_diag,
279  QudaPrecision gPrecision, QudaGaugeParam* param, int optflag)
280 {
281  setup_dims(X);
282  static void* sitelink_fwd_sendbuf[4];
283  static void* sitelink_back_sendbuf[4];
284  static bool allocated = false;
285 
286  if (!allocated) {
287  for (int i=0; i<4; i++) {
288  int nbytes = 4*Vs[i]*gaugeSiteSize*gPrecision;
289  sitelink_fwd_sendbuf[i] = safe_malloc(nbytes);
290  sitelink_back_sendbuf[i] = safe_malloc(nbytes);
291  memset(sitelink_fwd_sendbuf[i], 0, nbytes);
292  memset(sitelink_back_sendbuf[i], 0, nbytes);
293  }
294  allocated = true;
295  }
296 
297  if (gPrecision == QUDA_DOUBLE_PRECISION){
298  exchange_sitelink(X, (double**)sitelink, (double**)(ghost_sitelink), (double**)ghost_sitelink_diag,
299  (double**)sitelink_fwd_sendbuf, (double**)sitelink_back_sendbuf, optflag);
300  }else{ //single
301  exchange_sitelink(X, (float**)sitelink, (float**)(ghost_sitelink), (float**)ghost_sitelink_diag,
302  (float**)sitelink_fwd_sendbuf, (float**)sitelink_back_sendbuf, optflag);
303  }
304 
306  for(int i=0;i < 4;i++){
307  host_free(sitelink_fwd_sendbuf[i]);
308  host_free(sitelink_back_sendbuf[i]);
309  }
310  allocated = false;
311  }
312 }
313 
314 
315 #define MEMCOPY_GAUGE_FIELDS_GRID_TO_BUF(ghost_buf, dst_idx, sitelink, src_idx, num, dir) \
316  if(src_oddness){ \
317  src_idx += Vh_ex; \
318  } \
319  if(dst_oddness){ \
320  dst_idx += R[dir]*slice_3d[dir]/2; \
321  } \
322  if(cpu_order == QUDA_QDP_GAUGE_ORDER){ \
323  for(int linkdir=0; linkdir < 4; linkdir++){ \
324  char* src = (char*) sitelink[linkdir] + (src_idx)*gaugebytes; \
325  char* dst = ((char*)ghost_buf[dir])+ linkdir*R[dir]*slice_3d[dir]*gaugebytes + (dst_idx)*gaugebytes; \
326  memcpy(dst, src, gaugebytes*(num)); \
327  } \
328  }else{ /*QUDA_MILC_GAUGE_ORDER*/ \
329  char* src = ((char*)sitelink)+ 4*(src_idx)*gaugebytes; \
330  char* dst = ((char*)ghost_buf[dir]) + 4*(dst_idx)*gaugebytes; \
331  memcpy(dst, src, 4*gaugebytes*(num)); \
332  } \
333 
334 #define MEMCOPY_GAUGE_FIELDS_BUF_TO_GRID(sitelink, dst_idx, ghost_buf, src_idx, num, dir) \
335  if(oddness){ \
336  if(commDimPartitioned(dir)){ \
337  src_idx += R[dir]*slice_3d[dir]/2; \
338  }else{ \
339  src_idx += Vh_ex; \
340  } \
341  dst_idx += Vh_ex; \
342  } \
343  if(cpu_order == QUDA_QDP_GAUGE_ORDER){ \
344  for(int linkdir=0; linkdir < 4; linkdir++){ \
345  char* src; \
346  if(commDimPartitioned(dir)){ \
347  src = ((char*)ghost_buf[dir])+ linkdir*R[dir]*slice_3d[dir]*gaugebytes + (src_idx)*gaugebytes; \
348  }else{ \
349  src = ((char*)sitelink[linkdir])+ (src_idx)*gaugebytes; \
350  } \
351  char* dst = (char*) sitelink[linkdir] + (dst_idx)*gaugebytes; \
352  memcpy(dst, src, gaugebytes*(num)); \
353  } \
354  }else{/*QUDA_MILC_GAUGE_FIELD*/ \
355  char* src; \
356  if(commDimPartitioned(dir)){ \
357  src=((char*)ghost_buf[dir]) + 4*(src_idx)*gaugebytes; \
358  }else{ \
359  src = ((char*)sitelink)+ 4*(src_idx)*gaugebytes; \
360  } \
361  char* dst = ((char*)sitelink) + 4*(dst_idx)*gaugebytes; \
362  memcpy(dst, src, 4*gaugebytes*(num)); \
363  }
364 
365 #define MEMCOPY_GAUGE_FIELDS_BUF_TO_GRID_T(sitelink, ghost_buf, dst_face, src_face, dir) \
366  /*even*/ \
367  int even_dst_idx = (dst_face*E3E2E1)/2; \
368  int even_src_idx; \
369  if(commDimPartitioned(dir)){ \
370  even_src_idx = 0; \
371  }else{ \
372  even_src_idx = (src_face*E3E2E1)/2; \
373  } \
374  /*odd*/ \
375  int odd_dst_idx = even_dst_idx+Vh_ex; \
376  int odd_src_idx; \
377  if(commDimPartitioned(dir)){ \
378  odd_src_idx = R[dir]*slice_3d[dir]/2; \
379  }else{ \
380  odd_src_idx = even_src_idx+Vh_ex; \
381  } \
382  if(cpu_order == QUDA_QDP_GAUGE_ORDER){ \
383  for(int linkdir=0; linkdir < 4; linkdir ++){ \
384  char* dst = (char*)sitelink[linkdir]; \
385  char* src; \
386  if(commDimPartitioned(dir)){ \
387  src = ((char*)ghost_buf[dir]) + linkdir*R[dir]*slice_3d[dir]*gaugebytes; \
388  }else{ \
389  src = (char*)sitelink[linkdir]; \
390  } \
391  memcpy(dst + even_dst_idx * gaugebytes, src + even_src_idx*gaugebytes, R[dir]*slice_3d[dir]*gaugebytes/2); \
392  memcpy(dst + odd_dst_idx * gaugebytes, src + odd_src_idx*gaugebytes, R[dir]*slice_3d[dir]*gaugebytes/2); \
393  } \
394  }else{/*QUDA_MILC_GAUGE_ORDER*/ \
395  char* dst = (char*)sitelink; \
396  char* src; \
397  if(commDimPartitioned(dir)){ \
398  src = (char*)ghost_buf[dir]; \
399  }else{ \
400  src = (char*)sitelink; \
401  } \
402  memcpy(dst+4*even_dst_idx*gaugebytes, src+4*even_src_idx*gaugebytes, 4*R[dir]*slice_3d[dir]*gaugebytes/2); \
403  memcpy(dst+4*odd_dst_idx*gaugebytes, src+4*odd_src_idx*gaugebytes, 4*R[dir]*slice_3d[dir]*gaugebytes/2); \
404  }
405 
406 /* This function exchange the sitelink and store them in the correspoinding portion of
407  * the extended sitelink memory region
408  * @sitelink: this is stored according to dimension size (X4+R4) * (X1+R1) * (X2+R2) * (X3+R3)
409  */
410 
411 void exchange_cpu_sitelink_ex(int* X, int *R, void** sitelink, QudaGaugeFieldOrder cpu_order,
412  QudaPrecision gPrecision, int optflag)
413 {
414  int E1,E2,E3,E4;
415  E1 = X[0]+2*R[0]; E2 = X[1]+2*R[1]; E3 = X[2]+2*R[2]; E4 = X[3]+2*R[3];
416  int E3E2E1=E3*E2*E1;
417  int E2E1=E2*E1;
418  int E4E3E2=E4*E3*E2;
419  int E3E2=E3*E2;
420  int E4E3E1=E4*E3*E1;
421  int E3E1=E3*E1;
422  int E4E2E1=E4*E2*E1;
423  int Vh_ex = E4*E3*E2*E1/2;
424 
425  //...............x.........y.....z......t
426  int starta[] = {R[3], R[3], R[3], 0};
427  int enda[] = {X[3]+R[3], X[3]+R[3], X[3]+R[3], X[2]+2*R[2]};
428 
429  int startb[] = {R[2], R[2], 0, 0};
430  int endb[] = {X[2]+R[2], X[2]+R[2], X[1]+2*R[1], X[1]+2*R[1]};
431 
432  int startc[] = {R[1], 0, 0, 0};
433  int endc[] = {X[1]+R[1], X[0]+2*R[0], X[0]+2*R[0], X[0]+2*R[0]};
434 
435  int f_main[4][4] = {
436  {E3E2E1, E2E1, E1, 1},
437  {E3E2E1, E2E1, 1, E1},
438  {E3E2E1, E1, 1, E2E1},
439  {E2E1, E1, 1, E3E2E1}
440  };
441 
442  int f_bound[4][4]={
443  {E3E2, E2, 1, E4E3E2},
444  {E3E1, E1, 1, E4E3E1},
445  {E2E1, E1, 1, E4E2E1},
446  {E2E1, E1, 1, E3E2E1}
447  };
448 
449  int slice_3d[] = { E4E3E2, E4E3E1, E4E2E1, E3E2E1};
450  int len[4];
451  for(int i=0;i < 4;i++){
452  len[i] = slice_3d[i] * R[i] * 4*gaugeSiteSize*gPrecision; //2 slices, 4 directions' links
453  }
454 
455  void* ghost_sitelink_fwd_sendbuf[4];
456  void* ghost_sitelink_back_sendbuf[4];
457  void* ghost_sitelink_fwd[4];
458  void* ghost_sitelink_back[4];
459 
460  for(int i=0; i<4; i++) {
461  if(!commDimPartitioned(i)) continue;
462  ghost_sitelink_fwd_sendbuf[i] = safe_malloc(len[i]);
463  ghost_sitelink_back_sendbuf[i] = safe_malloc(len[i]);
464  ghost_sitelink_fwd[i] = safe_malloc(len[i]);
465  ghost_sitelink_back[i] = safe_malloc(len[i]);
466  }
467 
468  int gaugebytes = gaugeSiteSize*gPrecision;
469  int a, b, c,d;
470  for(int dir =0;dir < 4;dir++){
471  if( (!commDimPartitioned(dir)) && optflag) continue;
472  if(commDimPartitioned(dir)){
473  //fill the sendbuf here
474  //back
475  for(d=R[dir]; d < 2*R[dir]; d++)
476  for(a=starta[dir];a < enda[dir]; a++)
477  for(b=startb[dir]; b < endb[dir]; b++)
478 
479  if(f_main[dir][2] != 1 || f_bound[dir][2] !=1){
480  for (c=startc[dir]; c < endc[dir]; c++){
481  int oddness = (a+b+c+d)%2;
482  int src_idx = ( a*f_main[dir][0] + b*f_main[dir][1]+ c*f_main[dir][2] + d*f_main[dir][3])>> 1;
483  int dst_idx = ( a*f_bound[dir][0] + b*f_bound[dir][1]+ c*f_bound[dir][2] + (d-R[dir])*f_bound[dir][3])>> 1;
484 
485  int src_oddness = oddness;
486  int dst_oddness = oddness;
487  if((X[dir] % 2 ==1) && (commDim(dir) > 1)){ //switch even/odd position
488  dst_oddness = 1-oddness;
489  }
490 
491  MEMCOPY_GAUGE_FIELDS_GRID_TO_BUF(ghost_sitelink_back_sendbuf, dst_idx, sitelink, src_idx, 1, dir);
492 
493  }//c
494  }else{
495  for(int loop=0; loop < 2; loop++){
496  c=startc[dir]+loop;
497  if(c < endc[dir]){
498  int oddness = (a+b+c+d)%2;
499  int src_idx = ( a*f_main[dir][0] + b*f_main[dir][1]+ c*f_main[dir][2] + d*f_main[dir][3])>> 1;
500  int dst_idx = ( a*f_bound[dir][0] + b*f_bound[dir][1]+ c*f_bound[dir][2] + (d-R[dir])*f_bound[dir][3])>> 1;
501 
502  int src_oddness = oddness;
503  int dst_oddness = oddness;
504  if((X[dir] % 2 ==1) && (commDim(dir) > 1)){ //switch even/odd position
505  dst_oddness = 1-oddness;
506  }
507  MEMCOPY_GAUGE_FIELDS_GRID_TO_BUF(ghost_sitelink_back_sendbuf, dst_idx, sitelink, src_idx, (endc[dir]-c+1)/2, dir);
508 
509  }//if c
510  }//for loop
511  }//if
512 
513 
514  //fwd
515  for(d=X[dir]; d < X[dir]+R[dir]; d++) {
516  for(a=starta[dir];a < enda[dir]; a++) {
517  for(b=startb[dir]; b < endb[dir]; b++) {
518 
519  if(f_main[dir][2] != 1 || f_bound[dir][2] !=1){
520  for (c=startc[dir]; c < endc[dir]; c++){
521  int oddness = (a+b+c+d)%2;
522  int src_idx = ( a*f_main[dir][0] + b*f_main[dir][1]+ c*f_main[dir][2] + d*f_main[dir][3])>> 1;
523  int dst_idx = ( a*f_bound[dir][0] + b*f_bound[dir][1]+ c*f_bound[dir][2] + (d-X[dir])*f_bound[dir][3])>> 1;
524 
525  int src_oddness = oddness;
526  int dst_oddness = oddness;
527  if((X[dir] % 2 ==1) && (commDim(dir) > 1)){ //switch even/odd position
528  dst_oddness = 1-oddness;
529  }
530 
531  MEMCOPY_GAUGE_FIELDS_GRID_TO_BUF(ghost_sitelink_fwd_sendbuf, dst_idx, sitelink, src_idx, 1,dir);
532  }//c
533  }else{
534  for(int loop=0; loop < 2; loop++){
535  c=startc[dir]+loop;
536  if(c < endc[dir]){
537  int oddness = (a+b+c+d)%2;
538  int src_idx = ( a*f_main[dir][0] + b*f_main[dir][1]+ c*f_main[dir][2] + d*f_main[dir][3])>> 1;
539  int dst_idx = ( a*f_bound[dir][0] + b*f_bound[dir][1]+ c*f_bound[dir][2] + (d-X[dir])*f_bound[dir][3])>> 1;
540 
541  int src_oddness = oddness;
542  int dst_oddness = oddness;
543  if((X[dir] % 2 ==1) && (commDim(dir) > 1)){ //switch even/odd position
544  dst_oddness = 1-oddness;
545  }
546  MEMCOPY_GAUGE_FIELDS_GRID_TO_BUF(ghost_sitelink_fwd_sendbuf, dst_idx, sitelink, src_idx, (endc[dir]-c+1)/2,dir);
547  }
548  }//for loop
549  }//if
550 
551  }
552  }
553  }
554 
555  MsgHandle *mh_recv_back;
556  MsgHandle *mh_recv_fwd;
557  MsgHandle *mh_send_fwd;
558  MsgHandle *mh_send_back;
559 
560  mh_recv_back = comm_declare_receive_relative(ghost_sitelink_back[dir], dir, -1, len[dir]);
561  mh_recv_fwd = comm_declare_receive_relative(ghost_sitelink_fwd[dir], dir, +1, len[dir]);
562  mh_send_fwd = comm_declare_send_relative(ghost_sitelink_fwd_sendbuf[dir], dir, +1, len[dir]);
563  mh_send_back = comm_declare_send_relative(ghost_sitelink_back_sendbuf[dir], dir, -1, len[dir]);
564 
565  comm_start(mh_recv_back);
566  comm_start(mh_recv_fwd);
567  comm_start(mh_send_fwd);
568  comm_start(mh_send_back);
569 
570  comm_wait(mh_send_fwd);
571  comm_wait(mh_send_back);
572  comm_wait(mh_recv_back);
573  comm_wait(mh_recv_fwd);
574 
575  comm_free(mh_send_fwd);
576  comm_free(mh_send_back);
577  comm_free(mh_recv_back);
578  comm_free(mh_recv_fwd);
579 
580  }//if
581 
582  //use the messages to fill the sitelink data
583  //back
584  if (dir < 3 ) {
585 
586  for(d=0; d < R[dir]; d++) {
587  for(a=starta[dir];a < enda[dir]; a++) {
588  for(b=startb[dir]; b < endb[dir]; b++) {
589 
590  if(f_main[dir][2] != 1 || f_bound[dir][2] !=1){
591  for (c=startc[dir]; c < endc[dir]; c++){
592  int oddness = (a+b+c+d)%2;
593  int dst_idx = ( a*f_main[dir][0] + b*f_main[dir][1]+ c*f_main[dir][2] + d*f_main[dir][3])>> 1;
594  int src_idx;
595  if(commDimPartitioned(dir)){
596  src_idx = ( a*f_bound[dir][0] + b*f_bound[dir][1]+ c*f_bound[dir][2] + d*f_bound[dir][3])>> 1;
597  }else{
598  src_idx = ( a*f_main[dir][0] + b*f_main[dir][1]+ c*f_main[dir][2] + (d+X[dir])*f_main[dir][3])>> 1;
599  }
600 
601  MEMCOPY_GAUGE_FIELDS_BUF_TO_GRID(sitelink, dst_idx, ghost_sitelink_back, src_idx, 1, dir);
602 
603  }//c
604  }else{
605  //optimized copy
606  //first half: startc[dir] -> end[dir] with step=2
607 
608  for(int loop =0;loop <2;loop++){
609  int c=startc[dir]+loop;
610  if(c < endc[dir]){
611  int oddness = (a+b+c+d)%2;
612  int dst_idx = ( a*f_main[dir][0] + b*f_main[dir][1]+ c*f_main[dir][2] + d*f_main[dir][3])>> 1;
613  int src_idx;
614  if(commDimPartitioned(dir)){
615  src_idx = ( a*f_bound[dir][0] + b*f_bound[dir][1]+ c*f_bound[dir][2] + d*f_bound[dir][3])>> 1;
616  }else{
617  src_idx = ( a*f_main[dir][0] + b*f_main[dir][1]+ c*f_main[dir][2] + (d+X[dir])*f_main[dir][3])>> 1;
618  }
619 
620  MEMCOPY_GAUGE_FIELDS_BUF_TO_GRID(sitelink, dst_idx, ghost_sitelink_back, src_idx, (endc[dir]-c+1)/2, dir);
621 
622  }//if c
623  }//for loop
624  }//if
625 
626  }
627  }
628  }
629 
630  }else{
631  //when dir == 3 (T direction), the data layout format in sitelink and the message is the same, we can do large copys
632 
633  MEMCOPY_GAUGE_FIELDS_BUF_TO_GRID_T(sitelink, ghost_sitelink_back, 0, X[3], dir)
634  }//if
635 
636  //fwd
637  if( dir < 3 ){
638 
639  for(d=X[dir]+R[dir]; d < X[dir]+2*R[dir]; d++) {
640  for(a=starta[dir];a < enda[dir]; a++) {
641  for(b=startb[dir]; b < endb[dir]; b++) {
642 
643  if(f_main[dir][2] != 1 || f_bound[dir][2] != 1){
644  for (c=startc[dir]; c < endc[dir]; c++){
645  int oddness = (a+b+c+d)%2;
646  int dst_idx = ( a*f_main[dir][0] + b*f_main[dir][1]+ c*f_main[dir][2] + d*f_main[dir][3])>> 1;
647  int src_idx;
648  if(commDimPartitioned(dir)){
649  src_idx = ( a*f_bound[dir][0] + b*f_bound[dir][1]+ c*f_bound[dir][2] + (d-X[dir]-R[dir])*f_bound[dir][3])>> 1;
650  }else{
651  src_idx = ( a*f_main[dir][0] + b*f_main[dir][1]+ c*f_main[dir][2] + (d-X[dir])*f_main[dir][3])>> 1;
652  }
653 
654  MEMCOPY_GAUGE_FIELDS_BUF_TO_GRID(sitelink, dst_idx, ghost_sitelink_fwd, src_idx, 1, dir);
655 
656  }//c
657  }else{
658  for(int loop =0; loop < 2; loop++){
659  //for (c=startc[dir]; c < endc[dir]; c++){
660  c=startc[dir] + loop;
661  if(c < endc[dir]){
662  int oddness = (a+b+c+d)%2;
663  int dst_idx = ( a*f_main[dir][0] + b*f_main[dir][1]+ c*f_main[dir][2] + d*f_main[dir][3])>> 1;
664  int src_idx;
665  if(commDimPartitioned(dir)){
666  src_idx = ( a*f_bound[dir][0] + b*f_bound[dir][1]+ c*f_bound[dir][2] + (d-X[dir]-R[dir])*f_bound[dir][3])>> 1;
667  }else{
668  src_idx = ( a*f_main[dir][0] + b*f_main[dir][1]+ c*f_main[dir][2] + (d-X[dir])*f_main[dir][3])>> 1;
669  }
670  MEMCOPY_GAUGE_FIELDS_BUF_TO_GRID(sitelink, dst_idx, ghost_sitelink_fwd, src_idx, (endc[dir]-c+1)/2, dir);
671  }//if
672  }//for loop
673  }//if
674 
675  }
676  }
677  }
678 
679 
680  } else {
681 
682  //when dir == 3 (T direction), the data layout format in sitelink and the message is the same, we can do large copys
683  MEMCOPY_GAUGE_FIELDS_BUF_TO_GRID_T(sitelink, ghost_sitelink_fwd, (X[3]+R[3]), 2, dir) // TESTME 2
684 
685  }//if
686 
687  }//dir for loop
688 
689 
690  for(int dir=0;dir < 4;dir++){
691  if(!commDimPartitioned(dir)) continue;
692  host_free(ghost_sitelink_fwd_sendbuf[dir]);
693  host_free(ghost_sitelink_back_sendbuf[dir]);
694  host_free(ghost_sitelink_fwd[dir]);
695  host_free(ghost_sitelink_back[dir]);
696  }
697 
698 }
699 
700 
701 
702 template<typename Float>
703 void
704 do_exchange_cpu_staple(Float* staple, Float** ghost_staple, Float** staple_fwd_sendbuf, Float** staple_back_sendbuf, int* X)
705 {
706 
707 
708 #if 0
709  int len = Vsh_t*gaugeSiteSize*sizeof(Float);
710  Float* even_staple_back_src = staple;
711  Float* odd_staple_back_src = staple + volumeCB*gaugeSiteSize;
712  Float* staple_back_dst = staple_back_sendbuf[3];
713 
714  if(dims[3] % 2 == 0){
715  memcpy(staple_back_dst, even_staple_back_src, len);
716  memcpy(staple_back_dst + Vsh_t*gaugeSiteSize, odd_staple_back_src, len);
717  }else{
718  //switching odd and even ghost staple
719  memcpy(staple_back_dst, odd_staple_back_src, len);
720  memcpy(staple_back_dst + Vsh_t*gaugeSiteSize, even_staple_back_src, len);
721  }
722 
723 
724  Float* even_staple_fwd_src = staple + (volumeCB - Vsh_t)*gaugeSiteSize;
725  Float* odd_staple_fwd_src = staple + volumeCB*gaugeSiteSize + (volumeCB - Vsh_t)*gaugeSiteSize;
726  Float* staple_fwd_dst = staple_fwd_sendbuf[3];
727  if(dims[3] % 2 == 0){
728  memcpy(staple_fwd_dst, even_staple_fwd_src, len);
729  memcpy(staple_fwd_dst + Vsh_t*gaugeSiteSize, odd_staple_fwd_src, len);
730  }else{
731  //switching odd and even ghost staple
732  memcpy(staple_fwd_dst, odd_staple_fwd_src, len);
733  memcpy(staple_fwd_dst + Vsh_t*gaugeSiteSize, even_staple_fwd_src, len);
734  }
735 #else
736  int nFace =1;
737  pack_ghost_all_staples_cpu(staple, (void**)staple_back_sendbuf,
738  (void**)staple_fwd_sendbuf, nFace, (QudaPrecision)(sizeof(Float)), X);
739 
740 #endif
741 
742  int Vsh[4] = {Vsh_x, Vsh_y, Vsh_z, Vsh_t};
743  int len[4] = {
744  Vsh_x*gaugeSiteSize*sizeof(Float),
745  Vsh_y*gaugeSiteSize*sizeof(Float),
746  Vsh_z*gaugeSiteSize*sizeof(Float),
747  Vsh_t*gaugeSiteSize*sizeof(Float)
748  };
749 
750  for (int dir=0;dir < 4; dir++) {
751 
752  Float *ghost_staple_back = ghost_staple[dir];
753  Float *ghost_staple_fwd = ghost_staple[dir] + 2*Vsh[dir]*gaugeSiteSize;
754 
755  MsgHandle *mh_recv_back;
756  MsgHandle *mh_recv_fwd;
757  MsgHandle *mh_send_fwd;
758  MsgHandle *mh_send_back;
759 
760  mh_recv_back = comm_declare_receive_relative(ghost_staple_back, dir, -1, 2*len[dir]);
761  mh_recv_fwd = comm_declare_receive_relative(ghost_staple_fwd, dir, +1, 2*len[dir]);
762  mh_send_fwd = comm_declare_send_relative(staple_fwd_sendbuf[dir], dir, +1, 2*len[dir]);
763  mh_send_back = comm_declare_send_relative(staple_back_sendbuf[dir], dir, -1, 2*len[dir]);
764 
765  comm_start(mh_recv_back);
766  comm_start(mh_recv_fwd);
767  comm_start(mh_send_fwd);
768  comm_start(mh_send_back);
769 
770  comm_wait(mh_send_fwd);
771  comm_wait(mh_send_back);
772  comm_wait(mh_recv_back);
773  comm_wait(mh_recv_fwd);
774 
775  comm_free(mh_send_fwd);
776  comm_free(mh_send_back);
777  comm_free(mh_recv_back);
778  comm_free(mh_recv_fwd);
779  }
780 }
781 
782 
783 //this function is used for link fattening computation
784 void exchange_cpu_staple(int* X, void* staple, void** ghost_staple, QudaPrecision gPrecision)
785 {
786  setup_dims(X);
787 
788  int Vs[4] = {Vs_x, Vs_y, Vs_z, Vs_t};
789  void *staple_fwd_sendbuf[4];
790  void *staple_back_sendbuf[4];
791 
792  for(int i=0;i < 4; i++){
793  staple_fwd_sendbuf[i] = safe_malloc(Vs[i]*gaugeSiteSize*gPrecision);
794  staple_back_sendbuf[i] = safe_malloc(Vs[i]*gaugeSiteSize*gPrecision);
795  }
796 
797  if (gPrecision == QUDA_DOUBLE_PRECISION) {
798  do_exchange_cpu_staple((double*)staple, (double**)ghost_staple,
799  (double**)staple_fwd_sendbuf, (double**)staple_back_sendbuf, X);
800  } else { //single
801  do_exchange_cpu_staple((float*)staple, (float**)ghost_staple,
802  (float**)staple_fwd_sendbuf, (float**)staple_back_sendbuf, X);
803  }
804 
805  for (int i=0;i < 4;i++) {
806  host_free(staple_fwd_sendbuf[i]);
807  host_free(staple_back_sendbuf[i]);
808  }
809 }
810 
811 //@whichway indicates send direction
812 void
813 exchange_gpu_staple_start(int* X, void* _cudaStaple, int dir, int whichway, cudaStream_t * stream)
814 {
815  setup_dims(X);
816 
817  cudaGaugeField* cudaStaple = (cudaGaugeField*) _cudaStaple;
818  exchange_llfat_init(cudaStaple->Precision());
819 
820 
821  void* even = cudaStaple->Even_p();
822  void* odd = cudaStaple->Odd_p();
823  int volume = cudaStaple->VolumeCB();
824  QudaPrecision prec = cudaStaple->Precision();
825  int stride = cudaStaple->Stride();
826 
827  packGhostStaple(X, even, odd, volume, prec, stride,
828  dir, whichway, fwd_nbr_staple_gpu, back_nbr_staple_gpu,
829  fwd_nbr_staple_sendbuf, back_nbr_staple_sendbuf, stream);
830 }
831 
832 
833 void exchange_gpu_staple_comms(int* X, void* _cudaStaple, int dim, int send_dir, cudaStream_t *stream)
834 {
835  cudaGaugeField* cudaStaple = (cudaGaugeField*) _cudaStaple;
836  QudaPrecision prec = cudaStaple->Precision();
837 
838  cudaStreamSynchronize(*stream);
839 
840  int recv_dir = (send_dir == QUDA_BACKWARDS) ? QUDA_FORWARDS : QUDA_BACKWARDS;
841 
842  int len = Vs[dim]*gaugeSiteSize*prec;
843 
844  if (recv_dir == QUDA_BACKWARDS) {
845 
846 #ifdef GPU_DIRECT
847  llfat_recv.back[dim] = comm_declare_receive_relative(back_nbr_staple[dim], dim, -1, len);
848  llfat_send.fwd[dim] = comm_declare_send_relative(fwd_nbr_staple_sendbuf[dim], dim, +1, len);
849 #else
850  llfat_recv.back[dim] = comm_declare_receive_relative(back_nbr_staple_cpu[dim], dim, -1, len);
851  memcpy(fwd_nbr_staple_sendbuf_cpu[dim], fwd_nbr_staple_sendbuf[dim], len);
852  llfat_send.fwd[dim] = comm_declare_send_relative(fwd_nbr_staple_sendbuf_cpu[dim], dim, +1, len);
853 #endif
854 
855  comm_start(llfat_recv.back[dim]);
856  comm_start(llfat_send.fwd[dim]);
857 
858  } else { // QUDA_FORWARDS
859 
860 #ifdef GPU_DIRECT
861  llfat_recv.fwd[dim] = comm_declare_receive_relative(fwd_nbr_staple[dim], dim, +1, len);
862  llfat_send.back[dim] = comm_declare_send_relative(back_nbr_staple_sendbuf[dim], dim, -1, len);
863 #else
864  llfat_recv.fwd[dim] = comm_declare_receive_relative(fwd_nbr_staple_cpu[dim], dim, +1, len);
865  memcpy(back_nbr_staple_sendbuf_cpu[dim], back_nbr_staple_sendbuf[dim], len);
866  llfat_send.back[dim] = comm_declare_send_relative(back_nbr_staple_sendbuf_cpu[dim], dim, -1, len);
867 #endif
868 
869  comm_start(llfat_recv.fwd[dim]);
870  comm_start(llfat_send.back[dim]);
871 
872  }
873 }
874 
875 
876 //@whichway indicates send direction
877 //we use recv_whichway to indicate recv direction
878 void
879 exchange_gpu_staple_wait(int* X, void* _cudaStaple, int dim, int send_dir, cudaStream_t * stream)
880 {
881  cudaGaugeField* cudaStaple = (cudaGaugeField*) _cudaStaple;
882 
883  void* even = cudaStaple->Even_p();
884  void* odd = cudaStaple->Odd_p();
885  int volume = cudaStaple->VolumeCB();
886  QudaPrecision prec = cudaStaple->Precision();
887  int stride = cudaStaple->Stride();
888 
889  int recv_dir = (send_dir == QUDA_BACKWARDS) ? QUDA_FORWARDS : QUDA_BACKWARDS;
890 
891 #ifndef GPU_DIRECT
892  int len = Vs[dim]*gaugeSiteSize*prec;
893 #endif
894 
895  if (recv_dir == QUDA_BACKWARDS) {
896 
897  comm_wait(llfat_send.fwd[dim]);
898  comm_wait(llfat_recv.back[dim]);
899 
900  comm_free(llfat_send.fwd[dim]);
901  comm_free(llfat_recv.back[dim]);
902 
903 #ifdef GPU_DIRECT
904  unpackGhostStaple(X, even, odd, volume, prec, stride,
905  dim, QUDA_BACKWARDS, fwd_nbr_staple, back_nbr_staple, stream);
906 #else
907  memcpy(back_nbr_staple[dim], back_nbr_staple_cpu[dim], len);
908  unpackGhostStaple(X, even, odd, volume, prec, stride,
909  dim, QUDA_BACKWARDS, fwd_nbr_staple, back_nbr_staple, stream);
910 #endif
911 
912  } else { // QUDA_FORWARDS
913 
914  comm_wait(llfat_send.back[dim]);
915  comm_wait(llfat_recv.fwd[dim]);
916 
917  comm_free(llfat_send.back[dim]);
918  comm_free(llfat_recv.fwd[dim]);
919 
920 #ifdef GPU_DIRECT
921  unpackGhostStaple(X, even, odd, volume, prec, stride,
922  dim, QUDA_FORWARDS, fwd_nbr_staple, back_nbr_staple, stream);
923 #else
924  memcpy(fwd_nbr_staple[dim], fwd_nbr_staple_cpu[dim], len);
925  unpackGhostStaple(X, even, odd, volume, prec, stride,
926  dim, QUDA_FORWARDS, fwd_nbr_staple, back_nbr_staple, stream);
927 #endif
928 
929  }
930 }
931 
932 
933 void exchange_llfat_cleanup(void)
934 {
935  for (int i=0; i<4; i++) {
936 
937  if(fwd_nbr_staple_gpu[i]){
938  device_free(fwd_nbr_staple_gpu[i]); fwd_nbr_staple_gpu[i] = NULL;
939  }
940  if(back_nbr_staple_gpu[i]){
941  device_free(back_nbr_staple_gpu[i]); back_nbr_staple_gpu[i] = NULL;
942  }
943 
944 #ifndef GPU_DIRECT
945  if(fwd_nbr_staple_cpu[i]){
946  host_free(fwd_nbr_staple_cpu[i]); fwd_nbr_staple_cpu[i] = NULL;
947  }
948  if(back_nbr_staple_cpu[i]){
949  host_free(back_nbr_staple_cpu[i]);back_nbr_staple_cpu[i] = NULL;
950  }
951  if(fwd_nbr_staple_sendbuf_cpu[i]){
952  host_free(fwd_nbr_staple_sendbuf_cpu[i]); fwd_nbr_staple_sendbuf_cpu[i] = NULL;
953  }
954  if(back_nbr_staple_sendbuf_cpu[i]){
955  host_free(back_nbr_staple_sendbuf_cpu[i]); back_nbr_staple_sendbuf_cpu[i] = NULL;
956  }
957 #endif
958 
959  if(fwd_nbr_staple[i]){
960  host_free(fwd_nbr_staple[i]); fwd_nbr_staple[i] = NULL;
961  }
962  if(back_nbr_staple[i]){
963  host_free(back_nbr_staple[i]); back_nbr_staple[i] = NULL;
964  }
965  if(fwd_nbr_staple_sendbuf[i]){
966  host_free(fwd_nbr_staple_sendbuf[i]); fwd_nbr_staple_sendbuf[i] = NULL;
967  }
968  if(back_nbr_staple_sendbuf[i]){
969  host_free(back_nbr_staple_sendbuf[i]); back_nbr_staple_sendbuf[i] = NULL;
970  }
971 
972  }
973  checkCudaError();
974 }
975 
976 #endif