1 #include <blas_magma.h>
5 #include <quda_internal.h>
8 #define MAX(a, b) (a > b) ? a : b;
11 #define MAGMA_17 //default version version of the MAGMA library
32 #define _cU MagmaUpper
34 #define _cR MagmaRight
37 #define _cC MagmaConjTrans
38 #define _cN MagmaNoTrans
40 #define _cNV MagmaNoVec
46 //Column major format: Big matrix times Little matrix.
54 magma_int_t err = magma_init();
56 if(err != MAGMA_SUCCESS) errorQuda("\nError: cannot initialize MAGMA library\n");
58 int major, minor, micro;
60 magma_version( &major, &minor, µ);
61 printfQuda("\nMAGMA library version: %d.%d\n\n", major, minor);
63 errorQuda("\nError: MAGMA library was not compiled, check your compilation options...\n");
72 if(magma_finalize() != MAGMA_SUCCESS) errorQuda("\nError: cannot close MAGMA library\n");
74 errorQuda("\nError: MAGMA library was not compiled, check your compilation options...\n");
80 #define FMULS_GETRF(m_, n_) ( ((m_) < (n_)) \
81 ? (0.5 * (m_) * ((m_) * ((n_) - (1./3.) * (m_) - 1. ) + (n_)) + (2. / 3.) * (m_)) \
82 : (0.5 * (n_) * ((n_) * ((m_) - (1./3.) * (n_) - 1. ) + (m_)) + (2. / 3.) * (n_)) )
83 #define FADDS_GETRF(m_, n_) ( ((m_) < (n_)) \
84 ? (0.5 * (m_) * ((m_) * ((n_) - (1./3.) * (m_) ) - (n_)) + (1. / 6.) * (m_)) \
85 : (0.5 * (n_) * ((n_) * ((m_) - (1./3.) * (n_) ) - (m_)) + (1. / 6.) * (n_)) )
87 #define FLOPS_ZGETRF(m_, n_) (6. * FMULS_GETRF((double)(m_), (double)(n_)) + 2.0 * FADDS_GETRF((double)(m_), (double)(n_)) )
88 #define FLOPS_CGETRF(m_, n_) (6. * FMULS_GETRF((double)(m_), (double)(n_)) + 2.0 * FADDS_GETRF((double)(m_), (double)(n_)) )
90 #define FMULS_GETRI(n_) ( (n_) * ((5. / 6.) + (n_) * ((2. / 3.) * (n_) + 0.5)) )
91 #define FADDS_GETRI(n_) ( (n_) * ((5. / 6.) + (n_) * ((2. / 3.) * (n_) - 1.5)) )
93 #define FLOPS_ZGETRI(n_) (6. * FMULS_GETRI((double)(n_)) + 2.0 * FADDS_GETRI((double)(n_)) )
94 #define FLOPS_CGETRI(n_) (6. * FMULS_GETRI((double)(n_)) + 2.0 * FADDS_GETRI((double)(n_)) )
96 void BlasMagmaArgs::BatchInvertMatrix(void *Ainv_h, void* A_h, const int n, const uint64_t batch, const int prec)
99 printfQuda("%s with n=%d and batch=%d\n", __func__, n, batch);
101 magma_queue_t queue = 0;
103 size_t size = 2*n*n*prec*batch;
104 void *A_d = device_malloc(size);
105 void *Ainv_d = device_malloc(size);
106 qudaMemcpy(A_d, A_h, size, cudaMemcpyHostToDevice);
108 magma_int_t **dipiv_array = static_cast<magma_int_t**>(device_malloc(batch*sizeof(magma_int_t*)));
109 magma_int_t *dipiv_tmp = static_cast<magma_int_t*>(device_malloc(batch*n*sizeof(magma_int_t)));
110 set_ipointer(dipiv_array, dipiv_tmp, 1, 0, 0, n, batch, queue);
112 magma_int_t *no_piv_array = static_cast<magma_int_t*>(safe_malloc(batch*n*sizeof(magma_int_t)));
114 for (int i=0; i<batch; i++) {
115 for (int j=0; j<n; j++) {
116 no_piv_array[i*n + j] = j+1;
119 qudaMemcpy(dipiv_tmp, no_piv_array, batch*n*sizeof(magma_int_t), cudaMemcpyHostToDevice);
121 host_free(no_piv_array);
123 magma_int_t *dinfo_array = static_cast<magma_int_t*>(device_malloc(batch*sizeof(magma_int_t)));
124 magma_int_t *info_array = static_cast<magma_int_t*>(safe_malloc(batch*sizeof(magma_int_t)));
127 // FIXME do this in pipelined fashion to reduce memory overhead.
129 magmaFloatComplex **A_array = static_cast<magmaFloatComplex**>(device_malloc(batch*sizeof(magmaFloatComplex*)));
130 magmaFloatComplex **Ainv_array = static_cast<magmaFloatComplex**>(device_malloc(batch*sizeof(magmaFloatComplex*)));
132 cset_pointer(A_array, static_cast<magmaFloatComplex*>(A_d), n, 0, 0, n*n, batch, queue);
133 cset_pointer(Ainv_array, static_cast<magmaFloatComplex*>(Ainv_d), n, 0, 0, n*n, batch, queue);
135 double magma_time = magma_sync_wtime(queue);
136 //err = magma_cgetrf_batched(n, n, A_array, n, dipiv_array, dinfo_array, batch, queue);
137 err = magma_cgetrf_nopiv_batched(n, n, A_array, n, dinfo_array, batch, queue);
138 magma_time = magma_sync_wtime(queue) - magma_time;
139 printfQuda("LU factorization completed in %f seconds with GFLOPS = %f\n",
140 magma_time, 1e-9 * batch * FLOPS_CGETRF(n,n) / magma_time);
142 if(err != 0) errorQuda("\nError in LU decomposition (magma_cgetrf), error code = %d\n", err);
144 qudaMemcpy(info_array, dinfo_array, batch*sizeof(magma_int_t), cudaMemcpyDeviceToHost);
145 for (int i=0; i<batch; i++) {
146 if (info_array[i] < 0) {
147 errorQuda("%d argument had an illegal value or another error occured, such as memory allocation failed", i);
148 } else if (info_array[i] > 0) {
149 errorQuda("%d factorization completed but the factor U is exactly singular", i);
153 magma_time = magma_sync_wtime(queue);
154 err = magma_cgetri_outofplace_batched(n, A_array, n, dipiv_array, Ainv_array, n, dinfo_array, batch, queue);
155 magma_time = magma_sync_wtime(queue) - magma_time;
156 printfQuda("Matrix inversion completed in %f seconds with GFLOPS = %f\n",
157 magma_time, 1e-9 * batch * FLOPS_CGETRI(n) / magma_time);
159 if(err != 0) errorQuda("\nError in matrix inversion (magma_cgetri), error code = %d\n", err);
161 qudaMemcpy(info_array, dinfo_array, batch*sizeof(magma_int_t), cudaMemcpyDeviceToHost);
163 for (int i=0; i<batch; i++) {
164 if (info_array[i] < 0) {
165 errorQuda("%d argument had an illegal value or another error occured, such as memory allocation failed", i);
166 } else if (info_array[i] > 0) {
167 errorQuda("%d factorization completed but the factor U is exactly singular", i);
171 device_free(Ainv_array);
172 device_free(A_array);
173 } else if (prec == 8) {
174 magmaDoubleComplex **A_array = static_cast<magmaDoubleComplex**>(device_malloc(batch*sizeof(magmaDoubleComplex*)));
175 zset_pointer(A_array, static_cast<magmaDoubleComplex*>(A_d), n, 0, 0, n*n, batch, queue);
177 magmaDoubleComplex **Ainv_array = static_cast<magmaDoubleComplex**>(device_malloc(batch*sizeof(magmaDoubleComplex*)));
178 zset_pointer(Ainv_array, static_cast<magmaDoubleComplex*>(Ainv_d), n, 0, 0, n*n, batch, queue);
180 double magma_time = magma_sync_wtime(queue);
181 err = magma_zgetrf_batched(n, n, A_array, n, dipiv_array, dinfo_array, batch, queue);
182 magma_time = magma_sync_wtime(queue) - magma_time;
183 printfQuda("LU factorization completed in %f seconds with GFLOPS = %f\n",
184 magma_time, 1e-9 * batch * FLOPS_ZGETRF(n,n) / magma_time);
186 if(err != 0) errorQuda("\nError in LU decomposition (magma_zgetrf), error code = %d\n", err);
188 qudaMemcpy(info_array, dinfo_array, batch*sizeof(magma_int_t), cudaMemcpyDeviceToHost);
189 for (int i=0; i<batch; i++) {
190 if (info_array[i] < 0) {
191 errorQuda("%d argument had an illegal value or another error occured, such as memory allocation failed", i);
192 } else if (info_array[i] > 0) {
193 errorQuda("%d factorization completed but the factor U is exactly singular", i);
197 magma_time = magma_sync_wtime(queue);
198 err = magma_zgetri_outofplace_batched(n, A_array, n, dipiv_array, Ainv_array, n, dinfo_array, batch, queue);
199 magma_time = magma_sync_wtime(queue) - magma_time;
200 printfQuda("Matrix inversion completed in %f seconds with GFLOPS = %f\n",
201 magma_time, 1e-9 * batch * FLOPS_ZGETRI(n) / magma_time);
203 if(err != 0) errorQuda("\nError in matrix inversion (magma_cgetri), error code = %d\n", err);
205 qudaMemcpy(info_array, dinfo_array, batch*sizeof(magma_int_t), cudaMemcpyDeviceToHost);
207 for (int i=0; i<batch; i++) {
208 if (info_array[i] < 0) {
209 errorQuda("%d argument had an illegal value or another error occured, such as memory allocation failed", i);
210 } else if (info_array[i] > 0) {
211 errorQuda("%d factorization completed but the factor U is exactly singular", i);
215 device_free(Ainv_array);
216 device_free(A_array);
218 errorQuda("%s not implemented for precision=%d", __func__, prec);
221 qudaMemcpy(Ainv_h, Ainv_d, size, cudaMemcpyDeviceToHost);
223 device_free(dipiv_tmp);
224 device_free(dipiv_array);
225 device_free(dinfo_array);
226 host_free(info_array);