QUDA  v1.1.0
A library for QCD on GPUs
blas_lapack_hipblas.cpp
Go to the documentation of this file.
1 #include <blas_lapack.h>
2 #ifdef NATIVE_LAPACK_LIB
3 #include <cublas_v2.h>
4 #include <malloc_quda.h>
5 #endif
6 
7 //#define _DEBUG
8 
9 #ifdef _DEBUG
10 #include <eigen_helper.h>
11 #endif
12 
13 namespace quda
14 {
15 
16  namespace blas_lapack
17  {
18 
19  namespace native
20  {
21 
22 #ifdef NATIVE_LAPACK_LIB
23  static cublasHandle_t handle;
24 #endif
25  static bool cublas_init = false;
26 
27  void init()
28  {
29  if (!cublas_init) {
30 #ifdef NATIVE_LAPACK_LIB
31  cublasStatus_t error = cublasCreate(&handle);
32  if (error != CUBLAS_STATUS_SUCCESS) errorQuda("cublasCreate failed with error %d", error);
33  cublas_init = true;
34 #endif
35  }
36  }
37 
38  void destroy()
39  {
40  if (cublas_init) {
41 #ifdef NATIVE_LAPACK_LIB
42  cublasStatus_t error = cublasDestroy(handle);
43  if (error != CUBLAS_STATUS_SUCCESS)
44  errorQuda("\nError indestroying cublas context, error code = %d\n", error);
45  cublas_init = false;
46 #endif
47  }
48  }
49 
50 #ifdef _DEBUG
51  template <typename EigenMatrix, typename Float>
52  __host__ void checkEigen(std::complex<Float> *A_h, std::complex<Float> *Ainv_h, int n, uint64_t batch)
53  {
54  EigenMatrix A = EigenMatrix::Zero(n, n);
55  EigenMatrix Ainv = EigenMatrix::Zero(n, n);
56  for (int j = 0; j < n; j++) {
57  for (int k = 0; k < n; k++) {
58  A(k, j) = A_h[batch * n * n + j * n + k];
59  Ainv(k, j) = Ainv_h[batch * n * n + j * n + k];
60  }
61  }
62 
63  // Check result:
64  EigenMatrix unit = EigenMatrix::Identity(n, n);
65  EigenMatrix prod = A * Ainv;
66  Float L2norm = ((prod - unit).norm() / (n * n));
67  printfQuda("cuBLAS: Norm of (A * Ainv - I) batch %lu = %e\n", batch, L2norm);
68  }
69 #endif
70 
71  // FIXME do this in pipelined fashion to reduce memory overhead.
72  long long BatchInvertMatrix(void *Ainv, void *A, const int n, const uint64_t batch, QudaPrecision prec,
73  QudaFieldLocation location)
74  {
75 #ifdef NATIVE_LAPACK_LIB
76  init();
77  if (getVerbosity() >= QUDA_VERBOSE)
78  printfQuda("BatchInvertMatrix (native - cuBLAS): Nc = %d, batch = %lu\n", n, batch);
79 
80  long long flops = 0;
81  timeval start, stop;
82  gettimeofday(&start, NULL);
83 
84  size_t size = 2 * n * n * prec * batch;
85  void *A_d = location == QUDA_CUDA_FIELD_LOCATION ? A : pool_device_malloc(size);
86  void *Ainv_d = location == QUDA_CUDA_FIELD_LOCATION ? Ainv : pool_device_malloc(size);
87  if (location == QUDA_CPU_FIELD_LOCATION) qudaMemcpy(A_d, A, size, cudaMemcpyHostToDevice);
88 
89 #ifdef _DEBUG
90  // Debug code: Copy original A matrix to host
91  std::complex<float> *A_h
92  = (location == QUDA_CUDA_FIELD_LOCATION ? static_cast<std::complex<float> *>(pool_pinned_malloc(size)) :
93  static_cast<std::complex<float> *>(A_d));
94  if (location == QUDA_CUDA_FIELD_LOCATION) qudaMemcpy((void *)A_h, A_d, size, cudaMemcpyDeviceToHost);
95 #endif
96 
97  int *dipiv = static_cast<int *>(pool_device_malloc(batch * n * sizeof(int)));
98  int *dinfo_array = static_cast<int *>(pool_device_malloc(batch * sizeof(int)));
99  int *info_array = static_cast<int *>(pool_pinned_malloc(batch * sizeof(int)));
100  memset(info_array, '0', batch * sizeof(int)); // silence memcheck warnings
101 
102  if (prec == QUDA_SINGLE_PRECISION) {
103  typedef cuFloatComplex C;
104  C **A_array = static_cast<C **>(pool_device_malloc(2 * batch * sizeof(C *)));
105  C **Ainv_array = A_array + batch;
106  C **A_array_h = static_cast<C **>(pool_pinned_malloc(2 * batch * sizeof(C *)));
107  C **Ainv_array_h = A_array_h + batch;
108  for (uint64_t i = 0; i < batch; i++) {
109  A_array_h[i] = static_cast<C *>(A_d) + i * n * n;
110  Ainv_array_h[i] = static_cast<C *>(Ainv_d) + i * n * n;
111  }
112  qudaMemcpy(A_array, A_array_h, 2 * batch * sizeof(C *), cudaMemcpyHostToDevice);
113 
114  cublasStatus_t error = cublasCgetrfBatched(handle, n, A_array, n, dipiv, dinfo_array, batch);
115  flops += batch * FLOPS_CGETRF(n, n);
116 
117  if (error != CUBLAS_STATUS_SUCCESS)
118  errorQuda("\nError in LU decomposition (cublasCgetrfBatched), error code = %d\n", error);
119 
120  qudaMemcpy(info_array, dinfo_array, batch * sizeof(int), cudaMemcpyDeviceToHost);
121  for (uint64_t i = 0; i < batch; i++) {
122  if (info_array[i] < 0) {
123  errorQuda("%lu argument had an illegal value or another error occured, such as memory allocation failed",
124  i);
125  } else if (info_array[i] > 0) {
126  errorQuda("%lu factorization completed but the factor U is exactly singular", i);
127  }
128  }
129 
130  error = cublasCgetriBatched(handle, n, (const C **)A_array, n, dipiv, Ainv_array, n, dinfo_array, batch);
131  flops += batch * FLOPS_CGETRI(n);
132 
133  if (error != CUBLAS_STATUS_SUCCESS)
134  errorQuda("\nError in matrix inversion (cublasCgetriBatched), error code = %d\n", error);
135 
136  qudaMemcpy(info_array, dinfo_array, batch * sizeof(int), cudaMemcpyDeviceToHost);
137 
138  for (uint64_t i = 0; i < batch; i++) {
139  if (info_array[i] < 0) {
140  errorQuda("%lu argument had an illegal value or another error occured, such as memory allocation failed",
141  i);
142  } else if (info_array[i] > 0) {
143  errorQuda("%lu factorization completed but the factor U is exactly singular", i);
144  }
145  }
146 
147  pool_device_free(A_array);
148  pool_pinned_free(A_array_h);
149 
150 #ifdef _DEBUG
151  // Debug code: Copy computed Ainv to host
152  std::complex<float> *Ainv_h = static_cast<std::complex<float> *>(pool_pinned_malloc(size));
153  qudaMemcpy((void *)Ainv_h, Ainv_d, size, cudaMemcpyDeviceToHost);
154 
155  for (uint64_t i = 0; i < batch; i++) { checkEigen<MatrixXcf, float>(A_h, Ainv_h, n, i); }
156  pool_pinned_free(Ainv_h);
157  pool_pinned_free(A_h);
158 #endif
159  } else {
160  errorQuda("%s not implemented for precision=%d", __func__, prec);
161  }
162 
163  if (location == QUDA_CPU_FIELD_LOCATION) {
164  qudaMemcpy(Ainv, Ainv_d, size, cudaMemcpyDeviceToHost);
165  pool_device_free(Ainv_d);
166  pool_device_free(A_d);
167  }
168 
169  pool_device_free(dipiv);
170  pool_device_free(dinfo_array);
171  pool_pinned_free(info_array);
172 
174  gettimeofday(&stop, NULL);
175  long ds = stop.tv_sec - start.tv_sec;
176  long dus = stop.tv_usec - start.tv_usec;
177  double time = ds + 0.000001 * dus;
178 
179  if (getVerbosity() >= QUDA_VERBOSE)
180  printfQuda("Batched matrix inversion completed in %f seconds with GFLOPS = %f\n", time, 1e-9 * flops / time);
181 
182  return flops;
183 #else
184  errorQuda("Native BLAS not built. Please build and use native BLAS or use generic BLAS");
185  return 0; // Stops a compiler warning
186 #endif
187  }
188 
189  } // namespace native
190  } // namespace blas_lapack
191 } // namespace quda
#define FLOPS_CGETRF(m_, n_)
Definition: blas_lapack.h:14
#define FLOPS_CGETRI(n_)
Definition: blas_lapack.h:21
QudaPrecision prec
void * memset(void *s, int c, size_t n)
enum QudaPrecision_s QudaPrecision
@ QUDA_CUDA_FIELD_LOCATION
Definition: enum_quda.h:326
@ QUDA_CPU_FIELD_LOCATION
Definition: enum_quda.h:325
@ QUDA_VERBOSE
Definition: enum_quda.h:267
enum QudaFieldLocation_s QudaFieldLocation
@ QUDA_SINGLE_PRECISION
Definition: enum_quda.h:64
#define pool_pinned_malloc(size)
Definition: malloc_quda.h:172
#define pool_device_malloc(size)
Definition: malloc_quda.h:170
#define pool_pinned_free(ptr)
Definition: malloc_quda.h:173
#define pool_device_free(ptr)
Definition: malloc_quda.h:171
long long BatchInvertMatrix(void *Ainv, void *A, const int n, const uint64_t batch, QudaPrecision precision, QudaFieldLocation location)
Batch inversion the matrix field using an LU decomposition method.
void init()
Create the BLAS context.
void destroy()
Destroy the BLAS context.
unsigned long long flops
void stop()
Stop profiling.
Definition: device.cpp:228
void start()
Start profiling.
Definition: device.cpp:226
__host__ __device__ ValueType norm(const complex< ValueType > &z)
Returns the magnitude of z squared.
FloatingPoint< float > Float
#define qudaMemcpy(dst, src, count, kind)
Definition: quda_api.h:204
#define qudaDeviceSynchronize()
Definition: quda_api.h:250
#define printfQuda(...)
Definition: util_quda.h:114
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define errorQuda(...)
Definition: util_quda.h:120