QUDA  v1.1.0
A library for QCD on GPUs
blas_lapack_cublas.cpp
Go to the documentation of this file.
1 #include <complex.h>
2 #include <blas_lapack.h>
3 #ifdef NATIVE_LAPACK_LIB
4 #include <cublas_v2.h>
5 #include <malloc_quda.h>
6 #endif
7 
8 //#define _DEBUG
9 
10 #ifdef _DEBUG
11 #include <eigen_helper.h>
12 #endif
13 
14 namespace quda
15 {
16 
17  namespace blas_lapack
18  {
19 
20  namespace native
21  {
22 
23 #ifdef NATIVE_LAPACK_LIB
24  static cublasHandle_t handle;
25 #endif
26  static bool cublas_init = false;
27 
28  void init()
29  {
30  if (!cublas_init) {
31 #ifdef NATIVE_LAPACK_LIB
32  cublasStatus_t error = cublasCreate(&handle);
33  if (error != CUBLAS_STATUS_SUCCESS)
34  errorQuda("cublasCreate failed with error %d", error);
35  else
36  printfQuda("cublasCreated successfully\n");
37  cublas_init = true;
38 #endif
39  }
40  }
41 
42  void destroy()
43  {
44  if (cublas_init) {
45 #ifdef NATIVE_LAPACK_LIB
46  cublasStatus_t error = cublasDestroy(handle);
47  if (error != CUBLAS_STATUS_SUCCESS)
48  errorQuda("\nError indestroying cublas context, error code = %d\n", error);
49  cublas_init = false;
50 #endif
51  }
52  }
53 
54 #ifdef _DEBUG
55  template <typename EigenMatrix, typename Float>
56  __host__ void checkEigen(std::complex<Float> *A_h, std::complex<Float> *Ainv_h, int n, uint64_t batch)
57  {
58  EigenMatrix A = EigenMatrix::Zero(n, n);
59  EigenMatrix Ainv = EigenMatrix::Zero(n, n);
60  for (int j = 0; j < n; j++) {
61  for (int k = 0; k < n; k++) {
62  A(k, j) = A_h[batch * n * n + j * n + k];
63  Ainv(k, j) = Ainv_h[batch * n * n + j * n + k];
64  }
65  }
66 
67  // Check result:
68  EigenMatrix unit = EigenMatrix::Identity(n, n);
69  EigenMatrix prod = A * Ainv;
70  Float L2norm = ((prod - unit).norm() / (n * n));
71  printfQuda("cuBLAS: Norm of (A * Ainv - I) batch %lu = %e\n", batch, L2norm);
72  }
73 #endif
74 
75  // FIXME do this in pipelined fashion to reduce memory overhead.
76  long long BatchInvertMatrix(void *Ainv, void *A, const int n, const uint64_t batch, QudaPrecision prec,
77  QudaFieldLocation location)
78  {
79 #ifdef NATIVE_LAPACK_LIB
80  init();
81  if (getVerbosity() >= QUDA_VERBOSE)
82  printfQuda("BatchInvertMatrix (native - cuBLAS): Nc = %d, batch = %lu\n", n, batch);
83 
84  long long flops = 0;
85  timeval start, stop;
86  gettimeofday(&start, NULL);
87 
88  size_t size = 2 * n * n * prec * batch;
89  void *A_d = location == QUDA_CUDA_FIELD_LOCATION ? A : pool_device_malloc(size);
90  void *Ainv_d = location == QUDA_CUDA_FIELD_LOCATION ? Ainv : pool_device_malloc(size);
91  if (location == QUDA_CPU_FIELD_LOCATION) qudaMemcpy(A_d, A, size, cudaMemcpyHostToDevice);
92 
93 #ifdef _DEBUG
94  // Debug code: Copy original A matrix to host
95  std::complex<float> *A_h
96  = (location == QUDA_CUDA_FIELD_LOCATION ? static_cast<std::complex<float> *>(pool_pinned_malloc(size)) :
97  static_cast<std::complex<float> *>(A_d));
98  if (location == QUDA_CUDA_FIELD_LOCATION) qudaMemcpy((void *)A_h, A_d, size, cudaMemcpyDeviceToHost);
99 #endif
100 
101  int *dipiv = static_cast<int *>(pool_device_malloc(batch * n * sizeof(int)));
102  int *dinfo_array = static_cast<int *>(pool_device_malloc(batch * sizeof(int)));
103  int *info_array = static_cast<int *>(pool_pinned_malloc(batch * sizeof(int)));
104  memset(info_array, '0', batch * sizeof(int)); // silence memcheck warnings
105 
106  if (prec == QUDA_SINGLE_PRECISION) {
107  typedef cuFloatComplex C;
108  C **A_array = static_cast<C **>(pool_device_malloc(2 * batch * sizeof(C *)));
109  C **Ainv_array = A_array + batch;
110  C **A_array_h = static_cast<C **>(pool_pinned_malloc(2 * batch * sizeof(C *)));
111  C **Ainv_array_h = A_array_h + batch;
112  for (uint64_t i = 0; i < batch; i++) {
113  A_array_h[i] = static_cast<C *>(A_d) + i * n * n;
114  Ainv_array_h[i] = static_cast<C *>(Ainv_d) + i * n * n;
115  }
116  qudaMemcpy(A_array, A_array_h, 2 * batch * sizeof(C *), cudaMemcpyHostToDevice);
117 
118  cublasStatus_t error = cublasCgetrfBatched(handle, n, A_array, n, dipiv, dinfo_array, batch);
119  flops += batch * FLOPS_CGETRF(n, n);
120 
121  if (error != CUBLAS_STATUS_SUCCESS)
122  errorQuda("\nError in LU decomposition (cublasCgetrfBatched), error code = %d\n", error);
123 
124  qudaMemcpy(info_array, dinfo_array, batch * sizeof(int), cudaMemcpyDeviceToHost);
125  for (uint64_t i = 0; i < batch; i++) {
126  if (info_array[i] < 0) {
127  errorQuda("%lu argument had an illegal value or another error occured, such as memory allocation failed",
128  i);
129  } else if (info_array[i] > 0) {
130  errorQuda("%lu factorization completed but the factor U is exactly singular", i);
131  }
132  }
133 
134  error = cublasCgetriBatched(handle, n, (const C **)A_array, n, dipiv, Ainv_array, n, dinfo_array, batch);
135  flops += batch * FLOPS_CGETRI(n);
136 
137  if (error != CUBLAS_STATUS_SUCCESS)
138  errorQuda("\nError in matrix inversion (cublasCgetriBatched), error code = %d\n", error);
139 
140  qudaMemcpy(info_array, dinfo_array, batch * sizeof(int), cudaMemcpyDeviceToHost);
141 
142  for (uint64_t i = 0; i < batch; i++) {
143  if (info_array[i] < 0) {
144  errorQuda("%lu argument had an illegal value or another error occured, such as memory allocation failed",
145  i);
146  } else if (info_array[i] > 0) {
147  errorQuda("%lu factorization completed but the factor U is exactly singular", i);
148  }
149  }
150 
151  pool_device_free(A_array);
152  pool_pinned_free(A_array_h);
153 
154 #ifdef _DEBUG
155  // Debug code: Copy computed Ainv to host
156  std::complex<float> *Ainv_h = static_cast<std::complex<float> *>(pool_pinned_malloc(size));
157  qudaMemcpy((void *)Ainv_h, Ainv_d, size, cudaMemcpyDeviceToHost);
158 
159  for (uint64_t i = 0; i < batch; i++) { checkEigen<MatrixXcf, float>(A_h, Ainv_h, n, i); }
160  pool_pinned_free(Ainv_h);
161  pool_pinned_free(A_h);
162 #endif
163  } else {
164  errorQuda("%s not implemented for precision=%d", __func__, prec);
165  }
166 
167  if (location == QUDA_CPU_FIELD_LOCATION) {
168  qudaMemcpy(Ainv, Ainv_d, size, cudaMemcpyDeviceToHost);
169  pool_device_free(Ainv_d);
170  pool_device_free(A_d);
171  }
172 
173  pool_device_free(dipiv);
174  pool_device_free(dinfo_array);
175  pool_pinned_free(info_array);
176 
178  gettimeofday(&stop, NULL);
179  long ds = stop.tv_sec - start.tv_sec;
180  long dus = stop.tv_usec - start.tv_usec;
181  double time = ds + 0.000001 * dus;
182 
183  if (getVerbosity() >= QUDA_VERBOSE)
184  printfQuda("Batched matrix inversion completed in %f seconds with GFLOPS = %f\n", time, 1e-9 * flops / time);
185 
186  return flops;
187 #else
188  errorQuda("Native BLAS not built. Please build and use native BLAS or use generic BLAS");
189  return 0; // Stops a compiler warning
190 #endif
191  }
192 
193  long long stridedBatchGEMM(void *A_data, void *B_data, void *C_data, QudaBLASParam blas_param,
194  QudaFieldLocation location)
195  {
196  long long flops = 0;
197 #ifdef NATIVE_LAPACK_LIB
198  timeval start, stop;
199  gettimeofday(&start, NULL);
200 
201  // Sanity checks on parameters
202  //-------------------------------------------------------------------------
203  // If the user passes non positive M,N, or K, we error out
204  int min_dim = std::min(blas_param.m, std::min(blas_param.n, blas_param.k));
205  if (min_dim <= 0) {
206  errorQuda("BLAS dims must be positive: m=%d, n=%d, k=%d", blas_param.m, blas_param.n, blas_param.k);
207  }
208 
209  // If the user passes a negative stride, we error out as this has no meaning.
210  int min_stride = std::min(std::min(blas_param.a_stride, blas_param.b_stride), blas_param.c_stride);
211  if (min_stride < 0) {
212  errorQuda("BLAS strides must be positive or zero: a_stride=%d, b_stride=%d, c_stride=%d", blas_param.a_stride,
213  blas_param.b_stride, blas_param.c_stride);
214  }
215 
216  // If the user passes a negative offset, we error out as this has no meaning.
217  int min_offset = std::min(std::min(blas_param.a_offset, blas_param.b_offset), blas_param.c_offset);
218  if (min_offset < 0) {
219  errorQuda("BLAS offsets must be positive or zero: a_offset=%d, b_offset=%d, c_offset=%d", blas_param.a_offset,
220  blas_param.b_offset, blas_param.c_offset);
221  }
222 
223  // If the batch value is non-positve, we error out
224  if (blas_param.batch_count <= 0) { errorQuda("Batches must be positive: batches=%d", blas_param.batch_count); }
225 
226  // Leading dims are dependendent on the matrix op type.
227  if (blas_param.data_order == QUDA_BLAS_DATAORDER_COL) {
228  if (blas_param.trans_a == QUDA_BLAS_OP_N) {
229  if (blas_param.lda < std::max(1, blas_param.m))
230  errorQuda("lda=%d must be >= max(1,m=%d)", blas_param.lda, blas_param.m);
231  } else {
232  if (blas_param.lda < std::max(1, blas_param.k))
233  errorQuda("lda=%d must be >= max(1,k=%d)", blas_param.lda, blas_param.k);
234  }
235 
236  if (blas_param.trans_b == QUDA_BLAS_OP_N) {
237  if (blas_param.ldb < std::max(1, blas_param.k))
238  errorQuda("ldb=%d must be >= max(1,k=%d)", blas_param.ldb, blas_param.k);
239  } else {
240  if (blas_param.ldb < std::max(1, blas_param.n))
241  errorQuda("ldb=%d must be >= max(1,n=%d)", blas_param.ldb, blas_param.n);
242  }
243  if (blas_param.ldc < std::max(1, blas_param.m))
244  errorQuda("ldc=%d must be >= max(1,m=%d)", blas_param.ldc, blas_param.m);
245  } else {
246  if (blas_param.trans_a == QUDA_BLAS_OP_N) {
247  if (blas_param.lda < std::max(1, blas_param.k))
248  errorQuda("lda=%d must be >= max(1,k=%d)", blas_param.lda, blas_param.k);
249  } else {
250  if (blas_param.lda < std::max(1, blas_param.m))
251  errorQuda("lda=%d must be >= max(1,m=%d)", blas_param.lda, blas_param.m);
252  }
253  if (blas_param.trans_b == QUDA_BLAS_OP_N) {
254  if (blas_param.ldb < std::max(1, blas_param.n))
255  errorQuda("ldb=%d must be >= max(1,n=%d)", blas_param.ldb, blas_param.n);
256  } else {
257  if (blas_param.ldb < std::max(1, blas_param.k))
258  errorQuda("ldb=%d must be >= max(1,k=%d)", blas_param.ldb, blas_param.k);
259  }
260  if (blas_param.ldc < std::max(1, blas_param.n))
261  errorQuda("ldc=%d must be >= max(1,n=%d)", blas_param.ldc, blas_param.n);
262  }
263  //-------------------------------------------------------------------------
264 
265  // Parse parameters for CUBLAS
266  //-------------------------------------------------------------------------
267  // Swap A and B if in row order
268  if (blas_param.data_order == QUDA_BLAS_DATAORDER_ROW) {
269  std::swap(blas_param.m, blas_param.n);
270  std::swap(blas_param.lda, blas_param.ldb);
271  std::swap(blas_param.trans_a, blas_param.trans_b);
272  std::swap(blas_param.a_offset, blas_param.b_offset);
273  std::swap(blas_param.a_stride, blas_param.b_stride);
274  std::swap(A_data, B_data);
275  }
276 
277  // Get maximum stride length to deduce the number of batches in the
278  // computation
279  int max_stride = std::max(std::max(blas_param.a_stride, blas_param.b_stride), blas_param.c_stride);
280 
281  // If the user gives strides of 0 for all arrays, we are essentially performing
282  // a GEMM on the first matrices in the array N_{batch} times.
283  // Give them what they ask for, YMMV...
284  // If the strides have not been set, we are just using strides of 1.
285  if (max_stride == 0) max_stride = 1;
286 
287  // The number of GEMMs to compute
288  const uint64_t batch = blas_param.batch_count / max_stride;
289 
290  uint64_t data_size
291  = (blas_param.data_type == QUDA_BLAS_DATATYPE_S || blas_param.data_type == QUDA_BLAS_DATATYPE_C) ? 4 : 8;
292 
293  if (blas_param.data_type == QUDA_BLAS_DATATYPE_C || blas_param.data_type == QUDA_BLAS_DATATYPE_Z) {
294  data_size *= 2;
295  }
296 
297  // Number of data between batches
298  unsigned int A_batch_size = blas_param.lda * blas_param.k;
299  if (blas_param.trans_a != QUDA_BLAS_OP_N) A_batch_size = blas_param.lda * blas_param.m;
300  unsigned int B_batch_size = blas_param.ldb * blas_param.n;
301  if (blas_param.trans_b != QUDA_BLAS_OP_N) B_batch_size = blas_param.ldb * blas_param.k;
302  unsigned int C_batch_size = blas_param.ldc * blas_param.n;
303 
304  // Strides in the cublas param are defaulted to -1. If that remains unchanged,
305  // the stride will be the regular batch size, else the user specified value
306  // is used.
307  unsigned int a_stride = blas_param.a_stride == 0 ? A_batch_size : A_batch_size * blas_param.a_stride;
308  unsigned int b_stride = blas_param.b_stride == 0 ? B_batch_size : B_batch_size * blas_param.b_stride;
309  unsigned int c_stride = blas_param.c_stride == 0 ? C_batch_size : C_batch_size * blas_param.c_stride;
310 
311  // Data size of the entire array
312  size_t sizeAarr = A_batch_size * data_size * batch;
313  size_t sizeBarr = B_batch_size * data_size * batch;
314  size_t sizeCarr = C_batch_size * data_size * batch;
315 
316  // If already on the device, just use the given pointer. If the data is on
317  // the host, allocate device memory and transfer
318  void *A_d = location == QUDA_CUDA_FIELD_LOCATION ? A_data : pool_device_malloc(sizeAarr);
319  void *B_d = location == QUDA_CUDA_FIELD_LOCATION ? B_data : pool_device_malloc(sizeBarr);
320  void *C_d = location == QUDA_CUDA_FIELD_LOCATION ? C_data : pool_device_malloc(sizeCarr);
321  if (location == QUDA_CPU_FIELD_LOCATION) {
322  qudaMemcpy(A_d, A_data, sizeAarr, cudaMemcpyHostToDevice);
323  qudaMemcpy(B_d, B_data, sizeBarr, cudaMemcpyHostToDevice);
324  qudaMemcpy(C_d, C_data, sizeCarr, cudaMemcpyHostToDevice);
325  }
326 
327  cublasOperation_t trans_a = CUBLAS_OP_N;
328  switch (blas_param.trans_a) {
329  case QUDA_BLAS_OP_N: trans_a = CUBLAS_OP_N; break;
330  case QUDA_BLAS_OP_T: trans_a = CUBLAS_OP_T; break;
331  case QUDA_BLAS_OP_C: trans_a = CUBLAS_OP_C; break;
332  default: errorQuda("Unknown QUDA_BLAS_OP type %d\n", blas_param.trans_a);
333  }
334 
335  cublasOperation_t trans_b = CUBLAS_OP_N;
336  switch (blas_param.trans_b) {
337  case QUDA_BLAS_OP_N: trans_b = CUBLAS_OP_N; break;
338  case QUDA_BLAS_OP_T: trans_b = CUBLAS_OP_T; break;
339  case QUDA_BLAS_OP_C: trans_b = CUBLAS_OP_C; break;
340  default: errorQuda("Unknown QUDA_BLAS_OP type %d\n", blas_param.trans_b);
341  }
342  //-------------------------------------------------------------------------
343 
344  // Call CUBLAS
345  //-------------------------------------------------------------------------
346  if (blas_param.data_type == QUDA_BLAS_DATATYPE_Z) {
347 
348  typedef cuDoubleComplex Z;
349 
350  const Z alpha = make_double2((double)(static_cast<std::complex<double>>(blas_param.alpha).real()),
351  (double)(static_cast<std::complex<double>>(blas_param.alpha).imag()));
352 
353  const Z beta = make_double2((double)(static_cast<std::complex<double>>(blas_param.beta).real()),
354  (double)(static_cast<std::complex<double>>(blas_param.beta).imag()));
355 
356  cublasStatus_t error;
357  if (batch > 1) {
358  error = cublasZgemmStridedBatched(handle, trans_a, trans_b, blas_param.m, blas_param.n, blas_param.k,
359  &alpha, (Z *)A_d + blas_param.a_offset, blas_param.lda, a_stride,
360  (Z *)B_d + blas_param.b_offset, blas_param.ldb, b_stride, &beta,
361  (Z *)C_d + blas_param.c_offset, blas_param.ldc, c_stride, batch);
362 
363  if (error != CUBLAS_STATUS_SUCCESS)
364  errorQuda("\nError in cuBLASZGEMMStridedBatched, error code = %d\n", error);
365  } else {
366  error = cublasZgemm(handle, trans_a, trans_b, blas_param.m, blas_param.n, blas_param.k, &alpha,
367  (Z *)A_d + blas_param.a_offset, blas_param.lda, (Z *)B_d + blas_param.b_offset,
368  blas_param.ldb, &beta, (Z *)C_d + blas_param.c_offset, blas_param.ldc);
369 
370  if (error != CUBLAS_STATUS_SUCCESS) errorQuda("\nError in cuBLASZGEMM, error code = %d\n", error);
371  }
372  } else if (blas_param.data_type == QUDA_BLAS_DATATYPE_C) {
373 
374  typedef cuFloatComplex C;
375 
376  const C alpha = make_float2((float)(static_cast<std::complex<double>>(blas_param.alpha).real()),
377  (float)(static_cast<std::complex<double>>(blas_param.alpha).imag()));
378 
379  const C beta = make_float2((float)(static_cast<std::complex<double>>(blas_param.beta).real()),
380  (float)(static_cast<std::complex<double>>(blas_param.beta).imag()));
381 
382  cublasStatus_t error;
383  if (batch > 1) {
384  error = cublasCgemmStridedBatched(handle, trans_a, trans_b, blas_param.m, blas_param.n, blas_param.k,
385  &alpha, (C *)A_d + blas_param.a_offset, blas_param.lda, a_stride,
386  (C *)B_d + blas_param.b_offset, blas_param.ldb, b_stride, &beta,
387  (C *)C_d + blas_param.c_offset, blas_param.ldc, c_stride, batch);
388 
389  if (error != CUBLAS_STATUS_SUCCESS)
390  errorQuda("\nError in cuBLASCGEMMStridedBatched, error code = %d\n", error);
391  } else {
392  error = cublasCgemm(handle, trans_a, trans_b, blas_param.m, blas_param.n, blas_param.k, &alpha,
393  (C *)A_d + blas_param.a_offset, blas_param.lda, (C *)B_d + blas_param.b_offset,
394  blas_param.ldb, &beta, (C *)C_d + blas_param.c_offset, blas_param.ldc);
395 
396  if (error != CUBLAS_STATUS_SUCCESS) errorQuda("\nError in cuBLASCGEMMBatched, error code = %d\n", error);
397  }
398  } else if (blas_param.data_type == QUDA_BLAS_DATATYPE_D) {
399 
400  typedef double D;
401 
402  const D alpha = (D)(static_cast<std::complex<double>>(blas_param.alpha).real());
403  const D beta = (D)(static_cast<std::complex<double>>(blas_param.beta).real());
404 
405  cublasStatus_t error;
406  if (batch > 1) {
407  error = cublasDgemmStridedBatched(handle, trans_a, trans_b, blas_param.m, blas_param.n, blas_param.k,
408  &alpha, (D *)A_d + blas_param.a_offset, blas_param.lda, a_stride,
409  (D *)B_d + blas_param.b_offset, blas_param.ldb, b_stride, &beta,
410  (D *)C_d + blas_param.c_offset, blas_param.ldc, c_stride, batch);
411 
412  if (error != CUBLAS_STATUS_SUCCESS)
413  errorQuda("\nError in cuBLASDGEMMStridedBatched, error code = %d\n", error);
414  } else {
415  error = cublasDgemm(handle, trans_a, trans_b, blas_param.m, blas_param.n, blas_param.k, &alpha,
416  (D *)A_d + blas_param.a_offset, blas_param.lda, (D *)B_d + blas_param.b_offset,
417  blas_param.ldb, &beta, (D *)C_d + blas_param.c_offset, blas_param.ldc);
418 
419  if (error != CUBLAS_STATUS_SUCCESS) errorQuda("\nError in cuBLASDGEMMBatched, error code = %d\n", error);
420  }
421  } else if (blas_param.data_type == QUDA_BLAS_DATATYPE_S) {
422 
423  typedef float S;
424 
425  const S alpha = (S)(static_cast<std::complex<float>>(blas_param.alpha).real());
426  const S beta = (S)(static_cast<std::complex<float>>(blas_param.beta).real());
427 
428  cublasStatus_t error;
429  if (batch > 1) {
430  error = cublasSgemmStridedBatched(handle, trans_a, trans_b, blas_param.m, blas_param.n, blas_param.k,
431  &alpha, (S *)A_d + blas_param.a_offset, blas_param.lda, a_stride,
432  (S *)B_d + blas_param.b_offset, blas_param.ldb, b_stride, &beta,
433  (S *)C_d + blas_param.c_offset, blas_param.ldc, c_stride, batch);
434 
435  if (error != CUBLAS_STATUS_SUCCESS)
436  errorQuda("\nError in cuBLASSGEMMStridedBatched, error code = %d\n", error);
437  } else {
438  error = cublasSgemm(handle, trans_a, trans_b, blas_param.m, blas_param.n, blas_param.k, &alpha,
439  (S *)A_d + blas_param.a_offset, blas_param.lda, (S *)B_d + blas_param.b_offset,
440  blas_param.ldb, &beta, (S *)C_d + blas_param.c_offset, blas_param.ldc);
441 
442  if (error != CUBLAS_STATUS_SUCCESS) errorQuda("\nError in cuBLASSGEMMBatched, error code = %d\n", error);
443  }
444  } else {
445  errorQuda("cublasGEMM type %d not implemented\n", blas_param.data_type);
446  }
447  //-------------------------------------------------------------------------
448 
449  // Clean up
450  //-------------------------------------------------------------------------
451  if (blas_param.data_order == QUDA_BLAS_DATAORDER_ROW) {
452  std::swap(blas_param.m, blas_param.n);
453  std::swap(blas_param.lda, blas_param.ldb);
454  std::swap(blas_param.trans_a, blas_param.trans_b);
455  std::swap(blas_param.a_offset, blas_param.b_offset);
456  std::swap(blas_param.a_stride, blas_param.b_stride);
457  std::swap(A_data, B_data);
458  }
459 
460  if (location == QUDA_CPU_FIELD_LOCATION) {
461  qudaMemcpy(C_data, C_d, sizeCarr, cudaMemcpyDeviceToHost);
462  pool_device_free(A_d);
463  pool_device_free(B_d);
464  pool_device_free(C_d);
465  }
466 
468  gettimeofday(&stop, NULL);
469  long ds = stop.tv_sec - start.tv_sec;
470  long dus = stop.tv_usec - start.tv_usec;
471  double time = ds + 0.000001 * dus;
473  printfQuda("Batched matrix GEMM completed in %f seconds with GFLOPS = %f\n", time, 1e-9 * flops / time);
474  //-------------------------------------------------------------------------
475 
476  return flops;
477 #else
478  errorQuda("Native BLAS not built. Please build and use native BLAS or use generic BLAS");
479  return 0; // Stops a compiler warning
480 #endif
481  }
482  } // namespace native
483  } // namespace blas_lapack
484 } // namespace quda
#define FLOPS_CGETRF(m_, n_)
Definition: blas_lapack.h:14
#define FLOPS_CGETRI(n_)
Definition: blas_lapack.h:21
QudaPrecision prec
int Z[4]
Definition: host_utils.cpp:36
void * memset(void *s, int c, size_t n)
enum QudaPrecision_s QudaPrecision
@ QUDA_CUDA_FIELD_LOCATION
Definition: enum_quda.h:326
@ QUDA_CPU_FIELD_LOCATION
Definition: enum_quda.h:325
@ QUDA_DEBUG_VERBOSE
Definition: enum_quda.h:268
@ QUDA_VERBOSE
Definition: enum_quda.h:267
@ QUDA_BLAS_DATATYPE_Z
Definition: enum_quda.h:480
@ QUDA_BLAS_DATATYPE_D
Definition: enum_quda.h:478
@ QUDA_BLAS_DATATYPE_C
Definition: enum_quda.h:479
@ QUDA_BLAS_DATATYPE_S
Definition: enum_quda.h:477
enum QudaFieldLocation_s QudaFieldLocation
@ QUDA_BLAS_DATAORDER_COL
Definition: enum_quda.h:486
@ QUDA_BLAS_DATAORDER_ROW
Definition: enum_quda.h:485
@ QUDA_BLAS_OP_C
Definition: enum_quda.h:472
@ QUDA_BLAS_OP_N
Definition: enum_quda.h:470
@ QUDA_BLAS_OP_T
Definition: enum_quda.h:471
@ QUDA_SINGLE_PRECISION
Definition: enum_quda.h:64
#define pool_pinned_malloc(size)
Definition: malloc_quda.h:172
#define pool_device_malloc(size)
Definition: malloc_quda.h:170
#define pool_pinned_free(ptr)
Definition: malloc_quda.h:173
#define pool_device_free(ptr)
Definition: malloc_quda.h:171
long long BatchInvertMatrix(void *Ainv, void *A, const int n, const uint64_t batch, QudaPrecision precision, QudaFieldLocation location)
Batch inversion the matrix field using an LU decomposition method.
void init()
Create the BLAS context.
long long stridedBatchGEMM(void *A, void *B, void *C, QudaBLASParam blas_param, QudaFieldLocation location)
Strided Batch GEMM. This function performs N GEMM type operations in a strided batched fashion....
void destroy()
Destroy the BLAS context.
unsigned long long flops
void stop()
Stop profiling.
Definition: device.cpp:228
void start()
Start profiling.
Definition: device.cpp:226
__host__ __device__ ValueType norm(const complex< ValueType > &z)
Returns the magnitude of z squared.
FloatingPoint< float > Float
#define qudaMemcpy(dst, src, count, kind)
Definition: quda_api.h:204
#define qudaDeviceSynchronize()
Definition: quda_api.h:250
int c_offset
Definition: quda.h:761
double_complex alpha
Definition: quda.h:766
int a_stride
Definition: quda.h:762
int b_stride
Definition: quda.h:763
QudaBLASDataOrder data_order
Definition: quda.h:772
int c_stride
Definition: quda.h:764
int b_offset
Definition: quda.h:760
QudaBLASOperation trans_a
Definition: quda.h:751
double_complex beta
Definition: quda.h:767
QudaBLASDataType data_type
Definition: quda.h:771
int a_offset
Definition: quda.h:759
int batch_count
Definition: quda.h:769
QudaBLASOperation trans_b
Definition: quda.h:752
DEVICEHOST void swap(Real &a, Real &b)
Definition: svd_quda.h:134
#define printfQuda(...)
Definition: util_quda.h:114
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define errorQuda(...)
Definition: util_quda.h:120