1 #include <color_spinor_field.h>
2 #include <color_spinor_field_order.h>
3 #include <index_helper.cuh>
4 #include <instantiate.h>
5 #include <instantiate_dslash.h>
7 #include <kernels/dslash_mobius_eofa.cuh>
13 template <typename storage_type, int nColor> class Dslash5 : public TunableVectorYZ
15 Dslash5Arg<storage_type, nColor> arg;
16 const ColorSpinorField &meta;
17 static constexpr bool shared = true; // whether to use shared memory cache blocking for M5inv
19 long long flops() const
21 // FIXME: Fix the flop count
22 long long Ls = meta.X(4);
23 long long bulk = (Ls - 2) * (meta.Volume() / Ls);
24 long long wall = 2 * meta.Volume() / Ls;
25 long long n = meta.Ncolor() * meta.Nspin();
30 case M5INV_EOFA: flops_ = n * (8ll * bulk + 10ll * wall + (arg.xpay ? 4ll * meta.Volume() : 0)); break;
31 default: errorQuda("Unknown Dslash5Type %d for EOFA", arg.type);
37 long long bytes() const
39 long long Ls = meta.X(4);
42 case M5INV_EOFA: return arg.out.Bytes() + 2 * arg.in.Bytes() + (arg.xpay ? arg.x.Bytes() : 0);
43 default: errorQuda("Unknown Dslash5Type %d for EOFA", arg.type);
48 bool tuneGridDim() const { return false; }
49 unsigned int minThreads() const { return arg.volume_4d_cb; }
50 int blockStep() const { return 4; }
51 int blockMin() const { return 4; }
52 unsigned int sharedBytesPerThread() const
54 // spin components in shared depend on inversion algorithm
55 int nSpin = meta.Nspin();
56 return 2 * nSpin * nColor * sizeof(typename mapper<storage_type>::type);
59 unsigned int sharedBytesPerBlock(const TuneParam ¶m) const { return 0; }
61 // overloaded to return max dynamic shared memory if doing shared-memory
63 unsigned int maxSharedBytesPerBlock() const
65 if (shared && (arg.type == M5_EOFA || arg.type == M5INV_EOFA)) {
66 return maxDynamicSharedBytesPerBlock();
68 return TunableVectorYZ::maxSharedBytesPerBlock();
73 Dslash5(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, const double m_f,
74 const double m_5, const Complex *b_5, const Complex *c_5, double a, int eofa_pm, double inv,
75 double kappa, const double *eofa_u, const double *eofa_x, const double *eofa_y,
76 double sherman_morrison, bool dagger, Dslash5Type type) :
77 TunableVectorYZ(in.X(4), in.SiteSubset()),
78 arg(out, in, x, m_f, m_5, b_5, c_5, a, eofa_pm, inv, kappa, eofa_u, eofa_x, eofa_y, sherman_morrison, dagger, type),
81 TunableVectorY::resizeStep(arg.Ls);
82 strcpy(aux, meta.AuxString());
83 if (arg.dagger) strcat(aux, ",Dagger");
84 if (arg.xpay) strcat(aux, ",xpay");
86 strcat(aux, ",eofa_plus");
88 strcat(aux, ",eofa_minus");
91 case M5_EOFA: strcat(aux, ",mobius_M5_EOFA"); break;
92 case M5INV_EOFA: strcat(aux, ",mobius_M5INV_EOFA"); break;
93 default: errorQuda("Unknown Dslash5Type %d", arg.type);
96 apply(streams[Nstream - 1]);
99 template <typename T, typename Arg> inline void launch(T *f, TuneParam &tp, Arg &arg, const qudaStream_t &stream)
101 if (shared && (arg.type == M5_EOFA || arg.type == M5INV_EOFA)) {
102 // if inverse kernel uses shared memory then maximize total shared memory
103 tp.set_max_shared_bytes = true;
105 qudaLaunchKernel(f, tp, stream, arg);
108 void apply(const qudaStream_t &stream)
110 using Arg = decltype(arg);
111 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
112 if (arg.type == M5_EOFA) {
115 arg.dagger ? launch(dslash5GPU<storage_type, nColor, true, true, true, M5_EOFA, Arg>, tp, arg, stream) :
116 launch(dslash5GPU<storage_type, nColor, false, true, true, M5_EOFA, Arg>, tp, arg, stream);
118 arg.dagger ? launch(dslash5GPU<storage_type, nColor, true, true, false, M5_EOFA, Arg>, tp, arg, stream) :
119 launch(dslash5GPU<storage_type, nColor, false, true, false, M5_EOFA, Arg>, tp, arg, stream);
123 arg.dagger ? launch(dslash5GPU<storage_type, nColor, true, false, true, M5_EOFA, Arg>, tp, arg, stream) :
124 launch(dslash5GPU<storage_type, nColor, false, false, true, M5_EOFA, Arg>, tp, arg, stream);
126 arg.dagger ? launch(dslash5GPU<storage_type, nColor, true, false, false, M5_EOFA, Arg>, tp, arg, stream) :
127 launch(dslash5GPU<storage_type, nColor, false, false, false, M5_EOFA, Arg>, tp, arg, stream);
130 } else if (arg.type == M5INV_EOFA) {
133 arg.dagger ? launch(dslash5GPU<storage_type, nColor, true, true, true, M5INV_EOFA, Arg>, tp, arg, stream) :
134 launch(dslash5GPU<storage_type, nColor, false, true, true, M5INV_EOFA, Arg>, tp, arg, stream);
137 launch(dslash5GPU<storage_type, nColor, true, true, false, M5INV_EOFA, Arg>, tp, arg, stream) :
138 launch(dslash5GPU<storage_type, nColor, false, true, false, M5INV_EOFA, Arg>, tp, arg, stream);
143 launch(dslash5GPU<storage_type, nColor, true, false, true, M5INV_EOFA, Arg>, tp, arg, stream) :
144 launch(dslash5GPU<storage_type, nColor, false, false, true, M5INV_EOFA, Arg>, tp, arg, stream);
147 launch(dslash5GPU<storage_type, nColor, true, false, false, M5INV_EOFA, Arg>, tp, arg, stream) :
148 launch(dslash5GPU<storage_type, nColor, false, false, false, M5INV_EOFA, Arg>, tp, arg, stream);
152 errorQuda("Unknown Dslash5Type %d", arg.type);
156 void initTuneParam(TuneParam ¶m) const
158 TunableVectorYZ::initTuneParam(param);
159 param.block.y = arg.Ls; // Ls must be contained in the block
161 param.shared_bytes = sharedBytesPerThread() * param.block.x * param.block.y * param.block.z;
164 void defaultTuneParam(TuneParam ¶m) const { initTuneParam(param); }
166 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
169 // Apply the 5th dimension dslash operator to a colorspinor field
171 void apply_dslash5(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, double m_f,
172 double m_5, const Complex *b_5, const Complex *c_5, double a, int eofa_pm, double inv,
173 double kappa, const double *eofa_u, const double *eofa_x, const double *eofa_y,
174 double sherman_morrison, bool dagger, Dslash5Type type)
176 #ifdef GPU_DOMAIN_WALL_DIRAC
177 checkLocation(out, in, x); // check all locations match
178 instantiate<Dslash5>(out, in, x, m_f, m_5, b_5, c_5, a, eofa_pm, inv, kappa, eofa_u, eofa_x, eofa_y,
179 sherman_morrison, dagger, type);
181 errorQuda("Mobius EOFA dslash has not been built");
184 } // namespace mobius_eofa