2 #include <clover_field.h>
3 #include <launch_kernel.cuh>
4 #include <instantiate.h>
6 #include <jitify_helper.cuh>
7 #include <kernels/clover_invert.cuh>
11 using namespace clover;
13 template <typename store_t>
14 class CloverInvert : TunableLocalParityReduction {
15 CloverInvertArg<store_t> arg;
16 const CloverField &meta; // used for meta data only
19 CloverInvert(CloverField &clover, bool compute_tr_log) :
20 arg(clover, compute_tr_log),
23 writeAuxString("stride=%d,prec=%lu,trlog=%s,twist=%s", arg.clover.stride, sizeof(store_t),
24 compute_tr_log ? "true" : "false", arg.twist ? "true" : "false");
25 if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
27 create_jitify_program("kernels/clover_invert.cuh");
33 arg.complete(*clover.TrLog());
34 comm_allreduce_array(clover.TrLog(), 2);
38 void apply(const qudaStream_t &stream) {
39 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
40 using Arg = decltype(arg);
41 if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
43 using namespace jitify::reflection;
44 jitify_error = program->kernel("quda::cloverInvertKernel")
45 .instantiate((int)tp.block.x, Type<Arg>(), arg.compute_tr_log, arg.twist)
46 .configure(tp.grid, tp.block, tp.shared_bytes, stream)
49 if (arg.compute_tr_log) {
51 errorQuda("Not instantiated");
53 LAUNCH_KERNEL_LOCAL_PARITY(cloverInvertKernel, (*this), tp, stream, arg, Arg, true, false);
57 qudaLaunchKernel(cloverInvertKernel<1,Arg,false,true>, tp, stream, arg);
59 qudaLaunchKernel(cloverInvertKernel<1,Arg,false,false>, tp, stream, arg);
64 errorQuda("Not implemented");
68 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
69 long long flops() const { return 0; }
70 long long bytes() const { return 2*arg.clover.volumeCB*(arg.inverse.Bytes() + arg.clover.Bytes()); }
71 void preTune() { if (arg.clover.clover == arg.inverse.clover) arg.inverse.save(); }
72 void postTune() { if (arg.clover.clover == arg.inverse.clover) arg.inverse.load(); }
75 // this is the function that is actually called, from here on down we instantiate all required templates
76 void cloverInvert(CloverField &clover, bool computeTraceLog)
78 #ifdef GPU_CLOVER_DIRAC
79 instantiate<CloverInvert>(clover, computeTraceLog);
81 errorQuda("Clover has not been built");