12 static bool native_blas_lapack =
true;
14 void set_native(
bool native) { native_blas_lapack = native; }
25 template <
typename EigenMatrix,
typename Float>
26 void invertEigen(std::complex<Float> *A_eig, std::complex<Float> *Ainv_eig,
int n, uint64_t batch)
28 EigenMatrix res = EigenMatrix::Zero(n, n);
29 EigenMatrix inv = EigenMatrix::Zero(n, n);
30 for (
int j = 0; j < n; j++) {
31 for (
int k = 0; k < n; k++) { res(k, j) = A_eig[batch * n * n + j * n + k]; }
36 for (
int j = 0; j < n; j++) {
37 for (
int k = 0; k < n; k++) { Ainv_eig[batch * n * n + j * n + k] = inv(k, j); }
42 EigenMatrix unit = EigenMatrix::Identity(n, n);
43 EigenMatrix prod = res * inv;
44 Float L2norm = ((prod - unit).
norm() / (n * n));
45 printfQuda(
"Eigen: Norm of (A * Ainv - I) batch %lu = %e\n", batch, L2norm);
56 printfQuda(
"BatchInvertMatrix (generic - Eigen): Nc = %d, batch = %lu\n", n, batch);
58 size_t size = 2 * n * n * batch *
prec;
65 gettimeofday(&
start, NULL);
68 std::complex<float> *A_eig = (std::complex<float> *)A_h;
69 std::complex<float> *Ainv_eig = (std::complex<float> *)Ainv_h;
72 #pragma omp parallel for
74 for (uint64_t i = 0; i < batch; i++) { invertEigen<MatrixXcf, float>(A_eig, Ainv_eig, n, i); }
77 std::complex<double> *A_eig = (std::complex<double> *)A_h;
78 std::complex<double> *Ainv_eig = (std::complex<double> *)Ainv_h;
81 #pragma omp parallel for
83 for (uint64_t i = 0; i < batch; i++) { invertEigen<MatrixXcd, double>(A_eig, Ainv_eig, n, i); }
86 errorQuda(
"%s not implemented for precision = %d", __func__,
prec);
89 gettimeofday(&
stop, NULL);
92 double timeh = dsh + 0.000001 * dush;
97 threads = omp_get_num_threads();
99 printfQuda(
"CPU: Batched matrix inversion completed in %f seconds using %d threads with GFLOPS = %f\n", timeh,
100 threads, 1e-9 *
flops / timeh);
106 qudaMemcpy((
void *)Ainv, Ainv_h, size, cudaMemcpyHostToDevice);
114 template <
typename EigenMat,
typename T>
115 void fillArray(EigenMat &EigenArr, T *arr,
int rows,
int cols,
int ld,
int offset,
bool fill_eigen)
117 int counter = offset;
118 for (
int i = 0; i < rows; i++) {
119 for (
int j = 0; j < cols; j++) {
121 EigenArr(i, j) = arr[counter];
123 arr[counter] = EigenArr(i, j);
126 counter += (ld - cols);
130 template <
typename EigenMat,
typename T>
131 void GEMM(
void *A_h,
void *B_h,
void *C_h, T alpha, T beta,
int max_stride,
QudaBLASParam &blas_param)
134 int m = blas_param.
m;
135 int n = blas_param.
n;
136 int k = blas_param.
k;
137 int lda = blas_param.
lda;
138 int ldb = blas_param.
ldb;
139 int ldc = blas_param.
ldc;
152 unsigned int A_batch_size = blas_param.
lda * blas_param.
k;
154 unsigned int B_batch_size = blas_param.
ldb * blas_param.
n;
156 unsigned int C_batch_size = blas_param.
ldc * blas_param.
n;
158 T *A_ptr = (T *)(&A_h)[0];
159 T *B_ptr = (T *)(&B_h)[0];
160 T *C_ptr = (T *)(&C_h)[0];
163 EigenMat Amat = EigenMat::Zero(m, k);
164 EigenMat Bmat = EigenMat::Zero(k, n);
165 EigenMat Cmat = EigenMat::Zero(m, n);
167 for (
int batch = 0; batch < batches; batch += max_stride) {
170 fillArray<EigenMat, T>(Amat, A_ptr, m, k, lda, a_offset,
true);
171 fillArray<EigenMat, T>(Bmat, B_ptr, k, n, ldb, b_offset,
true);
172 fillArray<EigenMat, T>(Cmat, C_ptr, m, n, ldc, c_offset,
true);
190 Cmat = alpha * Amat * Bmat + beta * Cmat;
193 fillArray<EigenMat, T>(Cmat, C_ptr, m, n, ldc, c_offset,
false);
195 a_offset += A_batch_size * a_stride;
196 b_offset += B_batch_size * b_stride;
197 c_offset += C_batch_size * c_stride;
209 gettimeofday(&
start, NULL);
214 int min_dim = std::min(blas_param.
m, std::min(blas_param.
n, blas_param.
k));
216 errorQuda(
"BLAS dims must be positive: m=%d, n=%d, k=%d", blas_param.
m, blas_param.
n, blas_param.
k);
221 if (min_stride < 0) {
222 errorQuda(
"BLAS strides must be positive or zero: a_stride=%d, b_stride=%d, c_stride=%d", blas_param.
a_stride,
228 if (min_offset < 0) {
229 errorQuda(
"BLAS offsets must be positive or zero: a_offset=%d, b_offset=%d, c_offset=%d", blas_param.
a_offset,
239 if (blas_param.
lda < std::max(1, blas_param.
m))
240 errorQuda(
"lda=%d must be >= max(1,m=%d)", blas_param.
lda, blas_param.
m);
242 if (blas_param.
lda < std::max(1, blas_param.
k))
243 errorQuda(
"lda=%d must be >= max(1,k=%d)", blas_param.
lda, blas_param.
k);
247 if (blas_param.
ldb < std::max(1, blas_param.
k))
248 errorQuda(
"ldb=%d must be >= max(1,k=%d)", blas_param.
ldb, blas_param.
k);
250 if (blas_param.
ldb < std::max(1, blas_param.
n))
251 errorQuda(
"ldb=%d must be >= max(1,n=%d)", blas_param.
ldb, blas_param.
n);
253 if (blas_param.
ldc < std::max(1, blas_param.
m))
254 errorQuda(
"ldc=%d must be >= max(1,m=%d)", blas_param.
ldc, blas_param.
m);
257 if (blas_param.
lda < std::max(1, blas_param.
k))
258 errorQuda(
"lda=%d must be >= max(1,k=%d)", blas_param.
lda, blas_param.
k);
260 if (blas_param.
lda < std::max(1, blas_param.
m))
261 errorQuda(
"lda=%d must be >= max(1,m=%d)", blas_param.
lda, blas_param.
m);
264 if (blas_param.
ldb < std::max(1, blas_param.
n))
265 errorQuda(
"ldb=%d must be >= max(1,n=%d)", blas_param.
ldb, blas_param.
n);
267 if (blas_param.
ldb < std::max(1, blas_param.
k))
268 errorQuda(
"ldb=%d must be >= max(1,k=%d)", blas_param.
ldb, blas_param.
k);
270 if (blas_param.
ldc < std::max(1, blas_param.
n))
271 errorQuda(
"ldc=%d must be >= max(1,n=%d)", blas_param.
ldc, blas_param.
n);
295 if (max_stride <= 0) max_stride = 1;
298 const uint64_t batch = blas_param.
batch_count / max_stride;
308 unsigned int A_batch_size = blas_param.
lda * blas_param.
k;
310 unsigned int B_batch_size = blas_param.
ldb * blas_param.
n;
312 unsigned int C_batch_size = blas_param.
ldc * blas_param.
n;
315 size_t sizeAarr = A_batch_size * data_size * batch;
316 size_t sizeBarr = B_batch_size * data_size * batch;
317 size_t sizeCarr = C_batch_size * data_size * batch;
325 qudaMemcpy(A_h, A_data, sizeAarr, cudaMemcpyDeviceToHost);
326 qudaMemcpy(B_h, B_data, sizeBarr, cudaMemcpyDeviceToHost);
327 qudaMemcpy(C_h, C_data, sizeCarr, cudaMemcpyDeviceToHost);
332 typedef std::complex<double>
Z;
333 const Z alpha = blas_param.
alpha;
334 const Z beta = blas_param.
beta;
335 GEMM<MatrixXcd, Z>(A_h, B_h, C_h, alpha, beta, max_stride, blas_param);
339 typedef std::complex<float> C;
340 const C alpha = blas_param.
alpha;
341 const C beta = blas_param.
beta;
342 GEMM<MatrixXcf, C>(A_h, B_h, C_h, alpha, beta, max_stride, blas_param);
347 const D alpha = (D)(
static_cast<std::complex<double>
>(blas_param.
alpha).real());
348 const D beta = (D)(
static_cast<std::complex<double>
>(blas_param.
beta).real());
349 GEMM<MatrixXd, D>(A_h, B_h, C_h, alpha, beta, max_stride, blas_param);
354 const S alpha = (S)(
static_cast<std::complex<float>
>(blas_param.
alpha).real());
355 const S beta = (S)(
static_cast<std::complex<float>
>(blas_param.
beta).real());
356 GEMM<MatrixXf, S>(A_h, B_h, C_h, alpha, beta, max_stride, blas_param);
374 qudaMemcpy(C_data, C_h, sizeCarr, cudaMemcpyHostToDevice);
381 gettimeofday(&
stop, NULL);
384 double time = ds + 0.000001 * dus;
386 printfQuda(
"Batched matrix GEMM completed in %f seconds with GFLOPS = %f\n", time, 1e-9 *
flops / time);
#define FLOPS_CGETRF(m_, n_)
#define FLOPS_ZGETRF(m_, n_)
enum QudaPrecision_s QudaPrecision
@ QUDA_CUDA_FIELD_LOCATION
@ QUDA_CPU_FIELD_LOCATION
enum QudaFieldLocation_s QudaFieldLocation
@ QUDA_BLAS_DATAORDER_COL
#define pool_pinned_malloc(size)
#define pool_pinned_free(ptr)
void fillArray(EigenMat &EigenArr, T *arr, int rows, int cols, int ld, int offset, bool fill_eigen)
void invertEigen(std::complex< Float > *A_eig, std::complex< Float > *Ainv_eig, int n, uint64_t batch)
void init()
Create the BLAS context.
void GEMM(void *A_h, void *B_h, void *C_h, T alpha, T beta, int max_stride, QudaBLASParam &blas_param)
long long stridedBatchGEMM(void *A, void *B, void *C, QudaBLASParam blas_param, QudaFieldLocation location)
Strided Batch GEMM. This function performs N GEMM type operations in a strided batched fashion....
long long BatchInvertMatrix(void *Ainv, void *A, const int n, const uint64_t batch, QudaPrecision precision, QudaFieldLocation location)
Batch inversion the matrix field using an LU decomposition method.
void destroy()
Destroy the BLAS context.
void set_native(bool native)
void stop()
Stop profiling.
void start()
Start profiling.
__host__ __device__ ValueType norm(const complex< ValueType > &z)
Returns the magnitude of z squared.
FloatingPoint< float > Float
#define qudaMemcpy(dst, src, count, kind)
#define qudaDeviceSynchronize()
QudaBLASDataOrder data_order
QudaBLASOperation trans_a
QudaBLASDataType data_type
QudaBLASOperation trans_b
DEVICEHOST void swap(Real &a, Real &b)
QudaVerbosity getVerbosity()