3 #include <dslash_helper.cuh>
4 #include <color_spinor_field_order.h>
5 #include <gauge_field_order.h>
6 #include <color_spinor.h>
7 #include <dslash_helper.cuh>
8 #include <index_helper.cuh>
9 #include <gauge_field.h>
10 #include <uint_to_char.h>
12 #include <dslash_policy.cuh>
13 #include <kernels/laplace.cuh>
16 This is the laplacian derivative based on the basic gauged differential operator
22 template <typename Arg> class Laplace : public Dslash<laplace, Arg>
24 using Dslash = Dslash<laplace, Arg>;
29 Laplace(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) : Dslash(arg, out, in) {}
31 void apply(const qudaStream_t &stream)
33 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
36 // operator is Hermitian so do not instantiate dagger
37 if (arg.nParity == 1) {
39 Dslash::template instantiate<packStaggeredShmem, 1, false, true>(tp, stream);
41 Dslash::template instantiate<packStaggeredShmem, 1, false, false>(tp, stream);
42 } else if (arg.nParity == 2) {
44 Dslash::template instantiate<packStaggeredShmem, 2, false, true>(tp, stream);
46 Dslash::template instantiate<packStaggeredShmem, 2, false, false>(tp, stream);
50 long long flops() const
52 int mv_flops = (8 * in.Ncolor() - 2) * in.Ncolor(); // SU(3) matrix-vector flops
53 int num_mv_multiply = in.Nspin() == 4 ? 2 : 1;
54 int ghost_flops = (num_mv_multiply * mv_flops + 2 * in.Ncolor() * in.Nspin());
55 int xpay_flops = 2 * 2 * in.Ncolor() * in.Nspin(); // multiply and add per real component
56 int num_dir = (arg.dir == 4 ? 2 * 4 : 2 * 3); // 3D or 4D operator
60 switch (arg.kernel_type) {
61 case EXTERIOR_KERNEL_X:
62 case EXTERIOR_KERNEL_Y:
63 case EXTERIOR_KERNEL_Z:
64 case EXTERIOR_KERNEL_T:
65 flops_ = (ghost_flops + (arg.xpay ? xpay_flops : xpay_flops / 2)) * 2 * in.GhostFace()[arg.kernel_type];
67 case EXTERIOR_KERNEL_ALL: {
68 long long ghost_sites = 2 * (in.GhostFace()[0] + in.GhostFace()[1] + in.GhostFace()[2] + in.GhostFace()[3]);
69 flops_ = (ghost_flops + (arg.xpay ? xpay_flops : xpay_flops / 2)) * ghost_sites;
75 long long sites = in.Volume();
76 flops_ = (num_dir * (in.Nspin() / 4) * in.Ncolor() * in.Nspin() + // spin project (=0 for staggered)
77 num_dir * num_mv_multiply * mv_flops + // SU(3) matrix-vector multiplies
78 ((num_dir - 1) * 2 * in.Ncolor() * in.Nspin()))
79 * sites; // accumulation
80 if (arg.xpay) flops_ += xpay_flops * sites;
82 if (arg.kernel_type == KERNEL_POLICY) break;
83 // now correct for flops done by exterior kernel
84 long long ghost_sites = 0;
85 for (int d = 0; d < 4; d++)
86 if (arg.commDim[d]) ghost_sites += 2 * in.GhostFace()[d];
87 flops_ -= ghost_flops * ghost_sites;
96 virtual long long bytes() const
98 int gauge_bytes = arg.reconstruct * in.Precision();
99 int spinor_bytes = 2 * in.Ncolor() * in.Nspin() * in.Precision() + (isFixed<typename Arg::Float>::value ? sizeof(float) : 0);
100 int proj_spinor_bytes = in.Nspin() == 4 ? spinor_bytes / 2 : spinor_bytes;
101 int ghost_bytes = (proj_spinor_bytes + gauge_bytes) + 2 * spinor_bytes; // 2 since we have to load the partial
102 int num_dir = (arg.dir == 4 ? 2 * 4 : 2 * 3); // 3D or 4D operator
104 long long bytes_ = 0;
106 switch (arg.kernel_type) {
107 case EXTERIOR_KERNEL_X:
108 case EXTERIOR_KERNEL_Y:
109 case EXTERIOR_KERNEL_Z:
110 case EXTERIOR_KERNEL_T: bytes_ = ghost_bytes * 2 * in.GhostFace()[arg.kernel_type]; break;
111 case EXTERIOR_KERNEL_ALL: {
112 long long ghost_sites = 2 * (in.GhostFace()[0] + in.GhostFace()[1] + in.GhostFace()[2] + in.GhostFace()[3]);
113 bytes_ = ghost_bytes * ghost_sites;
116 case INTERIOR_KERNEL:
118 case KERNEL_POLICY: {
119 long long sites = in.Volume();
120 bytes_ = (num_dir * gauge_bytes + ((num_dir - 2) * spinor_bytes + 2 * proj_spinor_bytes) + spinor_bytes) * sites;
121 if (arg.xpay) bytes_ += spinor_bytes;
123 if (arg.kernel_type == KERNEL_POLICY) break;
124 // now correct for bytes done by exterior kernel
125 long long ghost_sites = 0;
126 for (int d = 0; d < 4; d++)
127 if (arg.commDim[d]) ghost_sites += 2 * in.GhostFace()[d];
128 bytes_ -= ghost_bytes * ghost_sites;
136 TuneKey tuneKey() const
138 // add laplace transverse dir to the key
139 char aux[TuneKey::aux_n];
141 (arg.pack_blocks > 0 && arg.kernel_type == INTERIOR_KERNEL) ? Dslash::aux_pack :
142 Dslash::aux[arg.kernel_type]);
143 strcat(aux, ",laplace=");
145 u32toa(laplace, arg.dir);
146 strcat(aux, laplace);
147 return TuneKey(in.VolString(), typeid(*this).name(), aux);
151 template <typename Float, int nColor, QudaReconstructType recon> struct LaplaceApply {
153 inline LaplaceApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, int dir,
154 double a, double b, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override,
155 TimeProfile &profile)
157 if (in.Nspin() == 1) {
158 #if defined(GPU_STAGGERED_DIRAC) && defined(GPU_LAPLACE)
159 constexpr int nDim = 4;
160 constexpr int nSpin = 1;
161 LaplaceArg<Float, nSpin, nColor, nDim, recon> arg(out, in, U, dir, a, b, x, parity, dagger, comm_override);
162 Laplace<decltype(arg)> laplace(arg, out, in);
164 dslash::DslashPolicyTune<decltype(laplace)> policy(
165 laplace, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
166 in.GhostFaceCB(), profile);
169 errorQuda("nSpin=1 Laplace operator required staggered dslash and laplace to be enabled");
171 } else if (in.Nspin() == 4) {
172 #if defined(GPU_WILSON_DIRAC) && defined(GPU_LAPLACE)
173 constexpr int nDim = 4;
174 constexpr int nSpin = 4;
175 LaplaceArg<Float, nSpin, nColor, nDim, recon> arg(out, in, U, dir, a, b, x, parity, dagger, comm_override);
176 Laplace<decltype(arg)> laplace(arg, out, in);
178 dslash::DslashPolicyTune<decltype(laplace)> policy(
179 laplace, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
180 in.GhostFaceCB(), profile);
183 errorQuda("nSpin=4 Laplace operator required wilson dslash and laplace to be enabled");
186 errorQuda("Unsupported nSpin= %d", in.Nspin());
191 // Apply the Laplace operator
192 // out(x) = M*in = - a*\sum_mu U_{-\mu}(x)in(x+mu) + U^\dagger_mu(x-mu)in(x-mu) + b*in(x)
193 // Omits direction 'dir' from the operator.
194 void ApplyLaplace(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, int dir, double a, double b,
195 const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
197 instantiate<LaplaceApply>(out, in, U, dir, a, b, x, parity, dagger, comm_override, profile);