QUDA  v1.1.0
A library for QCD on GPUs
clover_invert.cu
Go to the documentation of this file.
1 #include <tune_quda.h>
2 #include <clover_field.h>
3 #include <launch_kernel.cuh>
4 #include <instantiate.h>
5 
6 #include <jitify_helper.cuh>
7 #include <kernels/clover_invert.cuh>
8 
9 namespace quda {
10 
11  using namespace clover;
12 
13  template <typename store_t>
14  class CloverInvert : TunableLocalParityReduction {
15  CloverInvertArg<store_t> arg;
16  const CloverField &meta; // used for meta data only
17 
18  public:
19  CloverInvert(CloverField &clover, bool compute_tr_log) :
20  arg(clover, compute_tr_log),
21  meta(clover)
22  {
23  writeAuxString("stride=%d,prec=%lu,trlog=%s,twist=%s", arg.clover.stride, sizeof(store_t),
24  compute_tr_log ? "true" : "false", arg.twist ? "true" : "false");
25  if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
26 #ifdef JITIFY
27  create_jitify_program("kernels/clover_invert.cuh");
28 #endif
29  }
30 
31  apply(0);
32  if (compute_tr_log) {
33  arg.complete(*clover.TrLog());
34  comm_allreduce_array(clover.TrLog(), 2);
35  }
36  }
37 
38  void apply(const qudaStream_t &stream) {
39  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
40  using Arg = decltype(arg);
41  if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
42 #ifdef JITIFY
43  using namespace jitify::reflection;
44  jitify_error = program->kernel("quda::cloverInvertKernel")
45  .instantiate((int)tp.block.x, Type<Arg>(), arg.compute_tr_log, arg.twist)
46  .configure(tp.grid, tp.block, tp.shared_bytes, stream)
47  .launch(arg);
48 #else
49  if (arg.compute_tr_log) {
50  if (arg.twist) {
51  errorQuda("Not instantiated");
52  } else {
53  LAUNCH_KERNEL_LOCAL_PARITY(cloverInvertKernel, (*this), tp, stream, arg, Arg, true, false);
54  }
55  } else {
56  if (arg.twist) {
57  qudaLaunchKernel(cloverInvertKernel<1,Arg,false,true>, tp, stream, arg);
58  } else {
59  qudaLaunchKernel(cloverInvertKernel<1,Arg,false,false>, tp, stream, arg);
60  }
61  }
62 #endif
63  } else {
64  errorQuda("Not implemented");
65  }
66  }
67 
68  TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
69  long long flops() const { return 0; }
70  long long bytes() const { return 2*arg.clover.volumeCB*(arg.inverse.Bytes() + arg.clover.Bytes()); }
71  void preTune() { if (arg.clover.clover == arg.inverse.clover) arg.inverse.save(); }
72  void postTune() { if (arg.clover.clover == arg.inverse.clover) arg.inverse.load(); }
73  };
74 
75  // this is the function that is actually called, from here on down we instantiate all required templates
76  void cloverInvert(CloverField &clover, bool computeTraceLog)
77  {
78 #ifdef GPU_CLOVER_DIRAC
79  instantiate<CloverInvert>(clover, computeTraceLog);
80 #else
81  errorQuda("Clover has not been built");
82 #endif
83  }
84 
85 } // namespace quda