QUDA  v0.5.0
A library for QCD on GPUs
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
misc_helpers.cu
Go to the documentation of this file.
1 
2 #include <misc_helpers.h>
3 #define gaugeSiteSize 18
4 #define BLOCKSIZE 64
5 
6 
7 
8 /*
9  * MILC order, CPU->GPU
10  *
11  *This function converts format in CPU form
12  * into forms in GPU so as to enable coalesce access
13  * The function only converts half(even or odd) of the links
14  * Therefore the entire link conversion need to call this
15  * function twice
16  *
17  * Without loss of generarity, the parity is assume to be even.
18  * The actual data format in cpu is following
19  * [a0a1 .... a17][b0b1...b17][c..][d...][a18a19 .....a35] ...[b0b1 ... b17] ...
20  * X links Y links T,Z links X Links
21  * where a0->a17 is the X link in the first site
22  * b0->b17 is the Y link in the first site
23  * c0->c17 is the Z link in the first site
24  * d0->d17 is the T link in the first site
25  * a18->a35 is the X link in the second site
26  * etc
27  *
28  * The GPU format of data looks like the following
29  * [a0a1][a18a19] ....[pad][a2a3][a20a21]..... [b0b1][b18b19]....
30  * X links Y links T,Z links
31  *
32  * N: # of FloatN in one gauge field
33  * 9 for QUDA_RECONSTRUCT_NO, SP/DP
34  * 6 for QUDA_RECONSTRUCT_12, DP
35  * 3 for QUDA_RECONSTRUCT_12, SP
36  */
37 
38 namespace quda {
39 
40  template<int N, typename FloatN, typename Float2>
41  __global__ void
42  do_link_format_cpu_to_gpu(FloatN* dst, Float2* src,
43  int reconstruct,
44  int Vh, int pad, int ghostV, size_t threads)
45  {
46  int tid = blockIdx.x * blockDim.x + threadIdx.x;
47  int thread0_tid = blockIdx.x * blockDim.x;
48  __shared__ FloatN buf[N*BLOCKSIZE];
49 
50  int dir;
51  int j;
52  int stride = Vh+pad;
53  for(dir = 0; dir < 4; dir++){
54 #ifdef MULTI_GPU
55  Float2* src_start = src + dir*9*(Vh+ghostV) + thread0_tid*9;
56 #else
57  Float2* src_start = src + dir*9*(Vh) + thread0_tid*9;
58 #endif
59  for(j=0; j < 9; j++){
60  if(thread0_tid*9+j*blockDim.x+threadIdx.x >= 9*threads) break;
61  if( N == 9){
62  ((Float2*)buf)[j*blockDim.x + threadIdx.x] = src_start[j*blockDim.x + threadIdx.x];
63  }else{
64  int idx = j*blockDim.x + threadIdx.x;
65  int modval = idx % 9;
66  int divval = idx / 9;
67  if(modval < 6){
68  ((Float2*)buf)[divval*6+modval] = src_start[idx];
69  }
70 
71  }
72  }
73 
74  __syncthreads();
75  if(tid < threads){
76  FloatN* dst_start = (FloatN*)(dst+dir*N*stride);
77  for(j=0; j < N; j++){
78  dst_start[tid + j*stride] = buf[N*threadIdx.x + j];
79  }
80  }
81  __syncthreads();
82  }//dir
83  }
84 
85 
86  /*
87  *
88  * N: # of FloatN in one gauge field
89  * 9 for QUDA_RECONSTRUCT_NO, SP/DP
90  * 6 for QUDA_RECONSTRUCT_12, DP
91  * 3 for QUDA_RECONSTRUCT_12, SP
92  *
93  * FloatN: float2/double2
94  * Float: float/double
95  *
96  * This is the reverse process for the function do_link_format_gpu_to_cpu()
97  *
98  */
99 
100  template<int N, typename FloatN, typename Float2>
101  __global__ void
103  int reconstruct,
104  int Vh, int pad, int ghostV, size_t threads)
105  {
106 
107  __shared__ FloatN buf[N*BLOCKSIZE];
108  int block_idx = blockIdx.x*blockDim.x/4;
109  int local_idx = 16*(threadIdx.x/64) + threadIdx.x%16;
110  int pos_idx = blockIdx.x * blockDim.x/4 + 16*(threadIdx.x/64) + threadIdx.x%16;
111  int mydir = (threadIdx.x >> 4)% 4;
112  int j;
113  int stride = Vh+pad;
114 
115  for(j=0; j < 9; j++){
116  if(block_idx*9*4 + j*blockDim.x+threadIdx.x >= 9*threads) break;
117  if(N == 9){
118  ((Float2*)buf)[j*blockDim.x + threadIdx.x] = src[block_idx*9*4 + j*blockDim.x + threadIdx.x];
119  }else{
120  int idx = j*blockDim.x + threadIdx.x;
121  int modval = idx % 9;
122  int divval = idx / 9;
123  if(modval < 6){
124  ((Float2*)buf)[divval*6+modval] = src[block_idx*9*4 + idx];
125  }
126  }
127  }
128 
129  __syncthreads();
130 
131  if(pos_idx >= threads/4) return;
132 
133  for(j=0; j < N; j++){
134  if(N == 9){
135  dst[pos_idx + mydir*N*stride + j*stride] = buf[local_idx*4*9+mydir*9+j];
136  }else{
137  dst[pos_idx + mydir*N*stride + j*stride] = buf[local_idx*4*N+mydir*N+j];
138  }
139  }
140  }
141 
142  void
143  link_format_cpu_to_gpu(void* dst, void* src,
144  int reconstruct, int Vh, int pad,
145  int ghostV,
147  cudaStream_t stream)
148  {
149  dim3 blockDim(BLOCKSIZE);
150  if(cpu_order == QUDA_QDP_GAUGE_ORDER){
151 #ifdef MULTI_GPU
152  size_t threads=Vh+ghostV;
153 #else
154  size_t threads=Vh;
155 #endif
156  dim3 gridDim ((threads + BLOCKSIZE -1)/BLOCKSIZE);
157 
158  switch (prec){
160  switch( reconstruct){
161  case QUDA_RECONSTRUCT_NO:
162  do_link_format_cpu_to_gpu<9><<<gridDim, blockDim, 0, stream>>>((double2*)dst, (double2*)src, reconstruct, Vh, pad, ghostV, threads);
163  break;
164  case QUDA_RECONSTRUCT_12:
165  do_link_format_cpu_to_gpu<6><<<gridDim, blockDim, 0, stream>>>((double2*)dst, (double2*)src, reconstruct, Vh, pad, ghostV, threads);
166  break;
167  default:
168  errorQuda("reconstruct type not supported\n");
169  }
170  break;
171 
173  switch( reconstruct){
174  case QUDA_RECONSTRUCT_NO:
175  do_link_format_cpu_to_gpu<9><<<gridDim, blockDim, 0, stream>>>((float2*)dst, (float2*)src, reconstruct, Vh, pad, ghostV, threads);
176  break;
177  case QUDA_RECONSTRUCT_12:
178  do_link_format_cpu_to_gpu<3><<<gridDim, blockDim>>>((float4*)dst, (float2*)src, reconstruct, Vh, pad, ghostV, threads);
179  break;
180  default:
181  errorQuda("reconstruct type not supported\n");
182  }
183  break;
184 
185  default:
186  errorQuda("ERROR: half precision not support in %s\n", __FUNCTION__);
187  }
188  }else if (cpu_order == QUDA_MILC_GAUGE_ORDER){
189 #ifdef MULTI_GPU
190  int threads=4*(Vh+ghostV);
191 #else
192  int threads=4*Vh;
193 #endif
194  dim3 gridDim ((threads + BLOCKSIZE -1)/BLOCKSIZE);
195 
196  switch (prec){
198  switch( reconstruct){
199  case QUDA_RECONSTRUCT_NO:
200  do_link_format_cpu_to_gpu_milc<9><<<gridDim, blockDim, 0, stream>>>((double2*)dst, (double2*)src, reconstruct, Vh, pad, ghostV, threads);
201  break;
202  case QUDA_RECONSTRUCT_12:
203  do_link_format_cpu_to_gpu_milc<6><<<gridDim, blockDim, 0, stream>>>((double2*)dst, (double2*)src, reconstruct, Vh, pad, ghostV, threads);
204  break;
205  default:
206  errorQuda("reconstruct type not supported\n");
207  }
208  break;
209 
211  switch( reconstruct){
212  case QUDA_RECONSTRUCT_NO:
213  do_link_format_cpu_to_gpu_milc<9><<<gridDim, blockDim, 0, stream>>>((float2*)dst, (float2*)src, reconstruct, Vh, pad, ghostV, threads);
214  break;
215  case QUDA_RECONSTRUCT_12:
216  do_link_format_cpu_to_gpu_milc<3><<<gridDim, blockDim, 0, stream>>>((float4*)dst, (float2*)src, reconstruct, Vh, pad, ghostV, threads);
217  break;
218  default:
219  errorQuda("reconstruct type not supported\n");
220  }
221  break;
222 
223  default:
224  errorQuda("ERROR: half precision not support in %s\n", __FUNCTION__);
225  }
226 
227  }else{
228  errorQuda("ERROR: invalid cpu ordering (%d)\n", cpu_order);
229  }
230 
231  return;
232 
233  }
234  /*
235  * src format: the normal link format in GPU that has stride size @stride
236  * the src is stored with 9 double2
237  * dst format: an array of links where x,y,z,t links with the same node id is stored next to each other
238  * This format is used in destination in fatlink computation in cpu
239  * Without loss of generarity, the parity is assume to be even.
240  * The actual data format in GPU is the following
241  * [a0a1][a18a19] ....[pad][a2a3][a20a21]..... [b0b1][b18b19]....
242  * X links Y links T,Z links
243  * The temporary data store in GPU shared memory and the CPU format of data are the following
244  * [a0a1 .... a17] [b0b1 .....b17] [c0c1 .....c17] [d0d1 .....d17] [a18a19....a35] ....
245  * |<------------------------site 0 ---------------------------->|<----- site 2 ----->
246  *
247  *
248  * In loading phase the indices for all threads in the first block is the following (assume block size is 64)
249  * (half warp works on one direction)
250  * threadIdx.x pos_idx mydir
251  * 0 0 0
252  * 1 1 0
253  * 2 2 0
254  * 3 3 0
255  * 4 4 0
256  * 5 5 0
257  * 6 6 0
258  * 7 7 0
259  * 8 8 0
260  * 9 9 0
261  * 10 10 0
262  * 11 11 0
263  * 12 12 0
264  * 13 13 0
265  * 14 14 0
266  * 15 15 0
267  * 16 0 1
268  * 17 1 1
269  * 18 2 1
270  * 19 3 1
271  * 20 4 1
272  * 21 5 1
273  * 22 6 1
274  * 23 7 1
275  * 24 8 1
276  * 25 9 1
277  * 26 10 1
278  * 27 11 1
279  * 28 12 1
280  * 29 13 1
281  * 30 14 1
282  * 31 15 1
283  * 32 0 2
284  * 33 1 2
285  * 34 2 2
286  * 35 3 2
287  * 36 4 2
288  * 37 5 2
289  * 38 6 2
290  * 39 7 2
291  * 40 8 2
292  * 41 9 2
293  * 42 10 2
294  * 43 11 2
295  * 44 12 2
296  * 45 13 2
297  * 46 14 2
298  * 47 15 2
299  * 48 0 3
300  * 49 1 3
301  * 50 2 3
302  * 51 3 3
303  * 52 4 3
304  * 53 5 3
305  * 54 6 3
306  * 55 7 3
307  * 56 8 3
308  * 57 9 3
309  * 58 10 3
310  * 59 11 3
311  * 60 12 3
312  * 61 13 3
313  * 62 14 3
314  * 63 15 3
315  *
316  */
317 
318  template<typename FloatN>
319  __global__ void
321  int Vh, int stride)
322  {
323  __shared__ FloatN buf[gaugeSiteSize/2*BLOCKSIZE];
324 
325  int j;
326 
327  int block_idx = blockIdx.x*blockDim.x/4;
328  int local_idx = 16*(threadIdx.x/64) + threadIdx.x%16;
329  int pos_idx = blockIdx.x * blockDim.x/4 + 16*(threadIdx.x/64) + threadIdx.x%16;
330  int mydir = (threadIdx.x >> 4)% 4;
331  for(j=0; j < 9; j++){
332  buf[local_idx*4*9+mydir*9+j] = src[pos_idx + mydir*9*stride + j*stride];
333  }
334  __syncthreads();
335 
336  for(j=0; j < 9; j++){
337  dst[block_idx*9*4 + j*blockDim.x + threadIdx.x ] = buf[j*blockDim.x + threadIdx.x];
338  }
339 
340  }
341 
342 
343 
344  void
345  link_format_gpu_to_cpu(void* dst, void* src,
346  int Vh, int stride, QudaPrecision prec, cudaStream_t stream)
347  {
348 
349  dim3 blockDim(BLOCKSIZE);
350  dim3 gridDim(4*Vh/blockDim.x); //every 4 threads process one site's x,y,z,t links
351  //4*Vh must be multipl of BLOCKSIZE or the kernel does not work
352  if ((4*Vh) % blockDim.x != 0){
353  errorQuda("ERROR: 4*Vh(%d) is not multiple of blocksize(%d), exitting\n", Vh, blockDim.x);
354  }
355  if(prec == QUDA_DOUBLE_PRECISION){
356  do_link_format_gpu_to_cpu<<<gridDim, blockDim, 0, stream>>>((double2*)dst, (double2*)src, Vh, stride);
357  }else if(prec == QUDA_SINGLE_PRECISION){
358  do_link_format_gpu_to_cpu<<<gridDim, blockDim, 0, stream>>>((float2*)dst, (float2*)src, Vh, stride);
359  }else{
360  printf("ERROR: half precision is not supported in %s\n",__FUNCTION__);
361  exit(1);
362  }
363 
364  }
365 
366 #define READ_ST_STAPLE(staple, idx, mystride) \
367  Float2 P0 = staple[idx + 0*mystride]; \
368  Float2 P1 = staple[idx + 1*mystride]; \
369  Float2 P2 = staple[idx + 2*mystride]; \
370  Float2 P3 = staple[idx + 3*mystride]; \
371  Float2 P4 = staple[idx + 4*mystride]; \
372  Float2 P5 = staple[idx + 5*mystride]; \
373  Float2 P6 = staple[idx + 6*mystride]; \
374  Float2 P7 = staple[idx + 7*mystride]; \
375  Float2 P8 = staple[idx + 8*mystride];
376 
377 #define WRITE_ST_STAPLE(staple, idx, mystride) \
378  staple[idx + 0*mystride] = P0; \
379  staple[idx + 1*mystride] = P1; \
380  staple[idx + 2*mystride] = P2; \
381  staple[idx + 3*mystride] = P3; \
382  staple[idx + 4*mystride] = P4; \
383  staple[idx + 5*mystride] = P5; \
384  staple[idx + 6*mystride] = P6; \
385  staple[idx + 7*mystride] = P7; \
386  staple[idx + 8*mystride] = P8;
387 
388 
389 
390  template<int dir, int whichway, typename Float2>
391  __global__ void
392  collectGhostStapleKernel(Float2* in, const int oddBit,
393  Float2* nbr_staple_gpu)
394  {
395 
396  int sid = blockIdx.x*blockDim.x + threadIdx.x;
397  int z1 = sid / X1h;
398  int x1h = sid - z1*X1h;
399  int z2 = z1 / X2;
400  int x2 = z1 - z2*X2;
401  int x4 = z2 / X3;
402  int x3 = z2 - x4*X3;
403  int x1odd = (x2 + x3 + x4 + oddBit) & 1;
404  int x1 = 2*x1h + x1odd;
405  //int X = 2*sid + x1odd;
406 
407  READ_ST_STAPLE(in, sid, fl.staple_stride);
408  int ghost_face_idx;
409 
410  if ( dir == 0 && whichway == QUDA_BACKWARDS){
411  if (x1 < 1){
412  ghost_face_idx = (x4*(X3*X2)+x3*X2 +x2)>>1;
413  WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X4*X3*X2/2);
414  }
415  }
416 
417  if ( dir == 0 && whichway == QUDA_FORWARDS){
418  if (x1 >= X1 - 1){
419  ghost_face_idx = (x4*(X3*X2)+x3*X2 +x2)>>1;
420  WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X4*X3*X2/2);
421  }
422  }
423 
424  if ( dir == 1 && whichway == QUDA_BACKWARDS){
425  if (x2 < 1){
426  ghost_face_idx = (x4*X3*X1+x3*X1+x1)>>1;
427  WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X4*X3*X1/2);
428  }
429  }
430 
431  if ( dir == 1 && whichway == QUDA_FORWARDS){
432  if (x2 >= X2 - 1){
433  ghost_face_idx = (x4*X3*X1+x3*X1+x1)>>1;
434  WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X4*X3*X1/2);
435  }
436  }
437 
438  if ( dir == 2 && whichway == QUDA_BACKWARDS){
439  if (x3 < 1){
440  ghost_face_idx = (x4*X2*X1+x2*X1+x1)>>1;
441  WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X4*X2*X1/2);
442  }
443  }
444 
445  if ( dir == 2 && whichway == QUDA_FORWARDS){
446  if (x3 >= X3 - 1){
447  ghost_face_idx = (x4*X2*X1 + x2*X1 + x1)>>1;
448  WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X4*X2*X1/2);
449  }
450  }
451 
452  if ( dir == 3 && whichway == QUDA_BACKWARDS){
453  if (x4 < 1){
454  ghost_face_idx = (x3*X2*X1+x2*X1+x1)>>1;
455  WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X3*X2*X1/2);
456  }
457  }
458 
459  if ( dir == 3 && whichway == QUDA_FORWARDS){
460  if (x4 >= X4 - 1){
461  ghost_face_idx = (x3*X2*X1+x2*X1+x1)>>1;
462  WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X3*X2*X1/2);
463  }
464  }
465 
466  }
467 
468 
469  //@dir can be 0, 1, 2, 3 (X,Y,Z,T directions)
470  //@whichway can be QUDA_FORWARDS, QUDA_BACKWORDS
471  void
472  collectGhostStaple(int* X, void* even, void* odd, int volume, QudaPrecision precision,
473  void* ghost_staple_gpu,
474  int dir, int whichway, cudaStream_t* stream)
475  {
476  int Vsh_x, Vsh_y, Vsh_z, Vsh_t;
477 
478  Vsh_x = X[1]*X[2]*X[3]/2;
479  Vsh_y = X[0]*X[2]*X[3]/2;
480  Vsh_z = X[0]*X[1]*X[3]/2;
481  Vsh_t = X[0]*X[1]*X[2]/2;
482 
483  dim3 gridDim(volume/BLOCKSIZE, 1, 1);
484  dim3 blockDim(BLOCKSIZE, 1, 1);
485  int Vsh[4] = {Vsh_x, Vsh_y, Vsh_z, Vsh_t};
486 
487  void* gpu_buf_even = ghost_staple_gpu;
488  void* gpu_buf_odd = ((char*)ghost_staple_gpu) + Vsh[dir]*gaugeSiteSize*precision ;
489  if (X[dir] % 2 ==1){ //need switch even/odd
490  gpu_buf_odd = ghost_staple_gpu;
491  gpu_buf_even = ((char*)ghost_staple_gpu) + Vsh[dir]*gaugeSiteSize*precision ;
492  }
493 
494  int even_parity = 0;
495  int odd_parity = 1;
496 
497  if (precision == QUDA_DOUBLE_PRECISION){
498  switch(dir){
499  case 0:
500  switch(whichway){
501  case QUDA_BACKWARDS:
502  collectGhostStapleKernel<0, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even);
503  collectGhostStapleKernel<0, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd);
504  break;
505  case QUDA_FORWARDS:
506  collectGhostStapleKernel<0, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even);
507  collectGhostStapleKernel<0, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd);
508  break;
509  default:
510  errorQuda("Invalid whichway");
511  break;
512  }
513  break;
514 
515  case 1:
516  switch(whichway){
517  case QUDA_BACKWARDS:
518  collectGhostStapleKernel<1, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even);
519  collectGhostStapleKernel<1, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd);
520  break;
521  case QUDA_FORWARDS:
522  collectGhostStapleKernel<1, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even);
523  collectGhostStapleKernel<1, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd);
524  break;
525  default:
526  errorQuda("Invalid whichway");
527  break;
528  }
529  break;
530 
531  case 2:
532  switch(whichway){
533  case QUDA_BACKWARDS:
534  collectGhostStapleKernel<2, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even);
535  collectGhostStapleKernel<2, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd);
536  break;
537  case QUDA_FORWARDS:
538  collectGhostStapleKernel<2, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even);
539  collectGhostStapleKernel<2, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd);
540  break;
541  default:
542  errorQuda("Invalid whichway");
543  break;
544  }
545  break;
546 
547  case 3:
548  switch(whichway){
549  case QUDA_BACKWARDS:
550  collectGhostStapleKernel<3, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even);
551  collectGhostStapleKernel<3, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd);
552  break;
553  case QUDA_FORWARDS:
554  collectGhostStapleKernel<3, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even);
555  collectGhostStapleKernel<3, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd);
556  break;
557  default:
558  errorQuda("Invalid whichway");
559  break;
560  }
561  break;
562  }
563  }else if(precision == QUDA_SINGLE_PRECISION){
564  switch(dir){
565  case 0:
566  switch(whichway){
567  case QUDA_BACKWARDS:
568  collectGhostStapleKernel<0, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even);
569  collectGhostStapleKernel<0, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd);
570  break;
571  case QUDA_FORWARDS:
572  collectGhostStapleKernel<0, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even);
573  collectGhostStapleKernel<0, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd);
574  break;
575  default:
576  errorQuda("Invalid whichway");
577  break;
578  }
579  break;
580 
581  case 1:
582  switch(whichway){
583  case QUDA_BACKWARDS:
584  collectGhostStapleKernel<1, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even);
585  collectGhostStapleKernel<1, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd);
586  break;
587  case QUDA_FORWARDS:
588  collectGhostStapleKernel<1, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even);
589  collectGhostStapleKernel<1, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd);
590  break;
591  default:
592  errorQuda("Invalid whichway");
593  break;
594  }
595  break;
596 
597  case 2:
598  switch(whichway){
599  case QUDA_BACKWARDS:
600  collectGhostStapleKernel<2, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even);
601  collectGhostStapleKernel<2, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd);
602  break;
603  case QUDA_FORWARDS:
604  collectGhostStapleKernel<2, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even);
605  collectGhostStapleKernel<2, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd);
606  break;
607  default:
608  errorQuda("Invalid whichway");
609  break;
610  }
611  break;
612 
613  case 3:
614  switch(whichway){
615  case QUDA_BACKWARDS:
616  collectGhostStapleKernel<3, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even);
617  collectGhostStapleKernel<3, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd);
618  break;
619  case QUDA_FORWARDS:
620  collectGhostStapleKernel<3, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even);
621  collectGhostStapleKernel<3, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd);
622  break;
623  default:
624  errorQuda("Invalid whichway");
625  break;
626  }
627  break;
628  }
629  }else{
630  printf("ERROR: invalid precision for %s\n", __FUNCTION__);
631  exit(1);
632  }
633 
634  }
635 
636 } // namespace quda
637 
638 #undef gaugeSiteSize
639 #undef BLOCKSIZE