4 #include <staggered_oprod.h>
6 #include <gauge_field_order.h>
7 #include <color_spinor_field_order.h>
8 #include <quda_matrix.h>
9 #include <index_helper.cuh>
13 enum OprodKernelType { OPROD_INTERIOR_KERNEL, OPROD_EXTERIOR_KERNEL };
15 template <typename Float, int nColor_> struct StaggeredOprodArg {
16 typedef typename mapper<Float>::type real;
17 static constexpr int nColor = nColor_;
18 static constexpr int nSpin = 1;
19 using F = typename colorspinor_mapper<Float, nSpin, nColor>::type;
20 using GU = typename gauge_mapper<Float, QUDA_RECONSTRUCT_NO, 18>::type;
21 using GL = typename gauge_mapper<Float, QUDA_RECONSTRUCT_NO, 18>::type;
23 const F inA; /** input vector field */
24 const F inB; /** input vector field */
25 GU U; /** output one-hop field */
26 GL L; /** output three-hop field */
32 OprodKernelType kernelType;
38 StaggeredOprodArg(GaugeField &U, GaugeField &L, const ColorSpinorField &inA, const ColorSpinorField &inB,
39 int parity, int dir, int displacement, const OprodKernelType &kernelType, int nFace, const double coeff[2]) :
47 displacement(displacement),
48 kernelType(kernelType),
51 this->coeff[0] = coeff[0];
52 this->coeff[1] = coeff[1];
53 for (int i = 0; i < 4; ++i) this->X[i] = U.X()[i];
54 for (int i = 0; i < 4; ++i) this->partitioned[i] = commDimPartitioned(i) ? true : false;
58 template<typename real, typename Arg> __global__ void interiorOprodKernel(Arg arg)
60 using complex = complex<real>;
61 using matrix = Matrix<complex, Arg::nColor>;
62 using vector = ColorSpinor<real, Arg::nColor, 1>;
64 unsigned int idx = blockIdx.x*blockDim.x + threadIdx.x;
65 const unsigned int gridSize = gridDim.x*blockDim.x;
69 while (idx < arg.length) {
70 const vector x = arg.inA(idx, 0);
73 for (int dim=0; dim<4; ++dim) {
74 int shift[4] = {0,0,0,0};
76 const int first_nbr_idx = neighborIndex(idx, shift, arg.partitioned, arg.parity, arg.X);
77 if (first_nbr_idx >= 0) {
78 const vector y = arg.inB(first_nbr_idx, 0);
79 result = outerProduct(y, x);
80 matrix tempA = arg.U(dim, idx, arg.parity);
81 result = tempA + result*arg.coeff[0];
83 arg.U(dim, idx, arg.parity) = result;
87 const int third_nbr_idx = neighborIndex(idx, shift, arg.partitioned, arg.parity, arg.X);
88 if (third_nbr_idx >= 0) {
89 const vector z = arg.inB(third_nbr_idx, 0);
90 result = outerProduct(z, x);
91 matrix tempB = arg.L(dim, idx, arg.parity);
92 result = tempB + result*arg.coeff[1];
93 arg.L(dim, idx, arg.parity) = result;
101 } // interiorOprodKernel
103 template<int dim, typename real, typename Arg> __global__ void exteriorOprodKernel(Arg arg)
105 using complex = complex<real>;
106 using matrix = Matrix<complex, Arg::nColor>;
107 using vector = ColorSpinor<real, Arg::nColor, 1>;
109 unsigned int cb_idx = blockIdx.x*blockDim.x + threadIdx.x;
110 const unsigned int gridSize = gridDim.x*blockDim.x;
114 auto &out = (arg.displacement == 1) ? arg.U : arg.L;
115 real coeff = (arg.displacement == 1) ? arg.coeff[0] : arg.coeff[1];
118 while (cb_idx < arg.length) {
119 coordsFromIndexExterior(x, cb_idx, arg.X, arg.dir, arg.displacement, arg.parity);
120 const unsigned int bulk_cb_idx = ((((x[3]*arg.X[2] + x[2])*arg.X[1] + x[1])*arg.X[0] + x[0]) >> 1);
122 matrix inmatrix = out(arg.dir, bulk_cb_idx, arg.parity);
123 const vector a = arg.inA(bulk_cb_idx, 0);
124 const vector b = arg.inB.Ghost(arg.dir, 1, cb_idx, 0);
126 result = outerProduct(b, a);
127 result = inmatrix + result*coeff;
128 out(arg.dir, bulk_cb_idx, arg.parity) = result;
134 template<typename Float, typename Arg>
135 class StaggeredOprodField : public Tunable {
138 const GaugeField &meta;
140 unsigned int sharedBytesPerThread() const { return 0; }
141 unsigned int sharedBytesPerBlock(const TuneParam &) const { return 0; }
143 unsigned int minThreads() const { return arg.U.volumeCB; }
144 bool tunedGridDim() const { return false; }
147 StaggeredOprodField(Arg &arg, const GaugeField &meta)
148 : arg(arg), meta(meta) {
149 writeAuxString(meta.AuxString());
152 void apply(const qudaStream_t &stream) {
153 if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
154 // Disable tuning for the time being
155 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
156 if (arg.kernelType == OPROD_INTERIOR_KERNEL) {
157 qudaLaunchKernel(interiorOprodKernel<Float, Arg>, tp, stream, arg);
158 } else if (arg.kernelType == OPROD_EXTERIOR_KERNEL) {
159 if (arg.dir == 0) qudaLaunchKernel(exteriorOprodKernel<0,Float,Arg>, tp, stream, arg);
160 else if (arg.dir == 1) qudaLaunchKernel(exteriorOprodKernel<1,Float,Arg>, tp, stream, arg);
161 else if (arg.dir == 2) qudaLaunchKernel(exteriorOprodKernel<2,Float,Arg>, tp, stream, arg);
162 else if (arg.dir == 3) qudaLaunchKernel(exteriorOprodKernel<3,Float,Arg>, tp, stream, arg);
164 errorQuda("Kernel type not supported\n");
166 } else { // run the CPU code
167 errorQuda("No CPU support for staggered outer-product calculation\n");
171 void preTune() { arg.U.save(); arg.L.save(); }
172 void postTune() { arg.U.load(); arg.L.load(); }
174 long long flops() const { return 0; } // FIXME
175 long long bytes() const { return 0; } // FIXME
176 TuneKey tuneKey() const {
177 char aux[TuneKey::aux_n];
178 strcpy(aux, this->aux);
179 if (arg.kernelType == OPROD_EXTERIOR_KERNEL) {
180 strcat(aux, ",dir=");
182 u32toa(tmp, arg.dir);
184 strcat(aux, ",displacement=");
185 u32toa(tmp, arg.displacement);
188 return TuneKey(meta.VolString(), typeid(*this).name(), aux);
190 }; // StaggeredOprodField
192 template <typename Float>
193 void computeStaggeredOprod(GaugeField &U, GaugeField &L, ColorSpinorField &inA, ColorSpinorField &inB, int parity, const double coeff[2], int nFace)
195 // Create the arguments for the interior kernel
196 StaggeredOprodArg<Float, 3> arg(U, L, inA, inB, parity, 0, 1, OPROD_INTERIOR_KERNEL, nFace, coeff);
197 StaggeredOprodField<Float, decltype(arg)> oprod(arg, U);
199 arg.kernelType = OPROD_INTERIOR_KERNEL;
200 arg.length = U.VolumeCB();
203 for (int i = 3; i >= 0; i--) {
204 if (commDimPartitioned(i)) {
205 // update parameters for this exterior kernel
206 arg.kernelType = OPROD_EXTERIOR_KERNEL;
209 // First, do the one hop term
211 arg.displacement = 1;
212 arg.length = inB.GhostFaceCB()[i];
216 // Now do the 3 hop term
218 arg.displacement = 3;
219 arg.length = arg.displacement * inB.GhostFaceCB()[i];
225 } // computeStaggeredOprod
227 void computeStaggeredOprod(GaugeField &U, GaugeField &L, ColorSpinorField &inEven, ColorSpinorField &inOdd,
228 int parity, const double coeff[2], int nFace)
230 if (U.Order() != QUDA_FLOAT2_GAUGE_ORDER) errorQuda("Unsupported output ordering: %d\n", U.Order());
231 if (L.Order() != QUDA_FLOAT2_GAUGE_ORDER) errorQuda("Unsupported output ordering: %d\n", L.Order());
233 ColorSpinorField &inA = (parity & 1) ? inOdd : inEven;
234 ColorSpinorField &inB = (parity & 1) ? inEven : inOdd;
236 inB.exchangeGhost((QudaParity)(1-parity), nFace, 0);
238 auto prec = checkPrecision(inEven, inOdd, U, L);
239 if (prec == QUDA_DOUBLE_PRECISION) {
240 computeStaggeredOprod<double>(U, L, inA, inB, parity, coeff, nFace);
241 } else if (prec == QUDA_SINGLE_PRECISION) {
242 computeStaggeredOprod<float>(U, L, inA, inB, parity, coeff, nFace);
244 errorQuda("Unsupported precision: %d", prec);
247 inB.bufferIndex = (1 - inB.bufferIndex);
250 void computeStaggeredOprod(GaugeField *out[], ColorSpinorField& in, const double coeff[], int nFace)
252 #ifdef GPU_STAGGERED_DIRAC
254 computeStaggeredOprod(*out[0], *out[0], in.Even(), in.Odd(), 0, coeff, nFace);
255 double coeff_[2] = {-coeff[0],0.0}; // need to multiply by -1 on odd sites
256 computeStaggeredOprod(*out[0], *out[0], in.Even(), in.Odd(), 1, coeff_, nFace);
257 } else if (nFace == 3) {
258 computeStaggeredOprod(*out[0], *out[1], in.Even(), in.Odd(), 0, coeff, nFace);
259 computeStaggeredOprod(*out[0], *out[1], in.Even(), in.Odd(), 1, coeff, nFace);
261 errorQuda("Invalid nFace=%d", nFace);
263 #else // GPU_STAGGERED_DIRAC not defined
264 errorQuda("Staggered Outer Product has not been built!");