QUDA  v1.1.0
A library for QCD on GPUs
blas_test.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 
4 #include <quda_internal.h>
5 #include <color_spinor_field.h>
6 #include <blas_quda.h>
7 
8 #include <host_utils.h>
9 #include <command_line_params.h>
10 
11 // include because of nasty globals used in the tests
12 #include <dslash_reference.h>
13 
14 // google test
15 #include <gtest/gtest.h>
16 
17 using namespace quda;
18 
31 // these are pointers to the host fields
32 ColorSpinorField *xH, *yH, *zH, *wH, *vH, *xoH, *yoH, *zoH;
33 
34 // these are pointers to the device fields that have "this precision"
36 
37 // these are pointers to the device multi-fields that have "this precision"
39 
40 // these are pointers to the device fields that have "this precision"
42 
43 // these are pointers to the device multi-fields that have "other precision"
45 
46 // these are pointers to the host multi-fields that have "this precision"
47 std::vector<cpuColorSpinorField *> xmH;
48 std::vector<cpuColorSpinorField *> ymH;
49 std::vector<cpuColorSpinorField *> zmH;
50 int Nspin;
51 int Ncolor;
52 
53 void setPrec(ColorSpinorParam &param, QudaPrecision precision) { param.setPrecision(precision, precision, true); }
54 
56 {
57  printfQuda("running the following test:\n");
58  printfQuda("S_dimension T_dimension Nspin Ncolor\n");
59  printfQuda("%3d /%3d / %3d %3d %d %d\n", xdim, ydim, zdim, tdim, Nspin, Ncolor);
60  printfQuda("Grid partition info: X Y Z T\n");
61  printfQuda(" %d %d %d %d\n", dimPartitioned(0), dimPartitioned(1), dimPartitioned(2),
62  dimPartitioned(3));
63 }
64 
65 using prec_pair_t = std::pair<QudaPrecision, QudaPrecision>;
66 
67 const std::map<QudaPrecision, std::string> prec_map = {{QUDA_QUARTER_PRECISION, "quarter"},
68  {QUDA_HALF_PRECISION, "half"},
69  {QUDA_SINGLE_PRECISION, "single"},
70  {QUDA_DOUBLE_PRECISION, "double"}};
71 
72 const int Nprec = prec_map.size();
73 
74 enum class Kernel {
75  copyHS,
76  copyLS,
77  axpbyz,
78  ax,
79  caxpy,
80  caxpby,
81  cxpaypbz,
82  axpyBzpcx,
83  axpyZpbx,
85  cabxpyAx,
86  caxpyXmaz,
87  norm2,
89  axpbyzNorm,
90  axpyCGNorm,
91  caxpyNorm,
95  caxpyDotzy,
102  axpyReDot,
103  caxpyBxpz,
104  caxpyBzpx,
105  axpy_block,
106  caxpy_block,
113 };
114 
115 // For googletest names must be non-empty, unique, and may only contain ASCII
116 // alphanumeric characters or underscore
117 const std::map<Kernel, std::string> kernel_map
118  = {{Kernel::copyHS, "copyHS"},
119  {Kernel::copyLS, "copyLS"},
120  {Kernel::axpbyz, "axpbyz"},
121  {Kernel::ax, "ax"},
122  {Kernel::caxpy, "caxpy"},
123  {Kernel::caxpby, "caxpby"},
124  {Kernel::cxpaypbz, "cxpaypbz"},
125  {Kernel::axpyBzpcx, "axpyBzpcx"},
126  {Kernel::axpyZpbx, "axpyZpbx"},
127  {Kernel::caxpbypzYmbw, "caxpbypzYmbw"},
128  {Kernel::cabxpyAx, "cabxpyAx"},
129  {Kernel::caxpyXmaz, "caxpyXmaz"},
130  {Kernel::norm2, "norm2"},
131  {Kernel::reDotProduct, "reDotProduct"},
132  {Kernel::axpbyzNorm, "axpbyzNorm"},
133  {Kernel::axpyCGNorm, "axpyCGNorm"},
134  {Kernel::caxpyNorm, "caxpyNorm"},
135  {Kernel::caxpyXmazNormX, "caxpyXmazNormX"},
136  {Kernel::cabxpyzAxNorm, "cabxpyzAxNorm"},
137  {Kernel::cDotProduct, "cDotProduct"},
138  {Kernel::caxpyDotzy, "caxpyDotzy"},
139  {Kernel::cDotProductNormA, "cDotProductNormA"},
140  {Kernel::caxpbypzYmbwcDotProductUYNormY, "caxpbypzYmbwcDotProductUYNormY"},
141  {Kernel::HeavyQuarkResidualNorm, "HeavyQuarkResidualNorm"},
142  {Kernel::xpyHeavyQuarkResidualNorm, "xpyHeavyQuarkResidualNorm"},
143  {Kernel::tripleCGReduction, "tripleCGReduction"},
144  {Kernel::tripleCGUpdate, "tripleCGUpdate"},
145  {Kernel::axpyReDot, "axpyReDot"},
146  {Kernel::caxpyBxpz, "caxpyBxpz"},
147  {Kernel::caxpyBzpx, "caxpyBzpx"},
148  {Kernel::axpy_block, "axpy_block"},
149  {Kernel::caxpy_block, "caxpy_block"},
150  {Kernel::axpyBzpcx_block, "axpyBzpcx_block"},
151  {Kernel::reDotProductNorm_block, "reDotProductNorm_block"},
152  {Kernel::reDotProduct_block, "reDotProduct_block"},
153  {Kernel::cDotProductNorm_block, "cDotProductNorm_block"},
154  {Kernel::cDotProduct_block, "cDotProduct_block"},
155  {Kernel::caxpyXmazMR, "caxpyXmazMR"}};
156 
157 const int Nkernels = kernel_map.size();
158 
159 // kernels that utilize multi-blas
160 bool is_multi(Kernel kernel)
161 {
162  return std::string(kernel_map.at(kernel)).find("_block") != std::string::npos ? true : false;
163 }
164 
165 bool is_copy(Kernel kernel) { return (kernel == Kernel::copyHS || kernel == Kernel::copyLS); }
166 
167 // kernels that require site unrolling
168 bool is_site_unroll(Kernel kernel)
169 {
171 }
172 
173 bool skip_kernel(prec_pair_t pair, Kernel kernel)
174 {
175  auto &this_prec = pair.first;
176  auto &other_prec = pair.second;
177 
178  if ((QUDA_PRECISION & this_prec) == 0) return true;
179  if ((QUDA_PRECISION & other_prec) == 0) return true;
180 
181  // if we've selected a given kernel then make sure we only run that
182  if (test_type != -1 && (int)kernel != test_type) return true;
183 
184  // if we've selected a given precision then make sure we only run that
185  if (prec != QUDA_INVALID_PRECISION && this_prec != prec) return true;
186 
187  // if we've selected a given precision then make sure we only run that
188  if (prec_sloppy != QUDA_INVALID_PRECISION && other_prec != prec_sloppy) return true;
189 
190  if (Nspin == 2 && this_prec < QUDA_SINGLE_PRECISION) {
191  // avoid quarter, half precision tests if doing coarse fields
192  return true;
193  } else if (Ncolor != 3 && is_site_unroll(kernel)) {
194  // only benchmark heavy-quark norm if doing 3 colors
195  return true;
196  }
197 
198  return false;
199 }
200 
201 void initFields(prec_pair_t prec_pair)
202 {
204  param.nColor = Ncolor;
205  param.nSpin = Nspin;
206  param.nDim = 4; // number of spacetime dimensions
207  param.pad = 0; // padding must be zero for cpu fields
208 
209  switch (solve_type) {
211  case QUDA_NORMOP_PC_SOLVE: param.siteSubset = QUDA_PARITY_SITE_SUBSET; break;
212  case QUDA_DIRECT_SOLVE:
213  case QUDA_NORMOP_SOLVE: param.siteSubset = QUDA_FULL_SITE_SUBSET; break;
214  default: errorQuda("Unexpected solve_type=%d\n", solve_type);
215  }
216 
217  if (param.siteSubset == QUDA_PARITY_SITE_SUBSET)
218  param.x[0] = xdim / 2;
219  else
220  param.x[0] = xdim;
221  param.x[1] = ydim;
222  param.x[2] = zdim;
223  param.x[3] = tdim;
224 
225  param.siteOrder = QUDA_EVEN_ODD_SITE_ORDER;
227  param.setPrecision(QUDA_DOUBLE_PRECISION);
229  param.create = QUDA_ZERO_FIELD_CREATE;
230 
236 
237  // all host fields are double precision, so the "other" fields just alias the regular fields
238  xoH = xH;
239  yoH = yH;
240  zoH = zH;
241 
242  xmH.reserve(Nsrc);
243  for (int cid = 0; cid < Nsrc; cid++) xmH.push_back(new cpuColorSpinorField(param));
244  ymH.reserve(Msrc);
245  for (int cid = 0; cid < Msrc; cid++) ymH.push_back(new cpuColorSpinorField(param));
246  zmH.reserve(Nsrc);
247  for (int cid = 0; cid < Nsrc; cid++) zmH.push_back(new cpuColorSpinorField(param));
248 
249  static_cast<cpuColorSpinorField *>(vH)->Source(QUDA_RANDOM_SOURCE, 0, 0, 0);
250  static_cast<cpuColorSpinorField *>(wH)->Source(QUDA_RANDOM_SOURCE, 0, 0, 0);
251  static_cast<cpuColorSpinorField *>(xH)->Source(QUDA_RANDOM_SOURCE, 0, 0, 0);
252  static_cast<cpuColorSpinorField *>(yH)->Source(QUDA_RANDOM_SOURCE, 0, 0, 0);
253  static_cast<cpuColorSpinorField *>(zH)->Source(QUDA_RANDOM_SOURCE, 0, 0, 0);
254  for (int i = 0; i < Nsrc; i++) { static_cast<cpuColorSpinorField *>(xmH[i])->Source(QUDA_RANDOM_SOURCE, 0, 0, 0); }
255  for (int i = 0; i < Msrc; i++) { static_cast<cpuColorSpinorField *>(ymH[i])->Source(QUDA_RANDOM_SOURCE, 0, 0, 0); }
256  // Now set the parameters for the cuda fields
257  // param.pad = xdim*ydim*zdim/2;
258 
259  if (param.nSpin == 4) param.gammaBasis = QUDA_UKQCD_GAMMA_BASIS;
260  param.create = QUDA_ZERO_FIELD_CREATE;
261 
262  QudaPrecision prec = prec_pair.first;
263  QudaPrecision prec_other = prec_pair.second;
264 
265  param.setPrecision(prec, prec, true);
271 
272  param.setPrecision(prec_other, prec_other, true);
278 
279  // create composite fields
280  param.is_composite = true;
281  param.is_component = false;
282 
283  param.setPrecision(prec, prec, true);
284  param.composite_dim = Nsrc;
286 
287  param.composite_dim = Msrc;
289 
290  param.composite_dim = Nsrc;
292 
293  param.setPrecision(prec_other, prec_other, true);
294  param.composite_dim = Nsrc;
296 
297  param.composite_dim = Msrc;
299 
300  param.composite_dim = Nsrc;
302 
303  // only do copy if not doing half precision with mg
304  bool flag = !(param.nSpin == 2 && (prec < QUDA_SINGLE_PRECISION || prec_other < QUDA_HALF_PRECISION));
305 
306  if (flag) {
307  *vD = *vH;
308  *wD = *wH;
309  *xD = *xH;
310  *yD = *yH;
311  *zD = *zH;
312  }
313 }
314 
316 {
317  // release memory
318  delete vD;
319  delete wD;
320  delete xD;
321  delete yD;
322  delete zD;
323  delete voD;
324  delete woD;
325  delete xoD;
326  delete yoD;
327  delete zoD;
328  delete xmD;
329  delete ymD;
330  delete zmD;
331  delete xmoD;
332  delete ymoD;
333  delete zmoD;
334 
335  // release memory
336  delete vH;
337  delete wH;
338  delete xH;
339  delete yH;
340  delete zH;
341  for (int i = 0; i < Nsrc; i++) delete xmH[i];
342  for (int i = 0; i < Msrc; i++) delete ymH[i];
343  for (int i = 0; i < Nsrc; i++) delete zmH[i];
344  xmH.clear();
345  ymH.clear();
346  zmH.clear();
347 }
348 
349 double benchmark(Kernel kernel, const int niter)
350 {
351  double a = 1.0, b = 2.0, c = 3.0;
352  quda::Complex a2, b2;
353  quda::Complex *A = new quda::Complex[Nsrc * Msrc];
354  quda::Complex *B = new quda::Complex[Nsrc * Msrc];
355  quda::Complex *C = new quda::Complex[Nsrc * Msrc];
356  quda::Complex *A2 = new quda::Complex[Nsrc * Nsrc]; // for the block cDotProductNorm test
357  double *Ar = new double[Nsrc * Msrc];
358 
359  cudaEvent_t start, end;
360  cudaEventCreate(&start);
361  cudaEventCreate(&end);
362  cudaEventRecord(start, 0);
363 
364  {
365  switch (kernel) {
366 
367  case Kernel::copyHS:
368  for (int i = 0; i < niter; ++i) blas::copy(*yD, *xoD);
369  break;
370 
371  case Kernel::copyLS:
372  for (int i = 0; i < niter; ++i) blas::copy(*yoD, *xD);
373  break;
374 
375  case Kernel::axpbyz:
376  for (int i = 0; i < niter; ++i) blas::axpbyz(a, *xD, b, *yoD, *zoD);
377  break;
378 
379  case Kernel::ax:
380  for (int i = 0; i < niter; ++i) blas::ax(a, *xD);
381  break;
382 
383  case Kernel::caxpy:
384  for (int i = 0; i < niter; ++i) blas::caxpy(a2, *xD, *yoD);
385  break;
386 
387  case Kernel::caxpby:
388  for (int i = 0; i < niter; ++i) blas::caxpby(a2, *xD, b2, *yD);
389  break;
390 
391  case Kernel::cxpaypbz:
392  for (int i = 0; i < niter; ++i) blas::cxpaypbz(*xD, a2, *yD, b2, *zD);
393  break;
394 
395  case Kernel::axpyBzpcx:
396  for (int i = 0; i < niter; ++i) blas::axpyBzpcx(a, *xD, *yoD, b, *zD, c);
397  break;
398 
399  case Kernel::axpyZpbx:
400  for (int i = 0; i < niter; ++i) blas::axpyZpbx(a, *xD, *yoD, *zD, b);
401  break;
402 
404  for (int i = 0; i < niter; ++i) blas::caxpbypzYmbw(a2, *xD, b2, *yD, *zD, *wD);
405  break;
406 
407  case Kernel::cabxpyAx:
408  for (int i = 0; i < niter; ++i) blas::cabxpyAx(a, b2, *xD, *yD);
409  break;
410 
411  case Kernel::caxpyXmaz:
412  for (int i = 0; i < niter; ++i) blas::caxpyXmaz(a2, *xD, *yD, *zD);
413  break;
414 
415  case Kernel::norm2:
416  for (int i = 0; i < niter; ++i) blas::norm2(*xD);
417  break;
418 
420  for (int i = 0; i < niter; ++i) blas::reDotProduct(*xD, *yD);
421  break;
422 
423  case Kernel::axpbyzNorm:
424  for (int i = 0; i < niter; ++i) blas::axpbyzNorm(a, *xD, b, *yD, *zD);
425  break;
426 
427  case Kernel::axpyCGNorm:
428  for (int i = 0; i < niter; ++i) blas::axpyCGNorm(a, *xD, *yoD);
429  break;
430 
431  case Kernel::caxpyNorm:
432  for (int i = 0; i < niter; ++i) blas::caxpyNorm(a2, *xD, *yD);
433  break;
434 
436  for (int i = 0; i < niter; ++i) blas::caxpyXmazNormX(a2, *xD, *yD, *zD);
437  break;
438 
440  for (int i = 0; i < niter; ++i) blas::cabxpyzAxNorm(a, b2, *xD, *yD, *yD);
441  break;
442 
443  case Kernel::cDotProduct:
444  for (int i = 0; i < niter; ++i) blas::cDotProduct(*xD, *yD);
445  break;
446 
447  case Kernel::caxpyDotzy:
448  for (int i = 0; i < niter; ++i) blas::caxpyDotzy(a2, *xD, *yD, *zD);
449  break;
450 
452  for (int i = 0; i < niter; ++i) blas::cDotProductNormA(*xD, *yD);
453  break;
454 
456  for (int i = 0; i < niter; ++i) blas::caxpbypzYmbwcDotProductUYNormY(a2, *xD, b2, *yD, *zoD, *wD, *vD);
457  break;
458 
460  for (int i = 0; i < niter; ++i) blas::HeavyQuarkResidualNorm(*xD, *yD);
461  break;
462 
464  for (int i = 0; i < niter; ++i) blas::xpyHeavyQuarkResidualNorm(*xD, *yD, *zD);
465  break;
466 
468  for (int i = 0; i < niter; ++i) blas::tripleCGReduction(*xD, *yD, *zD);
469  break;
470 
472  for (int i = 0; i < niter; ++i) blas::tripleCGUpdate(a, b, *xD, *yD, *zD, *wD);
473  break;
474 
475  case Kernel::axpyReDot:
476  for (int i = 0; i < niter; ++i) blas::axpyReDot(a, *xD, *yD);
477  break;
478 
479  case Kernel::caxpyBxpz:
480  for (int i = 0; i < niter; ++i) blas::caxpyBxpz(a2, *xD, *yD, b2, *zD);
481  break;
482 
483  case Kernel::caxpyBzpx:
484  for (int i = 0; i < niter; ++i) blas::caxpyBzpx(a2, *xD, *yD, b2, *zD);
485  break;
486 
487  case Kernel::axpy_block:
488  for (int i = 0; i < niter; ++i) blas::axpy(Ar, xmD->Components(), ymoD->Components());
489  break;
490 
491  case Kernel::caxpy_block:
492  for (int i = 0; i < niter; ++i) blas::caxpy(A, *xmD, *ymoD);
493  break;
494 
496  for (int i = 0; i < niter; ++i)
497  blas::axpyBzpcx((double *)A, xmD->Components(), zmoD->Components(), (double *)B, *yD, (double *)C);
498  break;
499 
501  for (int i = 0; i < niter; ++i) blas::reDotProduct((double *)A2, xmD->Components(), xmD->Components());
502  break;
503 
505  for (int i = 0; i < niter; ++i) blas::reDotProduct((double *)A, xmD->Components(), ymoD->Components());
506  break;
507 
509  for (int i = 0; i < niter; ++i) blas::cDotProduct(A2, xmD->Components(), xmD->Components());
510  break;
511 
513  for (int i = 0; i < niter; ++i) blas::cDotProduct(A, xmD->Components(), ymoD->Components());
514  break;
515 
516  case Kernel::caxpyXmazMR:
517  commAsyncReductionSet(true);
518  for (int i = 0; i < niter; ++i) blas::caxpyXmazMR(a, *xD, *yD, *zD);
519  commAsyncReductionSet(false);
520  break;
521 
522  default: errorQuda("Undefined blas kernel %s\n", kernel_map.at(kernel).c_str());
523  }
524  }
525 
526  cudaEventRecord(end, 0);
527  cudaEventSynchronize(end);
528  float runTime;
529  cudaEventElapsedTime(&runTime, start, end);
530  cudaEventDestroy(start);
531  cudaEventDestroy(end);
532  delete[] A;
533  delete[] B;
534  delete[] C;
535  delete[] A2;
536  delete[] Ar;
537  double secs = runTime / 1000;
538  return secs;
539 }
540 
541 #define ERROR(a) fabs(blas::norm2(*a##D) - blas::norm2(*a##H)) / blas::norm2(*a##H)
542 
543 double test(Kernel kernel)
544 {
545  double a = M_PI, b = M_PI * exp(1.0), c = sqrt(M_PI);
546  quda::Complex a2(a, b), b2(b, -c), c2(a + b, c * a);
547  double error = 0;
548  quda::Complex *A = new quda::Complex[Nsrc * Msrc];
549  quda::Complex *B = new quda::Complex[Nsrc * Msrc];
550  quda::Complex *C = new quda::Complex[Nsrc * Msrc];
551  quda::Complex *A2 = new quda::Complex[Nsrc * Nsrc]; // for the block cDotProductNorm test
552  quda::Complex *B2 = new quda::Complex[Nsrc * Nsrc]; // for the block cDotProductNorm test
553  double *Ar = new double[Nsrc * Msrc];
554 
555  for (int i = 0; i < Nsrc * Msrc; i++) {
556  A[i] = a2 * (1.0 * ((i / (double)Nsrc) + i)) + b2 * (1.0 * i) + c2 * (1.0 * (0.5 * Nsrc * Msrc - i));
557  B[i] = a2 * (1.0 * ((i / (double)Nsrc) + i)) - b2 * (M_PI * i) + c2 * (1.0 * (0.5 * Nsrc * Msrc - i));
558  C[i] = a2 * (1.0 * ((M_PI / (double)Nsrc) + i)) + b2 * (1.0 * i) + c2 * (1.0 * (0.5 * Nsrc * Msrc - i));
559  Ar[i] = A[i].real();
560  }
561  for (int i = 0; i < Nsrc * Nsrc; i++) {
562  A2[i] = a2 * (1.0 * ((i / (double)Nsrc) + i)) + b2 * (1.0 * i) + c2 * (1.0 * (0.5 * Nsrc * Nsrc - i));
563  B2[i] = a2 * (1.0 * ((i / (double)Nsrc) + i)) - b2 * (M_PI * i) + c2 * (1.0 * (0.5 * Nsrc * Nsrc - i));
564  }
565  // A[0] = a2;
566  // A[1] = 0.;
567  // A[2] = 0.;
568  // A[3] = 0.;
569 
570  switch (kernel) {
571 
572  case Kernel::copyHS:
573  *xoD = *xH;
574  blas::copy(*yD, *xoD);
575  blas::copy(*yH, *xH);
576  error = ERROR(y);
577  break;
578 
579  case Kernel::copyLS:
580  *xD = *xH;
581  blas::copy(*yoD, *xD);
582  blas::copy(*yH, *xH);
583  error = ERROR(yo);
584  break;
585 
586  case Kernel::axpbyz:
587  *xD = *xH;
588  *yoD = *yH;
589  blas::axpbyz(a, *xD, b, *yoD, *zoD);
590  blas::axpbyz(a, *xH, b, *yH, *zH);
591  error = ERROR(zo);
592  break;
593 
594  case Kernel::ax:
595  *xD = *xH;
596  blas::ax(a, *xD);
597  blas::ax(a, *xH);
598  error = ERROR(x);
599  break;
600 
601  case Kernel::caxpy:
602  *xD = *xH;
603  *yoD = *yH;
604  blas::caxpy(a2, *xD, *yoD);
605  blas::caxpy(a2, *xH, *yH);
606  error = ERROR(yo);
607  break;
608 
609  case Kernel::caxpby:
610  *xD = *xH;
611  *yD = *yH;
612  blas::caxpby(a2, *xD, b2, *yD);
613  blas::caxpby(a2, *xH, b2, *yH);
614  error = ERROR(y);
615  break;
616 
617  case Kernel::cxpaypbz:
618  *xD = *xH;
619  *yD = *yH;
620  *zD = *zH;
621  blas::cxpaypbz(*xD, a2, *yD, b2, *zD);
622  blas::cxpaypbz(*xH, a2, *yH, b2, *zH);
623  error = ERROR(z);
624  break;
625 
626  case Kernel::axpyBzpcx:
627  *xD = *xH;
628  *yoD = *yH;
629  *zD = *zH;
630  blas::axpyBzpcx(a, *xD, *yoD, b, *zD, c);
631  blas::axpyBzpcx(a, *xH, *yH, b, *zH, c);
632  error = ERROR(x) + ERROR(yo);
633  break;
634 
635  case Kernel::axpyZpbx:
636  *xD = *xH;
637  *yoD = *yH;
638  *zD = *zH;
639  blas::axpyZpbx(a, *xD, *yoD, *zD, b);
640  blas::axpyZpbx(a, *xH, *yH, *zH, b);
641  error = ERROR(x) + ERROR(yo);
642  break;
643 
645  *xD = *xH;
646  *yD = *yH;
647  *zD = *zH;
648  *wD = *wH;
649  blas::caxpbypzYmbw(a2, *xD, b2, *yD, *zD, *wD);
650  blas::caxpbypzYmbw(a2, *xH, b2, *yH, *zH, *wH);
651  error = ERROR(z) + ERROR(y);
652  break;
653 
654  case Kernel::cabxpyAx:
655  *xD = *xH;
656  *yD = *yH;
657  blas::cabxpyAx(a, b2, *xD, *yD);
658  blas::cabxpyAx(a, b2, *xH, *yH);
659  error = ERROR(y) + ERROR(x);
660  break;
661 
662  case Kernel::caxpyXmaz:
663  *xD = *xH;
664  *yD = *yH;
665  *zD = *zH;
666  {
667  blas::caxpyXmaz(a, *xD, *yD, *zD);
668  blas::caxpyXmaz(a, *xH, *yH, *zH);
669  error = ERROR(y) + ERROR(x);
670  }
671  break;
672 
673  case Kernel::norm2:
674  *xD = *xH;
675  error = fabs(blas::norm2(*xD) - blas::norm2(*xH)) / blas::norm2(*xH);
676  break;
677 
679  *xD = *xH;
680  *yD = *yH;
681  error = fabs(blas::reDotProduct(*xD, *yD) - blas::reDotProduct(*xH, *yH)) / fabs(blas::reDotProduct(*xH, *yH));
682  break;
683 
684  case Kernel::axpbyzNorm:
685  *xD = *xH;
686  *yD = *yH;
687  {
688  double d = blas::axpbyzNorm(a, *xD, b, *yD, *zD);
689  double h = blas::axpbyzNorm(a, *xH, b, *yH, *zH);
690  error = ERROR(z) + fabs(d - h) / fabs(h);
691  }
692  break;
693 
694  case Kernel::axpyCGNorm:
695  *xD = *xH;
696  *yoD = *yH;
697  {
699  quda::Complex h = blas::axpyCGNorm(a, *xH, *yH);
700  error = ERROR(yo) + fabs(d.real() - h.real()) / fabs(h.real()) + fabs(d.imag() - h.imag()) / fabs(h.imag());
701  }
702  break;
703 
704  case Kernel::caxpyNorm:
705  *xD = *xH;
706  *yD = *yH;
707  {
708  double d = blas::caxpyNorm(a, *xD, *yD);
709  double h = blas::caxpyNorm(a, *xH, *yH);
710  error = ERROR(y) + fabs(d - h) / fabs(h);
711  }
712  break;
713 
715  *xD = *xH;
716  *yD = *yH;
717  *zD = *zH;
718  {
719  double d = blas::caxpyXmazNormX(a, *xD, *yD, *zD);
720  double h = blas::caxpyXmazNormX(a, *xH, *yH, *zH);
721  error = ERROR(y) + ERROR(x) + fabs(d - h) / fabs(h);
722  }
723  break;
724 
726  *xD = *xH;
727  *yD = *yH;
728  {
729  double d = blas::cabxpyzAxNorm(a, b2, *xD, *yD, *yD);
730  double h = blas::cabxpyzAxNorm(a, b2, *xH, *yH, *yH);
731  error = ERROR(x) + ERROR(y) + fabs(d - h) / fabs(h);
732  }
733  break;
734 
735  case Kernel::cDotProduct:
736  *xD = *xH;
737  *yD = *yH;
739  break;
740 
741  case Kernel::caxpyDotzy:
742  *xD = *xH;
743  *yD = *yH;
744  *zD = *zH;
745  {
746  quda::Complex d = blas::caxpyDotzy(a, *xD, *yD, *zD);
747  quda::Complex h = blas::caxpyDotzy(a, *xH, *yH, *zH);
748  error = ERROR(y) + abs(d - h) / abs(h);
749  }
750  break;
751 
753  *xD = *xH;
754  *yD = *yH;
755  {
756  double3 d = blas::cDotProductNormA(*xD, *yD);
757  double3 h = blas::cDotProductNormA(*xH, *yH);
758  error = abs(Complex(d.x - h.x, d.y - h.y)) / abs(Complex(h.x, h.y)) + fabs(d.z - h.z) / fabs(h.z);
759  }
760  break;
761 
763  *xD = *xH;
764  *yD = *yH;
765  *zD = *zH;
766  *wD = *wH;
767  *vD = *vH;
768  {
769  double3 d = blas::caxpbypzYmbwcDotProductUYNormY(a2, *xD, b2, *yD, *zD, *wD, *vD);
770  double3 h = blas::caxpbypzYmbwcDotProductUYNormY(a2, *xH, b2, *yH, *zH, *wH, *vH);
771  error = ERROR(z) + ERROR(y) + abs(Complex(d.x - h.x, d.y - h.y)) / abs(Complex(h.x, h.y))
772  + fabs(d.z - h.z) / fabs(h.z);
773  }
774  break;
775 
777  *xD = *xH;
778  *yD = *yH;
779  {
780  double3 d = blas::HeavyQuarkResidualNorm(*xD, *yD);
781  double3 h = blas::HeavyQuarkResidualNorm(*xH, *yH);
782  error = fabs(d.x - h.x) / fabs(h.x) + fabs(d.y - h.y) / fabs(h.y) + fabs(d.z - h.z) / fabs(h.z);
783  }
784  break;
785 
787  *xD = *xH;
788  *yD = *yH;
789  *zD = *zH;
790  {
791  double3 d = blas::xpyHeavyQuarkResidualNorm(*xD, *yD, *zD);
792  double3 h = blas::xpyHeavyQuarkResidualNorm(*xH, *yH, *zH);
793  error = ERROR(y) + fabs(d.x - h.x) / fabs(h.x) + fabs(d.y - h.y) / fabs(h.y) + fabs(d.z - h.z) / fabs(h.z);
794  }
795  break;
796 
798  *xD = *xH;
799  *yD = *yH;
800  *zD = *zH;
801  {
802  double3 d = blas::tripleCGReduction(*xD, *yD, *zD);
803  double3 h = make_double3(blas::norm2(*xH), blas::norm2(*yH), blas::reDotProduct(*yH, *zH));
804  error = fabs(d.x - h.x) / fabs(h.x) + fabs(d.y - h.y) / fabs(h.y) + fabs(d.z - h.z) / fabs(h.z);
805  }
806  break;
807 
809  *xD = *xH;
810  *yD = *yH;
811  *zD = *zH;
812  *wD = *wH;
813  {
814  blas::tripleCGUpdate(a, b, *xD, *yD, *zD, *wD);
815  blas::tripleCGUpdate(a, b, *xH, *yH, *zH, *wH);
816  error = ERROR(y) + ERROR(z) + ERROR(w);
817  }
818  break;
819 
820  case Kernel::axpyReDot:
821  *xD = *xH;
822  *yD = *yH;
823  {
824  double d = blas::axpyReDot(a, *xD, *yD);
825  double h = blas::axpyReDot(a, *xH, *yH);
826  error = ERROR(y) + fabs(d - h) / fabs(h);
827  }
828  break;
829 
830  case Kernel::caxpyBxpz:
831  *xD = *xH;
832  *yD = *yH;
833  *zD = *zH;
834  {
835  blas::caxpyBxpz(a, *xD, *yD, b2, *zD);
836  blas::caxpyBxpz(a, *xH, *yH, b2, *zH);
837  error = ERROR(x) + ERROR(z);
838  }
839  break;
840 
841  case Kernel::caxpyBzpx:
842  *xD = *xH;
843  *yD = *yH;
844  *zD = *zH;
845  {
846  blas::caxpyBzpx(a, *xD, *yD, b2, *zD);
847  blas::caxpyBzpx(a, *xH, *yH, b2, *zH);
848  error = ERROR(x) + ERROR(z);
849  }
850  break;
851 
852  case Kernel::axpy_block:
853  for (int i = 0; i < Nsrc; i++) xmD->Component(i) = *(xmH[i]);
854  for (int i = 0; i < Msrc; i++) ymoD->Component(i) = *(ymH[i]);
855 
856  blas::axpy(Ar, *xmD, *ymoD);
857  for (int i = 0; i < Nsrc; i++) {
858  for (int j = 0; j < Msrc; j++) { blas::axpy(Ar[Msrc * i + j], *(xmH[i]), *(ymH[j])); }
859  }
860 
861  error = 0;
862  for (int i = 0; i < Msrc; i++) {
863  error += fabs(blas::norm2((ymoD->Component(i))) - blas::norm2(*(ymH[i]))) / blas::norm2(*(ymH[i]));
864  }
865  error /= Msrc;
866  break;
867 
868  case Kernel::caxpy_block:
869  for (int i = 0; i < Nsrc; i++) xmD->Component(i) = *(xmH[i]);
870  for (int i = 0; i < Msrc; i++) ymoD->Component(i) = *(ymH[i]);
871 
872  blas::caxpy(A, *xmD, *ymoD);
873  for (int i = 0; i < Nsrc; i++) {
874  for (int j = 0; j < Msrc; j++) { blas::caxpy(A[Msrc * i + j], *(xmH[i]), *(ymH[j])); }
875  }
876  error = 0;
877  for (int i = 0; i < Msrc; i++) {
878  error += fabs(blas::norm2((ymoD->Component(i))) - blas::norm2(*(ymH[i]))) / blas::norm2(*(ymH[i]));
879  }
880  error /= Msrc;
881  break;
882 
884  for (int i = 0; i < Nsrc; i++) {
885  xmD->Component(i) = *(xmH[i]);
886  zmoD->Component(i) = *(zmH[i]);
887  }
888  *yD = *yH;
889 
890  blas::axpyBzpcx((double *)A, xmD->Components(), zmoD->Components(), (double *)B, *yD, (const double *)C);
891 
892  for (int i = 0; i < Nsrc; i++) {
893  blas::axpyBzpcx(((double *)A)[i], *xmH[i], *zmH[i], ((double *)B)[i], *yH, ((double *)C)[i]);
894  }
895 
896  error = 0;
897  for (int i = 0; i < Nsrc; i++) {
898  error += fabs(blas::norm2((xmD->Component(i))) - blas::norm2(*(xmH[i]))) / blas::norm2(*(xmH[i]));
899  error += fabs(blas::norm2((zmoD->Component(i))) - blas::norm2(*(zmH[i]))) / blas::norm2(*(zmH[i]));
900  }
901  error /= Nsrc;
902  break;
903 
905  for (int i = 0; i < Nsrc; i++) xmD->Component(i) = *(xmH[i]);
906  blas::reDotProduct((double *)A2, xmD->Components(), xmD->Components());
907  error = 0.0;
908  for (int i = 0; i < Nsrc; i++) {
909  for (int j = 0; j < Nsrc; j++) {
910  ((double *)B2)[i * Nsrc + j] = blas::reDotProduct(xmD->Component(i), xmD->Component(j));
911  error += std::abs(((double *)A2)[i * Nsrc + j] - ((double *)B2)[i * Nsrc + j])
912  / std::abs(((double *)B2)[i * Nsrc + j]);
913  }
914  }
915  error /= Nsrc * Nsrc;
916  break;
917 
919  for (int i = 0; i < Nsrc; i++) xmD->Component(i) = *(xmH[i]);
920  for (int i = 0; i < Msrc; i++) ymoD->Component(i) = *(ymH[i]);
921  for (int i = 0; i < Msrc; i++) ymD->Component(i) = *(ymH[i]);
922  blas::reDotProduct((double *)A, xmD->Components(), ymoD->Components());
923  error = 0.0;
924  for (int i = 0; i < Nsrc; i++) {
925  for (int j = 0; j < Msrc; j++) {
926  ((double *)B)[i * Msrc + j] = blas::reDotProduct(xmD->Component(i), ymD->Component(j));
927  error
928  += std::abs(((double *)A)[i * Msrc + j] - ((double *)B)[i * Msrc + j]) / std::abs(((double *)B)[i * Msrc + j]);
929  }
930  }
931  error /= Nsrc * Msrc;
932  break;
933 
935  for (int i = 0; i < Nsrc; i++) xmD->Component(i) = *(xmH[i]);
937  error = 0.0;
938  for (int i = 0; i < Nsrc; i++) {
939  for (int j = 0; j < Nsrc; j++) {
940  B2[i * Nsrc + j] = blas::cDotProduct(xmD->Component(i), xmD->Component(j));
941  error += std::abs(A2[i * Nsrc + j] - B2[i * Nsrc + j]) / std::abs(B2[i * Nsrc + j]);
942  }
943  }
944  error /= Nsrc * Nsrc;
945  break;
946 
948  for (int i = 0; i < Nsrc; i++) xmD->Component(i) = *(xmH[i]);
949  for (int i = 0; i < Msrc; i++) ymoD->Component(i) = *(ymH[i]);
950  for (int i = 0; i < Msrc; i++) ymD->Component(i) = *(ymH[i]);
952  error = 0.0;
953  for (int i = 0; i < Nsrc; i++) {
954  for (int j = 0; j < Msrc; j++) {
955  B[i * Msrc + j] = blas::cDotProduct(xmD->Component(i), ymD->Component(j));
956  error += std::abs(A[i * Msrc + j] - B[i * Msrc + j]) / std::abs(B[i * Msrc + j]);
957  }
958  }
959  error /= Nsrc * Msrc;
960  break;
961 
962  case Kernel::caxpyXmazMR:
963  *xD = *xH;
964  *yD = *yH;
965  *zD = *zH;
966 
967  commGlobalReductionSet(false); // switch off global reductions for this test
968 
969  commAsyncReductionSet(true);
971  blas::caxpyXmazMR(a, *xD, *yD, *zD);
972  commAsyncReductionSet(false);
973 
974  *vD = *xH;
975  *wD = *yH;
976  *zD = *zH;
977  {
978  double3 Ar3 = blas::cDotProductNormA(*zD, *vD);
979  auto alpha = Complex(Ar3.x, Ar3.y) / Ar3.z;
980  blas::caxpyXmaz(a * alpha, *vD, *wD, *zD);
981  }
982  *xH = *vD;
983  *yH = *wD;
984 
985  commGlobalReductionSet(true); // restore global reductions
986 
987  error = ERROR(x) + ERROR(y);
988  break;
989 
990  default: errorQuda("Undefined blas kernel %s\n", kernel_map.at(kernel).c_str());
991  }
992  delete[] A;
993  delete[] B;
994  delete[] C;
995  delete[] A2;
996  delete[] B2;
997  delete[] Ar;
998  return error;
999 }
1000 
1001 int main(int argc, char **argv)
1002 {
1003  ::testing::InitGoogleTest(&argc, argv);
1004  int result = 0;
1005 
1008  test_type = -1;
1009 
1010  // command line options
1011  auto app = make_app();
1012  // add_eigen_option_group(app);
1013  // add_deflation_option_group(app);
1014  // add_multigrid_option_group(app);
1015 
1016  app->add_option("--test", test_type, "Kernel to test (-1: -> all kernels)")->check(CLI::Range(0, Nkernels - 1));
1017  try {
1018  app->parse(argc, argv);
1019  } catch (const CLI::ParseError &e) {
1020  return app->exit(e);
1021  }
1022 
1023  // override spin setting if mg solver is set to test coarse grids
1024  if (inv_type == QUDA_MG_INVERTER) {
1025  Nspin = 2;
1026  Ncolor = nvec[0];
1027  if (Ncolor == 0) Ncolor = 24;
1028  } else {
1029  // set spin according to the type of dslash
1031  Ncolor = 3;
1032  }
1033 
1034  initComms(argc, argv, gridsize_from_cmdline);
1037 
1039 
1040  // lastly check for correctness
1042  if (comm_rank() != 0) { delete listeners.Release(listeners.default_result_printer()); }
1043  result = RUN_ALL_TESTS();
1044 
1045  endQuda();
1046 
1047  finalizeComms();
1048  return result;
1049 }
1050 
1051 // The following tests each kernel at each precision using the google testing framework
1052 
1056 using ::testing::TestWithParam;
1058 
1059 // map the 1-d precision test index into 2-d mixed prec
1061 {
1062  switch (idx) {
1063  case 0: return std::make_pair(QUDA_QUARTER_PRECISION, QUDA_QUARTER_PRECISION);
1064  case 1: return std::make_pair(QUDA_QUARTER_PRECISION, QUDA_HALF_PRECISION);
1065  case 2: return std::make_pair(QUDA_QUARTER_PRECISION, QUDA_SINGLE_PRECISION);
1066  case 3: return std::make_pair(QUDA_QUARTER_PRECISION, QUDA_DOUBLE_PRECISION);
1067  case 4: return std::make_pair(QUDA_HALF_PRECISION, QUDA_HALF_PRECISION);
1068  case 5: return std::make_pair(QUDA_HALF_PRECISION, QUDA_SINGLE_PRECISION);
1069  case 6: return std::make_pair(QUDA_HALF_PRECISION, QUDA_DOUBLE_PRECISION);
1070  case 7: return std::make_pair(QUDA_SINGLE_PRECISION, QUDA_SINGLE_PRECISION);
1071  case 8: return std::make_pair(QUDA_SINGLE_PRECISION, QUDA_DOUBLE_PRECISION);
1072  case 9: return std::make_pair(QUDA_DOUBLE_PRECISION, QUDA_DOUBLE_PRECISION);
1073  default: errorQuda("Unexpect precision index %d", idx);
1074  }
1075  return std::make_pair(QUDA_INVALID_PRECISION, QUDA_INVALID_PRECISION);
1076 }
1077 
1078 class BlasTest : public ::testing::TestWithParam<::testing::tuple<int, int>>
1079 {
1080 protected:
1081  ::testing::tuple<int, int> param;
1083  const int &kernel;
1084 
1085 public:
1086  BlasTest() : param(GetParam()), prec_pair(prec_idx_map(::testing::get<0>(param))), kernel(::testing::get<1>(param)) {}
1087  virtual void SetUp()
1088  {
1089  if (!skip_kernel(prec_pair, (Kernel)kernel)) initFields(prec_pair);
1090  }
1091  virtual void TearDown()
1092  {
1093  if (!skip_kernel(prec_pair, (Kernel)kernel)) { freeFields(); }
1094  }
1095 };
1096 
1098 {
1099  prec_pair_t prec_pair = ::prec_idx_map(testing::get<0>(GetParam()));
1100  Kernel kernel = (Kernel)::testing::get<1>(GetParam());
1101  if (skip_kernel(prec_pair, kernel)) GTEST_SKIP();
1102 
1103  // certain tests will fail to run for coarse grids so mark these as
1104  // failed without running
1105  double deviation = test(kernel);
1106  // printfQuda("%-35s error = %e\n", names[kernel], deviation);
1107  double tol_x
1108  = (prec_pair.first == QUDA_DOUBLE_PRECISION ?
1109  1e-12 :
1110  (prec_pair.first == QUDA_SINGLE_PRECISION ? 1e-6 : (prec_pair.first == QUDA_HALF_PRECISION ? 1e-4 : 1e-2)));
1111  double tol_y
1112  = (prec_pair.second == QUDA_DOUBLE_PRECISION ?
1113  1e-12 :
1114  (prec_pair.second == QUDA_SINGLE_PRECISION ? 1e-6 : (prec_pair.second == QUDA_HALF_PRECISION ? 1e-4 : 1e-2)));
1115  double tol = std::max(tol_x, tol_y);
1116  tol = is_copy(kernel) ? 5e-2 : tol; // use different tolerance for copy
1117  EXPECT_LE(deviation, tol) << "CPU and CUDA implementations do not agree";
1118  EXPECT_EQ(false, std::isnan(deviation)) << "Nan has propagated into the result";
1119 }
1120 
1122 {
1123  prec_pair_t prec_pair = prec_idx_map(::testing::get<0>(GetParam()));
1124  Kernel kernel = (Kernel)::testing::get<1>(GetParam());
1125 
1126  if (skip_kernel(prec_pair, kernel)) GTEST_SKIP();
1127 
1128  // do the initial tune
1129  benchmark(kernel, 1);
1130 
1131  // now rerun with more iterations to get accurate speed measurements
1132  quda::blas::flops = 0;
1133  quda::blas::bytes = 0;
1134 
1135  double secs = benchmark(kernel, niter);
1136 
1137  double gflops = (quda::blas::flops * 1e-9) / (secs);
1138  double gbytes = quda::blas::bytes / (secs * 1e9);
1139  RecordProperty("Gflops", std::to_string(gflops));
1140  RecordProperty("GBs", std::to_string(gbytes));
1141  printfQuda("%-31s: Gflop/s = %6.1f, GB/s = %6.1f\n", kernel_map.at(kernel).c_str(), gflops, gbytes);
1142 }
1143 
1145 {
1146  prec_pair_t prec_pair = prec_idx_map(::testing::get<0>(param.param));
1147  int kernel = ::testing::get<1>(param.param);
1148  std::string str(kernel_map.at((Kernel)kernel));
1149  str += std::string("_") + prec_map.at(prec_pair.first) + std::string("_") + prec_map.at(prec_pair.second);
1150  return str;
1151 }
1152 
1153 // instantiate all test cases
ColorSpinorField * xmoD
Definition: blas_test.cpp:44
ColorSpinorField * xH
Definition: blas_test.cpp:32
int Nspin
Definition: blas_test.cpp:50
ColorSpinorField * zD
Definition: blas_test.cpp:35
const std::map< Kernel, std::string > kernel_map
Definition: blas_test.cpp:118
void freeFields()
Definition: blas_test.cpp:315
int Ncolor
Definition: blas_test.cpp:51
int main(int argc, char **argv)
Definition: blas_test.cpp:1001
bool skip_kernel(prec_pair_t pair, Kernel kernel)
Definition: blas_test.cpp:173
prec_pair_t prec_idx_map(int idx)
Definition: blas_test.cpp:1060
ColorSpinorField * xoH
Definition: blas_test.cpp:32
ColorSpinorField * yoH
Definition: blas_test.cpp:32
ColorSpinorField * yoD
Definition: blas_test.cpp:41
ColorSpinorField * ymoD
Definition: blas_test.cpp:44
void setPrec(ColorSpinorParam &param, QudaPrecision precision)
Definition: blas_test.cpp:53
ColorSpinorField * wH
Definition: blas_test.cpp:32
#define ERROR(a)
Definition: blas_test.cpp:541
INSTANTIATE_TEST_SUITE_P(QUDA, BlasTest, Combine(Range(0,(Nprec *(Nprec+1))/2), Range(0, Nkernels)), getblasname)
ColorSpinorField * ymD
Definition: blas_test.cpp:38
ColorSpinorField * vH
Definition: blas_test.cpp:32
ColorSpinorField * yH
Definition: blas_test.cpp:32
TEST_P(BlasTest, verify)
Definition: blas_test.cpp:1097
const int Nkernels
Definition: blas_test.cpp:157
ColorSpinorField * zH
Definition: blas_test.cpp:32
std::vector< cpuColorSpinorField * > ymH
Definition: blas_test.cpp:48
bool is_multi(Kernel kernel)
Definition: blas_test.cpp:160
std::string getblasname(testing::TestParamInfo<::testing::tuple< int, int >> param)
Definition: blas_test.cpp:1144
ColorSpinorField * zmD
Definition: blas_test.cpp:38
ColorSpinorField * zoD
Definition: blas_test.cpp:41
Kernel
Definition: blas_test.cpp:74
@ axpbyzNorm
@ axpyCGNorm
@ cabxpyzAxNorm
@ tripleCGReduction
@ reDotProduct_block
@ xpyHeavyQuarkResidualNorm
@ caxpyXmazNormX
@ axpyBzpcx_block
@ cDotProduct_block
@ cDotProduct
@ caxpyXmazMR
@ cDotProductNorm_block
@ caxpbypzYmbwcDotProductUYNormY
@ reDotProduct
@ caxpbypzYmbw
@ HeavyQuarkResidualNorm
@ axpy_block
@ tripleCGUpdate
@ caxpyDotzy
@ cDotProductNormA
@ reDotProductNorm_block
@ caxpy_block
double benchmark(Kernel kernel, const int niter)
Definition: blas_test.cpp:349
ColorSpinorField * woD
Definition: blas_test.cpp:41
bool is_site_unroll(Kernel kernel)
Definition: blas_test.cpp:168
const int Nprec
Definition: blas_test.cpp:72
std::vector< cpuColorSpinorField * > zmH
Definition: blas_test.cpp:49
ColorSpinorField * vD
Definition: blas_test.cpp:35
double test(Kernel kernel)
Definition: blas_test.cpp:543
ColorSpinorField * xD
Definition: blas_test.cpp:35
ColorSpinorField * xmD
Definition: blas_test.cpp:38
void initFields(prec_pair_t prec_pair)
Definition: blas_test.cpp:201
ColorSpinorField * zmoD
Definition: blas_test.cpp:44
ColorSpinorField * yD
Definition: blas_test.cpp:35
ColorSpinorField * wD
Definition: blas_test.cpp:35
void display_test_info()
Definition: blas_test.cpp:55
ColorSpinorField * xoD
Definition: blas_test.cpp:41
std::vector< cpuColorSpinorField * > xmH
Definition: blas_test.cpp:47
bool is_copy(Kernel kernel)
Definition: blas_test.cpp:165
const std::map< QudaPrecision, std::string > prec_map
Definition: blas_test.cpp:67
ColorSpinorField * zoH
Definition: blas_test.cpp:32
std::pair< QudaPrecision, QudaPrecision > prec_pair_t
Definition: blas_test.cpp:65
ColorSpinorField * voD
Definition: blas_test.cpp:41
const int & kernel
Definition: blas_test.cpp:1083
virtual void TearDown()
Definition: blas_test.cpp:1091
::testing::tuple< int, int > param
Definition: blas_test.cpp:1081
virtual void SetUp()
Definition: blas_test.cpp:1087
const prec_pair_t prec_pair
Definition: blas_test.cpp:1082
CompositeColorSpinorField & Components()
ColorSpinorField & Component(const int idx) const
TestEventListener * Release(TestEventListener *listener)
TestEventListener * default_result_printer() const
Definition: gtest.h:1186
TestEventListeners & listeners()
static UnitTest * GetInstance()
void commAsyncReductionSet(bool global_reduce)
int comm_rank(void)
void commGlobalReductionSet(bool global_reduce)
QudaSolveType solve_type
QudaInverterType inv_type
std::shared_ptr< QUDAApp > make_app(std::string app_description, std::string app_name)
double tol
int niter
int test_type
int device_ordinal
QudaVerbosity verbosity
int & ydim
int Msrc
quda::mgarray< int > nvec
int & zdim
QudaDslashType dslash_type
int Nsrc
QudaPrecision prec
int & tdim
int & xdim
std::array< int, 4 > gridsize_from_cmdline
QudaPrecision prec_sloppy
void end(void)
enum QudaPrecision_s QudaPrecision
@ QUDA_RANDOM_SOURCE
Definition: enum_quda.h:376
@ QUDA_STAGGERED_DSLASH
Definition: enum_quda.h:97
@ QUDA_ASQTAD_DSLASH
Definition: enum_quda.h:98
@ QUDA_FULL_SITE_SUBSET
Definition: enum_quda.h:333
@ QUDA_PARITY_SITE_SUBSET
Definition: enum_quda.h:332
@ QUDA_DEGRAND_ROSSI_GAMMA_BASIS
Definition: enum_quda.h:368
@ QUDA_UKQCD_GAMMA_BASIS
Definition: enum_quda.h:369
@ QUDA_MG_INVERTER
Definition: enum_quda.h:122
@ QUDA_EVEN_ODD_SITE_ORDER
Definition: enum_quda.h:340
@ QUDA_DOUBLE_PRECISION
Definition: enum_quda.h:65
@ QUDA_SINGLE_PRECISION
Definition: enum_quda.h:64
@ QUDA_INVALID_PRECISION
Definition: enum_quda.h:66
@ QUDA_QUARTER_PRECISION
Definition: enum_quda.h:62
@ QUDA_HALF_PRECISION
Definition: enum_quda.h:63
@ QUDA_SPACE_SPIN_COLOR_FIELD_ORDER
Definition: enum_quda.h:351
@ QUDA_ZERO_FIELD_CREATE
Definition: enum_quda.h:361
@ QUDA_DIRECT_SOLVE
Definition: enum_quda.h:167
@ QUDA_NORMOP_PC_SOLVE
Definition: enum_quda.h:170
@ QUDA_DIRECT_PC_SOLVE
Definition: enum_quda.h:169
@ QUDA_NORMOP_SOLVE
Definition: enum_quda.h:168
#define GTEST_SKIP()
Definition: gtest.h:1887
#define EXPECT_EQ(val1, val2)
Definition: gtest.h:2017
int RUN_ALL_TESTS() GTEST_MUST_USE_RESULT_
Definition: gtest.h:2468
#define EXPECT_LE(val1, val2)
Definition: gtest.h:2021
int dimPartitioned(int dim)
Definition: host_utils.cpp:376
void initComms(int argc, char **argv, std::array< int, 4 > &commDims)
Definition: host_utils.cpp:255
void finalizeComms()
Definition: host_utils.cpp:292
__host__ __device__ __forceinline__ T & get(array< T, m > &src)
Definition: array.h:87
double axpbyzNorm(double a, ColorSpinorField &x, double b, ColorSpinorField &y, ColorSpinorField &z)
double axpyReDot(double a, ColorSpinorField &x, ColorSpinorField &y)
Complex axpyCGNorm(double a, ColorSpinorField &x, ColorSpinorField &y)
Complex caxpyDotzy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
void caxpbypzYmbw(const Complex &, ColorSpinorField &, const Complex &, ColorSpinorField &, ColorSpinorField &, ColorSpinorField &)
void axpyZpbx(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, double b)
void cabxpyAx(double a, const Complex &b, ColorSpinorField &x, ColorSpinorField &y)
void caxpby(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y)
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
double3 tripleCGReduction(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
void caxpyBzpx(const Complex &, ColorSpinorField &, ColorSpinorField &, const Complex &, ColorSpinorField &)
unsigned long long flops
double caxpyXmazNormX(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
void axpbyz(double a, ColorSpinorField &x, double b, ColorSpinorField &y, ColorSpinorField &z)
double3 caxpbypzYmbwcDotProductUYNormY(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &u)
unsigned long long bytes
void axpyBzpcx(double a, ColorSpinorField &x, ColorSpinorField &y, double b, ColorSpinorField &z, double c)
void caxpyBxpz(const Complex &, ColorSpinorField &, ColorSpinorField &, const Complex &, ColorSpinorField &)
double cabxpyzAxNorm(double a, const Complex &b, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
double3 xpyHeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &r)
void caxpyXmaz(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
void ax(double a, ColorSpinorField &x)
void caxpyXmazMR(const double &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
double caxpyNorm(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
double norm2(const ColorSpinorField &a)
void tripleCGUpdate(double alpha, double beta, ColorSpinorField &q, ColorSpinorField &r, ColorSpinorField &x, ColorSpinorField &p)
double reDotProduct(ColorSpinorField &x, ColorSpinorField &y)
void axpy(double a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.h:43
double3 cDotProductNormA(ColorSpinorField &a, ColorSpinorField &b)
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
void cxpaypbz(ColorSpinorField &, const Complex &b, ColorSpinorField &y, const Complex &c, ColorSpinorField &z)
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
Definition: blas_quda.h:24
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
void start()
Start profiling.
Definition: device.cpp:226
double norm2(const CloverField &a, bool inverse=false)
void ax(const double &a, GaugeField &u)
Scale the gauge field by the scalar a.
std::complex< double > Complex
Definition: quda_internal.h:86
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:120
__host__ __device__ ValueType exp(ValueType x)
Definition: complex_quda.h:96
__host__ __device__ ValueType abs(ValueType x)
Definition: complex_quda.h:125
::std::string string
Definition: gtest-port.h:891
internal::ParamGenerator< T > Range(T start, T end, IncrementT step)
internal::CartesianProductHolder< Generator... > Combine(const Generator &... g)
internal::ValueArray< T... > Values(T... v)
internal::ParamGenerator< bool > Bool()
GTEST_API_ void InitGoogleTest(int *argc, char **argv)
QudaGaugeParam param
Definition: pack_test.cpp:18
void initQuda(int device)
void endQuda(void)
#define printfQuda(...)
Definition: util_quda.h:114
void setVerbosity(QudaVerbosity verbosity)
Definition: util_quda.cpp:25
#define errorQuda(...)
Definition: util_quda.h:120