QUDA  v0.7.0
A library for QCD on GPUs
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
register_traits.h
Go to the documentation of this file.
1 #ifndef REGISTER_TRAITS_H
2 #define REGISTER_TRAITS_H
3 
4 #include <quda_internal.h>
5 
6 namespace quda {
7 
8  /*
9  Here we use traits to define the mapping between storage type and
10  register type:
11  double -> double
12  float -> float
13  short -> float
14  This allows us to wrap the encapsulate the register type into the storage template type
15  */
16  template<typename> struct mapper { };
17  template<> struct mapper<double> { typedef double type; };
18  template<> struct mapper<float> { typedef float type; };
19  template<> struct mapper<short> { typedef float type; };
20 
21  template<> struct mapper<double2> { typedef double2 type; };
22  template<> struct mapper<float2> { typedef float2 type; };
23  template<> struct mapper<short2> { typedef float2 type; };
24 
25  template<> struct mapper<double4> { typedef double4 type; };
26  template<> struct mapper<float4> { typedef float4 type; };
27  template<> struct mapper<short4> { typedef float4 type; };
28 
29  /* Traits used to determine if a variable is half precision or not */
30  template< typename T > struct isHalf{ static const bool value = false; };
31  template<> struct isHalf<short>{ static const bool value = true; };
32 
33  template<typename T1, typename T2> __host__ __device__ inline void copy (T1 &a, const T2 &b) { a = b; }
34  template<> __host__ __device__ inline void copy(float &a, const short &b) { a = (float)b/MAX_SHORT; }
35  template<> __host__ __device__ inline void copy(short &a, const float &b) { a = (short)(b*MAX_SHORT); }
36 
40  template <bool isHalf>
41  struct Trig {
42  template<typename T>
43  __device__ __host__ static T Atan2( const T &a, const T &b) { return atan2(a,b); }
44  template<typename T>
45  __device__ __host__ static T Sin( const T &a ) { return sin(a); }
46  template<typename T>
47  __device__ __host__ static T Cos( const T &a ) { return cos(a); }
48 
49  template<typename T>
50  __device__ __host__ static void SinCos(const T& a, T *s, T *c) { *s = sin(a); *c = cos(a); }
51  };
52 
56  template <>
57  struct Trig<true> {
58  template<typename T>
59  __device__ __host__ static T Atan2( const T &a, const T &b) { return atan2(a,b)/M_PI; }
60  template<typename T>
61  __device__ __host__ static T Sin( const T &a ) { return sin(a*M_PI); }
62  template<typename T>
63  __device__ __host__ static T Cos( const T &a ) { return cos(a*M_PI); }
64  };
65 
66 
67 
68 
69 } // namespace quda
70 
71 #endif
__device__ static __host__ T Sin(const T &a)
__device__ static __host__ T Atan2(const T &a, const T &b)
__host__ __device__ void copy(T1 &a, const T2 &b)
static const bool value
__device__ static __host__ T Atan2(const T &a, const T &b)
__host__ __device__ ValueType sin(ValueType x)
Definition: complex_quda.h:40
__host__ __device__ ValueType atan2(ValueType x, ValueType y)
Definition: complex_quda.h:65
__device__ static __host__ T Cos(const T &a)
__device__ static __host__ void SinCos(const T &a, T *s, T *c)
__host__ __device__ ValueType cos(ValueType x)
Definition: complex_quda.h:35
#define MAX_SHORT
Definition: quda_internal.h:30
__device__ static __host__ T Cos(const T &a)
VOLATILE spinorFloat * s
__device__ static __host__ T Sin(const T &a)