2 #include <gauge_field.h>
3 #include <color_spinor_field.h>
4 #include <register_traits.h>
5 #include <dslash_quda.h>
6 #include <instantiate.h>
8 #include <jitify_helper.cuh>
9 #include <kernels/staggered_kd_apply_xinv_kernel.cuh>
13 template <typename Arg>
14 class ApplyStaggeredKDBlock : public TunableVectorY {
17 const ColorSpinorField &meta;
18 const GaugeField &Xinv;
20 long long flops() const {
21 // a coarse volume number of 48x48 mat-vec
22 return 2ll * arg.coarseVolumeCB * Arg::coarseDof * (8ll * Arg::coarseDof - 2);
25 long long bytes() const
27 return 2 * meta.Bytes() + Xinv.Bytes();
30 unsigned int sharedBytesPerThread() const {
31 // 16 threads needs to store 16 ColorVectors in the compute
32 // precision, times 2 for in vs out
33 // -> each thread needs to store 2 x size of ColorVector
34 // plus some padding to avoid bank conflicts: each KD block stores
35 // 17 ColorVectors. (2 * 16 threads * 51 complex * 8 bytes per complex) / 256 threads
36 // -> each thread needs 51 bytes
37 return 2 * Arg::fineColor * 16 * Arg::paddedSpinorSizeKD * sizeof(complex<typename Arg::Float>) / 256 +
38 Arg::xinvPaddedColTileSize * sizeof(complex<typename Arg::Float>);
41 int blockStep() const { return 256; }
42 int blockMin() const { return 256; }
44 unsigned int minThreads() const { return 2 * arg.fineVolumeCB; }
45 bool tuneGridDim() const { return false; } // don't tune the grid dimension
48 ApplyStaggeredKDBlock(Arg &arg, const ColorSpinorField &meta, const GaugeField &Xinv) :
55 create_jitify_program("kernels/staggered_kd_apply_xinv_kernel.cuh");
57 strcpy(aux, compile_type_str(meta));
59 strcat(aux, meta.AuxString());
60 strcat(aux, ",applyStaggeredKDBlock");
61 strcat(aux, ",Xinv:coarse_");
62 strcat(aux, Xinv.AuxString());
63 // should be all we need?
66 void apply(const qudaStream_t &stream)
68 TuneParam tp = tuneLaunch(*this, getTuning(), QUDA_VERBOSE);
70 if (meta.Location() == QUDA_CPU_FIELD_LOCATION) {
71 //ComputeStaggeredKDBlockCPU<Float,fineColor,coarseSpin,coarseColor>(arg);
74 using namespace jitify::reflection;
75 jitify_error = program->kernel("quda::ApplyStaggeredKDBlockGPU")
76 .instantiate(Type<Arg>())
77 .configure(tp.grid,tp.block,tp.shared_bytes,stream).launch(arg);
79 qudaLaunchKernel(ApplyStaggeredKDBlockGPU<Arg>, tp, stream, arg);
84 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
88 @brief Apply the staggered Kahler-Dirac block inverse
90 @param out[out] output staggered spinor accessor
91 @param in[in] input staggered spinor accessor
92 @param Xinv[in] KD block inverse accessor
93 @param out_[out] output staggered spinor
94 @param in_[in] input staggered spinor
95 @param Xinv_[in] KD block inverse
97 template<typename vFloatSpinor, typename vFloatGauge, int fineColor, int coarseDof, bool dagger, typename fineColorSpinor, typename xInvGauge>
98 void applyStaggeredKDBlock(fineColorSpinor &out, const fineColorSpinor &in, const xInvGauge &Xinv,
99 ColorSpinorField &out_, const ColorSpinorField &in_, const GaugeField &Xinv_)
103 errorQuda("Input gauge field should have nColor=3, not nColor=%d\n", fineColor);
105 if (Xinv.Ndim() != 4) errorQuda("Number of dimensions not supported");
108 if (fineColor * 16 != coarseDof)
109 errorQuda("Fine nColor=%d is not consistent with KD dof %d", fineColor, coarseDof);
111 int x_size[QUDA_MAX_DIM] = { };
112 int xc_size[QUDA_MAX_DIM] = { };
113 for (int i = 0; i < nDim; i++) {
114 x_size[i] = out_.X()[i];
115 xc_size[i] = Xinv_.X()[i];
116 // check that local volumes are consistent
117 if (2 * xc_size[i] != x_size[i]) {
118 errorQuda("Inconsistent fine dimension %d and coarse KD dimension %d", x_size[i], xc_size[i]);
122 using Arg = ApplyStaggeredKDBlockArg<vFloatSpinor,vFloatGauge,coarseDof,fineColor,dagger,fineColorSpinor,xInvGauge>;
123 Arg arg(out, in, Xinv, x_size, xc_size);
125 ApplyStaggeredKDBlock<Arg> y(arg, out_, Xinv_);
127 if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Applying KD block...\n");
130 if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("... done applying KD block\n");
133 // create accessors, specify dagger vs non-dagger
134 template <typename vFloatSpinor, typename vFloatGauge, int fineColor, int fineSpin, int coarseDof>
135 void applyStaggeredKDBlock(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &Xinv, bool dagger)
138 // Create the accessor for Xinv
139 constexpr QudaGaugeFieldOrder xOrder = QUDA_MILC_GAUGE_ORDER;
140 if (Xinv.FieldOrder() != xOrder) errorQuda("Unsupported field order %d\n", Xinv.FieldOrder());
141 using xInvCoarse = typename gauge::FieldOrder<typename mapper<vFloatGauge>::type,coarseDof,1,xOrder,true,vFloatGauge>;
142 xInvCoarse xInvAccessor(const_cast<GaugeField &>(Xinv));
144 // Create the accessors for out, in
145 constexpr bool spin_project = false;
146 constexpr bool spinor_direct_load = false; // seems legacy? false means texture load
147 using csFine = typename colorspinor_mapper<vFloatSpinor, fineSpin, fineColor, spin_project, spinor_direct_load>::type;
148 const csFine inAccessor(in);
149 csFine outAccessor(out);
151 if (dagger) applyStaggeredKDBlock<vFloatSpinor, vFloatGauge, fineColor, coarseDof, true>(outAccessor, inAccessor, xInvAccessor, out, in, Xinv);
152 else applyStaggeredKDBlock<vFloatSpinor, vFloatGauge, fineColor, coarseDof, false>(outAccessor, inAccessor, xInvAccessor, out, in, Xinv);
155 // template on coarse color, spin
156 template <typename vFloatSpinor, typename vFloatGauge, int fineColor, int fineSpin>
157 void applyStaggeredKDBlock(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &Xinv, bool dagger)
159 //constexpr int coarseSpin = 2;
160 const int coarseDof = Xinv.Ncolor(); // / coarseSpin;
162 if (coarseDof == 48) { // dof w/in a KD block
163 applyStaggeredKDBlock<vFloatSpinor, vFloatGauge, fineColor, fineSpin, 48>(out, in, Xinv, dagger);
165 errorQuda("Unsupported number of Kahler-Dirac dof %d\n", Xinv.Ncolor());
169 // template on fine colors, spin
170 template <typename vFloatSpinor, typename vFloatGauge>
171 void applyStaggeredKDBlock(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &Xinv, bool dagger)
173 if (out.Ncolor() != in.Ncolor())
174 errorQuda("Ncolor %d and %d do not match", out.Ncolor(), in.Ncolor());
176 if (out.Nspin() != in.Nspin())
177 errorQuda("Nspin %d and %d do not match", out.Nspin(), in.Nspin());
179 if (out.Ncolor() == 3 && out.Nspin() == 1) {
180 applyStaggeredKDBlock<vFloatSpinor, vFloatGauge, 3, 1>(out, in, Xinv, dagger);
182 errorQuda("Unsupported (color, spin) = (%d, %d)", out.Ncolor(), out.Nspin());
186 // template on Xinv precision (only half and single for now)
187 template <typename vFloatSpinor> struct StaggeredKDBlockApply {
188 StaggeredKDBlockApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &Xinv, bool dagger)
191 #if QUDA_PRECISION & 4
192 if (Xinv.Precision() == QUDA_SINGLE_PRECISION) {
193 applyStaggeredKDBlock<vFloatSpinor, float>(out, in, Xinv, dagger);
196 #if QUDA_PRECISION & 2
197 if (Xinv.Precision() == QUDA_HALF_PRECISION) {
198 applyStaggeredKDBlock<vFloatSpinor, short>(out, in, Xinv, dagger);
202 errorQuda("Unsupported precision %d", Xinv.Precision());
209 // Applies the staggered KD block inverse to a staggered ColorSpinor
210 void ApplyStaggeredKahlerDiracInverse(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &Xinv, bool dagger)
212 #if defined(GPU_STAGGERED_DIRAC)
213 auto location = checkLocation(out, in, Xinv);
215 if (location == QUDA_CPU_FIELD_LOCATION)
216 errorQuda("There is no support for applying the KD operator to CPU fields (yet)");
218 // the staggered KD block inverse can only be applied to a full field
219 if (out.SiteSubset() != QUDA_FULL_SITE_SUBSET || out.SiteSubset() != QUDA_FULL_SITE_SUBSET)
220 errorQuda("There is no meaning to applying the KD inverse to a single parity field");
222 checkPrecision(out, in);
224 // Instantiate based on ColorSpinor precision
225 // We don't have a constraint on the precision of Xinv matching
226 // the precision of the spinors.
227 instantiatePrecision<StaggeredKDBlockApply>(out, in, Xinv, dagger);
230 errorQuda("Staggered fermion support has not been built");