16 #define PRESERVE_SPINOR_NORM
18 #ifdef PRESERVE_SPINOR_NORM // Preserve the norm regardless of basis
19 #define kP (1.0/sqrt(2.0))
20 #define kU (1.0/sqrt(2.0))
21 #else // More numerically accurate not to preserve the norm between basis
29 template <
typename FloatOut,
typename FloatIn,
int Ns,
int Nc>
34 __device__ __host__
inline void operator()(RegTypeOut
out[Ns*Nc*2],
const RegTypeIn
in[Ns*Nc*2]) {
35 for (
int s=0;
s<Ns;
s++) {
36 for (
int c=0; c<Nc; c++) {
37 for (
int z=0; z<2; z++) {
38 out[(
s*Nc+c)*2+z] =
in[(
s*Nc+c)*2+z];
46 template <
typename FloatOut,
typename FloatIn,
int Ns,
int Nc>
51 int s1[4] = {1, 2, 3, 0};
52 int s2[4] = {3, 0, 1, 2};
55 for (
int s=0;
s<Ns;
s++) {
56 for (
int c=0; c<Nc; c++) {
57 for (
int z=0; z<2; z++) {
58 out[(
s*Nc+c)*2+z] = K1[
s]*
in[(s1[
s]*Nc+c)*2+z] + K2[
s]*
in[(s2[
s]*Nc+c)*2+z];
66 template <
typename FloatOut,
typename FloatIn,
int Ns,
int Nc>
71 int s1[4] = {1, 2, 3, 0};
72 int s2[4] = {3, 0, 1, 2};
75 for (
int s=0;
s<Ns;
s++) {
76 for (
int c=0; c<Nc; c++) {
77 for (
int z=0; z<2; z++) {
78 out[(
s*Nc+c)*2+z] = K1[
s]*
in[(s1[
s]*Nc+c)*2+z] + K2[
s]*
in[(s2[
s]*Nc+c)*2+z];
86 template <
typename FloatOut,
typename FloatIn,
int Ns,
int Nc>
91 int s1[4] = {0, 1, 0, 1};
92 int s2[4] = {2, 3, 2, 3};
95 for (
int s=0;
s<Ns;
s++) {
96 for (
int c=0; c<Nc; c++) {
97 for (
int z=0; z<2; z++) {
98 out[(
s*Nc+c)*2+z] = K1[
s]*
in[(s1[
s]*Nc+c)*2+z] + K2[
s]*
in[(s2[
s]*Nc+c)*2+z];
106 template <
typename FloatOut,
typename FloatIn,
int Ns,
int Nc>
111 int s1[4] = {0, 1, 0, 1};
112 int s2[4] = {2, 3, 2, 3};
115 for (
int s=0;
s<Ns;
s++) {
116 for (
int c=0; c<Nc; c++) {
117 for (
int z=0; z<2; z++) {
118 out[(
s*Nc+c)*2+z] = K1[
s]*
in[(s1[
s]*Nc+c)*2+z] + K2[
s]*
in[(s2[
s]*Nc+c)*2+z];
126 template <
typename FloatOut,
typename FloatIn,
int Ns,
int Nc,
typename OutOrder,
typename InOrder,
typename Basis>
127 void packSpinor(OutOrder &outOrder,
const InOrder &inOrder, Basis basis,
int volume) {
130 for (
int x=0;
x<volume;
x++) {
131 RegTypeIn
in[Ns*Nc*2];
132 RegTypeOut
out[Ns*Nc*2];
135 outOrder.save(out,
x);
140 template <
typename FloatOut,
typename FloatIn,
int Ns,
int Nc,
typename OutOrder,
typename InOrder,
typename Basis>
141 __global__
void packSpinorKernel(OutOrder outOrder,
const InOrder inOrder, Basis basis,
int volume) {
145 int x = blockIdx.x * blockDim.x + threadIdx.x;
146 RegTypeIn
in[Ns*Nc*2];
147 RegTypeOut
out[Ns*Nc*2];
151 outOrder.save(out, x);
154 template <
typename FloatOut,
typename FloatIn,
int Ns,
int Nc,
typename OutOrder,
typename InOrder,
typename Basis>
162 unsigned int sharedBytesPerThread()
const {
163 size_t regSize =
sizeof(FloatOut) >
sizeof(FloatIn) ?
sizeof(FloatOut) :
sizeof(FloatIn);
164 return Ns*Nc*2*regSize;
168 unsigned int sharedBytesPerBlock(
const TuneParam &
param)
const {
return (param.
block.x+1)*sharedBytesPerThread(); }
169 bool advanceSharedBytes(
TuneParam ¶m)
const {
return false; }
170 bool tuneGridDim()
const {
return false; }
171 unsigned int minThreads()
const {
return meta.
VolumeCB(); }
172 bool advanceBlockDim(
TuneParam ¶m)
const {
181 : out(out), in(in), basis(basis), meta(meta) {
182 writeAuxString(
"out_stride=%d,in_stride=%d", out.stride, in.stride);
188 packSpinorKernel<FloatOut, FloatIn, Ns, Nc, OutOrder, InOrder, Basis>
196 std::stringstream ps;
197 ps <<
"block=(" << param.
block.x <<
"," << param.
block.y <<
"," << param.
block.z <<
"), ";
202 long long flops()
const {
return 0; }
203 long long bytes()
const {
return in.Bytes() + out.Bytes(); }
208 template <
typename FloatOut,
typename FloatIn,
int Ns,
int Nc,
typename OutOrder,
typename InOrder>
212 if (dstBasis==srcBasis) {
215 packSpinor<FloatOut, FloatIn, Ns, Nc>(outOrder, inOrder, basis, out.
VolumeCB());
218 pack(outOrder, inOrder, basis, out);
222 if (Ns != 4)
errorQuda(
"Can only change basis with Nspin = 4, not Nspin = %d", Ns);
225 packSpinor<FloatOut, FloatIn, Ns, Nc>(outOrder, inOrder, basis, out.
VolumeCB());
228 pack(outOrder, inOrder, basis, out);
232 if (Ns != 4)
errorQuda(
"Can only change basis with Nspin = 4, not Nspin = %d", Ns);
235 packSpinor<FloatOut, FloatIn, Ns, Nc>(outOrder, inOrder, basis, out.
VolumeCB());
238 pack(outOrder, inOrder, basis, out);
242 if (Ns != 4)
errorQuda(
"Can only change basis with Nspin = 4, not Nspin = %d", Ns);
245 packSpinor<FloatOut, FloatIn, Ns, Nc>(outOrder, inOrder, basis, out.
VolumeCB());
248 pack(outOrder, inOrder, basis, out);
252 if (Ns != 4)
errorQuda(
"Can only change basis with Nspin = 4, not Nspin = %d", Ns);
255 packSpinor<FloatOut, FloatIn, Ns, Nc>(outOrder, inOrder, basis, out.
VolumeCB());
258 pack(outOrder, inOrder, basis, out);
267 template <
typename FloatOut,
typename FloatIn,
int Ns,
int Nc,
typename InOrder>
270 FloatOut *Out,
float *outNorm) {
273 genericCopyColorSpinor<FloatOut,FloatIn,Ns,Nc>
277 genericCopyColorSpinor<FloatOut,FloatIn,Ns,Nc>
281 genericCopyColorSpinor<FloatOut,FloatIn,Ns,Nc>
285 genericCopyColorSpinor<FloatOut,FloatIn,Ns,Nc>
289 #ifdef BUILD_QDPJIT_INTERFACE
291 genericCopyColorSpinor<FloatOut,FloatIn,Ns,Nc>
294 errorQuda(
"QDPJIT interface has not been built\n");
304 template <
typename FloatOut,
typename FloatIn,
int Ns,
int Nc>
307 float *outNorm,
float *inNorm) {
310 genericCopyColorSpinor<FloatOut,FloatIn,Ns,Nc>(inOrder,
out, in.
GammaBasis(),
location, Out, outNorm);
313 genericCopyColorSpinor<FloatOut,FloatIn,Ns,Nc>(inOrder,
out, in.
GammaBasis(),
location, Out, outNorm);
316 genericCopyColorSpinor<FloatOut,FloatIn,Ns,Nc>(inOrder,
out, in.
GammaBasis(),
location, Out, outNorm);
319 genericCopyColorSpinor<FloatOut,FloatIn,Ns,Nc>(inOrder,
out, in.
GammaBasis(),
location, Out, outNorm);
322 #ifdef BUILD_QDPJIT_INTERFACE
324 genericCopyColorSpinor<FloatOut,FloatIn,Ns,Nc>(inOrder,
out, in.
GammaBasis(),
location, Out, outNorm);
326 errorQuda(
"QDPJIT interface has not been built\n");
336 template <
int Ns,
typename dstFloat,
typename srcFloat>
339 float *dstNorm,
float *srcNorm) {
364 errorQuda(
"Copying to full fields with lexicographical ordering is not currently supported");
370 errorQuda(
"QDPJIT field ordering not supported for full site fields");
374 srcFloat *srcEven = Src ? Src : (srcFloat*)src.
V();
375 srcFloat *srcOdd = (srcFloat*)((
char*)srcEven + src.
Bytes()/2);
376 float *srcNormEven = srcNorm ? srcNorm : (
float*)src.
Norm();
377 float *srcNormOdd = (
float*)((
char*)srcNormEven + src.
NormBytes()/2);
379 std::swap<srcFloat*>(srcEven, srcOdd);
380 std::swap<float*>(srcNormEven, srcNormOdd);
384 dstFloat *dstEven = Dst ? Dst : (dstFloat*)dst.
V();
385 dstFloat *dstOdd = (dstFloat*)((
char*)dstEven + dst.
Bytes()/2);
386 float *dstNormEven = dstNorm ? dstNorm : (
float*)dst.
Norm();
387 float *dstNormOdd = (
float*)((
char*)dstNormEven + dst.
NormBytes()/2);
389 std::swap<dstFloat*>(dstEven, dstOdd);
390 std::swap<float*>(dstNormEven, dstNormOdd);
393 genericCopyColorSpinor<dstFloat, srcFloat, Ns, Nc>
394 (dst, src,
location, dstEven, srcEven, dstNormEven, srcNormEven);
395 genericCopyColorSpinor<dstFloat, srcFloat, Ns, Nc>
396 (dst, src,
location, dstOdd, srcOdd, dstNormOdd, srcNormOdd);
398 genericCopyColorSpinor<dstFloat, srcFloat, Ns, Nc>
399 (dst, src,
location, Dst, Src, dstNorm, srcNorm);
404 template <
typename dstFloat,
typename srcFloat>
407 float *dstNorm=0,
float *srcNorm=0) {
410 errorQuda(
"source and destination spins must match");
412 if (dst.
Nspin() == 4) {
413 copyGenericColorSpinor<4>(dst, src,
location, Dst, Src, dstNorm, srcNorm);
414 }
else if (dst.
Nspin() == 1) {
415 copyGenericColorSpinor<1>(dst, src,
location, Dst, Src, dstNorm, srcNorm);
424 void *dstNorm,
void *srcNorm) {
std::string paramString(const TuneParam ¶m) const
__device__ __host__ void operator()(RegTypeOut out[Ns *Nc *2], const RegTypeIn in[Ns *Nc *2])
mapper< FloatOut >::type RegTypeOut
QudaVerbosity getVerbosity()
mapper< FloatOut >::type RegTypeOut
void CopyGenericColorSpinor(ColorSpinorField &dst, const ColorSpinorField &src, QudaFieldLocation location, dstFloat *Dst, srcFloat *Src, float *dstNorm=0, float *srcNorm=0)
mapper< FloatIn >::type RegTypeIn
void copyGenericColorSpinor(ColorSpinorField &dst, const ColorSpinorField &src, QudaFieldLocation location, void *Dst=0, void *Src=0, void *dstNorm=0, void *srcNorm=0)
__device__ __host__ void operator()(RegTypeOut out[Ns *Nc *2], const RegTypeIn in[Ns *Nc *2])
void writeAuxString(const char *format,...)
void apply(const cudaStream_t &stream)
const QudaFieldLocation location
__global__ void packSpinorKernel(OutOrder outOrder, const InOrder inOrder, Basis basis, int volume)
virtual bool advanceBlockDim(TuneParam ¶m) const
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
const char * VolString() const
__device__ __host__ void operator()(RegTypeOut out[Ns *Nc *2], const RegTypeIn in[Ns *Nc *2])
QudaSiteOrder SiteOrder() const
mapper< FloatOut >::type RegTypeOut
mapper< FloatIn >::type RegTypeIn
QudaFieldOrder FieldOrder() const
__device__ __host__ void operator()(RegTypeOut out[Ns *Nc *2], const RegTypeIn in[Ns *Nc *2])
enum QudaFieldLocation_s QudaFieldLocation
mapper< FloatIn >::type RegTypeIn
cpuColorSpinorField * out
PackSpinor(OutOrder &out, const InOrder &in, Basis &basis, const ColorSpinorField &meta)
mapper< FloatIn >::type RegTypeIn
enum QudaGammaBasis_s QudaGammaBasis
mapper< FloatOut >::type RegTypeOut
QudaPrecision Precision() const
QudaGammaBasis GammaBasis() const
void genericCopyColorSpinor(OutOrder &outOrder, const InOrder &inOrder, QudaGammaBasis dstBasis, QudaGammaBasis srcBasis, const ColorSpinorField &out, QudaFieldLocation location)
QudaSiteSubset SiteSubset() const
void packSpinor(OutOrder &outOrder, const InOrder &inOrder, Basis basis, int volume)
__device__ __host__ void operator()(RegTypeOut out[Ns *Nc *2], const RegTypeIn in[Ns *Nc *2])