2 #include <gauge_field.h>
3 #include <jitify_helper.cuh>
4 #include <kernels/field_strength_tensor.cuh>
5 #include <instantiate.h>
10 template <typename Float, int nColor, QudaReconstructType recon> class Fmunu : TunableVectorYZ
12 FmunuArg<Float, nColor, recon> arg;
13 const GaugeField &meta;
15 unsigned int minThreads() const { return arg.threads; }
16 bool tuneGridDim() const { return false; }
19 Fmunu(const GaugeField &u, GaugeField &f) :
20 TunableVectorYZ(2, 6),
24 strcpy(aux, meta.AuxString());
25 strcat(aux, comm_dim_partitioned_string());
26 if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
28 create_jitify_program("kernels/field_strength_tensor.cuh");
32 qudaDeviceSynchronize();
35 void apply(const qudaStream_t &stream)
37 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
39 using namespace jitify::reflection;
40 jitify_error = program->kernel("quda::computeFmunuKernel").instantiate(Type<decltype(arg)>())
41 .configure(tp.grid, tp.block, tp.shared_bytes, stream).launch(arg);
43 qudaLaunchKernel(computeFmunuKernel<decltype(arg)>, tp, stream, arg);
47 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
49 long long flops() const { return (2430 + 36) * 6 * 2 * (long long)arg.threads; }
50 long long bytes() const
52 return ((16 * arg.u.Bytes() + arg.f.Bytes()) * 6 * 2 * arg.threads);
53 } // Ignores link reconstruction
57 void computeFmunu(GaugeField &f, const GaugeField &u)
59 #ifdef GPU_GAUGE_TOOLS
61 instantiate<Fmunu,ReconstructWilson>(u, f); // u must be first here for correct template instantiation
63 errorQuda("Gauge tools are not built");
64 #endif // GPU_GAUGE_TOOLS