QUDA  v1.1.0
A library for QCD on GPUs
dslash_test.cpp
Go to the documentation of this file.
1 #include "dslash_test_utils.h"
2 
3 using namespace quda;
4 
6 
8 {
9  printfQuda("running the following test:\n");
10 
11  printfQuda("prec recon dtest_type matpc_type dagger S_dim T_dimension Ls_dimension "
12  "dslash_type niter\n");
13  printfQuda("%6s %2s %s %12s %d %3d/%3d/%3d %3d %2d %14s %d\n",
17  printfQuda("Grid partition info: X Y Z T\n");
18  printfQuda(" %d %d %d %d\n", dimPartitioned(0), dimPartitioned(1), dimPartitioned(2),
19  dimPartitioned(3));
20 
22  printfQuda("Testing with split grid: %d %d %d %d\n", grid_partition[0], grid_partition[1], grid_partition[2],
23  grid_partition[3]);
24  }
25 }
26 
27 TEST(dslash, verify)
28 {
29  double deviation = dslash_test_wrapper.verify();
31  // If we are using tensor core we tolerate a greater deviation
33  tol *= 10;
35  tol *= 10; // if recon 8, we tolerate a greater deviation
36 
37  ASSERT_LE(deviation, tol) << "CPU and CUDA implementations do not agree";
38 }
39 
40 int main(int argc, char **argv)
41 {
42  // initalize google test, includes command line options
43  ::testing::InitGoogleTest(&argc, argv);
44 
45  // return code for google test
46  int test_rc = 0;
47  // command line options
48  auto app = make_app();
49  app->add_option("--test", dslash_test_wrapper.dtest_type, "Test method")
50  ->transform(CLI::CheckedTransformer(dtest_type_map));
53 
54  try {
55  app->parse(argc, argv);
56  } catch (const CLI::ParseError &e) {
57  return app->exit(e);
58  }
59 
60  initComms(argc, argv, gridsize_from_cmdline);
61 
62  // Ensure gtest prints only from rank 0
64  if (comm_rank() != 0) { delete listeners.Release(listeners.default_result_printer()); }
65 
67  dslash_test_wrapper.init_test(argc, argv);
68 
70 
71  int attempts = 1;
73  for (int i=0; i<attempts; i++) {
75  if (verify_results) {
77  if (comm_rank() != 0) { delete listeners.Release(listeners.default_result_printer()); }
78 
79  test_rc = RUN_ALL_TESTS();
80  if (test_rc != 0) warningQuda("Tests failed");
81  }
82  }
84 
85  endQuda();
86 
87  finalizeComms();
88  return test_rc;
89 }
TestEventListener * Release(TestEventListener *listener)
TestEventListener * default_result_printer() const
Definition: gtest.h:1186
TestEventListeners & listeners()
static UnitTest * GetInstance()
int comm_rank(void)
std::shared_ptr< QUDAApp > make_app(std::string app_description, std::string app_name)
double tol
QudaReconstructType link_recon
int niter
int device_ordinal
int & ydim
std::array< int, 4 > grid_partition
bool verify_results
void add_eofa_option_group(std::shared_ptr< QUDAApp > quda_app)
int & zdim
QudaDslashType dslash_type
QudaMatPCType matpc_type
QudaPrecision prec
int Lsdim
int & tdim
int & xdim
void add_comms_option_group(std::shared_ptr< QUDAApp > quda_app)
std::array< int, 4 > gridsize_from_cmdline
bool dagger
std::string get_string(CLI::TransformPairs< T > &map, T val)
int main(int argc, char **argv)
Definition: dslash_test.cpp:40
TEST(dslash, verify)
Definition: dslash_test.cpp:27
DslashTestWrapper dslash_test_wrapper
Definition: dslash_test.cpp:5
void display_test_info()
Definition: dslash_test.cpp:7
CLI::TransformPairs< dslash_test_type > dtest_type_map
@ QUDA_MOBIUS_DWF_DSLASH
Definition: enum_quda.h:95
@ QUDA_RECONSTRUCT_8
Definition: enum_quda.h:72
#define ASSERT_LE(val1, val2)
Definition: gtest.h:2055
int RUN_ALL_TESTS() GTEST_MUST_USE_RESULT_
Definition: gtest.h:2468
int dimPartitioned(int dim)
Definition: host_utils.cpp:376
void initComms(int argc, char **argv, std::array< int, 4 > &commDims)
Definition: host_utils.cpp:255
void finalizeComms()
Definition: host_utils.cpp:292
double getTolerance(QudaPrecision prec)
Definition: host_utils.h:245
const char * get_matpc_str(QudaMatPCType type)
Definition: misc.cpp:200
const char * get_prec_str(QudaPrecision prec)
Definition: misc.cpp:26
const char * get_dslash_str(QudaDslashType type)
Definition: misc.cpp:118
const char * get_recon_str(QudaReconstructType recon)
Definition: misc.cpp:68
GTEST_API_ void InitGoogleTest(int *argc, char **argv)
void initQuda(int device)
void endQuda(void)
void init_test(int argc, char **argv)
dslash_test_type dtest_type
void run_test(int niter, bool print_metrics=false)
QudaGaugeParam gauge_param
QudaInvertParam inv_param
QudaReconstructType reconstruct
Definition: quda.h:49
QudaPrecision cuda_prec
Definition: quda.h:238
#define printfQuda(...)
Definition: util_quda.h:114
#define warningQuda(...)
Definition: util_quda.h:132