QUDA  0.9.0
blas_core.h
Go to the documentation of this file.
1 
4 template <template <typename Float, typename FloatN> class Functor,
5  int writeX, int writeY, int writeZ, int writeW>
6  void blasCuda(const double2 &a, const double2 &b, const double2 &c,
7  ColorSpinorField &x, ColorSpinorField &y,
8  ColorSpinorField &z, ColorSpinorField &w) {
9 
11 
12  if (x.Precision() == QUDA_DOUBLE_PRECISION) {
13 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) || defined(GPU_STAGGERED_DIRAC)
14  const int M = 1;
15  blasCuda<double2,double2,double2,M,Functor,writeX,writeY,writeZ,writeW>(a,b,c,x,y,z,w,x.Length()/(2*M));
16 #else
17  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
18 #endif
19  } else if (x.Precision() == QUDA_SINGLE_PRECISION) {
20  if (x.Nspin() == 4) {
21 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
22  const int M = 1;
23  blasCuda<float4,float4,float4,M,Functor,writeX,writeY,writeZ,writeW>(a,b,c,x,y,z,w,x.Length()/(4*M));
24 #else
25  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
26 #endif
27  } else if (x.Nspin()==2 || x.Nspin()==1) {
28 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) || defined(GPU_STAGGERED_DIRAC)
29  const int M = 1;
30  blasCuda<float2,float2,float2,M,Functor,writeX,writeY,writeZ,writeW>(a,b,c,x,y,z,w,x.Length()/(2*M));
31 #else
32  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
33 #endif
34  } else { errorQuda("nSpin=%d is not supported\n", x.Nspin()); }
35  } else {
36  if (x.Ncolor() != 3) { errorQuda("nColor = %d is not supported", x.Ncolor()); }
37  if (x.Nspin() == 4) { //wilson
38 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
39  const int M = 6;
40  blasCuda<float4,short4,short4,M,Functor,writeX,writeY,writeZ,writeW>(a,b,c,x,y,z,w,x.Volume());
41 #else
42  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
43 #endif
44  } else if (x.Nspin() == 1) {//staggered
45 #ifdef GPU_STAGGERED_DIRAC
46  const int M = 3;
47  blasCuda<float2,short2,short2,M,Functor,writeX,writeY,writeZ,writeW>(a,b,c,x,y,z,w,x.Volume());
48 #else
49  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
50 #endif
51  } else {
52  errorQuda("nSpin=%d is not supported\n", x.Nspin());
53  }
54  }
55  } else { // fields on the cpu
56  using namespace quda::colorspinor;
57  if (x.Precision() == QUDA_DOUBLE_PRECISION) {
58  Functor<double2, double2> f(a, b, c);
59  genericBlas<double, double, writeX, writeY, writeZ, writeW>(x, y, z, w, f);
60  } else if (x.Precision() == QUDA_SINGLE_PRECISION) {
61  Functor<float2, float2> f(make_float2(a.x,a.y), make_float2(b.x,b.y), make_float2(c.x,c.y) );
62  genericBlas<float, float, writeX, writeY, writeZ, writeW>(x, y, z, w, f);
63  } else {
64  errorQuda("Not implemented");
65  }
66  }
67 
68 }
69 
#define errorQuda(...)
Definition: util_quda.h:90
#define b
int int int w
#define checkLocation(...)
int int int enum cudaChannelFormatKind f
void blasCuda(const double2 &a, const double2 &b, const double2 &c, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w)
Definition: blas_core.h:6
const void * c
#define a