16 #include <Eigen/Dense>
17 using namespace Eigen;
19 void fillEigenArray(MatrixXcd &EigenArr, complex<double> *arr,
int rows,
int cols,
int ld,
int offset)
22 for (
int i = 0; i < rows; i++) {
23 for (
int j = 0; j < cols; j++) {
24 EigenArr(i, j) = arr[counter];
27 counter += (ld - cols);
31 double blasGEMMEigenVerify(
void *A_data,
void *B_data,
void *C_data_copy,
void *C_data, uint64_t refA_size,
32 uint64_t refB_size, uint64_t refC_size,
QudaBLASParam *blas_param)
37 int min_dim = std::min(blas_param->
m, std::min(blas_param->
n, blas_param->
k));
39 errorQuda(
"BLAS dims must be positive: m=%d, n=%d, k=%d", blas_param->
m, blas_param->
n, blas_param->
k);
45 errorQuda(
"BLAS strides must be positive or zero: a_stride=%d, b_stride=%d, c_stride=%d", blas_param->
a_stride,
52 errorQuda(
"BLAS offsets must be positive or zero: a_offset=%d, b_offset=%d, c_offset=%d", blas_param->
a_offset,
62 if (blas_param->
lda < std::max(1, blas_param->
m))
63 errorQuda(
"lda=%d must be >= max(1,m=%d)", blas_param->
lda, blas_param->
m);
65 if (blas_param->
lda < std::max(1, blas_param->
k))
66 errorQuda(
"lda=%d must be >= max(1,k=%d)", blas_param->
lda, blas_param->
k);
70 if (blas_param->
ldb < std::max(1, blas_param->
k))
71 errorQuda(
"ldb=%d must be >= max(1,k=%d)", blas_param->
ldb, blas_param->
k);
73 if (blas_param->
ldb < std::max(1, blas_param->
n))
74 errorQuda(
"ldb=%d must be >= max(1,n=%d)", blas_param->
ldb, blas_param->
n);
76 if (blas_param->
ldc < std::max(1, blas_param->
m))
77 errorQuda(
"ldc=%d must be >= max(1,m=%d)", blas_param->
ldc, blas_param->
m);
80 if (blas_param->
lda < std::max(1, blas_param->
k))
81 errorQuda(
"lda=%d must be >= max(1,k=%d)", blas_param->
lda, blas_param->
k);
83 if (blas_param->
lda < std::max(1, blas_param->
m))
84 errorQuda(
"lda=%d must be >= max(1,m=%d)", blas_param->
lda, blas_param->
m);
87 if (blas_param->
ldb < std::max(1, blas_param->
n))
88 errorQuda(
"ldb=%d must be >= max(1,n=%d)", blas_param->
ldb, blas_param->
n);
90 if (blas_param->
ldb < std::max(1, blas_param->
k))
91 errorQuda(
"ldb=%d must be >= max(1,k=%d)", blas_param->
ldb, blas_param->
k);
93 if (blas_param->
ldc < std::max(1, blas_param->
n))
94 errorQuda(
"ldc=%d must be >= max(1,n=%d)", blas_param->
ldc, blas_param->
n);
111 int m = blas_param->
m;
112 int n = blas_param->
n;
113 int k = blas_param->
k;
115 int lda = blas_param->
lda;
116 int ldb = blas_param->
ldb;
117 int ldc = blas_param->
ldc;
119 int a_stride = blas_param->
a_stride;
120 int b_stride = blas_param->
b_stride;
121 int c_stride = blas_param->
c_stride;
123 int a_offset = blas_param->
a_offset;
124 int b_offset = blas_param->
b_offset;
125 int c_offset = blas_param->
c_offset;
129 complex<double> alpha = blas_param->
alpha;
130 complex<double> beta = blas_param->
beta;
133 MatrixXcd A = MatrixXd::Zero(m, k);
134 MatrixXcd B = MatrixXd::Zero(k, n);
135 MatrixXcd C_eigen = MatrixXd::Zero(m, n);
136 MatrixXcd C_gpu = MatrixXd::Zero(m, n);
137 MatrixXcd C_resid = MatrixXd::Zero(m, n);
140 complex<double> *A_ptr = (complex<double> *)(&A_data)[0];
141 complex<double> *B_ptr = (complex<double> *)(&B_data)[0];
142 complex<double> *C_ptr = (complex<double> *)(&C_data)[0];
143 complex<double> *Ccopy_ptr = (complex<double> *)(&C_data_copy)[0];
147 int max_stride = std::max(std::max(a_stride, b_stride), c_stride);
153 if (max_stride <= 0) max_stride = 1;
155 printfQuda(
"Computing Eigen matrix operation a * A_{%lu,%lu} * B_{%lu,%lu} + b * C_{%lu,%lu} = C_{%lu,%lu}\n",
156 A.rows(), A.cols(), B.rows(), B.cols(), C_eigen.rows(), C_eigen.cols(), C_eigen.rows(), C_eigen.cols());
158 double max_relative_deviation = 0.0;
159 for (
int batch = 0; batch < batches; batch += max_stride) {
183 C_eigen = alpha * A * B + beta * C_eigen;
186 C_resid = C_gpu - C_eigen;
187 double deviation = C_resid.norm();
188 double relative_deviation = deviation / C_eigen.norm();
189 max_relative_deviation = std::max(max_relative_deviation, relative_deviation);
191 printfQuda(
"batch %d: (C_host - C_gpu) Frobenius norm = %e. Relative deviation = %e\n", batch, deviation,
194 a_offset += refA_size * a_stride;
195 b_offset += refB_size * b_stride;
196 c_offset += refC_size * c_stride;
209 return max_relative_deviation;
212 double blasGEMMQudaVerify(
void *arrayA,
void *arrayB,
void *arrayC,
void *arrayCcopy, uint64_t refA_size,
213 uint64_t refB_size, uint64_t refC_size,
QudaBLASParam *blas_param)
216 size_t data_size =
sizeof(double);
222 int batches_extra = 0;
224 int batches = blas_param->
batch_count + batches_extra;
228 void *checkA =
pinned_malloc(refA_size * data_size * batches);
229 void *checkB =
pinned_malloc(refB_size * data_size * batches);
230 void *checkC =
pinned_malloc(refC_size * data_size * batches);
231 void *checkCcopy =
pinned_malloc(refC_size * data_size * batches);
233 memset(checkA, 0, batches * refA_size * data_size);
234 memset(checkB, 0, batches * refB_size * data_size);
235 memset(checkC, 0, batches * refC_size * data_size);
236 memset(checkCcopy, 0, batches * refC_size * data_size);
240 for (uint64_t i = 0; i < 2 * refA_size * batches; i += 2) { ((
double *)checkA)[i] = ((
float *)arrayA)[i / 2]; }
241 for (uint64_t i = 0; i < 2 * refB_size * batches; i += 2) { ((
double *)checkB)[i] = ((
float *)arrayB)[i / 2]; }
242 for (uint64_t i = 0; i < 2 * refC_size * batches; i += 2) {
243 ((
double *)checkC)[i] = ((
float *)arrayC)[i / 2];
244 ((
double *)checkCcopy)[i] = ((
float *)arrayCcopy)[i / 2];
248 for (uint64_t i = 0; i < 2 * refA_size * batches; i += 2) { ((
double *)checkA)[i] = ((
double *)arrayA)[i / 2]; }
249 for (uint64_t i = 0; i < 2 * refB_size * batches; i += 2) { ((
double *)checkB)[i] = ((
double *)arrayB)[i / 2]; }
250 for (uint64_t i = 0; i < 2 * refC_size * batches; i += 2) {
251 ((
double *)checkC)[i] = ((
double *)arrayC)[i / 2];
252 ((
double *)checkCcopy)[i] = ((
double *)arrayCcopy)[i / 2];
256 for (uint64_t i = 0; i < 2 * refA_size * batches; i++) { ((
double *)checkA)[i] = ((
float *)arrayA)[i]; }
257 for (uint64_t i = 0; i < 2 * refB_size * batches; i++) { ((
double *)checkB)[i] = ((
float *)arrayB)[i]; }
258 for (uint64_t i = 0; i < 2 * refC_size * batches; i++) {
259 ((
double *)checkC)[i] = ((
float *)arrayC)[i];
260 ((
double *)checkCcopy)[i] = ((
float *)arrayCcopy)[i];
264 for (uint64_t i = 0; i < 2 * refA_size * batches; i++) { ((
double *)checkA)[i] = ((
double *)arrayA)[i]; }
265 for (uint64_t i = 0; i < 2 * refB_size * batches; i++) { ((
double *)checkB)[i] = ((
double *)arrayB)[i]; }
266 for (uint64_t i = 0; i < 2 * refC_size * batches; i++) {
267 ((
double *)checkC)[i] = ((
double *)arrayC)[i];
268 ((
double *)checkCcopy)[i] = ((
double *)arrayCcopy)[i];
274 auto deviation =
blasGEMMEigenVerify(checkA, checkB, checkCcopy, checkC, refA_size, refB_size, refC_size, blas_param);
double blasGEMMEigenVerify(void *A_data, void *B_data, void *C_data_copy, void *C_data, uint64_t refA_size, uint64_t refB_size, uint64_t refC_size, QudaBLASParam *blas_param)
double blasGEMMQudaVerify(void *arrayA, void *arrayB, void *arrayC, void *arrayCcopy, uint64_t refA_size, uint64_t refB_size, uint64_t refC_size, QudaBLASParam *blas_param)
void fillEigenArray(MatrixXcd &EigenArr, complex< double > *arr, int rows, int cols, int ld, int offset)
void * memset(void *s, int c, size_t n)
@ QUDA_BLAS_DATAORDER_COL
#define pinned_malloc(size)
QudaBLASDataOrder data_order
QudaBLASOperation trans_a
QudaBLASDataType data_type
QudaBLASOperation trans_b
DEVICEHOST void swap(Real &a, Real &b)