3 #include <quda_internal.h>
7 template<typename Output, typename Input>
8 struct ShiftColorSpinorFieldArg {
9 const unsigned int length;
12 const usigned int ghostOffset; // depends on the direction
14 const unsigned int parity;
15 const unsigned int dir;
20 ShiftColorSpinorFieldArg(const unsigned int length,
21 const unsigned int X[4],
22 const unsigned int ghostOffset,
23 const unsigned int parity,
24 const unsigned int dir,
27 const Output& out) : length(length),
29 ghostOffset(ghostOffset),
31 parity(parity), dir(dir), shift(shift), in(in), out(out)
33 for(int i=0; i<4; ++i) this->X[i] = X[i];
34 for(int i=0; i<4; ++i) partitioned[i] = commDimPartitioned(i) ? true : false;
38 template<IndexType idxType, typename Int>
39 __device__ __forceinline__
40 int neighborIndex(const unsigned int& cb_idx, const int (&shift)[4], const bool (&partitioned)[4], const unsigned int& parity){
45 coordsFromIndex(full_idx, x, y, z, t, cb_idx, parity);
49 if( (x+shift[0])<0 || (x+shift[0])>=X1) return -1;
51 if( (y+shift[1])<0 || (y+shift[1])>=X2) return -1;
53 if( (z+shift[2])<0 || (z+shift[2])>=X3) return -1;
55 if( (z+shift[3])<0 || (z+shift[3])>=X4) return -1;
58 x = shift[0] ? (x + shift[0] + X1) % X1 : x;
59 y = shift[1] ? (y + shift[1] + X2) % X2 : y;
60 z = shift[2] ? (z + shift[2] + X3) % X3 : z;
61 t = shift[3] ? (t + shift[3] + X4) % X4 : t;
62 return (((t*X3 + z)*X2 + y)*X1 + x) >> 1;
65 template <typename FloatN, int N, typename Arg>
66 __global__ void shiftColorSpinorFieldKernel(Arg arg)
68 int shift[4] = {0,0,0,0};
69 shift[arg.dir] = arg.shift;
71 unsigned int idx = blockIdx.x*(blockDim.x) + threadIdx.x;
72 unsigned int gridSize = gridDim.x*blockDim.x;
75 while(idx<arg.length){
76 const int new_idx = neighborIndex(idx, shift, arg.partitioned, arg.parity);
80 arg.in.load(x, new_idx);
89 template<typename FloatN, int N, typename Arg>
90 __global__ void shiftColorSpinorFieldExternalKernel(Arg arg)
92 unsigned int idx = blockIdx.x*(blockDim.x) + threadIdx.x;
93 unsigned int gridSize = gridDim.x*blockDim.x;
96 unsigned int coord[4];
97 while(idx<arg.length){
99 // compute the coordinates in the ghost zone
100 coordsFromIndex<1>(coord, idx, arg.X, arg.dir, arg.parity);
102 unsigned int ghost_idx = arg.ghostOffset + ghostIndexFromCoords<3,3>(arg.X, coord, arg.dir, arg.shift);
104 arg.in.load(x, ghost_idx);
105 arg.out.save(x, idx);
111 template<typename Output, typename Input>
112 class ShiftColorSpinorField : public Tunable {
113 ShiftColorSpinorFieldArg<Output,Input> arg;
114 const int *X; // pointer to lattice dimensions
116 int sharedBytesPerThread() const { return 0; }
117 int sharedBytesPerBlock(const TuneParam &) cont { return 0; }
119 // don't tune the grid dimension
120 bool advanceGridDim(TuneParam & param) const { return false; }
122 bool advanceBlockDim(TuneParam ¶m) const
124 const unsigned int max_threads = deviceProp.maxThreadsDim[0];
125 const unsigned int max_blocks = deviceProp.maxGridSize[0];
126 const unsigned int max_shared = 16384;
127 const int step = deviceProp.warpSize;
128 const int threads = arg.length;
131 param.block.x += step;
132 if(param.block.x > max_threads || sharedBytesPerThread()*param.block.x > max_shared){
133 param.block = dim3((threads+max_blocks-1)/max_blocks, 1, 1); // ensure the blockDim is large enough given the limit on gridDim
134 param.block.x = ((param.block.x+step-1)/step)*step;
135 if(param.block.x > max_threads) errorQuda("Local lattice volume is too large for device");
140 param.grid = dim3((threads+param.block.x-1)/param.block.x,1,1);
146 ShiftColorSpinorField(const ShiftColorSpinorField<Output,Input> &arg,
147 QudaFieldLocation location)
148 : arg(arg), location(location) {}
149 virtual ~ShiftColorSpinorField() {}
151 void apply(const qudaStream_t &stream){
152 if(location == QUDA_CUDA_FIELD_LOCATION){
153 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
154 qudaLaunchKernel(shiftColorSpinorFieldKernel<decltype(arg)>, tp, stream, arg);
156 // Need to perform some communication and call exterior kernel, I guess
158 }else{ // run the CPU code
159 errorQuda("ShiftColorSpinorField is not yet implemented on the CPU\n");
163 virtual void initTuneParam(TuneParam ¶m) const
165 const unsigned int max_threads = deviceProp.maxThreadsDim[0];
166 const unsigned int max_blocks = deviceProp.maxGridSize[0];
167 const int threads = arg.length;
168 const int step = deviceProp.warpSize;
169 param.block = dim3((threads+max_blocks-1)/max_blocks, 1, 1); // ensure the blockDim is large enough, given the limit on gridDim
170 param.block.x = ((param.block.x+step-1) / step) * step; // round up to the nearest "step"
171 if (param.block.x > max_threads) errorQuda("Local lattice volume is too large for device");
172 param.grid = dim3((threads+param.block.x-1)/param.block.x, 1, 1);
173 param.shared_bytes = sharedBytesPerThread()*param.block.x > sharedBytesPerBlock(param) ?
174 sharedBytesPerThread()*param.block.x : sharedBytesPerBlock(param);
177 /** sets default values for when tuning is disabled */
178 void defaultTuneParam(TuneParam ¶m) const {
179 initTuneParam(param);
182 long long flops() const { return 0; } // fixme
183 long long bytes() const { return 0; } // fixme
185 TuneKey tuneKey() const {
186 std::stringstream vol, aux;
191 aux << "threads=" << 2*arg.in.volumeCB << ",prec=" << sizeof(Complex)/2;
192 aux << "stride=" << arg.in.stride;
193 return TuneKey(vol.str(), typeid(*this).name(), aux.str());
198 // Should really have a parity
199 void shiftColorSpinorField(cudaColorSpinorField &dst, const cudaColorSpinorField &src, const unsigned int parity, const unsigned int dim, const int shift) {
202 errorQuda("destination field is the same as source field\n");
206 if(src.Nspin() != 1 && src.Nspin() !=4) errorQuda("nSpin(%d) not supported\n", src.Nspin());
208 if(src.SiteSubset() != dst.SiteSubset())
209 errorQuda("Spinor fields do not have matching subsets\n");
211 if(src.SiteSubset() == QUDA_FULL_SITE_SUBSET){
213 shiftColorSpinorField(dst.Even(), src.Odd(), 0, dim, shift);
214 shiftColorSpinorField(dst.Odd(), src.Even(), 1, dim, shift);
216 shiftColorSpinorField(dst.Even(), src.Even(), 0, dim, shift);
217 shiftColorSpinorField(dst.Odd(), src.Odd(), 1, dim, shift);
223 const int dir = (shift>0) ? QUDA_BACKWARDS : QUDA_FORWARDS; // pack the start of the field if shift is positive
224 const int offset = (shift>0) ? 0 : 1;
228 if(dst.Precision() == QUDA_DOUBLE_PRECISION && src.Precision() == QUDA_DOUBLE_PRECISION){
229 if(src.Nspin() == 1){
230 Spinor<double2, double2, double2, 3, 0, 0> src_tex(src);
231 Spinor<double2, double2, double2, 3, 1> dst_spinor(dst);
232 ShiftColorSpinorFieldArg arg(src.Volume(), parity, dim, shift, dst_spinor, src_tex);
233 ShiftColorSpinorField shiftColorSpinor(arg, QUDA_CPU_FIELD_LOCATION);
236 if(commDimPartitioned(dim) && dim!=3){
237 face->pack(src, 1-parity, dagger, dim, dir, streams); // pack in stream[1]
238 qudaEventRecord(packEnd, streams[1]);
239 qudaStreamWaitEvent(streams[1], packEnd, 0); // wait for pack to end in stream[1]
240 face->gather(src, dagger, 2*dim+offset, 1); // copy packed data from device buffer to host and do this in stream[1]
241 qudaEventRecord(gatherEnd, streams[1]); // record the completion of face->gather
245 shiftColorSpinor.apply(0); // shift the field in the interior region
248 if(commDimPartitioned(dim) && dim!=3){
250 cudaError_t eventQuery = cudaEventQuery(gatherEnd);
251 if(eventQuery == cudaSuccess){
252 face->commsStart(2*dim + offset); // if argument is even, send backwards, else send forwards
257 // after communication, load data back on to device
258 // do this in stream[1]
260 if(face->commsQuery(2*dim + offset)){
261 face->scatter(src, dagger, 2*dim+offset, 1);
265 qudaEventRecord(scatterEnd, streams[1]);
266 qudaStreamWaitEvent(streams[1], scatterEnd, 0);
267 shiftColorSpinor.apply(1);
272 errorQuda("Only staggered fermions are currently supported\n");
274 }else if(dst.Precision() == QUDA_SINGLE_PRECISION && src.Precision() == QUDA_SINGLE_PRECISION){
275 if(src.Nspin() == 1 ){
276 Spinor<float2, float2, float2, 3, 0, 0> src_tex(src);
277 Spinor<float2, float2, float2, 3, 1> dst_spinor(dst);
278 ShiftColorSpinorFieldArg arg(src.Volume(), parity, dim, shift, dst_spinor, src_tex);
279 ShiftColorSpinorField shiftColorSpinor(arg, QUDA_CPU_FIELD_LOCATION);
281 errorQuda("Only staggered fermions are currently supported\n");