18 template <
typename Float,
typename vFloat,
int fineSpin,
int fineColor,
33 const int *fine_to_coarse,
const int *coarse_to_fine,
int parity)
34 : out(out), in(in), V(V), fine_to_coarse(fine_to_coarse), coarse_to_fine(coarse_to_fine),
35 spin_map(), parity(parity), nParity(in.SiteSubset()), swizzle(1)
39 out(arg.out), in(arg.in), V(arg.V),
40 fine_to_coarse(arg.fine_to_coarse), coarse_to_fine(arg.coarse_to_fine), spin_map(),
41 parity(arg.parity), nParity(arg.nParity), swizzle(arg.swizzle)
48 template <
typename Float,
int fineSpin,
int fineColor,
int coarseColor,
int coarse_colors_per_thread,
49 class FineColor,
class Rotator>
51 const FineColor &
in,
const Rotator &
V,
52 int parity,
int nParity,
int x_cb,
int coarse_color_block) {
53 const int spinor_parity = (nParity == 2) ? parity : 0;
54 const int v_parity = (V.Nparity() == 2) ? parity : 0;
57 for (
int s=0;
s<fineSpin;
s++)
59 for (
int coarse_color_local=0; coarse_color_local<coarse_colors_per_thread; coarse_color_local++) {
60 out[
s*coarse_colors_per_thread+coarse_color_local] = 0.0;
64 for (
int coarse_color_local=0; coarse_color_local<coarse_colors_per_thread; coarse_color_local++) {
65 int i = coarse_color_block + coarse_color_local;
67 for (
int s=0;
s<fineSpin;
s++) {
69 constexpr
int color_unroll = fineColor == 3 ? 3 : 2;
71 complex<Float> partial[color_unroll];
73 for (
int k=0; k<color_unroll; k++) partial[k] = 0.0;
76 for (
int j=0; j<fineColor; j+=color_unroll) {
78 for (
int k=0; k<color_unroll; k++)
79 partial[k] +=
conj(
V(v_parity, x_cb,
s, j+k, i)) *
in(spinor_parity, x_cb,
s, j+k);
83 for (
int k=0; k<color_unroll; k++)
out[
s*coarse_colors_per_thread + coarse_color_local] += partial[k];
89 template <
typename Float,
int fineSpin,
int fineColor,
int coarseSpin,
int coarseColor,
int coarse_colors_per_thread,
typename Arg>
91 for (
int parity_coarse=0; parity_coarse<2; parity_coarse++)
92 for (
int x_coarse_cb=0; x_coarse_cb<arg.out.VolumeCB(); x_coarse_cb++)
93 for (
int s=0;
s<coarseSpin;
s++)
94 for (
int c=0; c<coarseColor; c++)
95 arg.out(parity_coarse, x_coarse_cb,
s, c) = 0.0;
101 for (
int x_cb=0; x_cb<arg.in.VolumeCB(); x_cb++) {
103 int x =
parity*arg.in.VolumeCB() + x_cb;
104 int x_coarse = arg.fine_to_coarse[x];
105 int parity_coarse = (x_coarse >= arg.out.VolumeCB()) ? 1 : 0;
106 int x_coarse_cb = x_coarse - parity_coarse*arg.out.VolumeCB();
108 for (
int coarse_color_block=0; coarse_color_block<coarseColor; coarse_color_block+=coarse_colors_per_thread) {
109 complex<Float>
tmp[fineSpin*coarse_colors_per_thread];
110 rotateCoarseColor<Float,fineSpin,fineColor,coarseColor,coarse_colors_per_thread>
113 for (
int s=0;
s<fineSpin;
s++) {
114 for (
int coarse_color_local=0; coarse_color_local<coarse_colors_per_thread; coarse_color_local++) {
115 int c = coarse_color_block + coarse_color_local;
116 arg.out(parity_coarse,x_coarse_cb,arg.spin_map(
s,
parity),c) += tmp[
s*coarse_colors_per_thread+coarse_color_local];
134 template <
int block_size,
typename Float,
int fineSpin,
int fineColor,
int coarseSpin,
135 int coarseColor,
int coarse_colors_per_thread,
typename Arg>
140 const int gridp = gridDim.x - gridDim.x % arg.swizzle;
142 int x_coarse = blockIdx.x;
143 if (blockIdx.x < gridp) {
145 const int i = blockIdx.x % arg.swizzle;
146 const int j = blockIdx.x / arg.swizzle;
149 x_coarse = i * (gridp / arg.swizzle) + j;
152 int x_coarse = blockIdx.x;
155 int parity_coarse = x_coarse >= arg.out.VolumeCB() ? 1 : 0;
156 int x_coarse_cb = x_coarse - parity_coarse*arg.out.VolumeCB();
167 int x_fine = arg.coarse_to_fine[ (x_coarse*2 +
parity) * blockDim.x + threadIdx.x];
168 int x_fine_cb = x_fine - parity*arg.in.VolumeCB();
170 int coarse_color_block = (blockDim.z*blockIdx.z + threadIdx.z) * coarse_colors_per_thread;
171 if (coarse_color_block >= coarseColor)
return;
173 complex<Float>
tmp[fineSpin*coarse_colors_per_thread];
174 rotateCoarseColor<Float,fineSpin,fineColor,coarseColor,coarse_colors_per_thread>
181 for (
int s=0;
s<fineSpin;
s++) {
182 for (
int v=0; v<coarse_colors_per_thread; v++) {
183 reduced[arg.spin_map(
s,parity)*coarse_colors_per_thread+v] += tmp[
s*coarse_colors_per_thread+v];
189 typedef cub::BlockReduce<vector, block_size, cub::BLOCK_REDUCE_WARP_REDUCTIONS, 2> BlockReduce;
190 __shared__
typename BlockReduce::TempStorage temp_storage;
193 reduced = BlockReduce(temp_storage).Sum(reduced);
195 typedef cub::BlockReduce<vector, block_size, cub::BLOCK_REDUCE_WARP_REDUCTIONS> BlockReduce;
196 __shared__
typename BlockReduce::TempStorage temp_storage;
199 reduced = BlockReduce(temp_storage).Sum(reduced);
202 if (threadIdx.x==0 && threadIdx.y == 0) {
203 for (
int s=0;
s<coarseSpin;
s++) {
204 for (
int coarse_color_local=0; coarse_color_local<coarse_colors_per_thread; coarse_color_local++) {
205 int v = coarse_color_block + coarse_color_local;
206 arg.out(parity_coarse, x_coarse_cb,
s, v) = reduced[
s*coarse_colors_per_thread+coarse_color_local];
FieldOrderCB< Float, coarseSpin, coarseColor, 1, order > out
__global__ void RestrictKernel(Arg arg)
enum QudaFieldOrder_s QudaFieldOrder
cudaColorSpinorField * tmp
RestrictArg(const RestrictArg< Float, vFloat, fineSpin, fineColor, coarseSpin, coarseColor, order > &arg)
const int * coarse_to_fine
const FieldOrderCB< Float, fineSpin, fineColor, 1, order > in
const FieldOrderCB< Float, fineSpin, fineColor, coarseColor, order, vFloat > V
__device__ __host__ void rotateCoarseColor(complex< Float > out[fineSpin *coarse_colors_per_thread], const FineColor &in, const Rotator &V, int parity, int nParity, int x_cb, int coarse_color_block)
const int * fine_to_coarse
cpuColorSpinorField * out
const spin_mapper< fineSpin, coarseSpin > spin_map
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
RestrictArg(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &V, const int *fine_to_coarse, const int *coarse_to_fine, int parity)
colorspinor::FieldOrderCB< real, Ns, Nc, 1, order > V
__host__ __device__ ValueType conj(ValueType x)