QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
generic_reduce.cuh
Go to the documentation of this file.
1 
4 template <typename ReduceType, typename Float, int writeX, int writeY, int writeZ, int writeW, int writeV,
5  typename SpinorX, typename SpinorY, typename SpinorZ, typename SpinorW, typename SpinorV, typename Reducer>
6 ReduceType genericReduce(SpinorX &X, SpinorY &Y, SpinorZ &Z, SpinorW &W, SpinorV &V, Reducer r)
7 {
8 
9  ReduceType sum;
10  ::quda::zero(sum);
11 
12  for (int parity = 0; parity < X.Nparity(); parity++) {
13  for (int x = 0; x < X.VolumeCB(); x++) {
14  r.pre();
15  for (int s = 0; s < X.Nspin(); s++) {
16  for (int c = 0; c < X.Ncolor(); c++) {
17  complex<Float> X_ = X(parity, x, s, c);
18  complex<Float> Y_ = Y(parity, x, s, c);
19  complex<Float> Z_ = Z(parity, x, s, c);
20  complex<Float> W_ = W(parity, x, s, c);
21  complex<Float> V_ = V(parity, x, s, c);
22  r(sum, X_, Y_, Z_, W_, V_);
23  if (writeX) X(parity, x, s, c) = X_;
24  if (writeY) Y(parity, x, s, c) = Y_;
25  if (writeZ) Z(parity, x, s, c) = Z_;
26  if (writeW) W(parity, x, s, c) = W_;
27  if (writeV) V(parity, x, s, c) = V_;
28  }
29  }
30  r.post(sum);
31  }
32  }
33 
34  return sum;
35 }
36 
37 template <typename ReduceType, typename Float, typename zFloat, int nSpin, int nColor, QudaFieldOrder order, int writeX,
38  int writeY, int writeZ, int writeW, int writeV, typename R>
39 ReduceType genericReduce(
40  ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v, R r)
41 {
42  colorspinor::FieldOrderCB<Float, nSpin, nColor, 1, order> X(x), Y(y), W(w), V(v);
43  colorspinor::FieldOrderCB<zFloat, nSpin, nColor, 1, order> Z(z);
44  return genericReduce<ReduceType, zFloat, writeX, writeY, writeZ, writeW, writeV>(X, Y, Z, W, V, r);
45 }
46 
47 template <typename ReduceType, typename Float, typename zFloat, int nSpin, QudaFieldOrder order, int writeX, int writeY,
48  int writeZ, int writeW, int writeV, typename R>
49 ReduceType genericReduce(
50  ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v, R r)
51 {
52  ReduceType value;
53  if (x.Ncolor() == 3) {
54  value = genericReduce<ReduceType, Float, zFloat, nSpin, 3, order, writeX, writeY, writeZ, writeW, writeV, R>(
55  x, y, z, w, v, r);
56  } else if (x.Ncolor() == 4) {
57  value = genericReduce<ReduceType, Float, zFloat, nSpin, 4, order, writeX, writeY, writeZ, writeW, writeV, R>(
58  x, y, z, w, v, r);
59  } else if (x.Ncolor() == 6) { // free field Wilson
60  value = genericReduce<ReduceType, Float, zFloat, nSpin, 6, order, writeX, writeY, writeZ, writeW, writeV, R>(
61  x, y, z, w, v, r);
62  } else if (x.Ncolor() == 8) {
63  value = genericReduce<ReduceType, Float, zFloat, nSpin, 8, order, writeX, writeY, writeZ, writeW, writeV, R>(
64  x, y, z, w, v, r);
65  } else if (x.Ncolor() == 12) {
66  value = genericReduce<ReduceType, Float, zFloat, nSpin, 12, order, writeX, writeY, writeZ, writeW, writeV, R>(
67  x, y, z, w, v, r);
68  } else if (x.Ncolor() == 16) {
69  value = genericReduce<ReduceType, Float, zFloat, nSpin, 16, order, writeX, writeY, writeZ, writeW, writeV, R>(
70  x, y, z, w, v, r);
71  } else if (x.Ncolor() == 20) {
72  value = genericReduce<ReduceType, Float, zFloat, nSpin, 20, order, writeX, writeY, writeZ, writeW, writeV, R>(
73  x, y, z, w, v, r);
74  } else if (x.Ncolor() == 24) {
75  value = genericReduce<ReduceType, Float, zFloat, nSpin, 24, order, writeX, writeY, writeZ, writeW, writeV, R>(
76  x, y, z, w, v, r);
77  } else if (x.Ncolor() == 32) {
78  value = genericReduce<ReduceType, Float, zFloat, nSpin, 32, order, writeX, writeY, writeZ, writeW, writeV, R>(
79  x, y, z, w, v, r);
80  } else if (x.Ncolor() == 72) {
81  value = genericReduce<ReduceType, Float, zFloat, nSpin, 72, order, writeX, writeY, writeZ, writeW, writeV, R>(
82  x, y, z, w, v, r);
83  } else if (x.Ncolor() == 576) {
84  value = genericReduce<ReduceType, Float, zFloat, nSpin, 576, order, writeX, writeY, writeZ, writeW, writeV, R>(
85  x, y, z, w, v, r);
86  } else {
87  ::quda::zero(value);
88  errorQuda("nColor = %d not implemented", x.Ncolor());
89  }
90  return value;
91 }
92 
93 template <typename ReduceType, typename Float, typename zFloat, QudaFieldOrder order, int writeX, int writeY,
94  int writeZ, int writeW, int writeV, typename R>
95 ReduceType genericReduce(
96  ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v, R r)
97 {
98  ReduceType value;
99  ::quda::zero(value);
100  if (x.Nspin() == 4) {
101  value = genericReduce<ReduceType, Float, zFloat, 4, order, writeX, writeY, writeZ, writeW, writeV, R>(
102  x, y, z, w, v, r);
103  } else if (x.Nspin() == 2) {
104  value = genericReduce<ReduceType, Float, zFloat, 2, order, writeX, writeY, writeZ, writeW, writeV, R>(
105  x, y, z, w, v, r);
106 #ifdef GPU_STAGGERED_DIRAC
107  } else if (x.Nspin() == 1) {
108  value = genericReduce<ReduceType, Float, zFloat, 1, order, writeX, writeY, writeZ, writeW, writeV, R>(
109  x, y, z, w, v, r);
110 #endif
111  } else {
112  errorQuda("nSpin = %d not implemented", x.Nspin());
113  }
114  return value;
115 }
116 
117 template <typename doubleN, typename ReduceType, typename Float, typename zFloat, int writeX, int writeY, int writeZ,
118  int writeW, int writeV, typename R>
120  ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, ColorSpinorField &v, R r)
121 {
122  ReduceType value;
123  ::quda::zero(value);
124  if (x.FieldOrder() == QUDA_SPACE_SPIN_COLOR_FIELD_ORDER) {
125  value = genericReduce<ReduceType, Float, zFloat, QUDA_SPACE_SPIN_COLOR_FIELD_ORDER, writeX, writeY, writeZ, writeW,
126  writeV, R>(x, y, z, w, v, r);
127  } else {
128  warningQuda("CPU reductions not implemented for %d field order", x.FieldOrder());
129  }
130  return set(value);
131 }
static void sum(Float *dst, Float *a, Float *b, int cnt)
Definition: dslash_util.h:8
int Z[4]
Definition: test_util.cpp:26
#define errorQuda(...)
Definition: util_quda.h:121
enum QudaFieldOrder_s QudaFieldOrder
static int R[4]
const int nColor
Definition: covdev_test.cpp:75
#define warningQuda(...)
Definition: util_quda.h:133
int X[4]
Definition: covdev_test.cpp:70
void zero(ColorSpinorField &a)
Definition: blas_quda.cu:472
int V
Definition: test_util.cpp:27
__shared__ float s[]
ReduceType genericReduce(SpinorX &X, SpinorY &Y, SpinorZ &Z, SpinorW &W, SpinorV &V, Reducer r)
QudaParity parity
Definition: covdev_test.cpp:54