QUDA  v1.1.0
A library for QCD on GPUs
blas_lapack.h
Go to the documentation of this file.
1 #include <quda_internal.h>
2 
3 #pragma once
4 
5 #define FMULS_GETRF(m_, n_) \
6  (((m_) < (n_)) ? (0.5 * (m_) * ((m_) * ((n_) - (1. / 3.) * (m_)-1.) + (n_)) + (2. / 3.) * (m_)) : \
7  (0.5 * (n_) * ((n_) * ((m_) - (1. / 3.) * (n_)-1.) + (m_)) + (2. / 3.) * (n_)))
8 #define FADDS_GETRF(m_, n_) \
9  (((m_) < (n_)) ? (0.5 * (m_) * ((m_) * ((n_) - (1. / 3.) * (m_)) - (n_)) + (1. / 6.) * (m_)) : \
10  (0.5 * (n_) * ((n_) * ((m_) - (1. / 3.) * (n_)) - (m_)) + (1. / 6.) * (n_)))
11 
12 #define FLOPS_ZGETRF(m_, n_) \
13  (6. * FMULS_GETRF((double)(m_), (double)(n_)) + 2.0 * FADDS_GETRF((double)(m_), (double)(n_)))
14 #define FLOPS_CGETRF(m_, n_) \
15  (6. * FMULS_GETRF((double)(m_), (double)(n_)) + 2.0 * FADDS_GETRF((double)(m_), (double)(n_)))
16 
17 #define FMULS_GETRI(n_) ((n_) * ((5. / 6.) + (n_) * ((2. / 3.) * (n_) + 0.5)))
18 #define FADDS_GETRI(n_) ((n_) * ((5. / 6.) + (n_) * ((2. / 3.) * (n_)-1.5)))
19 
20 #define FLOPS_ZGETRI(n_) (6. * FMULS_GETRI((double)(n_)) + 2.0 * FADDS_GETRI((double)(n_)))
21 #define FLOPS_CGETRI(n_) (6. * FMULS_GETRI((double)(n_)) + 2.0 * FADDS_GETRI((double)(n_)))
22 
23 namespace quda
24 {
25 
26  namespace blas_lapack
27  {
28 
29  bool use_native();
30  void set_native(bool native);
31 
37  namespace native
38  {
39 
43  void init();
44 
48  void destroy();
49 
60  long long BatchInvertMatrix(void *Ainv, void *A, const int n, const uint64_t batch, QudaPrecision precision,
61  QudaFieldLocation location);
62 
91  long long stridedBatchGEMM(void *A, void *B, void *C, QudaBLASParam blas_param, QudaFieldLocation location);
92 
93  } // namespace native
94 
100  namespace generic
101  {
102 
106  void init();
107 
111  void destroy();
112 
123  long long BatchInvertMatrix(void *Ainv, void *A, const int n, const uint64_t batch, QudaPrecision precision,
124  QudaFieldLocation location);
125 
154  long long stridedBatchGEMM(void *A, void *B, void *C, QudaBLASParam blas_param, QudaFieldLocation location);
155 
156  } // namespace generic
157  } // namespace blas_lapack
158 } // namespace quda
enum QudaPrecision_s QudaPrecision
enum QudaFieldLocation_s QudaFieldLocation
void init()
Create the BLAS context.
long long stridedBatchGEMM(void *A, void *B, void *C, QudaBLASParam blas_param, QudaFieldLocation location)
Strided Batch GEMM. This function performs N GEMM type operations in a strided batched fashion....
long long BatchInvertMatrix(void *Ainv, void *A, const int n, const uint64_t batch, QudaPrecision precision, QudaFieldLocation location)
Batch inversion the matrix field using an LU decomposition method.
void destroy()
Destroy the BLAS context.
long long BatchInvertMatrix(void *Ainv, void *A, const int n, const uint64_t batch, QudaPrecision precision, QudaFieldLocation location)
Batch inversion the matrix field using an LU decomposition method.
void init()
Create the BLAS context.
long long stridedBatchGEMM(void *A, void *B, void *C, QudaBLASParam blas_param, QudaFieldLocation location)
Strided Batch GEMM. This function performs N GEMM type operations in a strided batched fashion....
void destroy()
Destroy the BLAS context.
void set_native(bool native)