QUDA  v1.1.0
A library for QCD on GPUs
copy_gauge_extended.cu
Go to the documentation of this file.
1 #include <tune_quda.h>
2 #include <gauge_field_order.h>
3 #include <quda_matrix.h>
4 
5 namespace quda {
6 
7  using namespace gauge;
8 
9  /**
10  Kernel argument struct
11  */
12  template <typename OutOrder, typename InOrder>
13  struct CopyGaugeExArg {
14  OutOrder out;
15  const InOrder in;
16  int Xin[QUDA_MAX_DIM];
17  int Xout[QUDA_MAX_DIM];
18  int volume;
19  int volumeEx;
20  int nDim;
21  int geometry;
22  int faceVolumeCB[QUDA_MAX_DIM];
23  bool regularToextended;
24  CopyGaugeExArg(const OutOrder &out, const InOrder &in, const int *Xout, const int *Xin,
25  const int *faceVolumeCB, int nDim, int geometry)
26  : out(out), in(in), nDim(nDim), geometry(geometry) {
27  for (int d=0; d<nDim; d++) {
28  this->Xout[d] = Xout[d];
29  this->Xin[d] = Xin[d];
30  this->faceVolumeCB[d] = faceVolumeCB[d];
31  }
32 
33  if (out.volumeCB > in.volumeCB) {
34  this->volume = 2*in.volumeCB;
35  this->volumeEx = 2*out.volumeCB;
36  this->regularToextended = true;
37  } else {
38  this->volume = 2*out.volumeCB;
39  this->volumeEx = 2*in.volumeCB;
40  this->regularToextended = false;
41  }
42  }
43 
44  };
45 
46  /**
47  Copy a regular/extended gauge field into an extended/regular gauge field
48  */
49  template <typename FloatOut, typename FloatIn, int length, typename OutOrder, typename InOrder, bool regularToextended>
50  __device__ __host__ void copyGaugeEx(CopyGaugeExArg<OutOrder,InOrder> &arg, int X, int parity) {
51  typedef typename mapper<FloatIn>::type RegTypeIn;
52  typedef typename mapper<FloatOut>::type RegTypeOut;
53  constexpr int nColor = Ncolor(length);
54 
55  int x[4];
56  int R[4];
57  int xin, xout;
58  if(regularToextended){
59  //regular to extended
60  for (int d=0; d<4; d++) R[d] = (arg.Xout[d] - arg.Xin[d]) >> 1;
61  int za = X/(arg.Xin[0]/2);
62  int x0h = X - za*(arg.Xin[0]/2);
63  int zb = za/arg.Xin[1];
64  x[1] = za - zb*arg.Xin[1];
65  x[3] = zb / arg.Xin[2];
66  x[2] = zb - x[3]*arg.Xin[2];
67  x[0] = 2*x0h + ((x[1] + x[2] + x[3] + parity) & 1);
68  // Y is the cb spatial index into the extended gauge field
69  xout = ((((x[3]+R[3])*arg.Xout[2] + (x[2]+R[2]))*arg.Xout[1] + (x[1]+R[1]))*arg.Xout[0]+(x[0]+R[0])) >> 1;
70  xin = X;
71  } else{
72  //extended to regular gauge
73  for (int d=0; d<4; d++) R[d] = (arg.Xin[d] - arg.Xout[d]) >> 1;
74  int za = X/(arg.Xout[0]/2);
75  int x0h = X - za*(arg.Xout[0]/2);
76  int zb = za/arg.Xout[1];
77  x[1] = za - zb*arg.Xout[1];
78  x[3] = zb / arg.Xout[2];
79  x[2] = zb - x[3]*arg.Xout[2];
80  x[0] = 2*x0h + ((x[1] + x[2] + x[3] + parity) & 1);
81  // Y is the cb spatial index into the extended gauge field
82  xin = ((((x[3]+R[3])*arg.Xin[2] + (x[2]+R[2]))*arg.Xin[1] + (x[1]+R[1]))*arg.Xin[0]+(x[0]+R[0])) >> 1;
83  xout = X;
84  }
85  for (int d=0; d<arg.geometry; d++) {
86  const Matrix<complex<RegTypeIn>,nColor> in = arg.in(d, xin, parity);
87  Matrix<complex<RegTypeOut>,nColor> out = in;
88  arg.out(d, xout, parity) = out;
89  }//dir
90  }
91 
92  template <typename FloatOut, typename FloatIn, int length, typename OutOrder, typename InOrder, bool regularToextended>
93  void copyGaugeEx(CopyGaugeExArg<OutOrder,InOrder> arg) {
94  for (int parity=0; parity<2; parity++) {
95  for(int X=0; X<arg.volume/2; X++){
96  copyGaugeEx<FloatOut, FloatIn, length, OutOrder, InOrder, regularToextended>(arg, X, parity);
97  }
98  }
99  }
100 
101  template <typename FloatOut, typename FloatIn, int length, typename OutOrder, typename InOrder, bool regularToextended>
102  __global__ void copyGaugeExKernel(CopyGaugeExArg<OutOrder,InOrder> arg) {
103  for (int parity=0; parity<2; parity++) {
104  int X = blockIdx.x * blockDim.x + threadIdx.x;
105  if (X >= arg.volume/2) return;
106  copyGaugeEx<FloatOut, FloatIn, length, OutOrder, InOrder, regularToextended>(arg, X, parity);
107  }
108  }
109 
110  template <typename FloatOut, typename FloatIn, int length, typename OutOrder, typename InOrder>
111  class CopyGaugeEx : Tunable {
112  CopyGaugeExArg<OutOrder,InOrder> arg;
113  const GaugeField &meta; // use for metadata
114  QudaFieldLocation location;
115 
116  private:
117  unsigned int sharedBytesPerThread() const { return 0; }
118  unsigned int sharedBytesPerBlock(const TuneParam &param) const { return 0 ;}
119 
120  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
121  unsigned int minThreads() const { return arg.volume/2; }
122 
123  public:
124  CopyGaugeEx(CopyGaugeExArg<OutOrder,InOrder> &arg, const GaugeField &meta, QudaFieldLocation location)
125  : arg(arg), meta(meta), location(location) {
126  writeAuxString("out_stride=%d,in_stride=%d,geometry=%d",arg.out.stride,arg.in.stride,arg.geometry);
127  }
128  virtual ~CopyGaugeEx() { ; }
129 
130  void apply(const qudaStream_t &stream) {
131  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
132 
133  if (location == QUDA_CPU_FIELD_LOCATION) {
134  if (arg.regularToextended) copyGaugeEx<FloatOut, FloatIn, length, OutOrder, InOrder, true>(arg);
135  else copyGaugeEx<FloatOut, FloatIn, length, OutOrder, InOrder, false>(arg);
136  } else if (location == QUDA_CUDA_FIELD_LOCATION) {
137  if (arg.regularToextended)
138  qudaLaunchKernel(copyGaugeExKernel<FloatOut, FloatIn, length, OutOrder, InOrder, true>, tp, stream, arg);
139  else
140  qudaLaunchKernel(copyGaugeExKernel<FloatOut, FloatIn, length, OutOrder, InOrder, false>, tp, stream, arg);
141  }
142  }
143 
144  TuneKey tuneKey() const {
145  return TuneKey(meta.VolString(), typeid(*this).name(), aux);
146  }
147 
148  long long flops() const { return 0; }
149  long long bytes() const {
150  int sites = 4*arg.volume/2;
151  return 2 * sites * ( arg.in.Bytes() + arg.in.hasPhase*sizeof(FloatIn)
152  + arg.out.Bytes() + arg.out.hasPhase*sizeof(FloatOut) );
153  }
154  };
155 
156 
157  template <typename FloatOut, typename FloatIn, int length, typename OutOrder, typename InOrder>
158  void copyGaugeEx(OutOrder outOrder, const InOrder inOrder, const int *E,
159  const int *X, const int *faceVolumeCB, const GaugeField &meta, QudaFieldLocation location) {
160 
161  CopyGaugeExArg<OutOrder,InOrder>
162  arg(outOrder, inOrder, E, X, faceVolumeCB, meta.Ndim(), meta.Geometry());
163  CopyGaugeEx<FloatOut, FloatIn, length, OutOrder, InOrder> copier(arg, meta, location);
164  copier.apply(0);
165  }
166 
167  template <typename FloatOut, typename FloatIn, int length, typename InOrder>
168  void copyGaugeEx(const InOrder &inOrder, const int *X, GaugeField &out,
169  QudaFieldLocation location, FloatOut *Out) {
170 
171  int faceVolumeCB[QUDA_MAX_DIM];
172  for (int i=0; i<4; i++) faceVolumeCB[i] = out.SurfaceCB(i) * out.Nface();
173 
174  if (out.isNative()) {
175  if (out.Reconstruct() == QUDA_RECONSTRUCT_NO) {
176  typedef typename gauge_mapper<FloatOut, QUDA_RECONSTRUCT_NO>::type G;
177  copyGaugeEx<FloatOut, FloatIn, length>(G(out, Out), inOrder, out.X(), X, faceVolumeCB, out, location);
178  } else if (out.Reconstruct() == QUDA_RECONSTRUCT_12) {
179 #if QUDA_RECONSTRUCT & 2
180  typedef typename gauge_mapper<FloatOut,QUDA_RECONSTRUCT_12>::type G;
181  copyGaugeEx<FloatOut,FloatIn,length>
182  (G(out, Out), inOrder, out.X(), X, faceVolumeCB, out, location);
183 #else
184  errorQuda("QUDA_RECONSTRUCT=%d does not enable reconstruct-12", QUDA_RECONSTRUCT);
185 #endif
186  } else if (out.Reconstruct() == QUDA_RECONSTRUCT_8) {
187 #if QUDA_RECONSTRUCT & 1
188  typedef typename gauge_mapper<FloatOut,QUDA_RECONSTRUCT_8>::type G;
189  copyGaugeEx<FloatOut,FloatIn,length>
190  (G(out, Out), inOrder, out.X(), X, faceVolumeCB, out, location);
191 #else
192  errorQuda("QUDA_RECONSTRUCT=%d does not enable reconstruct-8", QUDA_RECONSTRUCT);
193 #endif
194 #ifdef GPU_STAGGERED_DIRAC
195  } else if (out.Reconstruct() == QUDA_RECONSTRUCT_13) {
196 #if QUDA_RECONSTRUCT & 2
197  typedef typename gauge_mapper<FloatOut,QUDA_RECONSTRUCT_13>::type G;
198  copyGaugeEx<FloatOut,FloatIn,length>
199  (G(out, Out), inOrder, out.X(), X, faceVolumeCB, out, location);
200 #else
201  errorQuda("QUDA_RECONSTRUCT=%d does not enable reconstruct-13", QUDA_RECONSTRUCT);
202 #endif
203  } else if (out.Reconstruct() == QUDA_RECONSTRUCT_9) {
204 #if QUDA_RECONSTRUCT & 1
205  typedef typename gauge_mapper<FloatOut,QUDA_RECONSTRUCT_9>::type G;
206  copyGaugeEx<FloatOut,FloatIn,length>
207  (G(out, Out), inOrder, out.X(), X, faceVolumeCB, out, location);
208 #else
209  errorQuda("QUDA_RECONSTRUCT=%d does not enable reconstruct-9", QUDA_RECONSTRUCT);
210 #endif
211 #endif // GPU_STAGGERED_DIRAC
212  } else {
213  errorQuda("Reconstruction %d and order %d not supported", out.Reconstruct(), out.Order());
214  }
215  } else if (out.Order() == QUDA_QDP_GAUGE_ORDER) {
216 
217 #ifdef BUILD_QDP_INTERFACE
218  copyGaugeEx<FloatOut,FloatIn,length>
219  (QDPOrder<FloatOut,length>(out, Out), inOrder, out.X(), X, faceVolumeCB, out, location);
220 #else
221  errorQuda("QDP interface has not been built\n");
222 #endif
223 
224  } else if (out.Order() == QUDA_MILC_GAUGE_ORDER) {
225 
226 #ifdef BUILD_MILC_INTERFACE
227  copyGaugeEx<FloatOut,FloatIn,length>
228  (MILCOrder<FloatOut,length>(out, Out), inOrder, out.X(), X, faceVolumeCB, out, location);
229 #else
230  errorQuda("MILC interface has not been built\n");
231 #endif
232 
233  } else if (out.Order() == QUDA_TIFR_GAUGE_ORDER) {
234 
235 #ifdef BUILD_TIFR_INTERFACE
236  copyGaugeEx<FloatOut,FloatIn,length>
237  (TIFROrder<FloatOut,length>(out, Out), inOrder, out.X(), X, faceVolumeCB, out, location);
238 #else
239  errorQuda("TIFR interface has not been built\n");
240 #endif
241 
242  } else {
243  errorQuda("Gauge field %d order not supported", out.Order());
244  }
245 
246  }
247 
248  template <typename FloatOut, typename FloatIn, int length>
249  void copyGaugeEx(GaugeField &out, const GaugeField &in, QudaFieldLocation location,
250  FloatOut *Out, FloatIn *In) {
251 
252  if (in.isNative()) {
253  if (in.Reconstruct() == QUDA_RECONSTRUCT_NO) {
254  typedef typename gauge_mapper<FloatIn, QUDA_RECONSTRUCT_NO>::type G;
255  copyGaugeEx<FloatOut, FloatIn, length>(G(in, In), in.X(), out, location, Out);
256  } else if (in.Reconstruct() == QUDA_RECONSTRUCT_12) {
257 #if QUDA_RECONSTRUCT & 2
258  typedef typename gauge_mapper<FloatIn,QUDA_RECONSTRUCT_12>::type G;
259  copyGaugeEx<FloatOut,FloatIn,length> (G(in, In), in.X(), out, location, Out);
260 #else
261  errorQuda("QUDA_RECONSTRUCT=%d does not enable reconstruct-12", QUDA_RECONSTRUCT);
262 #endif
263  } else if (in.Reconstruct() == QUDA_RECONSTRUCT_8) {
264 #if QUDA_RECONSTRUCT & 1
265  typedef typename gauge_mapper<FloatIn,QUDA_RECONSTRUCT_8>::type G;
266  copyGaugeEx<FloatOut,FloatIn,length> (G(in, In), in.X(), out, location, Out);
267 #else
268  errorQuda("QUDA_RECONSTRUCT=%d does not enable reconstruct-8", QUDA_RECONSTRUCT);
269 #endif
270 #ifdef GPU_STAGGERED_DIRAC
271  } else if (in.Reconstruct() == QUDA_RECONSTRUCT_13) {
272 #if QUDA_RECONSTRUCT & 2
273  typedef typename gauge_mapper<FloatIn,QUDA_RECONSTRUCT_13>::type G;
274  copyGaugeEx<FloatOut,FloatIn,length> (G(in, In), in.X(), out, location, Out);
275 #else
276  errorQuda("QUDA_RECONSTRUCT=%d does not enable reconstruct-13", QUDA_RECONSTRUCT);
277 #endif
278  } else if (in.Reconstruct() == QUDA_RECONSTRUCT_9) {
279 #if QUDA_RECONSTRUCT & 1
280  typedef typename gauge_mapper<FloatIn,QUDA_RECONSTRUCT_9>::type G;
281  copyGaugeEx<FloatOut,FloatIn,length> (G(in, In), in.X(), out, location, Out);
282 #else
283  errorQuda("QUDA_RECONSTRUCT=%d does not enable reconstruct-9", QUDA_RECONSTRUCT);
284 #endif
285 #endif // GPU_STAGGERED_DIRAC
286  } else {
287  errorQuda("Reconstruction %d and order %d not supported", in.Reconstruct(), in.Order());
288  }
289  } else if (in.Order() == QUDA_QDP_GAUGE_ORDER) {
290 
291 #ifdef BUILD_QDP_INTERFACE
292  copyGaugeEx<FloatOut,FloatIn,length>(QDPOrder<FloatIn,length>(in, In),
293  in.X(), out, location, Out);
294 #else
295  errorQuda("QDP interface has not been built\n");
296 #endif
297 
298  } else if (in.Order() == QUDA_MILC_GAUGE_ORDER) {
299 
300 #ifdef BUILD_MILC_INTERFACE
301  copyGaugeEx<FloatOut,FloatIn,length>(MILCOrder<FloatIn,length>(in, In),
302  in.X(), out, location, Out);
303 #else
304  errorQuda("MILC interface has not been built\n");
305 #endif
306 
307  } else if (in.Order() == QUDA_TIFR_GAUGE_ORDER) {
308 
309 #ifdef BUILD_TIFR_INTERFACE
310  copyGaugeEx<FloatOut,FloatIn,length>(TIFROrder<FloatIn,length>(in, In),
311  in.X(), out, location, Out);
312 #else
313  errorQuda("TIFR interface has not been built\n");
314 #endif
315 
316  } else {
317  errorQuda("Gauge field %d order not supported", in.Order());
318  }
319 
320  }
321 
322  template <typename FloatOut, typename FloatIn>
323  void copyGaugeEx(GaugeField &out, const GaugeField &in, QudaFieldLocation location,
324  FloatOut *Out, FloatIn *In) {
325 
326  if (in.Ncolor() != 3 && out.Ncolor() != 3) {
327  errorQuda("Unsupported number of colors; out.Nc=%d, in.Nc=%d", out.Ncolor(), in.Ncolor());
328  }
329 
330  if (out.Geometry() != in.Geometry()) {
331  errorQuda("Field geometries %d %d do not match", out.Geometry(), in.Geometry());
332  }
333 
334  if (in.LinkType() != QUDA_ASQTAD_MOM_LINKS && out.LinkType() != QUDA_ASQTAD_MOM_LINKS) {
335  // we are doing gauge field packing
336  copyGaugeEx<FloatOut,FloatIn,18>(out, in, location, Out, In);
337  } else {
338  errorQuda("Not supported");
339  }
340  }
341 
342  void copyExtendedGauge(GaugeField &out, const GaugeField &in,
343  QudaFieldLocation location, void *Out, void *In) {
344 
345  for (int d=0; d<in.Ndim(); d++) {
346  if ( (out.X()[d] - in.X()[d]) % 2 != 0)
347  errorQuda("Cannot copy into an asymmetrically extended gauge field");
348  }
349 
350  if (out.Precision() == QUDA_DOUBLE_PRECISION) {
351  if (in.Precision() == QUDA_DOUBLE_PRECISION) {
352  copyGaugeEx(out, in, location, (double*)Out, (double*)In);
353  } else if (in.Precision() == QUDA_SINGLE_PRECISION) {
354 #if QUDA_PRECISION & 4
355  copyGaugeEx(out, in, location, (double*)Out, (float*)In);
356 #else
357  errorQuda("QUDA_PRECISION=%d does not enable single precision", QUDA_PRECISION);
358 #endif
359  } else {
360  errorQuda("Precision %d not instantiated", in.Precision());
361  }
362  } else if (out.Precision() == QUDA_SINGLE_PRECISION) {
363  if (in.Precision() == QUDA_DOUBLE_PRECISION) {
364  copyGaugeEx(out, in, location, (float *)Out, (double *)In);
365  } else if (in.Precision() == QUDA_SINGLE_PRECISION) {
366 #if QUDA_PRECISION & 4
367  copyGaugeEx(out, in, location, (float *)Out, (float *)In);
368 #else
369  errorQuda("QUDA_PRECISION=%d does not enable single precision", QUDA_PRECISION);
370 #endif
371  } else {
372  errorQuda("Precision %d not instantiated", in.Precision());
373  }
374  } else if (out.Precision() == QUDA_HALF_PRECISION) {
375  if (in.Precision() == QUDA_HALF_PRECISION) {
376 #if QUDA_PRECISION & 2
377  copyGaugeEx(out, in, location, (short *)Out, (short *)In);
378 #else
379  errorQuda("QUDA_PRECISION=%d does not enable single precision", QUDA_PRECISION);
380 #endif
381  } else {
382  errorQuda("Precision %d not instantiated", in.Precision());
383  }
384  } else if (out.Precision() == QUDA_QUARTER_PRECISION) {
385  if (in.Precision() == QUDA_QUARTER_PRECISION) {
386 #if QUDA_PRECISION & 1
387  copyGaugeEx(out, in, location, (int8_t *)Out, (int8_t *)In);
388 #else
389  errorQuda("QUDA_PRECISION=%d does not enable single precision", QUDA_PRECISION);
390 #endif
391  } else {
392  errorQuda("Precision %d not instantiated", in.Precision());
393  }
394  } else {
395  errorQuda("Precision %d not instantiated", out.Precision());
396  }
397  }
398 
399 } // namespace quda