3 #ifdef NATIVE_LAPACK_LIB
23 #ifdef NATIVE_LAPACK_LIB
24 static cublasHandle_t handle;
26 static bool cublas_init =
false;
31 #ifdef NATIVE_LAPACK_LIB
32 cublasStatus_t error = cublasCreate(&handle);
33 if (error != CUBLAS_STATUS_SUCCESS)
34 errorQuda(
"cublasCreate failed with error %d", error);
45 #ifdef NATIVE_LAPACK_LIB
46 cublasStatus_t error = cublasDestroy(handle);
47 if (error != CUBLAS_STATUS_SUCCESS)
48 errorQuda(
"\nError indestroying cublas context, error code = %d\n", error);
55 template <
typename EigenMatrix,
typename Float>
56 __host__
void checkEigen(std::complex<Float> *A_h, std::complex<Float> *Ainv_h,
int n, uint64_t batch)
58 EigenMatrix A = EigenMatrix::Zero(n, n);
59 EigenMatrix Ainv = EigenMatrix::Zero(n, n);
60 for (
int j = 0; j < n; j++) {
61 for (
int k = 0; k < n; k++) {
62 A(k, j) = A_h[batch * n * n + j * n + k];
63 Ainv(k, j) = Ainv_h[batch * n * n + j * n + k];
68 EigenMatrix unit = EigenMatrix::Identity(n, n);
69 EigenMatrix prod = A * Ainv;
70 Float L2norm = ((prod - unit).
norm() / (n * n));
71 printfQuda(
"cuBLAS: Norm of (A * Ainv - I) batch %lu = %e\n", batch, L2norm);
79 #ifdef NATIVE_LAPACK_LIB
82 printfQuda(
"BatchInvertMatrix (native - cuBLAS): Nc = %d, batch = %lu\n", n, batch);
86 gettimeofday(&
start, NULL);
88 size_t size = 2 * n * n *
prec * batch;
95 std::complex<float> *A_h
97 static_cast<std::complex<float> *
>(A_d));
104 memset(info_array,
'0', batch *
sizeof(
int));
107 typedef cuFloatComplex C;
109 C **Ainv_array = A_array + batch;
111 C **Ainv_array_h = A_array_h + batch;
112 for (uint64_t i = 0; i < batch; i++) {
113 A_array_h[i] =
static_cast<C *
>(A_d) + i * n * n;
114 Ainv_array_h[i] =
static_cast<C *
>(Ainv_d) + i * n * n;
116 qudaMemcpy(A_array, A_array_h, 2 * batch *
sizeof(C *), cudaMemcpyHostToDevice);
118 cublasStatus_t error = cublasCgetrfBatched(handle, n, A_array, n, dipiv, dinfo_array, batch);
121 if (error != CUBLAS_STATUS_SUCCESS)
122 errorQuda(
"\nError in LU decomposition (cublasCgetrfBatched), error code = %d\n", error);
124 qudaMemcpy(info_array, dinfo_array, batch *
sizeof(
int), cudaMemcpyDeviceToHost);
125 for (uint64_t i = 0; i < batch; i++) {
126 if (info_array[i] < 0) {
127 errorQuda(
"%lu argument had an illegal value or another error occured, such as memory allocation failed",
129 }
else if (info_array[i] > 0) {
130 errorQuda(
"%lu factorization completed but the factor U is exactly singular", i);
134 error = cublasCgetriBatched(handle, n, (
const C **)A_array, n, dipiv, Ainv_array, n, dinfo_array, batch);
137 if (error != CUBLAS_STATUS_SUCCESS)
138 errorQuda(
"\nError in matrix inversion (cublasCgetriBatched), error code = %d\n", error);
140 qudaMemcpy(info_array, dinfo_array, batch *
sizeof(
int), cudaMemcpyDeviceToHost);
142 for (uint64_t i = 0; i < batch; i++) {
143 if (info_array[i] < 0) {
144 errorQuda(
"%lu argument had an illegal value or another error occured, such as memory allocation failed",
146 }
else if (info_array[i] > 0) {
147 errorQuda(
"%lu factorization completed but the factor U is exactly singular", i);
156 std::complex<float> *Ainv_h =
static_cast<std::complex<float> *
>(
pool_pinned_malloc(size));
157 qudaMemcpy((
void *)Ainv_h, Ainv_d, size, cudaMemcpyDeviceToHost);
159 for (uint64_t i = 0; i < batch; i++) { checkEigen<MatrixXcf, float>(A_h, Ainv_h, n, i); }
164 errorQuda(
"%s not implemented for precision=%d", __func__,
prec);
168 qudaMemcpy(Ainv, Ainv_d, size, cudaMemcpyDeviceToHost);
178 gettimeofday(&
stop, NULL);
181 double time = ds + 0.000001 * dus;
184 printfQuda(
"Batched matrix inversion completed in %f seconds with GFLOPS = %f\n", time, 1e-9 *
flops / time);
188 errorQuda(
"Native BLAS not built. Please build and use native BLAS or use generic BLAS");
197 #ifdef NATIVE_LAPACK_LIB
199 gettimeofday(&
start, NULL);
204 int min_dim = std::min(blas_param.
m, std::min(blas_param.
n, blas_param.
k));
206 errorQuda(
"BLAS dims must be positive: m=%d, n=%d, k=%d", blas_param.
m, blas_param.
n, blas_param.
k);
211 if (min_stride < 0) {
212 errorQuda(
"BLAS strides must be positive or zero: a_stride=%d, b_stride=%d, c_stride=%d", blas_param.
a_stride,
218 if (min_offset < 0) {
219 errorQuda(
"BLAS offsets must be positive or zero: a_offset=%d, b_offset=%d, c_offset=%d", blas_param.
a_offset,
229 if (blas_param.
lda < std::max(1, blas_param.
m))
230 errorQuda(
"lda=%d must be >= max(1,m=%d)", blas_param.
lda, blas_param.
m);
232 if (blas_param.
lda < std::max(1, blas_param.
k))
233 errorQuda(
"lda=%d must be >= max(1,k=%d)", blas_param.
lda, blas_param.
k);
237 if (blas_param.
ldb < std::max(1, blas_param.
k))
238 errorQuda(
"ldb=%d must be >= max(1,k=%d)", blas_param.
ldb, blas_param.
k);
240 if (blas_param.
ldb < std::max(1, blas_param.
n))
241 errorQuda(
"ldb=%d must be >= max(1,n=%d)", blas_param.
ldb, blas_param.
n);
243 if (blas_param.
ldc < std::max(1, blas_param.
m))
244 errorQuda(
"ldc=%d must be >= max(1,m=%d)", blas_param.
ldc, blas_param.
m);
247 if (blas_param.
lda < std::max(1, blas_param.
k))
248 errorQuda(
"lda=%d must be >= max(1,k=%d)", blas_param.
lda, blas_param.
k);
250 if (blas_param.
lda < std::max(1, blas_param.
m))
251 errorQuda(
"lda=%d must be >= max(1,m=%d)", blas_param.
lda, blas_param.
m);
254 if (blas_param.
ldb < std::max(1, blas_param.
n))
255 errorQuda(
"ldb=%d must be >= max(1,n=%d)", blas_param.
ldb, blas_param.
n);
257 if (blas_param.
ldb < std::max(1, blas_param.
k))
258 errorQuda(
"ldb=%d must be >= max(1,k=%d)", blas_param.
ldb, blas_param.
k);
260 if (blas_param.
ldc < std::max(1, blas_param.
n))
261 errorQuda(
"ldc=%d must be >= max(1,n=%d)", blas_param.
ldc, blas_param.
n);
285 if (max_stride == 0) max_stride = 1;
288 const uint64_t batch = blas_param.
batch_count / max_stride;
298 unsigned int A_batch_size = blas_param.
lda * blas_param.
k;
300 unsigned int B_batch_size = blas_param.
ldb * blas_param.
n;
302 unsigned int C_batch_size = blas_param.
ldc * blas_param.
n;
307 unsigned int a_stride = blas_param.
a_stride == 0 ? A_batch_size : A_batch_size * blas_param.
a_stride;
308 unsigned int b_stride = blas_param.
b_stride == 0 ? B_batch_size : B_batch_size * blas_param.
b_stride;
309 unsigned int c_stride = blas_param.
c_stride == 0 ? C_batch_size : C_batch_size * blas_param.
c_stride;
312 size_t sizeAarr = A_batch_size * data_size * batch;
313 size_t sizeBarr = B_batch_size * data_size * batch;
314 size_t sizeCarr = C_batch_size * data_size * batch;
322 qudaMemcpy(A_d, A_data, sizeAarr, cudaMemcpyHostToDevice);
323 qudaMemcpy(B_d, B_data, sizeBarr, cudaMemcpyHostToDevice);
324 qudaMemcpy(C_d, C_data, sizeCarr, cudaMemcpyHostToDevice);
327 cublasOperation_t trans_a = CUBLAS_OP_N;
335 cublasOperation_t trans_b = CUBLAS_OP_N;
348 typedef cuDoubleComplex
Z;
350 const Z alpha = make_double2((
double)(
static_cast<std::complex<double>
>(blas_param.
alpha).real()),
351 (
double)(
static_cast<std::complex<double>
>(blas_param.
alpha).imag()));
353 const Z beta = make_double2((
double)(
static_cast<std::complex<double>
>(blas_param.
beta).real()),
354 (
double)(
static_cast<std::complex<double>
>(blas_param.
beta).imag()));
356 cublasStatus_t error;
358 error = cublasZgemmStridedBatched(handle, trans_a, trans_b, blas_param.
m, blas_param.
n, blas_param.
k,
359 &alpha, (
Z *)A_d + blas_param.
a_offset, blas_param.
lda, a_stride,
360 (
Z *)B_d + blas_param.
b_offset, blas_param.
ldb, b_stride, &beta,
361 (
Z *)C_d + blas_param.
c_offset, blas_param.
ldc, c_stride, batch);
363 if (error != CUBLAS_STATUS_SUCCESS)
364 errorQuda(
"\nError in cuBLASZGEMMStridedBatched, error code = %d\n", error);
366 error = cublasZgemm(handle, trans_a, trans_b, blas_param.
m, blas_param.
n, blas_param.
k, &alpha,
368 blas_param.
ldb, &beta, (
Z *)C_d + blas_param.
c_offset, blas_param.
ldc);
370 if (error != CUBLAS_STATUS_SUCCESS)
errorQuda(
"\nError in cuBLASZGEMM, error code = %d\n", error);
374 typedef cuFloatComplex C;
376 const C alpha = make_float2((
float)(
static_cast<std::complex<double>
>(blas_param.
alpha).real()),
377 (
float)(
static_cast<std::complex<double>
>(blas_param.
alpha).imag()));
379 const C beta = make_float2((
float)(
static_cast<std::complex<double>
>(blas_param.
beta).real()),
380 (
float)(
static_cast<std::complex<double>
>(blas_param.
beta).imag()));
382 cublasStatus_t error;
384 error = cublasCgemmStridedBatched(handle, trans_a, trans_b, blas_param.
m, blas_param.
n, blas_param.
k,
385 &alpha, (C *)A_d + blas_param.
a_offset, blas_param.
lda, a_stride,
386 (C *)B_d + blas_param.
b_offset, blas_param.
ldb, b_stride, &beta,
387 (C *)C_d + blas_param.
c_offset, blas_param.
ldc, c_stride, batch);
389 if (error != CUBLAS_STATUS_SUCCESS)
390 errorQuda(
"\nError in cuBLASCGEMMStridedBatched, error code = %d\n", error);
392 error = cublasCgemm(handle, trans_a, trans_b, blas_param.
m, blas_param.
n, blas_param.
k, &alpha,
394 blas_param.
ldb, &beta, (C *)C_d + blas_param.
c_offset, blas_param.
ldc);
396 if (error != CUBLAS_STATUS_SUCCESS)
errorQuda(
"\nError in cuBLASCGEMMBatched, error code = %d\n", error);
402 const D alpha = (D)(
static_cast<std::complex<double>
>(blas_param.
alpha).real());
403 const D beta = (D)(
static_cast<std::complex<double>
>(blas_param.
beta).real());
405 cublasStatus_t error;
407 error = cublasDgemmStridedBatched(handle, trans_a, trans_b, blas_param.
m, blas_param.
n, blas_param.
k,
408 &alpha, (D *)A_d + blas_param.
a_offset, blas_param.
lda, a_stride,
409 (D *)B_d + blas_param.
b_offset, blas_param.
ldb, b_stride, &beta,
410 (D *)C_d + blas_param.
c_offset, blas_param.
ldc, c_stride, batch);
412 if (error != CUBLAS_STATUS_SUCCESS)
413 errorQuda(
"\nError in cuBLASDGEMMStridedBatched, error code = %d\n", error);
415 error = cublasDgemm(handle, trans_a, trans_b, blas_param.
m, blas_param.
n, blas_param.
k, &alpha,
417 blas_param.
ldb, &beta, (D *)C_d + blas_param.
c_offset, blas_param.
ldc);
419 if (error != CUBLAS_STATUS_SUCCESS)
errorQuda(
"\nError in cuBLASDGEMMBatched, error code = %d\n", error);
425 const S alpha = (S)(
static_cast<std::complex<float>
>(blas_param.
alpha).real());
426 const S beta = (S)(
static_cast<std::complex<float>
>(blas_param.
beta).real());
428 cublasStatus_t error;
430 error = cublasSgemmStridedBatched(handle, trans_a, trans_b, blas_param.
m, blas_param.
n, blas_param.
k,
431 &alpha, (S *)A_d + blas_param.
a_offset, blas_param.
lda, a_stride,
432 (S *)B_d + blas_param.
b_offset, blas_param.
ldb, b_stride, &beta,
433 (S *)C_d + blas_param.
c_offset, blas_param.
ldc, c_stride, batch);
435 if (error != CUBLAS_STATUS_SUCCESS)
436 errorQuda(
"\nError in cuBLASSGEMMStridedBatched, error code = %d\n", error);
438 error = cublasSgemm(handle, trans_a, trans_b, blas_param.
m, blas_param.
n, blas_param.
k, &alpha,
440 blas_param.
ldb, &beta, (S *)C_d + blas_param.
c_offset, blas_param.
ldc);
442 if (error != CUBLAS_STATUS_SUCCESS)
errorQuda(
"\nError in cuBLASSGEMMBatched, error code = %d\n", error);
461 qudaMemcpy(C_data, C_d, sizeCarr, cudaMemcpyDeviceToHost);
468 gettimeofday(&
stop, NULL);
471 double time = ds + 0.000001 * dus;
473 printfQuda(
"Batched matrix GEMM completed in %f seconds with GFLOPS = %f\n", time, 1e-9 *
flops / time);
478 errorQuda(
"Native BLAS not built. Please build and use native BLAS or use generic BLAS");
#define FLOPS_CGETRF(m_, n_)
void * memset(void *s, int c, size_t n)
enum QudaPrecision_s QudaPrecision
@ QUDA_CUDA_FIELD_LOCATION
@ QUDA_CPU_FIELD_LOCATION
enum QudaFieldLocation_s QudaFieldLocation
@ QUDA_BLAS_DATAORDER_COL
@ QUDA_BLAS_DATAORDER_ROW
#define pool_pinned_malloc(size)
#define pool_device_malloc(size)
#define pool_pinned_free(ptr)
#define pool_device_free(ptr)
long long BatchInvertMatrix(void *Ainv, void *A, const int n, const uint64_t batch, QudaPrecision precision, QudaFieldLocation location)
Batch inversion the matrix field using an LU decomposition method.
void init()
Create the BLAS context.
long long stridedBatchGEMM(void *A, void *B, void *C, QudaBLASParam blas_param, QudaFieldLocation location)
Strided Batch GEMM. This function performs N GEMM type operations in a strided batched fashion....
void destroy()
Destroy the BLAS context.
void stop()
Stop profiling.
void start()
Start profiling.
__host__ __device__ ValueType norm(const complex< ValueType > &z)
Returns the magnitude of z squared.
FloatingPoint< float > Float
#define qudaMemcpy(dst, src, count, kind)
#define qudaDeviceSynchronize()
QudaBLASDataOrder data_order
QudaBLASOperation trans_a
QudaBLASDataType data_type
QudaBLASOperation trans_b
DEVICEHOST void swap(Real &a, Real &b)
QudaVerbosity getVerbosity()