6 template<
int NXZ,
typename doubleN,
typename ReduceType,
7 template <
int MXZ,
typename ReducerType,
typename Float,
typename FloatN>
class Reducer,
typename write,
bool siteUnroll,
typename T>
8 void multiReduceCuda(doubleN result[],
const reduce::coeff_array<T> &
a,
const reduce::coeff_array<T> &
b,
const reduce::coeff_array<T> &
c,
11 const int NYW =
y.size();
13 assert(siteUnroll==
true);
14 int reduce_length = siteUnroll ?
x[0]->RealLength() :
x[0]->Length();
19 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) 21 multiReduceCuda<doubleN,ReduceType,double2,float4,double2,M,NXZ,Reducer,write>
22 (result,
a,
b,
c,
x,
y,
z,
w, reduce_length/(2*M));
24 errorQuda(
"blas has not been built for Nspin=%d fields",
x[0]->
Nspin());
26 }
else if (
x[0]->
Nspin() == 1) {
27 #ifdef GPU_STAGGERED_DIRAC 29 multiReduceCuda<doubleN,ReduceType,double2,float2,double2,M,NXZ,Reducer,write>
30 (result,
a,
b,
c,
x,
y,
z,
w, reduce_length/(2*M));
39 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) 41 multiReduceCuda<doubleN,ReduceType,double2,short4,double2,M,NXZ,Reducer,write>
42 (result,
a,
b,
c,
x,
y,
z,
w, reduce_length/(4*M));
44 errorQuda(
"blas has not been built for Nspin=%d fields",
x[0]->
Nspin());
47 #if defined(GPU_STAGGERED_DIRAC) 49 multiReduceCuda<doubleN,ReduceType,double2,short2,double2,M,NXZ,Reducer,write>
50 (result,
a,
b,
c,
x,
y,
z,
w, reduce_length/(2*M));
52 errorQuda(
"blas has not been built for Nspin=%d fields",
x[0]->
Nspin());
59 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) 61 multiReduceCuda<doubleN,ReduceType,float4,short4,float4,M,NXZ,Reducer,write>
62 (result,
a,
b,
c,
x,
y,
z,
w,
x[0]->Volume());
64 errorQuda(
"blas has not been built for Nspin=%d fields",
x[0]->
Nspin());
66 }
else if(
x[0]->
Nspin() == 1) {
67 #ifdef GPU_STAGGERED_DIRAC 69 multiReduceCuda<doubleN,ReduceType,float2,short2,float2,M,NXZ,Reducer,write>
70 (result,
a,
b,
c,
x,
y,
z,
w,
x[0]->Volume());
72 errorQuda(
"blas has not been built for Nspin=%d fields",
x[0]->
Nspin());
77 errorQuda(
"Precision combination x=%d y=%d not supported\n",
x[0]->Precision(),
y[0]->Precision());
void multiReduceCuda(doubleN result[], const reduce::coeff_array< T > &a, const reduce::coeff_array< T > &b, const reduce::coeff_array< T > &c, CompositeColorSpinorField &x, CompositeColorSpinorField &y, CompositeColorSpinorField &z, CompositeColorSpinorField &w)
std::vector< ColorSpinorField * > CompositeColorSpinorField