QUDA v0.4.0
A library for QCD on GPUs
quda/lib/misc_helpers.cu
Go to the documentation of this file.
00001 
00002 #include <misc_helpers.h>
00003 #define gaugeSiteSize 18
00004 #define BLOCKSIZE 64
00005 
00006 
00007 
00008 /*
00009  * MILC order, CPU->GPU
00010  *
00011  *This function converts format in CPU form 
00012  * into forms in GPU so as to enable coalesce access
00013  * The function only converts half(even or odd) of the links
00014  * Therefore the entire link conversion need to call this 
00015  * function twice
00016  *   
00017  * Without loss of generarity, the parity is assume to be even.
00018  * The actual data format in cpu is following
00019  * [a0a1 .... a17][b0b1...b17][c..][d...][a18a19 .....a35] ...[b0b1 ... b17] ...
00020  *  X links        Y links    T,Z links   X Links
00021  * where a0->a17 is the X link in the first site
00022  *       b0->b17 is the Y link in the first site
00023  *       c0->c17 is the Z link in the first site
00024  *       d0->d17 is the T link in the first site
00025  *       a18->a35 is the X link in the second site
00026  *       etc
00027  *
00028  * The GPU format of data looks like the following
00029  * [a0a1][a18a19]  ....[pad][a2a3][a20a21]..... [b0b1][b18b19]....
00030  *  X links                                      Y links      T,Z links
00031  *
00032  * N: # of FloatN in one gauge field
00033  *    9 for QUDA_RECONSTRUCT_NO, SP/DP
00034  *    6 for QUDA_RECONSTRUCT_12, DP
00035  *    3 for QUDA_RECONSTRUCT_12, SP
00036  */
00037 
00038 
00039 template<int N, typename FloatN, typename Float2>
00040   __global__ void
00041   do_link_format_cpu_to_gpu(FloatN* dst, Float2* src,
00042                             int reconstruct,
00043                             int Vh, int pad, int ghostV, size_t threads)
00044 {
00045   int tid = blockIdx.x * blockDim.x +  threadIdx.x;
00046   int thread0_tid = blockIdx.x * blockDim.x;
00047   __shared__ FloatN buf[N*BLOCKSIZE];
00048   
00049   int dir;
00050   int j;
00051   int stride = Vh+pad;
00052   for(dir = 0; dir < 4; dir++){
00053 #ifdef MULTI_GPU
00054     Float2* src_start = src + dir*9*(Vh+ghostV) + thread0_tid*9;   
00055 #else
00056     Float2* src_start = src + dir*9*(Vh) + thread0_tid*9;   
00057 #endif
00058     for(j=0; j < 9; j++){
00059       if(thread0_tid*9+j*blockDim.x+threadIdx.x >= 9*threads) break;
00060       if( N == 9){
00061         ((Float2*)buf)[j*blockDim.x + threadIdx.x] =  src_start[j*blockDim.x + threadIdx.x];
00062       }else{ 
00063         int idx = j*blockDim.x + threadIdx.x;
00064         int modval = idx % 9;
00065         int divval = idx / 9;
00066         if(modval < 6){
00067           ((Float2*)buf)[divval*6+modval] = src_start[idx];
00068         }
00069         
00070       }
00071     }
00072     
00073     __syncthreads();
00074     if(tid < threads){
00075       FloatN* dst_start = (FloatN*)(dst+dir*N*stride);
00076       for(j=0; j < N; j++){
00077         dst_start[tid + j*stride] = buf[N*threadIdx.x + j];
00078       }
00079     }
00080     __syncthreads();
00081   }//dir
00082 }
00083 
00084 
00085  /*
00086   *
00087   * N: # of FloatN in one gauge field
00088   *    9 for QUDA_RECONSTRUCT_NO, SP/DP
00089   *    6 for QUDA_RECONSTRUCT_12, DP
00090   *    3 for QUDA_RECONSTRUCT_12, SP
00091   *
00092   * FloatN: float2/double2
00093   * Float: float/double
00094   *
00095   * This is the reverse process for the function do_link_format_gpu_to_cpu()
00096   *
00097   */
00098 
00099 template<int N, typename FloatN, typename Float2>
00100   __global__ void
00101   do_link_format_cpu_to_gpu_milc(FloatN* dst, Float2* src,
00102                                  int reconstruct,
00103                                  int Vh, int pad, int ghostV, size_t threads)
00104 {
00105   
00106   __shared__ FloatN buf[N*BLOCKSIZE];
00107   int block_idx = blockIdx.x*blockDim.x/4;
00108   int local_idx = 16*(threadIdx.x/64) + threadIdx.x%16;
00109   int pos_idx = blockIdx.x * blockDim.x/4 + 16*(threadIdx.x/64) + threadIdx.x%16;
00110   int mydir = (threadIdx.x >> 4)% 4;
00111   int j;
00112   int stride = Vh+pad;
00113   
00114   for(j=0; j < 9; j++){
00115     if(block_idx*9*4 + j*blockDim.x+threadIdx.x >= 9*threads) break;
00116     if(N == 9){
00117       ((Float2*)buf)[j*blockDim.x + threadIdx.x] = src[block_idx*9*4 + j*blockDim.x + threadIdx.x]; 
00118     }else{ 
00119       int idx = j*blockDim.x + threadIdx.x;
00120       int modval = idx % 9;
00121       int divval = idx / 9;
00122       if(modval < 6){
00123         ((Float2*)buf)[divval*6+modval] = src[block_idx*9*4 + idx];
00124       }
00125     }
00126   }  
00127   
00128   __syncthreads();
00129   
00130   if(pos_idx >= threads/4) return;
00131   
00132   for(j=0; j < N; j++){
00133     if(N == 9){
00134       dst[pos_idx + mydir*N*stride + j*stride] = buf[local_idx*4*9+mydir*9+j];
00135     }else{
00136       dst[pos_idx + mydir*N*stride + j*stride] = buf[local_idx*4*N+mydir*N+j];      
00137     }
00138   }      
00139 }
00140 
00141 void 
00142 link_format_cpu_to_gpu(void* dst, void* src, 
00143                        int reconstruct, int Vh, int pad, 
00144                        int ghostV,
00145                        QudaPrecision prec, QudaGaugeFieldOrder cpu_order, 
00146                        cudaStream_t stream)
00147 {
00148   dim3 blockDim(BLOCKSIZE);
00149   if(cpu_order ==  QUDA_QDP_GAUGE_ORDER){
00150 #ifdef MULTI_GPU  
00151     size_t threads=Vh+ghostV;
00152 #else
00153     size_t threads=Vh;
00154 #endif    
00155     dim3 gridDim ((threads + BLOCKSIZE -1)/BLOCKSIZE);
00156 
00157     switch (prec){
00158     case QUDA_DOUBLE_PRECISION:
00159       switch( reconstruct){
00160       case QUDA_RECONSTRUCT_NO:
00161         do_link_format_cpu_to_gpu<9><<<gridDim, blockDim, 0, stream>>>((double2*)dst, (double2*)src, reconstruct, Vh, pad, ghostV, threads);
00162         break;
00163       case QUDA_RECONSTRUCT_12:
00164         do_link_format_cpu_to_gpu<6><<<gridDim, blockDim, 0, stream>>>((double2*)dst, (double2*)src, reconstruct, Vh, pad, ghostV, threads);
00165         break;
00166       default:
00167         errorQuda("reconstruct type not supported\n");
00168       }
00169       break;    
00170       
00171     case QUDA_SINGLE_PRECISION:
00172       switch( reconstruct){
00173       case QUDA_RECONSTRUCT_NO:
00174         do_link_format_cpu_to_gpu<9><<<gridDim, blockDim, 0, stream>>>((float2*)dst, (float2*)src, reconstruct,  Vh, pad, ghostV, threads);   
00175         break;
00176       case QUDA_RECONSTRUCT_12:
00177         do_link_format_cpu_to_gpu<3><<<gridDim, blockDim>>>((float4*)dst, (float2*)src, reconstruct, Vh, pad, ghostV, threads);   
00178         break;
00179       default:
00180         errorQuda("reconstruct type not supported\n");      
00181       }
00182       break;
00183       
00184     default:
00185       errorQuda("ERROR: half precision not support in %s\n", __FUNCTION__);
00186     }
00187   }else if (cpu_order == QUDA_MILC_GAUGE_ORDER){    
00188 #ifdef MULTI_GPU  
00189     int threads=4*(Vh+ghostV);
00190 #else
00191     int threads=4*Vh;
00192 #endif  
00193   dim3 gridDim ((threads + BLOCKSIZE -1)/BLOCKSIZE);
00194 
00195     switch (prec){
00196     case QUDA_DOUBLE_PRECISION:
00197       switch( reconstruct){
00198       case QUDA_RECONSTRUCT_NO:
00199         do_link_format_cpu_to_gpu_milc<9><<<gridDim, blockDim, 0, stream>>>((double2*)dst, (double2*)src, reconstruct, Vh, pad, ghostV, threads);
00200         break;
00201       case QUDA_RECONSTRUCT_12:
00202         do_link_format_cpu_to_gpu_milc<6><<<gridDim, blockDim, 0, stream>>>((double2*)dst, (double2*)src, reconstruct, Vh, pad, ghostV, threads);
00203         break;
00204       default:
00205         errorQuda("reconstruct type not supported\n");
00206       }
00207       break;    
00208       
00209     case QUDA_SINGLE_PRECISION:
00210       switch( reconstruct){
00211       case QUDA_RECONSTRUCT_NO:
00212         do_link_format_cpu_to_gpu_milc<9><<<gridDim, blockDim, 0, stream>>>((float2*)dst, (float2*)src, reconstruct, Vh, pad, ghostV, threads);
00213         break;
00214       case QUDA_RECONSTRUCT_12:
00215         do_link_format_cpu_to_gpu_milc<3><<<gridDim, blockDim, 0, stream>>>((float4*)dst, (float2*)src, reconstruct, Vh, pad, ghostV, threads);
00216         break;
00217       default:
00218         errorQuda("reconstruct type not supported\n");      
00219       }
00220       break;
00221       
00222     default:
00223       errorQuda("ERROR: half precision not support in %s\n", __FUNCTION__);
00224     }
00225     
00226   }else{
00227     errorQuda("ERROR: invalid cpu ordering (%d)\n", cpu_order);
00228   }
00229   
00230   return;
00231   
00232 }
00233 /*
00234  * src format: the normal link format in GPU that has stride size @stride
00235  *             the src is stored with 9 double2
00236  * dst format: an array of links where x,y,z,t links with the same node id is stored next to each other
00237  *             This format is used in destination in fatlink computation in cpu
00238  *    Without loss of generarity, the parity is assume to be even.
00239  * The actual data format in GPU is the following
00240  *    [a0a1][a18a19]  ....[pad][a2a3][a20a21]..... [b0b1][b18b19]....
00241  *    X links                                      Y links      T,Z links
00242  * The temporary data store in GPU shared memory and the CPU format of data are the following
00243  *    [a0a1 .... a17] [b0b1 .....b17] [c0c1 .....c17] [d0d1 .....d17] [a18a19....a35] ....
00244  *    |<------------------------site 0 ---------------------------->|<----- site 2 ----->
00245  *
00246  *
00247  * In loading phase the indices for all threads in the first block is the following (assume block size is 64)
00248  * (half warp works on one direction)
00249  * threadIdx.x  pos_idx         mydir
00250  * 0            0               0
00251  * 1            1               0
00252  * 2            2               0
00253  * 3            3               0                       
00254  * 4            4               0               
00255  * 5            5               0
00256  * 6            6               0
00257  * 7            7               0
00258  * 8            8               0
00259  * 9            9               0
00260  * 10           10              0
00261  * 11           11              0
00262  * 12           12              0
00263  * 13           13              0
00264  * 14           14              0
00265  * 15           15              0
00266  * 16           0               1
00267  * 17           1               1
00268  * 18           2               1
00269  * 19           3               1
00270  * 20           4               1
00271  * 21           5               1
00272  * 22           6               1
00273  * 23           7               1
00274  * 24           8               1
00275  * 25           9               1
00276  * 26           10              1
00277  * 27           11              1
00278  * 28           12              1
00279  * 29           13              1
00280  * 30           14              1
00281  * 31           15              1
00282  * 32           0               2
00283  * 33           1               2
00284  * 34           2               2
00285  * 35           3               2
00286  * 36           4               2
00287  * 37           5               2
00288  * 38           6               2
00289  * 39           7               2
00290  * 40           8               2
00291  * 41           9               2
00292  * 42           10              2
00293  * 43           11              2
00294  * 44           12              2
00295  * 45           13              2
00296  * 46           14              2
00297  * 47           15              2
00298  * 48           0               3
00299  * 49           1               3
00300  * 50           2               3
00301  * 51           3               3
00302  * 52           4               3
00303  * 53           5               3
00304  * 54           6               3
00305  * 55           7               3
00306  * 56           8               3
00307  * 57           9               3
00308  * 58           10              3
00309  * 59           11              3
00310  * 60           12              3
00311  * 61           13              3
00312  * 62           14              3
00313  * 63           15              3
00314  *
00315  */
00316 
00317 template<typename FloatN>
00318 __global__ void
00319 do_link_format_gpu_to_cpu(FloatN* dst, FloatN* src,
00320                           int Vh, int stride)
00321 {
00322   __shared__ FloatN buf[gaugeSiteSize/2*BLOCKSIZE];
00323   
00324   int j;
00325   
00326   int block_idx = blockIdx.x*blockDim.x/4;
00327   int local_idx = 16*(threadIdx.x/64) + threadIdx.x%16;
00328   int pos_idx = blockIdx.x * blockDim.x/4 + 16*(threadIdx.x/64) + threadIdx.x%16;
00329   int mydir = (threadIdx.x >> 4)% 4;
00330   for(j=0; j < 9; j++){
00331     buf[local_idx*4*9+mydir*9+j] = src[pos_idx + mydir*9*stride + j*stride];
00332   }
00333   __syncthreads();
00334   
00335   for(j=0; j < 9; j++){
00336     dst[block_idx*9*4 + j*blockDim.x + threadIdx.x ] = buf[j*blockDim.x + threadIdx.x];    
00337   }  
00338   
00339 }
00340 
00341 
00342 
00343 void 
00344 link_format_gpu_to_cpu(void* dst, void* src, 
00345                        int Vh, int stride, QudaPrecision prec, cudaStream_t stream)
00346 {
00347   
00348   dim3 blockDim(BLOCKSIZE);
00349   dim3 gridDim(4*Vh/blockDim.x); //every 4 threads process one site's x,y,z,t links
00350   //4*Vh must be multipl of BLOCKSIZE or the kernel does not work
00351   if ((4*Vh) % blockDim.x != 0){
00352     errorQuda("ERROR: 4*Vh(%d) is not multiple of blocksize(%d), exitting\n", Vh, blockDim.x);
00353   }
00354   if(prec == QUDA_DOUBLE_PRECISION){
00355     do_link_format_gpu_to_cpu<<<gridDim, blockDim, 0, stream>>>((double2*)dst, (double2*)src, Vh, stride);
00356   }else if(prec == QUDA_SINGLE_PRECISION){
00357     do_link_format_gpu_to_cpu<<<gridDim, blockDim, 0, stream>>>((float2*)dst, (float2*)src, Vh, stride);
00358   }else{
00359     printf("ERROR: half precision is not supported in %s\n",__FUNCTION__);
00360     exit(1);
00361   }
00362   
00363 }
00364 
00365 #define READ_ST_STAPLE(staple, idx, mystride)           \
00366   Float2 P0 = staple[idx + 0*mystride];                 \
00367   Float2 P1 = staple[idx + 1*mystride];                 \
00368   Float2 P2 = staple[idx + 2*mystride];                 \
00369   Float2 P3 = staple[idx + 3*mystride];                 \
00370   Float2 P4 = staple[idx + 4*mystride];                 \
00371   Float2 P5 = staple[idx + 5*mystride];                 \
00372   Float2 P6 = staple[idx + 6*mystride];                 \
00373   Float2 P7 = staple[idx + 7*mystride];                 \
00374   Float2 P8 = staple[idx + 8*mystride];                 
00375 
00376 #define WRITE_ST_STAPLE(staple, idx, mystride)          \
00377   staple[idx + 0*mystride] = P0;                        \
00378   staple[idx + 1*mystride] = P1;                        \
00379   staple[idx + 2*mystride] = P2;                        \
00380   staple[idx + 3*mystride] = P3;                        \
00381   staple[idx + 4*mystride] = P4;                        \
00382   staple[idx + 5*mystride] = P5;                        \
00383   staple[idx + 6*mystride] = P6;                        \
00384   staple[idx + 7*mystride] = P7;                        \
00385   staple[idx + 8*mystride] = P8;                        
00386 
00387 
00388 
00389 template<int dir, int whichway, typename Float2>
00390   __global__ void
00391   collectGhostStapleKernel(Float2* in, const int oddBit,
00392                            Float2* nbr_staple_gpu)
00393 {
00394 
00395   int sid = blockIdx.x*blockDim.x + threadIdx.x;
00396   int z1 = sid / X1h;
00397   int x1h = sid - z1*X1h;
00398   int z2 = z1 / X2;
00399   int x2 = z1 - z2*X2;
00400   int x4 = z2 / X3;
00401   int x3 = z2 - x4*X3;
00402   int x1odd = (x2 + x3 + x4 + oddBit) & 1;
00403   int x1 = 2*x1h + x1odd;
00404   //int X = 2*sid + x1odd;
00405 
00406   READ_ST_STAPLE(in, sid, staple_stride);
00407   int ghost_face_idx;
00408   
00409   if ( dir == 0 && whichway == QUDA_BACKWARDS){
00410     if (x1 < 1){
00411       ghost_face_idx = (x4*(X3*X2)+x3*X2 +x2)>>1;
00412       WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X4*X3*X2/2);
00413     }
00414   }
00415 
00416   if ( dir == 0 && whichway == QUDA_FORWARDS){
00417     if (x1 >= X1 - 1){
00418       ghost_face_idx = (x4*(X3*X2)+x3*X2 +x2)>>1;
00419       WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X4*X3*X2/2);
00420     }
00421   }
00422   
00423   if ( dir == 1 && whichway == QUDA_BACKWARDS){
00424     if (x2 < 1){
00425       ghost_face_idx = (x4*X3*X1+x3*X1+x1)>>1;
00426       WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X4*X3*X1/2);
00427     }
00428   }
00429 
00430   if ( dir == 1 && whichway == QUDA_FORWARDS){
00431     if (x2 >= X2 - 1){
00432       ghost_face_idx = (x4*X3*X1+x3*X1+x1)>>1;
00433       WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X4*X3*X1/2);
00434     }
00435   }
00436 
00437   if ( dir == 2 && whichway == QUDA_BACKWARDS){
00438     if (x3 < 1){
00439       ghost_face_idx = (x4*X2*X1+x2*X1+x1)>>1;
00440       WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X4*X2*X1/2);
00441     }
00442   }
00443 
00444   if ( dir == 2 && whichway == QUDA_FORWARDS){
00445     if (x3 >= X3 - 1){
00446       ghost_face_idx = (x4*X2*X1 + x2*X1 + x1)>>1;
00447       WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X4*X2*X1/2);
00448     }
00449   }
00450 
00451   if ( dir == 3 && whichway == QUDA_BACKWARDS){
00452     if (x4 < 1){
00453       ghost_face_idx = (x3*X2*X1+x2*X1+x1)>>1;
00454       WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X3*X2*X1/2);
00455     }
00456   }
00457   
00458   if ( dir == 3 && whichway == QUDA_FORWARDS){
00459     if (x4 >= X4 - 1){
00460       ghost_face_idx = (x3*X2*X1+x2*X1+x1)>>1;
00461       WRITE_ST_STAPLE(nbr_staple_gpu, ghost_face_idx, X3*X2*X1/2);
00462     }
00463   }
00464 
00465 }
00466 
00467 
00468 //@dir can be 0, 1, 2, 3 (X,Y,Z,T directions)
00469 //@whichway can be QUDA_FORWARDS, QUDA_BACKWORDS
00470 void
00471 collectGhostStaple(int* X, void* even, void* odd, int volume, QudaPrecision precision,
00472                    void* ghost_staple_gpu,                 
00473                    int dir, int whichway, cudaStream_t* stream)
00474 {
00475   int Vsh_x, Vsh_y, Vsh_z, Vsh_t;
00476   
00477   Vsh_x = X[1]*X[2]*X[3]/2;
00478   Vsh_y = X[0]*X[2]*X[3]/2;
00479   Vsh_z = X[0]*X[1]*X[3]/2;
00480   Vsh_t = X[0]*X[1]*X[2]/2;  
00481     
00482   dim3 gridDim(volume/BLOCKSIZE, 1, 1);
00483   dim3 blockDim(BLOCKSIZE, 1, 1);
00484   int Vsh[4] = {Vsh_x, Vsh_y, Vsh_z, Vsh_t};
00485     
00486   void* gpu_buf_even = ghost_staple_gpu;
00487   void* gpu_buf_odd = ((char*)ghost_staple_gpu) + Vsh[dir]*gaugeSiteSize*precision ;
00488   if (X[dir] % 2 ==1){ //need switch even/odd
00489     gpu_buf_odd = ghost_staple_gpu;
00490     gpu_buf_even = ((char*)ghost_staple_gpu) + Vsh[dir]*gaugeSiteSize*precision ;    
00491   }
00492 
00493   int even_parity = 0;
00494   int odd_parity = 1;
00495   
00496   if (precision == QUDA_DOUBLE_PRECISION){
00497     switch(dir){
00498     case 0:
00499       switch(whichway){
00500       case QUDA_BACKWARDS:
00501         collectGhostStapleKernel<0, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even);
00502         collectGhostStapleKernel<0, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd);
00503         break;
00504       case QUDA_FORWARDS:
00505         collectGhostStapleKernel<0, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even);
00506         collectGhostStapleKernel<0, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd);
00507         break;
00508       default:
00509         errorQuda("Invalid whichway");
00510         break;
00511       }
00512       break;
00513 
00514     case 1:
00515       switch(whichway){
00516       case QUDA_BACKWARDS:
00517         collectGhostStapleKernel<1, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even);
00518         collectGhostStapleKernel<1, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd);
00519         break;
00520       case QUDA_FORWARDS:
00521         collectGhostStapleKernel<1, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even);
00522         collectGhostStapleKernel<1, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd);
00523         break;
00524       default:
00525         errorQuda("Invalid whichway");
00526         break;
00527       }
00528       break;
00529       
00530     case 2:
00531       switch(whichway){
00532       case QUDA_BACKWARDS:
00533         collectGhostStapleKernel<2, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even);
00534         collectGhostStapleKernel<2, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd);
00535         break;
00536       case QUDA_FORWARDS:
00537         collectGhostStapleKernel<2, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even);
00538         collectGhostStapleKernel<2, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd);
00539         break;
00540       default:
00541         errorQuda("Invalid whichway");
00542         break;
00543       }
00544       break;
00545       
00546     case 3:
00547       switch(whichway){
00548       case QUDA_BACKWARDS:
00549         collectGhostStapleKernel<3, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even);
00550         collectGhostStapleKernel<3, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd);
00551         break;
00552       case QUDA_FORWARDS:
00553         collectGhostStapleKernel<3, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)even, even_parity, (double2*)gpu_buf_even);
00554         collectGhostStapleKernel<3, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((double2*)odd, odd_parity, (double2*)gpu_buf_odd);
00555         break;
00556       default:
00557         errorQuda("Invalid whichway");
00558         break;
00559       }
00560       break;      
00561     }
00562   }else if(precision == QUDA_SINGLE_PRECISION){
00563    switch(dir){
00564     case 0:
00565       switch(whichway){
00566       case QUDA_BACKWARDS:
00567         collectGhostStapleKernel<0, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even);
00568         collectGhostStapleKernel<0, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd);
00569         break;
00570       case QUDA_FORWARDS:
00571         collectGhostStapleKernel<0, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even);
00572         collectGhostStapleKernel<0, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd);
00573         break;
00574       default:
00575         errorQuda("Invalid whichway");
00576         break;
00577       }
00578       break;
00579 
00580     case 1:
00581       switch(whichway){
00582       case QUDA_BACKWARDS:
00583         collectGhostStapleKernel<1, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even);
00584         collectGhostStapleKernel<1, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd);
00585         break;
00586       case QUDA_FORWARDS:
00587         collectGhostStapleKernel<1, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even);
00588         collectGhostStapleKernel<1, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd);
00589         break;
00590       default:
00591         errorQuda("Invalid whichway");
00592         break;
00593       }
00594       break;
00595       
00596     case 2:
00597       switch(whichway){
00598       case QUDA_BACKWARDS:
00599         collectGhostStapleKernel<2, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even);
00600         collectGhostStapleKernel<2, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd);
00601         break;
00602       case QUDA_FORWARDS:
00603         collectGhostStapleKernel<2, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even);
00604         collectGhostStapleKernel<2, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd);
00605         break;
00606       default:
00607         errorQuda("Invalid whichway");
00608         break;
00609       }
00610       break;
00611       
00612     case 3:
00613       switch(whichway){
00614       case QUDA_BACKWARDS:
00615         collectGhostStapleKernel<3, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even);
00616         collectGhostStapleKernel<3, QUDA_BACKWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd);
00617         break;
00618       case QUDA_FORWARDS:
00619         collectGhostStapleKernel<3, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)even, even_parity, (float2*)gpu_buf_even);
00620         collectGhostStapleKernel<3, QUDA_FORWARDS><<<gridDim, blockDim, 0, *stream>>>((float2*)odd, odd_parity, (float2*)gpu_buf_odd);
00621         break;
00622       default:
00623         errorQuda("Invalid whichway");
00624         break;
00625       }
00626       break;
00627    }
00628   }else{
00629     printf("ERROR: invalid  precision for %s\n", __FUNCTION__);
00630     exit(1);
00631   }
00632 
00633 }
00634 
00635 
00636 #undef gaugeSiteSize 
00637 #undef BLOCKSIZE 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines