36 #define WARP_CONVERGED 0xffffffff
45 template<int m, bool ispo2=is_power_of_two<m>::value,
bool isodd=
is_odd<m>::value>
60 template<int m, typename Schema=typename tx_algorithm<m>::type>
73 static const int permute = m - 1;
82 template<int m, typename Schema=typename tx_algorithm<m>::type>
95 template<
typename T,
template<
int>
class Permute,
int position=0>
98 template<
typename T,
int s,
template<
int>
class Permute,
int position>
101 static const int idx = Permute<position>::value;
102 template<
typename Source>
106 trove::get<idx>(src),
112 template<
typename T,
template<
int>
class Permute,
int position>
115 static const int idx = Permute<position>::value;
116 template<
typename Source>
124 template<
int m,
int a,
int b=0>
128 static const int value = (a * x + b) % m;
136 static const int p = m /
c;
139 static const int value = (x *
o - (x /
p)) % m;
153 template<
typename Array>
163 template<
typename Array>
170 template<
typename Array>
179 template<
typename Array>
188 template<
typename Array,
int b,
int o>
191 template<
int s,
int b,
int o>
205 template<
int b,
int o>
213 return Array(offset);
217 template<
int m,
typename Schema>
224 int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) &
WARP_MASK;
226 return initial_offset;
233 int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) &
WARP_MASK;
235 return initial_offset;
239 template<
int m,
typename Schema>
245 int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) &
WARP_MASK;
246 int initial_offset = (warp_id * m) &
WARP_MASK;
247 return initial_offset;
252 template<
int m,
typename Schema>
259 constants::offset>::impl(initial_offset);
262 template<
typename T,
int m,
int p = 0>
265 template<
int s,
int m,
int p>
268 static const int mod_n = n - 1;
271 static const int mod_c = c - 1;
273 static const int n_div_c = n / c;
274 static const int mod_n_div_c = n_div_c - 1;
278 int offset = ((((idx >> log_c) * k) & mod_n_div_c) +
279 ((idx & mod_c) << log_n_div_c)) & mod_n;
280 int new_idx = idx + n - 1;
281 new_idx = (p == m - c + (col & mod_c)) ? new_idx + m : new_idx;
285 ::impl(new_idx, col));
290 template<
int m,
int p>
293 static const int mod_n = n - 1;
296 static const int mod_c = c - 1;
298 static const int n_div_c = n / c;
299 static const int mod_n_div_c = n_div_c - 1;
303 int offset = ((((idx >> log_c) * k) & mod_n_div_c) +
304 ((idx & mod_c) << log_n_div_c)) & mod_n;
311 template<
int index,
int offset,
int bound>
313 static const int value = (offset * index) % bound;
316 template<
typename Array,
int index,
int m,
typename Schema>
319 template<
int s,
int index,
int m>
325 int current_offset = (initial_offset + offset) &
WARP_MASK;
326 return Array(current_offset,
328 index + 1, m,
odd>::impl(initial_offset));
332 template<
int index,
int m>
338 int current_offset = (initial_offset + offset) &
WARP_MASK;
339 return Array(current_offset);
344 template<
int s,
int index,
int m>
349 int new_offset = (offset == lb) ? offset + m - 1 : offset - 1;
355 template<
int index,
int m>
360 return Array(offset);
365 template<
typename T,
int m>
368 template<
int s,
int m>
371 static const int mod_n = n - 1;
373 static const int n_div_c = n / c;
377 int new_offset = offset + 1;
378 new_offset = (new_offset == ub) ? lb : new_offset;
382 ::impl(col, new_offset, lb, ub));
390 static const int mod_n = n - 1;
392 static const int n_div_c = n / c;
401 template<
int m,
typename Schema>
408 0, m, Schema>::impl(initial_offset);
412 template<
typename Data,
typename Indices>
415 template<
typename T,
int m>
419 #if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000)
433 #if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000)
442 template<
typename Array,
typename Schema>
445 template<
typename Array>
447 __device__
static void impl(Array& indices,
int& rotation) {
448 indices = detail::c2r_compute_offsets<Array::size, odd>();
449 int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) &
WARP_MASK;
450 int size = Array::size;
452 rotation = (warp_id * r) % size;
456 template<
typename Array>
458 __device__
static void impl(Array& indices,
int& rotation) {
459 indices = detail::c2r_compute_offsets<Array::size, power_of_two>();
460 int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) &
WARP_MASK;
461 int size = Array::size;
462 rotation = (size - warp_id) & (size - 1);
466 template<
typename Array>
468 __device__
static void impl(Array& indices,
int& rotation) {
469 int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) &
WARP_MASK;
472 rotation = warp_id % Array::size;
476 template<
typename Array,
typename Indices,
typename Schema>
479 template<
typename Array,
typename Indices>
481 __device__
static void impl(Array& src,
482 const Indices& indices,
483 const int& rotation) {
489 template<
typename Array,
typename Indices>
491 __device__
static void impl(Array& src,
492 const Indices& indices,
493 const int& rotation) {
494 int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) &
WARP_MASK;
496 src =
rotate(src, pre_rotation);
501 template<
typename Array,
typename Indices>
503 __device__
static void impl(Array& src,
504 const Indices& indices,
505 const int& rotation) {
506 int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) &
WARP_MASK;
508 src =
rotate(src, pre_rotation);
510 src =
rotate(src, rotation);
515 template<
typename Array,
typename Schema>
518 template<
typename Array>
520 __device__
static void impl(Array& indices,
int& rotation) {
522 detail::r2c_compute_offsets<Array::size, odd>();
523 int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) &
WARP_MASK;
524 int size = Array::size;
527 rotation = (warp_id * r) % size;
531 template<
typename Array>
533 static const int m = Array::size;
535 static const int clear_m = ~(m-1);
538 static const int mod_n = n-1;
541 __device__
static void impl(Array& indices,
int& rotation) {
542 int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) &
WARP_MASK;
543 int size = Array::size;
544 rotation = warp_id % size;
545 int initial_offset = ((warp_id << log_m) + (warp_id >> log_n_div_m)) & mod_n;
546 int lb = initial_offset & clear_m;
551 template<
typename Array>
553 static const int size = Array::size;
555 __device__
static void impl(Array& indices,
int& rotation) {
556 int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) &
WARP_MASK;
557 rotation = size - (warp_id % size);
560 int offset = lb + warp_id / (
WARP_SIZE / c);
565 template<
typename Array,
typename Indices,
typename Schema>
568 template<
typename Array,
typename Indices>
570 __device__
static void impl(Array& src,
571 const Indices& indices,
572 const int& rotation) {
573 Array rotated =
rotate(src, rotation);
579 template<
typename Array,
typename Indices>
581 __device__
static void impl(Array& src,
582 const Indices& indices,
583 const int& rotation) {
584 Array rotated =
rotate(src, rotation);
586 const int size = Array::size;
587 int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) &
WARP_MASK;
593 template<
typename Array,
typename Indices>
596 static const int size = Array::size;
597 __device__
static void impl(Array& src,
598 const Indices& indices,
599 const int& rotation) {
600 int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) &
WARP_MASK;
602 src =
rotate(src, rotation);
616 ::impl(indices, rotation);
620 template<
typename T,
int i>
628 impl(src, indices, rotation);
631 template<
typename T,
int i>
635 indices_array indices;
642 impl(src, indices, rotation);
650 ::impl(indices, rotation);
654 template<
typename T,
int i>
662 ::impl(src, indices, rotation);
665 template<
typename T,
int i>
669 indices_array indices;
676 ::impl(src, indices, rotation);
__device__ __forceinline__ T __shfl(const T &t, const int &i)
__host__ __device__ Array composite_r2c_tx_permute(const Array &t)
__device__ array< int, m > c2r_compute_offsets()
__device__ array< int, m > r2c_compute_offsets()
__host__ __device__ Array c2r_tx_permute(const Array &t)
__host__ __device__ Array composite_c2r_tx_permute(const Array &t)
__host__ __device__ Array r2c_tx_permute(const Array &t)
__device__ void r2c_compute_indices(array< int, i > &indices, int &rotation)
__host__ __device__ array< T, i > rotate(const array< T, i > &t, int a)
__device__ void c2r_compute_indices(array< int, i > &indices, int &rotation)
__device__ void r2c_warp_transpose(array< T, i > &src, const array< int, i > &indices, int rotation)
__device__ void c2r_warp_transpose(array< T, i > &src, const array< int, i > &indices, int rotation)
array< int, 1 > result_type
__host__ static __device__ result_type impl(int idx, int col)
__host__ static __device__ result_type impl(int idx, int col)
array< int, s > result_type
static __device__ void impl(Array &indices, int &rotation)
static __device__ void impl(Array &indices, int &rotation)
static __device__ void impl(Array &indices, int &rotation)
static __device__ int impl()
c2r_offset_constants< m > constants
static __device__ int impl()
static __device__ Array impl(int offset)
static __device__ Array impl(int offset)
static __device__ void impl(Array &src, const Indices &indices, const int &rotation)
static __device__ void impl(Array &src, const Indices &indices, const int &rotation)
static __device__ void impl(Array &src, const Indices &indices, const int &rotation)
array< int, 1 > result_type
__host__ static __device__ result_type impl(int col, int offset, int lb, int ub)
array< int, s > result_type
__host__ static __device__ result_type impl(int col, int offset, int lb, int ub)
static __device__ void impl(Array &indices, int &rotation)
static __device__ void impl(Array &indices, int &rotation)
static __device__ void impl(Array &indices, int &rotation)
static __device__ int impl()
static __device__ Array impl(int initial_offset)
static __device__ Array impl(int offset, int lb)
static __device__ Array impl(int initial_offset)
static __device__ Array impl(int offset, int lb)
static __device__ void impl(Array &src, const Indices &indices, const int &rotation)
static __device__ void impl(Array &src, const Indices &indices, const int &rotation)
static __device__ void impl(Array &src, const Indices &indices, const int &rotation)
__host__ static __device__ Remaining impl(const Source &src)
__host__ static __device__ Remaining impl(const Source &src)
static __device__ void impl(array< T, 1 > &d, const array< int, 1 > &i)
static __device__ void impl(array< T, m > &d, const array< int, m > &i)