QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
clover_invert.cu
Go to the documentation of this file.
1 #include <tune_quda.h>
2 #include <clover_field.h>
3 #include <launch_kernel.cuh>
4 
5 #include <jitify_helper.cuh>
7 
8 namespace quda {
9 
10  using namespace clover;
11 
12 #ifdef GPU_CLOVER_DIRAC
13 
14  template <typename Float, typename Arg>
15  class CloverInvert : TunableLocalParity {
16  Arg arg;
17  const CloverField &meta; // used for meta data only
18 
19  private:
20  bool tuneGridDim() const { return true; }
21 
22  public:
23  CloverInvert(Arg &arg, const CloverField &meta) : arg(arg), meta(meta) {
24  writeAuxString("stride=%d,prec=%lu,trlog=%s,twist=%s", arg.clover.stride, sizeof(Float),
25  arg.computeTraceLog ? "true" : "false", arg.twist ? "true" : "false");
26  if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
27 #ifdef JITIFY
28  create_jitify_program("kernels/clover_invert.cuh");
29 #endif
30  }
31  }
32 
33  virtual ~CloverInvert() { ; }
34 
35  void apply(const cudaStream_t &stream) {
36  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
37  arg.result_h[0] = make_double2(0.,0.);
38  if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
39 #ifdef JITIFY
40  using namespace jitify::reflection;
41  jitify_error = program->kernel("quda::cloverInvertKernel")
42  .instantiate((int)tp.block.x, Type<Float>(), Type<Arg>(), arg.computeTraceLog, arg.twist)
43  .configure(tp.grid, tp.block, tp.shared_bytes, stream)
44  .launch(arg);
45 #else
46  if (arg.computeTraceLog) {
47  if (arg.twist) {
48  errorQuda("Not instantiated");
49  } else {
50  LAUNCH_KERNEL_LOCAL_PARITY(cloverInvertKernel, tp, stream, arg, Float, Arg, true, false);
51  }
52  } else {
53  if (arg.twist) {
54  cloverInvertKernel<1,Float,Arg,false,true> <<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
55  } else {
56  cloverInvertKernel<1,Float,Arg,false,false> <<<tp.grid,tp.block,tp.shared_bytes,stream>>>(arg);
57  }
58  }
59 #endif
60  } else {
61  if (arg.computeTraceLog) {
62  if (arg.twist) {
63  cloverInvert<Float, Arg, true, true>(arg);
64  } else {
65  cloverInvert<Float, Arg, true, false>(arg);
66  }
67  } else {
68  if (arg.twist) {
69  cloverInvert<Float, Arg, false, true>(arg);
70  } else {
71  cloverInvert<Float, Arg, false, false>(arg);
72  }
73  }
74  }
75  }
76 
77  TuneKey tuneKey() const {
78  return TuneKey(meta.VolString(), typeid(*this).name(), aux);
79  }
80 
81  long long flops() const { return 0; }
82  long long bytes() const { return 2*arg.clover.volumeCB*(arg.inverse.Bytes() + arg.clover.Bytes()); }
83 
84  void preTune() { if (arg.clover.clover == arg.inverse.clover) arg.inverse.save(); }
85  void postTune() { if (arg.clover.clover == arg.inverse.clover) arg.inverse.load(); }
86 
87  };
88 
89  template <typename Float>
90  void cloverInvert(CloverField &clover, bool computeTraceLog) {
91  CloverInvertArg<Float> arg(clover, computeTraceLog);
92  CloverInvert<Float,CloverInvertArg<Float>> invert(arg, clover);
93  invert.apply(0);
94 
95  if (arg.computeTraceLog) {
97  comm_allreduce_array((double*)arg.result_h, 2);
98  clover.TrLog()[0] = arg.result_h[0].x;
99  clover.TrLog()[1] = arg.result_h[0].y;
100  }
101  }
102 
103 #endif
104 
105  // this is the function that is actually called, from here on down we instantiate all required templates
106  void cloverInvert(CloverField &clover, bool computeTraceLog) {
107 
108 #ifdef GPU_CLOVER_DIRAC
109  if (clover.Precision() == QUDA_HALF_PRECISION && clover.Order() > 4)
110  errorQuda("Half precision not supported for order %d", clover.Order());
111 
112  if (clover.Precision() == QUDA_DOUBLE_PRECISION) {
113  cloverInvert<double>(clover, computeTraceLog);
114  } else if (clover.Precision() == QUDA_SINGLE_PRECISION) {
115  cloverInvert<float>(clover, computeTraceLog);
116  } else {
117  errorQuda("Precision %d not supported", clover.Precision());
118  }
119 #else
120  errorQuda("Clover has not been built");
121 #endif
122  }
123 
124 } // namespace quda
double * TrLog() const
Definition: clover_field.h:88
#define LAUNCH_KERNEL_LOCAL_PARITY(kernel, tp, stream, arg,...)
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define errorQuda(...)
Definition: util_quda.h:121
Helper file when using jitify run-time compilation. This file should be included in source code...
cudaStream_t * stream
void comm_allreduce_array(double *data, size_t size)
Definition: comm_mpi.cpp:272
const char * VolString() const
QudaCloverFieldOrder Order() const
Definition: clover_field.h:93
void cloverInvert(CloverField &clover, bool computeTraceLog)
This function compute the Cholesky decomposition of each clover matrix and stores the clover inverse ...
#define qudaDeviceSynchronize()
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
__global__ void cloverInvertKernel(Arg arg)
QudaFieldLocation Location() const
unsigned long long flops
Definition: blas_quda.cu:22
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
const int volumeCB
Definition: spinor_noise.cu:26
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
QudaPrecision Precision() const
unsigned long long bytes
Definition: blas_quda.cu:23