5 #include <color_spinor_field.h>
6 #include <color_spinor_field_order.h>
9 #define PRESERVE_SPINOR_NORM
11 #ifdef PRESERVE_SPINOR_NORM // Preserve the norm regardless of basis
12 #define kP (1.0/sqrt(2.0))
13 #define kU (1.0/sqrt(2.0))
14 #else // More numerically accurate not to preserve the norm between basis
23 using namespace colorspinor;
25 void exchangeExtendedGhost(cudaColorSpinorField* spinor, int R[], int parity, qudaStream_t *stream_p)
29 for(int i=0; i<4; i++){
30 if(R[i] > nFace) nFace = R[i];
35 int gatherCompleted[2] = {0,0};
36 int commsCompleted[2] = {0,0};
38 cudaEvent_t gatherEnd[2];
39 for(int dir=0; dir<2; dir++) cudaEventCreate(&gatherEnd[dir], cudaEventDisableTiming);
41 for(int dim=3; dim<=0; dim--){
42 if(!commDim(dim)) continue;
44 spinor->packExtended(nFace, R, parity, dagger, dim, stream_p); // packing in the dim dimension complete
45 qudaDeviceSynchronize(); // Need this since packing is performed in stream[Nstream-1]
46 for(int dir=1; dir<=0; dir--){
47 spinor->gather(nFace, dagger, 2*dim + dir);
48 qudaEventRecord(gatherEnd[dir], streams[2*dim+dir]); // gatherEnd[1], gatherEnd[0]
53 while(completeSum < 2){
54 if(!gatherCompleted[dir]){
55 if(cudaSuccess == cudaEventQuery(gatherEnd[dir])){
56 spinor->commsStart(nFace, 2*dim+dir, dagger);
58 gatherCompleted[dir--] = 1;
62 gatherCompleted[0] = gatherCompleted[1] = 0;
64 // Query if comms has completed
66 while(completeSum < 4){
67 if(!commsCompleted[dir]){
68 if(spinor->commsQuery(nFace, 2*dim+dir, dagger)){
69 spinor->scatterExtended(nFace, parity, dagger, 2*dim+dir);
71 commsCompleted[dir--] = 1;
75 commsCompleted[0] = commsCompleted[1] = 0;
76 qudaDeviceSynchronize(); // Wait for scatters to complete before next iteration
79 for(int dir=0; dir<2; dir++) cudaEventDestroy(gatherEnd[dir]);
85 /** Straight copy with no basis change */
86 template <typename FloatOut, typename FloatIn, int Ns, int Nc>
88 typedef typename mapper<FloatIn>::type RegTypeIn;
89 typedef typename mapper<FloatOut>::type RegTypeOut;
91 __device__ __host__ inline void operator()(ColorSpinor<RegTypeOut,Nc,Ns> &out, const ColorSpinor<RegTypeIn,Nc,Ns> &in) {
92 for (int s=0; s<Ns; s++) {
93 for (int c=0; c<Nc; c++) {
100 /** Transform from relativistic into non-relavisitic basis */
101 template <typename FloatOut, typename FloatIn, int Ns, int Nc>
103 typedef typename mapper<FloatIn>::type RegTypeIn;
104 typedef typename mapper<FloatOut>::type RegTypeOut;
105 __device__ __host__ inline void operator()(ColorSpinor<RegTypeOut,Nc,Ns> &out, const ColorSpinor<RegTypeIn,Nc,Ns> &in) {
106 int s1[4] = {1, 2, 3, 0};
107 int s2[4] = {3, 0, 1, 2};
108 RegTypeOut K1[4] = {static_cast<RegTypeOut>(kP), static_cast<RegTypeOut>(-kP),
109 static_cast<RegTypeOut>(-kP), static_cast<RegTypeOut>(-kP)};
110 RegTypeOut K2[4] = {static_cast<RegTypeOut>(kP), static_cast<RegTypeOut>(-kP),
111 static_cast<RegTypeOut>(kP), static_cast<RegTypeOut>(kP)};
112 for (int s=0; s<Ns; s++) {
113 for (int c=0; c<Nc; c++) {
114 out(s,c).real(K1[s]*in(s1[s],c).real() + K2[s]*in(s2[s],c).real());
115 out(s,c).imag(K1[s]*in(s1[s],c).imag() + K2[s]*in(s2[s],c).imag());
121 /** Transform from non-relativistic into relavisitic basis */
122 template <typename FloatOut, typename FloatIn, int Ns, int Nc>
124 typedef typename mapper<FloatIn>::type RegTypeIn;
125 typedef typename mapper<FloatOut>::type RegTypeOut;
126 __device__ __host__ inline void operator()(ColorSpinor<RegTypeOut,Nc,Ns> &out, const ColorSpinor<RegTypeIn,Nc,Ns> &in) {
127 int s1[4] = {1, 2, 3, 0};
128 int s2[4] = {3, 0, 1, 2};
129 RegTypeOut K1[4] = {static_cast<RegTypeOut>(-kU), static_cast<RegTypeOut>(kU),
130 static_cast<RegTypeOut>(kU), static_cast<RegTypeOut>(kU)};
131 RegTypeOut K2[4] = {static_cast<RegTypeOut>(-kU), static_cast<RegTypeOut>(kU),
132 static_cast<RegTypeOut>(-kU), static_cast<RegTypeOut>(-kU)};
133 for (int s=0; s<Ns; s++) {
134 for (int c=0; c<Nc; c++) {
135 out(s,c).real(K1[s]*in(s1[s],c).real() + K2[s]*in(s2[s],c).real());
136 out(s,c).imag(K1[s]*in(s1[s],c).imag() + K2[s]*in(s2[s],c).imag());
142 template<typename OutOrder, typename InOrder, typename Basis>
143 struct CopySpinorExArg{
152 CopySpinorExArg(const OutOrder &out, const InOrder &in, const Basis& basis, const int *E, const int *X, const int parity)
153 : out(out), in(in), basis(basis), parity(parity)
156 for(int d=0; d<4; d++){
159 this->length *= X[d]; // smaller volume
165 template<typename FloatOut, typename FloatIn, int Ns, int Nc, typename OutOrder, typename InOrder, typename Basis, bool extend>
166 __device__ __host__ void copyInterior(CopySpinorExArg<OutOrder,InOrder,Basis>& arg, int X)
170 for(int d=0; d<4; d++) R[d] = (arg.E[d] - arg.X[d]) >> 1;
172 int za = X/(arg.X[0]/2);
173 int x0h = X - za*(arg.X[0]/2);
174 int zb = za/arg.X[1];
175 x[1] = za - zb*arg.X[1];
176 x[3] = zb / arg.X[2];
177 x[2] = zb - x[3]*arg.X[2];
178 x[0] = 2*x0h + ((x[1] + x[2] + x[3] + arg.parity) & 1);
180 // Y is the cb spatial index into the extended gauge field
181 int Y = ((((x[3]+R[3])*arg.E[2] + (x[2]+R[2]))*arg.E[1] + (x[1]+R[1]))*arg.E[0]+(x[0]+R[0])) >> 1;
183 typedef typename mapper<FloatIn>::type RegTypeIn;
184 typedef typename mapper<FloatOut>::type RegTypeOut;
186 ColorSpinor<RegTypeIn,Nc,Ns> in;
187 ColorSpinor<RegTypeOut,Nc,Ns> out;
191 in = arg.in(X, parity);
193 arg.out(Y, parity) = out;
195 in = arg.in(Y, parity);
197 arg.out(Y, parity) = out;
201 template<typename FloatOut, typename FloatIn, int Ns, int Nc, typename OutOrder, typename InOrder, typename Basis, bool extend>
202 __global__ void copyInteriorKernel(CopySpinorExArg<OutOrder,InOrder,Basis> arg)
204 int cb_idx = blockIdx.x*blockDim.x + threadIdx.x;
206 while(cb_idx < arg.length){
207 copyInterior<FloatOut,FloatIn,Ns,Nc,OutOrder,InOrder,Basis,extend>(arg,cb_idx);
208 cb_idx += gridDim.x*blockDim.x;
215 template<typename FloatOut, typename FloatIn, int Ns, int Nc, typename OutOrder, typename InOrder, typename Basis, bool extend>
216 void copyInterior(CopySpinorExArg<OutOrder,InOrder,Basis>& arg)
218 for(int cb_idx=0; cb_idx<arg.length; cb_idx++){
219 copyInterior<FloatOut,FloatIn,Ns,Nc,OutOrder,InOrder,Basis,extend>(arg, cb_idx);
223 template<typename FloatOut, typename FloatIn, int Ns, int Nc, typename OutOrder, typename InOrder, typename Basis, bool extend>
224 class CopySpinorEx : Tunable {
226 CopySpinorExArg<OutOrder,InOrder,Basis> arg;
227 const ColorSpinorField &meta;
228 QudaFieldLocation location;
230 unsigned int sharedBytesPerThread() const { return 0; }
231 unsigned int sharedBytesPerBlock(const TuneParam ¶m) const { return 0; }
232 bool advanceSharedBytes(TuneParam ¶m) const { return false; } // Don't tune shared mem
233 bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
234 unsigned int minThreads() const { return arg.length; }
237 CopySpinorEx(CopySpinorExArg<OutOrder,InOrder,Basis> &arg, const ColorSpinorField &meta, QudaFieldLocation location)
238 : arg(arg), meta(meta), location(location) {
239 writeAuxString("out_stride=%d,in_stride=%d",arg.out.stride,arg.in.stride);
242 void apply(const qudaStream_t &stream){
243 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
245 if (location == QUDA_CPU_FIELD_LOCATION) {
246 copyInterior<FloatOut,FloatIn,Ns,Nc,OutOrder,InOrder,Basis,extend>(arg);
247 } else if (location == QUDA_CUDA_FIELD_LOCATION) {
248 qudaLaunchKernel(copyInteriorKernel<FloatOut,FloatIn,Ns,Nc,OutOrder,InOrder,Basis,extend>, tp, stream, arg);
252 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
253 long long flops() const { return 0; }
254 long long bytes() const { return arg.length*2*Nc*Ns*(sizeof(FloatIn) + sizeof(FloatOut)); }
257 template<typename FloatOut, typename FloatIn, int Ns, int Nc, typename OutOrder, typename InOrder, typename Basis>
258 void copySpinorEx(OutOrder outOrder, const InOrder inOrder, const Basis basis, const int *E,
259 const int *X, const int parity, const bool extend, const ColorSpinorField &meta, QudaFieldLocation location)
261 CopySpinorExArg<OutOrder,InOrder,Basis> arg(outOrder, inOrder, basis, E, X, parity);
263 CopySpinorEx<FloatOut, FloatIn, Ns, Nc, OutOrder, InOrder, Basis, true> copier(arg, meta, location);
266 CopySpinorEx<FloatOut, FloatIn, Ns, Nc, OutOrder, InOrder, Basis, false> copier(arg, meta, location);
271 template<typename FloatOut, typename FloatIn, int Ns, int Nc, typename OutOrder, typename InOrder>
272 void copySpinorEx(OutOrder outOrder, InOrder inOrder, const QudaGammaBasis outBasis, const QudaGammaBasis inBasis,
273 const int* E, const int* X, const int parity, const bool extend,
274 const ColorSpinorField &meta, QudaFieldLocation location)
276 if(inBasis == outBasis){
277 PreserveBasis<FloatOut,FloatIn,Ns,Nc> basis;
278 copySpinorEx<FloatOut, FloatIn, Ns, Nc, OutOrder, InOrder, PreserveBasis<FloatOut,FloatIn,Ns,Nc> >
279 (outOrder, inOrder, basis, E, X, parity, extend, meta, location);
280 }else if(outBasis == QUDA_UKQCD_GAMMA_BASIS && inBasis == QUDA_DEGRAND_ROSSI_GAMMA_BASIS){
281 if(Ns != 4) errorQuda("Can only change basis with Nspin = 4, not Nspin = %d", Ns);
282 NonRelBasis<FloatOut,FloatIn,4,Nc> basis;
283 copySpinorEx<FloatOut, FloatIn, 4, Nc, OutOrder, InOrder, NonRelBasis<FloatOut,FloatIn,4,Nc> >
284 (outOrder, inOrder, basis, E, X, parity, extend, meta, location);
285 }else if(inBasis == QUDA_UKQCD_GAMMA_BASIS && outBasis == QUDA_DEGRAND_ROSSI_GAMMA_BASIS){
286 if(Ns != 4) errorQuda("Can only change basis with Nspin = 4, not Nspin = %d", Ns);
287 RelBasis<FloatOut,FloatIn,4,Nc> basis;
288 copySpinorEx<FloatOut, FloatIn, 4, Nc, OutOrder, InOrder, RelBasis<FloatOut,FloatIn,4,Nc> >
289 (outOrder, inOrder, basis, E, X, parity, extend, meta, location);
291 errorQuda("Basis change not supported");
296 // Need to rewrite the following two functions...
297 // Decide on the output order
298 template<typename FloatOut, typename FloatIn, int Ns, int Nc, typename InOrder>
299 void extendedCopyColorSpinor(InOrder &inOrder, ColorSpinorField &out,
300 QudaGammaBasis inBasis, const int *E, const int *X, const int parity, const bool extend,
301 QudaFieldLocation location, FloatOut *Out, float *outNorm){
303 if (out.isNative()) {
304 typedef typename colorspinor_mapper<FloatOut,Ns,Nc>::type ColorSpinor;
305 ColorSpinor outOrder(out, 1, Out, outNorm);
306 copySpinorEx<FloatOut,FloatIn,Ns,Nc>
307 (outOrder, inOrder, out.GammaBasis(), inBasis, E, X, parity, extend, out, location);
309 errorQuda("Order not defined");
314 template<typename FloatOut, typename FloatIn, int Ns, int Nc>
315 void extendedCopyColorSpinor(ColorSpinorField &out, const ColorSpinorField &in,
316 const int parity, const QudaFieldLocation location, FloatOut *Out, FloatIn *In,
317 float* outNorm, float *inNorm){
321 const bool extend = (out.Volume() >= in.Volume());
323 for (int d=0; d<4; d++) {
328 for (int d=0; d<4; d++) {
333 X[0] *= 2; E[0] *= 2; // Since we consider only a single parity at a time
336 typedef typename colorspinor_mapper<FloatIn,Ns,Nc>::type ColorSpinor;
337 ColorSpinor inOrder(in, 1, In, inNorm);
338 extendedCopyColorSpinor<FloatOut,FloatIn,Ns,Nc>(inOrder, out, in.GammaBasis(), E, X, parity, extend, location, Out, outNorm);
340 errorQuda("Order not defined");
345 template<int Ns, typename dstFloat, typename srcFloat>
346 void copyExtendedColorSpinor(ColorSpinorField &dst, const ColorSpinorField &src,
347 const int parity, const QudaFieldLocation location, dstFloat *Dst, srcFloat *Src,
348 float *dstNorm, float *srcNorm) {
351 if(dst.Ndim() != src.Ndim())
352 errorQuda("Number of dimensions %d %d don't match", dst.Ndim(), src.Ndim());
354 if(!(dst.SiteOrder() == src.SiteOrder() ||
355 (dst.SiteOrder() == QUDA_EVEN_ODD_SITE_ORDER &&
356 src.SiteOrder() == QUDA_ODD_EVEN_SITE_ORDER) ||
357 (dst.SiteOrder() == QUDA_ODD_EVEN_SITE_ORDER &&
358 src.SiteOrder() == QUDA_EVEN_ODD_SITE_ORDER) ) ){
360 errorQuda("Subset orders %d %d don't match", dst.SiteOrder(), src.SiteOrder());
363 if(dst.SiteSubset() != src.SiteSubset())
364 errorQuda("Subset types do not match %d %d", dst.SiteSubset(), src.SiteSubset());
366 if(dst.Ncolor() != 3 || src.Ncolor() != 3) errorQuda("Nc != 3 not yet supported");
370 // We currently only support parity-ordered fields; even-odd or odd-even
371 if(dst.SiteOrder() == QUDA_LEXICOGRAPHIC_SITE_ORDER){
372 errorQuda("Copying to full fields with lexicographical ordering is not currently supported");
375 if(dst.SiteSubset() == QUDA_FULL_SITE_SUBSET){
376 if(src.FieldOrder() == QUDA_QDPJIT_FIELD_ORDER ||
377 dst.FieldOrder() == QUDA_QDPJIT_FIELD_ORDER){
378 errorQuda("QDPJIT field ordering not supported for full site fields");
381 // set for the source subset ordering
382 srcFloat *srcEven = Src ? Src : (srcFloat*)src.V();
383 srcFloat* srcOdd = (srcFloat*)((char*)srcEven + src.Bytes()/2);
384 float *srcNormEven = srcNorm ? srcNorm : (float*)src.Norm();
385 float *srcNormOdd = (float*)((char*)srcNormEven + src.NormBytes()/2);
386 if(src.SiteOrder() == QUDA_ODD_EVEN_SITE_ORDER){
387 std::swap<srcFloat*>(srcEven, srcOdd);
388 std::swap<float*>(srcNormEven, srcNormOdd);
391 // set for the destination subset ordering
392 dstFloat *dstEven = Dst ? Dst : (dstFloat*)dst.V();
393 dstFloat *dstOdd = (dstFloat*)((char*)dstEven + dst.Bytes()/2);
394 float *dstNormEven = dstNorm ? dstNorm : (float*)dst.Norm();
395 float *dstNormOdd = (float*)((char*)dstNormEven + dst.NormBytes()/2);
396 if(dst.SiteOrder() == QUDA_ODD_EVEN_SITE_ORDER){
397 std::swap<dstFloat*>(dstEven, dstOdd);
398 std::swap<float*>(dstNormEven, dstNormOdd);
401 // should be able to apply to select either even or odd parity at this point as well.
402 extendedCopyColorSpinor<dstFloat, srcFloat, Ns, Nc>
403 (dst, src, 0, location, dstEven, srcEven, dstNormEven, srcNormEven);
404 extendedCopyColorSpinor<dstFloat, srcFloat, Ns, Nc>
405 (dst, src, 1, location, dstOdd, srcOdd, dstNormOdd, srcNormOdd);
407 extendedCopyColorSpinor<dstFloat, srcFloat, Ns, Nc>
408 (dst, src, parity, location, Dst, Src, dstNorm, srcNorm);
409 } // N.B. Need to update this to account for differences in parity
413 template<typename dstFloat, typename srcFloat>
414 void CopyExtendedColorSpinor(ColorSpinorField &dst, const ColorSpinorField &src,
415 const int parity, const QudaFieldLocation location, dstFloat *Dst, srcFloat *Src,
416 float *dstNorm=0, float *srcNorm=0)
418 if(dst.Nspin() != src.Nspin())
419 errorQuda("source and destination spins must match");
421 if(dst.Nspin() == 4){
423 copyExtendedColorSpinor<4>(dst, src, parity, location, Dst, Src, dstNorm, srcNorm);
425 errorQuda("Extended copy has not been built for Nspin=%d fields",dst.Nspin());
427 }else if(dst.Nspin() == 1){
429 copyExtendedColorSpinor<1>(dst, src, parity, location, Dst, Src, dstNorm, srcNorm);
431 errorQuda("Extended copy has not been built for Nspin=%d fields", dst.Nspin());
434 errorQuda("Nspin=%d unsupported", dst.Nspin());
439 // There's probably no need to have the additional Dst and Src arguments here!
440 void copyExtendedColorSpinor(ColorSpinorField &dst, const ColorSpinorField &src,
441 QudaFieldLocation location, const int parity, void *Dst, void *Src,
442 void *dstNorm, void *srcNorm){
445 if(dst.Precision() == QUDA_DOUBLE_PRECISION){
446 if(src.Precision() == QUDA_DOUBLE_PRECISION){
447 CopyExtendedColorSpinor(dst, src, parity, location, static_cast<double*>(Dst), static_cast<double*>(Src));
448 }else if(src.Precision() == QUDA_SINGLE_PRECISION){
449 CopyExtendedColorSpinor(dst, src, parity, location, static_cast<double*>(Dst), static_cast<float*>(Src));
450 }else if(src.Precision() == QUDA_HALF_PRECISION){
451 CopyExtendedColorSpinor(dst, src, parity, location, static_cast<double*>(Dst), static_cast<short*>(Src), 0, static_cast<float*>(srcNorm));
453 errorQuda("Unsupported Precision %d", src.Precision());
455 } else if (dst.Precision() == QUDA_SINGLE_PRECISION){
456 if(src.Precision() == QUDA_DOUBLE_PRECISION){
457 CopyExtendedColorSpinor(dst, src, parity, location, static_cast<float*>(Dst), static_cast<double*>(Src));
458 }else if(src.Precision() == QUDA_SINGLE_PRECISION){
459 CopyExtendedColorSpinor(dst, src, parity, location, static_cast<float*>(Dst), static_cast<float*>(Src));
460 }else if(src.Precision() == QUDA_HALF_PRECISION){
461 CopyExtendedColorSpinor(dst, src, parity, location, static_cast<float*>(Dst), static_cast<short*>(Src), 0, static_cast<float*>(srcNorm));
463 errorQuda("Unsupported Precision %d", src.Precision());
465 } else if (dst.Precision() == QUDA_HALF_PRECISION){
466 if(src.Precision() == QUDA_DOUBLE_PRECISION){
467 CopyExtendedColorSpinor(dst, src, parity, location, static_cast<short*>(Dst), static_cast<double*>(Src), static_cast<float*>(dstNorm), 0);
468 }else if(src.Precision() == QUDA_SINGLE_PRECISION){
469 CopyExtendedColorSpinor(dst, src, parity, location, static_cast<short*>(Dst), static_cast<float*>(Src), static_cast<float*>(dstNorm), 0);
470 }else if(src.Precision() == QUDA_HALF_PRECISION){
471 CopyExtendedColorSpinor(dst, src, parity, location, static_cast<short*>(Dst), static_cast<short*>(Src), static_cast<float*>(dstNorm), static_cast<float*>(srcNorm));
473 errorQuda("Unsupported Precision %d", src.Precision());
475 } else if (dst.Precision() == QUDA_QUARTER_PRECISION){
476 if(src.Precision() == QUDA_DOUBLE_PRECISION){
477 CopyExtendedColorSpinor(dst, src, parity, location, static_cast<char*>(Dst), static_cast<double*>(Src), static_cast<float*>(dstNorm), 0);
478 }else if(src.Precision() == QUDA_SINGLE_PRECISION){
479 CopyExtendedColorSpinor(dst, src, parity, location, static_cast<char*>(Dst), static_cast<float*>(Src), static_cast<float*>(dstNorm), 0);
480 }else if(src.Precision() == QUDA_HALF_PRECISION){
481 CopyExtendedColorSpinor(dst, src, parity, location, static_cast<char*>(Dst), static_cast<short*>(Src), static_cast<float*>(dstNorm), static_cast<float*>(srcNorm));
483 errorQuda("Unsupported Precision %d", src.Precision());
486 errorQuda("Unsupported Precision %d", dst.Precision());
489 errorQuda("Disabled");