QUDA  0.9.0
multi_blas_mixed_core.h
Go to the documentation of this file.
1 namespace mixed {
2 
6  template <int NXZ, template < int MXZ, typename Float, typename FloatN> class Functor,
7  typename write, typename T>
11 
12  if (checkLocation(*x[0], *y[0], *z[0], *w[0]) == QUDA_CUDA_FIELD_LOCATION) {
13 
14  if (y[0]->Precision() == QUDA_DOUBLE_PRECISION && x[0]->Precision() == QUDA_SINGLE_PRECISION) {
15 
16  if (x[0]->Nspin() == 4) {
17 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
18  const int M = 12;
19  multiblasCuda<NXZ,double2,float4,double2,M,Functor,write>(a,b,c,x,y,z,w,x[0]->Volume());
20 #else
21  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
22 #endif
23  } else if (x[0]->Nspin() == 1) {
24 #if defined(GPU_STAGGERED_DIRAC)
25  const int M = 3;
26  multiblasCuda<NXZ,double2,float2,double2,M,Functor,write>(a,b,c,x,y,z,w,x[0]->Volume());
27 #else
28  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
29 #endif
30  }
31 
32  } else if (y[0]->Precision() == QUDA_DOUBLE_PRECISION && x[0]->Precision() == QUDA_HALF_PRECISION) {
33 
34  if (x[0]->Nspin() == 4) {
35 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
36  const int M = 12;
37  multiblasCuda<NXZ,double2,short4,double2,M,Functor,write>(a,b,c,x,y,z,w,x[0]->Volume());
38 #else
39  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
40 #endif
41  } else if (x[0]->Nspin() == 1) {
42 #if defined(GPU_STAGGERED_DIRAC)
43  const int M = 3;
44  multiblasCuda<NXZ,double2,short2,double2,M,Functor,write>(a,b,c,x,y,z,w,x[0]->Volume());
45 #else
46  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
47 #endif
48  }
49 
50  } else if (y[0]->Precision() == QUDA_SINGLE_PRECISION && x[0]->Precision() == QUDA_HALF_PRECISION) {
51 
52  if (x[0]->Nspin() == 4) {
53 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
54  const int M = 6;
55  multiblasCuda<NXZ,float4,short4,float4,M,Functor,write>(a,b,c,x,y,z,w,x[0]->Volume());
56 #else
57  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
58 #endif
59 
60  } else if (x[0]->Nspin()==2 || x[0]->Nspin()==1) {
61 
62 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) || defined(GPU_STAGGERED_DIRAC)
63  const int M = 3;
64  multiblasCuda<NXZ,float2,short2,float2,M,Functor,write>(a,b,c,x,y,z,w,x[0]->Volume());
65 #else
66  errorQuda("blas has not been built for Nspin=%d fields", x[0]->Nspin());
67 #endif
68  } else { errorQuda("nSpin=%d is not supported\n", x[0]->Nspin()); }
69 
70  } else {
71  errorQuda("Precision combination x=%d y=%d not supported\n", x[0]->Precision(), y[0]->Precision());
72  }
73  } else { // fields on the cpu
74  // using namespace quda::colorspinor;
75  // if (x[0]->Precision() == QUDA_DOUBLE_PRECISION) {
76  // Functor<NXZ, NYW, double2, double2> f(a, b, c);
77  // genericMultBlas<double, double, writeX, writeY, writeZ, writeW>(x, y, z, w, f);
78  // } else if (x[0]->Precision() == QUDA_SINGLE_PRECISION) {
79  // Functor<NXZ, NYW, float2, float2> f(a, make_float2(b.x,b.y), make_float2(c.x,c.y) );
80  // genericMultBlas<float, float, writeX, writeY, writeZ, writeW>(x, y, z, w, f);
81  // } else {
82  errorQuda("Not implemented");
83  // }
84  }
85 
86  }
87 
88 }
#define errorQuda(...)
Definition: util_quda.h:90
int Nspin
Definition: blas_test.cu:45
#define b
std::vector< ColorSpinorField * > CompositeColorSpinorField
int int int w
#define checkLocation(...)
const void * c
void multiblasCuda(const coeff_array< T > &a, const coeff_array< T > &b, const coeff_array< T > &c, CompositeColorSpinorField &x, CompositeColorSpinorField &y, CompositeColorSpinorField &z, CompositeColorSpinorField &w)
#define a