9 using namespace dslash_aux;
12 #ifdef READ_SPINOR_SINGLE 13 #undef READ_SPINOR_SINGLE 20 #ifndef _TWIST_QUDA_CONTRACT 21 #error "Contraction core undefined" 24 #ifndef _TWIST_QUDA_CONTRACT_PLUS 25 #error "Contraction core (plus) undefined" 28 #ifndef _TWIST_QUDA_CONTRACT_MINUS 29 #error "Contraction core (minus) undefined" 32 #define checkSpinor(a, b) \ 34 if (a.Precision() != b.Precision()) \ 35 errorQuda("precisions do not match: %d %d", a.Precision(), b.Precision()); \ 36 if (a.Length() != b.Length()) \ 37 errorQuda("lengths do not match: %d %d", a.Length(), b.Length()); \ 38 if (a.Stride() != b.Stride()) \ 39 errorQuda("strides do not match: %d %d", a.Stride(), b.Stride()); \ 46 template <
typename Float2,
typename rFloat>
47 class ContractCuda :
public Tunable {
81 unsigned int sharedBytesPerThread()
const {
return 16*
sizeof(rFloat); }
82 unsigned int sharedBytesPerBlock(
const TuneParam &
param)
const {
return 0; }
83 bool tuneGridDim()
const {
return false; }
84 unsigned int minThreads()
const {
return x.X(0) *
x.X(1) *
x.X(2) *
x.X(3); }
86 char *saveOut, *saveOutNorm;
88 void fillAux(
QudaContractType contract_type,
const char *contract_str) {
strcpy(aux[contract_type], contract_str); }
92 x(
x),
y(
y), result(result),
parity(
parity), contract_type(contract_type), nTSlice(-1) {
104 dslashParam.
dc =
y.getDslashConstant();
105 bindSpinorTex<Float2>(&
x, &
y);
108 ContractCuda(
const cudaColorSpinorField &
x,
const cudaColorSpinorField &
y,
void *result,
const QudaParity parity,
const QudaContractType contract_type,
const int tSlice) :
109 x(
x),
y(
y), result(result),
parity(
parity), contract_type(contract_type), nTSlice(tSlice) {
122 dslashParam.
Vsh = (
x.X(0)*
x.X(1)*
x.X(2)) /
x.SiteSubset();
123 dslashParam.
dc =
y.getDslashConstant();
126 virtual ~ContractCuda() { unbindSpinorTex<Float2>(&
x, &
y); }
130 TuneKey tuneKey()
const 132 return TuneKey(
x.VolString(),
typeid(*this).name(), aux[contract_type]);
135 void apply(
const cudaStream_t &
stream)
138 switch (contract_type)
142 contractGamma5Kernel<<<tp.grid, tp.block, tp.shared_bytes>>>((rFloat*)result, (Float2*)
x.V(), (Float2*)
y.V(),
x.Stride(),
parity, dslashParam);
154 contractKernel<<<tp.grid, tp.block, tp.shared_bytes>>>((rFloat*)result, (Float2*)
x.V(), (Float2*)
y.V(),
x.Stride(),
parity, dslashParam);
158 contractPlusKernel<<<tp.grid, tp.block, tp.shared_bytes>>>((rFloat*)result, (Float2*)
x.V(), (Float2*)
y.V(),
x.Stride(),
parity, dslashParam);
162 contractMinusKernel<<<tp.grid, tp.block, tp.shared_bytes>>>((rFloat*)result, (Float2*)
x.V(), (Float2*)
y.V(),
x.Stride(),
parity, dslashParam);
166 contractTsliceKernel<<<tp.grid, tp.block, tp.shared_bytes>>>((rFloat*)result, (Float2*)
x.V(), (Float2*)
y.V(),
x.Stride(), nTSlice,
parity, dslashParam);
180 saveOut =
new char[dslashParam.
threads*
sizeof(Float2)*32];
181 cudaMemcpy(saveOut, result, dslashParam.
threads*
sizeof(Float2)*32, cudaMemcpyDeviceToHost);
185 cudaMemcpy(result, saveOut, dslashParam.
threads*
sizeof(Float2)*32, cudaMemcpyHostToDevice);
189 long long flops()
const {
return 120ll *
x.VolumeCB(); }
190 long long bytes()
const {
return x.Bytes() +
x.NormBytes() +
y.Bytes() +
y.NormBytes(); }
206 errorQuda(
"No time-slice specified for contraction\n");
216 contract =
new ContractCuda<double2,double2>(
x,
y, result,
parity, contract_type);
218 contract =
new ContractCuda<float4,float2>(
x,
y, result,
parity, contract_type);
220 errorQuda(
"Half precision not supported for gamma5 kernel yet");
237 errorQuda(
"Contraction code has not been built");
253 errorQuda(
"No time-slice input allowed for volume contractions\n");
263 contract =
new ContractCuda<double2,double2>(
x,
y, result,
parity, contract_type, nTSlice);
265 contract =
new ContractCuda<float4,float2>(
x,
y, result,
parity, contract_type, nTSlice);
267 errorQuda(
"Half precision not supported for gamma5 kernel yet");
283 errorQuda(
"Contraction code has not been built");
void contract(const cudaColorSpinorField x, const cudaColorSpinorField y, void *ctrn, const QudaContractType cType)
__global__ void contractKernel(double2 *out, double2 *in1, double2 *in2, int myStride, const int Parity, const DslashParam param)
__global__ void contractTslicePlusKernel(double2 *out, double2 *in1, double2 *in2, int myStride, const int Tslice, const int Parity, const DslashParam param)
QudaVerbosity getVerbosity()
char * strcpy(char *__dst, const char *__src)
__global__ void contractGamma5MinusKernel(double2 *out, double2 *in1, double2 *in2, int myStride, const int Parity, const DslashParam param)
cudaError_t qudaStreamSynchronize(cudaStream_t &stream)
Wrapper around cudaStreamSynchronize or cuStreamSynchronize.
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
__global__ void contractMinusKernel(double2 *out, double2 *in1, double2 *in2, int myStride, const int Parity, const DslashParam param)
__global__ void contractTsliceKernel(double2 *out, double2 *in1, double2 *in2, int myStride, const int Tslice, const int Parity, const DslashParam param)
enum QudaParity_s QudaParity
__global__ void contractGamma5PlusKernel(double2 *out, double2 *in1, double2 *in2, int myStride, const int Parity, const DslashParam param)
void contractCuda(const cudaColorSpinorField &x, const cudaColorSpinorField &y, void *result, const QudaContractType contract_type, const QudaParity parity, TimeProfile &profile)
__global__ void contractPlusKernel(double2 *out, double2 *in1, double2 *in2, int myStride, const int Parity, const DslashParam param)
enum QudaContractType_s QudaContractType
__global__ void contractGamma5Kernel(double2 *out, double2 *in1, double2 *in2, int myStride, const int Parity, const DslashParam param)
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
__global__ void contractTsliceMinusKernel(double2 *out, double2 *in1, double2 *in2, int myStride, const int Tslice, const int Parity, const DslashParam param)