QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
dslash_quda.cuh
Go to the documentation of this file.
1 //#define DSLASH_TUNE_TILE
2 
3 #if (__COMPUTE_CAPABILITY__ >= 700)
4 // for running on Volta we set large shared memory mode to prefer hitting in L2
5 #define SET_CACHE(f) qudaFuncSetAttribute( (const void*)f, cudaFuncAttributePreferredSharedMemoryCarveout, (int)cudaSharedmemCarveoutMaxShared)
6 #else
7 #define SET_CACHE(f)
8 #endif
9 
10 #if 1
11 #define LAUNCH_KERNEL(f, grid, block, shared, stream, param) \
12  void *args[] = { &param }; \
13  void (*func)( const DslashParam ) = &(f); \
14  qudaLaunchKernel( (const void*)func, grid, block, args, shared, stream);
15 #else
16 #define LAUNCH_KERNEL(f, grid, block, shared, stream, param) f<<<grid, block, shared, stream>>>(param)
17 #endif
18 
19 #define EVEN_MORE_GENERIC_DSLASH(FUNC, FLOAT, DAG, X, kernel_type, gridDim, blockDim, shared, stream, param) \
20  if (x==0) { \
21  if (reconstruct == QUDA_RECONSTRUCT_NO) { \
22  SET_CACHE( FUNC ## FLOAT ## 18 ## DAG ## Kernel<kernel_type> ); \
23  LAUNCH_KERNEL( FUNC ## FLOAT ## 18 ## DAG ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
24  } else if (reconstruct == QUDA_RECONSTRUCT_12) { \
25  SET_CACHE( FUNC ## FLOAT ## 12 ## DAG ## Kernel<kernel_type> ); \
26  LAUNCH_KERNEL( FUNC ## FLOAT ## 12 ## DAG ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
27  } else if (reconstruct == QUDA_RECONSTRUCT_8) { \
28  SET_CACHE( FUNC ## FLOAT ## 8 ## DAG ## Kernel<kernel_type> ); \
29  LAUNCH_KERNEL( FUNC ## FLOAT ## 8 ## DAG ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
30  } \
31  } else { \
32  if (reconstruct == QUDA_RECONSTRUCT_NO) { \
33  SET_CACHE( FUNC ## FLOAT ## 18 ## DAG ## X ## Kernel<kernel_type> ); \
34  LAUNCH_KERNEL( FUNC ## FLOAT ## 18 ## DAG ## X ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
35  } else if (reconstruct == QUDA_RECONSTRUCT_12) { \
36  SET_CACHE( FUNC ## FLOAT ## 18 ## DAG ## X ## Kernel<kernel_type> ); \
37  LAUNCH_KERNEL( FUNC ## FLOAT ## 12 ## DAG ## X ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
38  } else if (reconstruct == QUDA_RECONSTRUCT_8) { \
39  SET_CACHE( FUNC ## FLOAT ## 18 ## DAG ## X ## Kernel<kernel_type> ); \
40  LAUNCH_KERNEL( FUNC ## FLOAT ## 8 ## DAG ## X ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
41  } \
42  }
43 
44 #define MORE_GENERIC_DSLASH(FUNC, DAG, X, kernel_type, gridDim, blockDim, shared, stream, param) \
45  if (typeid(sFloat) == typeid(double2)) { \
46  EVEN_MORE_GENERIC_DSLASH(FUNC, D, DAG, X, kernel_type, gridDim, blockDim, shared, stream, param) \
47  } else if (typeid(sFloat) == typeid(float4) || typeid(sFloat) == typeid(float2)) { \
48  EVEN_MORE_GENERIC_DSLASH(FUNC, S, DAG, X, kernel_type, gridDim, blockDim, shared, stream, param) \
49  } else if (typeid(sFloat)==typeid(short4) || typeid(sFloat) == typeid(short2)) { \
50  EVEN_MORE_GENERIC_DSLASH(FUNC, H, DAG, X, kernel_type, gridDim, blockDim, shared, stream, param) \
51  } else { \
52  errorQuda("Undefined precision type"); \
53  }
54 
55 
56 #define EVEN_MORE_GENERIC_STAGGERED_DSLASH(FUNC, FLOAT, DAG, X, kernel_type, gridDim, blockDim, shared, stream, param) \
57  if (x==0) { \
58  if (reconstruct == QUDA_RECONSTRUCT_NO) { \
59  LAUNCH_KERNEL( FUNC ## FLOAT ## 18 ## 18 ## DAG ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
60  } else if (reconstruct == QUDA_RECONSTRUCT_13) { \
61  LAUNCH_KERNEL( FUNC ## FLOAT ## 18 ## 13 ## DAG ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
62  } else if (reconstruct == QUDA_RECONSTRUCT_12) { \
63  LAUNCH_KERNEL( FUNC ## FLOAT ## 18 ## 12 ## DAG ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
64  } else if (reconstruct == QUDA_RECONSTRUCT_9) { \
65  LAUNCH_KERNEL( FUNC ## FLOAT ## 18 ## 9 ## DAG ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
66  } else if (reconstruct == QUDA_RECONSTRUCT_8) { \
67  LAUNCH_KERNEL( FUNC ## FLOAT ## 18 ## 8 ## DAG ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
68  } \
69  } else { \
70  if (reconstruct == QUDA_RECONSTRUCT_NO) { \
71  LAUNCH_KERNEL( FUNC ## FLOAT ## 18 ## 18 ## DAG ## X ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
72  } else if (reconstruct == QUDA_RECONSTRUCT_13) { \
73  LAUNCH_KERNEL( FUNC ## FLOAT ## 18 ## 13 ## DAG ## X ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
74  } else if (reconstruct == QUDA_RECONSTRUCT_12) { \
75  LAUNCH_KERNEL( FUNC ## FLOAT ## 18 ## 12 ## DAG ## X ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
76  } else if (reconstruct == QUDA_RECONSTRUCT_9) { \
77  LAUNCH_KERNEL( FUNC ## FLOAT ## 18 ## 9 ## DAG ## X ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
78  } else if (reconstruct == QUDA_RECONSTRUCT_8) { \
79  LAUNCH_KERNEL( FUNC ## FLOAT ## 18 ## 8 ## DAG ## X ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
80  } \
81  }
82 
83 #define MORE_GENERIC_STAGGERED_DSLASH(FUNC, DAG, X, kernel_type, gridDim, blockDim, shared, stream, param) \
84  if (typeid(sFloat) == typeid(double2)) { \
85  EVEN_MORE_GENERIC_STAGGERED_DSLASH(FUNC, D, DAG, X, kernel_type, gridDim, blockDim, shared, stream, param) \
86  } else if (typeid(sFloat) == typeid(float2)) { \
87  EVEN_MORE_GENERIC_STAGGERED_DSLASH(FUNC, S, DAG, X, kernel_type, gridDim, blockDim, shared, stream, param) \
88  } else if (typeid(sFloat)==typeid(short2)) { \
89  EVEN_MORE_GENERIC_STAGGERED_DSLASH(FUNC, H, DAG, X, kernel_type, gridDim, blockDim, shared, stream, param) \
90  } else { \
91  errorQuda("Undefined precision type"); \
92  }
93 
94 #ifndef MULTI_GPU
95 
96 #define GENERIC_DSLASH(FUNC, DAG, X, gridDim, blockDim, shared, stream, param) \
97  switch(param.kernel_type) { \
98  case INTERIOR_KERNEL: \
99  MORE_GENERIC_DSLASH(FUNC, DAG, X, INTERIOR_KERNEL, gridDim, blockDim, shared, stream, param) \
100  break; \
101  default: \
102  errorQuda("KernelType %d not defined for single GPU", param.kernel_type); \
103  }
104 
105 #define GENERIC_STAGGERED_DSLASH(FUNC, DAG, X, gridDim, blockDim, shared, stream, param) \
106  switch(param.kernel_type) { \
107  case INTERIOR_KERNEL: \
108  MORE_GENERIC_STAGGERED_DSLASH(FUNC, DAG, X, INTERIOR_KERNEL, gridDim, blockDim, shared, stream, param) \
109  break; \
110  default: \
111  errorQuda("KernelType %d not defined for single GPU", param.kernel_type); \
112  }
113 
114 
115 #else
116 
117 #define GENERIC_DSLASH(FUNC, DAG, X, gridDim, blockDim, shared, stream, param) \
118  switch(param.kernel_type) { \
119  case INTERIOR_KERNEL: \
120  MORE_GENERIC_DSLASH(FUNC, DAG, X, INTERIOR_KERNEL, gridDim, blockDim, shared, stream, param) \
121  break; \
122  case EXTERIOR_KERNEL_X: \
123  MORE_GENERIC_DSLASH(FUNC, DAG, X, EXTERIOR_KERNEL_X, gridDim, blockDim, shared, stream, param) \
124  break; \
125  case EXTERIOR_KERNEL_Y: \
126  MORE_GENERIC_DSLASH(FUNC, DAG, X, EXTERIOR_KERNEL_Y, gridDim, blockDim, shared, stream, param) \
127  break; \
128  case EXTERIOR_KERNEL_Z: \
129  MORE_GENERIC_DSLASH(FUNC, DAG, X, EXTERIOR_KERNEL_Z, gridDim, blockDim, shared, stream, param) \
130  break; \
131  case EXTERIOR_KERNEL_T: \
132  MORE_GENERIC_DSLASH(FUNC, DAG, X, EXTERIOR_KERNEL_T, gridDim, blockDim, shared, stream, param) \
133  break; \
134  case EXTERIOR_KERNEL_ALL: \
135  MORE_GENERIC_DSLASH(FUNC, DAG, X, EXTERIOR_KERNEL_ALL, gridDim, blockDim, shared, stream, param) \
136  break; \
137  default: \
138  break; \
139  }
140 
141 #define GENERIC_STAGGERED_DSLASH(FUNC, DAG, X, gridDim, blockDim, shared, stream, param) \
142  switch(param.kernel_type) { \
143  case INTERIOR_KERNEL: \
144  MORE_GENERIC_STAGGERED_DSLASH(FUNC, DAG, X, INTERIOR_KERNEL, gridDim, blockDim, shared, stream, param) \
145  break; \
146  case EXTERIOR_KERNEL_X: \
147  MORE_GENERIC_STAGGERED_DSLASH(FUNC, DAG, X, EXTERIOR_KERNEL_X, gridDim, blockDim, shared, stream, param) \
148  break; \
149  case EXTERIOR_KERNEL_Y: \
150  MORE_GENERIC_STAGGERED_DSLASH(FUNC, DAG, X, EXTERIOR_KERNEL_Y, gridDim, blockDim, shared, stream, param) \
151  break; \
152  case EXTERIOR_KERNEL_Z: \
153  MORE_GENERIC_STAGGERED_DSLASH(FUNC, DAG, X, EXTERIOR_KERNEL_Z, gridDim, blockDim, shared, stream, param) \
154  break; \
155  case EXTERIOR_KERNEL_T: \
156  MORE_GENERIC_STAGGERED_DSLASH(FUNC, DAG, X, EXTERIOR_KERNEL_T, gridDim, blockDim, shared, stream, param) \
157  break; \
158  case EXTERIOR_KERNEL_ALL: \
159  MORE_GENERIC_STAGGERED_DSLASH(FUNC, DAG, X, EXTERIOR_KERNEL_ALL, gridDim, blockDim, shared, stream, param) \
160  break; \
161  default: \
162  break; \
163  }
164 
165 
166 #endif
167 
168 // macro used for dslash types with dagger kernel defined (Wilson, domain wall, etc.)
169 #define DSLASH(FUNC, gridDim, blockDim, shared, stream, param) \
170  if (!dagger) { \
171  GENERIC_DSLASH(FUNC, , Xpay, gridDim, blockDim, shared, stream, param) \
172  } else { \
173  GENERIC_DSLASH(FUNC, Dagger, Xpay, gridDim, blockDim, shared, stream, param) \
174  }
175 
176 // macro used for staggered dslash
177 #define STAGGERED_DSLASH(gridDim, blockDim, shared, stream, param) \
178  if (!dagger) { \
179  GENERIC_DSLASH(staggeredDslash, , Axpy, gridDim, blockDim, shared, stream, param) \
180  } else { \
181  GENERIC_DSLASH(staggeredDslash, Dagger, Axpy, gridDim, blockDim, shared, stream, param) \
182  }
183 
184 // macro used for staggered dslash
185 #define STAGGERED_DSLASH_TIFR(gridDim, blockDim, shared, stream, param) \
186  if (!dagger) { \
187  GENERIC_DSLASH(staggeredDslashTIFR, , Axpy, gridDim, blockDim, shared, stream, param) \
188  } else { \
189  GENERIC_DSLASH(staggeredDslashTIFR, Dagger, Axpy, gridDim, blockDim, shared, stream, param) \
190  }
191 
192 #define IMPROVED_STAGGERED_DSLASH(gridDim, blockDim, shared, stream, param) \
193  if (!dagger) { \
194  GENERIC_STAGGERED_DSLASH(improvedStaggeredDslash, , Axpy, gridDim, blockDim, shared, stream, param) \
195  } else { \
196  GENERIC_STAGGERED_DSLASH(improvedStaggeredDslash, Dagger, Axpy, gridDim, blockDim, shared, stream, param) \
197  }
198 
199 #define EVEN_MORE_GENERIC_ASYM_DSLASH(FUNC, FLOAT, DAG, X, kernel_type, gridDim, blockDim, shared, stream, param) \
200  if (reconstruct == QUDA_RECONSTRUCT_NO) { \
201  SET_CACHE( FUNC ## FLOAT ## 18 ## DAG ## X ## Kernel<kernel_type> ); \
202  LAUNCH_KERNEL( FUNC ## FLOAT ## 18 ## DAG ## X ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
203  } else if (reconstruct == QUDA_RECONSTRUCT_12) { \
204  SET_CACHE( FUNC ## FLOAT ## 12 ## DAG ## X ## Kernel<kernel_type> ); \
205  LAUNCH_KERNEL( FUNC ## FLOAT ## 12 ## DAG ## X ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
206  } else if (reconstruct == QUDA_RECONSTRUCT_8) { \
207  SET_CACHE( FUNC ## FLOAT ## 8 ## DAG ## X ## Kernel<kernel_type> ); \
208  LAUNCH_KERNEL( FUNC ## FLOAT ## 8 ## DAG ## X ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
209  }
210 
211 #define MORE_GENERIC_ASYM_DSLASH(FUNC, DAG, X, kernel_type, gridDim, blockDim, shared, stream, param) \
212  if (typeid(sFloat) == typeid(double2)) { \
213  EVEN_MORE_GENERIC_ASYM_DSLASH(FUNC, D, DAG, X, kernel_type, gridDim, blockDim, shared, stream, param) \
214  } else if (typeid(sFloat) == typeid(float4)) { \
215  EVEN_MORE_GENERIC_ASYM_DSLASH(FUNC, S, DAG, X, kernel_type, gridDim, blockDim, shared, stream, param) \
216  } else if (typeid(sFloat)==typeid(short4)) { \
217  EVEN_MORE_GENERIC_ASYM_DSLASH(FUNC, H, DAG, X, kernel_type, gridDim, blockDim, shared, stream, param) \
218  }
219 
220 
221 #ifndef MULTI_GPU
222 
223 #define GENERIC_ASYM_DSLASH(FUNC, DAG, X, gridDim, blockDim, shared, stream, param) \
224  switch(param.kernel_type) { \
225  case INTERIOR_KERNEL: \
226  MORE_GENERIC_ASYM_DSLASH(FUNC, DAG, X, INTERIOR_KERNEL, gridDim, blockDim, shared, stream, param) \
227  break; \
228  default: \
229  errorQuda("KernelType %d not defined for single GPU", param.kernel_type); \
230  }
231 
232 #else
233 
234 #define GENERIC_ASYM_DSLASH(FUNC, DAG, X, gridDim, blockDim, shared, stream, param) \
235  switch(param.kernel_type) { \
236  case INTERIOR_KERNEL: \
237  MORE_GENERIC_ASYM_DSLASH(FUNC, DAG, X, INTERIOR_KERNEL, gridDim, blockDim, shared, stream, param) \
238  break; \
239  case EXTERIOR_KERNEL_X: \
240  MORE_GENERIC_ASYM_DSLASH(FUNC, DAG, X, EXTERIOR_KERNEL_X, gridDim, blockDim, shared, stream, param) \
241  break; \
242  case EXTERIOR_KERNEL_Y: \
243  MORE_GENERIC_ASYM_DSLASH(FUNC, DAG, X, EXTERIOR_KERNEL_Y, gridDim, blockDim, shared, stream, param) \
244  break; \
245  case EXTERIOR_KERNEL_Z: \
246  MORE_GENERIC_ASYM_DSLASH(FUNC, DAG, X, EXTERIOR_KERNEL_Z, gridDim, blockDim, shared, stream, param) \
247  break; \
248  case EXTERIOR_KERNEL_T: \
249  MORE_GENERIC_ASYM_DSLASH(FUNC, DAG, X, EXTERIOR_KERNEL_T, gridDim, blockDim, shared, stream, param) \
250  break; \
251  case EXTERIOR_KERNEL_ALL: \
252  MORE_GENERIC_ASYM_DSLASH(FUNC, DAG, X, EXTERIOR_KERNEL_ALL, gridDim, blockDim, shared, stream, param) \
253  break; \
254  default: \
255  break; \
256  }
257 
258 #endif
259 
260 // macro used for dslash types with dagger kernel defined (Wilson, domain wall, etc.)
261 #define ASYM_DSLASH(FUNC, gridDim, blockDim, shared, stream, param) \
262  if (!dagger) { \
263  GENERIC_ASYM_DSLASH(FUNC, , Xpay, gridDim, blockDim, shared, stream, param) \
264  } else { \
265  GENERIC_ASYM_DSLASH(FUNC, Dagger, Xpay, gridDim, blockDim, shared, stream, param) \
266  }
267 
268 
269 
270 //macro used for twisted mass dslash:
271 
272 #define EVEN_MORE_GENERIC_NDEG_TM_DSLASH(FUNC, FLOAT, DAG, X, kernel_type, gridDim, blockDim, shared, stream, param) \
273  if (x == 0 && d == 0) { \
274  if (reconstruct == QUDA_RECONSTRUCT_NO) { \
275  SET_CACHE( FUNC ## FLOAT ## 18 ## DAG ## Twist ## Kernel<kernel_type> ); \
276  LAUNCH_KERNEL( FUNC ## FLOAT ## 18 ## DAG ## Twist ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
277  } else if (reconstruct == QUDA_RECONSTRUCT_12) { \
278  SET_CACHE( FUNC ## FLOAT ## 12 ## DAG ## Twist ## Kernel<kernel_type> ); \
279  LAUNCH_KERNEL( FUNC ## FLOAT ## 12 ## DAG ## Twist ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
280  } else { \
281  SET_CACHE( FUNC ## FLOAT ## 8 ## DAG ## Twist ## Kernel<kernel_type> ); \
282  LAUNCH_KERNEL( FUNC ## FLOAT ## 8 ## DAG ## Twist ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
283  } \
284  } else if (x != 0 && d == 0) { \
285  if (reconstruct == QUDA_RECONSTRUCT_NO) { \
286  SET_CACHE( FUNC ## FLOAT ## 18 ## DAG ## Twist ## X ## Kernel<kernel_type> ); \
287  LAUNCH_KERNEL( FUNC ## FLOAT ## 18 ## DAG ## Twist ## X ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
288  } else if (reconstruct == QUDA_RECONSTRUCT_12) { \
289  SET_CACHE( FUNC ## FLOAT ## 12 ## DAG ## Twist ## X ## Kernel<kernel_type> ); \
290  LAUNCH_KERNEL( FUNC ## FLOAT ## 12 ## DAG ## Twist ## X ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
291  } else if (reconstruct == QUDA_RECONSTRUCT_8) { \
292  SET_CACHE( FUNC ## FLOAT ## 8 ## DAG ## Twist ## X ## Kernel<kernel_type> ); \
293  LAUNCH_KERNEL( FUNC ## FLOAT ## 8 ## DAG ## Twist ## X ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
294  } \
295  } else if (x == 0 && d != 0) { \
296  if (reconstruct == QUDA_RECONSTRUCT_NO) { \
297  SET_CACHE( FUNC ## FLOAT ## 18 ## DAG ## Kernel<kernel_type> ); \
298  LAUNCH_KERNEL( FUNC ## FLOAT ## 18 ## DAG ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
299  } else if (reconstruct == QUDA_RECONSTRUCT_12) { \
300  SET_CACHE( FUNC ## FLOAT ## 12 ## DAG ## Kernel<kernel_type> ); \
301  LAUNCH_KERNEL( FUNC ## FLOAT ## 12 ## DAG ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
302  } else { \
303  SET_CACHE( FUNC ## FLOAT ## 8 ## DAG ## Kernel<kernel_type> ); \
304  LAUNCH_KERNEL( FUNC ## FLOAT ## 8 ## DAG ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
305  } \
306  } else{ \
307  if (reconstruct == QUDA_RECONSTRUCT_NO) { \
308  SET_CACHE( FUNC ## FLOAT ## 18 ## DAG ## X ## Kernel<kernel_type> ); \
309  LAUNCH_KERNEL( FUNC ## FLOAT ## 18 ## DAG ## X ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
310  } else if (reconstruct == QUDA_RECONSTRUCT_12) { \
311  SET_CACHE( FUNC ## FLOAT ## 12 ## DAG ## X ## Kernel<kernel_type> ); \
312  LAUNCH_KERNEL( FUNC ## FLOAT ## 12 ## DAG ## X ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
313  } else if (reconstruct == QUDA_RECONSTRUCT_8) { \
314  SET_CACHE( FUNC ## FLOAT ## 8 ## DAG ## X ## Kernel<kernel_type> ); \
315  LAUNCH_KERNEL( FUNC ## FLOAT ## 8 ## DAG ## X ## Kernel<kernel_type>, gridDim, blockDim, shared, stream, param); \
316  } \
317  }
318 
319 #define MORE_GENERIC_NDEG_TM_DSLASH(FUNC, DAG, X, kernel_type, gridDim, blockDim, shared, stream, param) \
320  if (typeid(sFloat) == typeid(double2)) { \
321  EVEN_MORE_GENERIC_NDEG_TM_DSLASH(FUNC, D, DAG, X, kernel_type, gridDim, blockDim, shared, stream, param) \
322  } else if (typeid(sFloat) == typeid(float4)) { \
323  EVEN_MORE_GENERIC_NDEG_TM_DSLASH(FUNC, S, DAG, X, kernel_type, gridDim, blockDim, shared, stream, param) \
324  } else if (typeid(sFloat)==typeid(short4)) { \
325  EVEN_MORE_GENERIC_NDEG_TM_DSLASH(FUNC, H, DAG, X, kernel_type, gridDim, blockDim, shared, stream, param) \
326  } else { \
327  errorQuda("Undefined precision type"); \
328  }
329 
330 #ifndef MULTI_GPU
331 
332 #define GENERIC_NDEG_TM_DSLASH(FUNC, DAG, X, gridDim, blockDim, shared, stream, param) \
333  switch(param.kernel_type) { \
334  case INTERIOR_KERNEL: \
335  MORE_GENERIC_NDEG_TM_DSLASH(FUNC, DAG, X, INTERIOR_KERNEL, gridDim, blockDim, shared, stream, param) \
336  break; \
337  default: \
338  errorQuda("KernelType %d not defined for single GPU", param.kernel_type); \
339  }
340 
341 #else
342 
343 #define GENERIC_NDEG_TM_DSLASH(FUNC, DAG, X, gridDim, blockDim, shared, stream, param) \
344  switch(param.kernel_type) { \
345  case INTERIOR_KERNEL: \
346  MORE_GENERIC_NDEG_TM_DSLASH(FUNC, DAG, X, INTERIOR_KERNEL, gridDim, blockDim, shared, stream, param) \
347  break; \
348  case EXTERIOR_KERNEL_X: \
349  MORE_GENERIC_NDEG_TM_DSLASH(FUNC, DAG, X, EXTERIOR_KERNEL_X, gridDim, blockDim, shared, stream, param) \
350  break; \
351  case EXTERIOR_KERNEL_Y: \
352  MORE_GENERIC_NDEG_TM_DSLASH(FUNC, DAG, X, EXTERIOR_KERNEL_Y, gridDim, blockDim, shared, stream, param) \
353  break; \
354  case EXTERIOR_KERNEL_Z: \
355  MORE_GENERIC_NDEG_TM_DSLASH(FUNC, DAG, X, EXTERIOR_KERNEL_Z, gridDim, blockDim, shared, stream, param) \
356  break; \
357  case EXTERIOR_KERNEL_T: \
358  MORE_GENERIC_NDEG_TM_DSLASH(FUNC, DAG, X, EXTERIOR_KERNEL_T, gridDim, blockDim, shared, stream, param) \
359  break; \
360  case EXTERIOR_KERNEL_ALL: \
361  MORE_GENERIC_NDEG_TM_DSLASH(FUNC, DAG, X, EXTERIOR_KERNEL_ALL, gridDim, blockDim, shared, stream, param) \
362  break; \
363  default: \
364  break; \
365  }
366 
367 #endif
368 
369 #define NDEG_TM_DSLASH(FUNC, gridDim, blockDim, shared, stream, param) \
370  if (!dagger) { \
371  GENERIC_NDEG_TM_DSLASH(FUNC, , Xpay, gridDim, blockDim, shared, stream, param) \
372  } else { \
373  GENERIC_NDEG_TM_DSLASH(FUNC, Dagger, Xpay, gridDim, blockDim, shared, stream, param) \
374  }
375 //end of tm dslash macro
376 
377 
378 // Use an abstract class interface to drive the different CUDA dslash
379 // kernels. All parameters are curried into the derived classes to
380 // allow a simple interface.
381 class DslashCuda : public Tunable {
382 
383 protected:
384  cudaColorSpinorField *out;
385  const cudaColorSpinorField *in;
386  const cudaColorSpinorField *x;
387  const GaugeField &gauge;
390  const int dagger;
391  static bool init;
392 
393  unsigned int sharedBytesPerBlock(const TuneParam &param) const { return 0; }
394  bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
395  bool tuneAuxDim() const { return true; } // Do tune the aux dimensions.
396  // all dslashes expect a 4-d volume here (dwf Ls is y thread dimension)
397  unsigned int minThreads() const { return dslashParam.threads; }
398  char aux_base[TuneKey::aux_n];
399  char aux[8][TuneKey::aux_n];
400  static char ghost_str[TuneKey::aux_n]; // string with ghostDim information
401 
406  inline void fillAuxBase() {
407  char comm[5];
408  comm[0] = (dslashParam.commDim[0] ? '1' : '0');
409  comm[1] = (dslashParam.commDim[1] ? '1' : '0');
410  comm[2] = (dslashParam.commDim[2] ? '1' : '0');
411  comm[3] = (dslashParam.commDim[3] ? '1' : '0');
412  comm[4] = '\0';
413  strcpy(aux_base,",comm=");
414  strcat(aux_base,comm);
415 
416  switch (reconstruct) {
417  case QUDA_RECONSTRUCT_NO: strcat(aux_base,",reconstruct=18"); break;
418  case QUDA_RECONSTRUCT_13: strcat(aux_base,",reconstruct=13"); break;
419  case QUDA_RECONSTRUCT_12: strcat(aux_base,",reconstruct=12"); break;
420  case QUDA_RECONSTRUCT_9: strcat(aux_base, ",reconstruct=9"); break;
421  case QUDA_RECONSTRUCT_8: strcat(aux_base, ",reconstruct=8"); break;
422  default: break;
423  }
424 
425  if (x) strcat(aux_base,",Xpay");
426  if (dagger) strcat(aux_base,",dagger");
427  }
428 
434  inline void fillAux(KernelType kernel_type, const char *kernel_str) {
435  strcpy(aux[kernel_type],kernel_str);
436  if (kernel_type == INTERIOR_KERNEL) strcat(aux[kernel_type],ghost_str);
437  strcat(aux[kernel_type],aux_base);
438  }
439 
445  inline void setParam()
446  {
447  // factor of 2 (or 1) for T-dimensional spin projection (FIXME - unnecessary)
448  dslashParam.tProjScale = getKernelPackT() ? 1.0 : 2.0;
450 
451  // update the ghosts for the non-p2p directions
452  for (int dim=0; dim<4; dim++) {
453  if (!dslashParam.commDim[dim]) continue;
454 
455  for (int dir=0; dir<2; dir++) {
456  /* if doing interior kernel, then this is the initial call, so
457  we must set all ghost pointers else if doing exterior kernel, then
458  we only have to update the non-p2p ghosts, since these may
459  have been assigned to zero-copy memory */
461  dslashParam.ghost[2*dim+dir] = (void*)in->Ghost2();
462  dslashParam.ghostNorm[2*dim+dir] = (float*)(in->Ghost2());
463 
464 #ifdef USE_TEXTURE_OBJECTS
465  dslashParam.ghostTex[2*dim+dir] = in->GhostTex();
466  dslashParam.ghostTexNorm[2*dim+dir] = in->GhostTexNorm();
467 #endif // USE_TEXTURE_OBJECTS
468  }
469  }
470  }
471 
472  }
473 
474 public:
476 
477  DslashCuda(cudaColorSpinorField *out, const cudaColorSpinorField *in,
478  const cudaColorSpinorField *x, const GaugeField &gauge,
479  const int parity, const int dagger, const int *commOverride)
480  : out(out), in(in), x(x), gauge(gauge), reconstruct(gauge.Reconstruct()),
481  dagger(dagger), saveOut(0), saveOutNorm(0) {
482 
483  if (in->Precision() != gauge.Precision())
484  errorQuda("Mixing gauge %d and spinor %d precision not supported", gauge.Precision(), in->Precision());
485 
486  constexpr int nDimComms = 4;
487  for (int i=0; i<nDimComms; i++){
488  dslashParam.ghostOffset[i][0] = in->GhostOffset(i,0)/in->FieldOrder();
489  dslashParam.ghostOffset[i][1] = in->GhostOffset(i,1)/in->FieldOrder();
490  dslashParam.ghostNormOffset[i][0] = in->GhostNormOffset(i,0);
491  dslashParam.ghostNormOffset[i][1] = in->GhostNormOffset(i,1);
492  dslashParam.ghostDim[i] = comm_dim_partitioned(i); // determines whether to use regular or ghost indexing at boundary
493  dslashParam.commDim[i] = (!commOverride[i]) ? 0 : comm_dim_partitioned(i); // switch off comms if override = 0
494  }
495 
496  // set parameters particular for this instance
497  dslashParam.out = (void*)out->V();
498  dslashParam.outNorm = (float*)out->Norm();
499  dslashParam.in = (void*)in->V();
500  dslashParam.inNorm = (float*)in->Norm();
501  dslashParam.x = x ? (void*)x->V() : nullptr;
502  dslashParam.xNorm = x ? (float*)x->Norm() : nullptr;
503 
504 #ifdef USE_TEXTURE_OBJECTS
505  dslashParam.inTex = in->Tex();
506  dslashParam.inTexNorm = in->TexNorm();
507  if (out) dslashParam.outTex = out->Tex();
508  if (out) dslashParam.outTexNorm = out->TexNorm();
509  if (x) dslashParam.xTex = x->Tex();
510  if (x) dslashParam.xTexNorm = x->TexNorm();
511 #endif // USE_TEXTURE_OBJECTS
512 
513  dslashParam.parity = parity;
514  bindGaugeTex(static_cast<const cudaGaugeField&>(gauge), parity, dslashParam);
515 
516  dslashParam.sp_stride = in->Stride();
517 
518  dslashParam.dc = in->getDslashConstant(); // get precomputed constants
519 
520  dslashParam.gauge_stride = gauge.Stride();
521  dslashParam.gauge_fixed = gauge.GaugeFixed();
522 
523  dslashParam.anisotropy = gauge.Anisotropy();
524  dslashParam.anisotropy_f = (float)dslashParam.anisotropy;
525 
526  dslashParam.t_boundary = (gauge.TBoundary() == QUDA_PERIODIC_T) ? 1.0 : -1.0;
527  dslashParam.t_boundary_f = (float)dslashParam.t_boundary;
528 
529  dslashParam.An2 = make_float2(gauge.Anisotropy(), 1.0 / (gauge.Anisotropy()*gauge.Anisotropy()));
530  dslashParam.TB2 = make_float2(dslashParam.t_boundary_f, 1.0 / (dslashParam.t_boundary * dslashParam.t_boundary));
531 
532  dslashParam.coeff = 1.0;
533  dslashParam.coeff_f = 1.0f;
534 
535  dslashParam.twist_a = 0.0;
536  dslashParam.twist_b = 0.0;
537 
538  dslashParam.No2 = make_float2(1.0f, 1.0f);
539  dslashParam.Pt0 = (comm_coord(3) == 0) ? true : false;
540  dslashParam.PtNm1 = (comm_coord(3) == comm_dim(3)-1) ? true : false;
541 
542  // this sets the communications pattern for the packing kernel
543  setPackComms(dslashParam.commDim);
544 
545  if (!init) { // these parameters are constant across all dslash instances for a given run
546  char ghost[5]; // set the ghost string
547  for (int dim=0; dim<nDimComms; dim++) ghost[dim] = (dslashParam.ghostDim[dim] ? '1' : '0');
548  ghost[4] = '\0';
549  strcpy(ghost_str,",ghost=");
550  strcat(ghost_str,ghost);
551  init = true;
552  }
553 
554  fillAuxBase();
555 #ifdef MULTI_GPU
556  fillAux(INTERIOR_KERNEL, "policy_kernel=interior");
557  fillAux(EXTERIOR_KERNEL_ALL, "policy_kernel=exterior_all");
558  fillAux(EXTERIOR_KERNEL_X, "policy_kernel=exterior_x");
559  fillAux(EXTERIOR_KERNEL_Y, "policy_kernel=exterior_y");
560  fillAux(EXTERIOR_KERNEL_Z, "policy_kernel=exterior_z");
561  fillAux(EXTERIOR_KERNEL_T, "policy_kernel=exterior_t");
562 #else
563  fillAux(INTERIOR_KERNEL, "policy_kernel=single-GPU");
564 #endif // MULTI_GPU
565  fillAux(KERNEL_POLICY, "policy");
566 
567  }
568 
569  virtual ~DslashCuda() { unbindGaugeTex(static_cast<const cudaGaugeField&>(gauge)); }
570  virtual TuneKey tuneKey() const
571  { return TuneKey(in->VolString(), typeid(*this).name(), aux[dslashParam.kernel_type]); }
572 
573  const char* getAux(KernelType type) const {
574  return aux[type];
575  }
576 
577  void setAux(KernelType type, const char *aux_) {
578  strcpy(aux[type], aux_);
579  }
580 
581  void augmentAux(KernelType type, const char *extra) {
582  strcat(aux[type], extra);
583  }
584 
585  virtual int Nface() const { return 2; }
586 
587  int Dagger() const { return dagger; }
588 
589 #if defined(DSLASH_TUNE_TILE)
590  // Experimental autotuning of the thread ordering
591  bool advanceAux(TuneParam &param) const
592  {
593  if (in->Nspin()==1 || in->Ndim()==5) return false;
594  const int *X = in->X();
595 
596  if (param.aux.w < X[3] && param.aux.x > 1 && param.aux.w < 2) {
597  do { param.aux.w++; } while( (X[3]) % param.aux.w != 0);
598  if (param.aux.w <= X[3]) return true;
599  } else {
600  param.aux.w = 1;
601 
602  if (param.aux.z < X[2] && param.aux.x > 1 && param.aux.z < 8) {
603  do { param.aux.z++; } while( (X[2]) % param.aux.z != 0);
604  if (param.aux.z <= X[2]) return true;
605  } else {
606 
607  param.aux.z = 1;
608  if (param.aux.y < X[1] && param.aux.x > 1 && param.aux.y < 32) {
609  do { param.aux.y++; } while( X[1] % param.aux.y != 0);
610  if (param.aux.y <= X[1]) return true;
611  } else {
612  param.aux.y = 1;
613  if (param.aux.x < (2*X[0]) && param.aux.x < 32) {
614  do { param.aux.x++; } while( (2*X[0]) % param.aux.x != 0);
615  if (param.aux.x <= (2*X[0]) ) return true;
616  }
617  }
618  }
619  }
620  param.aux = make_int4(2,1,1,1);
621  return false;
622  }
623 
624  void initTuneParam(TuneParam &param) const
625  {
626  Tunable::initTuneParam(param);
627  param.aux = make_int4(2,1,1,1);
628  }
629 
631  void defaultTuneParam(TuneParam &param) const
632  {
633  Tunable::defaultTuneParam(param);
634  param.aux = make_int4(2,1,1,1);
635  }
636 #endif
637 
638  virtual void preTune()
639  {
640  out->backup();
641  }
642 
643  virtual void postTune()
644  {
645  out->restore();
646  }
647 
648  /*void launch_auxiliary(cudaStream_t &stream) {
649  auxiliary.apply(stream);
650  }*/
651 
652  /*
653  per direction / dimension flops
654  spin project flops = Nc * Ns
655  SU(3) matrix-vector flops = (8 Nc - 2) * Nc
656  spin reconstruction flops = 2 * Nc * Ns (just an accumulation to all components)
657  xpay = 2 * 2 * Nc * Ns
658 
659  So for the full dslash we have, where for the final spin
660  reconstruct we have -1 since the first direction does not
661  require any accumulation.
662 
663  flops = (2 * Nd * Nc * Ns) + (2 * Nd * (Ns/2) * (8*Nc-2) * Nc) + ((2 * Nd - 1) * 2 * Nc * Ns)
664  flops_xpay = flops + 2 * 2 * Nc * Ns
665 
666  For Wilson this should give 1344 for Nc=3,Ns=2 and 1368 for the xpay equivalent
667  */
668  virtual long long flops() const {
669  int mv_flops = (8 * in->Ncolor() - 2) * in->Ncolor(); // SU(3) matrix-vector flops
670  int num_mv_multiply = in->Nspin() == 4 ? 2 : 1;
671  int ghost_flops = (num_mv_multiply * mv_flops + 2*in->Ncolor()*in->Nspin());
672  int xpay_flops = 2 * 2 * in->Ncolor() * in->Nspin(); // multiply and add per real component
673  int num_dir = 2 * 4;
674 
675  long long flops_ = 0;
676  switch(dslashParam.kernel_type) {
677  case EXTERIOR_KERNEL_X:
678  case EXTERIOR_KERNEL_Y:
679  case EXTERIOR_KERNEL_Z:
680  case EXTERIOR_KERNEL_T:
681  flops_ = (ghost_flops + (x ? xpay_flops : 0)) * 2 * in->GhostFace()[dslashParam.kernel_type];
682  break;
683  case EXTERIOR_KERNEL_ALL:
684  {
685  long long ghost_sites = 2 * (in->GhostFace()[0]+in->GhostFace()[1]+in->GhostFace()[2]+in->GhostFace()[3]);
686  flops_ = (ghost_flops + (x ? xpay_flops : 0)) * ghost_sites;
687  break;
688  }
689  case INTERIOR_KERNEL:
690  case KERNEL_POLICY:
691  {
692  long long sites = in->VolumeCB();
693  flops_ = (num_dir*(in->Nspin()/4)*in->Ncolor()*in->Nspin() + // spin project (=0 for staggered)
694  num_dir*num_mv_multiply*mv_flops + // SU(3) matrix-vector multiplies
695  ((num_dir-1)*2*in->Ncolor()*in->Nspin())) * sites; // accumulation
696  if (x) flops_ += xpay_flops * sites;
697 
698  if (dslashParam.kernel_type == KERNEL_POLICY) break;
699  // now correct for flops done by exterior kernel
700  long long ghost_sites = 0;
701  for (int d=0; d<4; d++) if (dslashParam.commDim[d]) ghost_sites += 2 * in->GhostFace()[d];
702  flops_ -= (ghost_flops + (x ? xpay_flops : 0)) * ghost_sites;
703 
704  break;
705  }
706  }
707  return flops_;
708  }
709 
710  virtual long long bytes() const {
711  int gauge_bytes = reconstruct * in->Precision();
712  bool isFixed = (in->Precision() == sizeof(short) || in->Precision() == sizeof(char)) ? true : false;
713  int spinor_bytes = 2 * in->Ncolor() * in->Nspin() * in->Precision() + (isFixed ? sizeof(float) : 0);
714  int proj_spinor_bytes = (in->Nspin()==4 ? 1 : 2) * in->Ncolor() * in->Nspin() * in->Precision() + (isFixed ? sizeof(float) : 0);
715  int ghost_bytes = (proj_spinor_bytes + gauge_bytes) + spinor_bytes;
716  int num_dir = 2 * 4; // set to 4 dimensions since we take care of 5-d fermions in derived classes where necessary
717 
718  long long bytes_=0;
719  switch(dslashParam.kernel_type) {
720  case EXTERIOR_KERNEL_X:
721  case EXTERIOR_KERNEL_Y:
722  case EXTERIOR_KERNEL_Z:
723  case EXTERIOR_KERNEL_T:
724  bytes_ = (ghost_bytes + (x ? spinor_bytes : 0)) * 2 * in->GhostFace()[dslashParam.kernel_type];
725  break;
726  case EXTERIOR_KERNEL_ALL:
727  {
728  long long ghost_sites = 2 * (in->GhostFace()[0]+in->GhostFace()[1]+in->GhostFace()[2]+in->GhostFace()[3]);
729  bytes_ = (ghost_bytes + (x ? spinor_bytes : 0)) * ghost_sites;
730  break;
731  }
732  case INTERIOR_KERNEL:
733  case KERNEL_POLICY:
734  {
735  long long sites = in->VolumeCB();
736  bytes_ = (num_dir*gauge_bytes + ((num_dir-2)*spinor_bytes + 2*proj_spinor_bytes) + spinor_bytes)*sites;
737  if (x) bytes_ += spinor_bytes;
738 
739  if (dslashParam.kernel_type == KERNEL_POLICY) break;
740  // now correct for bytes done by exterior kernel
741  long long ghost_sites = 0;
742  for (int d=0; d<4; d++) if (dslashParam.commDim[d]) ghost_sites += 2*in->GhostFace()[d];
743  bytes_ -= (ghost_bytes + (x ? spinor_bytes : 0)) * ghost_sites;
744 
745  break;
746  }
747  }
748  return bytes_;
749  }
750 
751 };
752 
753 //static declarations
754 bool DslashCuda::init = false;
755 char DslashCuda::ghost_str[TuneKey::aux_n];
756 
760 #ifdef SHARED_WILSON_DSLASH
761 class SharedDslashCuda : public DslashCuda {
762 protected:
763  unsigned int sharedBytesPerBlock(const TuneParam &param) const { return 0; } // FIXME: this isn't quite true, but works
764  bool advanceSharedBytes(TuneParam &param) const {
765  if (dslashParam.kernel_type != INTERIOR_KERNEL) return DslashCuda::advanceSharedBytes(param);
766  else return false;
767  } // FIXME - shared memory tuning only supported on exterior kernels
768 
770  int sharedBytes(const dim3 &block) const {
771  int warpSize = 32; // FIXME - query from device properties
772  int block_xy = block.x*block.y;
773  if (block_xy % warpSize != 0) block_xy = ((block_xy / warpSize) + 1)*warpSize;
774  return block_xy*block.z*sharedBytesPerThread();
775  }
776 
778  dim3 createGrid(const dim3 &block) const {
779  unsigned int gx = (in->X(0)*in->X(3) + block.x - 1) / block.x;
780  unsigned int gy = (in->X(1) + block.y - 1 ) / block.y;
781  unsigned int gz = (in->X(2) + block.z - 1) / block.z;
782  return dim3(gx, gy, gz);
783  }
784 
786  bool advanceBlockDim(TuneParam &param) const {
787  if (dslashParam.kernel_type != INTERIOR_KERNEL) return DslashCuda::advanceBlockDim(param);
788  const unsigned int min_threads = 2;
789  const unsigned int max_threads = 512; // FIXME: use deviceProp.maxThreadsDim[0];
790  const unsigned int max_shared = 16384*3; // FIXME: use deviceProp.sharedMemPerBlock;
791 
792  // set the x-block dimension equal to the entire x dimension
793  bool set = false;
794  dim3 blockInit = param.block;
795  blockInit.z++;
796  for (unsigned bx=blockInit.x; bx<=in->X(0); bx++) {
797  //unsigned int gx = (in->X(0)*in->x(3) + bx - 1) / bx;
798  for (unsigned by=blockInit.y; by<=in->X(1); by++) {
799  unsigned int gy = (in->X(1) + by - 1 ) / by;
800 
801  if (by > 1 && (by%2) != 0) continue; // can't handle odd blocks yet except by=1
802 
803  for (unsigned bz=blockInit.z; bz<=in->X(2); bz++) {
804  unsigned int gz = (in->X(2) + bz - 1) / bz;
805 
806  if (bz > 1 && (bz%2) != 0) continue; // can't handle odd blocks yet except bz=1
807  if (bx*by*bz > max_threads) continue;
808  if (bx*by*bz < min_threads) continue;
809  // can't yet handle the last block properly in shared memory addressing
810  if (by*gy != in->X(1)) continue;
811  if (bz*gz != in->X(2)) continue;
812  if (sharedBytes(dim3(bx, by, bz)) > max_shared) continue;
813 
814  param.block = dim3(bx, by, bz);
815  set = true; break;
816  }
817  if (set) break;
818  blockInit.z = 1;
819  }
820  if (set) break;
821  blockInit.y = 1;
822  }
823 
824  if (param.block.x > in->X(0) && param.block.y > in->X(1) && param.block.z > in->X(2) || !set) {
825  //||sharedBytesPerThread()*param.block.x > max_shared) {
826  param.block = dim3(in->X(0), 1, 1);
827  return false;
828  } else {
829  param.grid = createGrid(param.block);
830  param.shared_bytes = sharedBytes(param.block);
831  return true;
832  }
833  }
834 
835 public:
836  SharedDslashCuda(cudaColorSpinorField *out, const cudaColorSpinorField *in,
837  const cudaColorSpinorField *x, const GaugeField &gauge,
838  int parity, int dagger, const int *commOverride)
839  : DslashCuda(out, in, x, gauge, parity, dagger, commOverride) { ; }
840  virtual ~SharedDslashCuda() { ; }
841 
842  virtual void initTuneParam(TuneParam &param) const
843  {
844  if (dslashParam.kernel_type != INTERIOR_KERNEL) return DslashCuda::initTuneParam(param);
845 
846  param.block = dim3(in->X(0), 1, 1);
847  param.grid = createGrid(param.block);
848  param.shared_bytes = sharedBytes(param.block);
849  }
850 
852  virtual void defaultTuneParam(TuneParam &param) const
853  {
854  if (dslashParam.kernel_type != INTERIOR_KERNEL) DslashCuda::defaultTuneParam(param);
855  else initTuneParam(param);
856  }
857 };
858 #else
859 class SharedDslashCuda : public DslashCuda {
860 public:
861  SharedDslashCuda(cudaColorSpinorField *out, const cudaColorSpinorField *in,
862  const cudaColorSpinorField *x, const GaugeField &gauge,
863  int parity, int dagger, const int *commOverride)
864  : DslashCuda(out, in, x, gauge, parity, dagger, commOverride) { }
865  virtual ~SharedDslashCuda() { }
866 };
867 #endif
virtual long long bytes() const
const cudaColorSpinorField * in
virtual ~SharedDslashCuda()
cudaColorSpinorField * out
int commDim[QUDA_MAX_DIM]
virtual void preTune()
bool getKernelPackT()
Definition: dslash_quda.cu:26
void fillAuxBase()
Set the base strings used by the different dslash kernel types for autotuning.
#define errorQuda(...)
Definition: util_quda.h:121
virtual void postTune()
void augmentAux(KernelType type, const char *extra)
int comm_dim(int dim)
const char * getAux(KernelType type) const
float * ghostNorm[2 *QUDA_MAX_DIM]
int comm_coord(int dim)
DslashCuda(cudaColorSpinorField *out, const cudaColorSpinorField *in, const cudaColorSpinorField *x, const GaugeField &gauge, const int parity, const int dagger, const int *commOverride)
virtual int Nface() const
int ghostOffset[QUDA_MAX_DIM+1][2]
char aux[8][TuneKey::aux_n]
char aux_base[TuneKey::aux_n]
const int dagger
char * saveOut
const GaugeField & gauge
QudaGaugeParam param
Definition: pack_test.cpp:17
static bool init
void * ghost[2 *QUDA_MAX_DIM]
DslashConstant dc
KernelType kernel_type
DslashParam dslashParam
bool tuneGridDim() const
virtual ~DslashCuda()
int X[4]
Definition: covdev_test.cpp:70
static char ghost_str[TuneKey::aux_n]
unsigned int sharedBytesPerBlock(const TuneParam &param) const
bool tuneAuxDim() const
bool comm_peer2peer_enabled(int dir, int dim)
char * saveOutNorm
const cudaColorSpinorField * x
void setAux(KernelType type, const char *aux_)
enum QudaReconstructType_s QudaReconstructType
int Dagger() const
virtual TuneKey tuneKey() const
unsigned int minThreads() const
int ghostDim[QUDA_MAX_DIM]
void fillAux(KernelType kernel_type, const char *kernel_str)
Specialize the auxiliary strings for each kernel type.
void setParam()
Set the dslashParam for the current multi-GPU parameters (set these at the last minute to ensure we a...
const QudaReconstructType reconstruct
virtual long long flops() const
SharedDslashCuda(cudaColorSpinorField *out, const cudaColorSpinorField *in, const cudaColorSpinorField *x, const GaugeField &gauge, int parity, int dagger, const int *commOverride)
QudaParity parity
Definition: covdev_test.cpp:54
int ghostNormOffset[QUDA_MAX_DIM+1][2]
int comm_dim_partitioned(int dim)
void setPackComms(const int *dim_pack)
Helper function that sets which dimensions the packing kernel should be packing for.
Definition: dslash_pack2.cu:14