QUDA  0.9.0
reduce_mixed_core.h
Go to the documentation of this file.
1 namespace mixed {
2 
3 /*
4  Wilson
5  double double2 M = 1/12
6  single float4 M = 1/6
7  half short4 M = 6/6
8 
9  Staggered
10  double double2 M = 1/3
11  single float2 M = 1/3
12  half short2 M = 3/3
13  */
14 
20 template <typename doubleN, typename ReduceType,
21  template <typename ReducerType, typename Float, typename FloatN> class Reducer,
22  int writeX, int writeY, int writeZ, int writeW, int writeV, bool siteUnroll>
23 doubleN reduceCuda(const double2 &a, const double2 &b, ColorSpinorField &x,
24  ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w,
25  ColorSpinorField &v) {
26 
27  doubleN value;
28  if (checkLocation(x, y, z, w, v) == QUDA_CUDA_FIELD_LOCATION) {
29 
30  // cannot do site unrolling for arbitrary color (needs JIT)
31  if (x.Ncolor()!=3) errorQuda("Not supported");
32 
33  if (x.Precision() == QUDA_SINGLE_PRECISION && z.Precision() == QUDA_DOUBLE_PRECISION) {
34  if (x.Nspin() == 4){ //wilson
35 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
36  const int M = 12; // determines how much work per thread to do
37  value = reduce::reduceCuda<doubleN,ReduceType,double2,float4,double2,M,Reducer,
38  writeX,writeY,writeZ,writeW,writeV>
39  (a, b, x, y, z, w, v, x.Volume());
40 #else
41  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
42 #endif
43  } else if (x.Nspin() == 1) { //staggered
44 #ifdef GPU_STAGGERED_DIRAC
45  const int M = siteUnroll ? 3 : 1; // determines how much work per thread to do
46  const int reduce_length = siteUnroll ? x.RealLength() : x.Length();
47  value = reduce::reduceCuda<doubleN,ReduceType,double2,float2,double2,M,Reducer,
48  writeX,writeY,writeZ,writeW,writeV>
49  (a, b, x, y, z, w, v, reduce_length/(2*M));
50 #else
51  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
52 #endif
53  } else { errorQuda("ERROR: nSpin=%d is not supported\n", x.Nspin()); }
54  } else if (x.Precision() == QUDA_HALF_PRECISION && z.Precision() == QUDA_DOUBLE_PRECISION) {
55  if (x.Nspin() == 4) { //wilson
56 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
57  const int M = 12; // determines how much work per thread to do
58  value = reduce::reduceCuda<doubleN,ReduceType,double2,short4,double2,M,Reducer,
59  writeX,writeY,writeZ,writeW,writeV>
60  (a, b, x, y, z, w, v, x.Volume());
61 #else
62  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
63 #endif
64  } else if (x.Nspin() == 1) { //staggered
65 #ifdef GPU_STAGGERED_DIRAC
66  const int M = 3; // determines how much work per thread to do
67  value = reduce::reduceCuda<doubleN,ReduceType,double2,short2,double2,M,Reducer,
68  writeX,writeY,writeZ,writeW,writeV>
69  (a, b, x, y, z, w, v, x.Volume());
70 #else
71  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
72 #endif
73  } else { errorQuda("ERROR: nSpin=%d is not supported\n", x.Nspin()); }
74  } else if (z.Precision() == QUDA_SINGLE_PRECISION) {
75  if (x.Nspin() == 4) { //wilson
76 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
77  const int M = 6;
78  value = reduce::reduceCuda<doubleN,ReduceType,float4,short4,float4,M,Reducer,
79  writeX,writeY,writeZ,writeW,writeV>
80  (a, b, x, y, z, w, v, x.Volume());
81 #else
82  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
83 #endif
84  } else if (x.Nspin() == 1) {//staggered
85 #ifdef GPU_STAGGERED_DIRAC
86  const int M = 3;
87  value = reduce::reduceCuda<doubleN,ReduceType,float2,short2,float2,M,Reducer,
88  writeX,writeY,writeZ,writeW,writeV>
89  (a, b, x, y, z, w, v, x.Volume());
90 #else
91  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
92 #endif
93  } else { errorQuda("ERROR: nSpin=%d is not supported\n", x.Nspin()); }
94  blas::bytes += Reducer<ReduceType,double2,double2>::streams()*(unsigned long long)x.Volume()*sizeof(float);
95  }
96  } else {
97  // we don't have quad precision support on the GPU so use doubleN instead of ReduceType
98  if (x.Precision() == QUDA_SINGLE_PRECISION && z.Precision() == QUDA_DOUBLE_PRECISION) {
99  Reducer<doubleN, double2, double2> r(a, b);
100  value = genericReduce<doubleN,doubleN,float,double,writeX,writeY,writeZ,writeW,writeV,Reducer<doubleN,double2,double2> >(x,y,z,w,v,r);
101  } else {
102  errorQuda("Precision %d not implemented", x.Precision());
103  }
104  }
105 
106  const int Nreduce = sizeof(doubleN) / sizeof(double);
107  reduceDoubleArray((double*)&value, Nreduce);
108 
109  return value;
110 }
111 
112 } // namespace mixed
#define errorQuda(...)
Definition: util_quda.h:90
cudaStream_t * streams
void reduceDoubleArray(double *, const int len)
#define b
int int int w
#define checkLocation(...)
doubleN reduceCuda(const double2 &a, const double2 &b, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v)
#define a
unsigned long long bytes
Definition: blas_quda.cu:43