52 default:
errorQuda(
"Undefined QUDA BLAS data type %d\n", data_type);
79 int min_dim = std::min(blas_param.
m, std::min(blas_param.
n, blas_param.
k));
81 errorQuda(
"BLAS dims must be positive: m=%d, n=%d, k=%d", blas_param.
m, blas_param.
n, blas_param.
k);
87 errorQuda(
"BLAS strides must be positive or zero: a_stride=%d, b_stride=%d, c_stride=%d", blas_param.
a_stride,
94 errorQuda(
"BLAS offsets must be positive or zero: a_offset=%d, b_offset=%d, c_offset=%d", blas_param.
a_offset,
101 if (blas_param.
lda < std::max(1, blas_param.
m))
102 errorQuda(
"lda=%d must be >= max(1,m=%d)", blas_param.
lda, blas_param.
m);
104 if (blas_param.
lda < std::max(1, blas_param.
k))
105 errorQuda(
"lda=%d must be >= max(1,k=%d)", blas_param.
lda, blas_param.
k);
109 if (blas_param.
ldb < std::max(1, blas_param.
k))
110 errorQuda(
"ldb=%d must be >= max(1,k=%d)", blas_param.
ldb, blas_param.
k);
112 if (blas_param.
ldb < std::max(1, blas_param.
n))
113 errorQuda(
"ldb=%d must be >= max(1,n=%d)", blas_param.
ldb, blas_param.
n);
115 if (blas_param.
ldc < std::max(1, blas_param.
m))
116 errorQuda(
"ldc=%d must be >= max(1,m=%d)", blas_param.
ldc, blas_param.
m);
119 if (blas_param.
lda < std::max(1, blas_param.
k))
120 errorQuda(
"lda=%d must be >= max(1,k=%d)", blas_param.
lda, blas_param.
k);
122 if (blas_param.
lda < std::max(1, blas_param.
m))
123 errorQuda(
"lda=%d must be >= max(1,m=%d)", blas_param.
lda, blas_param.
m);
126 if (blas_param.
ldb < std::max(1, blas_param.
n))
127 errorQuda(
"ldb=%d must be >= max(1,n=%d)", blas_param.
ldb, blas_param.
n);
129 if (blas_param.
ldb < std::max(1, blas_param.
k))
130 errorQuda(
"ldb=%d must be >= max(1,k=%d)", blas_param.
ldb, blas_param.
k);
132 if (blas_param.
ldc < std::max(1, blas_param.
n))
133 errorQuda(
"ldc=%d must be >= max(1,n=%d)", blas_param.
ldc, blas_param.
n);
141 size_t data_size =
sizeof(double);
147 int batches_extra = 0;
149 int batches = blas_param.
batch_count + batches_extra;
150 uint64_t refA_size = 0, refB_size = 0, refC_size = 0;
155 refA_size = blas_param.
lda * blas_param.
k;
157 refA_size = blas_param.
lda * blas_param.
m;
161 refB_size = blas_param.
ldb * blas_param.
n;
163 refB_size = blas_param.
ldb * blas_param.
k;
165 refC_size = blas_param.
ldc * blas_param.
n;
170 refA_size = blas_param.
lda * blas_param.
m;
172 refA_size = blas_param.
lda * blas_param.
k;
175 refB_size = blas_param.
ldb * blas_param.
k;
177 refB_size = blas_param.
ldb * blas_param.
n;
179 refC_size = blas_param.
ldc * blas_param.
m;
185 void *refCcopy =
pinned_malloc(batches * refC_size * data_size);
187 memset(refA, 0, batches * refA_size * data_size);
188 memset(refB, 0, batches * refB_size * data_size);
189 memset(refC, 0, batches * refC_size * data_size);
190 memset(refCcopy, 0, batches * refC_size * data_size);
193 for (uint64_t i = 0; i < 2 * refA_size * batches; i += 2) { ((
double *)refA)[i] = rand() / (double)RAND_MAX; }
194 for (uint64_t i = 0; i < 2 * refB_size * batches; i += 2) { ((
double *)refB)[i] = rand() / (double)RAND_MAX; }
195 for (uint64_t i = 0; i < 2 * refC_size * batches; i += 2) {
196 ((
double *)refC)[i] = rand() / (double)RAND_MAX;
197 ((
double *)refCcopy)[i] = ((
double *)refC)[i];
202 for (uint64_t i = 1; i < 2 * refA_size * batches; i += 2) { ((
double *)refA)[i] = rand() / (double)RAND_MAX; }
203 for (uint64_t i = 1; i < 2 * refB_size * batches; i += 2) { ((
double *)refB)[i] = rand() / (double)RAND_MAX; }
204 for (uint64_t i = 1; i < 2 * refC_size * batches; i += 2) {
205 ((
double *)refC)[i] = rand() / (double)RAND_MAX;
206 ((
double *)refCcopy)[i] = ((
double *)refC)[i];
211 void *arrayA =
nullptr;
212 void *arrayB =
nullptr;
213 void *arrayC =
nullptr;
214 void *arrayCcopy =
nullptr;
216 switch (test_data_type) {
221 arrayCcopy =
pinned_malloc(batches * refC_size *
sizeof(
float));
223 for (uint64_t i = 0; i < 2 * refA_size * batches; i += 2) { ((
float *)arrayA)[i / 2] = ((
double *)refA)[i]; }
224 for (uint64_t i = 0; i < 2 * refB_size * batches; i += 2) { ((
float *)arrayB)[i / 2] = ((
double *)refB)[i]; }
225 for (uint64_t i = 0; i < 2 * refC_size * batches; i += 2) {
226 ((
float *)arrayC)[i / 2] = ((
double *)refC)[i];
227 ((
float *)arrayCcopy)[i / 2] = ((
double *)refC)[i];
231 arrayA =
pinned_malloc(batches * refA_size *
sizeof(
double));
232 arrayB =
pinned_malloc(batches * refB_size *
sizeof(
double));
233 arrayC =
pinned_malloc(batches * refC_size *
sizeof(
double));
234 arrayCcopy =
pinned_malloc(batches * refC_size *
sizeof(
double));
236 for (uint64_t i = 0; i < 2 * refA_size * batches; i += 2) { ((
double *)arrayA)[i / 2] = ((
double *)refA)[i]; }
237 for (uint64_t i = 0; i < 2 * refB_size * batches; i += 2) { ((
double *)arrayB)[i / 2] = ((
double *)refB)[i]; }
238 for (uint64_t i = 0; i < 2 * refC_size * batches; i += 2) {
239 ((
double *)arrayC)[i / 2] = ((
double *)refC)[i];
240 ((
double *)arrayCcopy)[i / 2] = ((
double *)refC)[i];
244 arrayA =
pinned_malloc(batches * refA_size * 2 *
sizeof(
float));
245 arrayB =
pinned_malloc(batches * refB_size * 2 *
sizeof(
float));
246 arrayC =
pinned_malloc(batches * refC_size * 2 *
sizeof(
float));
247 arrayCcopy =
pinned_malloc(batches * refC_size * 2 *
sizeof(
float));
249 for (uint64_t i = 0; i < 2 * refA_size * batches; i++) { ((
float *)arrayA)[i] = ((
double *)refA)[i]; }
250 for (uint64_t i = 0; i < 2 * refB_size * batches; i++) { ((
float *)arrayB)[i] = ((
double *)refB)[i]; }
251 for (uint64_t i = 0; i < 2 * refC_size * batches; i++) {
252 ((
float *)arrayC)[i] = ((
double *)refC)[i];
253 ((
float *)arrayCcopy)[i] = ((
double *)refC)[i];
257 arrayA =
pinned_malloc(batches * refA_size * 2 *
sizeof(
double));
258 arrayB =
pinned_malloc(batches * refB_size * 2 *
sizeof(
double));
259 arrayC =
pinned_malloc(batches * refC_size * 2 *
sizeof(
double));
260 arrayCcopy =
pinned_malloc(batches * refC_size * 2 *
sizeof(
double));
262 for (uint64_t i = 0; i < 2 * refA_size * batches; i++) { ((
double *)arrayA)[i] = ((
double *)refA)[i]; }
263 for (uint64_t i = 0; i < 2 * refB_size * batches; i++) { ((
double *)arrayB)[i] = ((
double *)refB)[i]; }
264 for (uint64_t i = 0; i < 2 * refC_size * batches; i++) {
265 ((
double *)arrayC)[i] = ((
double *)refC)[i];
266 ((
double *)arrayCcopy)[i] = ((
double *)refC)[i];
269 default:
errorQuda(
"Unrecognised data type %d\n", test_data_type);
275 double deviation = 0.0;
277 deviation =
blasGEMMQudaVerify(arrayA, arrayB, arrayC, arrayCcopy, refA_size, refB_size, refC_size, &blas_param);
297 using ::testing::TestWithParam;
313 auto data_type = GetParam();
314 auto deviation =
test(data_type);
315 decltype(deviation)
tol;
322 EXPECT_LE(deviation,
tol) <<
"CPU and CUDA implementations do not agree";
328 int data_type =
param.param;
336 int main(
int argc,
char **argv)
347 app->parse(argc, argv);
348 }
catch (
const CLI::ParseError &e) {
377 if (result)
warningQuda(
"Google tests for QUDA BLAS failed.");
const char * data_type_str[]
double test(int data_type)
INSTANTIATE_TEST_SUITE_P(QUDA, BLASTest, Range(0, 4), getBLASName)
std::string getBLASName(testing::TestParamInfo< int > param)
int main(int argc, char **argv)
double blasGEMMQudaVerify(void *arrayA, void *arrayB, void *arrayC, void *arrayCcopy, uint64_t refA_size, uint64_t refB_size, uint64_t refC_size, QudaBLASParam *blas_param)
TestEventListener * Release(TestEventListener *listener)
TestEventListener * default_result_printer() const
TestEventListeners & listeners()
static UnitTest * GetInstance()
static const ParamType & GetParam()
std::shared_ptr< QUDAApp > make_app(std::string app_description, std::string app_name)
QudaBLASOperation blas_trans_b
std::array< double, 2 > blas_alpha_re_im
QudaBLASDataOrder blas_data_order
std::array< int, 3 > blas_leading_dims
std::array< int, 3 > blas_strides
std::array< double, 2 > blas_beta_re_im
std::array< int, 3 > blas_offsets
std::array< int, 3 > blas_mnk
QudaBLASOperation blas_trans_a
std::array< int, 4 > gridsize_from_cmdline
QudaBLASDataType blas_data_type
void * memset(void *s, int c, size_t n)
@ QUDA_BLAS_DATATYPE_INVALID
enum QudaBLASDataType_s QudaBLASDataType
@ QUDA_BLAS_DATAORDER_COL
int RUN_ALL_TESTS() GTEST_MUST_USE_RESULT_
#define EXPECT_LE(val1, val2)
int dimPartitioned(int dim)
void initComms(int argc, char **argv, std::array< int, 4 > &commDims)
#define pinned_malloc(size)
void setTransferGPU(bool)
internal::ParamGenerator< T > Range(T start, T end, IncrementT step)
internal::CartesianProductHolder< Generator... > Combine(const Generator &... g)
internal::ValueArray< T... > Values(T... v)
internal::ParamGenerator< bool > Bool()
GTEST_API_ void InitGoogleTest(int *argc, char **argv)
Main header file for the QUDA library.
QudaBLASParam newQudaBLASParam(void)
void blasGEMMQuda(void *arrayA, void *arrayB, void *arrayC, QudaBoolean native, QudaBLASParam *param)
Strided Batched GEMM.
void initQuda(int device)
QudaBLASDataOrder data_order
QudaBLASOperation trans_a
QudaBLASDataType data_type
QudaBLASOperation trans_b
void setVerbosity(QudaVerbosity verbosity)