QUDA  v1.1.0
A library for QCD on GPUs
blas_lapack_eigen.cpp
Go to the documentation of this file.
1 #include <blas_lapack.h>
2 #include <eigen_helper.h>
3 
4 //#define _DEBUG
5 
6 namespace quda
7 {
8  namespace blas_lapack
9  {
10 
11  // whether we are using the native blas-lapack library
12  static bool native_blas_lapack = true;
13  bool use_native() { return native_blas_lapack; }
14  void set_native(bool native) { native_blas_lapack = native; }
15 
16  namespace generic
17  {
18 
19  void init() {}
20 
21  void destroy() {}
22 
23  // Batched inversion ckecking
24  //---------------------------------------------------
25  template <typename EigenMatrix, typename Float>
26  void invertEigen(std::complex<Float> *A_eig, std::complex<Float> *Ainv_eig, int n, uint64_t batch)
27  {
28  EigenMatrix res = EigenMatrix::Zero(n, n);
29  EigenMatrix inv = EigenMatrix::Zero(n, n);
30  for (int j = 0; j < n; j++) {
31  for (int k = 0; k < n; k++) { res(k, j) = A_eig[batch * n * n + j * n + k]; }
32  }
33 
34  inv = res.inverse();
35 
36  for (int j = 0; j < n; j++) {
37  for (int k = 0; k < n; k++) { Ainv_eig[batch * n * n + j * n + k] = inv(k, j); }
38  }
39 
40  // Check result:
41 #ifdef _DEBUG
42  EigenMatrix unit = EigenMatrix::Identity(n, n);
43  EigenMatrix prod = res * inv;
44  Float L2norm = ((prod - unit).norm() / (n * n));
45  printfQuda("Eigen: Norm of (A * Ainv - I) batch %lu = %e\n", batch, L2norm);
46 #endif
47  }
48  //---------------------------------------------------
49 
50  // Batched Inversions
51  //---------------------------------------------------
52  long long BatchInvertMatrix(void *Ainv, void *A, const int n, const uint64_t batch, QudaPrecision prec,
53  QudaFieldLocation location)
54  {
55  if (getVerbosity() >= QUDA_VERBOSE)
56  printfQuda("BatchInvertMatrix (generic - Eigen): Nc = %d, batch = %lu\n", n, batch);
57 
58  size_t size = 2 * n * n * batch * prec;
59  void *A_h = (location == QUDA_CUDA_FIELD_LOCATION ? pool_pinned_malloc(size) : A);
60  void *Ainv_h = (location == QUDA_CUDA_FIELD_LOCATION ? pool_pinned_malloc(size) : Ainv);
61  if (location == QUDA_CUDA_FIELD_LOCATION) { qudaMemcpy(A_h, A, size, cudaMemcpyDeviceToHost); }
62 
63  long long flops = 0;
64  timeval start, stop;
65  gettimeofday(&start, NULL);
66 
67  if (prec == QUDA_SINGLE_PRECISION) {
68  std::complex<float> *A_eig = (std::complex<float> *)A_h;
69  std::complex<float> *Ainv_eig = (std::complex<float> *)Ainv_h;
70 
71 #ifdef _OPENMP
72 #pragma omp parallel for
73 #endif
74  for (uint64_t i = 0; i < batch; i++) { invertEigen<MatrixXcf, float>(A_eig, Ainv_eig, n, i); }
75  flops += batch * FLOPS_CGETRF(n, n);
76  } else if (prec == QUDA_DOUBLE_PRECISION) {
77  std::complex<double> *A_eig = (std::complex<double> *)A_h;
78  std::complex<double> *Ainv_eig = (std::complex<double> *)Ainv_h;
79 
80 #ifdef _OPENMP
81 #pragma omp parallel for
82 #endif
83  for (uint64_t i = 0; i < batch; i++) { invertEigen<MatrixXcd, double>(A_eig, Ainv_eig, n, i); }
84  flops += batch * FLOPS_ZGETRF(n, n);
85  } else {
86  errorQuda("%s not implemented for precision = %d", __func__, prec);
87  }
88 
89  gettimeofday(&stop, NULL);
90  long dsh = stop.tv_sec - start.tv_sec;
91  long dush = stop.tv_usec - start.tv_usec;
92  double timeh = dsh + 0.000001 * dush;
93 
94  if (getVerbosity() >= QUDA_VERBOSE) {
95  int threads = 1;
96 #ifdef _OPENMP
97  threads = omp_get_num_threads();
98 #endif
99  printfQuda("CPU: Batched matrix inversion completed in %f seconds using %d threads with GFLOPS = %f\n", timeh,
100  threads, 1e-9 * flops / timeh);
101  }
102 
103  if (location == QUDA_CUDA_FIELD_LOCATION) {
104  pool_pinned_free(Ainv_h);
105  pool_pinned_free(A_h);
106  qudaMemcpy((void *)Ainv, Ainv_h, size, cudaMemcpyHostToDevice);
107  }
108 
109  return flops;
110  }
111 
112  // Srided Batched GEMM helpers
113  //--------------------------------------------------------------------------
114  template <typename EigenMat, typename T>
115  void fillArray(EigenMat &EigenArr, T *arr, int rows, int cols, int ld, int offset, bool fill_eigen)
116  {
117  int counter = offset;
118  for (int i = 0; i < rows; i++) {
119  for (int j = 0; j < cols; j++) {
120  if (fill_eigen)
121  EigenArr(i, j) = arr[counter];
122  else
123  arr[counter] = EigenArr(i, j);
124  counter++;
125  }
126  counter += (ld - cols);
127  }
128  }
129 
130  template <typename EigenMat, typename T>
131  void GEMM(void *A_h, void *B_h, void *C_h, T alpha, T beta, int max_stride, QudaBLASParam &blas_param)
132  {
133  // Problem parameters
134  int m = blas_param.m;
135  int n = blas_param.n;
136  int k = blas_param.k;
137  int lda = blas_param.lda;
138  int ldb = blas_param.ldb;
139  int ldc = blas_param.ldc;
140 
141  // If the user did not set any stride values, we default them to 1
142  // as batch size 0 is an option.
143  int a_stride = blas_param.a_stride == 0 ? 1 : blas_param.a_stride;
144  int b_stride = blas_param.b_stride == 0 ? 1 : blas_param.b_stride;
145  int c_stride = blas_param.c_stride == 0 ? 1 : blas_param.c_stride;
146  int a_offset = blas_param.a_offset;
147  int b_offset = blas_param.b_offset;
148  int c_offset = blas_param.c_offset;
149  int batches = blas_param.batch_count;
150 
151  // Number of data between batches
152  unsigned int A_batch_size = blas_param.lda * blas_param.k;
153  if (blas_param.trans_a != QUDA_BLAS_OP_N) A_batch_size = blas_param.lda * blas_param.m;
154  unsigned int B_batch_size = blas_param.ldb * blas_param.n;
155  if (blas_param.trans_b != QUDA_BLAS_OP_N) B_batch_size = blas_param.ldb * blas_param.k;
156  unsigned int C_batch_size = blas_param.ldc * blas_param.n;
157 
158  T *A_ptr = (T *)(&A_h)[0];
159  T *B_ptr = (T *)(&B_h)[0];
160  T *C_ptr = (T *)(&C_h)[0];
161 
162  // Eigen objects to store data
163  EigenMat Amat = EigenMat::Zero(m, k);
164  EigenMat Bmat = EigenMat::Zero(k, n);
165  EigenMat Cmat = EigenMat::Zero(m, n);
166 
167  for (int batch = 0; batch < batches; batch += max_stride) {
168 
169  // Populate Eigen objects
170  fillArray<EigenMat, T>(Amat, A_ptr, m, k, lda, a_offset, true);
171  fillArray<EigenMat, T>(Bmat, B_ptr, k, n, ldb, b_offset, true);
172  fillArray<EigenMat, T>(Cmat, C_ptr, m, n, ldc, c_offset, true);
173 
174  // Apply op(A) and op(B)
175  switch (blas_param.trans_a) {
176  case QUDA_BLAS_OP_T: Amat.transposeInPlace(); break;
177  case QUDA_BLAS_OP_C: Amat.adjointInPlace(); break;
178  case QUDA_BLAS_OP_N: break;
179  default: errorQuda("Unknown blas op type %d", blas_param.trans_a);
180  }
181 
182  switch (blas_param.trans_b) {
183  case QUDA_BLAS_OP_T: Bmat.transposeInPlace(); break;
184  case QUDA_BLAS_OP_C: Bmat.adjointInPlace(); break;
185  case QUDA_BLAS_OP_N: break;
186  default: errorQuda("Unknown blas op type %d", blas_param.trans_b);
187  }
188 
189  // Perform GEMM using Eigen
190  Cmat = alpha * Amat * Bmat + beta * Cmat;
191 
192  // Write back to the C array
193  fillArray<EigenMat, T>(Cmat, C_ptr, m, n, ldc, c_offset, false);
194 
195  a_offset += A_batch_size * a_stride;
196  b_offset += B_batch_size * b_stride;
197  c_offset += C_batch_size * c_stride;
198  }
199  }
200  //---------------------------------------------------
201 
202  // Strided Batched GEMM
203  //---------------------------------------------------
204  long long stridedBatchGEMM(void *A_data, void *B_data, void *C_data, QudaBLASParam blas_param,
205  QudaFieldLocation location)
206  {
207  long long flops = 0;
208  timeval start, stop;
209  gettimeofday(&start, NULL);
210 
211  // Sanity checks on parameters
212  //-------------------------------------------------------------------------
213  // If the user passes non positive M,N, or K, we error out
214  int min_dim = std::min(blas_param.m, std::min(blas_param.n, blas_param.k));
215  if (min_dim <= 0) {
216  errorQuda("BLAS dims must be positive: m=%d, n=%d, k=%d", blas_param.m, blas_param.n, blas_param.k);
217  }
218 
219  // If the user passes a negative stride, we error out as this has no meaning.
220  int min_stride = std::min(std::min(blas_param.a_stride, blas_param.b_stride), blas_param.c_stride);
221  if (min_stride < 0) {
222  errorQuda("BLAS strides must be positive or zero: a_stride=%d, b_stride=%d, c_stride=%d", blas_param.a_stride,
223  blas_param.b_stride, blas_param.c_stride);
224  }
225 
226  // If the user passes a negative offset, we error out as this has no meaning.
227  int min_offset = std::min(std::min(blas_param.a_offset, blas_param.b_offset), blas_param.c_offset);
228  if (min_offset < 0) {
229  errorQuda("BLAS offsets must be positive or zero: a_offset=%d, b_offset=%d, c_offset=%d", blas_param.a_offset,
230  blas_param.b_offset, blas_param.c_offset);
231  }
232 
233  // If the batch value is non-positve, we error out
234  if (blas_param.batch_count <= 0) { errorQuda("Batches must be positive: batches=%d", blas_param.batch_count); }
235 
236  // Leading dims are dependendent on the matrix op type.
237  if (blas_param.data_order == QUDA_BLAS_DATAORDER_COL) {
238  if (blas_param.trans_a == QUDA_BLAS_OP_N) {
239  if (blas_param.lda < std::max(1, blas_param.m))
240  errorQuda("lda=%d must be >= max(1,m=%d)", blas_param.lda, blas_param.m);
241  } else {
242  if (blas_param.lda < std::max(1, blas_param.k))
243  errorQuda("lda=%d must be >= max(1,k=%d)", blas_param.lda, blas_param.k);
244  }
245 
246  if (blas_param.trans_b == QUDA_BLAS_OP_N) {
247  if (blas_param.ldb < std::max(1, blas_param.k))
248  errorQuda("ldb=%d must be >= max(1,k=%d)", blas_param.ldb, blas_param.k);
249  } else {
250  if (blas_param.ldb < std::max(1, blas_param.n))
251  errorQuda("ldb=%d must be >= max(1,n=%d)", blas_param.ldb, blas_param.n);
252  }
253  if (blas_param.ldc < std::max(1, blas_param.m))
254  errorQuda("ldc=%d must be >= max(1,m=%d)", blas_param.ldc, blas_param.m);
255  } else {
256  if (blas_param.trans_a == QUDA_BLAS_OP_N) {
257  if (blas_param.lda < std::max(1, blas_param.k))
258  errorQuda("lda=%d must be >= max(1,k=%d)", blas_param.lda, blas_param.k);
259  } else {
260  if (blas_param.lda < std::max(1, blas_param.m))
261  errorQuda("lda=%d must be >= max(1,m=%d)", blas_param.lda, blas_param.m);
262  }
263  if (blas_param.trans_b == QUDA_BLAS_OP_N) {
264  if (blas_param.ldb < std::max(1, blas_param.n))
265  errorQuda("ldb=%d must be >= max(1,n=%d)", blas_param.ldb, blas_param.n);
266  } else {
267  if (blas_param.ldb < std::max(1, blas_param.k))
268  errorQuda("ldb=%d must be >= max(1,k=%d)", blas_param.ldb, blas_param.k);
269  }
270  if (blas_param.ldc < std::max(1, blas_param.n))
271  errorQuda("ldc=%d must be >= max(1,n=%d)", blas_param.ldc, blas_param.n);
272  }
273  //-------------------------------------------------------------------------
274 
275  // Parse parameters for Eigen
276  //-------------------------------------------------------------------------
277  // Swap A and B if in column order
278  if (blas_param.data_order == QUDA_BLAS_DATAORDER_COL) {
279  std::swap(blas_param.m, blas_param.n);
280  std::swap(blas_param.lda, blas_param.ldb);
281  std::swap(blas_param.trans_a, blas_param.trans_b);
282  std::swap(blas_param.a_offset, blas_param.b_offset);
283  std::swap(blas_param.a_stride, blas_param.b_stride);
284  std::swap(A_data, B_data);
285  }
286 
287  // Get maximum stride length to deduce the number of batches in the
288  // computation
289  int max_stride = std::max(std::max(blas_param.a_stride, blas_param.b_stride), blas_param.c_stride);
290 
291  // If the user gives strides of 0 for all arrays, we are essentially performing
292  // a GEMM on the first matrices in the array N_{batch} times.
293  // Give them what they ask for, YMMV...
294  // If this evaluates to -1, the user did not set any strides.
295  if (max_stride <= 0) max_stride = 1;
296 
297  // Then number of GEMMs to compute
298  const uint64_t batch = blas_param.batch_count / max_stride;
299 
300  uint64_t data_size
301  = (blas_param.data_type == QUDA_BLAS_DATATYPE_S || blas_param.data_type == QUDA_BLAS_DATATYPE_C) ? 4 : 8;
302 
303  if (blas_param.data_type == QUDA_BLAS_DATATYPE_C || blas_param.data_type == QUDA_BLAS_DATATYPE_Z) {
304  data_size *= 2;
305  }
306 
307  // Number of data between batches
308  unsigned int A_batch_size = blas_param.lda * blas_param.k;
309  if (blas_param.trans_a != QUDA_BLAS_OP_N) A_batch_size = blas_param.lda * blas_param.m;
310  unsigned int B_batch_size = blas_param.ldb * blas_param.n;
311  if (blas_param.trans_b != QUDA_BLAS_OP_N) B_batch_size = blas_param.ldb * blas_param.k;
312  unsigned int C_batch_size = blas_param.ldc * blas_param.n;
313 
314  // Data size of the entire array
315  size_t sizeAarr = A_batch_size * data_size * batch;
316  size_t sizeBarr = B_batch_size * data_size * batch;
317  size_t sizeCarr = C_batch_size * data_size * batch;
318 
319  // If already on the host, just use the given pointer. If the data is on
320  // the device, allocate host memory and transfer
321  void *A_h = location == QUDA_CPU_FIELD_LOCATION ? A_data : pool_pinned_malloc(sizeAarr);
322  void *B_h = location == QUDA_CPU_FIELD_LOCATION ? B_data : pool_pinned_malloc(sizeBarr);
323  void *C_h = location == QUDA_CPU_FIELD_LOCATION ? C_data : pool_pinned_malloc(sizeCarr);
324  if (location == QUDA_CUDA_FIELD_LOCATION) {
325  qudaMemcpy(A_h, A_data, sizeAarr, cudaMemcpyDeviceToHost);
326  qudaMemcpy(B_h, B_data, sizeBarr, cudaMemcpyDeviceToHost);
327  qudaMemcpy(C_h, C_data, sizeCarr, cudaMemcpyDeviceToHost);
328  }
329 
330  if (blas_param.data_type == QUDA_BLAS_DATATYPE_Z) {
331 
332  typedef std::complex<double> Z;
333  const Z alpha = blas_param.alpha;
334  const Z beta = blas_param.beta;
335  GEMM<MatrixXcd, Z>(A_h, B_h, C_h, alpha, beta, max_stride, blas_param);
336 
337  } else if (blas_param.data_type == QUDA_BLAS_DATATYPE_C) {
338 
339  typedef std::complex<float> C;
340  const C alpha = blas_param.alpha;
341  const C beta = blas_param.beta;
342  GEMM<MatrixXcf, C>(A_h, B_h, C_h, alpha, beta, max_stride, blas_param);
343 
344  } else if (blas_param.data_type == QUDA_BLAS_DATATYPE_D) {
345 
346  typedef double D;
347  const D alpha = (D)(static_cast<std::complex<double>>(blas_param.alpha).real());
348  const D beta = (D)(static_cast<std::complex<double>>(blas_param.beta).real());
349  GEMM<MatrixXd, D>(A_h, B_h, C_h, alpha, beta, max_stride, blas_param);
350 
351  } else if (blas_param.data_type == QUDA_BLAS_DATATYPE_S) {
352 
353  typedef float S;
354  const S alpha = (S)(static_cast<std::complex<float>>(blas_param.alpha).real());
355  const S beta = (S)(static_cast<std::complex<float>>(blas_param.beta).real());
356  GEMM<MatrixXf, S>(A_h, B_h, C_h, alpha, beta, max_stride, blas_param);
357 
358  } else {
359  errorQuda("blasGEMM type %d not implemented\n", blas_param.data_type);
360  }
361 
362  // Restore the blas parameters to their original values
363  if (blas_param.data_order == QUDA_BLAS_DATAORDER_COL) {
364  std::swap(blas_param.m, blas_param.n);
365  std::swap(blas_param.lda, blas_param.ldb);
366  std::swap(blas_param.trans_a, blas_param.trans_b);
367  std::swap(blas_param.a_offset, blas_param.b_offset);
368  std::swap(blas_param.a_stride, blas_param.b_stride);
369  std::swap(A_data, B_data);
370  }
371 
372  // Transfer data
373  if (location == QUDA_CUDA_FIELD_LOCATION) {
374  qudaMemcpy(C_data, C_h, sizeCarr, cudaMemcpyHostToDevice);
375  pool_pinned_free(A_h);
376  pool_pinned_free(B_h);
377  pool_pinned_free(C_h);
378  }
379 
381  gettimeofday(&stop, NULL);
382  long ds = stop.tv_sec - start.tv_sec;
383  long dus = stop.tv_usec - start.tv_usec;
384  double time = ds + 0.000001 * dus;
386  printfQuda("Batched matrix GEMM completed in %f seconds with GFLOPS = %f\n", time, 1e-9 * flops / time);
387 
388  return flops;
389  }
390  } // namespace generic
391  } // namespace blas_lapack
392 } // namespace quda
#define FLOPS_CGETRF(m_, n_)
Definition: blas_lapack.h:14
#define FLOPS_ZGETRF(m_, n_)
Definition: blas_lapack.h:12
QudaPrecision prec
int Z[4]
Definition: host_utils.cpp:36
enum QudaPrecision_s QudaPrecision
@ QUDA_CUDA_FIELD_LOCATION
Definition: enum_quda.h:326
@ QUDA_CPU_FIELD_LOCATION
Definition: enum_quda.h:325
@ QUDA_DEBUG_VERBOSE
Definition: enum_quda.h:268
@ QUDA_VERBOSE
Definition: enum_quda.h:267
@ QUDA_BLAS_DATATYPE_Z
Definition: enum_quda.h:480
@ QUDA_BLAS_DATATYPE_D
Definition: enum_quda.h:478
@ QUDA_BLAS_DATATYPE_C
Definition: enum_quda.h:479
@ QUDA_BLAS_DATATYPE_S
Definition: enum_quda.h:477
enum QudaFieldLocation_s QudaFieldLocation
@ QUDA_BLAS_DATAORDER_COL
Definition: enum_quda.h:486
@ QUDA_BLAS_OP_C
Definition: enum_quda.h:472
@ QUDA_BLAS_OP_N
Definition: enum_quda.h:470
@ QUDA_BLAS_OP_T
Definition: enum_quda.h:471
@ QUDA_DOUBLE_PRECISION
Definition: enum_quda.h:65
@ QUDA_SINGLE_PRECISION
Definition: enum_quda.h:64
#define pool_pinned_malloc(size)
Definition: malloc_quda.h:172
#define pool_pinned_free(ptr)
Definition: malloc_quda.h:173
void fillArray(EigenMat &EigenArr, T *arr, int rows, int cols, int ld, int offset, bool fill_eigen)
void invertEigen(std::complex< Float > *A_eig, std::complex< Float > *Ainv_eig, int n, uint64_t batch)
void init()
Create the BLAS context.
void GEMM(void *A_h, void *B_h, void *C_h, T alpha, T beta, int max_stride, QudaBLASParam &blas_param)
long long stridedBatchGEMM(void *A, void *B, void *C, QudaBLASParam blas_param, QudaFieldLocation location)
Strided Batch GEMM. This function performs N GEMM type operations in a strided batched fashion....
long long BatchInvertMatrix(void *Ainv, void *A, const int n, const uint64_t batch, QudaPrecision precision, QudaFieldLocation location)
Batch inversion the matrix field using an LU decomposition method.
void destroy()
Destroy the BLAS context.
void set_native(bool native)
unsigned long long flops
void stop()
Stop profiling.
Definition: device.cpp:228
void start()
Start profiling.
Definition: device.cpp:226
__host__ __device__ ValueType norm(const complex< ValueType > &z)
Returns the magnitude of z squared.
FloatingPoint< float > Float
#define qudaMemcpy(dst, src, count, kind)
Definition: quda_api.h:204
#define qudaDeviceSynchronize()
Definition: quda_api.h:250
int c_offset
Definition: quda.h:761
double_complex alpha
Definition: quda.h:766
int a_stride
Definition: quda.h:762
int b_stride
Definition: quda.h:763
QudaBLASDataOrder data_order
Definition: quda.h:772
int c_stride
Definition: quda.h:764
int b_offset
Definition: quda.h:760
QudaBLASOperation trans_a
Definition: quda.h:751
double_complex beta
Definition: quda.h:767
QudaBLASDataType data_type
Definition: quda.h:771
int a_offset
Definition: quda.h:759
int batch_count
Definition: quda.h:769
QudaBLASOperation trans_b
Definition: quda.h:752
DEVICEHOST void swap(Real &a, Real &b)
Definition: svd_quda.h:134
#define printfQuda(...)
Definition: util_quda.h:114
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define errorQuda(...)
Definition: util_quda.h:120