QUDA  0.9.0
reduce_core.h
Go to the documentation of this file.
1 /*
2  Wilson
3  double double2 M = 1/12
4  single float4 M = 1/6
5  half short4 M = 6/6
6 
7  Staggered
8  double double2 M = 1/3
9  single float2 M = 1/3
10  half short2 M = 3/3
11  */
12 
18 template <typename doubleN, typename ReduceType,
19  template <typename ReducerType, typename Float, typename FloatN> class Reducer,
20  int writeX, int writeY, int writeZ, int writeW, int writeV, bool siteUnroll>
21 doubleN reduceCuda(const double2 &a, const double2 &b, ColorSpinorField &x,
22  ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w,
23  ColorSpinorField &v) {
24 
25  doubleN value;
26  if (checkLocation(x, y, z, w, v) == QUDA_CUDA_FIELD_LOCATION) {
27 
28  // cannot do site unrolling for arbitrary color (needs JIT)
29  if (siteUnroll && x.Ncolor()!=3) errorQuda("Not supported");
30 
31  int reduce_length = siteUnroll ? x.RealLength() : x.Length();
32 
33  if (x.Precision() == QUDA_DOUBLE_PRECISION) {
34  if (x.Nspin() == 4 || x.Nspin() == 2) { //wilson
35 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) || defined(GPU_MULTIGRID)
36  const int M = siteUnroll ? 12 : 1; // determines how much work per thread to do
37  if (x.Nspin() == 2 && siteUnroll) errorQuda("siteUnroll not supported for nSpin==2");
38  value = reduceCuda<doubleN,ReduceType,double2,double2,double2,M,Reducer,
39  writeX,writeY,writeZ,writeW,writeV>
40  (a, b, x, y, z, w, v, reduce_length/(2*M));
41 #else
42  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
43 #endif
44  } else if (x.Nspin() == 1) { //staggered
45 #ifdef GPU_STAGGERED_DIRAC
46  const int M = siteUnroll ? 3 : 1; // determines how much work per thread to do
47  value = reduceCuda<doubleN,ReduceType,double2,double2,double2,M,Reducer,
48  writeX,writeY,writeZ,writeW,writeV>
49  (a, b, x, y, z, w, v, reduce_length/(2*M));
50 #else
51  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
52 #endif
53  } else { errorQuda("ERROR: nSpin=%d is not supported\n", x.Nspin()); }
54  } else if (x.Precision() == QUDA_SINGLE_PRECISION) {
55  if (x.Nspin() == 4) { //wilson
56 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
57  const int M = siteUnroll ? 6 : 1; // determines how much work per thread to do
58  value = reduceCuda<doubleN,ReduceType,float4,float4,float4,M,Reducer,
59  writeX,writeY,writeZ,writeW,writeV>
60  (a, b, x, y, z, w, v, reduce_length/(4*M));
61 #else
62  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
63 #endif
64  } else if (x.Nspin() == 1 || x.Nspin() == 2) {
65 #if defined(GPU_STAGGERED_DIRAC) || defined(GPU_MULTIGRID)
66  const int M = siteUnroll ? 3 : 1; // determines how much work per thread to do
67  if (x.Nspin() == 2 && siteUnroll) errorQuda("siteUnroll not supported for nSpin==2");
68  value = reduceCuda<doubleN,ReduceType,float2,float2,float2,M,Reducer,
69  writeX,writeY,writeZ,writeW,writeV>
70  (a, b, x, y, z, w, v, reduce_length/(2*M));
71 #else
72  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
73 #endif
74  } else { errorQuda("ERROR: nSpin=%d is not supported\n", x.Nspin()); }
75  } else if (x.Precision() == QUDA_HALF_PRECISION) { // half precision
76  if (x.Nspin() == 4) { //wilson
77 #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
78  const int M = 6; // determines how much work per thread to do
79  value = reduceCuda<doubleN,ReduceType,float4,short4,short4,M,Reducer,
80  writeX,writeY,writeZ,writeW,writeV>
81  (a, b, x, y, z, w, v, y.Volume());
82 #else
83  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
84 #endif
85  } else if (x.Nspin() == 1) {//staggered
86 #ifdef GPU_STAGGERED_DIRAC
87  const int M = 3; // determines how much work per thread to do
88  value = reduceCuda<doubleN,ReduceType,float2,short2,short2,M,Reducer,
89  writeX,writeY,writeZ,writeW,writeV>
90  (a, b, x, y, z, w, v, y.Volume());
91 #else
92  errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
93 #endif
94  } else { errorQuda("nSpin=%d is not supported\n", x.Nspin()); }
95  } else {
96  errorQuda("precision=%d is not supported\n", x.Precision());
97  }
98  } else { // fields are on the CPU
99  // we don't have quad precision support on the GPU so use doubleN instead of ReduceType
100  if (x.Precision() == QUDA_DOUBLE_PRECISION) {
101  Reducer<doubleN, double2, double2> r(a, b);
102  value = genericReduce<doubleN,doubleN,double,double,writeX,writeY,writeZ,writeW,writeV,Reducer<doubleN,double2,double2> >(x,y,z,w,v,r);
103  } else if (x.Precision() == QUDA_SINGLE_PRECISION) {
104  Reducer<doubleN, float2, float2> r(make_float2(a.x, a.y), make_float2(b.x, b.y));
105  value = genericReduce<doubleN,doubleN,float,float,writeX,writeY,writeZ,writeW,writeV,Reducer<doubleN,float2,float2> >(x,y,z,w,v,r);
106  } else {
107  errorQuda("Precision %d not implemented", x.Precision());
108  }
109  }
110 
111  const int Nreduce = sizeof(doubleN) / sizeof(double);
112  reduceDoubleArray((double*)&value, Nreduce);
113 
114  return value;
115 }
#define errorQuda(...)
Definition: util_quda.h:90
void reduceDoubleArray(double *, const int len)
#define b
int int int w
#define checkLocation(...)
doubleN reduceCuda(const double2 &a, const double2 &b, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v)
Definition: reduce_core.h:21
#define a