2 #include <quda_internal.h>
3 #include <color_spinor_field.h>
6 #include <contract_quda.h>
7 #include <jitify_helper.cuh>
8 #include <kernels/contraction.cuh>
13 template <typename real, typename Arg> class Contraction : TunableVectorY
17 const ColorSpinorField &x;
18 const ColorSpinorField &y;
19 const QudaContractType cType;
22 bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
23 unsigned int minThreads() const { return arg.threads; }
26 Contraction(Arg &arg, const ColorSpinorField &x, const ColorSpinorField &y, const QudaContractType cType) :
34 case QUDA_CONTRACT_TYPE_OPEN: strcat(aux, "open,"); break;
35 case QUDA_CONTRACT_TYPE_DR: strcat(aux, "degrand-rossi,"); break;
36 default: errorQuda("Unexpected contraction type %d", cType);
38 strcat(aux, x.AuxString());
40 create_jitify_program("kernels/contraction.cuh");
44 void apply(const qudaStream_t &stream)
46 if (x.Location() == QUDA_CUDA_FIELD_LOCATION) {
47 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
49 std::string function_name;
51 case QUDA_CONTRACT_TYPE_OPEN: function_name = "quda::computeColorContraction"; break;
52 case QUDA_CONTRACT_TYPE_DR: function_name = "quda::computeDegrandRossiContraction"; break;
53 default: errorQuda("Unexpected contraction type %d", cType);
56 using namespace jitify::reflection;
57 jitify_error = program->kernel(function_name)
58 .instantiate(Type<real>(), Type<Arg>())
59 .configure(tp.grid, tp.block, tp.shared_bytes, stream)
63 case QUDA_CONTRACT_TYPE_OPEN: qudaLaunchKernel(computeColorContraction<real, Arg>, tp, stream, arg); break;
64 case QUDA_CONTRACT_TYPE_DR: qudaLaunchKernel(computeDegrandRossiContraction<real, Arg>, tp, stream, arg); break;
65 default: errorQuda("Unexpected contraction type %d", cType);
69 errorQuda("CPU not supported yet\n");
73 TuneKey tuneKey() const { return TuneKey(x.VolString(), typeid(*this).name(), aux); }
78 long long flops() const
80 if (cType == QUDA_CONTRACT_TYPE_OPEN)
81 return 16 * 3 * 6ll * x.Volume();
83 return ((16 * 3 * 6ll) + (16 * (4 + 12))) * x.Volume();
86 long long bytes() const
88 return x.Bytes() + y.Bytes() + x.Nspin() * x.Nspin() * x.Volume() * sizeof(complex<real>);
92 template <typename real>
93 void contract_quda(const ColorSpinorField &x, const ColorSpinorField &y, complex<real> *result,
94 const QudaContractType cType)
96 ContractionArg<real> arg(x, y, result);
97 Contraction<real, ContractionArg<real>> contraction(arg, x, y, cType);
99 qudaDeviceSynchronize();
104 void contractQuda(const ColorSpinorField &x, const ColorSpinorField &y, void *result, const QudaContractType cType)
107 checkPrecision(x, y);
109 if (x.GammaBasis() != QUDA_DEGRAND_ROSSI_GAMMA_BASIS || y.GammaBasis() != QUDA_DEGRAND_ROSSI_GAMMA_BASIS)
110 errorQuda("Unexpected gamma basis x=%d y=%d", x.GammaBasis(), y.GammaBasis());
111 if (x.Ncolor() != 3 || y.Ncolor() != 3) errorQuda("Unexpected number of colors x=%d y=%d", x.Ncolor(), y.Ncolor());
112 if (x.Nspin() != 4 || y.Nspin() != 4) errorQuda("Unexpected number of spins x=%d y=%d", x.Nspin(), y.Nspin());
114 if (x.Precision() == QUDA_SINGLE_PRECISION) {
115 contract_quda<float>(x, y, (complex<float> *)result, cType);
116 } else if (x.Precision() == QUDA_DOUBLE_PRECISION) {
117 contract_quda<double>(x, y, (complex<double> *)result, cType);
119 errorQuda("Precision %d not supported", x.Precision());
123 errorQuda("Contraction code has not been built");