2 #include <gauge_field_order.h>
3 #include <quda_matrix.h>
10 Kernel argument struct
12 template <typename OutOrder, typename InOrder>
13 struct CopyGaugeExArg {
16 int Xin[QUDA_MAX_DIM];
17 int Xout[QUDA_MAX_DIM];
22 int faceVolumeCB[QUDA_MAX_DIM];
23 bool regularToextended;
24 CopyGaugeExArg(const OutOrder &out, const InOrder &in, const int *Xout, const int *Xin,
25 const int *faceVolumeCB, int nDim, int geometry)
26 : out(out), in(in), nDim(nDim), geometry(geometry) {
27 for (int d=0; d<nDim; d++) {
28 this->Xout[d] = Xout[d];
29 this->Xin[d] = Xin[d];
30 this->faceVolumeCB[d] = faceVolumeCB[d];
33 if (out.volumeCB > in.volumeCB) {
34 this->volume = 2*in.volumeCB;
35 this->volumeEx = 2*out.volumeCB;
36 this->regularToextended = true;
38 this->volume = 2*out.volumeCB;
39 this->volumeEx = 2*in.volumeCB;
40 this->regularToextended = false;
47 Copy a regular/extended gauge field into an extended/regular gauge field
49 template <typename FloatOut, typename FloatIn, int length, typename OutOrder, typename InOrder, bool regularToextended>
50 __device__ __host__ void copyGaugeEx(CopyGaugeExArg<OutOrder,InOrder> &arg, int X, int parity) {
51 typedef typename mapper<FloatIn>::type RegTypeIn;
52 typedef typename mapper<FloatOut>::type RegTypeOut;
53 constexpr int nColor = Ncolor(length);
58 if(regularToextended){
60 for (int d=0; d<4; d++) R[d] = (arg.Xout[d] - arg.Xin[d]) >> 1;
61 int za = X/(arg.Xin[0]/2);
62 int x0h = X - za*(arg.Xin[0]/2);
63 int zb = za/arg.Xin[1];
64 x[1] = za - zb*arg.Xin[1];
65 x[3] = zb / arg.Xin[2];
66 x[2] = zb - x[3]*arg.Xin[2];
67 x[0] = 2*x0h + ((x[1] + x[2] + x[3] + parity) & 1);
68 // Y is the cb spatial index into the extended gauge field
69 xout = ((((x[3]+R[3])*arg.Xout[2] + (x[2]+R[2]))*arg.Xout[1] + (x[1]+R[1]))*arg.Xout[0]+(x[0]+R[0])) >> 1;
72 //extended to regular gauge
73 for (int d=0; d<4; d++) R[d] = (arg.Xin[d] - arg.Xout[d]) >> 1;
74 int za = X/(arg.Xout[0]/2);
75 int x0h = X - za*(arg.Xout[0]/2);
76 int zb = za/arg.Xout[1];
77 x[1] = za - zb*arg.Xout[1];
78 x[3] = zb / arg.Xout[2];
79 x[2] = zb - x[3]*arg.Xout[2];
80 x[0] = 2*x0h + ((x[1] + x[2] + x[3] + parity) & 1);
81 // Y is the cb spatial index into the extended gauge field
82 xin = ((((x[3]+R[3])*arg.Xin[2] + (x[2]+R[2]))*arg.Xin[1] + (x[1]+R[1]))*arg.Xin[0]+(x[0]+R[0])) >> 1;
85 for (int d=0; d<arg.geometry; d++) {
86 const Matrix<complex<RegTypeIn>,nColor> in = arg.in(d, xin, parity);
87 Matrix<complex<RegTypeOut>,nColor> out = in;
88 arg.out(d, xout, parity) = out;
92 template <typename FloatOut, typename FloatIn, int length, typename OutOrder, typename InOrder, bool regularToextended>
93 void copyGaugeEx(CopyGaugeExArg<OutOrder,InOrder> arg) {
94 for (int parity=0; parity<2; parity++) {
95 for(int X=0; X<arg.volume/2; X++){
96 copyGaugeEx<FloatOut, FloatIn, length, OutOrder, InOrder, regularToextended>(arg, X, parity);
101 template <typename FloatOut, typename FloatIn, int length, typename OutOrder, typename InOrder, bool regularToextended>
102 __global__ void copyGaugeExKernel(CopyGaugeExArg<OutOrder,InOrder> arg) {
103 for (int parity=0; parity<2; parity++) {
104 int X = blockIdx.x * blockDim.x + threadIdx.x;
105 if (X >= arg.volume/2) return;
106 copyGaugeEx<FloatOut, FloatIn, length, OutOrder, InOrder, regularToextended>(arg, X, parity);
110 template <typename FloatOut, typename FloatIn, int length, typename OutOrder, typename InOrder>
111 class CopyGaugeEx : Tunable {
112 CopyGaugeExArg<OutOrder,InOrder> arg;
113 const GaugeField &meta; // use for metadata
114 QudaFieldLocation location;
117 unsigned int sharedBytesPerThread() const { return 0; }
118 unsigned int sharedBytesPerBlock(const TuneParam ¶m) const { return 0 ;}
120 bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
121 unsigned int minThreads() const { return arg.volume/2; }
124 CopyGaugeEx(CopyGaugeExArg<OutOrder,InOrder> &arg, const GaugeField &meta, QudaFieldLocation location)
125 : arg(arg), meta(meta), location(location) {
126 writeAuxString("out_stride=%d,in_stride=%d,geometry=%d",arg.out.stride,arg.in.stride,arg.geometry);
128 virtual ~CopyGaugeEx() { ; }
130 void apply(const qudaStream_t &stream) {
131 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
133 if (location == QUDA_CPU_FIELD_LOCATION) {
134 if (arg.regularToextended) copyGaugeEx<FloatOut, FloatIn, length, OutOrder, InOrder, true>(arg);
135 else copyGaugeEx<FloatOut, FloatIn, length, OutOrder, InOrder, false>(arg);
136 } else if (location == QUDA_CUDA_FIELD_LOCATION) {
137 if (arg.regularToextended)
138 qudaLaunchKernel(copyGaugeExKernel<FloatOut, FloatIn, length, OutOrder, InOrder, true>, tp, stream, arg);
140 qudaLaunchKernel(copyGaugeExKernel<FloatOut, FloatIn, length, OutOrder, InOrder, false>, tp, stream, arg);
144 TuneKey tuneKey() const {
145 return TuneKey(meta.VolString(), typeid(*this).name(), aux);
148 long long flops() const { return 0; }
149 long long bytes() const {
150 int sites = 4*arg.volume/2;
151 return 2 * sites * ( arg.in.Bytes() + arg.in.hasPhase*sizeof(FloatIn)
152 + arg.out.Bytes() + arg.out.hasPhase*sizeof(FloatOut) );
157 template <typename FloatOut, typename FloatIn, int length, typename OutOrder, typename InOrder>
158 void copyGaugeEx(OutOrder outOrder, const InOrder inOrder, const int *E,
159 const int *X, const int *faceVolumeCB, const GaugeField &meta, QudaFieldLocation location) {
161 CopyGaugeExArg<OutOrder,InOrder>
162 arg(outOrder, inOrder, E, X, faceVolumeCB, meta.Ndim(), meta.Geometry());
163 CopyGaugeEx<FloatOut, FloatIn, length, OutOrder, InOrder> copier(arg, meta, location);
167 template <typename FloatOut, typename FloatIn, int length, typename InOrder>
168 void copyGaugeEx(const InOrder &inOrder, const int *X, GaugeField &out,
169 QudaFieldLocation location, FloatOut *Out) {
171 int faceVolumeCB[QUDA_MAX_DIM];
172 for (int i=0; i<4; i++) faceVolumeCB[i] = out.SurfaceCB(i) * out.Nface();
174 if (out.isNative()) {
175 if (out.Reconstruct() == QUDA_RECONSTRUCT_NO) {
176 typedef typename gauge_mapper<FloatOut, QUDA_RECONSTRUCT_NO>::type G;
177 copyGaugeEx<FloatOut, FloatIn, length>(G(out, Out), inOrder, out.X(), X, faceVolumeCB, out, location);
178 } else if (out.Reconstruct() == QUDA_RECONSTRUCT_12) {
179 #if QUDA_RECONSTRUCT & 2
180 typedef typename gauge_mapper<FloatOut,QUDA_RECONSTRUCT_12>::type G;
181 copyGaugeEx<FloatOut,FloatIn,length>
182 (G(out, Out), inOrder, out.X(), X, faceVolumeCB, out, location);
184 errorQuda("QUDA_RECONSTRUCT=%d does not enable reconstruct-12", QUDA_RECONSTRUCT);
186 } else if (out.Reconstruct() == QUDA_RECONSTRUCT_8) {
187 #if QUDA_RECONSTRUCT & 1
188 typedef typename gauge_mapper<FloatOut,QUDA_RECONSTRUCT_8>::type G;
189 copyGaugeEx<FloatOut,FloatIn,length>
190 (G(out, Out), inOrder, out.X(), X, faceVolumeCB, out, location);
192 errorQuda("QUDA_RECONSTRUCT=%d does not enable reconstruct-8", QUDA_RECONSTRUCT);
194 #ifdef GPU_STAGGERED_DIRAC
195 } else if (out.Reconstruct() == QUDA_RECONSTRUCT_13) {
196 #if QUDA_RECONSTRUCT & 2
197 typedef typename gauge_mapper<FloatOut,QUDA_RECONSTRUCT_13>::type G;
198 copyGaugeEx<FloatOut,FloatIn,length>
199 (G(out, Out), inOrder, out.X(), X, faceVolumeCB, out, location);
201 errorQuda("QUDA_RECONSTRUCT=%d does not enable reconstruct-13", QUDA_RECONSTRUCT);
203 } else if (out.Reconstruct() == QUDA_RECONSTRUCT_9) {
204 #if QUDA_RECONSTRUCT & 1
205 typedef typename gauge_mapper<FloatOut,QUDA_RECONSTRUCT_9>::type G;
206 copyGaugeEx<FloatOut,FloatIn,length>
207 (G(out, Out), inOrder, out.X(), X, faceVolumeCB, out, location);
209 errorQuda("QUDA_RECONSTRUCT=%d does not enable reconstruct-9", QUDA_RECONSTRUCT);
211 #endif // GPU_STAGGERED_DIRAC
213 errorQuda("Reconstruction %d and order %d not supported", out.Reconstruct(), out.Order());
215 } else if (out.Order() == QUDA_QDP_GAUGE_ORDER) {
217 #ifdef BUILD_QDP_INTERFACE
218 copyGaugeEx<FloatOut,FloatIn,length>
219 (QDPOrder<FloatOut,length>(out, Out), inOrder, out.X(), X, faceVolumeCB, out, location);
221 errorQuda("QDP interface has not been built\n");
224 } else if (out.Order() == QUDA_MILC_GAUGE_ORDER) {
226 #ifdef BUILD_MILC_INTERFACE
227 copyGaugeEx<FloatOut,FloatIn,length>
228 (MILCOrder<FloatOut,length>(out, Out), inOrder, out.X(), X, faceVolumeCB, out, location);
230 errorQuda("MILC interface has not been built\n");
233 } else if (out.Order() == QUDA_TIFR_GAUGE_ORDER) {
235 #ifdef BUILD_TIFR_INTERFACE
236 copyGaugeEx<FloatOut,FloatIn,length>
237 (TIFROrder<FloatOut,length>(out, Out), inOrder, out.X(), X, faceVolumeCB, out, location);
239 errorQuda("TIFR interface has not been built\n");
243 errorQuda("Gauge field %d order not supported", out.Order());
248 template <typename FloatOut, typename FloatIn, int length>
249 void copyGaugeEx(GaugeField &out, const GaugeField &in, QudaFieldLocation location,
250 FloatOut *Out, FloatIn *In) {
253 if (in.Reconstruct() == QUDA_RECONSTRUCT_NO) {
254 typedef typename gauge_mapper<FloatIn, QUDA_RECONSTRUCT_NO>::type G;
255 copyGaugeEx<FloatOut, FloatIn, length>(G(in, In), in.X(), out, location, Out);
256 } else if (in.Reconstruct() == QUDA_RECONSTRUCT_12) {
257 #if QUDA_RECONSTRUCT & 2
258 typedef typename gauge_mapper<FloatIn,QUDA_RECONSTRUCT_12>::type G;
259 copyGaugeEx<FloatOut,FloatIn,length> (G(in, In), in.X(), out, location, Out);
261 errorQuda("QUDA_RECONSTRUCT=%d does not enable reconstruct-12", QUDA_RECONSTRUCT);
263 } else if (in.Reconstruct() == QUDA_RECONSTRUCT_8) {
264 #if QUDA_RECONSTRUCT & 1
265 typedef typename gauge_mapper<FloatIn,QUDA_RECONSTRUCT_8>::type G;
266 copyGaugeEx<FloatOut,FloatIn,length> (G(in, In), in.X(), out, location, Out);
268 errorQuda("QUDA_RECONSTRUCT=%d does not enable reconstruct-8", QUDA_RECONSTRUCT);
270 #ifdef GPU_STAGGERED_DIRAC
271 } else if (in.Reconstruct() == QUDA_RECONSTRUCT_13) {
272 #if QUDA_RECONSTRUCT & 2
273 typedef typename gauge_mapper<FloatIn,QUDA_RECONSTRUCT_13>::type G;
274 copyGaugeEx<FloatOut,FloatIn,length> (G(in, In), in.X(), out, location, Out);
276 errorQuda("QUDA_RECONSTRUCT=%d does not enable reconstruct-13", QUDA_RECONSTRUCT);
278 } else if (in.Reconstruct() == QUDA_RECONSTRUCT_9) {
279 #if QUDA_RECONSTRUCT & 1
280 typedef typename gauge_mapper<FloatIn,QUDA_RECONSTRUCT_9>::type G;
281 copyGaugeEx<FloatOut,FloatIn,length> (G(in, In), in.X(), out, location, Out);
283 errorQuda("QUDA_RECONSTRUCT=%d does not enable reconstruct-9", QUDA_RECONSTRUCT);
285 #endif // GPU_STAGGERED_DIRAC
287 errorQuda("Reconstruction %d and order %d not supported", in.Reconstruct(), in.Order());
289 } else if (in.Order() == QUDA_QDP_GAUGE_ORDER) {
291 #ifdef BUILD_QDP_INTERFACE
292 copyGaugeEx<FloatOut,FloatIn,length>(QDPOrder<FloatIn,length>(in, In),
293 in.X(), out, location, Out);
295 errorQuda("QDP interface has not been built\n");
298 } else if (in.Order() == QUDA_MILC_GAUGE_ORDER) {
300 #ifdef BUILD_MILC_INTERFACE
301 copyGaugeEx<FloatOut,FloatIn,length>(MILCOrder<FloatIn,length>(in, In),
302 in.X(), out, location, Out);
304 errorQuda("MILC interface has not been built\n");
307 } else if (in.Order() == QUDA_TIFR_GAUGE_ORDER) {
309 #ifdef BUILD_TIFR_INTERFACE
310 copyGaugeEx<FloatOut,FloatIn,length>(TIFROrder<FloatIn,length>(in, In),
311 in.X(), out, location, Out);
313 errorQuda("TIFR interface has not been built\n");
317 errorQuda("Gauge field %d order not supported", in.Order());
322 template <typename FloatOut, typename FloatIn>
323 void copyGaugeEx(GaugeField &out, const GaugeField &in, QudaFieldLocation location,
324 FloatOut *Out, FloatIn *In) {
326 if (in.Ncolor() != 3 && out.Ncolor() != 3) {
327 errorQuda("Unsupported number of colors; out.Nc=%d, in.Nc=%d", out.Ncolor(), in.Ncolor());
330 if (out.Geometry() != in.Geometry()) {
331 errorQuda("Field geometries %d %d do not match", out.Geometry(), in.Geometry());
334 if (in.LinkType() != QUDA_ASQTAD_MOM_LINKS && out.LinkType() != QUDA_ASQTAD_MOM_LINKS) {
335 // we are doing gauge field packing
336 copyGaugeEx<FloatOut,FloatIn,18>(out, in, location, Out, In);
338 errorQuda("Not supported");
342 void copyExtendedGauge(GaugeField &out, const GaugeField &in,
343 QudaFieldLocation location, void *Out, void *In) {
345 for (int d=0; d<in.Ndim(); d++) {
346 if ( (out.X()[d] - in.X()[d]) % 2 != 0)
347 errorQuda("Cannot copy into an asymmetrically extended gauge field");
350 if (out.Precision() == QUDA_DOUBLE_PRECISION) {
351 if (in.Precision() == QUDA_DOUBLE_PRECISION) {
352 copyGaugeEx(out, in, location, (double*)Out, (double*)In);
353 } else if (in.Precision() == QUDA_SINGLE_PRECISION) {
354 #if QUDA_PRECISION & 4
355 copyGaugeEx(out, in, location, (double*)Out, (float*)In);
357 errorQuda("QUDA_PRECISION=%d does not enable single precision", QUDA_PRECISION);
360 errorQuda("Precision %d not instantiated", in.Precision());
362 } else if (out.Precision() == QUDA_SINGLE_PRECISION) {
363 if (in.Precision() == QUDA_DOUBLE_PRECISION) {
364 copyGaugeEx(out, in, location, (float *)Out, (double *)In);
365 } else if (in.Precision() == QUDA_SINGLE_PRECISION) {
366 #if QUDA_PRECISION & 4
367 copyGaugeEx(out, in, location, (float *)Out, (float *)In);
369 errorQuda("QUDA_PRECISION=%d does not enable single precision", QUDA_PRECISION);
372 errorQuda("Precision %d not instantiated", in.Precision());
374 } else if (out.Precision() == QUDA_HALF_PRECISION) {
375 if (in.Precision() == QUDA_HALF_PRECISION) {
376 #if QUDA_PRECISION & 2
377 copyGaugeEx(out, in, location, (short *)Out, (short *)In);
379 errorQuda("QUDA_PRECISION=%d does not enable single precision", QUDA_PRECISION);
382 errorQuda("Precision %d not instantiated", in.Precision());
384 } else if (out.Precision() == QUDA_QUARTER_PRECISION) {
385 if (in.Precision() == QUDA_QUARTER_PRECISION) {
386 #if QUDA_PRECISION & 1
387 copyGaugeEx(out, in, location, (int8_t *)Out, (int8_t *)In);
389 errorQuda("QUDA_PRECISION=%d does not enable single precision", QUDA_PRECISION);
392 errorQuda("Precision %d not instantiated", in.Precision());
395 errorQuda("Precision %d not instantiated", out.Precision());