35 #define PRESERVE_SPINOR_NORM
37 #ifdef PRESERVE_SPINOR_NORM // Preserve the norm regardless of basis
38 #define kP (1.0/sqrt(2.0))
39 #define kU (1.0/sqrt(2.0))
40 #else // More numerically accurate not to preserve the norm between basis
45 template <
typename Float,
int Ns,
int Nc,
int N>
51 : field(field), volume(volume), stride(stride) { ; }
55 for (
int s=0;
s<Ns;
s++) {
56 for (
int c=0; c<Nc; c++) {
57 for (
int z=0; z<2; z++) {
58 int internal_idx = (
s*Nc + c)*2 + z;
59 int pad_idx = internal_idx / N;
60 v[(
s*Nc+c)*2+z] =
field[(pad_idx *
stride + x)*N + internal_idx % N];
67 for (
int s=0;
s<Ns;
s++) {
68 for (
int c=0; c<Nc; c++) {
69 for (
int z=0; z<2; z++) {
70 int internal_idx = (
s*Nc + c)*2 + z;
71 int pad_idx = internal_idx / N;
72 field[(pad_idx *
stride +
x)*N + internal_idx % N] = v[(
s*Nc+c)*2+z];
85 for (
int i=0; i<4*3*2; i+=4) {
86 float4
tmp = ((float4*)field)[i/4 *
stride +
x];
87 v[i] = tmp.x; v[i+1] = tmp.y; v[i+2] = tmp.z; v[i+3] = tmp.w;
94 for (
int i=0; i<4*3*2; i+=4) {
95 float4
tmp = make_float4(v[i], v[i+1], v[i+2], v[i+3]);
100 template <
typename Float,
int Ns,
int Nc>
106 : field(field), volume(volume), stride(stride)
107 {
if (volume != stride)
errorQuda(
"Stride must equal volume for this field order"); }
111 for (
int s=0;
s<Ns;
s++) {
112 for (
int c=0; c<Nc; c++) {
113 for (
int z=0; z<2; z++) {
114 v[(
s*Nc+c)*2+z] =
field[((x*Nc + c)*Ns +
s)*2 + z];
121 for (
int s=0;
s<Ns;
s++) {
122 for (
int c=0; c<Nc; c++) {
123 for (
int z=0; z<2; z++) {
124 field[((x*Nc + c)*Ns +
s)*2 + z] = v[(
s*Nc+c)*2+z];
133 template <
typename Float,
int Ns,
int Nc>
135 const int tid = threadIdx.x;
136 const int vec_length = Ns*Nc*2;
139 const int block_dim = (blockIdx.x == gridDim.x-1) ?
140 volume - (gridDim.x-1)*blockDim.x : blockDim.x;
146 while (i<vec_length*block_dim) {
147 int space_idx = i / vec_length;
148 int internal_idx = i - space_idx*vec_length;
149 int sh_idx = internal_idx*(blockDim.x+1) + space_idx;
150 s_data[sh_idx] = field[x0*vec_length + i];
157 for (
int s=0;
s<Ns;
s++)
159 for (
int c=0; c<Nc; c++)
161 for (
int z=0; z<2; z++) {
162 int sh_idx = ((c*Ns+
s)*2+z)*(blockDim.x+1) + tid;
163 v[(
s*Nc + c)*2 + z] = s_data[sh_idx];
168 template <
typename Float,
int Ns,
int Nc>
170 const int tid = threadIdx.x;
171 const int vec_length = Ns*Nc*2;
174 const int block_dim = (blockIdx.x == gridDim.x-1) ?
175 volume - (gridDim.x-1)*blockDim.x : blockDim.x;
180 for (
int s=0;
s<Ns;
s++)
182 for (
int c=0; c<Nc; c++)
184 for (
int z=0; z<2; z++) {
185 int sh_idx = ((c*Ns+
s)*2+z)*(blockDim.x+1) + tid;
186 s_data[sh_idx] = v[(
s*Nc + c)*2 + z];
193 while (i<vec_length*block_dim) {
194 int space_idx = i / vec_length;
195 int internal_idx = i - space_idx*vec_length;
196 int sh_idx = internal_idx*(blockDim.x+1) + space_idx;
197 field[x0*vec_length + i] = s_data[sh_idx];
206 load_shared<float, 4, 3>(v, field,
x, volume);
210 for (
int s=0;
s<Ns;
s++) {
211 for (
int c=0; c<Nc; c++) {
212 for (
int z=0; z<2; z++) {
213 v[(
s*Nc+c)*2+z] = field[((x*Nc + c)*Ns +
s)*2 + z];
223 save_shared<float, 4, 3>(field, v,
x, volume);
227 for (
int s=0;
s<Ns;
s++) {
228 for (
int c=0; c<Nc; c++) {
229 for (
int z=0; z<2; z++) {
230 field[((x*Nc + c)*Ns +
s)*2 + z] = v[(
s*Nc+c)*2+z];
237 template <
typename Float,
int Ns,
int Nc>
243 : field(field), volume(volume), stride(stride)
244 {
if (volume != stride)
errorQuda(
"Stride must equal volume for this field order"); }
248 for (
int s=0;
s<Ns;
s++) {
249 for (
int c=0; c<Nc; c++) {
250 for (
int z=0; z<2; z++) {
251 v[(
s*Nc+c)*2+z] =
field[((x*Ns +
s)*Nc + c)*2 + z];
258 for (
int s=0;
s<Ns;
s++) {
259 for (
int c=0; c<Nc; c++) {
260 for (
int z=0; z<2; z++) {
261 field[((x*Ns +
s)*Nc + c)*2 + z] = v[(
s*Nc+c)*2+z];
271 template <
typename FloatOut,
typename FloatIn,
int Ns,
int Nc>
274 __device__ __host__
inline void operator()(FloatOut
out[Ns*Nc*2],
const FloatIn
in[Ns*Nc*2]) {
275 for (
int s=0;
s<Ns;
s++) {
276 for (
int c=0; c<Nc; c++) {
277 for (
int z=0; z<2; z++) {
278 out[(
s*Nc+c)*2+z] =
in[(
s*Nc+c)*2+z];
286 template <
typename FloatOut,
typename FloatIn,
int Ns,
int Nc>
288 __device__ __host__
inline void operator()(FloatOut
out[Ns*Nc*2],
const FloatIn
in[Ns*Nc*2]) {
289 int s1[4] = {1, 2, 3, 0};
290 int s2[4] = {3, 0, 1, 2};
291 FloatOut K1[4] = {
kP, -
kP, -
kP, -kP};
292 FloatOut K2[4] = {
kP, -
kP,
kP, kP};
293 for (
int s=0;
s<Ns;
s++) {
294 for (
int c=0; c<Nc; c++) {
295 for (
int z=0; z<2; z++) {
296 out[(
s*Nc+c)*2+z] = K1[
s]*
in[(s1[
s]*Nc+c)*2+z] + K2[
s]*
in[(s2[
s]*Nc+c)*2+z];
304 template <
typename FloatOut,
typename FloatIn,
int Ns,
int Nc>
306 __device__ __host__
inline void operator()(FloatOut
out[Ns*Nc*2],
const FloatIn
in[Ns*Nc*2]) {
307 int s1[4] = {1, 2, 3, 0};
308 int s2[4] = {3, 0, 1, 2};
309 FloatOut K1[4] = {-
kU,
kU,
kU, kU};
310 FloatOut K2[4] = {-
kU,
kU, -
kU, -kU};
311 for (
int s=0;
s<Ns;
s++) {
312 for (
int c=0; c<Nc; c++) {
313 for (
int z=0; z<2; z++) {
314 out[(
s*Nc+c)*2+z] = K1[
s]*
in[(s1[
s]*Nc+c)*2+z] + K2[
s]*
in[(s2[
s]*Nc+c)*2+z];
322 template <
typename FloatOut,
typename FloatIn,
int Ns,
int Nc,
typename OutOrder,
typename InOrder,
typename Basis>
323 void packSpinor(OutOrder &outOrder,
const InOrder &inOrder, Basis basis,
int volume) {
324 for (
int x=0;
x<volume;
x++) {
326 FloatOut
out[Ns*Nc*2];
327 inOrder.load(in,
x, volume);
329 outOrder.save(out,
x, volume);
334 template <
typename FloatOut,
typename FloatIn,
int Ns,
int Nc,
typename OutOrder,
typename InOrder,
typename Basis>
335 __global__
void packSpinorKernel(OutOrder outOrder,
const InOrder inOrder, Basis basis,
int volume) {
336 int x = blockIdx.x * blockDim.x + threadIdx.x;
339 FloatOut
out[Ns*Nc*2];
340 inOrder.load(in, x, volume);
342 if (x >= volume)
return;
344 outOrder.save(out, x, volume);
347 template <
typename FloatOut,
typename FloatIn,
int Ns,
int Nc,
typename OutOrder,
typename InOrder,
typename Basis>
355 int sharedBytesPerThread()
const {
356 size_t regSize =
sizeof(FloatOut) >
sizeof(FloatIn) ?
sizeof(FloatOut) :
sizeof(FloatIn);
357 return Ns*Nc*2*regSize;
361 int sharedBytesPerBlock(
const TuneParam &
param)
const {
return (param.block.x+1)*sharedBytesPerThread(); }
362 bool advanceSharedBytes(TuneParam ¶m)
const {
return false; }
363 bool advanceGridDim(TuneParam ¶m)
const {
return false; }
364 bool advanceBlockDim(TuneParam ¶m)
const {
365 bool advance = Tunable::advanceBlockDim(param);
366 if (advance) param.grid = dim3( (volume+param.block.x-1) / param.block.x, 1, 1);
367 param.shared_bytes = sharedBytesPerThread() * (param.block.x+1);
372 PackSpinor(OutOrder &out,
const InOrder &in, Basis &basis,
int volume)
373 : out(out), in(in), basis(basis), volume(volume) { ; }
378 packSpinorKernel<FloatOut, FloatIn, Ns, Nc, OutOrder, InOrder, Basis>
379 <<<tp.grid, tp.block, tp.shared_bytes, stream>>>
380 (out, in, basis, volume);
384 std::stringstream vol, aux;
386 aux <<
"out_stride=" << out.stride <<
",in_stride=" << in.stride;
387 return TuneKey(vol.str(),
typeid(*this).name(), aux.str());
391 std::stringstream ps;
392 ps <<
"block=(" << param.block.x <<
"," << param.block.y <<
"," << param.block.z <<
"), ";
393 ps <<
"shared=" << param.shared_bytes;
398 Tunable::initTuneParam(param);
399 param.grid = dim3( (volume+param.block.x-1) / param.block.x, 1, 1);
404 Tunable::defaultTuneParam(param);
405 param.grid = dim3( (volume+param.block.x-1) / param.block.x, 1, 1);
408 long long flops()
const {
return 0; }
409 long long bytes()
const {
return in.Bytes() + out.Bytes(); }
414 template <
int Ns,
int Nc,
typename OutOrder,
typename InOrder,
typename FloatOut,
typename FloatIn>
415 void packParitySpinor(FloatOut *dst, FloatIn *src, OutOrder &outOrder,
const InOrder &inOrder,
int Vh,
int pad,
417 if (dstBasis==srcBasis) {
420 packSpinor<FloatOut, FloatIn, Ns, Nc>(outOrder, inOrder, basis,
Vh);
422 PackSpinor<FloatOut, FloatIn, Ns, Nc, OutOrder, InOrder, PreserveBasis<FloatOut, FloatIn, Ns, Nc> > pack(outOrder, inOrder, basis, Vh);
426 if (Ns != 4)
errorQuda(
"Can only change basis with Nspin = 4, not Nspin = %d", Ns);
429 packSpinor<FloatOut, FloatIn, Ns, Nc>(outOrder, inOrder, basis,
Vh);
431 PackSpinor<FloatOut, FloatIn, Ns, Nc, OutOrder, InOrder, NonRelBasis<FloatOut, FloatIn, Ns, Nc> > pack(outOrder, inOrder, basis, Vh);
435 if (Ns != 4)
errorQuda(
"Can only change basis with Nspin = 4, not Nspin = %d", Ns);
438 packSpinor<FloatOut, FloatIn, Ns, Nc>(outOrder, inOrder, basis,
Vh);
440 PackSpinor<FloatOut, FloatIn, Ns, Nc, OutOrder, InOrder, RelBasis<FloatOut, FloatIn, Ns, Nc> > pack(outOrder, inOrder, basis, Vh);
449 template <
int Nc,
int Ns,
int N,
typename dstFloat,
typename srcFloat>
452 if (dst.Ndim() != src.Ndim()) {
453 errorQuda(
"Number of dimensions %d %d don't match", dst.Ndim(), src.Ndim());
456 if (dst.Volume() != src.Volume()) {
457 errorQuda(
"Volumes %d %d don't match", dst.Volume(), src.Volume());
460 if (!( dst.SiteOrder() == src.SiteOrder() ||
465 errorQuda(
"Subset orders %d %d don't match", dst.SiteOrder(), src.SiteOrder());
468 if (dst.SiteSubset() != src.SiteSubset()) {
469 errorQuda(
"Subset types do not match %d %d", dst.SiteSubset(), src.SiteSubset());
472 int V = dst.Volume();
475 int dstLength = dst.Bytes() / dst.Precision();
476 int srcLength = src.Bytes() / src.Precision();
484 errorQuda(
"Copying to full fields with lexicographical ordering is not currently supported");
489 unsigned int evenOff, oddOff;
492 oddOff = srcLength/2;
495 evenOff = srcLength/2;
501 if (src.Pad() != 0)
errorQuda(
"Non-zero pad not supported with fieldOrder %d\n", srcOrder);
506 packParitySpinor<Ns,Nc>(Dst, Src+evenOff, outOrder, inOrder,
Vh, dst.Pad(), dstBasis, srcBasis,
location);
511 packParitySpinor<Ns,Nc>(Dst + dstLength/2, Src+oddOff, outOrder, inOrder,
Vh, dst.Pad(), dstBasis, srcBasis,
location);
517 packParitySpinor<Ns,Nc>(Dst, Src+evenOff, outOrder, inOrder,
Vh, dst.Pad(), dstBasis, srcBasis,
location);
522 packParitySpinor<Ns,Nc>(Dst + dstLength/2, Src+oddOff, outOrder, inOrder,
Vh, dst.Pad(), dstBasis, srcBasis,
location);
527 if (dst.Pad() != 0)
errorQuda(
"Non-zero pad not supported with fieldOrder %d\n", dstOrder);
532 packParitySpinor<Ns,Nc>(Dst, Src+evenOff, outOrder, inOrder,
Vh, src.Pad(), dstBasis, srcBasis,
location);
537 packParitySpinor<Ns,Nc>(Dst + dstLength/2, Src+oddOff, outOrder, inOrder,
Vh, src.Pad(), dstBasis, srcBasis,
location);
543 packParitySpinor<Ns,Nc>(Dst, Src+evenOff, outOrder, inOrder,
Vh, src.Pad(), dstBasis, srcBasis,
location);
548 packParitySpinor<Ns,Nc>(Dst + dstLength/2, Src+oddOff, outOrder, inOrder,
Vh, src.Pad(), dstBasis, srcBasis,
location);
552 errorQuda(
"Field order conversion from %d to %d not supported", srcOrder, dstOrder);
559 if (src.Pad() != 0)
errorQuda(
"Non-zero pad not supported with fieldOrder %d\n", srcOrder);
563 packParitySpinor<Ns,Nc>(Dst, Src, outOrder, inOrder,
V, dst.Pad(), dstBasis, srcBasis,
location);
567 packParitySpinor<Ns,Nc>(Dst, Src, outOrder, inOrder,
V, dst.Pad(), dstBasis, srcBasis,
location);
571 if (dst.Pad() != 0)
errorQuda(
"Non-zero pad not supported with fieldOrder %d\n", dstOrder);
575 packParitySpinor<Ns,Nc>(Dst, Src, outOrder, inOrder,
V, src.Pad(), dstBasis, srcBasis,
location);
579 packParitySpinor<Ns,Nc>(Dst, Src, outOrder, inOrder,
V, src.Pad(), dstBasis, srcBasis,
location);
582 errorQuda(
"Field order conversion from %d to %d not supported", srcOrder, dstOrder);