QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
blas_cublas.cu
Go to the documentation of this file.
1 #ifdef CUBLAS_LIB
2 #include <blas_cublas.h>
3 #include <cublas_v2.h>
4 #endif
5 #include <malloc_quda.h>
6 
7 #define FMULS_GETRF(m_, n_) ( ((m_) < (n_)) \
8  ? (0.5 * (m_) * ((m_) * ((n_) - (1./3.) * (m_) - 1. ) + (n_)) + (2. / 3.) * (m_)) \
9  : (0.5 * (n_) * ((n_) * ((m_) - (1./3.) * (n_) - 1. ) + (m_)) + (2. / 3.) * (n_)) )
10 #define FADDS_GETRF(m_, n_) ( ((m_) < (n_)) \
11  ? (0.5 * (m_) * ((m_) * ((n_) - (1./3.) * (m_) ) - (n_)) + (1. / 6.) * (m_)) \
12  : (0.5 * (n_) * ((n_) * ((m_) - (1./3.) * (n_) ) - (m_)) + (1. / 6.) * (n_)) )
13 
14 #define FLOPS_ZGETRF(m_, n_) (6. * FMULS_GETRF((double)(m_), (double)(n_)) + 2.0 * FADDS_GETRF((double)(m_), (double)(n_)) )
15 #define FLOPS_CGETRF(m_, n_) (6. * FMULS_GETRF((double)(m_), (double)(n_)) + 2.0 * FADDS_GETRF((double)(m_), (double)(n_)) )
16 
17 #define FMULS_GETRI(n_) ( (n_) * ((5. / 6.) + (n_) * ((2. / 3.) * (n_) + 0.5)) )
18 #define FADDS_GETRI(n_) ( (n_) * ((5. / 6.) + (n_) * ((2. / 3.) * (n_) - 1.5)) )
19 
20 #define FLOPS_ZGETRI(n_) (6. * FMULS_GETRI((double)(n_)) + 2.0 * FADDS_GETRI((double)(n_)) )
21 #define FLOPS_CGETRI(n_) (6. * FMULS_GETRI((double)(n_)) + 2.0 * FADDS_GETRI((double)(n_)) )
22 
23 namespace quda {
24 
25  namespace cublas {
26 
27 #ifdef CUBLAS_LIB
28  static cublasHandle_t handle;
29 #endif
30 
31  void init() {
32 #ifdef CUBLAS_LIB
33  cublasStatus_t error = cublasCreate(&handle);
34  if (error != CUBLAS_STATUS_SUCCESS) errorQuda("cublasCreate failed with error %d", error);
35 #endif
36  }
37 
38  void destroy() {
39 #ifdef CUBLAS_LIB
40  cublasStatus_t error = cublasDestroy(handle);
41  if (error != CUBLAS_STATUS_SUCCESS) errorQuda("\nError indestroying cublas context, error code = %d\n", error);
42 #endif
43  }
44 
45  // mini kernel to set the array of pointers needed for batched cublas
46  template<typename T>
47  __global__ void set_pointer(T **output_array_a, T *input_a, T **output_array_b, T *input_b, int batch_offset)
48  {
49  output_array_a[blockIdx.x] = input_a + blockIdx.x * batch_offset;
50  output_array_b[blockIdx.x] = input_b + blockIdx.x * batch_offset;
51  }
52 
53  // FIXME do this in pipelined fashion to reduce memory overhead.
54  long long BatchInvertMatrix(void *Ainv, void* A, const int n, const int batch, QudaPrecision prec, QudaFieldLocation location)
55  {
56  long long flops = 0;
57 #ifdef CUBLAS_LIB
58  timeval start, stop;
59  gettimeofday(&start, NULL);
60 
61  size_t size = 2*n*n*prec*batch;
62  void *A_d = location == QUDA_CUDA_FIELD_LOCATION ? A : pool_device_malloc(size);
63  void *Ainv_d = location == QUDA_CUDA_FIELD_LOCATION ? Ainv : pool_device_malloc(size);
64  if (location == QUDA_CPU_FIELD_LOCATION) qudaMemcpy(A_d, A, size, cudaMemcpyHostToDevice);
65 
66  int *dipiv = static_cast<int*>(pool_device_malloc(batch*n*sizeof(int)));
67  int *dinfo_array = static_cast<int*>(pool_device_malloc(batch*sizeof(int)));
68  int *info_array = static_cast<int*>(pool_pinned_malloc(batch*sizeof(int)));
69 
70  if (prec == QUDA_SINGLE_PRECISION) {
71  typedef cuFloatComplex C;
72  C **A_array = static_cast<C**>(pool_device_malloc(batch*sizeof(C*)));
73  C **Ainv_array = static_cast<C**>(pool_device_malloc(batch*sizeof(C*)));
74 
75  set_pointer<C><<<batch,1>>>(A_array, (C*)A_d, Ainv_array, (C*)Ainv_d, n*n);
76 
77  cublasStatus_t error = cublasCgetrfBatched(handle, n, A_array, n, dipiv, dinfo_array, batch);
78  flops += batch*FLOPS_CGETRF(n,n);
79 
80  if (error != CUBLAS_STATUS_SUCCESS)
81  errorQuda("\nError in LU decomposition (cublasCgetrfBatched), error code = %d\n", error);
82 
83  qudaMemcpy(info_array, dinfo_array, batch*sizeof(int), cudaMemcpyDeviceToHost);
84  for (int i=0; i<batch; i++) {
85  if (info_array[i] < 0) {
86  errorQuda("%d argument had an illegal value or another error occured, such as memory allocation failed", i);
87  } else if (info_array[i] > 0) {
88  errorQuda("%d factorization completed but the factor U is exactly singular", i);
89  }
90  }
91 
92  error = cublasCgetriBatched(handle, n, (const C**)A_array, n, dipiv, Ainv_array, n, dinfo_array, batch);
93  flops += batch*FLOPS_CGETRI(n);
94 
95  if (error != CUBLAS_STATUS_SUCCESS)
96  errorQuda("\nError in matrix inversion (cublasCgetriBatched), error code = %d\n", error);
97 
98  qudaMemcpy(info_array, dinfo_array, batch*sizeof(int), cudaMemcpyDeviceToHost);
99 
100  for (int i=0; i<batch; i++) {
101  if (info_array[i] < 0) {
102  errorQuda("%d argument had an illegal value or another error occured, such as memory allocation failed", i);
103  } else if (info_array[i] > 0) {
104  errorQuda("%d factorization completed but the factor U is exactly singular", i);
105  }
106  }
107 
108  pool_device_free(Ainv_array);
109  pool_device_free(A_array);
110 
111  } else {
112  errorQuda("%s not implemented for precision=%d", __func__, prec);
113  }
114 
115  if (location == QUDA_CPU_FIELD_LOCATION) {
116  qudaMemcpy(Ainv, Ainv_d, size, cudaMemcpyDeviceToHost);
117  pool_device_free(Ainv_d);
118  pool_device_free(A_d);
119  }
120 
121  pool_device_free(dipiv);
122  pool_device_free(dinfo_array);
123  pool_pinned_free(info_array);
124 
126  gettimeofday(&stop, NULL);
127  long ds = stop.tv_sec - start.tv_sec;
128  long dus = stop.tv_usec - start.tv_usec;
129  double time = ds + 0.000001*dus;
130 
131  if (getVerbosity() >= QUDA_VERBOSE)
132  printfQuda("Batched matrix inversion completed in %f seconds with GFLOPS = %f\n", time, 1e-9 * flops / time);
133 #endif // CUBLAS_LIB
134 
135  return flops;
136  }
137 
138  } // namespace cublas
139 
140 } // namespace quda
#define FLOPS_CGETRI(n_)
Definition: blas_cublas.cu:21
#define qudaMemcpy(dst, src, count, kind)
Definition: quda_cuda_api.h:33
#define pool_pinned_free(ptr)
Definition: malloc_quda.h:128
enum QudaPrecision_s QudaPrecision
void destroy()
Destroy the CUBLAS context.
Definition: blas_cublas.cu:38
__global__ void set_pointer(T **output_array_a, T *input_a, T **output_array_b, T *input_b, int batch_offset)
Definition: blas_cublas.cu:47
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define errorQuda(...)
Definition: util_quda.h:121
#define qudaDeviceSynchronize()
long long BatchInvertMatrix(void *Ainv, void *A, const int n, const int batch, QudaPrecision precision, QudaFieldLocation location)
Definition: blas_cublas.cu:54
#define pool_device_malloc(size)
Definition: malloc_quda.h:125
constexpr int size
void init()
Create the CUBLAS context.
Definition: blas_cublas.cu:31
#define FLOPS_CGETRF(m_, n_)
Definition: blas_cublas.cu:15
#define pool_pinned_malloc(size)
Definition: malloc_quda.h:127
enum QudaFieldLocation_s QudaFieldLocation
#define printfQuda(...)
Definition: util_quda.h:115
unsigned long long flops
Definition: blas_quda.cu:22
#define pool_device_free(ptr)
Definition: malloc_quda.h:126
QudaPrecision prec
Definition: test_util.cpp:1608