QUDA  0.9.0
multi_blas_core.h
Go to the documentation of this file.
1 
4 template <int NXZ, template < int MXZ, typename Float, typename FloatN> class Functor,
5  typename write, typename T>
9 
10  if (checkLocation(*x[0], *y[0], *z[0], *w[0]) == QUDA_CUDA_FIELD_LOCATION) {
11 
12  if (y[0]->Precision() == QUDA_DOUBLE_PRECISION && x[0]->Precision() == QUDA_DOUBLE_PRECISION) {
13 
14 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) || defined(GPU_STAGGERED_DIRAC)
15  const int M = 1;
16  multiblasCuda<NXZ,double2,double2,double2,M,Functor,write>(a,b,c,x,y,z,w,x[0]->Length()/(2*M));
17 #else
18  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
19 #endif
20 
21  } else if (y[0]->Precision() == QUDA_SINGLE_PRECISION && x[0]->Precision() == QUDA_SINGLE_PRECISION) {
22 
23  if (x[0]->Nspin() == 4) {
24 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
25  const int M = 1;
26  multiblasCuda<NXZ,float4,float4,float4,M,Functor,write>(a,b,c,x,y,z,w,x[0]->Length()/(4*M));
27 #else
28  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
29 #endif
30 
31  } else if (x[0]->Nspin()==2 || x[0]->Nspin()==1) {
32 
33 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) || defined(GPU_STAGGERED_DIRAC)
34  const int M = 1;
35  multiblasCuda<NXZ,float2,float2,float2,M,Functor,write>(a,b,c,x,y,z,w,x[0]->Length()/(2*M));
36 #else
37  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
38 #endif
39  } else { errorQuda("nSpin=%d is not supported\n", x[0]->Nspin()); }
40 
41  } else if (y[0]->Precision() == QUDA_HALF_PRECISION && x[0]->Precision() == QUDA_HALF_PRECISION) {
42 
43  if (x[0]->Ncolor() != 3) { errorQuda("nColor = %d is not supported", x[0]->Ncolor()); }
44  if (x[0]->Nspin() == 4) { //wilson
45 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
46  const int M = 6;
47  multiblasCuda<NXZ,float4,short4,short4,M,Functor,write>(a,b,c,x,y,z,w,x[0]->Volume());
48 #else
49  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
50 #endif
51  } else if (x[0]->Nspin() == 1) {//staggered
52 #ifdef GPU_STAGGERED_DIRAC
53  const int M = 3;
54  multiblasCuda<NXZ,float2,short2,short2,M,Functor,write>(a,b,c,x,y,z,w,x[0]->Volume());
55 #else
56  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
57 #endif
58  } else {
59  errorQuda("nSpin=%d is not supported\n", x[0]->Nspin());
60  }
61 
62  } else {
63 
64  errorQuda("Precision combination x=%d not supported\n", x[0]->Precision());
65 
66  }
67  } else { // fields on the cpu
68  // using namespace quda::colorspinor;
69  // if (x[0]->Precision() == QUDA_DOUBLE_PRECISION) {
70  // Functor<NXZ, NYW, double2, double2> f(a, b, c);
71  // genericMultBlas<double, double, writeX, writeY, writeZ, writeW>(x, y, z, w, f);
72  // } else if (x[0]->Precision() == QUDA_SINGLE_PRECISION) {
73  // Functor<NXZ, NYW, float2, float2> f(a, make_float2(b.x,b.y), make_float2(c.x,c.y) );
74  // genericMultBlas<float, float, writeX, writeY, writeZ, writeW>(x, y, z, w, f);
75  // } else {
76  errorQuda("Not implemented");
77  // }
78  }
79 
80 }
#define errorQuda(...)
Definition: util_quda.h:90
int Nspin
Definition: blas_test.cu:45
#define b
std::vector< ColorSpinorField * > CompositeColorSpinorField
int int int w
#define checkLocation(...)
void multiblasCuda(const coeff_array< T > &a, const coeff_array< T > &b, const coeff_array< T > &c, CompositeColorSpinorField &x, CompositeColorSpinorField &y, CompositeColorSpinorField &z, CompositeColorSpinorField &w)
int Ncolor
Definition: blas_test.cu:46
const void * c
#define a