QUDA  v1.1.0
A library for QCD on GPUs
register_traits.h
Go to the documentation of this file.
1 #ifndef _REGISTER_TRAITS_H
2 #define _REGISTER_TRAITS_H
3 
11 #include <quda_internal.h>
12 #include <generics/ldg.h>
13 #include <complex_quda.h>
14 #include <inline_ptx.h>
15 
16 namespace quda {
17 
18  /*
19  Here we use traits to define the greater type used for mixing types of computation involving these types
20  */
21  template <class T, class U> struct PromoteTypeId {
22  typedef T type;
23  };
24  template <> struct PromoteTypeId<complex<float>, float> {
26  };
27  template <> struct PromoteTypeId<float, complex<float>> {
29  };
30  template <> struct PromoteTypeId<complex<double>, double> {
32  };
33  template <> struct PromoteTypeId<double, complex<double>> {
35  };
36  template <> struct PromoteTypeId<double, int> {
37  typedef double type;
38  };
39  template <> struct PromoteTypeId<int, double> {
40  typedef double type;
41  };
42  template <> struct PromoteTypeId<float, int> {
43  typedef float type;
44  };
45  template <> struct PromoteTypeId<int, float> {
46  typedef float type;
47  };
48  template <> struct PromoteTypeId<double, float> {
49  typedef double type;
50  };
51  template <> struct PromoteTypeId<float, double> {
52  typedef double type;
53  };
54  template <> struct PromoteTypeId<double, short> {
55  typedef double type;
56  };
57  template <> struct PromoteTypeId<short, double> {
58  typedef double type;
59  };
60  template <> struct PromoteTypeId<double, int8_t> {
61  typedef double type;
62  };
63  template <> struct PromoteTypeId<int8_t, double> {
64  typedef double type;
65  };
66  template <> struct PromoteTypeId<float, short> {
67  typedef float type;
68  };
69  template <> struct PromoteTypeId<short, float> {
70  typedef float type;
71  };
72  template <> struct PromoteTypeId<float, int8_t> {
73  typedef float type;
74  };
75  template <> struct PromoteTypeId<int8_t, float> {
76  typedef float type;
77  };
78  template <> struct PromoteTypeId<short, int8_t> {
79  typedef short type;
80  };
81  template <> struct PromoteTypeId<int8_t, short> {
82  typedef short type;
83  };
84 
85  /*
86  Here we use traits to define the mapping between storage type and
87  register type:
88  double -> double
89  float -> float
90  short -> float
91  quarter -> float
92  This allows us to wrap the encapsulate the register type into the storage template type
93  */
94  template<typename> struct mapper { };
95  template<> struct mapper<double> { typedef double type; };
96  template<> struct mapper<float> { typedef float type; };
97  template<> struct mapper<short> { typedef float type; };
98  template <> struct mapper<int8_t> {
99  typedef float type;
100  };
101 
102  template<> struct mapper<double2> { typedef double2 type; };
103  template<> struct mapper<float2> { typedef float2 type; };
104  template<> struct mapper<short2> { typedef float2 type; };
105  template<> struct mapper<char2> { typedef float2 type; };
106 
107  template<> struct mapper<double4> { typedef double4 type; };
108  template<> struct mapper<float4> { typedef float4 type; };
109  template<> struct mapper<short4> { typedef float4 type; };
110  template<> struct mapper<char4> { typedef float4 type; };
111 
112  template <> struct mapper<double8> {
113  typedef double8 type;
114  };
115  template <> struct mapper<float8> {
116  typedef float8 type;
117  };
118  template <> struct mapper<short8> {
119  typedef float8 type;
120  };
121  template <> struct mapper<char8> {
122  typedef float8 type;
123  };
124 
125  template<typename,typename> struct bridge_mapper { };
126  template<> struct bridge_mapper<double2,double2> { typedef double2 type; };
127  template<> struct bridge_mapper<double2,float2> { typedef double2 type; };
128  template<> struct bridge_mapper<double2,short2> { typedef float2 type; };
129  template<> struct bridge_mapper<double2,char2> { typedef float2 type; };
130  template<> struct bridge_mapper<double2,float4> { typedef double4 type; };
131  template<> struct bridge_mapper<double2,short4> { typedef float4 type; };
132  template<> struct bridge_mapper<double2,char4> { typedef float4 type; };
133  template<> struct bridge_mapper<float4,double2> { typedef float2 type; };
134  template<> struct bridge_mapper<float4,float4> { typedef float4 type; };
135  template<> struct bridge_mapper<float4,short4> { typedef float4 type; };
136  template<> struct bridge_mapper<float4,char4> { typedef float4 type; };
137  template<> struct bridge_mapper<float2,double2> { typedef float2 type; };
138  template<> struct bridge_mapper<float2,float2> { typedef float2 type; };
139  template<> struct bridge_mapper<float2,short2> { typedef float2 type; };
140  template<> struct bridge_mapper<float2,char2> { typedef float2 type; };
141 
142  template <> struct bridge_mapper<double2, short8> {
143  typedef double8 type;
144  };
145  template <> struct bridge_mapper<double2, char8> {
146  typedef double8 type;
147  };
148  template <> struct bridge_mapper<float8, short8> {
149  typedef float8 type;
150  };
151  template <> struct bridge_mapper<float8, char8> {
152  typedef float8 type;
153  };
154  template <> struct bridge_mapper<float4, short8> {
155  typedef float8 type;
156  };
157  template <> struct bridge_mapper<float4, char8> {
158  typedef float8 type;
159  };
160 
161  template<typename> struct vec_length { static const int value = 0; };
162  template <> struct vec_length<double8> {
163  static const int value = 8;
164  };
165  template<> struct vec_length<double4> { static const int value = 4; };
166  template <> struct vec_length<double3> {
167  static const int value = 3;
168  };
169  template<> struct vec_length<double2> { static const int value = 2; };
170  template<> struct vec_length<double> { static const int value = 1; };
171  template <> struct vec_length<float8> {
172  static const int value = 8;
173  };
174  template<> struct vec_length<float4> { static const int value = 4; };
175  template <> struct vec_length<float3> {
176  static const int value = 3;
177  };
178  template<> struct vec_length<float2> { static const int value = 2; };
179  template<> struct vec_length<float> { static const int value = 1; };
180  template <> struct vec_length<short8> {
181  static const int value = 8;
182  };
183  template<> struct vec_length<short4> { static const int value = 4; };
184  template <> struct vec_length<short3> {
185  static const int value = 3;
186  };
187  template<> struct vec_length<short2> { static const int value = 2; };
188  template<> struct vec_length<short> { static const int value = 1; };
189  template <> struct vec_length<char8> {
190  static const int value = 8;
191  };
192  template<> struct vec_length<char4> { static const int value = 4; };
193  template <> struct vec_length<char3> {
194  static const int value = 3;
195  };
196  template<> struct vec_length<char2> { static const int value = 2; };
197  template <> struct vec_length<int8_t> {
198  static const int value = 1;
199  };
200 
201  template <> struct vec_length<Complex> {
202  static const int value = 2;
203  };
204  template <> struct vec_length<complex<double>> {
205  static const int value = 2;
206  };
207  template <> struct vec_length<complex<float>> {
208  static const int value = 2;
209  };
210  template <> struct vec_length<complex<short>> {
211  static const int value = 2;
212  };
213  template <> struct vec_length<complex<int8_t>> {
214  static const int value = 2;
215  };
216 
217  template<typename, int N> struct vector { };
218 
219  template<> struct vector<double, 2> {
220  typedef double2 type;
222  vector(const type &a) { this->a.x = a.x; this->a.y = a.y; }
223  operator type() const { return a; }
224  };
225 
226  template<> struct vector<float, 2> {
227  typedef float2 type;
228  float2 a;
229  vector(const double2 &a) { this->a.x = a.x; this->a.y = a.y; }
230  operator type() const { return a; }
231  };
232 
233  template<> struct vector<int, 2> {
234  typedef int2 type;
235  int2 a;
236  vector(const int2 &a) { this->a.x = a.x; this->a.y = a.y; }
237  operator type() const { return a; }
238  };
239 
240  template<typename> struct scalar { };
241  template <> struct scalar<double8> {
242  typedef double type;
243  };
244  template <> struct scalar<double4> {
245  typedef double type;
246  };
247  template <> struct scalar<double3> {
248  typedef double type;
249  };
250  template <> struct scalar<double2> {
251  typedef double type;
252  };
253  template <> struct scalar<double> {
254  typedef double type;
255  };
256  template <> struct scalar<float8> {
257  typedef float type;
258  };
259  template <> struct scalar<float4> {
260  typedef float type;
261  };
262  template <> struct scalar<float3> {
263  typedef float type;
264  };
265  template <> struct scalar<float2> {
266  typedef float type;
267  };
268  template <> struct scalar<float> {
269  typedef float type;
270  };
271  template <> struct scalar<short8> {
272  typedef short type;
273  };
274  template <> struct scalar<short4> {
275  typedef short type;
276  };
277  template <> struct scalar<short3> {
278  typedef short type;
279  };
280  template <> struct scalar<short2> {
281  typedef short type;
282  };
283  template <> struct scalar<short> {
284  typedef short type;
285  };
286  template <> struct scalar<char8> {
287  typedef int8_t type;
288  };
289  template <> struct scalar<char4> {
290  typedef int8_t type;
291  };
292  template <> struct scalar<char3> {
293  typedef int8_t type;
294  };
295  template <> struct scalar<char2> {
296  typedef int8_t type;
297  };
298  template <> struct scalar<int8_t> {
299  typedef int8_t type;
300  };
301 
302  template <> struct scalar<complex<double>> {
303  typedef double type;
304  };
305  template <> struct scalar<complex<float>> {
306  typedef float type;
307  };
308 
309 #ifdef QUAD_SUM
310  template <> struct scalar<doubledouble> {
311  typedef doubledouble type;
312  };
313  template <> struct scalar<doubledouble2> {
314  typedef doubledouble type;
315  };
316  template <> struct scalar<doubledouble3> {
317  typedef doubledouble type;
318  };
319  template <> struct scalar<doubledouble4> {
320  typedef doubledouble type;
321  };
322  template <> struct vector<doubledouble, 2> {
323  typedef doubledouble2 type;
324  };
325 #endif
326 
327  /* Traits used to determine if a variable is half precision or not */
328  template< typename T > struct isHalf{ static const bool value = false; };
329  template<> struct isHalf<short>{ static const bool value = true; };
330  template<> struct isHalf<short2>{ static const bool value = true; };
331  template<> struct isHalf<short4>{ static const bool value = true; };
332  template <> struct isHalf<short8> {
333  static const bool value = true;
334  };
335 
336  /* Traits used to determine if a variable is quarter precision or not */
337  template< typename T > struct isQuarter{ static const bool value = false; };
338  template <> struct isQuarter<int8_t> {
339  static const bool value = true;
340  };
341  template<> struct isQuarter<char2>{ static const bool value = true; };
342  template<> struct isQuarter<char4>{ static const bool value = true; };
343  template <> struct isQuarter<char8> {
344  static const bool value = true;
345  };
346 
347  /* Traits used to determine if a variable is fixed precision or not */
348  template< typename T > struct isFixed{ static const bool value = false; };
349  template<> struct isFixed<short>{ static const bool value = true; };
350  template<> struct isFixed<short2>{ static const bool value = true; };
351  template<> struct isFixed<short4>{ static const bool value = true; };
352  template <> struct isFixed<short8> {
353  static const bool value = true;
354  };
355  template <> struct isFixed<int8_t> {
356  static const bool value = true;
357  };
358  template<> struct isFixed<char2>{ static const bool value = true; };
359  template<> struct isFixed<char4>{ static const bool value = true; };
360  template <> struct isFixed<char8> {
361  static const bool value = true;
362  };
363 
367  template <bool isFixed, typename T>
368  struct Trig {
369  __device__ __host__ static T Atan2( const T &a, const T &b) { return atan2(a,b); }
370  __device__ __host__ static T Sin( const T &a ) { return sin(a); }
371  __device__ __host__ static T Cos( const T &a ) { return cos(a); }
372  __device__ __host__ static void SinCos(const T &a, T *s, T *c) { sincos(a, s, c); }
373  };
374 
378  template <>
379  struct Trig<false,float> {
380  __device__ __host__ static float Atan2( const float &a, const float &b) { return atan2f(a,b); }
381  __device__ __host__ static float Sin(const float &a)
382  {
383 #ifdef __CUDA_ARCH__
384  return __sinf(a);
385 #else
386  return sinf(a);
387 #endif
388  }
389  __device__ __host__ static float Cos(const float &a)
390  {
391 #ifdef __CUDA_ARCH__
392  return __cosf(a);
393 #else
394  return cosf(a);
395 #endif
396  }
397 
398  __device__ __host__ static void SinCos(const float &a, float *s, float *c)
399  {
400 #ifdef __CUDA_ARCH__
401  __sincosf(a, s, c);
402 #else
403  sincosf(a, s, c);
404 #endif
405  }
406  };
407 
411  template <>
412  struct Trig<true,float> {
413  __device__ __host__ static float Atan2( const float &a, const float &b) { return atan2f(a,b)/M_PI; }
414  __device__ __host__ static float Sin(const float &a)
415  {
416 #ifdef __CUDA_ARCH__
417  return __sinf(a * static_cast<float>(M_PI));
418 #else
419  return sinf(a * static_cast<float>(M_PI));
420 #endif
421  }
422  __device__ __host__ static float Cos(const float &a)
423  {
424 #ifdef __CUDA_ARCH__
425  return __cosf(a * static_cast<float>(M_PI));
426 #else
427  return cosf(a * static_cast<float>(M_PI));
428 #endif
429  }
430 
431  __device__ __host__ static void SinCos(const float &a, float *s, float *c)
432  {
433 #ifdef __CUDA_ARCH__
434  __sincosf(a * static_cast<float>(M_PI), s, c);
435 #else
436  sincosf(a * static_cast<float>(M_PI), s, c);
437 #endif
438  }
439  };
440 
441 
442  template <typename Float, int number> struct VectorType;
443 
444  // double precision
445  template <> struct VectorType<double, 1>{typedef double type; };
446  template <> struct VectorType<double, 2>{typedef double2 type; };
447  template <> struct VectorType<double, 3> {
448  typedef double3 type;
449  };
450  template <> struct VectorType<double, 4>{typedef double4 type; };
451  template <> struct VectorType<double, 8> {
452  typedef double8 type;
453  };
454 
455  // single precision
456  template <> struct VectorType<float, 1>{typedef float type; };
457  template <> struct VectorType<float, 2>{typedef float2 type; };
458  template <> struct VectorType<float, 3> {
459  typedef float3 type;
460  };
461  template <> struct VectorType<float, 4>{typedef float4 type; };
462  template <> struct VectorType<float, 8> {
463  typedef float8 type;
464  };
465 
466  // half precision
467  template <> struct VectorType<short, 1>{typedef short type; };
468  template <> struct VectorType<short, 2>{typedef short2 type; };
469  template <> struct VectorType<short, 3> {
470  typedef short3 type;
471  };
472  template <> struct VectorType<short, 4>{typedef short4 type; };
473  template <> struct VectorType<short, 8> {
474  typedef short8 type;
475  };
476 
477  // quarter precision
478  template <> struct VectorType<int8_t, 1> {
479  typedef int8_t type;
480  };
481  template <> struct VectorType<int8_t, 2> {
482  typedef char2 type;
483  };
484  template <> struct VectorType<int8_t, 3> {
485  typedef char3 type;
486  };
487  template <> struct VectorType<int8_t, 4> {
488  typedef char4 type;
489  };
490  template <> struct VectorType<int8_t, 8> {
491  typedef char8 type;
492  };
493 
494  template <typename VectorType> __device__ __host__ inline VectorType vector_load(const void *ptr, int idx)
495  {
496 #if (__CUDA_ARCH__ >= 320 && __CUDA_ARCH__ < 520)
497  return __ldg(reinterpret_cast<const VectorType *>(ptr) + idx);
498 #else
499  return reinterpret_cast<const VectorType *>(ptr)[idx];
500 #endif
501  }
502 
503  template <> __device__ __host__ inline short8 vector_load(const void *ptr, int idx)
504  {
505  float4 tmp = vector_load<float4>(ptr, idx);
506  short8 recast;
507  memcpy(&recast, &tmp, sizeof(float4));
508  return recast;
509  }
510 
511  template <> __device__ __host__ inline char8 vector_load(const void *ptr, int idx)
512  {
513  float2 tmp = vector_load<float2>(ptr, idx);
514  char8 recast;
515  memcpy(&recast, &tmp, sizeof(float2));
516  return recast;
517  }
518 
519  template <typename VectorType>
520  __device__ __host__ inline void vector_store(void *ptr, int idx, const VectorType &value) {
521  reinterpret_cast< VectorType* >(ptr)[idx] = value;
522  }
523 
524  template <>
525  __device__ __host__ inline void vector_store(void *ptr, int idx, const double2 &value) {
526 #if defined(__CUDA_ARCH__)
527  store_streaming_double2(reinterpret_cast<double2*>(ptr)+idx, value.x, value.y);
528 #else
529  reinterpret_cast<double2*>(ptr)[idx] = value;
530 #endif
531  }
532 
533  template <>
534  __device__ __host__ inline void vector_store(void *ptr, int idx, const float4 &value) {
535 #if defined(__CUDA_ARCH__)
536  store_streaming_float4(reinterpret_cast<float4*>(ptr)+idx, value.x, value.y, value.z, value.w);
537 #else
538  reinterpret_cast<float4*>(ptr)[idx] = value;
539 #endif
540  }
541 
542  template <>
543  __device__ __host__ inline void vector_store(void *ptr, int idx, const float2 &value) {
544 #if defined(__CUDA_ARCH__)
545  store_streaming_float2(reinterpret_cast<float2*>(ptr)+idx, value.x, value.y);
546 #else
547  reinterpret_cast<float2*>(ptr)[idx] = value;
548 #endif
549  }
550 
551  template <>
552  __device__ __host__ inline void vector_store(void *ptr, int idx, const short4 &value) {
553 #if defined(__CUDA_ARCH__)
554  store_streaming_short4(reinterpret_cast<short4*>(ptr)+idx, value.x, value.y, value.z, value.w);
555 #else
556  reinterpret_cast<short4*>(ptr)[idx] = value;
557 #endif
558  }
559 
560  template <>
561  __device__ __host__ inline void vector_store(void *ptr, int idx, const short2 &value) {
562 #if defined(__CUDA_ARCH__)
563  store_streaming_short2(reinterpret_cast<short2*>(ptr)+idx, value.x, value.y);
564 #else
565  reinterpret_cast<short2*>(ptr)[idx] = value;
566 #endif
567  }
568 
569  // A char4 is the same size as a short2
570  template <>
571  __device__ __host__ inline void vector_store(void *ptr, int idx, const char4 &value) {
572 #if defined(__CUDA_ARCH__)
573  store_streaming_short2(reinterpret_cast<short2*>(ptr)+idx, reinterpret_cast<const short2*>(&value)->x, reinterpret_cast<const short2*>(&value)->y);
574 #else
575  reinterpret_cast<char4*>(ptr)[idx] = value;
576 #endif
577  }
578 
579  template <>
580  __device__ __host__ inline void vector_store(void *ptr, int idx, const char2 &value) {
581 #if defined(__CUDA_ARCH__)
582  vector_store(ptr, idx, *reinterpret_cast<const short*>(&value));
583 #else
584  reinterpret_cast<char2*>(ptr)[idx] = value;
585 #endif
586  }
587 
588  template <> __device__ __host__ inline void vector_store(void *ptr, int idx, const short8 &value)
589  {
590 #if defined(__CUDA_ARCH__)
591  vector_store(ptr, idx, *reinterpret_cast<const float4 *>(&value));
592 #else
593  reinterpret_cast<short8 *>(ptr)[idx] = value;
594 #endif
595  }
596 
597  template <> __device__ __host__ inline void vector_store(void *ptr, int idx, const char8 &value)
598  {
599 #if defined(__CUDA_ARCH__)
600  vector_store(ptr, idx, *reinterpret_cast<const float2 *>(&value));
601 #else
602  reinterpret_cast<char8 *>(ptr)[idx] = value;
603 #endif
604  }
605 
606  template<bool large_alloc> struct AllocType { };
607  template<> struct AllocType<true> { typedef size_t type; };
608  template<> struct AllocType<false> { typedef int type; };
609 
610 } // namespace quda
611 
612 #endif // _REGISTER_TRAITS_H
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:34
__device__ __forceinline__ T __ldg(const T *ptr)
Definition: ldg.h:44
__device__ void store_streaming_float2(float2 *addr, float x, float y)
Definition: inline_ptx.h:93
__host__ __device__ ValueType cos(ValueType x)
Definition: complex_quda.h:46
__host__ __device__ ValueType atan2(ValueType x, ValueType y)
Definition: complex_quda.h:76
__device__ void store_streaming_float4(float4 *addr, float x, float y, float z, float w)
Definition: inline_ptx.h:78
__device__ void store_streaming_double2(double2 *addr, double x, double y)
Definition: inline_ptx.h:88
__host__ __device__ ValueType sin(ValueType x)
Definition: complex_quda.h:51
std::complex< double > Complex
Definition: quda_internal.h:86
__device__ void store_streaming_short4(short4 *addr, short x, short y, short z, short w)
Definition: inline_ptx.h:83
__device__ void store_streaming_short2(short2 *addr, short x, short y)
Definition: inline_ptx.h:98
__device__ __host__ void vector_store(void *ptr, int idx, const VectorType &value)
__device__ __host__ VectorType vector_load(const void *ptr, int idx)
__device__ static __host__ float Cos(const float &a)
__device__ static __host__ void SinCos(const float &a, float *s, float *c)
__device__ static __host__ float Atan2(const float &a, const float &b)
__device__ static __host__ float Sin(const float &a)
__device__ static __host__ float Cos(const float &a)
__device__ static __host__ float Atan2(const float &a, const float &b)
__device__ static __host__ void SinCos(const float &a, float *s, float *c)
__device__ static __host__ float Sin(const float &a)
__device__ static __host__ T Sin(const T &a)
__device__ static __host__ void SinCos(const T &a, T *s, T *c)
__device__ static __host__ T Cos(const T &a)
__device__ static __host__ T Atan2(const T &a, const T &b)
static const bool value
static const bool value
static const bool value
static const int value
vector(const double2 &a)