3 #include <dslash_helper.cuh>
4 #include <color_spinor_field_order.h>
5 #include <gauge_field_order.h>
6 #include <color_spinor.h>
7 #include <dslash_helper.cuh>
8 #include <index_helper.cuh>
9 #include <gauge_field.h>
10 #include <uint_to_char.h>
12 #include <dslash_policy.cuh>
13 #include <kernels/covDev.cuh>
16 This is the covariant derivative based on the basic gauged Laplace operator
22 template <typename Arg> class CovDev : public Dslash<covDev, Arg>
24 using Dslash = Dslash<covDev, Arg>;
29 CovDev(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) : Dslash(arg, out, in) {}
31 void apply(const qudaStream_t &stream)
33 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
35 if (arg.xpay) errorQuda("Covariant derivative operator only defined without xpay");
36 if (arg.nParity != 2) errorQuda("Covariant derivative operator only defined for full field");
38 constexpr bool xpay = false;
39 constexpr int nParity = 2;
40 Dslash::template instantiate<packShmem, nParity, xpay>(tp, stream);
43 long long flops() const
45 int mv_flops = (8 * in.Ncolor() - 2) * in.Ncolor(); // SU(3) matrix-vector flops
46 int num_mv_multiply = in.Nspin();
47 int ghost_flops = num_mv_multiply * mv_flops;
51 switch (arg.kernel_type) {
52 case EXTERIOR_KERNEL_X:
53 case EXTERIOR_KERNEL_Y:
54 case EXTERIOR_KERNEL_Z:
55 case EXTERIOR_KERNEL_T:
56 if (arg.kernel_type != dim) break;
57 flops_ = (ghost_flops)*in.GhostFace()[dim];
59 case EXTERIOR_KERNEL_ALL: {
60 long long ghost_sites = in.GhostFace()[dim];
61 flops_ = ghost_flops * ghost_sites;
67 long long sites = in.Volume();
68 flops_ = num_mv_multiply * mv_flops * sites; // SU(3) matrix-vector multiplies
70 if (arg.kernel_type == KERNEL_POLICY) break;
71 // now correct for flops done by exterior kernel
72 long long ghost_sites = arg.commDim[dim] ? in.GhostFace()[dim] : 0;
73 flops_ -= ghost_flops * ghost_sites;
82 long long bytes() const
84 int gauge_bytes = arg.reconstruct * in.Precision();
85 int spinor_bytes = 2 * in.Ncolor() * in.Nspin() * in.Precision() +
86 (isFixed<typename Arg::Float>::value ? sizeof(float) : 0);
87 int ghost_bytes = gauge_bytes + 3 * spinor_bytes; // 3 since we have to load the partial
91 switch (arg.kernel_type) {
92 case EXTERIOR_KERNEL_X:
93 case EXTERIOR_KERNEL_Y:
94 case EXTERIOR_KERNEL_Z:
95 case EXTERIOR_KERNEL_T:
96 if (arg.kernel_type != dim) break;
97 bytes_ = ghost_bytes * in.GhostFace()[dim];
99 case EXTERIOR_KERNEL_ALL: {
100 long long ghost_sites = in.GhostFace()[dim];
101 bytes_ = ghost_bytes * ghost_sites;
104 case INTERIOR_KERNEL:
106 case KERNEL_POLICY: {
107 long long sites = in.Volume();
108 bytes_ = (gauge_bytes + 2 * spinor_bytes) * sites;
110 if (arg.kernel_type == KERNEL_POLICY) break;
111 // now correct for bytes done by exterior kernel
112 long long ghost_sites = arg.commDim[dim] ? in.GhostFace()[dim] : 0;
113 bytes_ -= ghost_bytes * ghost_sites;
121 TuneKey tuneKey() const
124 char aux[TuneKey::aux_n];
126 (arg.pack_blocks > 0 && arg.kernel_type == INTERIOR_KERNEL) ? Dslash::aux_pack :
127 Dslash::aux[arg.kernel_type]);
132 return TuneKey(in.VolString(), typeid(*this).name(), aux);
136 template <typename Float, int nColor, QudaReconstructType recon> struct CovDevApply {
138 inline CovDevApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, int mu, int parity,
139 bool dagger, const int *comm_override, TimeProfile &profile)
142 constexpr int nDim = 4;
143 CovDevArg<Float, nColor, recon, nDim> arg(out, in, U, mu, parity, dagger, comm_override);
144 CovDev<decltype(arg)> covDev(arg, out, in);
146 dslash::DslashPolicyTune<decltype(covDev)> policy(
147 covDev, const_cast<cudaColorSpinorField *>(static_cast<const cudaColorSpinorField *>(&in)), in.VolumeCB(),
148 in.GhostFaceCB(), profile);
153 // Apply the covariant derivative operator
154 // out(x) = U_{\mu}(x)in(x+mu) for mu = 0...3
155 // out(x) = U^\dagger_mu'(x-mu')in(x-mu') for mu = 4...7 and we set mu' = mu-4
156 void ApplyCovDev(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, int mu, int parity,
157 bool dagger, const int *comm_override, TimeProfile &profile)
160 instantiate<CovDevApply>(out, in, U, mu, parity, dagger, comm_override, profile);
162 errorQuda("Covariant derivative kernels have not been built");