QUDA  v1.1.0
A library for QCD on GPUs
dslash_ctest.cpp
Go to the documentation of this file.
1 #include "dslash_test_utils.h"
2 
3 using namespace quda;
4 
6 
7 // For loading the gauge fields
9 char **argv_copy;
10 
11 const char *prec_str[] = {"quarter", "half", "single", "double"};
12 const char *recon_str[] = {"r18", "r12", "r8"};
13 
14 // For googletest names must be non-empty, unique, and may only contain ASCII
15 // alphanumeric characters or underscore
16 
18 {
19  auto prec = getPrecision(precision);
20  // printfQuda("running the following test:\n");
21 
22  printfQuda("prec recon test_type matpc_type dagger S_dim T_dimension Ls_dimension dslash_type niter\n");
23  printfQuda("%6s %2s %s %12s %d %3d/%3d/%3d %3d %2d %14s %d\n",
27  // printfQuda("Grid partition info: X Y Z T\n");
28  // printfQuda(" %d %d %d %d\n",
29  // dimPartitioned(0),
30  // dimPartitioned(1),
31  // dimPartitioned(2),
32  // dimPartitioned(3));
33 
35  printfQuda("Testing with split grid: %d %d %d %d\n", grid_partition[0], grid_partition[1], grid_partition[2],
36  grid_partition[3]);
37  }
38 
39  return ;
40 
41 }
42 
43 using ::testing::TestWithParam;
48 
49 class DslashTest : public ::testing::TestWithParam<::testing::tuple<int, int, int>>
50 {
51 protected:
52  ::testing::tuple<int, int, int> param;
53 
54  bool skip()
55  {
56  QudaReconstructType recon = static_cast<QudaReconstructType>(::testing::get<1>(GetParam()));
57 
58  if ((QUDA_PRECISION & getPrecision(::testing::get<0>(GetParam()))) == 0
59  || (QUDA_RECONSTRUCT & getReconstructNibble(recon)) == 0) {
60  return true;
61  }
62 
64  && (::testing::get<0>(GetParam()) == 2 || ::testing::get<0>(GetParam()) == 3)) {
65  warningQuda("Only fixed precision supported for MatPCDagMatPCLocal operator, skipping...");
66  return true;
67  }
68 
69  if (::testing::get<2>(GetParam()) > 0 && dslash_test_wrapper.test_split_grid) { return true; }
70 
71  return false;
72  }
73 
74 public:
75  virtual ~DslashTest() { }
76  virtual void SetUp() {
77  int prec = ::testing::get<0>(GetParam());
78  QudaReconstructType recon = static_cast<QudaReconstructType>(::testing::get<1>(GetParam()));
79 
80  if (skip()) GTEST_SKIP();
81 
82  int value = ::testing::get<2>(GetParam());
83  for(int j=0; j < 4;j++){
84  if (value & (1 << j)){
86  }
87  }
88  updateR();
89 
91  display_test_info(prec, recon);
92  }
93 
94  virtual void TearDown()
95  {
96  if (skip()) GTEST_SKIP();
99  }
100 
102 
103  // Per-test-case tear-down.
104  // Called after the last test in this test case.
105  // Can be omitted if not needed.
106  static void TearDownTestCase() {
107  endQuda();
108  }
109 };
110 
112 {
115 
116  double deviation = dslash_test_wrapper.verify();
118  // If we are using tensor core we tolerate a greater deviation
120  tol *= 10;
123  tol *= 10; // if recon 8, we tolerate a greater deviation
124 
125  ASSERT_LE(deviation, tol) << "CPU and CUDA implementations do not agree";
126 }
127 
129 
130 int main(int argc, char **argv)
131 {
132  // initalize google test, includes command line options
133  ::testing::InitGoogleTest(&argc, argv);
134  // return code for google test
135  int test_rc = 0;
136  // command line options
137  auto app = make_app();
138  app->add_option("--test", dslash_test_wrapper.dtest_type, "Test method")
139  ->transform(CLI::CheckedTransformer(dtest_type_map));
141  try {
142  app->parse(argc, argv);
143  } catch (const CLI::ParseError &e) {
144  return app->exit(e);
145  }
146 
147  initComms(argc, argv, gridsize_from_cmdline);
148 
151 
152  // The 'SetUp()' method of the Google Test class from which DslashTest
153  // in derived has no arguments, but QUDA's implementation requires the
154  // use of argc and argv to set up the test via the function 'init'.
155  // As a workaround, we declare argc_copy and argv_copy as global pointers
156  // so that they are visible inside the 'init' function.
157  argc_copy = argc;
158  argv_copy = argv;
159 
161  if (comm_rank() != 0) { delete listeners.Release(listeners.default_result_printer()); }
162  test_rc = RUN_ALL_TESTS();
163 
164  finalizeComms();
165  return test_rc;
166 }
167 
169 {
170  const int prec = ::testing::get<0>(param.param);
171  const int recon = ::testing::get<1>(param.param);
172  const int part = ::testing::get<2>(param.param);
173  std::stringstream ss;
174  // std::cout << "getdslashtestname" << get_dslash_str(dslash_type) << "_" << prec_str[prec] << "_r" << recon <<
175  // "_partition" << part << std::endl; ss << get_dslash_str(dslash_type) << "_";
176  ss << prec_str[prec];
177  ss << "_r" << recon;
178  ss << "_partition" << part;
179  return ss.str();
180 }
181 
182 #ifdef MULTI_GPU
184  Combine(Range(0, 4),
186  Range(0, 16)),
188 #else
190  Combine(Range(0, 4),
192  ::testing::Values(0)),
194 #endif
double benchmark(Kernel kernel, const int niter)
Definition: blas_test.cpp:349
static void SetUpTestCase()
virtual void TearDown()
::testing::tuple< int, int, int > param
static void TearDownTestCase()
virtual ~DslashTest()
virtual void SetUp()
TestEventListener * Release(TestEventListener *listener)
TestEventListener * default_result_printer() const
Definition: gtest.h:1186
TestEventListeners & listeners()
static UnitTest * GetInstance()
int comm_rank(void)
void commDimPartitionedReset()
Reset the comm dim partioned array to zero,.
void commDimPartitionedSet(int dir)
std::shared_ptr< QUDAApp > make_app(std::string app_description, std::string app_name)
double tol
QudaReconstructType link_recon
int niter
int device_ordinal
int & ydim
std::array< int, 4 > grid_partition
int & zdim
QudaDslashType dslash_type
QudaMatPCType matpc_type
QudaPrecision prec
int Lsdim
int & tdim
int & xdim
void add_comms_option_group(std::shared_ptr< QUDAApp > quda_app)
std::array< int, 4 > gridsize_from_cmdline
bool dagger
std::string get_string(CLI::TransformPairs< T > &map, T val)
INSTANTIATE_TEST_SUITE_P(QUDA, DslashTest, Combine(Range(0, 4), ::testing::Values(QUDA_RECONSTRUCT_NO, QUDA_RECONSTRUCT_12, QUDA_RECONSTRUCT_8), ::testing::Values(0)), getdslashtestname)
std::string getdslashtestname(testing::TestParamInfo<::testing::tuple< int, int, int >> param)
const char * recon_str[]
int main(int argc, char **argv)
DslashTestWrapper dslash_test_wrapper
Definition: dslash_ctest.cpp:5
void display_test_info(int precision, QudaReconstructType link_recon)
const char * prec_str[]
int argc_copy
Definition: dslash_ctest.cpp:8
TEST_P(DslashTest, verify)
char ** argv_copy
Definition: dslash_ctest.cpp:9
CLI::TransformPairs< dslash_test_type > dtest_type_map
@ QUDA_MOBIUS_DWF_DSLASH
Definition: enum_quda.h:95
@ QUDA_RECONSTRUCT_NO
Definition: enum_quda.h:70
@ QUDA_RECONSTRUCT_12
Definition: enum_quda.h:71
@ QUDA_RECONSTRUCT_8
Definition: enum_quda.h:72
enum QudaReconstructType_s QudaReconstructType
@ QUDA_HALF_PRECISION
Definition: enum_quda.h:63
#define GTEST_SKIP()
Definition: gtest.h:1887
#define ASSERT_LE(val1, val2)
Definition: gtest.h:2055
int RUN_ALL_TESTS() GTEST_MUST_USE_RESULT_
Definition: gtest.h:2468
void initComms(int argc, char **argv, std::array< int, 4 > &commDims)
Definition: host_utils.cpp:255
void finalizeComms()
Definition: host_utils.cpp:292
int getReconstructNibble(QudaReconstructType recon)
Definition: host_utils.h:233
double getTolerance(QudaPrecision prec)
Definition: host_utils.h:245
QudaPrecision getPrecision(int i)
Definition: host_utils.h:222
const char * get_matpc_str(QudaMatPCType type)
Definition: misc.cpp:200
const char * get_prec_str(QudaPrecision prec)
Definition: misc.cpp:26
const char * get_dslash_str(QudaDslashType type)
Definition: misc.cpp:118
const char * get_recon_str(QudaReconstructType recon)
Definition: misc.cpp:68
::std::string string
Definition: gtest-port.h:891
internal::ParamGenerator< T > Range(T start, T end, IncrementT step)
internal::CartesianProductHolder< Generator... > Combine(const Generator &... g)
internal::ValueArray< T... > Values(T... v)
internal::ParamGenerator< bool > Bool()
GTEST_API_ void InitGoogleTest(int *argc, char **argv)
QudaGaugeParam param
Definition: pack_test.cpp:18
void initQuda(int device)
void endQuda(void)
void updateR()
update the radius for halos.
void init_ctest(int argc, char **argv, int precision, QudaReconstructType link_recon)
dslash_test_type dtest_type
void run_test(int niter, bool print_metrics=false)
QudaGaugeParam gauge_param
QudaInvertParam inv_param
QudaReconstructType reconstruct
Definition: quda.h:49
QudaPrecision cuda_prec
Definition: quda.h:238
#define printfQuda(...)
Definition: util_quda.h:114
#define warningQuda(...)
Definition: util_quda.h:132