QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
thrust_helper.cuh
Go to the documentation of this file.
1 #pragma once
2 
3 #include <malloc_quda.h>
4 
5 #undef device_malloc
6 #undef device_free
7 
8 // ensures we use shfl_sync and not shfl when compiling with clang
9 #if defined(__clang__) && defined(__CUDA__) && CUDA_VERSION >= 9000
10 #define CUB_USE_COOPERATIVE_GROUPS
11 #endif
12 
13 #include <thrust/system/cuda/vector.h>
14 #include <thrust/system/cuda/execution_policy.h>
15 #include <thrust/transform_reduce.h>
16 #include <thrust/device_ptr.h>
17 #include <thrust/device_vector.h>
18 #include <thrust/sort.h>
19 
20 #define device_malloc(size) quda::device_malloc_(__func__, quda::file_name(__FILE__), __LINE__, size)
21 #define device_free(ptr) quda::device_free_(__func__, quda::file_name(__FILE__), __LINE__, ptr)
22 
28 {
29 public:
30  // just allocate bytes
31  typedef char value_type;
32 
35 
36  char *allocate(std::ptrdiff_t num_bytes) { return reinterpret_cast<char*>(pool_device_malloc(num_bytes)); }
37  void deallocate(char *ptr, size_t n) { pool_device_free(ptr); }
38 
39 };
void deallocate(char *ptr, size_t n)
#define pool_device_malloc(size)
Definition: malloc_quda.h:125
char * allocate(std::ptrdiff_t num_bytes)
#define pool_device_free(ptr)
Definition: malloc_quda.h:126