QUDA  0.9.0
blas_mixed_core.h
Go to the documentation of this file.
1 namespace mixed {
2 
6  template <template <typename Float, typename FloatN> class Functor,
7  int writeX, int writeY, int writeZ, int writeW>
8  void blasCuda(const double2 &a, const double2 &b, const double2 &c,
9  ColorSpinorField &x, ColorSpinorField &y,
10  ColorSpinorField &z, ColorSpinorField &w) {
11 
13  if (x.Precision() == QUDA_SINGLE_PRECISION && y.Precision() == QUDA_DOUBLE_PRECISION) {
14  if (x.Nspin() == 4) {
15  const int M = 12;
16  blas::blasCuda<double2,float4,double2,M,Functor,writeX,writeY,writeZ,writeW>(a,b,c,x,y,z,w,x.Volume());
17  } else if (x.Nspin() == 1) {
18  const int M = 3;
19  blas::blasCuda<double2,float2,double2,M,Functor,writeX,writeY,writeZ,writeW>(a,b,c,x,y,z,w,x.Volume());
20  }
21  } else if (x.Precision() == QUDA_HALF_PRECISION && y.Precision() == QUDA_DOUBLE_PRECISION) {
22  if (x.Nspin() == 4) {
23  const int M = 12;
24  blas::blasCuda<double2,short4,double2,M,Functor,writeX,writeY,writeZ,writeW>(a,b,c,x,y,z,w,x.Volume());
25  } else if (x.Nspin() == 1) {
26  const int M = 3;
27  blas::blasCuda<double2,short2,double2,M,Functor,writeX,writeY,writeZ,writeW>(a,b,c,x,y,z,w,x.Volume());
28  }
29  } else if (x.Precision() == QUDA_HALF_PRECISION && y.Precision() == QUDA_SINGLE_PRECISION) {
30  if (x.Nspin() == 4) {
31  const int M = 6;
32  blas::blasCuda<float4,short4,float4,M,Functor,writeX,writeY,writeZ,writeW>(a,b,c,x,y,z,w,x.Volume());
33  } else if (x.Nspin() == 1) {
34  const int M = 3;
35  blas::blasCuda<float2,short2,float2,M,Functor,writeX,writeY,writeZ,writeW>(a,b,c,x,y,z,w,x.Volume());
36  }
37  } else {
38  errorQuda("Not implemented for this precision combination");
39  }
40  } else { // fields on the cpu
41  using namespace quda::colorspinor;
42  if (x.Precision() == QUDA_SINGLE_PRECISION && y.Precision() == QUDA_DOUBLE_PRECISION) {
43  Functor<double2, double2> f(a, b, c);
44  genericBlas<float, double, writeX, writeY, writeZ, writeW>(x, y, z, w, f);
45  } else {
46  errorQuda("Not implemented");
47  }
48  }
49  }
50 
51 }
#define errorQuda(...)
Definition: util_quda.h:90
void blasCuda(const double2 &a, const double2 &b, const double2 &c, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w)
#define b
int int int w
#define checkLocation(...)
int int int enum cudaChannelFormatKind f
const void * c
#define a