QUDA v0.4.0
A library for QCD on GPUs
|
00001 00002 #include <misc_helpers.h> 00003 #define gaugeSiteSize 18 00004 #define BLOCKSIZE 64 00005 00006 00007 00008 /* 00009 * MILC order, CPU->GPU 00010 * 00011 *This function converts format in CPU form 00012 * into forms in GPU so as to enable coalesce access 00013 * The function only converts half(even or odd) of the links 00014 * Therefore the entire link conversion need to call this 00015 * function twice 00016 * 00017 * Without loss of generarity, the parity is assume to be even. 00018 * The actual data format in cpu is following 00019 * [a0a1 .... a17][b0b1...b17][c..][d...][a18a19 .....a35] ...[b0b1 ... b17] ... 00020 * X links Y links T,Z links X Links 00021 * where a0->a17 is the X link in the first site 00022 * b0->b17 is the Y link in the first site 00023 * c0->c17 is the Z link in the first site 00024 * d0->d17 is the T link in the first site 00025 * a18->a35 is the X link in the second site 00026 * etc 00027 * 00028 * The GPU format of data looks like the following 00029 * [a0a1][a18a19] ....[pad][a2a3][a20a21]..... [b0b1][b18b19].... 00030 * X links Y links T,Z links 00031 * 00032 * N: # of FloatN in one gauge field 00033 * 9 for QUDA_RECONSTRUCT_NO, SP/DP 00034 * 6 for QUDA_RECONSTRUCT_12, DP 00035 * 3 for QUDA_RECONSTRUCT_12, SP 00036 */ 00037 00038 00039 template<int N, typename FloatN, typename Float2> 00040 __global__ void 00041 do_link_format_cpu_to_gpu(FloatN* dst, Float2* src, 00042 int reconstruct, 00043 int Vh, int pad, int ghostV, size_t threads) 00044 { 00045 int tid = blockIdx.x * blockDim.x + threadIdx.x; 00046 int thread0_tid = blockIdx.x * blockDim.x; 00047 __shared__ FloatN buf[N*BLOCKSIZE]; 00048 00049 int dir; 00050 int j; 00051 int stride = Vh+pad; 00052 for(dir = 0; dir < 4; dir++){ 00053 #ifdef MULTI_GPU 00054 Float2* src_start = src + dir*9*(Vh+ghostV) + thread0_tid*9; 00055 #else 00056 Float2* src_start = src + dir*9*(Vh) + thread0_tid*9; 00057 #endif 00058 for(j=0; j < 9; j++){ 00059 if(thread0_tid*9+j*blockDim.x+threadIdx.x >= 9*threads) break; 00060 if( N == 9){ 00061 ((Float2*)buf)[j*blockDim.x + threadIdx.x] = src_start[j*blockDim.x + threadIdx.x]; 00062 }else{ 00063 int idx = j*blockDim.x + threadIdx.x; 00064 int modval = idx % 9; 00065 int divval = idx / 9; 00066 if(modval < 6){ 00067 ((Float2*)buf)[divval*6+modval] = src_start[idx]; 00068 } 00069 00070 } 00071 } 00072 00073 __syncthreads(); 00074 if(tid < threads){ 00075 FloatN* dst_start = (FloatN*)(dst+dir*N*stride); 00076 for(j=0; j < N; j++){ 00077 dst_start[tid + j*stride] = buf[N*threadIdx.x + j]; 00078 } 00079 } 00080 __syncthreads(); 00081 }//dir 00082 } 00083 00084 00085 /* 00086 * 00087 * N: # of FloatN in one gauge field 00088 * 9 for QUDA_RECONSTRUCT_NO, SP/DP 00089 * 6 for QUDA_RECONSTRUCT_12, DP 00090 * 3 for QUDA_RECONSTRUCT_12, SP 00091 * 00092 * FloatN: float2/double2 00093 * Float: float/double 00094 * 00095 * This is the reverse process for the function do_link_format_gpu_to_cpu() 00096 * 00097 */ 00098 00099 template<int N, typename FloatN, typename Float2> 00100 __global__ void 00101 do_link_format_cpu_to_gpu_milc(FloatN* dst, Float2* src, 00102 int reconstruct, 00103 int Vh, int pad, int ghostV, size_t threads) 00104 { 00105 00106 __shared__ FloatN buf[N*BLOCKSIZE]; 00107 int block_idx = blockIdx.x*blockDim.x/4; 00108 int local_idx = 16*(threadIdx.x/64) + threadIdx.x%16; 00109 int pos_idx = blockIdx.x * blockDim.x/4 + 16*(threadIdx.x/64) + threadIdx.x%16; 00110 int mydir = (threadIdx.x >> 4)% 4; 00111 int j; 00112 int stride = Vh+pad; 00113 00114 for(j=0; j < 9; j++){ 00115 if(block_idx*9*4 + j*blockDim.x+threadIdx.x >= 9*threads) break; 00116 if(N == 9){ 00117 ((Float2*)buf)[j*blockDim.x + threadIdx.x] = src[block_idx*9*4 + j*blockDim.x + threadIdx.x]; 00118 }else{ 00119 int idx = j*blockDim.x + threadIdx.x; 00120 int modval = idx % 9; 00121 int divval = idx / 9; 00122 if(modval < 6){ 00123 ((Float2*)buf)[divval*6+modval] = src[block_idx*9*4 + idx]; 00124 } 00125 } 00126 } 00127 00128 __syncthreads(); 00129 00130 if(pos_idx >= threads/4) return; 00131 00132 for(j=0; j < N; j++){ 00133 if(N == 9){ 00134 dst[pos_idx + mydir*N*stride + j*stride] = buf[local_idx*4*9+mydir*9+j]; 00135 }else{ 00136 dst[pos_idx + mydir*N*stride + j*stride] = buf[local_idx*4*N+mydir*N+j]; 00137 } 00138 } 00139 } 00140 00141 void 00142 link_format_cpu_to_gpu(void* dst, void* src, 00143 int reconstruct, int Vh, int pad, 00144 int ghostV, 00145 QudaPrecision prec, QudaGaugeFieldOrder cpu_order, 00146 cudaStream_t stream) 00147 { 00148 dim3 blockDim(BLOCKSIZE); 00149 if(cpu_order == QUDA_QDP_GAUGE_ORDER){ 00150 #ifdef MULTI_GPU 00151 size_t threads=Vh+ghostV; 00152 #else 00153 size_t threads=Vh; 00154 #endif 00155 dim3 gridDim ((threads + BLOCKSIZE -1)/BLOCKSIZE); 00156 00157 switch (prec){ 00158 case QUDA_DOUBLE_PRECISION: 00159 switch( reconstruct){ 00160 case QUDA_RECONSTRUCT_NO: 00161 do_link_format_cpu_to_gpu<9><<<gridDim, blockDim, 0, stream>>>((double2*)dst, (double2*)src, reconstruct, Vh, pad, ghostV, threads); 00162 break; 00163 case QUDA_RECONSTRUCT_12: 00164 do_link_format_cpu_to_gpu<6><<<gridDim, blockDim, 0, stream>>>((double2*)dst, (double2*)src, reconstruct, Vh, pad, ghostV, threads); 00165 break; 00166 default: 00167 errorQuda("reconstruct type not supported\n"); 00168 } 00169 break; 00170 00171 case QUDA_SINGLE_PRECISION: 00172 switch( reconstruct){ 00173 case QUDA_RECONSTRUCT_NO: 00174 do_link_format_cpu_to_gpu<9><<<gridDim, blockDim, 0, stream>>>((float2*)dst, (float2*)src, reconstruct, Vh, pad, ghostV, threads); 00175 break; 00176 case QUDA_RECONSTRUCT_12: 00177 do_link_format_cpu_to_gpu<3><<<gridDim, blockDim>>>((float4*)dst, (float2*)src, reconstruct, Vh, pad, ghostV, threads); 00178 break; 00179 default: 00180 errorQuda("reconstruct type not supported\n"); 00181 } 00182 break; 00183 00184 default: 00185 errorQuda("ERROR: half precision not support in %s\n", __FUNCTION__); 00186 } 00187 }else if (cpu_order == QUDA_MILC_GAUGE_ORDER){ 00188 #ifdef MULTI_GPU 00189 int threads=4*(Vh+ghostV); 00190 #else 00191 int threads=4*Vh; 00192 #endif 00193 dim3 gridDim ((threads + BLOCKSIZE -1)/BLOCKSIZE); 00194 00195 switch (prec){ 00196 case QUDA_DOUBLE_PRECISION: 00197 switch( reconstruct){ 00198 case QUDA_RECONSTRUCT_NO: 00199 do_link_format_cpu_to_gpu_milc<9><<<gridDim, blockDim, 0, stream>>>((double2*)dst, (double2*)src, reconstruct, Vh, pad, ghostV, threads); 00200 break; 00201 case QUDA_RECONSTRUCT_12: 00202 do_link_format_cpu_to_gpu_milc<6><<<gridDim, blockDim, 0, stream>>>((double2*)dst, (double2*)src, reconstruct, Vh, pad, ghostV, threads); 00203 break; 00204 default: 00205 errorQuda("reconstruct type not supported\n"); 00206 } 00207 break; 00208 00209 case QUDA_SINGLE_PRECISION: 00210 switch( reconstruct){ 00211 case QUDA_RECONSTRUCT_NO: 00212 do_link_format_cpu_to_gpu_milc<9><<<gridDim, blockDim, 0, stream>>>((float2*)dst, (float2*)src, reconstruct, Vh, pad, ghostV, threads); 00213 break; 00214 case QUDA_RECONSTRUCT_12: 00215 do_link_format_cpu_to_gpu_milc<3><<<gridDim, blockDim, 0, stream>>>((float4*)dst, (float2*)src, reconstruct, Vh, pad, ghostV, threads); 00216 break; 00217 default: 00218 errorQuda("reconstruct type not supported\n"); 00219 } 00220 break; 00221 00222 default: 00223 errorQuda("ERROR: half precision not support in %s\n", __FUNCTION__); 00224 } 00225 00226 }else{ 00227 errorQuda("ERROR: invalid cpu ordering (%d)\n", cpu_order); 00228 } 00229 00230 return; 00231 00232 } 00233 /* 00234 * src format: the normal link format in GPU that has stride size @stride 00235 * the src is stored with 9 double2 00236 * dst format: an array of links where x,y,z,t links with the same node id is stored next to each other 00237 * This format is used in destination in fatlink computation in cpu 00238 * Without loss of generarity, the parity is assume to be even. 00239 * The actual data format in GPU is the following 00240 * [a0a1][a18a19] ....[pad][a2a3][a20a21]..... [b0b1][b18b19].... 00241 * X links Y links T,Z links 00242 * The temporary data store in GPU shared memory and the CPU format of data are the following 00243 * [a0a1 .... a17] [b0b1 .....b17] [c0c1 .....c17] [d0d1 .....d17] [a18a19....a35] .... 00244 * |<------------------------site 0 ---------------------------->|<----- site 2 -----> 00245 * 00246 * 00247 * In loading phase the indices for all threads in the first block is the following (assume block size is 64) 00248 * (half warp works on one direction) 00249 * threadIdx.x pos_idx mydir 00250 * 0 0 0 00251 * 1 1 0 00252 * 2 2 0 00253 * 3 3 0 00254 * 4 4 0 00255 * 5 5 0 00256 * 6 6 0 00257 * 7 7 0 00258 * 8 8 0 00259 * 9 9 0 00260 * 10 10 0 00261 * 11 11 0 00262 * 12 12 0 00263 * 13 13 0 00264 * 14 14 0 00265 * 15 15 0 00266 * 16 0 1 00267 * 17 1 1 00268 * 18 2 1 00269 * 19 3 1 00270 * 20 4 1 00271 * 21 5 1 00272 * 22 6 1 00273 * 23 7 1 00274 * 24 8 1 00275 * 25 9 1 00276 * 26 10 1 00277 * 27 11 1 00278 * 28 12 1 00279 * 29 13 1 00280 * 30 14 1 00281 * 31 15 1 00282 * 32 0 2 00283 * 33 1 2 00284 * 34 2 2 00285 * 35 3 2 00286 * 36 4 2 00287 * 37 5 2 00288 * 38 6 2 00289 * 39 7 2 00290 * 40 8 2 00291 * 41 9 2 00292 * 42 10 2 00293 * 43 11 2 00294 * 44 12 2 00295 * 45 13 2 00296 * 46 14 2 00297 * 47 15 2 00298 * 48 0 3 00299 * 49 1 3 00300 * 50 2 3 00301 * 51 3 3 00302 * 52 4 3 00303 * 53 5 3 00304 * 54 6 3 00305 * 55 7 3 00306 * 56 8 3 00307 * 57 9 3 00308 * 58 10 3 00309 * 59 11 3 00310 * 60 12 3 00311 * 61 13 3 00312 * 62 14 3 00313 * 63 15 3 00314 * 00315 */ 00316 00317 template<typename FloatN> 00318 __global__ void 00319 do_link_format_gpu_to_cpu(FloatN* dst, FloatN* src, 00320 int Vh, int stride) 00321 { 00322 __shared__ FloatN buf[gaugeSiteSize/2*BLOCKSIZE]; 00323 00324 int j; 00325 00326 int block_idx = blockIdx.x*blockDim.x/4; 00327 int local_idx = 16*(threadIdx.x/64) + threadIdx.x%16; 00328 int pos_idx = blockIdx.x * blockDim.x/4 + 16*(threadIdx.x/64) + threadIdx.x%16; 00329 int mydir = (threadIdx.x >> 4)% 4; 00330 for(j=0; j < 9; j++){ 00331 buf[local_idx*4*9+mydir*9+j] = src[pos_idx + mydir*9*stride + j*stride]; 00332 } 00333 __syncthreads(); 00334 00335 for(j=0; j < 9; j++){ 00336 dst[block_idx*9*4 + j*blockDim.x + threadIdx.x ] = buf[j*blockDim.x + threadIdx.x]; 00337 } 00338 00339 } 00340 00341 00342 00343 void 00344 link_format_gpu_to_cpu(void* dst, void* src, 00345 int Vh, int stride, QudaPrecision prec, cudaStream_t stream) 00346 { 00347 00348 dim3 blockDim(BLOCKSIZE); 00349 dim3 gridDim(4*Vh/blockDim.x); //every 4 threads process one site's x,y,z,t links 00350 //4*Vh must be multipl of BLOCKSIZE or the kernel does not work 00351 if ((4*Vh) % blockDim.x != 0){ 00352 errorQuda("ERROR: 4*Vh(%d) is not multiple of blocksize(%d), exitting\n", Vh, blockDim.x); 00353 } 00354 if(prec == QUDA_DOUBLE_PRECISION){ 00355 do_link_format_gpu_to_cpu<<<gridDim, blockDim, 0, stream>>>((double2*)dst, (double2*)src, Vh, stride); 00356 }else if(prec == QUDA_SINGLE_PRECISION){ 00357 do_link_format_gpu_to_cpu<<<gridDim, blockDim, 0, stream>>>((float2*)dst, (float2*)src, Vh, stride); 00358 }else{ 00359 printf("ERROR: half precision is not supported in %s\n",__FUNCTION__); 00360 exit(1); 00361 } 00362 00363 } 00364 00365 #define READ_ST_STAPLE(staple, idx, mystride) \ 00366 Float2 P0 = staple[idx + 0*mystride]; \ 00367 Float2 P1 = staple[idx + 1*mystride]; \ 00368 Float2 P2 = staple[idx + 2*mystride]; \ 00369 Float2 P3 = staple[idx + 3*mystride]; \ 00370 Float2 P4 = staple[idx + 4*mystride]; \ 00371 Float2 P5 = staple[idx + 5*mystride]; \ 00372 Float2 P6 = staple[idx + 6*mystride]; \ 00373 Float2 P7 = staple[idx + 7*mystride]; \ 00374 Float2 P8 = staple[idx + 8*mystride]; 00375 00376 #define WRITE_ST_STAPLE(staple, idx, mystride) \ 00377 staple[idx + 0*mystride] = P0; \ 00378 staple[idx + 1*mystride] = P1; \ 00379 staple[idx + 2*mystride] = P2; \ 00380 staple[idx + 3*mystride] = P3; \ 00381 staple[idx + 4*mystride] = P4; \ 00382 staple[idx + 5*mystride] = P5; \ 00383 staple[idx + 6*mystride] = P6; \ 00384 staple[idx + 7*mystride] = P7; \ 00385 staple[idx + 8*mystride] = P8; 00386 00387 00388 00389 template<int dir, int whichway, typename Float2> 00390 __global__ void 00391 collectGhostStapleKernel(Float2* in, const int oddBit, 00392 Float2* nbr_staple_gpu) 00393 { 00394 00395 int sid = blockIdx.x*blockDim.x + threadIdx.x; 00396 int z1 = sid / X1h; 00397 int x1h = sid - z1*X1h; 00398 int z2 = z1 / X2; 00399 int x2 = z1 - z2*X2; 00400 int x4 = z2 / X3; 00401 int x3 = z2 - x4*X3; 00402 int x1odd = (x2 + x3 + x4 + oddBit) & 1; 00403 int x1 = 2*x1h + x1odd; 00404 //int X = 2*sid + x1odd; 00405 00406 READ_ST_STAPLE(in, sid, staple_stride); 00407 int ghost_face_idx; 00408 00409 if ( dir == 0 && whichway == QUDA_BACKWARDS){ 00410 if (x1 < 1){ 00411 ghost_face_idx = (x4*(X3*X2)+x3*X2 +x2)>>1; 00412 WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X4*X3*X2/2); 00413 } 00414 } 00415 00416 if ( dir == 0 && whichway == QUDA_FORWARDS){ 00417 if (x1 >= X1 - 1){ 00418 ghost_face_idx = (x4*(X3*X2)+x3*X2 +x2)>>1; 00419 WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X4*X3*X2/2); 00420 } 00421 } 00422 00423 if ( dir == 1 && whichway == QUDA_BACKWARDS){ 00424 if (x2 < 1){ 00425 ghost_face_idx = (x4*X3*X1+x3*X1+x1)>>1; 00426 WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X4*X3*X1/2); 00427 } 00428 } 00429 00430 if ( dir == 1 && whichway == QUDA_FORWARDS){ 00431 if (x2 >= X2 - 1){ 00432 ghost_face_idx = (x4*X3*X1+x3*X1+x1)>>1; 00433 WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X4*X3*X1/2); 00434 } 00435 } 00436 00437 if ( dir == 2 && whichway == QUDA_BACKWARDS){ 00438 if (x3 < 1){ 00439 ghost_face_idx = (x4*X2*X1+x2*X1+x1)>>1; 00440 WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X4*X2*X1/2); 00441 } 00442 } 00443 00444 if ( dir == 2 && whichway == QUDA_FORWARDS){ 00445 if (x3 >= X3 - 1){ 00446 ghost_face_idx = (x4*X2*X1 + x2*X1 + x1)>>1; 00447 WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X4*X2*X1/2); 00448 } 00449 } 00450 00451 if ( dir == 3 && whichway == QUDA_BACKWARDS){ 00452 if (x4 < 1){ 00453 ghost_face_idx = (x3*X2*X1+x2*X1+x1)>>1; 00454 WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X3*X2*X1/2); 00455 } 00456 } 00457 00458 if ( dir == 3 && whichway == QUDA_FORWARDS){ 00459 if (x4 >= X4 - 1){ 00460 ghost_face_idx = (x3*X2*X1+x2*X1+x1)>>1; 00461 WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X3*X2*X1/2); 00462 } 00463 } 00464 00465 } 00466 00467 00468 //@dir can be 0, 1, 2, 3 (X,Y,Z,T directions) 00469 //@whichway can be QUDA_FORWARDS, QUDA_BACKWORDS 00470 void 00471 collectGhostStaple(int* X, void* even, void* odd, int volume, QudaPrecision precision, 00472 void* ghost_staple_gpu, 00473 int dir, int whichway, cudaStream_t* stream) 00474 { 00475 int Vsh_x, Vsh_y, Vsh_z, Vsh_t; 00476 00477 Vsh_x = X[1]*X[2]*X[3]/2; 00478 Vsh_y = X[0]*X[2]*X[3]/2; 00479 Vsh_z = X[0]*X[1]*X[3]/2; 00480 Vsh_t = X[0]*X[1]*X[2]/2; 00481 00482 dim3 gridDim(volume/BLOCKSIZE, 1, 1); 00483 dim3 blockDim(BLOCKSIZE, 1, 1); 00484 int Vsh[4] = {Vsh_x, Vsh_y, Vsh_z, Vsh_t}; 00485 00486 void* gpu_buf_even = ghost_staple_gpu; 00487 void* gpu_buf_odd = ((char*)ghost_staple_gpu) + Vsh[dir]*gaugeSiteSize*precision ; 00488 if (X[dir] % 2 ==1){ //need switch even/odd 00489 gpu_buf_odd = ghost_staple_gpu; 00490 gpu_buf_even = ((char*)ghost_staple_gpu) + Vsh[dir]*gaugeSiteSize*precision ; 00491 } 00492 00493 int even_parity = 0; 00494 int odd_parity = 1; 00495 00496 if (precision == QUDA_DOUBLE_PRECISION){ 00497 switch(dir){ 00498 case 0: 00499 switch(whichway){ 00500 case QUDA_BACKWARDS: 00501 collectGhostStapleKernel<0, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even); 00502 collectGhostStapleKernel<0, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd); 00503 break; 00504 case QUDA_FORWARDS: 00505 collectGhostStapleKernel<0, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even); 00506 collectGhostStapleKernel<0, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd); 00507 break; 00508 default: 00509 errorQuda("Invalid whichway"); 00510 break; 00511 } 00512 break; 00513 00514 case 1: 00515 switch(whichway){ 00516 case QUDA_BACKWARDS: 00517 collectGhostStapleKernel<1, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even); 00518 collectGhostStapleKernel<1, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd); 00519 break; 00520 case QUDA_FORWARDS: 00521 collectGhostStapleKernel<1, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even); 00522 collectGhostStapleKernel<1, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd); 00523 break; 00524 default: 00525 errorQuda("Invalid whichway"); 00526 break; 00527 } 00528 break; 00529 00530 case 2: 00531 switch(whichway){ 00532 case QUDA_BACKWARDS: 00533 collectGhostStapleKernel<2, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even); 00534 collectGhostStapleKernel<2, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd); 00535 break; 00536 case QUDA_FORWARDS: 00537 collectGhostStapleKernel<2, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even); 00538 collectGhostStapleKernel<2, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd); 00539 break; 00540 default: 00541 errorQuda("Invalid whichway"); 00542 break; 00543 } 00544 break; 00545 00546 case 3: 00547 switch(whichway){ 00548 case QUDA_BACKWARDS: 00549 collectGhostStapleKernel<3, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even); 00550 collectGhostStapleKernel<3, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd); 00551 break; 00552 case QUDA_FORWARDS: 00553 collectGhostStapleKernel<3, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even); 00554 collectGhostStapleKernel<3, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd); 00555 break; 00556 default: 00557 errorQuda("Invalid whichway"); 00558 break; 00559 } 00560 break; 00561 } 00562 }else if(precision == QUDA_SINGLE_PRECISION){ 00563 switch(dir){ 00564 case 0: 00565 switch(whichway){ 00566 case QUDA_BACKWARDS: 00567 collectGhostStapleKernel<0, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even); 00568 collectGhostStapleKernel<0, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd); 00569 break; 00570 case QUDA_FORWARDS: 00571 collectGhostStapleKernel<0, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even); 00572 collectGhostStapleKernel<0, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd); 00573 break; 00574 default: 00575 errorQuda("Invalid whichway"); 00576 break; 00577 } 00578 break; 00579 00580 case 1: 00581 switch(whichway){ 00582 case QUDA_BACKWARDS: 00583 collectGhostStapleKernel<1, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even); 00584 collectGhostStapleKernel<1, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd); 00585 break; 00586 case QUDA_FORWARDS: 00587 collectGhostStapleKernel<1, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even); 00588 collectGhostStapleKernel<1, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd); 00589 break; 00590 default: 00591 errorQuda("Invalid whichway"); 00592 break; 00593 } 00594 break; 00595 00596 case 2: 00597 switch(whichway){ 00598 case QUDA_BACKWARDS: 00599 collectGhostStapleKernel<2, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even); 00600 collectGhostStapleKernel<2, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd); 00601 break; 00602 case QUDA_FORWARDS: 00603 collectGhostStapleKernel<2, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even); 00604 collectGhostStapleKernel<2, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd); 00605 break; 00606 default: 00607 errorQuda("Invalid whichway"); 00608 break; 00609 } 00610 break; 00611 00612 case 3: 00613 switch(whichway){ 00614 case QUDA_BACKWARDS: 00615 collectGhostStapleKernel<3, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even); 00616 collectGhostStapleKernel<3, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd); 00617 break; 00618 case QUDA_FORWARDS: 00619 collectGhostStapleKernel<3, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even); 00620 collectGhostStapleKernel<3, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd); 00621 break; 00622 default: 00623 errorQuda("Invalid whichway"); 00624 break; 00625 } 00626 break; 00627 } 00628 }else{ 00629 printf("ERROR: invalid precision for %s\n", __FUNCTION__); 00630 exit(1); 00631 } 00632 00633 } 00634 00635 00636 #undef gaugeSiteSize 00637 #undef BLOCKSIZE