QUDA  0.9.0
blas_test.cu
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 
4 #include <quda_internal.h>
5 #include <color_spinor_field.h>
6 #include <blas_quda.h>
7 
8 #include <test_util.h>
9 
10 // include because of nasty globals used in the tests
11 #include <dslash_util.h>
12 
13 // google test
14 #include <gtest.h>
15 
16 extern int test_type;
17 extern QudaPrecision prec;
20 extern int nvec;
21 extern int device;
22 extern int xdim;
23 extern int ydim;
24 extern int zdim;
25 extern int tdim;
26 extern int gridsize_from_cmdline[];
27 extern int niter;
28 
29 extern bool verify_results;
30 extern int Nsrc;
31 extern int Msrc;
33 
34 extern void usage(char** );
35 
36 const int Nkernels = 42;
37 
38 using namespace quda;
39 
40 ColorSpinorField *xH, *yH, *zH, *wH, *vH, *hH, *lH;
41 ColorSpinorField *xD, *yD, *zD, *wD, *vD, *hD, *lD, *xmD, *ymD, *zmD;
42 std::vector<cpuColorSpinorField*> xmH;
43 std::vector<cpuColorSpinorField*> ymH;
44 std::vector<cpuColorSpinorField*> zmH;
45 int Nspin;
46 int Ncolor;
47 
49 {
50  param.precision = precision;
51  if (Nspin == 1 || Nspin == 2 || precision == QUDA_DOUBLE_PRECISION) {
52  param.fieldOrder = QUDA_FLOAT2_FIELD_ORDER;
53  } else {
54  param.fieldOrder = QUDA_FLOAT4_FIELD_ORDER;
55  }
56 }
57 
58 void
60 {
61  printfQuda("running the following test:\n");
62  printfQuda("S_dimension T_dimension Nspin Ncolor\n");
63  printfQuda("%3d /%3d / %3d %3d %d %d\n", xdim, ydim, zdim, tdim, Nspin, Ncolor);
64  printfQuda("Grid partition info: X Y Z T\n");
65  printfQuda(" %d %d %d %d\n",
66  dimPartitioned(0),
67  dimPartitioned(1),
68  dimPartitioned(2),
69  dimPartitioned(3));
70  return;
71 }
72 
73 int Nprec = 3;
74 
75 bool skip_kernel(int precision, int kernel) {
76  // if we've selected a given kernel then make sure we only run that
77  if (test_type != -1 && kernel != test_type) return true;
78 
79  // if we've selected a given precision then make sure we only run that
80  QudaPrecision this_prec = precision == 2 ? QUDA_DOUBLE_PRECISION : precision == 1 ? QUDA_SINGLE_PRECISION : QUDA_HALF_PRECISION;
81  if (prec != QUDA_INVALID_PRECISION && this_prec != prec) return true;
82 
83  if ( Nspin == 2 && precision == 0) {
84  // avoid half precision tests if doing coarse fields
85  return true;
86  } else if (Nspin == 2 && kernel == 1) {
87  // avoid low-precision copy if doing coarse fields
88  return true;
89  } else if (Ncolor != 3 && (kernel == 31 || kernel == 32)) {
90  // only benchmark heavy-quark norm if doing 3 colors
91  return true;
92  } else if ((Nprec < 3) && (kernel == 0)) {
93  // only benchmark high-precision copy() if double is supported
94  return true;
95  }
96 
97  return false;
98 }
99 
100 void initFields(int prec)
101 {
102  // precisions used for the source field in the copyCuda() benchmark
103  QudaPrecision high_aux_prec = QUDA_INVALID_PRECISION;
104  QudaPrecision low_aux_prec = QUDA_INVALID_PRECISION;
105 
107  param.nColor = Ncolor;
108  param.nSpin = Nspin;
109  param.nDim = 4; // number of spacetime dimensions
110 
111  param.pad = 0; // padding must be zero for cpu fields
112 
114  param.siteSubset = QUDA_PARITY_SITE_SUBSET;
115  } else if (solve_type == QUDA_DIRECT_SOLVE) {
116  param.siteSubset = QUDA_FULL_SITE_SUBSET;
117  } else {
118  errorQuda("Unexpected solve_type=%d\n", solve_type);
119  }
120 
121  if (param.siteSubset == QUDA_PARITY_SITE_SUBSET) param.x[0] = xdim/2;
122  else param.x[0] = xdim;
123  param.x[1] = ydim;
124  param.x[2] = zdim;
125  param.x[3] = tdim;
126 
127  param.siteOrder = QUDA_EVEN_ODD_SITE_ORDER;
129  param.precision = QUDA_DOUBLE_PRECISION;
131 
132  param.create = QUDA_ZERO_FIELD_CREATE;
133 
141 
142 // create composite fields
143 
144  // xmH = new cpuColorSpinorField(param);
145  // ymH = new cpuColorSpinorField(param);
146 
147 
148 
149  xmH.reserve(Nsrc);
150  for (int cid = 0; cid < Nsrc; cid++) xmH.push_back(new cpuColorSpinorField(param));
151  ymH.reserve(Msrc);
152  for (int cid = 0; cid < Msrc; cid++) ymH.push_back(new cpuColorSpinorField(param));
153  zmH.reserve(Nsrc);
154  for (int cid = 0; cid < Nsrc; cid++) zmH.push_back(new cpuColorSpinorField(param));
155 
156 
157  static_cast<cpuColorSpinorField*>(vH)->Source(QUDA_RANDOM_SOURCE, 0, 0, 0);
158  static_cast<cpuColorSpinorField*>(wH)->Source(QUDA_RANDOM_SOURCE, 0, 0, 0);
159  static_cast<cpuColorSpinorField*>(xH)->Source(QUDA_RANDOM_SOURCE, 0, 0, 0);
160  static_cast<cpuColorSpinorField*>(yH)->Source(QUDA_RANDOM_SOURCE, 0, 0, 0);
161  static_cast<cpuColorSpinorField*>(zH)->Source(QUDA_RANDOM_SOURCE, 0, 0, 0);
162  static_cast<cpuColorSpinorField*>(hH)->Source(QUDA_RANDOM_SOURCE, 0, 0, 0);
163  static_cast<cpuColorSpinorField*>(lH)->Source(QUDA_RANDOM_SOURCE, 0, 0, 0);
164  for(int i=0; i<Nsrc; i++){
165  static_cast<cpuColorSpinorField*>(xmH[i])->Source(QUDA_RANDOM_SOURCE, 0, 0, 0);
166  }
167  for(int i=0; i<Msrc; i++){
168  static_cast<cpuColorSpinorField*>(ymH[i])->Source(QUDA_RANDOM_SOURCE, 0, 0, 0);
169  }
170  // Now set the parameters for the cuda fields
171  //param.pad = xdim*ydim*zdim/2;
172 
173  if (param.nSpin == 4) param.gammaBasis = QUDA_UKQCD_GAMMA_BASIS;
174  param.create = QUDA_ZERO_FIELD_CREATE;
175 
176  switch(prec) {
177  case 0:
179  high_aux_prec = QUDA_DOUBLE_PRECISION;
180  low_aux_prec = QUDA_SINGLE_PRECISION;
181  break;
182  case 1:
184  high_aux_prec = QUDA_DOUBLE_PRECISION;
185  low_aux_prec = QUDA_HALF_PRECISION;
186  break;
187  case 2:
189  high_aux_prec = QUDA_SINGLE_PRECISION;
190  low_aux_prec = QUDA_HALF_PRECISION;
191  break;
192  default:
193  errorQuda("Precision option not defined");
194  }
195 
196  checkCudaError();
197 
203 
204  param.is_composite = true;
205  param.is_component = false;
206 
207 // create composite fields
208  param.composite_dim = Nsrc;
210 
211  param.composite_dim = Msrc;
213 
214  param.composite_dim = Nsrc;
216 
217  param.is_composite = false;
218  param.is_component = false;
219  param.composite_dim = 1;
220 
221  setPrec(param, high_aux_prec);
223 
224  setPrec(param, low_aux_prec);
226 
227  // check for successful allocation
228  checkCudaError();
229 
230  // only do copy if not doing half precision with mg
231  bool flag = !(param.nSpin == 2 &&
232  (prec == 0 || low_aux_prec == QUDA_HALF_PRECISION) );
233 
234  if ( flag ) {
235  *vD = *vH;
236  *wD = *wH;
237  *xD = *xH;
238  *yD = *yH;
239  *zD = *zH;
240  *hD = *hH;
241  *lD = *lH;
242  // for (int i=0; i < Nsrc; i++){
243  // xmD->Component(i) = *(xmH[i]);
244  // ymD->Component(i) = *(ymH[i]);
245  // }
246  // *ymD = *ymH;
247  }
248 }
249 
250 
252 {
253 
254  // release memory
255  delete vD;
256  delete wD;
257  delete xD;
258  delete yD;
259  delete zD;
260  delete hD;
261  delete lD;
262  delete xmD;
263  delete ymD;
264  delete zmD;
265 
266  // release memory
267  delete vH;
268  delete wH;
269  delete xH;
270  delete yH;
271  delete zH;
272  delete hH;
273  delete lH;
274  for (int i=0; i < Nsrc; i++) delete xmH[i];
275  for (int i=0; i < Msrc; i++) delete ymH[i];
276  for (int i=0; i < Nsrc; i++) delete zmH[i];
277  xmH.clear();
278  ymH.clear();
279  zmH.clear();
280 }
281 
282 
283 double benchmark(int kernel, const int niter) {
284 
285  double a, b, c;
286  quda::Complex a2, b2, c2;
290  quda::Complex * A2 = new quda::Complex[Nsrc*Nsrc]; // for the block cDotProductNorm test
291 
292  cudaEvent_t start, end;
293  cudaEventCreate(&start);
294  cudaEventCreate(&end);
295  cudaEventRecord(start, 0);
296 
297  {
298  switch (kernel) {
299 
300  case 0:
301  for (int i=0; i < niter; ++i) blas::copy(*yD, *hD);
302  break;
303 
304  case 1:
305  for (int i=0; i < niter; ++i) blas::copy(*yD, *lD);
306  break;
307 
308  case 2:
309  for (int i=0; i < niter; ++i) blas::axpby(a, *xD, b, *yD);
310  break;
311 
312  case 3:
313  for (int i=0; i < niter; ++i) blas::xpy(*xD, *yD);
314  break;
315 
316  case 4:
317  for (int i=0; i < niter; ++i) blas::axpy(a, *xD, *yD);
318  break;
319 
320  case 5:
321  for (int i=0; i < niter; ++i) blas::xpay(*xD, a, *yD);
322  break;
323 
324  case 6:
325  for (int i=0; i < niter; ++i) blas::mxpy(*xD, *yD);
326  break;
327 
328  case 7:
329  for (int i=0; i < niter; ++i) blas::ax(a, *xD);
330  break;
331 
332  case 8:
333  for (int i=0; i < niter; ++i) blas::caxpy(a2, *xD, *yD);
334  break;
335 
336  case 9:
337  for (int i=0; i < niter; ++i) blas::caxpby(a2, *xD, b2, *yD);
338  break;
339 
340  case 10:
341  for (int i=0; i < niter; ++i) blas::cxpaypbz(*xD, a2, *yD, b2, *zD);
342  break;
343 
344  case 11:
345  for (int i=0; i < niter; ++i) blas::axpyBzpcx(a, *xD, *yD, b, *zD, c);
346  break;
347 
348  case 12:
349  for (int i=0; i < niter; ++i) blas::axpyZpbx(a, *xD, *yD, *zD, b);
350  break;
351 
352  case 13:
353  for (int i=0; i < niter; ++i) blas::caxpbypzYmbw(a2, *xD, b2, *yD, *zD, *wD);
354  break;
355 
356  case 14:
357  for (int i=0; i < niter; ++i) blas::cabxpyAx(a, b2, *xD, *yD);
358  break;
359 
360  case 15:
361  for (int i=0; i < niter; ++i) blas::caxpbypz(a2, *xD, b2, *yD, *zD);
362  break;
363 
364  case 16:
365  for (int i=0; i < niter; ++i) blas::caxpbypczpw(a2, *xD, b2, *yD, c2, *zD, *wD);
366  break;
367 
368  case 17:
369  for (int i=0; i < niter; ++i) blas::caxpyXmaz(a2, *xD, *yD, *zD);
370  break;
371 
372  // double
373  case 18:
374  for (int i=0; i < niter; ++i) blas::norm2(*xD);
375  break;
376 
377  case 19:
378  for (int i=0; i < niter; ++i) blas::reDotProduct(*xD, *yD);
379  break;
380 
381  case 20:
382  for (int i=0; i < niter; ++i) blas::axpyNorm(a, *xD, *yD);
383  break;
384 
385  case 21:
386  for (int i=0; i < niter; ++i) blas::xmyNorm(*xD, *yD);
387  break;
388 
389  case 22:
390  for (int i=0; i < niter; ++i) blas::caxpyNorm(a2, *xD, *yD);
391  break;
392 
393  case 23:
394  for (int i=0; i < niter; ++i) blas::caxpyXmazNormX(a2, *xD, *yD, *zD);
395  break;
396 
397  case 24:
398  for (int i=0; i < niter; ++i) blas::cabxpyAxNorm(a, b2, *xD, *yD);
399  break;
400 
401  // double2
402  case 25:
403  for (int i=0; i < niter; ++i) blas::cDotProduct(*xD, *yD);
404  break;
405 
406  case 26:
407  for (int i=0; i < niter; ++i) blas::xpaycDotzy(*xD, a, *yD, *zD);
408  break;
409 
410  case 27:
411  for (int i=0; i < niter; ++i) blas::caxpyDotzy(a2, *xD, *yD, *zD);
412  break;
413 
414  // double3
415  case 28:
416  for (int i=0; i < niter; ++i) blas::cDotProductNormA(*xD, *yD);
417  break;
418 
419  case 29:
420  for (int i=0; i < niter; ++i) blas::cDotProductNormB(*xD, *yD);
421  break;
422 
423  case 30:
424  for (int i=0; i < niter; ++i) blas::caxpbypzYmbwcDotProductUYNormY(a2, *xD, b2, *yD, *zD, *wD, *vD);
425  break;
426 
427  case 31:
428  for (int i=0; i < niter; ++i) blas::HeavyQuarkResidualNorm(*xD, *yD);
429  break;
430 
431  case 32:
432  for (int i=0; i < niter; ++i) blas::xpyHeavyQuarkResidualNorm(*xD, *yD, *zD);
433  break;
434 
435  case 33:
436  for (int i=0; i < niter; ++i) blas::tripleCGReduction(*xD, *yD, *zD);
437  break;
438 
439  case 34:
440  for (int i=0; i < niter; ++i) blas::tripleCGUpdate(a, b, *xD, *yD, *zD, *wD);
441  break;
442 
443  case 35:
444  for (int i=0; i < niter; ++i) blas::axpyReDot(a, *xD, *yD);
445  break;
446 
447  case 36:
448  for (int i=0; i < niter; ++i) blas::caxpy(A, *xmD,* ymD);
449  break;
450 
451  case 37:
452  for (int i=0; i < niter; ++i) blas::axpyBzpcx((double*)A, xmD->Components(), zmD->Components(), (double*)B, *yD, (double*)C);
453  break;
454 
455  case 38:
456  for (int i=0; i < niter; ++i) blas::caxpyBxpz(a2, *xD, *yD, b2, *zD);
457  break;
458 
459  case 39:
460  for (int i=0; i < niter; ++i) blas::caxpyBzpx(a2, *xD, *yD, b2, *zD);
461  break;
462 
463  case 40:
464  for (int i=0; i < niter; ++i) blas::cDotProduct(A2, xmD->Components(), xmD->Components());
465  break;
466 
467  case 41:
468  for (int i=0; i < niter; ++i) blas::cDotProduct(A, xmD->Components(), ymD->Components());
469  break;
470 
471  default:
472  errorQuda("Undefined blas kernel %d\n", kernel);
473  }
474  }
475 
476  cudaEventRecord(end, 0);
477  cudaEventSynchronize(end);
478  float runTime;
479  cudaEventElapsedTime(&runTime, start, end);
480  cudaEventDestroy(start);
481  cudaEventDestroy(end);
482  delete[] A;
483  delete[] B;
484  delete[] C;
485  delete[] A2;
486  double secs = runTime / 1000;
487  return secs;
488 }
489 
490 #define ERROR(a) fabs(blas::norm2(*a##D) - blas::norm2(*a##H)) / blas::norm2(*a##H)
491 
492 double test(int kernel) {
493 
494  double a = M_PI, b = M_PI*exp(1.0), c = sqrt(M_PI);
495  quda::Complex a2(a, b), b2(b, -c), c2(a+b, c*a);
496  double error = 0;
500  quda::Complex * A2 = new quda::Complex[Nsrc*Nsrc]; // for the block cDotProductNorm test
501  quda::Complex * B2 = new quda::Complex[Nsrc*Nsrc]; // for the block cDotProductNorm test
502  for(int i=0; i < Nsrc*Msrc; i++){
503  A[i] = a2* (1.0*((i/Nsrc) + i)) + b2 * (1.0*i) + c2 *(1.0*(Nsrc*Msrc/2-i));
504  B[i] = a2* (1.0*((i/Nsrc) + i)) - b2 * (M_PI*i) + c2 *(1.0*(Nsrc*Msrc/2-i));
505  C[i] = a2* (1.0*((M_PI/Nsrc) + i)) + b2 * (1.0*i) + c2 *(1.0*(Nsrc*Msrc/2-i));
506  }
507  for(int i=0; i < Nsrc*Nsrc; i++){
508  A2[i] = a2* (1.0*((i/Nsrc) + i)) + b2 * (1.0*i) + c2 *(1.0*(Nsrc*Nsrc/2-i));
509  B2[i] = a2* (1.0*((i/Nsrc) + i)) - b2 * (M_PI*i) + c2 *(1.0*(Nsrc*Nsrc/2-i));
510  }
511  // A[0] = a2;
512  // A[1] = 0.;
513  // A[2] = 0.;
514  // A[3] = 0.;
515 
516  switch (kernel) {
517 
518  case 0:
519  *hD = *hH;
520  blas::copy(*yD, *hD);
521  blas::copy(*yH, *hH);
522  error = ERROR(y);
523  break;
524 
525  case 1:
526  *lD = *lH;
527  blas::copy(*yD, *lD);
528  blas::copy(*yH, *lH);
529  error = ERROR(y);
530  break;
531 
532  case 2:
533  *xD = *xH;
534  *yD = *yH;
535  blas::axpby(a, *xD, b, *yD);
536  blas::axpby(a, *xH, b, *yH);
537  error = ERROR(y);
538  break;
539 
540  case 3:
541  *xD = *xH;
542  *yD = *yH;
543  blas::xpy(*xD, *yD);
544  blas::xpy(*xH, *yH);
545  error = ERROR(y);
546  break;
547 
548  case 4:
549  *xD = *xH;
550  *yD = *yH;
551  blas::axpy(a, *xD, *yD);
552  blas::axpy(a, *xH, *yH);
553  *zH = *yD;
554  error = ERROR(y);
555  break;
556 
557  case 5:
558  *xD = *xH;
559  *yD = *yH;
560  blas::xpay(*xD, a, *yD);
561  blas::xpay(*xH, a, *yH);
562  error = ERROR(y);
563  break;
564 
565  case 6:
566  *xD = *xH;
567  *yD = *yH;
568  blas::mxpy(*xD, *yD);
569  blas::mxpy(*xH, *yH);
570  error = ERROR(y);
571  break;
572 
573  case 7:
574  *xD = *xH;
575  blas::ax(a, *xD);
576  blas::ax(a, *xH);
577  error = ERROR(x);
578  break;
579 
580  case 8:
581  *xD = *xH;
582  *yD = *yH;
583  blas::caxpy(a2, *xD, *yD);
584  blas::caxpy(a2, *xH, *yH);
585  error = ERROR(y);
586  break;
587 
588  case 9:
589  *xD = *xH;
590  *yD = *yH;
591  blas::caxpby(a2, *xD, b2, *yD);
592  blas::caxpby(a2, *xH, b2, *yH);
593  error = ERROR(y);
594  break;
595 
596  case 10:
597  *xD = *xH;
598  *yD = *yH;
599  *zD = *zH;
600  blas::cxpaypbz(*xD, a2, *yD, b2, *zD);
601  blas::cxpaypbz(*xH, a2, *yH, b2, *zH);
602  error = ERROR(z);
603  break;
604 
605  case 11:
606  *xD = *xH;
607  *yD = *yH;
608  *zD = *zH;
609  blas::axpyBzpcx(a, *xD, *yD, b, *zD, c);
610  blas::axpyBzpcx(a, *xH, *yH, b, *zH, c);
611  error = ERROR(x) + ERROR(y);
612  break;
613 
614  case 12:
615  *xD = *xH;
616  *yD = *yH;
617  *zD = *zH;
618  blas::axpyZpbx(a, *xD, *yD, *zD, b);
619  blas::axpyZpbx(a, *xH, *yH, *zH, b);
620  error = ERROR(x) + ERROR(y);
621  break;
622 
623  case 13:
624  *xD = *xH;
625  *yD = *yH;
626  *zD = *zH;
627  *wD = *wH;
628  blas::caxpbypzYmbw(a2, *xD, b2, *yD, *zD, *wD);
629  blas::caxpbypzYmbw(a2, *xH, b2, *yH, *zH, *wH);
630  error = ERROR(z) + ERROR(y);
631  break;
632 
633  case 14:
634  *xD = *xH;
635  *yD = *yH;
636  blas::cabxpyAx(a, b2, *xD, *yD);
637  blas::cabxpyAx(a, b2, *xH, *yH);
638  error = ERROR(y) + ERROR(x);
639  break;
640 
641  case 15:
642  *xD = *xH;
643  *yD = *yH;
644  *zD = *zH;
645  {blas::caxpbypz(a2, *xD, b2, *yD, *zD);
646  blas::caxpbypz(a2, *xH, b2, *yH, *zH);
647  error = ERROR(z); }
648  break;
649 
650  case 16:
651  *xD = *xH;
652  *yD = *yH;
653  *zD = *zH;
654  *wD = *wH;
655  {blas::caxpbypczpw(a2, *xD, b2, *yD, c2, *zD, *wD);
656  blas::caxpbypczpw(a2, *xH, b2, *yH, c2, *zH, *wH);
657  error = ERROR(w); }
658  break;
659 
660  case 17:
661  *xD = *xH;
662  *yD = *yH;
663  *zD = *zH;
664  {blas::caxpyXmaz(a, *xD, *yD, *zD);
665  blas::caxpyXmaz(a, *xH, *yH, *zH);
666  error = ERROR(y) + ERROR(x);}
667  break;
668 
669  // double
670  case 18:
671  *xD = *xH;
672  *yH = *xD;
673  error = fabs(blas::norm2(*xD) - blas::norm2(*xH)) / blas::norm2(*xH);
674  break;
675 
676  case 19:
677  *xD = *xH;
678  *yD = *yH;
680  break;
681 
682  case 20:
683  *xD = *xH;
684  *yD = *yH;
685  {double d = blas::axpyNorm(a, *xD, *yD);
686  double h = blas::axpyNorm(a, *xH, *yH);
687  error = ERROR(y) + fabs(d-h)/fabs(h);}
688  break;
689 
690  case 21:
691  *xD = *xH;
692  *yD = *yH;
693  {double d = blas::xmyNorm(*xD, *yD);
694  double h = blas::xmyNorm(*xH, *yH);
695  error = ERROR(y) + fabs(d-h)/fabs(h);}
696  break;
697 
698  case 22:
699  *xD = *xH;
700  *yD = *yH;
701  {double d = blas::caxpyNorm(a, *xD, *yD);
702  double h = blas::caxpyNorm(a, *xH, *yH);
703  error = ERROR(y) + fabs(d-h)/fabs(h);}
704  break;
705 
706  case 23:
707  *xD = *xH;
708  *yD = *yH;
709  *zD = *zH;
710  {double d = blas::caxpyXmazNormX(a, *xD, *yD, *zD);
711  double h = blas::caxpyXmazNormX(a, *xH, *yH, *zH);
712  error = ERROR(y) + ERROR(x) + fabs(d-h)/fabs(h);}
713  break;
714 
715  case 24:
716  *xD = *xH;
717  *yD = *yH;
718  {double d = blas::cabxpyAxNorm(a, b2, *xD, *yD);
719  double h = blas::cabxpyAxNorm(a, b2, *xH, *yH);
720  error = ERROR(x) + ERROR(y) + fabs(d-h)/fabs(h);}
721  break;
722 
723  // double2
724  case 25:
725  *xD = *xH;
726  *yD = *yH;
728  break;
729 
730  case 26:
731  *xD = *xH;
732  *yD = *yH;
733  *zD = *zH;
734  { quda::Complex d = blas::xpaycDotzy(*xD, a, *yD, *zD);
736  error = fabs(blas::norm2(*yD) - blas::norm2(*yH)) / blas::norm2(*yH) + abs(d-h)/abs(h);
737  }
738  break;
739 
740  case 27:
741  *xD = *xH;
742  *yD = *yH;
743  *zD = *zH;
746  error = ERROR(y) + abs(d-h)/abs(h);}
747  break;
748 
749  // double3
750  case 28:
751  *xD = *xH;
752  *yD = *yH;
753  { double3 d = blas::cDotProductNormA(*xD, *yD);
754  double3 h = blas::cDotProductNormA(*xH, *yH);
755  error = fabs(d.x - h.x) / fabs(h.x) + fabs(d.y - h.y) / fabs(h.y) + fabs(d.z - h.z) / fabs(h.z); }
756  break;
757 
758  case 29:
759  *xD = *xH;
760  *yD = *yH;
761  { double3 d = blas::cDotProductNormB(*xD, *yD);
762  double3 h = blas::cDotProductNormB(*xH, *yH);
763  error = fabs(d.x - h.x) / fabs(h.x) + fabs(d.y - h.y) / fabs(h.y) + fabs(d.z - h.z) / fabs(h.z); }
764  break;
765 
766  case 30:
767  *xD = *xH;
768  *yD = *yH;
769  *zD = *zH;
770  *wD = *wH;
771  *vD = *vH;
772  { double3 d = blas::caxpbypzYmbwcDotProductUYNormY(a2, *xD, b2, *yD, *zD, *wD, *vD);
773  double3 h = blas::caxpbypzYmbwcDotProductUYNormY(a2, *xH, b2, *yH, *zH, *wH, *vH);
774  error = ERROR(z) + ERROR(y) + fabs(d.x - h.x) / fabs(h.x) +
775  fabs(d.y - h.y) / fabs(h.y) + fabs(d.z - h.z) / fabs(h.z); }
776  break;
777 
778  case 31:
779  *xD = *xH;
780  *yD = *yH;
781  { double3 d = blas::HeavyQuarkResidualNorm(*xD, *yD);
782  double3 h = blas::HeavyQuarkResidualNorm(*xH, *yH);
783  error = fabs(d.x - h.x) / fabs(h.x) +
784  fabs(d.y - h.y) / fabs(h.y) + fabs(d.z - h.z) / fabs(h.z); }
785  break;
786 
787  case 32:
788  *xD = *xH;
789  *yD = *yH;
790  *zD = *zH;
791  { double3 d = blas::xpyHeavyQuarkResidualNorm(*xD, *yD, *zD);
792  double3 h = blas::xpyHeavyQuarkResidualNorm(*xH, *yH, *zH);
793  error = ERROR(y) + fabs(d.x - h.x) / fabs(h.x) +
794  fabs(d.y - h.y) / fabs(h.y) + fabs(d.z - h.z) / fabs(h.z); }
795  break;
796 
797  case 33:
798  *xD = *xH;
799  *yD = *yH;
800  *zD = *zH;
801  { double3 d = blas::tripleCGReduction(*xD, *yD, *zD);
802  double3 h = make_double3(blas::norm2(*xH), blas::norm2(*yH), blas::reDotProduct(*yH, *zH));
803  error = fabs(d.x - h.x) / fabs(h.x) +
804  fabs(d.y - h.y) / fabs(h.y) + fabs(d.z - h.z) / fabs(h.z); }
805  break;
806 
807  case 34:
808  *xD = *xH;
809  *yD = *yH;
810  *zD = *zH;
811  *wD = *wH;
812  { blas::tripleCGUpdate(a, b, *xD, *yD, *zD, *wD);
813  blas::tripleCGUpdate(a, b, *xH, *yH, *zH, *wH);
814  error = ERROR(y) + ERROR(z) + ERROR(w); }
815  break;
816 
817  case 35:
818  *xD = *xH;
819  *yD = *yH;
820  { double d = blas::axpyReDot(a, *xD, *yD);
821  double h = blas::axpyReDot(a, *xH, *yH);
822  error = ERROR(y) + fabs(d-h)/fabs(h); }
823  break;
824 
825  case 36:
826  for (int i=0; i < Nsrc; i++) xmD->Component(i) = *(xmH[i]);
827  for (int i=0; i < Msrc; i++) ymD->Component(i) = *(ymH[i]);
828 
829  blas::caxpy(A, *xmD, *ymD);
830  for (int i=0; i < Nsrc; i++){
831  for(int j=0; j < Msrc; j++){
832  blas::caxpy(A[Msrc*i+j], *(xmH[i]), *(ymH[j]));
833  }
834  }
835  error = 0;
836  for (int i=0; i < Msrc; i++){
837  error+= fabs(blas::norm2((ymD->Component(i))) - blas::norm2(*(ymH[i]))) / blas::norm2(*(ymH[i]));
838  }
839  error/= Msrc;
840  break;
841 
842  case 37:
843  for (int i=0; i < Nsrc; i++) {
844  xmD->Component(i) = *(xmH[i]);
845  zmD->Component(i) = *(zmH[i]);
846  }
847  *yD = *yH;
848 
849  blas::axpyBzpcx((double*)A, xmD->Components(), zmD->Components(), (double*)B, *yD, (const double*)C);
850 
851  for (int i=0; i<Nsrc; i++) {
852  blas::axpyBzpcx(((double*)A)[i], *xmH[i], *zmH[i], ((double*)B)[i], *yH, ((double*)C)[i]);
853  }
854 
855  error = 0;
856  for (int i=0; i < Nsrc; i++){
857  error+= fabs(blas::norm2((xmD->Component(i))) - blas::norm2(*(xmH[i]))) / blas::norm2(*(xmH[i]));
858  //error+= fabs(blas::norm2((zmD->Component(i))) - blas::norm2(*(zmH[i]))) / blas::norm2(*(zmH[i]));
859  }
860  error/= Nsrc;
861  break;
862 
863  case 38:
864  *xD = *xH;
865  *yD = *yH;
866  *zD = *zH;
867  {blas::caxpyBxpz(a, *xD, *yD, b2, *zD);
868  blas::caxpyBxpz(a, *xH, *yH, b2, *zH);
869  error = ERROR(x) + ERROR(z);}
870  break;
871 
872  case 39:
873  *xD = *xH;
874  *yD = *yH;
875  *zD = *zH;
876  {blas::caxpyBzpx(a, *xD, *yD, b2, *zD);
877  blas::caxpyBzpx(a, *xH, *yH, b2, *zH);
878  error = ERROR(x) + ERROR(z);}
879  break;
880 
881  case 40:
882  for (int i=0; i < Nsrc; i++) xmD->Component(i) = *(xmH[i]);
884  error = 0.0;
885  for (int i = 0; i < Nsrc; i++) {
886  for (int j = 0; j < Nsrc; j++) {
888  error += std::abs(A2[i*Nsrc+j] - B2[i*Nsrc+j])/std::abs(B2[i*Nsrc+j]);
889  }
890  }
891  error /= Nsrc*Nsrc;
892  break;
893 
894  case 41:
895  for (int i=0; i < Nsrc; i++) xmD->Component(i) = *(xmH[i]);
896  for (int i=0; i < Msrc; i++) ymD->Component(i) = *(ymH[i]);
898  error = 0.0;
899  for (int i = 0; i < Nsrc; i++) {
900  for (int j = 0; j < Msrc; j++) {
902  error += std::abs(A[i*Msrc+j] - B[i*Msrc+j])/std::abs(B[i*Msrc+j]);
903  }
904  }
905  error /= Nsrc*Msrc;
906  break;
907 
908  default:
909  errorQuda("Undefined blas kernel %d\n", kernel);
910  }
911  delete[] A;
912  delete[] B;
913  delete[] C;
914  delete[] A2;
915  delete[] B2;
916  return error;
917 }
918 
919 const char *prec_str[] = {"half", "single", "double"};
920 
921 
922 // For googletest names must be non-empty, unique, and may only contain ASCII
923 // alphanumeric characters or underscore
924 const char *names[] = {
925  "copyHS",
926  "copyLS",
927  "axpby",
928  "xpy",
929  "axpy",
930  "xpay",
931  "mxpy",
932  "ax",
933  "caxpy",
934  "caxpby",
935  "cxpaypbz",
936  "axpyBzpcx",
937  "axpyZpbx",
938  "caxpbypzYmbw",
939  "cabxpyAx",
940  "caxpbypz",
941  "caxpbypczpw",
942  "caxpyXmaz",
943  "norm",
944  "reDotProduct",
945  "axpyNorm",
946  "xmyNorm",
947  "caxpyNorm",
948  "caxpyXmazNormX",
949  "cabxpyAxNorm",
950  "cDotProduct",
951  "xpaycDotzy",
952  "caxpyDotzy",
953  "cDotProductNormA",
954  "cDotProductNormB",
955  "caxpbypzYmbwcDotProductUYNormY",
956  "HeavyQuarkResidualNorm",
957  "xpyHeavyQuarkResidualNorm",
958  "tripleCGReduction",
959  "tripleCGUpdate",
960  "axpyReDot",
961  "caxpy_block",
962  "axpyBzpcx_block",
963  "caxpyBxpz",
964  "caxpyBzpx",
965  "cDotProductNorm_block",
966  "cDotProduct_block",
967  "caxpy_composite"
968 };
969 
970 int main(int argc, char** argv)
971 {
972 
973  ::testing::InitGoogleTest(&argc, argv);
974  int result = 0;
975 
977  test_type = -1;
978 
979  for (int i = 1; i < argc; i++){
980  if(process_command_line_option(argc, argv, &i) == 0){
981  continue;
982  }
983  printfQuda("ERROR: Invalid option:%s\n", argv[i]);
984  usage(argv);
985  }
986 
987  // override spin setting if mg solver is set to test coarse grids
988  if (inv_type == QUDA_MG_INVERTER) {
989  Nspin = 2;
990  Ncolor = nvec;
991  } else {
992  // set spin according to the type of dslash
995  Ncolor = 3;
996  }
997 
998  setSpinorSiteSize(24);
999  initComms(argc, argv, gridsize_from_cmdline);
1001  initQuda(device);
1002 
1004 
1005  // clear the error state
1006  cudaGetLastError();
1007 
1008  // lastly check for correctness
1009  if (verify_results) {
1010  result = RUN_ALL_TESTS();
1011  }
1012 
1013  endQuda();
1014 
1015  finalizeComms();
1016  return result;
1017 }
1018 
1019 // The following tests each kernel at each precision using the google testing framework
1020 
1021 using ::testing::TestWithParam;
1022 using ::testing::Bool;
1023 using ::testing::Values;
1024 using ::testing::Range;
1025 using ::testing::Combine;
1026 
1027 class BlasTest : public ::testing::TestWithParam<::testing::tuple<int, int>> {
1028 protected:
1029  ::testing::tuple<int, int> param;
1030 
1031 public:
1032  virtual ~BlasTest() { }
1033  virtual void SetUp() {
1034  param = GetParam();
1035  initFields(::testing::get<0>(GetParam()));
1036  }
1037  virtual void TearDown() { freeFields(); }
1038 
1039 };
1040 
1041 
1042 TEST_P(BlasTest, verify) {
1043  int prec = ::testing::get<0>(GetParam());
1044  int kernel = ::testing::get<1>(GetParam());
1045 
1046  // certain tests will fail to run for coarse grids so mark these as
1047  // failed without running
1048  double deviation = skip_kernel(prec,kernel) ? 1.0 : test(kernel);
1049  // printfQuda("%-35s error = %e\n", names[kernel], deviation);
1050  double tol = (prec == 2 ? 1e-10 : (prec == 1 ? 1e-5 : 1e-3));
1051  tol = (kernel < 2) ? 1e-4 : tol; // use different tolerance for copy
1052  EXPECT_LE(deviation, tol) << "CPU and CUDA implementations do not agree";
1053 }
1054 
1056  int prec = ::testing::get<0>(GetParam());
1057  int kernel = ::testing::get<1>(GetParam());
1058 // do the initial tune
1059  benchmark(kernel, 1);
1060 
1061  // now rerun with more iterations to get accurate speed measurements
1062  quda::blas::flops = 0;
1063  quda::blas::bytes = 0;
1064 
1065  double secs = benchmark(kernel, niter);
1066 
1067  double gflops = (quda::blas::flops*1e-9)/(secs);
1068  double gbytes = quda::blas::bytes/(secs*1e9);
1069  RecordProperty("Gflops", std::to_string(gflops));
1070  RecordProperty("GBs", std::to_string(gbytes));
1071  printfQuda("%-31s: Gflop/s = %6.1f, GB/s = %6.1f\n", names[kernel], gflops, gbytes);
1072 }
1073 
1074 
1075 std::string getblasname(testing::TestParamInfo<::testing::tuple<int, int>> param){
1076  int prec = ::testing::get<0>(param.param);
1077  int kernel = ::testing::get<1>(param.param);
1078  std::string str(names[kernel]);
1079  str += std::string("_");
1080  str += std::string(prec_str[prec]);
1081  return str;//names[kernel] + "_" + prec_str[prec];
1082 }
1083 
1084 // half precision
1085 INSTANTIATE_TEST_CASE_P(QUDA, BlasTest, Combine( Range(0,3), Range(0, Nkernels) ), getblasname);
1086 
QudaDslashType dslash_type
Definition: test_util.cpp:1626
ColorSpinorField * ymD
Definition: blas_test.cu:41
int dimPartitioned(int dim)
Definition: test_util.cpp:1686
void endQuda(void)
void xpay(ColorSpinorField &x, const double &a, ColorSpinorField &y)
Definition: blas_quda.cu:173
enum QudaPrecision_s QudaPrecision
double3 cDotProductNormA(ColorSpinorField &a, ColorSpinorField &b)
Definition: reduce_quda.cu:572
double test(int kernel)
Definition: blas_test.cu:492
int Msrc
Definition: test_util.cpp:1629
double caxpyNorm(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:402
void display_test_info()
Definition: blas_test.cu:59
__host__ __device__ ValueType exp(ValueType x)
Definition: complex_quda.h:85
#define errorQuda(...)
Definition: util_quda.h:90
double norm2(const ColorSpinorField &a)
Definition: reduce_quda.cu:241
QudaInverterType inv_type
Definition: test_util.cpp:1638
cudaEvent_t start
int xdim
Definition: test_util.cpp:1620
enum QudaSolveType_s QudaSolveType
__host__ __device__ ValueType sqrt(ValueType x)
Definition: complex_quda.h:105
Complex cDotProduct(ColorSpinorField &, ColorSpinorField &)
Definition: reduce_quda.cu:500
std::complex< double > Complex
Definition: eig_variables.h:13
int process_command_line_option(int argc, char **argv, int *idx)
Definition: test_util.cpp:1795
ColorSpinorField * yH
Definition: blas_test.cu:40
::testing::tuple< int, int > param
Definition: blas_test.cu:1029
ColorSpinorField * vH
Definition: blas_test.cu:40
virtual ~BlasTest()
Definition: blas_test.cu:1032
double3 xpyHeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &r)
Definition: reduce_quda.cu:742
std::vector< cpuColorSpinorField * > ymH
Definition: blas_test.cu:43
double axpyNorm(const double &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:325
CompositeColorSpinorField & Components()
ColorSpinorField * yD
Definition: blas_test.cu:41
int Nprec
Definition: blas_test.cu:73
bool skip_kernel(int precision, int kernel)
Definition: blas_test.cu:75
std::vector< cpuColorSpinorField * > xmH
Definition: blas_test.cu:42
double reDotProduct(ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:277
int zdim
Definition: test_util.cpp:1622
void copy(ColorSpinorField &dst, const ColorSpinorField &src)
Definition: copy_quda.cu:263
void finalizeComms()
Definition: test_util.cpp:107
void ax(const double &a, ColorSpinorField &x)
Definition: blas_quda.cu:209
ColorSpinorField & Component(const int idx) const
double xmyNorm(ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:364
int nvec
Definition: test_util.cpp:1635
void caxpyBzpx(const Complex &, ColorSpinorField &, ColorSpinorField &, const Complex &, ColorSpinorField &)
Definition: blas_quda.cu:412
int Nspin
Definition: blas_test.cu:45
void caxpyBxpz(const Complex &, ColorSpinorField &, ColorSpinorField &, const Complex &, ColorSpinorField &)
Definition: blas_quda.cu:438
ColorSpinorField * xD
Definition: blas_test.cu:41
ColorSpinorField * wD
Definition: blas_test.cu:41
int tdim
ColorSpinorField * wH
Definition: blas_test.cu:40
QudaGaugeParam param
Definition: pack_test.cpp:17
#define b
void usage(char **)
Definition: test_util.cpp:1693
const char * names[]
Definition: blas_test.cu:924
void cabxpyAx(const double &a, const Complex &b, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:484
void initQuda(int device)
double cabxpyAxNorm(const double &a, const Complex &b, ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:449
double tol
Definition: test_util.cpp:1647
double benchmark(int kernel, const int niter)
Definition: blas_test.cu:283
void initFields(int prec)
Definition: blas_test.cu:100
void axpyZpbx(const double &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, const double &b)
Definition: blas_quda.cu:384
virtual void SetUp()
Definition: blas_test.cu:1033
ColorSpinorField * zD
Definition: blas_test.cu:41
virtual void TearDown()
Definition: blas_test.cu:1037
ColorSpinorField * xmD
Definition: blas_test.cu:41
void setSpinorSiteSize(int n)
Definition: test_util.cpp:192
ColorSpinorField * vD
Definition: blas_test.cu:41
int int int w
void caxpbypzYmbw(const Complex &, ColorSpinorField &, const Complex &, ColorSpinorField &, ColorSpinorField &, ColorSpinorField &)
Definition: blas_quda.cu:464
void tripleCGUpdate(const double &alpha, const double &beta, ColorSpinorField &q, ColorSpinorField &r, ColorSpinorField &x, ColorSpinorField &p)
Definition: blas_quda.cu:610
const char * prec_str[]
Definition: blas_test.cu:919
int niter
Definition: test_util.cpp:1630
double3 HeavyQuarkResidualNorm(ColorSpinorField &x, ColorSpinorField &r)
Definition: reduce_quda.cu:703
ColorSpinorField * xH
Definition: blas_test.cu:40
double caxpyXmazNormX(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Definition: reduce_quda.cu:424
const int Nkernels
Definition: blas_test.cu:36
Complex caxpyDotzy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Definition: reduce_quda.cu:544
ColorSpinorField * lH
Definition: blas_test.cu:40
void caxpy(const Complex &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:246
static __inline__ size_t h
QudaSolveType solve_type
Definition: test_util.cpp:1653
void caxpbypczpw(const Complex &, ColorSpinorField &, const Complex &, ColorSpinorField &, const Complex &, ColorSpinorField &, ColorSpinorField &)
Definition: blas_quda.cu:527
void axpy(const double &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:150
void axpby(const double &a, ColorSpinorField &x, const double &b, ColorSpinorField &y)
Definition: blas_quda.cu:106
int abs(int) __attribute__((const))
ColorSpinorField * hD
Definition: blas_test.cu:41
double3 caxpbypzYmbwcDotProductUYNormY(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &u)
Definition: reduce_quda.cu:619
void caxpbypz(const Complex &, ColorSpinorField &, const Complex &, ColorSpinorField &, ColorSpinorField &)
Definition: blas_quda.cu:505
ColorSpinorField * lD
Definition: blas_test.cu:41
int Ncolor
Definition: blas_test.cu:46
double axpyReDot(const double &a, ColorSpinorField &x, ColorSpinorField &y)
Definition: reduce_quda.cu:345
int Nsrc
Definition: test_util.cpp:1628
int ydim
Definition: test_util.cpp:1621
void axpyBzpcx(const double &a, ColorSpinorField &x, ColorSpinorField &y, const double &b, ColorSpinorField &z, const double &c)
Definition: blas_quda.cu:356
void caxpyXmaz(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Definition: blas_quda.cu:549
Complex xpaycDotzy(ColorSpinorField &x, const double &a, ColorSpinorField &y, ColorSpinorField &z)
Definition: reduce_quda.cu:521
ColorSpinorField * zmD
Definition: blas_test.cu:41
void setPrec(ColorSpinorParam &param, const QudaPrecision precision)
Definition: blas_test.cu:48
#define printfQuda(...)
Definition: util_quda.h:84
double fabs(double)
INSTANTIATE_TEST_CASE_P(QUDA, BlasTest, Combine(Range(0, 3), Range(0, Nkernels)), getblasname)
unsigned long long flops
Definition: blas_quda.cu:42
void xpy(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:128
std::string getblasname(testing::TestParamInfo<::testing::tuple< int, int >> param)
Definition: blas_test.cu:1075
enum QudaDslashType_s QudaDslashType
void caxpby(const Complex &a, ColorSpinorField &x, const Complex &b, ColorSpinorField &y)
Definition: blas_quda.cu:292
const void * c
ColorSpinorField * zH
Definition: blas_test.cu:40
void mxpy(ColorSpinorField &x, ColorSpinorField &y)
Definition: blas_quda.cu:192
__host__ __device__ ValueType abs(ValueType x)
Definition: complex_quda.h:110
int test_type
Definition: test_util.cpp:1634
void cxpaypbz(ColorSpinorField &, const Complex &b, ColorSpinorField &y, const Complex &c, ColorSpinorField &z)
Definition: blas_quda.cu:335
#define checkCudaError()
Definition: util_quda.h:129
TEST_P(BlasTest, verify)
Definition: blas_test.cu:1042
bool verify_results
Definition: test_util.cpp:1641
std::vector< cpuColorSpinorField * > zmH
Definition: blas_test.cu:44
double3 cDotProductNormB(ColorSpinorField &a, ColorSpinorField &b)
Definition: reduce_quda.cu:599
static __inline__ size_t size_t d
ColorSpinorField * hH
Definition: blas_test.cu:40
QudaPrecision prec
Definition: test_util.cpp:1615
#define a
void initComms(int argc, char **argv, const int *commDims)
Definition: test_util.cpp:72
void freeFields()
Definition: blas_test.cu:251
int gridsize_from_cmdline[]
Definition: test_util.cpp:50
void setVerbosity(const QudaVerbosity verbosity)
Definition: util_quda.cpp:24
enum QudaInverterType_s QudaInverterType
unsigned long long bytes
Definition: blas_quda.cu:43
int main(int argc, char **argv)
Definition: blas_test.cu:970
double3 tripleCGReduction(ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z)
Definition: reduce_quda.cu:767
cudaEvent_t cudaEvent_t end
#define ERROR(a)
Definition: blas_test.cu:490