10 #define PRESERVE_SPINOR_NORM
12 #ifdef PRESERVE_SPINOR_NORM // Preserve the norm regardless of basis
13 #define kP (1.0/sqrt(2.0))
14 #define kU (1.0/sqrt(2.0))
15 #else // More numerically accurate not to preserve the norm between basis
28 for(
int i=0; i<4; i++){
29 if(R[i] > nFace) nFace = R[i];
34 int gatherCompleted[2] = {0,0};
35 int commsCompleted[2] = {0,0};
38 for(
int dir=0; dir<2; dir++) cudaEventCreate(&gatherEnd[dir], cudaEventDisableTiming);
44 cudaDeviceSynchronize();
45 for(
int dir=1; dir<=0; dir--){
46 spinor->
gather(nFace, dagger, 2*
dim + dir);
47 cudaEventRecord(gatherEnd[dir],
streams[2*
dim+dir]);
52 while(completeSum < 2){
53 if(!gatherCompleted[dir]){
54 if(cudaSuccess == cudaEventQuery(gatherEnd[dir])){
57 gatherCompleted[dir--] = 1;
61 gatherCompleted[0] = gatherCompleted[1] = 0;
65 while(completeSum < 4){
66 if(!commsCompleted[dir]){
70 commsCompleted[dir--] = 1;
74 commsCompleted[0] = commsCompleted[1] = 0;
75 cudaDeviceSynchronize();
78 for(
int dir=0; dir<2; dir++) cudaEventDestroy(gatherEnd[dir]);
86 template <
typename FloatOut,
typename FloatIn,
int Ns,
int Nc>
88 typedef typename mapper<FloatIn>::type RegTypeIn;
89 typedef typename mapper<FloatOut>::type RegTypeOut;
91 __device__ __host__
inline void operator()(RegTypeOut
out[Ns*Nc*2],
const RegTypeIn
in[Ns*Nc*2]) {
92 for (
int s=0;
s<Ns;
s++) {
93 for (
int c=0; c<Nc; c++) {
94 for (
int z=0; z<2; z++) {
95 out[(
s*Nc+c)*2+z] =
in[(
s*Nc+c)*2+z];
103 template <
typename FloatOut,
typename FloatIn,
int Ns,
int Nc>
108 int s1[4] = {1, 2, 3, 0};
109 int s2[4] = {3, 0, 1, 2};
112 for (
int s=0;
s<Ns;
s++) {
113 for (
int c=0; c<Nc; c++) {
114 for (
int z=0; z<2; z++) {
115 out[(
s*Nc+c)*2+z] = K1[
s]*
in[(s1[
s]*Nc+c)*2+z] + K2[
s]*
in[(s2[
s]*Nc+c)*2+z];
124 template <
typename FloatOut,
typename FloatIn,
int Ns,
int Nc>
129 int s1[4] = {1, 2, 3, 0};
130 int s2[4] = {3, 0, 1, 2};
133 for (
int s=0;
s<Ns;
s++) {
134 for (
int c=0; c<Nc; c++) {
135 for (
int z=0; z<2; z++) {
136 out[(
s*Nc+c)*2+z] = K1[
s]*
in[(s1[
s]*Nc+c)*2+z] + K2[
s]*
in[(s2[
s]*Nc+c)*2+z];
146 template<
typename OutOrder,
typename InOrder,
typename Basis>
157 : out(out), in(in), basis(basis), parity(parity)
160 for(
int d=0; d<4; d++){
169 template<
typename FloatOut,
typename FloatIn,
int Ns,
int Nc,
typename OutOrder,
typename InOrder,
typename Basis,
bool extend>
174 for(
int d=0; d<4; d++) R[d] = (arg.
E[d] - arg.
X[d]) >> 1;
176 int za = X/(arg.
X[0]/2);
177 int x0h = X - za*(arg.
X[0]/2);
178 int zb = za/arg.
X[1];
179 x[1] = za - zb*arg.
X[1];
180 x[3] = zb / arg.
X[2];
181 x[2] = zb - x[3]*arg.
X[2];
182 x[0] = 2*x0h + ((x[1] + x[2] + x[3] + arg.
parity) & 1);
185 int Y = ((((x[3]+R[3])*arg.
E[2] + (x[2]+R[2]))*arg.
E[1] + (x[1]+R[1]))*arg.
E[0]+(x[0]+R[0])) >> 1;
190 RegTypeIn
in[Ns*Nc*2];
191 RegTypeOut
out[Ns*Nc*2];
196 arg.
out.save(out, Y);
200 arg.
out.save(out, Y);
205 template<
typename FloatOut,
typename FloatIn,
int Ns,
int Nc,
typename OutOrder,
typename InOrder,
typename Basis,
bool extend>
208 int cb_idx = blockIdx.x*blockDim.x + threadIdx.x;
210 while(cb_idx < arg.
length){
211 copyInterior<FloatOut,FloatIn,Ns,Nc,OutOrder,InOrder,Basis,extend>(
arg,cb_idx);
212 cb_idx += gridDim.x*blockDim.x;
219 template<
typename FloatOut,
typename FloatIn,
int Ns,
int Nc,
typename OutOrder,
typename InOrder,
typename Basis,
bool extend>
222 for(
int cb_idx=0; cb_idx<arg.
length; cb_idx++){
223 copyInterior<FloatOut,FloatIn,Ns,Nc,OutOrder,InOrder,Basis,extend>(
arg, cb_idx);
230 template<
typename FloatOut,
typename FloatIn,
int Ns,
int Nc,
typename OutOrder,
typename InOrder,
typename Basis,
bool extend>
238 unsigned int sharedBytesPerThread()
const {
return 0; }
239 unsigned int sharedBytesPerBlock(
const TuneParam &
param)
const {
return 0; }
240 bool advanceSharedBytes(
TuneParam ¶m)
const {
return false; }
241 bool tuneGridDim()
const {
return false; }
242 unsigned int minThreads()
const {
return arg.length; }
246 : arg(arg), meta(meta), location(location) {
255 copyInterior<FloatOut,FloatIn,Ns,Nc,OutOrder,InOrder,Basis,extend>(arg);
257 copyInteriorKernel<FloatOut,FloatIn,Ns,Nc,OutOrder,InOrder,Basis,extend>
265 std::stringstream ps;
266 ps <<
"block=(" << param.
block.x <<
"," << param.
block.y <<
"," << param.
block.z <<
")";
271 long long flops()
const {
return 0; }
273 return arg.length*2*Nc*Ns*(
sizeof(FloatIn) +
sizeof(FloatOut));
280 template<
typename FloatOut,
typename FloatIn,
int Ns,
int Nc,
typename OutOrder,
typename InOrder,
typename Basis>
281 void copySpinorEx(OutOrder outOrder,
const InOrder inOrder,
const Basis basis,
const int *
E,
295 template<
typename FloatOut,
typename FloatIn,
int Ns,
int Nc,
typename OutOrder,
typename InOrder>
297 const int*
E,
const int*
X,
const int parity,
const bool extend,
300 if(inBasis == outBasis){
302 copySpinorEx<FloatOut, FloatIn, Ns, Nc, OutOrder, InOrder, PreserveBasis<FloatOut,FloatIn,Ns,Nc> >
305 if(Ns != 4)
errorQuda(
"Can only change basis with Nspin = 4, not Nspin = %d", Ns);
307 copySpinorEx<FloatOut, FloatIn, Ns, Nc, OutOrder, InOrder, NonRelBasis<FloatOut,FloatIn,Ns,Nc> >
310 if(Ns != 4)
errorQuda(
"Can only change basis with Nspin = 4, not Nspin = %d", Ns);
312 copySpinorEx<FloatOut, FloatIn, Ns, Nc, OutOrder, InOrder, RelBasis<FloatOut,FloatIn,Ns,Nc> >
322 template<
typename FloatOut,
typename FloatIn,
int Ns,
int Nc,
typename InOrder>
329 copySpinorEx<FloatOut,FloatIn,Ns,Nc>
333 copySpinorEx<FloatOut,FloatIn,Ns,Nc>
338 copySpinorEx<FloatOut,FloatIn,Ns,Nc>
342 copySpinorEx<FloatOut,FloatIn,Ns,Nc>
345 #ifdef BUILD_QDPJIT_INTERFACE
347 copySpinorEx<FloatOut,FloatIn,Ns,Nc>
350 errorQuda(
"QDPJIT interface has not been built\n");
358 template<
typename FloatOut,
typename FloatIn,
int Ns,
int Nc>
361 float* outNorm,
float *inNorm){
368 for(
int d=0; d<4; d++){
373 for(
int d=0; d<4; d++){
378 X[0] *= 2; E[0] *= 2;
383 extendedCopyColorSpinor<FloatOut,FloatIn,Ns,Nc>(inOrder,
out, in.
GammaBasis(),
E,
X,
parity, extend,
location, Out, outNorm);
386 extendedCopyColorSpinor<FloatOut,FloatIn,Ns,Nc>(inOrder,
out, in.
GammaBasis(),
E,
X,
parity, extend,
location, Out, outNorm);
390 extendedCopyColorSpinor<FloatOut,FloatIn,Ns,Nc>(inOrder,
out, in.
GammaBasis(),
E,
X,
parity, extend,
location, Out, outNorm);
393 extendedCopyColorSpinor<FloatOut,FloatIn,Ns,Nc>(inOrder,
out, in.
GammaBasis(),
E,
X,
parity, extend,
location, Out, outNorm);
395 #ifdef BUILD_QDPJIT_INTERFACE
397 extendedCopyColorSpinor<FloatOut,FloatIn,Ns,Nc>(inOrder,
out, in.
GammaBasis(),
E,
X,
parity, extend,
location, Out, outNorm);
399 errorQuda(
"QDPJIT interface has not been built\n");
412 template<
int Ns,
typename dstFloat,
typename srcFloat>
415 float *dstNorm,
float *srcNorm) {
439 errorQuda(
"Copying to full fields with lexicographical ordering is not currently supported");
445 errorQuda(
"QDPJIT field ordering not supported for full site fields");
449 srcFloat *srcEven = Src ? Src : (srcFloat*)src.
V();
450 srcFloat* srcOdd = (srcFloat*)((
char*)srcEven + src.
Bytes()/2);
451 float *srcNormEven = srcNorm ? srcNorm : (
float*)src.
Norm();
452 float *srcNormOdd = (
float*)((
char*)srcNormEven + src.
NormBytes()/2);
454 std::swap<srcFloat*>(srcEven, srcOdd);
455 std::swap<float*>(srcNormEven, srcNormOdd);
459 dstFloat *dstEven = Dst ? Dst : (dstFloat*)dst.
V();
460 dstFloat *dstOdd = (dstFloat*)((
char*)dstEven + dst.
Bytes()/2);
461 float *dstNormEven = dstNorm ? dstNorm : (
float*)dst.
Norm();
462 float *dstNormOdd = (
float*)((
char*)dstNormEven + dst.
NormBytes()/2);
464 std::swap<dstFloat*>(dstEven, dstOdd);
465 std::swap<float*>(dstNormEven, dstNormOdd);
469 extendedCopyColorSpinor<dstFloat, srcFloat, Ns, Nc>
470 (dst, src, 0,
location, dstEven, srcEven, dstNormEven, srcNormEven);
471 extendedCopyColorSpinor<dstFloat, srcFloat, Ns, Nc>
472 (dst, src, 1,
location, dstOdd, srcOdd, dstNormOdd, srcNormOdd);
474 extendedCopyColorSpinor<dstFloat, srcFloat, Ns, Nc>
480 template<
typename dstFloat,
typename srcFloat>
483 float *dstNorm=0,
float *srcNorm=0)
486 errorQuda(
"source and destination spins must match");
488 if(dst.
Nspin() == 4){
489 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
490 copyExtendedColorSpinor<4>(dst, src,
parity,
location, Dst, Src, dstNorm, srcNorm);
492 errorQuda(
"Extended copy has not been built for Nspin=%d fields",dst.Nspin());
494 }
else if(dst.
Nspin() == 1){
495 #ifdef GPU_STAGGERED_DIRAC
496 copyExtendedColorSpinor<1>(dst, src,
parity,
location, Dst, Src, dstNorm, srcNorm);
498 errorQuda(
"Extended copy has not been built for Nspin=%d fields", dst.Nspin());
509 void *dstNorm,
void *srcNorm){
517 CopyExtendedColorSpinor(dst, src, parity, location, static_cast<double*>(Dst), static_cast<short*>(Src), 0, static_cast<float*>(srcNorm));
527 CopyExtendedColorSpinor(dst, src, parity, location, static_cast<float*>(Dst), static_cast<short*>(Src), 0, static_cast<float*>(srcNorm));
533 CopyExtendedColorSpinor(dst, src, parity, location, static_cast<short*>(Dst), static_cast<double*>(Src), static_cast<float*>(dstNorm), 0);
535 CopyExtendedColorSpinor(dst, src, parity, location, static_cast<short*>(Dst), static_cast<float*>(Src), static_cast<float*>(dstNorm), 0);
537 CopyExtendedColorSpinor(dst, src, parity, location, static_cast<short*>(Dst), static_cast<short*>(Src), static_cast<float*>(dstNorm), static_cast<float*>(srcNorm));
CopySpinorEx(CopySpinorExArg< OutOrder, InOrder, Basis > &arg, const ColorSpinorField &meta, QudaFieldLocation location)
__device__ __host__ void operator()(RegTypeOut out[Ns *Nc *2], const RegTypeIn in[Ns *Nc *2])
mapper< FloatOut >::type RegTypeOut
QudaVerbosity getVerbosity()
mapper< FloatOut >::type RegTypeOut
void gather(int nFace, int dagger, int dir, cudaStream_t *stream_p=NULL)
mapper< FloatIn >::type RegTypeIn
int commsQuery(int nFace, int dir, int dagger=0)
void CopyExtendedColorSpinor(ColorSpinorField &dst, const ColorSpinorField &src, const int parity, const QudaFieldLocation location, dstFloat *Dst, srcFloat *Src, float *dstNorm=0, float *srcNorm=0)
void copySpinorEx(OutOrder outOrder, const InOrder inOrder, const Basis basis, const int *E, const int *X, const int parity, const bool extend, const ColorSpinorField &meta, QudaFieldLocation location)
std::string paramString(const TuneParam ¶m) const
void scatterExtended(int nFace, int parity, int dagger, int dir)
cpuColorSpinorField * spinor
__device__ __host__ void operator()(RegTypeOut out[Ns *Nc *2], const RegTypeIn in[Ns *Nc *2])
void writeAuxString(const char *format,...)
const QudaFieldLocation location
void apply(const cudaStream_t &stream)
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
void packExtended(const int nFace, const int R[], const int parity, const int dagger, const int dim, cudaStream_t *stream_p, const bool zeroCopyPack=false)
void exchangeExtendedGhost(cudaColorSpinorField *spinor, int R[], int parity, cudaStream_t *stream_p)
void extendedCopyColorSpinor(InOrder &inOrder, ColorSpinorField &out, QudaGammaBasis inBasis, const int *E, const int *X, const int parity, const bool extend, QudaFieldLocation location, FloatOut *Out, float *outNorm)
const char * VolString() const
cudaEvent_t gatherEnd[Nstream]
QudaSiteOrder SiteOrder() const
QudaFieldOrder FieldOrder() const
enum QudaFieldLocation_s QudaFieldLocation
cpuColorSpinorField * out
void copyExtendedColorSpinor(ColorSpinorField &dst, const ColorSpinorField &src, QudaFieldLocation location, const int parity, void *Dst, void *Src, void *dstNorm, void *srcNorm)
mapper< FloatIn >::type RegTypeIn
enum QudaGammaBasis_s QudaGammaBasis
QudaPrecision Precision() const
QudaGammaBasis GammaBasis() const
CopySpinorExArg(const OutOrder &out, const InOrder &in, const Basis &basis, const int *E, const int *X, const int parity)
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
__global__ void copyInteriorKernel(CopySpinorExArg< OutOrder, InOrder, Basis > arg)
#define QUDA_MAX_DIM
Maximum number of dimensions supported by QUDA. In practice, no routines make use of more than 5...
void commsStart(int nFace, int dir, int dagger=0)
__device__ __host__ void copyInterior(CopySpinorExArg< OutOrder, InOrder, Basis > &arg, int X)
QudaSiteSubset SiteSubset() const
__device__ __host__ void operator()(RegTypeOut out[Ns *Nc *2], const RegTypeIn in[Ns *Nc *2])