1 #include <color_spinor_field.h>
2 #include <color_spinor_field_order.h>
3 #include <dslash_quda.h>
4 #include <index_helper.cuh>
5 #include <instantiate.h>
7 #include <kernels/dslash_domain_wall_m5.cuh>
12 template <typename Float, int nColor> class Dslash5 : public TunableVectorYZ
14 Dslash5Arg<Float, nColor> arg;
15 const ColorSpinorField &meta;
16 static constexpr bool shared = true; // whether to use shared memory cache blocking for M5inv
18 /** Whether to use variable or fixed coefficient algorithm. Must be true if using ZMOBIUS */
19 static constexpr bool var_inverse = true;
21 long long flops() const
23 long long Ls = meta.X(4);
24 long long bulk = (Ls - 2) * (meta.Volume() / Ls);
25 long long wall = 2 * meta.Volume() / Ls;
26 long long n = meta.Ncolor() * meta.Nspin();
30 case DSLASH5_DWF: flops_ = n * (8ll * bulk + 10ll * wall + (arg.xpay ? 4ll * meta.Volume() : 0)); break;
31 case DSLASH5_MOBIUS_PRE:
32 flops_ = n * (8ll * bulk + 10ll * wall + 14ll * meta.Volume() + (arg.xpay ? 8ll * meta.Volume() : 0));
35 flops_ = n * (8ll * bulk + 10ll * wall + 8ll * meta.Volume() + (arg.xpay ? 8ll * meta.Volume() : 0));
38 case M5_INV_MOBIUS: // FIXME flops
39 flops_ = ((2 + 8 * n) * Ls + (arg.xpay ? 4ll : 0)) * meta.Volume();
41 case M5_INV_ZMOBIUS: flops_ = ((12 + 16 * n) * Ls + (arg.xpay ? 8ll : 0)) * meta.Volume(); break;
42 default: errorQuda("Unknown Dslash5Type %d", arg.type);
48 long long bytes() const
50 long long Ls = meta.X(4);
52 case DSLASH5_DWF: return arg.out.Bytes() + 2 * arg.in.Bytes() + (arg.xpay ? arg.x.Bytes() : 0);
53 case DSLASH5_MOBIUS_PRE: return arg.out.Bytes() + 3 * arg.in.Bytes() + (arg.xpay ? arg.x.Bytes() : 0);
54 case DSLASH5_MOBIUS: return arg.out.Bytes() + 3 * arg.in.Bytes() + (arg.xpay ? arg.x.Bytes() : 0);
55 case M5_INV_DWF: return arg.out.Bytes() + Ls * arg.in.Bytes() + (arg.xpay ? arg.x.Bytes() : 0);
56 case M5_INV_MOBIUS: return arg.out.Bytes() + Ls * arg.in.Bytes() + (arg.xpay ? arg.x.Bytes() : 0);
57 case M5_INV_ZMOBIUS: return arg.out.Bytes() + Ls * arg.in.Bytes() + (arg.xpay ? arg.x.Bytes() : 0);
58 default: errorQuda("Unknown Dslash5Type %d", arg.type);
63 bool tuneGridDim() const { return false; }
64 unsigned int minThreads() const { return arg.volume_4d_cb; }
65 int blockStep() const { return 4; }
66 int blockMin() const { return 4; }
67 unsigned int sharedBytesPerThread() const
69 if (shared && (arg.type == M5_INV_DWF || arg.type == M5_INV_MOBIUS || arg.type == M5_INV_ZMOBIUS)) {
70 // spin components in shared depend on inversion algorithm
71 int nSpin = var_inverse ? meta.Nspin() / 2 : meta.Nspin();
72 return 2 * nSpin * nColor * sizeof(typename mapper<Float>::type);
78 // overloaded to return max dynamic shared memory if doing shared-memory inverse
79 unsigned int maxSharedBytesPerBlock() const
81 if (shared && (arg.type == M5_INV_DWF || arg.type == M5_INV_MOBIUS || arg.type == M5_INV_ZMOBIUS)) {
82 return maxDynamicSharedBytesPerBlock();
84 return TunableVectorYZ::maxSharedBytesPerBlock();
89 Dslash5(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, double m_f,
90 double m_5, const Complex *b_5, const Complex *c_5, double a, bool dagger, Dslash5Type type) :
91 TunableVectorYZ(in.X(4), in.SiteSubset()),
92 arg(out, in, x, m_f, m_5, b_5, c_5, a, dagger, type),
95 strcpy(aux, meta.AuxString());
96 if (arg.dagger) strcat(aux, ",Dagger");
97 if (arg.xpay) strcat(aux, ",xpay");
99 case DSLASH5_DWF: strcat(aux, ",DSLASH5_DWF"); break;
100 case DSLASH5_MOBIUS_PRE: strcat(aux, ",DSLASH5_MOBIUS_PRE"); break;
101 case DSLASH5_MOBIUS: strcat(aux, ",DSLASH5_MOBIUS"); break;
102 case M5_INV_DWF: strcat(aux, ",M5_INV_DWF"); break;
103 case M5_INV_MOBIUS: strcat(aux, ",M5_INV_MOBIUS"); break;
104 case M5_INV_ZMOBIUS: strcat(aux, ",M5_INV_ZMOBIUS"); break;
105 default: errorQuda("Unknown Dslash5Type %d", arg.type);
108 apply(streams[Nstream - 1]);
111 template <typename T, typename Arg> inline void launch(T *f, TuneParam &tp, Arg &arg, const qudaStream_t &stream)
113 if (shared && (arg.type == M5_INV_DWF || arg.type == M5_INV_MOBIUS || arg.type == M5_INV_ZMOBIUS)) {
114 // if inverse kernel uses shared memory then maximize total shared memory pool
115 tp.set_max_shared_bytes = true;
117 qudaLaunchKernel(f, tp, stream, arg);
120 void apply(const qudaStream_t &stream)
122 using Arg = decltype(arg);
123 if (meta.Location() == QUDA_CPU_FIELD_LOCATION) {
124 errorQuda("CPU variant not instantiated");
126 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
127 if (arg.type == DSLASH5_DWF) {
129 arg.dagger ? launch(dslash5GPU<Float, nColor, true, true, DSLASH5_DWF, Arg>, tp, arg, stream) :
130 launch(dslash5GPU<Float, nColor, false, true, DSLASH5_DWF, Arg>, tp, arg, stream);
132 arg.dagger ? launch(dslash5GPU<Float, nColor, true, false, DSLASH5_DWF, Arg>, tp, arg, stream) :
133 launch(dslash5GPU<Float, nColor, false, false, DSLASH5_DWF, Arg>, tp, arg, stream);
134 } else if (arg.type == DSLASH5_MOBIUS_PRE) {
136 arg.dagger ? launch(dslash5GPU<Float, nColor, true, true, DSLASH5_MOBIUS_PRE, Arg>, tp, arg, stream) :
137 launch(dslash5GPU<Float, nColor, false, true, DSLASH5_MOBIUS_PRE, Arg>, tp, arg, stream);
139 arg.dagger ? launch(dslash5GPU<Float, nColor, true, false, DSLASH5_MOBIUS_PRE, Arg>, tp, arg, stream) :
140 launch(dslash5GPU<Float, nColor, false, false, DSLASH5_MOBIUS_PRE, Arg>, tp, arg, stream);
141 } else if (arg.type == DSLASH5_MOBIUS) {
143 arg.dagger ? launch(dslash5GPU<Float, nColor, true, true, DSLASH5_MOBIUS, Arg>, tp, arg, stream) :
144 launch(dslash5GPU<Float, nColor, false, true, DSLASH5_MOBIUS, Arg>, tp, arg, stream);
146 arg.dagger ? launch(dslash5GPU<Float, nColor, true, false, DSLASH5_MOBIUS, Arg>, tp, arg, stream) :
147 launch(dslash5GPU<Float, nColor, false, false, DSLASH5_MOBIUS, Arg>, tp, arg, stream);
148 } else if (arg.type == M5_INV_DWF) {
151 launch(dslash5invGPU<Float, nColor, true, true, M5_INV_DWF, shared, var_inverse, Arg>, tp, arg, stream) :
152 launch(dslash5invGPU<Float, nColor, false, true, M5_INV_DWF, shared, var_inverse, Arg>, tp, arg, stream);
155 launch(dslash5invGPU<Float, nColor, true, false, M5_INV_DWF, shared, var_inverse, Arg>, tp, arg, stream) :
156 launch(dslash5invGPU<Float, nColor, false, false, M5_INV_DWF, shared, var_inverse, Arg>, tp, arg, stream);
157 } else if (arg.type == M5_INV_MOBIUS) {
160 dslash5invGPU<Float, nColor, true, true, M5_INV_MOBIUS, shared, var_inverse, Arg>, tp, arg, stream) :
161 launch(dslash5invGPU<Float, nColor, false, true, M5_INV_MOBIUS, shared, var_inverse, Arg>, tp,
165 dslash5invGPU<Float, nColor, true, false, M5_INV_MOBIUS, shared, var_inverse, Arg>, tp, arg, stream) :
166 launch(dslash5invGPU<Float, nColor, false, false, M5_INV_MOBIUS, shared, var_inverse, Arg>, tp,
168 } else if (arg.type == M5_INV_ZMOBIUS) {
171 dslash5invGPU<Float, nColor, true, true, M5_INV_ZMOBIUS, shared, var_inverse, Arg>, tp, arg, stream) :
172 launch(dslash5invGPU<Float, nColor, false, true, M5_INV_ZMOBIUS, shared, var_inverse, Arg>, tp,
176 dslash5invGPU<Float, nColor, true, false, M5_INV_ZMOBIUS, shared, var_inverse, Arg>, tp, arg, stream) :
177 launch(dslash5invGPU<Float, nColor, false, false, M5_INV_ZMOBIUS, shared, var_inverse, Arg>,
183 void initTuneParam(TuneParam ¶m) const
185 TunableVectorYZ::initTuneParam(param);
186 if (shared && (arg.type == M5_INV_DWF || arg.type == M5_INV_MOBIUS || arg.type == M5_INV_ZMOBIUS)) {
187 param.block.y = arg.Ls; // Ls must be contained in the block
189 param.shared_bytes = sharedBytesPerThread() * param.block.x * param.block.y * param.block.z;
193 void defaultTuneParam(TuneParam ¶m) const
195 TunableVectorYZ::defaultTuneParam(param);
196 if (shared && (arg.type == M5_INV_DWF || arg.type == M5_INV_MOBIUS || arg.type == M5_INV_ZMOBIUS)) {
197 param.block.y = arg.Ls; // Ls must be contained in the block
199 param.shared_bytes = sharedBytesPerThread() * param.block.x * param.block.y * param.block.z;
203 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
206 // Apply the 5th dimension dslash operator to a colorspinor field
208 void ApplyDslash5(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, double m_f,
209 double m_5, const Complex *b_5, const Complex *c_5, double a, bool dagger, Dslash5Type type)
211 #ifdef GPU_DOMAIN_WALL_DIRAC
212 if (in.PCType() != QUDA_4D_PC) errorQuda("Only 4-d preconditioned fields are supported");
213 checkLocation(out, in, x); // check all locations match
214 instantiate<Dslash5>(out, in, x, m_f, m_5, b_5, c_5, a, dagger, type);
216 errorQuda("Domain wall dslash has not been built");