QUDA  v1.1.0
A library for QCD on GPUs
staggered_dslash_ctest.cpp
Go to the documentation of this file.
2 
3 using namespace quda;
4 
6 
7 bool gauge_loaded = false;
8 
9 const char *prec_str[] = {"quarter", "half", "single", "double"};
10 const char *recon_str[] = {"r18", "r13", "r9"};
11 
12 void init(int precision, QudaReconstructType link_recon, int partition)
13 {
14  dslash_test_wrapper.init_ctest(precision, link_recon, partition);
15 }
16 
18 
20 {
21  auto prec = precision == 2 ? QUDA_DOUBLE_PRECISION : precision == 1 ? QUDA_SINGLE_PRECISION : QUDA_HALF_PRECISION;
22 
23  printfQuda("prec recon test_type dagger S_dim T_dimension\n");
24  printfQuda("%s %s %s %d %d/%d/%d %d \n", get_prec_str(prec), get_recon_str(link_recon),
26 }
27 
31 using ::testing::TestWithParam;
33 
34 class StaggeredDslashTest : public ::testing::TestWithParam<::testing::tuple<int, int, int>>
35 {
36 protected:
37  ::testing::tuple<int, int, int> param;
38 
39  bool skip()
40  {
41  QudaReconstructType recon = static_cast<QudaReconstructType>(::testing::get<1>(GetParam()));
42 
43  if ((QUDA_PRECISION & getPrecision(::testing::get<0>(GetParam()))) == 0
44  || (QUDA_RECONSTRUCT & getReconstructNibble(recon)) == 0) {
45  return true;
46  }
47 
49  && (::testing::get<0>(GetParam()) == 0 || ::testing::get<0>(GetParam()) == 1)) {
50  warningQuda("Fixed precision unsupported in fat/long compute, skipping...");
51  return true;
52  }
53 
55  warningQuda("Reconstruct 9 unsupported in fat/long compute, skipping...");
56  return true;
57  }
58 
59  if (dslash_type == QUDA_LAPLACE_DSLASH && (::testing::get<0>(GetParam()) == 0 || ::testing::get<0>(GetParam()) == 1)) {
60  warningQuda("Fixed precision unsupported for Laplace operator, skipping...");
61  return true;
62  }
63 
64  if (::testing::get<2>(GetParam()) > 0 && dslash_test_wrapper.test_split_grid) { return true; }
65  return false;
66  }
67 
68 public:
69  virtual ~StaggeredDslashTest() { }
70  virtual void SetUp() {
71  int prec = ::testing::get<0>(GetParam());
72  QudaReconstructType recon = static_cast<QudaReconstructType>(::testing::get<1>(GetParam()));
73 
74  if (skip()) GTEST_SKIP();
75 
76  int value = ::testing::get<2>(GetParam());
77  for(int j=0; j < 4;j++){
78  if (value & (1 << j)){
80  }
81 
82  }
83  updateR();
84 
85  init(prec, recon, value);
86  display_test_info(prec, recon);
87  }
88 
89  virtual void TearDown()
90  {
91  if (skip()) GTEST_SKIP();
92  end();
93  }
94 
96 
97  // Per-test-case tear-down.
98  // Called after the last test in this test case.
99  // Can be omitted if not needed.
100  static void TearDownTestCase() { endQuda(); }
101 };
102 
104 {
105  double deviation = 1.0;
107  // check for skip_kernel
109  if (dslash_test_wrapper.spinorRef != nullptr) {
111  deviation = dslash_test_wrapper.verify();
112  }
113  ASSERT_LE(deviation, tol) << "CPU and CUDA implementations do not agree";
114 }
115 
117 
118 int main(int argc, char **argv)
119 {
120  // hack for loading gauge fields
123 
125 
126  // initalize google test
127  ::testing::InitGoogleTest(&argc, argv);
128  auto app = make_app();
129  app->add_option("--test", dtest_type, "Test method")->transform(CLI::CheckedTransformer(dtest_type_map));
131  try {
132  app->parse(argc, argv);
133  } catch (const CLI::ParseError &e) {
134  return app->exit(e);
135  }
136  initComms(argc, argv, gridsize_from_cmdline);
137 
138  // Ensure gtest prints only from rank 0
140  if (comm_rank() != 0) { delete listeners.Release(listeners.default_result_printer()); }
141 
142  // Only these fermions are supported in this file. Ensure a reasonable default,
143  // ensure that the default is improved staggered
145  printfQuda("dslash_type %s not supported, defaulting to %s\n", get_dslash_str(dslash_type),
148  }
149 
150  // Sanity check: if you pass in a gauge field, want to test the asqtad/hisq dslash, and don't
151  // ask to build the fat/long links... it doesn't make sense.
152  if (strcmp(latfile, "") && !compute_fatlong && dslash_type == QUDA_ASQTAD_DSLASH) {
153  errorQuda(
154  "Cannot load a gauge field and test the ASQTAD/HISQ operator without setting \"--compute-fat-long true\".\n");
155  compute_fatlong = true;
156  }
157 
158  // Set n_naiks to 2 if eps_naik != 0.0
160  if (eps_naik != 0.0) {
161  if (compute_fatlong) {
162  n_naiks = 2;
163  printfQuda("Note: epsilon-naik != 0, testing epsilon correction links.\n");
164  } else {
165  eps_naik = 0.0;
166  printfQuda("Not computing fat-long, ignoring epsilon correction.\n");
167  }
168  } else {
169  printfQuda("Note: epsilon-naik = 0, testing original HISQ links.\n");
170  }
171  }
172 
175  errorQuda("Test type %s is not supported for the Laplace operator.\n",
177  }
178  }
179 
180  // return result of RUN_ALL_TESTS
181  int test_rc = RUN_ALL_TESTS();
182 
183  // Clean up loaded gauge field
184  for (int dir = 0; dir < 4; dir++) {
185  if (dslash_test_wrapper.qdp_inlink[dir] != nullptr) {
186  free(dslash_test_wrapper.qdp_inlink[dir]);
187  dslash_test_wrapper.qdp_inlink[dir] = nullptr;
188  }
189  }
190 
192 
193  finalizeComms();
194 
195  return test_rc;
196 }
197 
199  const int prec = ::testing::get<0>(param.param);
200  const int recon = ::testing::get<1>(param.param);
201  const int part = ::testing::get<2>(param.param);
202  std::stringstream ss;
203  // ss << get_dslash_str(dslash_type) << "_";
204  ss << prec_str[prec];
205  ss << "_r" << recon;
206  ss << "_partition" << part;
207  return ss.str();
208  }
209 
210 #ifdef MULTI_GPU
212  Combine(Range(0, 4),
214  Range(0, 16)),
216 #else
218  Combine(Range(0, 4),
220  ::testing::Values(0)),
222 #endif
double benchmark(Kernel kernel, const int niter)
Definition: blas_test.cpp:349
::testing::tuple< int, int, int > param
TestEventListener * Release(TestEventListener *listener)
TestEventListener * default_result_printer() const
Definition: gtest.h:1186
TestEventListeners & listeners()
static UnitTest * GetInstance()
int comm_rank(void)
void commDimPartitionedSet(int dir)
std::shared_ptr< QUDAApp > make_app(std::string app_description, std::string app_name)
double tol
QudaReconstructType link_recon
int niter
int device_ordinal
int & ydim
char latfile[256]
int & zdim
double eps_naik
QudaDslashType dslash_type
bool compute_fatlong
QudaPrecision prec
int n_naiks
int & tdim
int & xdim
void add_comms_option_group(std::shared_ptr< QUDAApp > quda_app)
std::array< int, 4 > gridsize_from_cmdline
bool dagger
std::string get_string(CLI::TransformPairs< T > &map, T val)
CLI::TransformPairs< dslash_test_type > dtest_type_map
@ QUDA_STAGGERED_DSLASH
Definition: enum_quda.h:97
@ QUDA_ASQTAD_DSLASH
Definition: enum_quda.h:98
@ QUDA_LAPLACE_DSLASH
Definition: enum_quda.h:101
@ QUDA_RECONSTRUCT_NO
Definition: enum_quda.h:70
@ QUDA_RECONSTRUCT_12
Definition: enum_quda.h:71
@ QUDA_RECONSTRUCT_8
Definition: enum_quda.h:72
enum QudaReconstructType_s QudaReconstructType
@ QUDA_DOUBLE_PRECISION
Definition: enum_quda.h:65
@ QUDA_SINGLE_PRECISION
Definition: enum_quda.h:64
@ QUDA_HALF_PRECISION
Definition: enum_quda.h:63
#define GTEST_SKIP()
Definition: gtest.h:1887
#define ASSERT_LE(val1, val2)
Definition: gtest.h:2055
int RUN_ALL_TESTS() GTEST_MUST_USE_RESULT_
Definition: gtest.h:2468
void initComms(int argc, char **argv, std::array< int, 4 > &commDims)
Definition: host_utils.cpp:255
void finalizeComms()
Definition: host_utils.cpp:292
int getReconstructNibble(QudaReconstructType recon)
Definition: host_utils.h:233
double getTolerance(QudaPrecision prec)
Definition: host_utils.h:245
QudaPrecision getPrecision(int i)
Definition: host_utils.h:222
const char * get_prec_str(QudaPrecision prec)
Definition: misc.cpp:26
const char * get_dslash_str(QudaDslashType type)
Definition: misc.cpp:118
const char * get_recon_str(QudaReconstructType recon)
Definition: misc.cpp:68
::std::string string
Definition: gtest-port.h:891
internal::ParamGenerator< T > Range(T start, T end, IncrementT step)
internal::CartesianProductHolder< Generator... > Combine(const Generator &... g)
internal::ValueArray< T... > Values(T... v)
internal::ParamGenerator< bool > Bool()
GTEST_API_ void InitGoogleTest(int *argc, char **argv)
QudaGaugeParam param
Definition: pack_test.cpp:18
void initQuda(int device)
void endQuda(void)
void updateR()
update the radius for halos.
INSTANTIATE_TEST_SUITE_P(QUDA, StaggeredDslashTest, Combine(Range(0, 4), ::testing::Values(QUDA_RECONSTRUCT_NO, QUDA_RECONSTRUCT_12, QUDA_RECONSTRUCT_8), ::testing::Values(0)), getstaggereddslashtestname)
TEST_P(StaggeredDslashTest, verify)
const char * recon_str[]
int main(int argc, char **argv)
std::string getstaggereddslashtestname(testing::TestParamInfo<::testing::tuple< int, int, int >> param)
void end()
bool gauge_loaded
void display_test_info(int precision, QudaReconstructType link_recon)
const char * prec_str[]
StaggeredDslashTestWrapper dslash_test_wrapper
void init(int precision, QudaReconstructType link_recon, int partition)
dslash_test_type dtest_type
QudaPrecision cuda_prec
Definition: quda.h:238
void run_test(int niter, bool print_metrics=false)
void init_ctest(int precision, QudaReconstructType link_recon_, int partition)
#define printfQuda(...)
Definition: util_quda.h:114
#define warningQuda(...)
Definition: util_quda.h:132
#define errorQuda(...)
Definition: util_quda.h:120