QUDA  v1.1.0
A library for QCD on GPUs
contract.cu
Go to the documentation of this file.
1 #include <tune_quda.h>
2 #include <quda_internal.h>
3 #include <color_spinor_field.h>
4 #include <blas_quda.h>
5 
6 #include <contract_quda.h>
7 #include <jitify_helper.cuh>
8 #include <kernels/contraction.cuh>
9 
10 namespace quda {
11 
12 #ifdef GPU_CONTRACT
13  template <typename real, typename Arg> class Contraction : TunableVectorY
14  {
15 protected:
16  Arg &arg;
17  const ColorSpinorField &x;
18  const ColorSpinorField &y;
19  const QudaContractType cType;
20 
21 private:
22  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
23  unsigned int minThreads() const { return arg.threads; }
24 
25 public:
26  Contraction(Arg &arg, const ColorSpinorField &x, const ColorSpinorField &y, const QudaContractType cType) :
27  TunableVectorY(2),
28  arg(arg),
29  x(x),
30  y(y),
31  cType(cType)
32  {
33  switch (cType) {
34  case QUDA_CONTRACT_TYPE_OPEN: strcat(aux, "open,"); break;
35  case QUDA_CONTRACT_TYPE_DR: strcat(aux, "degrand-rossi,"); break;
36  default: errorQuda("Unexpected contraction type %d", cType);
37  }
38  strcat(aux, x.AuxString());
39 #ifdef JITIFY
40  create_jitify_program("kernels/contraction.cuh");
41 #endif
42  }
43 
44  void apply(const qudaStream_t &stream)
45  {
46  if (x.Location() == QUDA_CUDA_FIELD_LOCATION) {
47  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
48 #ifdef JITIFY
49  std::string function_name;
50  switch (cType) {
51  case QUDA_CONTRACT_TYPE_OPEN: function_name = "quda::computeColorContraction"; break;
52  case QUDA_CONTRACT_TYPE_DR: function_name = "quda::computeDegrandRossiContraction"; break;
53  default: errorQuda("Unexpected contraction type %d", cType);
54  }
55 
56  using namespace jitify::reflection;
57  jitify_error = program->kernel(function_name)
58  .instantiate(Type<real>(), Type<Arg>())
59  .configure(tp.grid, tp.block, tp.shared_bytes, stream)
60  .launch(arg);
61 #else
62  switch (cType) {
63  case QUDA_CONTRACT_TYPE_OPEN: qudaLaunchKernel(computeColorContraction<real, Arg>, tp, stream, arg); break;
64  case QUDA_CONTRACT_TYPE_DR: qudaLaunchKernel(computeDegrandRossiContraction<real, Arg>, tp, stream, arg); break;
65  default: errorQuda("Unexpected contraction type %d", cType);
66  }
67 #endif
68  } else {
69  errorQuda("CPU not supported yet\n");
70  }
71  }
72 
73  TuneKey tuneKey() const { return TuneKey(x.VolString(), typeid(*this).name(), aux); }
74 
75  void preTune() {}
76  void postTune() {}
77 
78  long long flops() const
79  {
80  if (cType == QUDA_CONTRACT_TYPE_OPEN)
81  return 16 * 3 * 6ll * x.Volume();
82  else
83  return ((16 * 3 * 6ll) + (16 * (4 + 12))) * x.Volume();
84  }
85 
86  long long bytes() const
87  {
88  return x.Bytes() + y.Bytes() + x.Nspin() * x.Nspin() * x.Volume() * sizeof(complex<real>);
89  }
90  };
91 
92  template <typename real>
93  void contract_quda(const ColorSpinorField &x, const ColorSpinorField &y, complex<real> *result,
94  const QudaContractType cType)
95  {
96  ContractionArg<real> arg(x, y, result);
97  Contraction<real, ContractionArg<real>> contraction(arg, x, y, cType);
98  contraction.apply(0);
99  qudaDeviceSynchronize();
100  }
101 
102 #endif
103 
104  void contractQuda(const ColorSpinorField &x, const ColorSpinorField &y, void *result, const QudaContractType cType)
105  {
106 #ifdef GPU_CONTRACT
107  checkPrecision(x, y);
108 
109  if (x.GammaBasis() != QUDA_DEGRAND_ROSSI_GAMMA_BASIS || y.GammaBasis() != QUDA_DEGRAND_ROSSI_GAMMA_BASIS)
110  errorQuda("Unexpected gamma basis x=%d y=%d", x.GammaBasis(), y.GammaBasis());
111  if (x.Ncolor() != 3 || y.Ncolor() != 3) errorQuda("Unexpected number of colors x=%d y=%d", x.Ncolor(), y.Ncolor());
112  if (x.Nspin() != 4 || y.Nspin() != 4) errorQuda("Unexpected number of spins x=%d y=%d", x.Nspin(), y.Nspin());
113 
114  if (x.Precision() == QUDA_SINGLE_PRECISION) {
115  contract_quda<float>(x, y, (complex<float> *)result, cType);
116  } else if (x.Precision() == QUDA_DOUBLE_PRECISION) {
117  contract_quda<double>(x, y, (complex<double> *)result, cType);
118  } else {
119  errorQuda("Precision %d not supported", x.Precision());
120  }
121 
122 #else
123  errorQuda("Contraction code has not been built");
124 #endif
125  }
126 } // namespace quda