QUDA  1.0.0
contract.cu
Go to the documentation of this file.
1 #include <tune_quda.h>
2 #include <quda_internal.h>
3 #include <color_spinor_field.h>
4 #include <blas_quda.h>
5 
6 #include <contract_quda.h>
7 #include <jitify_helper.cuh>
9 
10 namespace quda {
11 
12 #ifdef GPU_CONTRACT
13  template <typename real, typename Arg> class Contraction : TunableVectorY
14  {
15 protected:
16  Arg &arg;
17  const ColorSpinorField &x;
18  const ColorSpinorField &y;
19  const QudaContractType cType;
20 
21 private:
22  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
23  unsigned int minThreads() const { return arg.threads; }
24 
25 public:
26  Contraction(Arg &arg, const ColorSpinorField &x, const ColorSpinorField &y, const QudaContractType cType) :
27  TunableVectorY(2),
28  arg(arg),
29  x(x),
30  y(y),
31  cType(cType)
32  {
33  switch (cType) {
34  case QUDA_CONTRACT_TYPE_OPEN: strcat(aux, "open,"); break;
35  case QUDA_CONTRACT_TYPE_DR: strcat(aux, "degrand-rossi,"); break;
36  default: errorQuda("Unexpected contraction type %d", cType);
37  }
38  strcat(aux, x.AuxString());
39 #ifdef JITIFY
40  create_jitify_program("kernels/contraction.cuh");
41 #endif
42  }
43  virtual ~Contraction() {}
44 
45  void apply(const cudaStream_t &stream)
46  {
47  if (x.Location() == QUDA_CUDA_FIELD_LOCATION) {
48  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
49 #ifdef JITIFY
50  std::string function_name;
51  switch (cType) {
52  case QUDA_CONTRACT_TYPE_OPEN: function_name = "quda::computeColorContraction"; break;
53  case QUDA_CONTRACT_TYPE_DR: function_name = "quda::computeDegrandRossiContraction"; break;
54  default: errorQuda("Unexpected contraction type %d", cType);
55  }
56 
57  using namespace jitify::reflection;
58  jitify_error = program->kernel(function_name)
59  .instantiate(Type<real>(), Type<Arg>())
60  .configure(tp.grid, tp.block, tp.shared_bytes, stream)
61  .launch(arg);
62 #else
63  switch (cType) {
64  case QUDA_CONTRACT_TYPE_OPEN: computeColorContraction<real><<<tp.grid, tp.block, tp.shared_bytes>>>(arg); break;
66  computeDegrandRossiContraction<real><<<tp.grid, tp.block, tp.shared_bytes>>>(arg);
67  break;
68  default: errorQuda("Unexpected contraction type %d", cType);
69  }
70 #endif
71  } else {
72  errorQuda("CPU not supported yet\n");
73  }
74  }
75 
76  TuneKey tuneKey() const { return TuneKey(x.VolString(), typeid(*this).name(), aux); }
77 
78  void preTune() {}
79  void postTune() {}
80 
81  long long flops() const
82  {
83  if (cType == QUDA_CONTRACT_TYPE_OPEN)
84  return 16 * 3 * 6ll * x.Volume();
85  else
86  return ((16 * 3 * 6ll) + (16 * (4 + 12))) * x.Volume();
87  }
88 
89  long long bytes() const
90  {
91  return x.Bytes() + y.Bytes() + x.Nspin() * x.Nspin() * x.Volume() * sizeof(complex<real>);
92  }
93  };
94 
95  template <typename real>
96  void contract_quda(const ColorSpinorField &x, const ColorSpinorField &y, complex<real> *result,
97  const QudaContractType cType)
98  {
99  ContractionArg<real> arg(x, y, result);
100  Contraction<real, ContractionArg<real>> contraction(arg, x, y, cType);
101  contraction.apply(0);
103  }
104 
105 #endif
106 
107  void contractQuda(const ColorSpinorField &x, const ColorSpinorField &y, void *result, const QudaContractType cType)
108  {
109 #ifdef GPU_CONTRACT
110  checkPrecision(x, y);
111 
113  errorQuda("Unexpected gamma basis x=%d y=%d", x.GammaBasis(), y.GammaBasis());
114  if (x.Ncolor() != 3 || y.Ncolor() != 3) errorQuda("Unexpected number of colors x=%d y=%d", x.Ncolor(), y.Ncolor());
115  if (x.Nspin() != 4 || y.Nspin() != 4) errorQuda("Unexpected number of spins x=%d y=%d", x.Nspin(), y.Nspin());
116 
117  if (x.Precision() == QUDA_SINGLE_PRECISION) {
118  contract_quda<float>(x, y, (complex<float> *)result, cType);
119  } else if (x.Precision() == QUDA_DOUBLE_PRECISION) {
120  contract_quda<double>(x, y, (complex<double> *)result, cType);
121  } else {
122  errorQuda("Precision %d not supported", x.Precision());
123  }
124 
125 #else
126  errorQuda("Contraction code has not been built");
127 #endif
128  }
129 } // namespace quda
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define checkPrecision(...)
#define errorQuda(...)
Definition: util_quda.h:121
Helper file when using jitify run-time compilation. This file should be included in source code...
cudaStream_t * stream
QudaGammaBasis GammaBasis() const
#define qudaDeviceSynchronize()
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
void contractQuda(const ColorSpinorField &x, const ColorSpinorField &y, void *result, QudaContractType cType)
Definition: contract.cu:107
unsigned long long flops
Definition: blas_quda.cu:22
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
enum QudaContractType_s QudaContractType
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
QudaPrecision Precision() const
unsigned long long bytes
Definition: blas_quda.cu:23