7 #define FMULS_GETRF(m_, n_) ( ((m_) < (n_)) \ 8 ? (0.5 * (m_) * ((m_) * ((n_) - (1./3.) * (m_) - 1. ) + (n_)) + (2. / 3.) * (m_)) \ 9 : (0.5 * (n_) * ((n_) * ((m_) - (1./3.) * (n_) - 1. ) + (m_)) + (2. / 3.) * (n_)) ) 10 #define FADDS_GETRF(m_, n_) ( ((m_) < (n_)) \ 11 ? (0.5 * (m_) * ((m_) * ((n_) - (1./3.) * (m_) ) - (n_)) + (1. / 6.) * (m_)) \ 12 : (0.5 * (n_) * ((n_) * ((m_) - (1./3.) * (n_) ) - (m_)) + (1. / 6.) * (n_)) ) 14 #define FLOPS_ZGETRF(m_, n_) (6. * FMULS_GETRF((double)(m_), (double)(n_)) + 2.0 * FADDS_GETRF((double)(m_), (double)(n_)) ) 15 #define FLOPS_CGETRF(m_, n_) (6. * FMULS_GETRF((double)(m_), (double)(n_)) + 2.0 * FADDS_GETRF((double)(m_), (double)(n_)) ) 17 #define FMULS_GETRI(n_) ( (n_) * ((5. / 6.) + (n_) * ((2. / 3.) * (n_) + 0.5)) ) 18 #define FADDS_GETRI(n_) ( (n_) * ((5. / 6.) + (n_) * ((2. / 3.) * (n_) - 1.5)) ) 20 #define FLOPS_ZGETRI(n_) (6. * FMULS_GETRI((double)(n_)) + 2.0 * FADDS_GETRI((double)(n_)) ) 21 #define FLOPS_CGETRI(n_) (6. * FMULS_GETRI((double)(n_)) + 2.0 * FADDS_GETRI((double)(n_)) ) 28 static cublasHandle_t handle;
33 cublasStatus_t error = cublasCreate(&handle);
34 if (error != CUBLAS_STATUS_SUCCESS)
errorQuda(
"cublasCreate failed with error %d", error);
40 cublasStatus_t error = cublasDestroy(handle);
41 if (error != CUBLAS_STATUS_SUCCESS)
errorQuda(
"\nError indestroying cublas context, error code = %d\n", error);
47 __global__
void set_pointer(T **output_array_a, T *input_a, T **output_array_b, T *input_b,
int batch_offset)
49 output_array_a[blockIdx.x] = input_a + blockIdx.x * batch_offset;
50 output_array_b[blockIdx.x] = input_b + blockIdx.x * batch_offset;
59 gettimeofday(&start, NULL);
61 size_t size = 2*n*n*prec*batch;
71 typedef cuFloatComplex C;
75 set_pointer<C><<<batch,1>>>(A_array, (C*)A_d, Ainv_array, (C*)Ainv_d, n*n);
77 cublasStatus_t error = cublasCgetrfBatched(handle, n, A_array, n, dipiv, dinfo_array, batch);
80 if (error != CUBLAS_STATUS_SUCCESS)
81 errorQuda(
"\nError in LU decomposition (cublasCgetrfBatched), error code = %d\n", error);
83 qudaMemcpy(info_array, dinfo_array, batch*
sizeof(
int), cudaMemcpyDeviceToHost);
84 for (
int i=0; i<batch; i++) {
85 if (info_array[i] < 0) {
86 errorQuda(
"%d argument had an illegal value or another error occured, such as memory allocation failed", i);
87 }
else if (info_array[i] > 0) {
88 errorQuda(
"%d factorization completed but the factor U is exactly singular", i);
92 error = cublasCgetriBatched(handle, n, (
const C**)A_array, n, dipiv, Ainv_array, n, dinfo_array, batch);
95 if (error != CUBLAS_STATUS_SUCCESS)
96 errorQuda(
"\nError in matrix inversion (cublasCgetriBatched), error code = %d\n", error);
98 qudaMemcpy(info_array, dinfo_array, batch*
sizeof(
int), cudaMemcpyDeviceToHost);
100 for (
int i=0; i<batch; i++) {
101 if (info_array[i] < 0) {
102 errorQuda(
"%d argument had an illegal value or another error occured, such as memory allocation failed", i);
103 }
else if (info_array[i] > 0) {
104 errorQuda(
"%d factorization completed but the factor U is exactly singular", i);
112 errorQuda(
"%s not implemented for precision=%d", __func__, prec);
116 qudaMemcpy(Ainv, Ainv_d, size, cudaMemcpyDeviceToHost);
126 gettimeofday(&stop, NULL);
127 long ds = stop.tv_sec - start.tv_sec;
128 long dus = stop.tv_usec - start.tv_usec;
129 double time = ds + 0.000001*dus;
132 printfQuda(
"Batched matrix inversion completed in %f seconds with GFLOPS = %f\n", time, 1e-9 * flops / time);
#define qudaMemcpy(dst, src, count, kind)
#define pool_pinned_free(ptr)
enum QudaPrecision_s QudaPrecision
void destroy()
Destroy the CUBLAS context.
__global__ void set_pointer(T **output_array_a, T *input_a, T **output_array_b, T *input_b, int batch_offset)
QudaVerbosity getVerbosity()
#define qudaDeviceSynchronize()
long long BatchInvertMatrix(void *Ainv, void *A, const int n, const int batch, QudaPrecision precision, QudaFieldLocation location)
#define pool_device_malloc(size)
void init()
Create the CUBLAS context.
#define FLOPS_CGETRF(m_, n_)
#define pool_pinned_malloc(size)
enum QudaFieldLocation_s QudaFieldLocation
#define pool_device_free(ptr)