QUDA  v1.1.0
A library for QCD on GPUs
blas_interface_test.cpp
Go to the documentation of this file.
1 #include <stdlib.h>
2 #include <stdio.h>
3 #include <time.h>
4 #include <math.h>
5 #include <string.h>
6 #include <complex>
7 #include <inttypes.h>
8 
9 #include <util_quda.h>
10 #include <host_utils.h>
11 #include <command_line_params.h>
12 #include "blas_reference.h"
13 #include "misc.h"
14 
15 // google test
16 #include <gtest/gtest.h>
17 
18 // In a typical application, quda.h is the only QUDA header required.
19 #include <quda.h>
20 
21 // For googletest, names must be non-empty, unique, and may only contain ASCII
22 // alphanumeric characters or underscore.
23 const char *data_type_str[] = {
24  "realSingle",
25  "realDouble",
26  "complexSingle",
27  "complexDouble",
28 };
29 
30 namespace quda
31 {
32  extern void setTransferGPU(bool);
33 }
34 
36 {
37  printfQuda("running the following test:\n");
38  printfQuda("BLAS interface test\n");
39  printfQuda("Grid partition info: X Y Z T\n");
40  printfQuda(" %d %d %d %d\n", dimPartitioned(0), dimPartitioned(1), dimPartitioned(2),
41  dimPartitioned(3));
42 }
43 
44 double test(int data_type)
45 {
47  switch (data_type) {
48  case 0: test_data_type = QUDA_BLAS_DATATYPE_S; break;
49  case 1: test_data_type = QUDA_BLAS_DATATYPE_D; break;
50  case 2: test_data_type = QUDA_BLAS_DATATYPE_C; break;
51  case 3: test_data_type = QUDA_BLAS_DATATYPE_Z; break;
52  default: errorQuda("Undefined QUDA BLAS data type %d\n", data_type);
53  }
54 
55  QudaBLASParam blas_param = newQudaBLASParam();
56  blas_param.trans_a = blas_trans_a;
57  blas_param.trans_b = blas_trans_b;
58  blas_param.m = blas_mnk[0];
59  blas_param.n = blas_mnk[1];
60  blas_param.k = blas_mnk[2];
61  blas_param.lda = blas_leading_dims[0];
62  blas_param.ldb = blas_leading_dims[1];
63  blas_param.ldc = blas_leading_dims[2];
64  blas_param.a_offset = blas_offsets[0];
65  blas_param.b_offset = blas_offsets[1];
66  blas_param.c_offset = blas_offsets[2];
67  blas_param.a_stride = blas_strides[0];
68  blas_param.b_stride = blas_strides[1];
69  blas_param.c_stride = blas_strides[2];
70  blas_param.alpha = (__complex__ double)blas_alpha_re_im[0];
71  blas_param.beta = (__complex__ double)blas_beta_re_im[0];
72  blas_param.data_order = blas_data_order;
73  blas_param.data_type = test_data_type;
74  blas_param.batch_count = blas_batch;
75 
76  // Sanity checks on parameters
77  //-------------------------------------------------------------------------
78  // If the user passes non positive M,N, or K, we error out
79  int min_dim = std::min(blas_param.m, std::min(blas_param.n, blas_param.k));
80  if (min_dim <= 0) {
81  errorQuda("BLAS dims must be positive: m=%d, n=%d, k=%d", blas_param.m, blas_param.n, blas_param.k);
82  }
83 
84  // If the user passes a negative stride, we error out as this has no meaning.
85  int min_stride = std::min(std::min(blas_param.a_stride, blas_param.b_stride), blas_param.c_stride);
86  if (min_stride < 0) {
87  errorQuda("BLAS strides must be positive or zero: a_stride=%d, b_stride=%d, c_stride=%d", blas_param.a_stride,
88  blas_param.b_stride, blas_param.c_stride);
89  }
90 
91  // If the user passes a negative offset, we error out as this has no meaning.
92  int min_offset = std::min(std::min(blas_param.a_offset, blas_param.b_offset), blas_param.c_offset);
93  if (min_offset < 0) {
94  errorQuda("BLAS offsets must be positive or zero: a_offset=%d, b_offset=%d, c_offset=%d", blas_param.a_offset,
95  blas_param.b_offset, blas_param.c_offset);
96  }
97 
98  // Leading dims are dependendent on the matrix op type.
99  if (blas_param.data_order == QUDA_BLAS_DATAORDER_COL) {
100  if (blas_param.trans_a == QUDA_BLAS_OP_N) {
101  if (blas_param.lda < std::max(1, blas_param.m))
102  errorQuda("lda=%d must be >= max(1,m=%d)", blas_param.lda, blas_param.m);
103  } else {
104  if (blas_param.lda < std::max(1, blas_param.k))
105  errorQuda("lda=%d must be >= max(1,k=%d)", blas_param.lda, blas_param.k);
106  }
107 
108  if (blas_param.trans_b == QUDA_BLAS_OP_N) {
109  if (blas_param.ldb < std::max(1, blas_param.k))
110  errorQuda("ldb=%d must be >= max(1,k=%d)", blas_param.ldb, blas_param.k);
111  } else {
112  if (blas_param.ldb < std::max(1, blas_param.n))
113  errorQuda("ldb=%d must be >= max(1,n=%d)", blas_param.ldb, blas_param.n);
114  }
115  if (blas_param.ldc < std::max(1, blas_param.m))
116  errorQuda("ldc=%d must be >= max(1,m=%d)", blas_param.ldc, blas_param.m);
117  } else {
118  if (blas_param.trans_a == QUDA_BLAS_OP_N) {
119  if (blas_param.lda < std::max(1, blas_param.k))
120  errorQuda("lda=%d must be >= max(1,k=%d)", blas_param.lda, blas_param.k);
121  } else {
122  if (blas_param.lda < std::max(1, blas_param.m))
123  errorQuda("lda=%d must be >= max(1,m=%d)", blas_param.lda, blas_param.m);
124  }
125  if (blas_param.trans_b == QUDA_BLAS_OP_N) {
126  if (blas_param.ldb < std::max(1, blas_param.n))
127  errorQuda("ldb=%d must be >= max(1,n=%d)", blas_param.ldb, blas_param.n);
128  } else {
129  if (blas_param.ldb < std::max(1, blas_param.k))
130  errorQuda("ldb=%d must be >= max(1,k=%d)", blas_param.ldb, blas_param.k);
131  }
132  if (blas_param.ldc < std::max(1, blas_param.n))
133  errorQuda("ldc=%d must be >= max(1,n=%d)", blas_param.ldc, blas_param.n);
134  }
135 
136  // If the batch value is non-positve, we error out
137  if (blas_param.batch_count <= 0) { errorQuda("Batches must be positive: batches=%d", blas_param.batch_count); }
138  //-------------------------------------------------------------------------
139 
140  // Reference data is always in complex double
141  size_t data_size = sizeof(double);
142  int re_im = 2;
143  data_size *= re_im;
144 
145  // If the user passes non-zero offsets, add one extra
146  // matrix to the test data.
147  int batches_extra = 0;
148  if (blas_param.a_offset + blas_param.b_offset + blas_param.c_offset > 0) { batches_extra++; }
149  int batches = blas_param.batch_count + batches_extra;
150  uint64_t refA_size = 0, refB_size = 0, refC_size = 0;
151  if (blas_param.data_order == QUDA_BLAS_DATAORDER_COL) {
152  // leading dimension is in terms of consecutive data
153  // elements in a column, multiplied by number of rows
154  if (blas_param.trans_a == QUDA_BLAS_OP_N) {
155  refA_size = blas_param.lda * blas_param.k; // A_mk
156  } else {
157  refA_size = blas_param.lda * blas_param.m; // A_km
158  }
159 
160  if (blas_param.trans_b == QUDA_BLAS_OP_N) {
161  refB_size = blas_param.ldb * blas_param.n; // B_kn
162  } else {
163  refB_size = blas_param.ldb * blas_param.k; // B_nk
164  }
165  refC_size = blas_param.ldc * blas_param.n; // C_mn
166  } else {
167  // leading dimension is in terms of consecutive data
168  // elements in a row, multiplied by number of columns.
169  if (blas_param.trans_a == QUDA_BLAS_OP_N) {
170  refA_size = blas_param.lda * blas_param.m; // A_mk
171  } else {
172  refA_size = blas_param.lda * blas_param.k; // A_km
173  }
174  if (blas_param.trans_b == QUDA_BLAS_OP_N) {
175  refB_size = blas_param.ldb * blas_param.k; // B_nk
176  } else {
177  refB_size = blas_param.ldb * blas_param.n; // B_kn
178  }
179  refC_size = blas_param.ldc * blas_param.m; // C_mn
180  }
181 
182  void *refA = pinned_malloc(batches * refA_size * data_size);
183  void *refB = pinned_malloc(batches * refB_size * data_size);
184  void *refC = pinned_malloc(batches * refC_size * data_size);
185  void *refCcopy = pinned_malloc(batches * refC_size * data_size);
186 
187  memset(refA, 0, batches * refA_size * data_size);
188  memset(refB, 0, batches * refB_size * data_size);
189  memset(refC, 0, batches * refC_size * data_size);
190  memset(refCcopy, 0, batches * refC_size * data_size);
191 
192  // Populate the real part with rands
193  for (uint64_t i = 0; i < 2 * refA_size * batches; i += 2) { ((double *)refA)[i] = rand() / (double)RAND_MAX; }
194  for (uint64_t i = 0; i < 2 * refB_size * batches; i += 2) { ((double *)refB)[i] = rand() / (double)RAND_MAX; }
195  for (uint64_t i = 0; i < 2 * refC_size * batches; i += 2) {
196  ((double *)refC)[i] = rand() / (double)RAND_MAX;
197  ((double *)refCcopy)[i] = ((double *)refC)[i];
198  }
199 
200  // Populate the imaginary part with rands if needed
201  if (test_data_type == QUDA_BLAS_DATATYPE_C || test_data_type == QUDA_BLAS_DATATYPE_Z) {
202  for (uint64_t i = 1; i < 2 * refA_size * batches; i += 2) { ((double *)refA)[i] = rand() / (double)RAND_MAX; }
203  for (uint64_t i = 1; i < 2 * refB_size * batches; i += 2) { ((double *)refB)[i] = rand() / (double)RAND_MAX; }
204  for (uint64_t i = 1; i < 2 * refC_size * batches; i += 2) {
205  ((double *)refC)[i] = rand() / (double)RAND_MAX;
206  ((double *)refCcopy)[i] = ((double *)refC)[i];
207  }
208  }
209 
210  // Create new arrays appropriate for the requested problem, and copy over the data.
211  void *arrayA = nullptr;
212  void *arrayB = nullptr;
213  void *arrayC = nullptr;
214  void *arrayCcopy = nullptr;
215 
216  switch (test_data_type) {
218  arrayA = pinned_malloc(batches * refA_size * sizeof(float));
219  arrayB = pinned_malloc(batches * refB_size * sizeof(float));
220  arrayC = pinned_malloc(batches * refC_size * sizeof(float));
221  arrayCcopy = pinned_malloc(batches * refC_size * sizeof(float));
222  // Populate
223  for (uint64_t i = 0; i < 2 * refA_size * batches; i += 2) { ((float *)arrayA)[i / 2] = ((double *)refA)[i]; }
224  for (uint64_t i = 0; i < 2 * refB_size * batches; i += 2) { ((float *)arrayB)[i / 2] = ((double *)refB)[i]; }
225  for (uint64_t i = 0; i < 2 * refC_size * batches; i += 2) {
226  ((float *)arrayC)[i / 2] = ((double *)refC)[i];
227  ((float *)arrayCcopy)[i / 2] = ((double *)refC)[i];
228  }
229  break;
231  arrayA = pinned_malloc(batches * refA_size * sizeof(double));
232  arrayB = pinned_malloc(batches * refB_size * sizeof(double));
233  arrayC = pinned_malloc(batches * refC_size * sizeof(double));
234  arrayCcopy = pinned_malloc(batches * refC_size * sizeof(double));
235  // Populate
236  for (uint64_t i = 0; i < 2 * refA_size * batches; i += 2) { ((double *)arrayA)[i / 2] = ((double *)refA)[i]; }
237  for (uint64_t i = 0; i < 2 * refB_size * batches; i += 2) { ((double *)arrayB)[i / 2] = ((double *)refB)[i]; }
238  for (uint64_t i = 0; i < 2 * refC_size * batches; i += 2) {
239  ((double *)arrayC)[i / 2] = ((double *)refC)[i];
240  ((double *)arrayCcopy)[i / 2] = ((double *)refC)[i];
241  }
242  break;
244  arrayA = pinned_malloc(batches * refA_size * 2 * sizeof(float));
245  arrayB = pinned_malloc(batches * refB_size * 2 * sizeof(float));
246  arrayC = pinned_malloc(batches * refC_size * 2 * sizeof(float));
247  arrayCcopy = pinned_malloc(batches * refC_size * 2 * sizeof(float));
248  // Populate
249  for (uint64_t i = 0; i < 2 * refA_size * batches; i++) { ((float *)arrayA)[i] = ((double *)refA)[i]; }
250  for (uint64_t i = 0; i < 2 * refB_size * batches; i++) { ((float *)arrayB)[i] = ((double *)refB)[i]; }
251  for (uint64_t i = 0; i < 2 * refC_size * batches; i++) {
252  ((float *)arrayC)[i] = ((double *)refC)[i];
253  ((float *)arrayCcopy)[i] = ((double *)refC)[i];
254  }
255  break;
257  arrayA = pinned_malloc(batches * refA_size * 2 * sizeof(double));
258  arrayB = pinned_malloc(batches * refB_size * 2 * sizeof(double));
259  arrayC = pinned_malloc(batches * refC_size * 2 * sizeof(double));
260  arrayCcopy = pinned_malloc(batches * refC_size * 2 * sizeof(double));
261  // Populate
262  for (uint64_t i = 0; i < 2 * refA_size * batches; i++) { ((double *)arrayA)[i] = ((double *)refA)[i]; }
263  for (uint64_t i = 0; i < 2 * refB_size * batches; i++) { ((double *)arrayB)[i] = ((double *)refB)[i]; }
264  for (uint64_t i = 0; i < 2 * refC_size * batches; i++) {
265  ((double *)arrayC)[i] = ((double *)refC)[i];
266  ((double *)arrayCcopy)[i] = ((double *)refC)[i];
267  }
268  break;
269  default: errorQuda("Unrecognised data type %d\n", test_data_type);
270  }
271 
272  // Perform device GEMM Blas operation
273  blasGEMMQuda(arrayA, arrayB, arrayC, native_blas_lapack ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE, &blas_param);
274 
275  double deviation = 0.0;
276  if (verify_results) {
277  deviation = blasGEMMQudaVerify(arrayA, arrayB, arrayC, arrayCcopy, refA_size, refB_size, refC_size, &blas_param);
278  }
279 
280  host_free(refA);
281  host_free(refB);
282  host_free(refC);
283  host_free(refCcopy);
284 
285  host_free(arrayA);
286  host_free(arrayB);
287  host_free(arrayC);
288  host_free(arrayCcopy);
289 
290  return deviation;
291 }
292 
293 // The following tests gets each BLAS type and precision using google testing framework
297 using ::testing::TestWithParam;
299 
301 {
302 protected:
303  int param;
304 
305 public:
306  virtual ~BLASTest() {}
307  virtual void SetUp() { param = GetParam(); }
308 };
309 
310 // Sets up the Google test
311 TEST_P(BLASTest, verify)
312 {
313  auto data_type = GetParam();
314  auto deviation = test(data_type);
315  decltype(deviation) tol;
316  switch (data_type) {
317  case 0:
318  case 2: tol = 10 * std::numeric_limits<float>::epsilon(); break;
319  case 1:
320  case 3: tol = 10 * std::numeric_limits<double>::epsilon(); break;
321  }
322  EXPECT_LE(deviation, tol) << "CPU and CUDA implementations do not agree";
323 }
324 
325 // Helper function to construct the test name
327 {
328  int data_type = param.param;
329  std::string str(data_type_str[data_type]);
330  return str;
331 }
332 
333 // Instantiate all test cases
335 
336 int main(int argc, char **argv)
337 {
338  // Start Google Test Suite
339  //-----------------------------------------------------------------------------
340  ::testing::InitGoogleTest(&argc, argv);
341 
342  // QUDA initialise
343  //-----------------------------------------------------------------------------
344  // command line options
345  auto app = make_app();
346  try {
347  app->parse(argc, argv);
348  } catch (const CLI::ParseError &e) {
349  return app->exit(e);
350  }
351 
352  // initialize QMP/MPI, QUDA comms grid and RNG (host_utils.cpp)
353  initComms(argc, argv, gridsize_from_cmdline);
354 
355  // Ensure gtest prints only from rank 0
357  if (comm_rank() != 0) { delete listeners.Release(listeners.default_result_printer()); }
358 
359  // call srand() with a rank-dependent seed
360  initRand();
364 
365  // initialize the QUDA library
367  int X[4] = {xdim, ydim, zdim, tdim};
368  setDims(X);
369  //-----------------------------------------------------------------------------
370 
371  int result = 0;
372  if (verify_results) {
373  // Run full set of test if we're doing a verification run
375  if (comm_rank() != 0) { delete listeners.Release(listeners.default_result_printer()); }
376  result = RUN_ALL_TESTS();
377  if (result) warningQuda("Google tests for QUDA BLAS failed.");
378  } else {
379  // Perform the BLAS op specified by the command line
380  switch (blas_data_type) {
381  case QUDA_BLAS_DATATYPE_S: test(0); break;
382  case QUDA_BLAS_DATATYPE_D: test(1); break;
383  case QUDA_BLAS_DATATYPE_C: test(2); break;
384  case QUDA_BLAS_DATATYPE_Z: test(3); break;
385  default: errorQuda("Undefined QUDA BLAS data type %d\n", blas_data_type);
386  }
387  }
388 
389  //-----------------------------------------------------------------------------
390 
391  // finalize the QUDA library
392  endQuda();
393 
394  // finalize the communications layer
395  finalizeComms();
396 
397  return result;
398 }
const char * data_type_str[]
double test(int data_type)
INSTANTIATE_TEST_SUITE_P(QUDA, BLASTest, Range(0, 4), getBLASName)
std::string getBLASName(testing::TestParamInfo< int > param)
int main(int argc, char **argv)
TEST_P(BLASTest, verify)
void display_test_info()
double blasGEMMQudaVerify(void *arrayA, void *arrayB, void *arrayC, void *arrayCcopy, uint64_t refA_size, uint64_t refB_size, uint64_t refC_size, QudaBLASParam *blas_param)
virtual ~BLASTest()
virtual void SetUp()
TestEventListener * Release(TestEventListener *listener)
TestEventListener * default_result_printer() const
Definition: gtest.h:1186
TestEventListeners & listeners()
static UnitTest * GetInstance()
static const ParamType & GetParam()
Definition: gtest.h:1851
int comm_rank(void)
std::shared_ptr< QUDAApp > make_app(std::string app_description, std::string app_name)
double tol
int device_ordinal
int blas_batch
QudaVerbosity verbosity
int & ydim
QudaBLASOperation blas_trans_b
double epsilon
bool verify_results
std::array< double, 2 > blas_alpha_re_im
QudaBLASDataOrder blas_data_order
int & zdim
std::array< int, 3 > blas_leading_dims
std::array< int, 3 > blas_strides
bool native_blas_lapack
std::array< double, 2 > blas_beta_re_im
std::array< int, 3 > blas_offsets
int & tdim
int & xdim
std::array< int, 3 > blas_mnk
QudaBLASOperation blas_trans_a
std::array< int, 4 > gridsize_from_cmdline
QudaBLASDataType blas_data_type
void * memset(void *s, int c, size_t n)
void setDims(int *)
Definition: host_utils.cpp:315
@ QUDA_BLAS_DATATYPE_INVALID
Definition: enum_quda.h:481
@ QUDA_BLAS_DATATYPE_Z
Definition: enum_quda.h:480
@ QUDA_BLAS_DATATYPE_D
Definition: enum_quda.h:478
@ QUDA_BLAS_DATATYPE_C
Definition: enum_quda.h:479
@ QUDA_BLAS_DATATYPE_S
Definition: enum_quda.h:477
@ QUDA_BOOLEAN_FALSE
Definition: enum_quda.h:460
@ QUDA_BOOLEAN_TRUE
Definition: enum_quda.h:461
enum QudaBLASDataType_s QudaBLASDataType
@ QUDA_BLAS_DATAORDER_COL
Definition: enum_quda.h:486
@ QUDA_BLAS_OP_N
Definition: enum_quda.h:470
int RUN_ALL_TESTS() GTEST_MUST_USE_RESULT_
Definition: gtest.h:2468
#define EXPECT_LE(val1, val2)
Definition: gtest.h:2021
int dimPartitioned(int dim)
Definition: host_utils.cpp:376
void initComms(int argc, char **argv, std::array< int, 4 > &commDims)
Definition: host_utils.cpp:255
void finalizeComms()
Definition: host_utils.cpp:292
void setQudaPrecisions()
Definition: host_utils.cpp:69
void initRand()
Definition: host_utils.cpp:302
#define pinned_malloc(size)
Definition: malloc_quda.h:107
#define host_free(ptr)
Definition: malloc_quda.h:115
void setTransferGPU(bool)
::std::string string
Definition: gtest-port.h:891
internal::ParamGenerator< T > Range(T start, T end, IncrementT step)
internal::CartesianProductHolder< Generator... > Combine(const Generator &... g)
internal::ValueArray< T... > Values(T... v)
internal::ParamGenerator< bool > Bool()
GTEST_API_ void InitGoogleTest(int *argc, char **argv)
QudaGaugeParam param
Definition: pack_test.cpp:18
Main header file for the QUDA library.
QudaBLASParam newQudaBLASParam(void)
void blasGEMMQuda(void *arrayA, void *arrayB, void *arrayC, QudaBoolean native, QudaBLASParam *param)
Strided Batched GEMM.
void initQuda(int device)
void endQuda(void)
int c_offset
Definition: quda.h:761
double_complex alpha
Definition: quda.h:766
int a_stride
Definition: quda.h:762
int b_stride
Definition: quda.h:763
QudaBLASDataOrder data_order
Definition: quda.h:772
int c_stride
Definition: quda.h:764
int b_offset
Definition: quda.h:760
QudaBLASOperation trans_a
Definition: quda.h:751
double_complex beta
Definition: quda.h:767
QudaBLASDataType data_type
Definition: quda.h:771
int a_offset
Definition: quda.h:759
int batch_count
Definition: quda.h:769
QudaBLASOperation trans_b
Definition: quda.h:752
#define printfQuda(...)
Definition: util_quda.h:114
#define warningQuda(...)
Definition: util_quda.h:132
void setVerbosity(QudaVerbosity verbosity)
Definition: util_quda.cpp:25
#define errorQuda(...)
Definition: util_quda.h:120