QUDA  0.9.0
contract.cu
Go to the documentation of this file.
1 namespace quda {
2 
3  namespace dslash_aux {
4  #include <dslash_constants.h>
5  #include <dslash_textures.h>
6  #include <io_spinor.h>
7  }
8 
9  using namespace dslash_aux;
10 
11 #ifdef GPU_CONTRACT
12 #ifdef READ_SPINOR_SINGLE
13 #undef READ_SPINOR_SINGLE
14 #endif
15 
16 #include "contract_core.h"
17 #include "contract_core_plus.h"
18 #include "contract_core_minus.h"
19 
20 #ifndef _TWIST_QUDA_CONTRACT
21 #error "Contraction core undefined"
22 #endif
23 
24 #ifndef _TWIST_QUDA_CONTRACT_PLUS
25 #error "Contraction core (plus) undefined"
26 #endif
27 
28 #ifndef _TWIST_QUDA_CONTRACT_MINUS
29 #error "Contraction core (minus) undefined"
30 #endif
31 
32 #define checkSpinor(a, b) \
33  { \
34  if (a.Precision() != b.Precision()) \
35  errorQuda("precisions do not match: %d %d", a.Precision(), b.Precision()); \
36  if (a.Length() != b.Length()) \
37  errorQuda("lengths do not match: %d %d", a.Length(), b.Length()); \
38  if (a.Stride() != b.Stride()) \
39  errorQuda("strides do not match: %d %d", a.Stride(), b.Stride()); \
40  }
41 
46  template <typename Float2, typename rFloat>
47  class ContractCuda : public Tunable {
48 
49  private:
50  DslashParam dslashParam;
51  const cudaColorSpinorField x; // Spinor to be contracted
52  const cudaColorSpinorField y; // Spinor to be contracted
53  const QudaParity parity; // Parity of the field, actual kernels act on parity spinors
54  const QudaContractType contract_type; // Type of contraction, to be detailed later
55 
75  void *result; // The output array with the result of the contraction
76 
77  const int nTSlice; // Time-slice in case of time-dilution
78 
79  char aux[16][TuneKey::aux_n]; // For tuning purposes
80 
81  unsigned int sharedBytesPerThread() const { return 16*sizeof(rFloat); }
82  unsigned int sharedBytesPerBlock(const TuneParam &param) const { return 0; }
83  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
84  unsigned int minThreads() const { return x.X(0) * x.X(1) * x.X(2) * x.X(3); }
85 
86  char *saveOut, *saveOutNorm;
87 
88  void fillAux(QudaContractType contract_type, const char *contract_str) { strcpy(aux[contract_type], contract_str); }
89 
90  public:
91  ContractCuda(const cudaColorSpinorField &x, const cudaColorSpinorField &y, void *result, const QudaParity parity, const QudaContractType contract_type) :
92  x(x), y(y), result(result), parity(parity), contract_type(contract_type), nTSlice(-1) {
93  fillAux(QUDA_CONTRACT, "type=plain");
94  fillAux(QUDA_CONTRACT_PLUS, "type=plain-plus");
95  fillAux(QUDA_CONTRACT_MINUS, "type=plain-minus");
96  fillAux(QUDA_CONTRACT_GAMMA5, "type=gamma5");
97  fillAux(QUDA_CONTRACT_GAMMA5_PLUS, "type=gamma5-plus");
98  fillAux(QUDA_CONTRACT_GAMMA5_MINUS, "type=gamma5-minus");
99  fillAux(QUDA_CONTRACT_TSLICE, "type=tslice");
100  fillAux(QUDA_CONTRACT_TSLICE_PLUS, "type=tslice-plus");
101  fillAux(QUDA_CONTRACT_TSLICE_MINUS, "type=tslice-minus");
102 
103  dslashParam.threads = x.Volume();
104  dslashParam.dc = y.getDslashConstant();
105  bindSpinorTex<Float2>(&x, &y);
106  }
107 
108  ContractCuda(const cudaColorSpinorField &x, const cudaColorSpinorField &y, void *result, const QudaParity parity, const QudaContractType contract_type, const int tSlice) :
109  x(x), y(y), result(result), parity(parity), contract_type(contract_type), nTSlice(tSlice) {
110  fillAux(QUDA_CONTRACT, "type=plain");
111  fillAux(QUDA_CONTRACT_PLUS, "type=plain-plus");
112  fillAux(QUDA_CONTRACT_MINUS, "type=plain-minus");
113  fillAux(QUDA_CONTRACT_GAMMA5, "type=gamma5");
114  fillAux(QUDA_CONTRACT_GAMMA5_PLUS, "type=gamma5-plus");
115  fillAux(QUDA_CONTRACT_GAMMA5_MINUS, "type=gamma5-minus");
116  fillAux(QUDA_CONTRACT_TSLICE, "type=tslice");
117  fillAux(QUDA_CONTRACT_TSLICE_PLUS, "type=tslice-plus");
118  fillAux(QUDA_CONTRACT_TSLICE_MINUS, "type=tslice-minus");
119 
120  DslashParam dslashParam;
121  dslashParam.threads = x.X(0)*x.X(1)*x.X(2);
122  dslashParam.Vsh = (x.X(0)*x.X(1)*x.X(2)) / x.SiteSubset();
123  dslashParam.dc = y.getDslashConstant();
124  }
125 
126  virtual ~ContractCuda() { unbindSpinorTex<Float2>(&x, &y); } // if (tSlice != NULL) { cudaFreeHost(tSlice); } }
127 
128  QudaContractType ContractType() const { return contract_type; }
129 
130  TuneKey tuneKey() const
131  {
132  return TuneKey(x.VolString(), typeid(*this).name(), aux[contract_type]);
133  }
134 
135  void apply(const cudaStream_t &stream)
136  {
137  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
138  switch (contract_type)
139  {
140  default:
141  case QUDA_CONTRACT_GAMMA5: // Calculates the volume contraction (x^+ g5)_\mu y_\nu and stores it in result
142  contractGamma5Kernel<<<tp.grid, tp.block, tp.shared_bytes>>>((rFloat*)result, (Float2*)x.V(), (Float2*)y.V(), x.Stride(), parity, dslashParam);
143  break;
144 
145  case QUDA_CONTRACT_GAMMA5_PLUS: // Calculates the volume contraction (x^+ g5)_\mu y_\nu and adds it to result
146  contractGamma5PlusKernel<<<tp.grid, tp.block, tp.shared_bytes>>>((rFloat*)result, (Float2*)x.V(), (Float2*)y.V(), x.Stride(), parity, dslashParam);
147  break;
148 
149  case QUDA_CONTRACT_GAMMA5_MINUS: // Calculates the volume contraction (x^+ g5)_\mu y_\nu and substracts it from result
150  contractGamma5MinusKernel<<<tp.grid, tp.block, tp.shared_bytes>>>((rFloat*)result, (Float2*)x.V(), (Float2*)y.V(), x.Stride(), parity, dslashParam);
151  break;
152 
153  case QUDA_CONTRACT: // Calculates the volume contraction x^+_\mu y_\nu and stores it in result
154  contractKernel<<<tp.grid, tp.block, tp.shared_bytes>>>((rFloat*)result, (Float2*)x.V(), (Float2*)y.V(), x.Stride(), parity, dslashParam);
155  break;
156 
157  case QUDA_CONTRACT_PLUS: // Calculates the volume contraction x^+_\mu y_\nu and adds it to result
158  contractPlusKernel<<<tp.grid, tp.block, tp.shared_bytes>>>((rFloat*)result, (Float2*)x.V(), (Float2*)y.V(), x.Stride(), parity, dslashParam);
159  break;
160 
161  case QUDA_CONTRACT_MINUS: // Calculates the volume contraction x^+_\mu y_\nu and substracts it from result
162  contractMinusKernel<<<tp.grid, tp.block, tp.shared_bytes>>>((rFloat*)result, (Float2*)x.V(), (Float2*)y.V(), x.Stride(), parity, dslashParam);
163  break;
164 
165  case QUDA_CONTRACT_TSLICE: // Calculates the time-slice contraction x^+_\mu y_\nu and stores it in result
166  contractTsliceKernel<<<tp.grid, tp.block, tp.shared_bytes>>>((rFloat*)result, (Float2*)x.V(), (Float2*)y.V(), x.Stride(), nTSlice, parity, dslashParam);
167  break;
168 
169  case QUDA_CONTRACT_TSLICE_PLUS: // Calculates the time-slice contraction x^+_\mu y_\nu and adds it to result
170  contractTslicePlusKernel<<<tp.grid, tp.block, tp.shared_bytes>>>((rFloat*)result, (Float2*)x.V(), (Float2*)y.V(), x.Stride(), nTSlice, parity, dslashParam);
171  break;
172 
173  case QUDA_CONTRACT_TSLICE_MINUS: // Calculates the time-slice contraction x^+_\mu y_\nu and substracts it from result
174  contractTsliceMinusKernel<<<tp.grid, tp.block, tp.shared_bytes>>>((rFloat*)result, (Float2*)x.V(), (Float2*)y.V(), x.Stride(), nTSlice, parity, dslashParam);
175  break;
176  }
177  }
178 
179  void preTune() {
180  saveOut = new char[dslashParam.threads*sizeof(Float2)*32];
181  cudaMemcpy(saveOut, result, dslashParam.threads*sizeof(Float2)*32, cudaMemcpyDeviceToHost);
182  }
183 
184  void postTune() {
185  cudaMemcpy(result, saveOut, dslashParam.threads*sizeof(Float2)*32, cudaMemcpyHostToDevice);
186  delete[] saveOut;
187  }
188 
189  long long flops() const { return 120ll * x.VolumeCB(); }
190  long long bytes() const { return x.Bytes() + x.NormBytes() + y.Bytes() + y.NormBytes(); }
191  };
192 #endif
193 
202  void contractCuda(const cudaColorSpinorField &x, const cudaColorSpinorField &y, void *result, const QudaContractType contract_type, const QudaParity parity, TimeProfile &profile)
203  {
204 #ifdef GPU_CONTRACT
205  if ((contract_type == QUDA_CONTRACT_TSLICE) || (contract_type == QUDA_CONTRACT_TSLICE_PLUS) || (contract_type == QUDA_CONTRACT_TSLICE_MINUS)) {
206  errorQuda("No time-slice specified for contraction\n");
207  return;
208  }
209 
210  profile.TPSTART(QUDA_PROFILE_TOTAL);
211  profile.TPSTART(QUDA_PROFILE_INIT);
212 
213  Tunable *contract = 0;
214 
215  if (x.Precision() == QUDA_DOUBLE_PRECISION) {
216  contract = new ContractCuda<double2,double2>(x, y, result, parity, contract_type);
217  } else if (x.Precision() == QUDA_SINGLE_PRECISION) {
218  contract = new ContractCuda<float4,float2>(x, y, result, parity, contract_type);
219  } else if (x.Precision() == QUDA_HALF_PRECISION) {
220  errorQuda("Half precision not supported for gamma5 kernel yet");
221  }
222  profile.TPSTOP(QUDA_PROFILE_INIT);
223 
224  profile.TPSTART(QUDA_PROFILE_COMPUTE);
225  contract->apply(streams[Nstream-1]);
227  profile.TPSTOP(QUDA_PROFILE_COMPUTE);
228 
229  profile.TPSTART(QUDA_PROFILE_EPILOGUE);
230  checkCudaError();
231 
232  delete contract;
233 
234  profile.TPSTOP(QUDA_PROFILE_EPILOGUE);
235  profile.TPSTOP(QUDA_PROFILE_TOTAL);
236 #else
237  errorQuda("Contraction code has not been built");
238 #endif
239  }
240 
248  void contractCuda(const cudaColorSpinorField &x, const cudaColorSpinorField &y, void *result, const QudaContractType contract_type,
249  const int nTSlice, const QudaParity parity, TimeProfile &profile)
250  {
251 #ifdef GPU_CONTRACT
252  if ((contract_type != QUDA_CONTRACT_TSLICE) || (contract_type != QUDA_CONTRACT_TSLICE_PLUS) || (contract_type != QUDA_CONTRACT_TSLICE_MINUS)) {
253  errorQuda("No time-slice input allowed for volume contractions\n");
254  return;
255  }
256 
257  profile.TPSTART(QUDA_PROFILE_TOTAL);
258  profile.TPSTART(QUDA_PROFILE_INIT);
259 
260  Tunable *contract = 0;
261 
262  if (x.Precision() == QUDA_DOUBLE_PRECISION) {
263  contract = new ContractCuda<double2,double2>(x, y, result, parity, contract_type, nTSlice);
264  } else if (x.Precision() == QUDA_SINGLE_PRECISION) {
265  contract = new ContractCuda<float4,float2>(x, y, result, parity, contract_type, nTSlice);
266  } else if (x.Precision() == QUDA_HALF_PRECISION) {
267  errorQuda("Half precision not supported for gamma5 kernel yet");
268  }
269  profile.TPSTOP(QUDA_PROFILE_INIT);
270 
271  profile.TPSTART(QUDA_PROFILE_COMPUTE);
272  contract->apply(streams[Nstream-1]);
274  profile.TPSTOP(QUDA_PROFILE_COMPUTE);
275 
276  profile.TPSTART(QUDA_PROFILE_EPILOGUE);
277  checkCudaError();
278  delete contract;
279 
280  profile.TPSTOP(QUDA_PROFILE_EPILOGUE);
281  profile.TPSTOP(QUDA_PROFILE_TOTAL);
282 #else
283  errorQuda("Contraction code has not been built");
284 #endif
285  }
286 
287 } // namespace quda
288 
void contract(const cudaColorSpinorField x, const cudaColorSpinorField y, void *ctrn, const QudaContractType cType)
__global__ void contractKernel(double2 *out, double2 *in1, double2 *in2, int myStride, const int Parity, const DslashParam param)
__global__ void contractTslicePlusKernel(double2 *out, double2 *in1, double2 *in2, int myStride, const int Tslice, const int Parity, const DslashParam param)
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:20
#define errorQuda(...)
Definition: util_quda.h:90
cudaStream_t * streams
cudaStream_t * stream
const int Nstream
char * strcpy(char *__dst, const char *__src)
__global__ void contractGamma5MinusKernel(double2 *out, double2 *in1, double2 *in2, int myStride, const int Parity, const DslashParam param)
QudaGaugeParam param
Definition: pack_test.cpp:17
cudaError_t qudaStreamSynchronize(cudaStream_t &stream)
Wrapper around cudaStreamSynchronize or cuStreamSynchronize.
DslashConstant dc
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:603
__global__ void contractMinusKernel(double2 *out, double2 *in1, double2 *in2, int myStride, const int Parity, const DslashParam param)
__global__ void contractTsliceKernel(double2 *out, double2 *in1, double2 *in2, int myStride, const int Tslice, const int Parity, const DslashParam param)
enum QudaParity_s QudaParity
__global__ void contractGamma5PlusKernel(double2 *out, double2 *in1, double2 *in2, int myStride, const int Parity, const DslashParam param)
static const int aux_n
Definition: tune_key.h:12
void contractCuda(const cudaColorSpinorField &x, const cudaColorSpinorField &y, void *result, const QudaContractType contract_type, const QudaParity parity, TimeProfile &profile)
Definition: contract.cu:202
unsigned long long flops
Definition: blas_quda.cu:42
__global__ void contractPlusKernel(double2 *out, double2 *in1, double2 *in2, int myStride, const int Parity, const DslashParam param)
enum QudaContractType_s QudaContractType
#define checkCudaError()
Definition: util_quda.h:129
__global__ void contractGamma5Kernel(double2 *out, double2 *in1, double2 *in2, int myStride, const int Parity, const DslashParam param)
Definition: contract_core.h:65
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:51
__global__ void contractTsliceMinusKernel(double2 *out, double2 *in1, double2 *in2, int myStride, const int Tslice, const int Parity, const DslashParam param)
QudaParity parity
Definition: covdev_test.cpp:53
unsigned long long bytes
Definition: blas_quda.cu:43