QUDA  v1.1.0
A library for QCD on GPUs
staggered_dslash_test.cpp
Go to the documentation of this file.
2 
3 using namespace quda;
4 
6 
7 static int dslashTest()
8 {
9  // return code for google test
10  int test_rc = 0;
12 
14  int attempts = 1;
15  for (int i = 0; i < attempts; i++) {
17  if (verify_results) {
18  test_rc = RUN_ALL_TESTS();
19  if (test_rc != 0) warningQuda("Tests failed");
20  }
21  }
23 
24  return test_rc;
25 }
26 
27 TEST(dslash, verify)
28 {
29  double deviation = dslash_test_wrapper.verify();
30  double tol = getTolerance(prec);
31  ASSERT_LE(deviation, tol) << "CPU and CUDA implementations do not agree";
32 }
33 
35 {
36  printfQuda("running the following test:\n");
37  printfQuda("prec recon test_type dagger S_dim T_dimension\n");
38  printfQuda("%s %s %s %d %d/%d/%d %d \n", get_prec_str(prec), get_recon_str(link_recon),
40  printfQuda("Grid partition info: X Y Z T\n");
41  printfQuda(" %d %d %d %d\n", dimPartitioned(0), dimPartitioned(1), dimPartitioned(2),
42  dimPartitioned(3));
43 }
44 
45 int main(int argc, char **argv)
46 {
47  // hack for loading gauge fields
50 
51  // initalize google test
52  ::testing::InitGoogleTest(&argc, argv);
53 
54  // command line options
55  auto app = make_app();
56  app->add_option("--test", dtest_type, "Test method")->transform(CLI::CheckedTransformer(dtest_type_map));
58  try {
59  app->parse(argc, argv);
60  } catch (const CLI::ParseError &e) {
61  return app->exit(e);
62  }
63 
64  initComms(argc, argv, gridsize_from_cmdline);
65 
66  updateR();
67 
69 
70  // Ensure gtest prints only from rank 0
72  if (comm_rank() != 0) { delete listeners.Release(listeners.default_result_printer()); }
73 
74  // Only these fermions are supported in this file. Ensure a reasonable default,
75  // ensure that the default is improved staggered
77  printfQuda("dslash_type %s not supported, defaulting to %s\n", get_dslash_str(dslash_type),
80  }
81 
82  // Sanity check: if you pass in a gauge field, want to test the asqtad/hisq dslash,
83  // and don't ask to build the fat/long links... it doesn't make sense.
84  if (strcmp(latfile,"") && !compute_fatlong && dslash_type == QUDA_ASQTAD_DSLASH) {
85  errorQuda("Cannot load a gauge field and test the ASQTAD/HISQ operator without setting \"--compute-fat-long true\".\n");
86  }
87 
88  // Set n_naiks to 2 if eps_naik != 0.0
90  if (eps_naik != 0.0) {
91  if (compute_fatlong) {
92  n_naiks = 2;
93  printfQuda("Note: epsilon-naik != 0, testing epsilon correction links.\n");
94  } else {
95  eps_naik = 0.0;
96  printfQuda("Not computing fat-long, ignoring epsilon correction.\n");
97  }
98  } else {
99  printfQuda("Note: epsilon-naik = 0, testing original HISQ links.\n");
100  }
101  }
102 
105  errorQuda("Test type %s is not supported for the Laplace operator.\n",
107  }
108  }
109 
110  // If we're building fat/long links, there are some
111  // tests we have to skip.
113  if (prec < QUDA_SINGLE_PRECISION /* half */) { errorQuda("Half precision unsupported in fat/long compute"); }
114  }
115 
117 
118  // return result of RUN_ALL_TESTS
119  int test_rc = dslashTest();
120 
121  endQuda();
122 
123  finalizeComms();
124 
125  return test_rc;
126 }
TestEventListener * Release(TestEventListener *listener)
TestEventListener * default_result_printer() const
Definition: gtest.h:1186
TestEventListeners & listeners()
static UnitTest * GetInstance()
int comm_rank(void)
std::shared_ptr< QUDAApp > make_app(std::string app_description, std::string app_name)
double tol
QudaReconstructType link_recon
int niter
int device_ordinal
int & ydim
bool verify_results
char latfile[256]
int & zdim
double eps_naik
QudaDslashType dslash_type
bool compute_fatlong
QudaPrecision prec
int n_naiks
int & tdim
int & xdim
void add_comms_option_group(std::shared_ptr< QUDAApp > quda_app)
std::array< int, 4 > gridsize_from_cmdline
bool dagger
std::string get_string(CLI::TransformPairs< T > &map, T val)
CLI::TransformPairs< dslash_test_type > dtest_type_map
@ QUDA_STAGGERED_DSLASH
Definition: enum_quda.h:97
@ QUDA_ASQTAD_DSLASH
Definition: enum_quda.h:98
@ QUDA_LAPLACE_DSLASH
Definition: enum_quda.h:101
@ QUDA_SINGLE_PRECISION
Definition: enum_quda.h:64
#define ASSERT_LE(val1, val2)
Definition: gtest.h:2055
int RUN_ALL_TESTS() GTEST_MUST_USE_RESULT_
Definition: gtest.h:2468
int dimPartitioned(int dim)
Definition: host_utils.cpp:376
void initComms(int argc, char **argv, std::array< int, 4 > &commDims)
Definition: host_utils.cpp:255
void finalizeComms()
Definition: host_utils.cpp:292
double getTolerance(QudaPrecision prec)
Definition: host_utils.h:245
const char * get_prec_str(QudaPrecision prec)
Definition: misc.cpp:26
const char * get_dslash_str(QudaDslashType type)
Definition: misc.cpp:118
const char * get_recon_str(QudaReconstructType recon)
Definition: misc.cpp:68
GTEST_API_ void InitGoogleTest(int *argc, char **argv)
void initQuda(int device)
void endQuda(void)
void updateR()
update the radius for halos.
int main(int argc, char **argv)
TEST(dslash, verify)
void display_test_info()
StaggeredDslashTestWrapper dslash_test_wrapper
dslash_test_type dtest_type
void run_test(int niter, bool print_metrics=false)
#define printfQuda(...)
Definition: util_quda.h:114
#define warningQuda(...)
Definition: util_quda.h:132
#define errorQuda(...)
Definition: util_quda.h:120