QUDA  v0.5.0
A library for QCD on GPUs
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
llfat_quda.cu
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <cuda_runtime.h>
3 #include <cuda.h>
4 
5 #include <quda_internal.h>
6 #include <llfat_quda.h>
7 #include <read_gauge.h>
8 #include "gauge_field.h"
9 #include <force_common.h>
10 
11 namespace quda {
12 
13 #if (__COMPUTE_CAPABILITY__ >= 200)
14 #define SITE_MATRIX_LOAD_TEX 1
15 #define MULINK_LOAD_TEX 1
16 #define FATLINK_LOAD_TEX 1
17 #else
18 #define SITE_MATRIX_LOAD_TEX 0
19 #define MULINK_LOAD_TEX 1
20 #define FATLINK_LOAD_TEX 1
21 #endif
22 
23 #define BLOCK_DIM 64
24 
25 #define WRITE_FAT_MATRIX(gauge, dir, idx)do { \
26  gauge[idx + dir*9*fl.fat_ga_stride] = FAT0; \
27  gauge[idx + (dir*9+1) * fl.fat_ga_stride] = FAT1; \
28  gauge[idx + (dir*9+2) * fl.fat_ga_stride] = FAT2; \
29  gauge[idx + (dir*9+3) * fl.fat_ga_stride] = FAT3; \
30  gauge[idx + (dir*9+4) * fl.fat_ga_stride] = FAT4; \
31  gauge[idx + (dir*9+5) * fl.fat_ga_stride] = FAT5; \
32  gauge[idx + (dir*9+6) * fl.fat_ga_stride] = FAT6; \
33  gauge[idx + (dir*9+7) * fl.fat_ga_stride] = FAT7; \
34  gauge[idx + (dir*9+8) * fl.fat_ga_stride] = FAT8;} while(0)
35 
36 
37 #define WRITE_STAPLE_MATRIX(gauge, idx) \
38  gauge[idx] = STAPLE0; \
39  gauge[idx + fl.staple_stride] = STAPLE1; \
40  gauge[idx + 2*fl.staple_stride] = STAPLE2; \
41  gauge[idx + 3*fl.staple_stride] = STAPLE3; \
42  gauge[idx + 4*fl.staple_stride] = STAPLE4; \
43  gauge[idx + 5*fl.staple_stride] = STAPLE5; \
44  gauge[idx + 6*fl.staple_stride] = STAPLE6; \
45  gauge[idx + 7*fl.staple_stride] = STAPLE7; \
46  gauge[idx + 8*fl.staple_stride] = STAPLE8;
47 
48 
49 #define SCALAR_MULT_SU3_MATRIX(a, b, c) \
50  c##00_re = a*b##00_re; \
51  c##00_im = a*b##00_im; \
52  c##01_re = a*b##01_re; \
53  c##01_im = a*b##01_im; \
54  c##02_re = a*b##02_re; \
55  c##02_im = a*b##02_im; \
56  c##10_re = a*b##10_re; \
57  c##10_im = a*b##10_im; \
58  c##11_re = a*b##11_re; \
59  c##11_im = a*b##11_im; \
60  c##12_re = a*b##12_re; \
61  c##12_im = a*b##12_im; \
62  c##20_re = a*b##20_re; \
63  c##20_im = a*b##20_im; \
64  c##21_re = a*b##21_re; \
65  c##21_im = a*b##21_im; \
66  c##22_re = a*b##22_re; \
67  c##22_im = a*b##22_im; \
68 
69  /*
70  #define LOAD_MATRIX_12_SINGLE_DECLARE(gauge, dir, idx, var, stride) \
71  float2 var##0 = gauge[idx + dir*6*stride]; \
72  float2 var##1 = gauge[idx + dir*6*stride + stride]; \
73  float2 var##2 = gauge[idx + dir*6*stride + 2*stride]; \
74  float2 var##3 = gauge[idx + dir*6*stride + 3*stride]; \
75  float2 var##4 = gauge[idx + dir*6*stride + 4*stride]; \
76  float2 var##5 = gauge[idx + dir*6*stride + 5*stride]; \
77  float2 var##6, var##7, var##8;
78 
79  #define LOAD_MATRIX_12_SINGLE_TEX_DECLARE(gauge, dir, idx, var, stride) \
80  float2 var##0 = tex1Dfetch(gauge, idx + dir*6*stride); \
81  float2 var##1 = tex1Dfetch(gauge, idx + dir*6*stride + stride); \
82  float2 var##2 = tex1Dfetch(gauge, idx + dir*6*stride + 2*stride); \
83  float2 var##3 = tex1Dfetch(gauge, idx + dir*6*stride + 3*stride); \
84  float2 var##4 = tex1Dfetch(gauge, idx + dir*6*stride + 4*stride); \
85  float2 var##5 = tex1Dfetch(gauge, idx + dir*6*stride + 5*stride); \
86  float2 var##6, var##7, var##8;
87  */
88 #define LOAD_MATRIX_12_SINGLE_DECLARE(gauge, dir, idx, var, stride) \
89  float4 var##0 = gauge[idx + dir*3*stride]; \
90  float4 var##1 = gauge[idx + dir*3*stride + stride]; \
91  float4 var##2 = gauge[idx + dir*3*stride + 2*stride]; \
92  float4 var##3, var##4;
93 
94 #define LOAD_MATRIX_12_SINGLE_TEX_DECLARE(gauge, dir, idx, var, stride) \
95  float4 var##0 = tex1Dfetch(gauge, idx + dir*3*stride); \
96  float4 var##1 = tex1Dfetch(gauge, idx + dir*3*stride + stride); \
97  float4 var##2 = tex1Dfetch(gauge, idx + dir*3*stride + 2*stride); \
98  float4 var##3, var##4;
99 
100 #define LOAD_MATRIX_18_SINGLE_DECLARE(gauge, dir, idx, var, stride) \
101  float2 var##0 = gauge[idx + dir*9*stride]; \
102  float2 var##1 = gauge[idx + dir*9*stride + stride]; \
103  float2 var##2 = gauge[idx + dir*9*stride + 2*stride]; \
104  float2 var##3 = gauge[idx + dir*9*stride + 3*stride]; \
105  float2 var##4 = gauge[idx + dir*9*stride + 4*stride]; \
106  float2 var##5 = gauge[idx + dir*9*stride + 5*stride]; \
107  float2 var##6 = gauge[idx + dir*9*stride + 6*stride]; \
108  float2 var##7 = gauge[idx + dir*9*stride + 7*stride]; \
109  float2 var##8 = gauge[idx + dir*9*stride + 8*stride];
112 #define LOAD_MATRIX_18_SINGLE_TEX_DECLARE(gauge, dir, idx, var, stride) \
113  float2 var##0 = tex1Dfetch(gauge, idx + dir*9*stride); \
114  float2 var##1 = tex1Dfetch(gauge, idx + dir*9*stride + stride); \
115  float2 var##2 = tex1Dfetch(gauge, idx + dir*9*stride + 2*stride); \
116  float2 var##3 = tex1Dfetch(gauge, idx + dir*9*stride + 3*stride); \
117  float2 var##4 = tex1Dfetch(gauge, idx + dir*9*stride + 4*stride); \
118  float2 var##5 = tex1Dfetch(gauge, idx + dir*9*stride + 5*stride); \
119  float2 var##6 = tex1Dfetch(gauge, idx + dir*9*stride + 6*stride); \
120  float2 var##7 = tex1Dfetch(gauge, idx + dir*9*stride + 7*stride); \
121  float2 var##8 = tex1Dfetch(gauge, idx + dir*9*stride + 8*stride);
125 #define LOAD_MATRIX_18_DOUBLE_DECLARE(gauge, dir, idx, var, stride) \
126  double2 var##0 = gauge[idx + dir*9*stride]; \
127  double2 var##1 = gauge[idx + dir*9*stride + stride]; \
128  double2 var##2 = gauge[idx + dir*9*stride + 2*stride]; \
129  double2 var##3 = gauge[idx + dir*9*stride + 3*stride]; \
130  double2 var##4 = gauge[idx + dir*9*stride + 4*stride]; \
131  double2 var##5 = gauge[idx + dir*9*stride + 5*stride]; \
132  double2 var##6 = gauge[idx + dir*9*stride + 6*stride]; \
133  double2 var##7 = gauge[idx + dir*9*stride + 7*stride]; \
134  double2 var##8 = gauge[idx + dir*9*stride + 8*stride];
137 #define LOAD_MATRIX_18_DOUBLE_TEX_DECLARE(gauge_tex, gauge, dir, idx, var, stride) \
138  double2 var##0 = READ_DOUBLE2_TEXTURE(gauge_tex, gauge, idx + dir*9*stride); \
139  double2 var##1 = READ_DOUBLE2_TEXTURE(gauge_tex, gauge, idx + dir*9*stride + stride); \
140  double2 var##2 = READ_DOUBLE2_TEXTURE(gauge_tex, gauge, idx + dir*9*stride + 2*stride); \
141  double2 var##3 = READ_DOUBLE2_TEXTURE(gauge_tex, gauge, idx + dir*9*stride + 3*stride); \
142  double2 var##4 = READ_DOUBLE2_TEXTURE(gauge_tex, gauge, idx + dir*9*stride + 4*stride); \
143  double2 var##5 = READ_DOUBLE2_TEXTURE(gauge_tex, gauge, idx + dir*9*stride + 5*stride); \
144  double2 var##6 = READ_DOUBLE2_TEXTURE(gauge_tex, gauge, idx + dir*9*stride + 6*stride); \
145  double2 var##7 = READ_DOUBLE2_TEXTURE(gauge_tex, gauge, idx + dir*9*stride + 7*stride); \
146  double2 var##8 = READ_DOUBLE2_TEXTURE(gauge_tex, gauge, idx + dir*9*stride + 8*stride);
149 #define LOAD_MATRIX_12_DOUBLE_DECLARE(gauge, dir, idx, var, stride) \
150  double2 var##0 = gauge[idx + dir*6*stride]; \
151  double2 var##1 = gauge[idx + dir*6*stride + stride]; \
152  double2 var##2 = gauge[idx + dir*6*stride + 2*stride]; \
153  double2 var##3 = gauge[idx + dir*6*stride + 3*stride]; \
154  double2 var##4 = gauge[idx + dir*6*stride + 4*stride]; \
155  double2 var##5 = gauge[idx + dir*6*stride + 5*stride]; \
156  double2 var##6, var##7, var##8;
159 #define LOAD_MATRIX_12_DOUBLE_TEX_DECLARE(gauge_tex, gauge, dir, idx, var, stride) \
160  double2 var##0 = READ_DOUBLE2_TEXTURE(gauge_tex, gauge, idx + dir*6*stride); \
161  double2 var##1 = READ_DOUBLE2_TEXTURE(gauge_tex, gauge, idx + dir*6*stride + stride); \
162  double2 var##2 = READ_DOUBLE2_TEXTURE(gauge_tex, gauge, idx + dir*6*stride + 2*stride); \
163  double2 var##3 = READ_DOUBLE2_TEXTURE(gauge_tex, gauge, idx + dir*6*stride + 3*stride); \
164  double2 var##4 = READ_DOUBLE2_TEXTURE(gauge_tex, gauge, idx + dir*6*stride + 4*stride); \
165  double2 var##5 = READ_DOUBLE2_TEXTURE(gauge_tex, gauge, idx + dir*6*stride + 5*stride); \
166  double2 var##6, var##7, var##8;
168 #define LLFAT_ADD_SU3_MATRIX(ma, mb, mc) \
169  mc##00_re = ma##00_re + mb##00_re; \
170  mc##00_im = ma##00_im + mb##00_im; \
171  mc##01_re = ma##01_re + mb##01_re; \
172  mc##01_im = ma##01_im + mb##01_im; \
173  mc##02_re = ma##02_re + mb##02_re; \
174  mc##02_im = ma##02_im + mb##02_im; \
175  mc##10_re = ma##10_re + mb##10_re; \
176  mc##10_im = ma##10_im + mb##10_im; \
177  mc##11_re = ma##11_re + mb##11_re; \
178  mc##11_im = ma##11_im + mb##11_im; \
179  mc##12_re = ma##12_re + mb##12_re; \
180  mc##12_im = ma##12_im + mb##12_im; \
181  mc##20_re = ma##20_re + mb##20_re; \
182  mc##20_im = ma##20_im + mb##20_im; \
183  mc##21_re = ma##21_re + mb##21_re; \
184  mc##21_im = ma##21_im + mb##21_im; \
185  mc##22_re = ma##22_re + mb##22_re; \
186  mc##22_im = ma##22_im + mb##22_im;
191  __constant__ int dir1_array[16];
192  __constant__ int dir2_array[16];
194  unsigned long staple_bytes=0;
195 
196  void
198  {
199  static int llfat_init_cuda_flag = 0;
200  if (llfat_init_cuda_flag){
201  return;
202  }
203 
204  llfat_init_cuda_flag = 1;
205 
206  int Vh = param->X[0]*param->X[1]*param->X[2]*param->X[3]/2;
208 
211  fl_h.site_ga_stride = param->site_ga_pad + Vh;
212  fl_h.staple_stride = param->staple_pad + Vh;
213  fl_h.fat_ga_stride = param->llfat_ga_pad + Vh;
214  cudaMemcpyToSymbol(fl, &fl_h, sizeof(fat_force_const_t));
216  int dir1[16];
217  int dir2[16];
218  for(int nu =0; nu < 4; nu++)
219  for(int mu=0; mu < 4; mu++){
220  if(nu == mu) continue;
221  int d1, d2;
222  for(d1=0; d1 < 4; d1 ++){
223  if(d1 != nu && d1 != mu){
224  break;
225  }
226  }
227  dir1[nu*4+mu] = d1;
229  for(d2=0; d2 < 4; d2 ++){
230  if(d2 != nu && d2 != mu && d2 != d1){
231  break;
232  }
233  }
234 
235  dir2[nu*4+mu] = d2;
236  }
237 
238  cudaMemcpyToSymbol(dir1_array, &dir1, sizeof(dir1));
239  cudaMemcpyToSymbol(dir2_array, &dir2, sizeof(dir2));
240 
242  }
245  void
247  {
248  static int llfat_init_cuda_flag = 0;
249  if (llfat_init_cuda_flag){
250  return;
251  }
253  llfat_init_cuda_flag = 1;
254 
255  int Vh_ex = param_ex->X[0]*param_ex->X[1]*param_ex->X[2]*param_ex->X[3]/2;
256  int Vh = (param_ex->X[0]-4)*(param_ex->X[1]-4)*(param_ex->X[2]-4)*(param_ex->X[3]-4)/2;
257 
259  fl_h.site_ga_stride = param_ex->site_ga_pad + Vh_ex;
260  fl_h.staple_stride = param_ex->staple_pad + Vh_ex;
261  fl_h.fat_ga_stride = param_ex->llfat_ga_pad + Vh;
262  cudaMemcpyToSymbol(fl, &fl_h, sizeof(fat_force_const_t));
263  }
268 #define LLFAT_CONCAT(a,b) a##b##Kernel
269 #define LLFAT_CONCAT_EX(a,b) a##b##Kernel_ex
270 #define LLFAT_KERNEL(a,b) LLFAT_CONCAT(a,b)
271 #define LLFAT_KERNEL_EX(a,b) LLFAT_CONCAT_EX(a,b)
272 
273  //precision: 0 is for double, 1 is for single
275  //single precision, common macro
276 #define PRECISION 1
277 #define Float float
278 #define LOAD_FAT_MATRIX(gauge, dir, idx) LOAD_MATRIX_18_SINGLE_DECLARE(gauge, dir, idx, FAT, fl.fat_ga_stride)
279 #if (MULINK_LOAD_TEX == 1)
280 #define LOAD_EVEN_MULINK_MATRIX(dir, idx, var) LOAD_MATRIX_18_SINGLE_TEX_DECLARE((odd_bit?muLink1TexSingle:muLink0TexSingle), dir, idx, var, fl.staple_stride)
281 #define LOAD_ODD_MULINK_MATRIX(dir, idx, var) LOAD_MATRIX_18_SINGLE_TEX_DECLARE((odd_bit?muLink0TexSingle:muLink1TexSingle), dir, idx, var, fl.staple_stride)
282 #else
283 #define LOAD_EVEN_MULINK_MATRIX(dir, idx, var) LOAD_MATRIX_18_SINGLE_DECLARE(mulink_even, dir, idx, var, fl.staple_stride)
284 #define LOAD_ODD_MULINK_MATRIX(dir, idx, var) LOAD_MATRIX_18_SINGLE_DECLARE(mulink_odd, dir, idx, var, fl.staple_stride)
285 #endif
287 #if (FATLINK_LOAD_TEX == 1)
288 #define LOAD_EVEN_FAT_MATRIX(dir, idx) LOAD_MATRIX_18_SINGLE_TEX_DECLARE((odd_bit?fatGauge1TexSingle:fatGauge0TexSingle), dir, idx, FAT, fl.fat_ga_stride);
289 #define LOAD_ODD_FAT_MATRIX(dir, idx) LOAD_MATRIX_18_SINGLE_TEX_DECLARE((odd_bit?fatGauge0TexSingle:fatGauge1TexSingle), dir, idx, FAT, fl.fat_ga_stride);
290 #else
291 #define LOAD_EVEN_FAT_MATRIX(dir, idx) LOAD_MATRIX_18_SINGLE_DECLARE(fatlink_even, dir, idx, FAT, fl.fat_ga_stride)
292 #define LOAD_ODD_FAT_MATRIX(dir, idx) LOAD_MATRIX_18_SINGLE_DECLARE(fatlink_odd, dir, idx, FAT, fl.fat_ga_stride)
293 #endif
296  //single precision, 12-reconstruct
297 #define DECLARE_VAR_SIGN short sign=1
298 #define SITELINK0TEX siteLink0TexSingle_recon
299 #define SITELINK1TEX siteLink1TexSingle_recon
300 #if (SITE_MATRIX_LOAD_TEX == 1)
301 #define LOAD_EVEN_SITE_MATRIX(dir, idx, var) LOAD_MATRIX_12_SINGLE_TEX_DECLARE((odd_bit?SITELINK1TEX:SITELINK0TEX), dir, idx, var, fl.site_ga_stride)
302 #define LOAD_ODD_SITE_MATRIX(dir, idx, var) LOAD_MATRIX_12_SINGLE_TEX_DECLARE((odd_bit?SITELINK0TEX:SITELINK1TEX), dir, idx, var, fl.site_ga_stride)
303 #else
304 #define LOAD_EVEN_SITE_MATRIX(dir, idx, var) LOAD_MATRIX_12_SINGLE_DECLARE(sitelink_even, dir, idx, var, fl.site_ga_stride)
305 #define LOAD_ODD_SITE_MATRIX(dir, idx, var) LOAD_MATRIX_12_SINGLE_DECLARE(sitelink_odd, dir, idx, var, fl.site_ga_stride)
306 #endif
307 #define LOAD_SITE_MATRIX(sitelink, dir, idx, var) LOAD_MATRIX_12_SINGLE_DECLARE(sitelink, dir, idx, var, fl.site_ga_stride)
308 
309 #define RECONSTRUCT_SITE_LINK(sign, var) RECONSTRUCT_LINK_12(sign, var);
310 #define FloatN float4
311 #define FloatM float2
312 #define RECONSTRUCT 12
313 #define sd_data float_12_sd_data
314 #include "llfat_core.h"
315 #undef DECLARE_VAR_SIGN
316 #undef SITELINK0TEX
317 #undef SITELINK1TEX
318 #undef LOAD_EVEN_SITE_MATRIX
319 #undef LOAD_ODD_SITE_MATRIX
320 #undef LOAD_SITE_MATRIX
321 #undef RECONSTRUCT_SITE_LINK
322 #undef FloatN
323 #undef FloatM
324 #undef RECONSTRUCT
325 #undef sd_data
326 
327  //single precision, 18-reconstruct
328 #define SITELINK0TEX siteLink0TexSingle_norecon
329 #define SITELINK1TEX siteLink1TexSingle_norecon
330 #if (SITE_MATRIX_LOAD_TEX == 1)
331 #define LOAD_EVEN_SITE_MATRIX(dir, idx, var) LOAD_MATRIX_18_SINGLE_TEX_DECLARE((odd_bit?SITELINK1TEX:SITELINK0TEX), dir, idx, var, fl.site_ga_stride)
332 #define LOAD_ODD_SITE_MATRIX(dir, idx, var) LOAD_MATRIX_18_SINGLE_TEX_DECLARE((odd_bit?SITELINK0TEX:SITELINK1TEX), dir, idx, var, fl.site_ga_stride)
333 #else
334 #define LOAD_EVEN_SITE_MATRIX(dir, idx, var) LOAD_MATRIX_18_SINGLE_DECLARE(sitelink_even, dir, idx, var, fl.site_ga_stride)
335 #define LOAD_ODD_SITE_MATRIX(dir, idx, var) LOAD_MATRIX_18_SINGLE_DECLARE(sitelink_odd, dir, idx, var, fl.site_ga_stride)
336 #endif
337 #define LOAD_SITE_MATRIX(sitelink, dir, idx, var) LOAD_MATRIX_18_SINGLE_DECLARE(sitelink, dir, idx, var, fl.site_ga_stride)
338 #define RECONSTRUCT_SITE_LINK(sign, var)
339 #define FloatN float2
340 #define FloatM float2
341 #define RECONSTRUCT 18
342 #define sd_data float_18_sd_data
343 #include "llfat_core.h"
344 #undef SITELINK0TEX
345 #undef SITELINK1TEX
346 #undef LOAD_EVEN_SITE_MATRIX
347 #undef LOAD_ODD_SITE_MATRIX
348 #undef LOAD_SITE_MATRIX
349 #undef RECONSTRUCT_SITE_LINK
350 #undef FloatN
351 #undef FloatM
352 #undef RECONSTRUCT
353 #undef sd_data
354 
355 
356 #undef PRECISION
357 #undef Float
358 #undef LOAD_FAT_MATRIX
359 #undef LOAD_EVEN_MULINK_MATRIX
360 #undef LOAD_ODD_MULINK_MATRIX
361 #undef LOAD_EVEN_FAT_MATRIX
362 #undef LOAD_ODD_FAT_MATRIX
363 
364 
365  //double precision, common macro
366 #define PRECISION 0
367 #define Float double
368 #define LOAD_FAT_MATRIX(gauge, dir, idx) LOAD_MATRIX_18_DOUBLE_DECLARE(gauge, dir, idx, FAT, fl.fat_ga_stride)
369 #if (MULINK_LOAD_TEX == 1)
370 #define LOAD_EVEN_MULINK_MATRIX(dir, idx, var) LOAD_MATRIX_18_DOUBLE_TEX_DECLARE(odd_bit?muLink1TexDouble:muLink0TexDouble), mulink_even, dir, idx, var, fl.staple_stride)
371 #define LOAD_ODD_MULINK_MATRIX(dir, idx, var) LOAD_MATRIX_18_DOUBLE_TEX_DECLARE((odd_bit?muLink0TexDouble:muLink1TexDouble), mulink_odd, dir, idx, var, fl.staple_stride)
372 #else
373 #define LOAD_EVEN_MULINK_MATRIX(dir, idx, var) LOAD_MATRIX_18_DOUBLE(mulink_even, dir, idx, var, fl.staple_stride)
374 #define LOAD_ODD_MULINK_MATRIX(dir, idx, var) LOAD_MATRIX_18_DOUBLE(mulink_odd, dir, idx, var, fl.staple_stride)
375 #endif
376 
377 #if (FATLINK_LOAD_TEX == 1)
378 #define LOAD_EVEN_FAT_MATRIX(dir, idx) LOAD_MATRIX_18_DOUBLE_TEX_DECLARE((odd_bit?fatGauge1TexDouble:fatGauge0TexDouble), fatlink_even, dir, idx, FAT, fl.fat_ga_stride)
379 #define LOAD_ODD_FAT_MATRIX(dir, idx) LOAD_MATRIX_18_DOUBLE_TEX_DECLARE((odd_bit?fatGauge0TexDouble:fatGauge1TexDouble), fatlink_odd, dir, idx, FAT, fl.fat_ga_stride)
380 #else
381 #define LOAD_EVEN_FAT_MATRIX(dir, idx) LOAD_MATRIX_18_DOUBLE_DECLARE(fatlink_even, dir, idx, FAT, fl.fat_ga_stride)
382 #define LOAD_ODD_FAT_MATRIX(dir, idx) LOAD_MATRIX_18_DOUBLE_DECLARE(fatlink_odd, dir, idx, FAT, fl.fat_ga_stride)
383 #endif
384 
385  //double precision, 18-reconstruct
386 #define SITELINK0TEX siteLink0TexDouble
387 #define SITELINK1TEX siteLink1TexDouble
388 #if (SITE_MATRIX_LOAD_TEX == 1)
389 #define LOAD_EVEN_SITE_MATRIX(dir, idx, var) LOAD_MATRIX_18_DOUBLE_TEX_DECLARE((odd_bit?SITELINK1TEX:SITELINK0TEX), sitelink_even, dir, idx, var, fl.site_ga_stride)
390 #define LOAD_ODD_SITE_MATRIX(dir, idx, var) LOAD_MATRIX_18_DOUBLE_TEX_DECLARE((odd_bit?SITELINK0TEX:SITELINK1TEX), sitelink_odd, dir, idx, var, fl.site_ga_stride)
391 #else
392 #define LOAD_EVEN_SITE_MATRIX(dir, idx, var) LOAD_MATRIX_18_DOUBLE_DECLARE(sitelink_even, dir, idx, var, fl.site_ga_stride)
393 #define LOAD_ODD_SITE_MATRIX(dir, idx, var) LOAD_MATRIX_18_DOUBLE_DECLARE(sitelink_odd, dir, idx, var, fl.site_ga_stride)
394 #endif
395 #define LOAD_SITE_MATRIX(sitelink, dir, idx, var) LOAD_MATRIX_18_DOUBLE_DECLARE(sitelink, dir, idx, var, fl.site_ga_stride)
396 #define RECONSTRUCT_SITE_LINK(sign, var)
397 #define FloatN double2
398 #define FloatM double2
399 #define RECONSTRUCT 18
400 #define sd_data double_18_sd_data
401 #include "llfat_core.h"
402 #undef SITELINK0TEX
403 #undef SITELINK1TEX
404 #undef LOAD_EVEN_SITE_MATRIX
405 #undef LOAD_ODD_SITE_MATRIX
406 #undef LOAD_SITE_MATRIX
407 #undef RECONSTRUCT_SITE_LINK
408 #undef FloatN
409 #undef FloatM
410 #undef RECONSTRUCT
411 #undef sd_data
412 
414 
415 #if 1
416  //double precision, 12-reconstruct
417 #define SITELINK0TEX siteLink0TexDouble
418 #define SITELINK1TEX siteLink1TexDouble
419 #if (SITE_MATRIX_LOAD_TEX == 1)
420 #define LOAD_EVEN_SITE_MATRIX(dir, idx, var) LOAD_MATRIX_12_DOUBLE_TEX_DECLARE((odd_bit?SITELINK1TEX:SITELINK0TEX), sitelink_even, dir, idx, var, fl.site_ga_stride)
421 #define LOAD_ODD_SITE_MATRIX(dir, idx, var) LOAD_MATRIX_12_DOUBLE_TEX_DECLARE((odd_bit?SITELINK0TEX:SITELINK1TEX), sitelink_odd, dir, idx, var, fl.site_ga_stride)
422 #else
423 #define LOAD_EVEN_SITE_MATRIX(dir, idx, var) LOAD_MATRIX_12_DOUBLE_DECLARE(sitelink_even, dir, idx, var, fl.site_ga_stride)
424 #define LOAD_ODD_SITE_MATRIX(dir, idx, var) LOAD_MATRIX_12_DOUBLE_DECLARE(sitelink_odd, dir, idx, var, fl.site_ga_stride)
425 #endif
426 #define LOAD_SITE_MATRIX(sitelink, dir, idx, var) LOAD_MATRIX_12_DOUBLE_DECLARE(sitelink, dir, idx, var, fl.site_ga_stride)
427 #define RECONSTRUCT_SITE_LINK(sign, var) RECONSTRUCT_LINK_12(sign, var);
428 #define FloatN double2
429 #define FloatM double2
430 #define RECONSTRUCT 12
431 #define sd_data double_12_sd_data
432 #include "llfat_core.h"
433 #undef SITELINK0TEX
434 #undef SITELINK1TEX
435 #undef LOAD_EVEN_SITE_MATRIX
436 #undef LOAD_ODD_SITE_MATRIX
437 #undef LOAD_SITE_MATRIX
438 #undef RECONSTRUCT_SITE_LINK
439 #undef FloatN
440 #undef FloatM
441 #undef RECONSTRUCT
442 #undef sd_data
443 #endif
444 
445 #undef PRECISION
446 #undef Float
447 #undef LOAD_FAT_MATRIX
448 #undef LOAD_EVEN_MULINK_MATRIX
449 #undef LOAD_ODD_MULINK_MATRIX
450 #undef LOAD_EVEN_FAT_MATRIX
451 #undef LOAD_ODD_FAT_MATRIX
452 
453 #undef LLFAT_CONCAT
454 #undef LLFAT_KERNEL
455 
456 #define UNBIND_ALL_TEXTURE do{ \
457  if(prec ==QUDA_DOUBLE_PRECISION){ \
458  cudaUnbindTexture(siteLink0TexDouble); \
459  cudaUnbindTexture(siteLink1TexDouble); \
460  cudaUnbindTexture(fatGauge0TexDouble); \
461  cudaUnbindTexture(fatGauge1TexDouble); \
462  cudaUnbindTexture(muLink0TexDouble); \
463  cudaUnbindTexture(muLink1TexDouble); \
464  }else{ \
465  if(cudaSiteLink.reconstruct == QUDA_RECONSTRUCT_NO){ \
466  cudaUnbindTexture(siteLink0TexSingle_norecon); \
467  cudaUnbindTexture(siteLink1TexSingle_norecon); \
468  }else{ \
469  cudaUnbindTexture(siteLink0TexSingle_recon); \
470  cudaUnbindTexture(siteLink1TexSingle_recon); \
471  } \
472  cudaUnbindTexture(fatGauge0TexSingle); \
473  cudaUnbindTexture(fatGauge1TexSingle); \
474  cudaUnbindTexture(muLink0TexSingle); \
475  cudaUnbindTexture(muLink1TexSingle); \
476  } \
477  }while(0)
478 
479 #define UNBIND_SITE_AND_FAT_LINK do{ \
480  if(prec == QUDA_DOUBLE_PRECISION){ \
481  cudaUnbindTexture(siteLink0TexDouble); \
482  cudaUnbindTexture(siteLink1TexDouble); \
483  cudaUnbindTexture(fatGauge0TexDouble); \
484  cudaUnbindTexture(fatGauge1TexDouble); \
485  }else { \
486  if(cudaSiteLink.reconstruct == QUDA_RECONSTRUCT_NO){ \
487  cudaUnbindTexture(siteLink0TexSingle_norecon); \
488  cudaUnbindTexture(siteLink1TexSingle_norecon); \
489  }else{ \
490  cudaUnbindTexture(siteLink0TexSingle_recon); \
491  cudaUnbindTexture(siteLink1TexSingle_recon); \
492  } \
493  cudaUnbindTexture(fatGauge0TexSingle); \
494  cudaUnbindTexture(fatGauge1TexSingle); \
495  } \
496  }while(0)
497 
498 
499 #define BIND_MU_LINK() do{ \
500  if(prec == QUDA_DOUBLE_PRECISION){ \
501  cudaBindTexture(0, muLink0TexDouble, mulink_even, staple_bytes); \
502  cudaBindTexture(0, muLink1TexDouble, mulink_odd, staple_bytes); \
503  }else{ \
504  cudaBindTexture(0, muLink0TexSingle, mulink_even, staple_bytes); \
505  cudaBindTexture(0, muLink1TexSingle, mulink_odd, staple_bytes); \
506  } \
507  }while(0)
508 
509 #define UNBIND_MU_LINK() do{ \
510  if(prec == QUDA_DOUBLE_PRECISION){ \
511  cudaUnbindTexture(muLink0TexSingle); \
512  cudaUnbindTexture(muLink1TexSingle); \
513  }else{ \
514  cudaUnbindTexture(muLink0TexDouble); \
515  cudaUnbindTexture(muLink1TexDouble); \
516  } \
517  }while(0)
518 
519 
520 #define BIND_SITE_AND_FAT_LINK do { \
521  if(prec == QUDA_DOUBLE_PRECISION){ \
522  cudaBindTexture(0, siteLink0TexDouble, cudaSiteLink.Even_p(), cudaSiteLink.Bytes()); \
523  cudaBindTexture(0, siteLink1TexDouble, cudaSiteLink.Odd_p(), cudaSiteLink.Bytes()); \
524  cudaBindTexture(0, fatGauge0TexDouble, cudaFatLink.Even_p(), cudaFatLink.Bytes()); \
525  cudaBindTexture(0, fatGauge1TexDouble, cudaFatLink.Odd_p(), cudaFatLink.Bytes()); \
526  }else{ \
527  if(cudaSiteLink.Reconstruct() == QUDA_RECONSTRUCT_NO){ \
528  cudaBindTexture(0, siteLink0TexSingle_norecon, cudaSiteLink.Even_p(), cudaSiteLink.Bytes()); \
529  cudaBindTexture(0, siteLink1TexSingle_norecon, cudaSiteLink.Odd_p(), cudaSiteLink.Bytes()); \
530  }else{ \
531  cudaBindTexture(0, siteLink0TexSingle_recon, cudaSiteLink.Even_p(), cudaSiteLink.Bytes()); \
532  cudaBindTexture(0, siteLink1TexSingle_recon, cudaSiteLink.Odd_p(), cudaSiteLink.Bytes()); \
533  } \
534  cudaBindTexture(0, fatGauge0TexSingle, cudaFatLink.Even_p(), cudaFatLink.Bytes()); \
535  cudaBindTexture(0, fatGauge1TexSingle, cudaFatLink.Odd_p(), cudaFatLink.Bytes()); \
536  } \
537  }while(0)
538 
539 #define BIND_MU_LINK() do{ \
540  if(prec == QUDA_DOUBLE_PRECISION){ \
541  cudaBindTexture(0, muLink0TexDouble, mulink_even, staple_bytes); \
542  cudaBindTexture(0, muLink1TexDouble, mulink_odd, staple_bytes); \
543  }else{ \
544  cudaBindTexture(0, muLink0TexSingle, mulink_even, staple_bytes); \
545  cudaBindTexture(0, muLink1TexSingle, mulink_odd, staple_bytes); \
546  } \
547  }while(0)
548 
549 #define UNBIND_MU_LINK() do{ \
550  if(prec == QUDA_DOUBLE_PRECISION){ \
551  cudaUnbindTexture(muLink0TexSingle); \
552  cudaUnbindTexture(muLink1TexSingle); \
553  }else{ \
554  cudaUnbindTexture(muLink0TexDouble); \
555  cudaUnbindTexture(muLink1TexDouble); \
556  } \
557  }while(0)
558 
559 #define BIND_SITE_AND_FAT_LINK_REVERSE do { \
560  if(prec == QUDA_DOUBLE_PRECISION){ \
561  cudaBindTexture(0, siteLink1TexDouble, cudaSiteLink.even, cudaSiteLink.bytes); \
562  cudaBindTexture(0, siteLink0TexDouble, cudaSiteLink.odd, cudaSiteLink.bytes); \
563  cudaBindTexture(0, fatGauge1TexDouble, cudaFatLink.even, cudaFatLink.bytes); \
564  cudaBindTexture(0, fatGauge0TexDouble, cudaFatLink.odd, cudaFatLink.bytes); \
565  }else{ \
566  if(cudaSiteLink.reconstruct == QUDA_RECONSTRUCT_NO){ \
567  cudaBindTexture(0, siteLink1TexSingle_norecon, cudaSiteLink.even, cudaSiteLink.bytes); \
568  cudaBindTexture(0, siteLink0TexSingle_norecon, cudaSiteLink.odd, cudaSiteLink.bytes); \
569  }else{ \
570  cudaBindTexture(0, siteLink1TexSingle_recon, cudaSiteLink.even, cudaSiteLink.bytes); \
571  cudaBindTexture(0, siteLink0TexSingle_recon, cudaSiteLink.odd, cudaSiteLink.bytes); \
572  } \
573  cudaBindTexture(0, fatGauge1TexSingle, cudaFatLink.even, cudaFatLink.bytes); \
574  cudaBindTexture(0, fatGauge0TexSingle, cudaFatLink.odd, cudaFatLink.bytes); \
575  } \
576  }while(0)
577 
578 
579 
580 #define ENUMERATE_FUNCS(mu,nu) switch(mu) { \
581  case 0: \
582  switch(nu){ \
583  case 0: \
584  printf("ERROR: invalid direction combination\n"); exit(1); \
585  break; \
586  case 1: \
587  CALL_FUNCTION(0,1); \
588  break; \
589  case 2: \
590  CALL_FUNCTION(0,2); \
591  break; \
592  case 3: \
593  CALL_FUNCTION(0,3); \
594  break; \
595  } \
596  break; \
597  case 1: \
598  switch(nu){ \
599  case 0: \
600  CALL_FUNCTION(1,0); \
601  break; \
602  case 1: \
603  printf("ERROR: invalid direction combination\n"); exit(1); \
604  break; \
605  case 2: \
606  CALL_FUNCTION(1,2); \
607  break; \
608  case 3: \
609  CALL_FUNCTION(1,3); \
610  break; \
611  } \
612  break; \
613  case 2: \
614  switch(nu){ \
615  case 0: \
616  CALL_FUNCTION(2,0); \
617  break; \
618  case 1: \
619  CALL_FUNCTION(2,1); \
620  break; \
621  case 2: \
622  printf("ERROR: invalid direction combination\n"); exit(1); \
623  break; \
624  case 3: \
625  CALL_FUNCTION(2,3); \
626  break; \
627  } \
628  break; \
629  case 3: \
630  switch(nu){ \
631  case 0: \
632  CALL_FUNCTION(3,0); \
633  break; \
634  case 1: \
635  CALL_FUNCTION(3,1); \
636  break; \
637  case 2: \
638  CALL_FUNCTION(3,2); \
639  break; \
640  case 3: \
641  printf("ERROR: invalid direction combination\n"); exit(1); \
642  break; \
643  } \
644  break; \
645  }
646 
647 #define ENUMERATE_FUNCS_SAVE(mu,nu, save_staple) if(save_staple){ \
648  switch(mu) { \
649  case 0: \
650  switch(nu){ \
651  case 0: \
652  printf("ERROR: invalid direction combination\n"); exit(1); \
653  break; \
654  case 1: \
655  CALL_FUNCTION(0,1,1); \
656  break; \
657  case 2: \
658  CALL_FUNCTION(0,2,1); \
659  break; \
660  case 3: \
661  CALL_FUNCTION(0,3,1); \
662  break; \
663  } \
664  break; \
665  case 1: \
666  switch(nu){ \
667  case 0: \
668  CALL_FUNCTION(1,0,1); \
669  break; \
670  case 1: \
671  printf("ERROR: invalid direction combination\n"); exit(1); \
672  break; \
673  case 2: \
674  CALL_FUNCTION(1,2,1); \
675  break; \
676  case 3: \
677  CALL_FUNCTION(1,3,1); \
678  break; \
679  } \
680  break; \
681  case 2: \
682  switch(nu){ \
683  case 0: \
684  CALL_FUNCTION(2,0,1); \
685  break; \
686  case 1: \
687  CALL_FUNCTION(2,1,1); \
688  break; \
689  case 2: \
690  printf("ERROR: invalid direction combination\n"); exit(1); \
691  break; \
692  case 3: \
693  CALL_FUNCTION(2,3,1); \
694  break; \
695  } \
696  break; \
697  case 3: \
698  switch(nu){ \
699  case 0: \
700  CALL_FUNCTION(3,0,1); \
701  break; \
702  case 1: \
703  CALL_FUNCTION(3,1,1); \
704  break; \
705  case 2: \
706  CALL_FUNCTION(3,2,1); \
707  break; \
708  case 3: \
709  printf("ERROR: invalid direction combination\n"); exit(1); \
710  break; \
711  } \
712  break; \
713  } \
714  }else{ \
715  switch(mu) { \
716  case 0: \
717  switch(nu){ \
718  case 0: \
719  printf("ERROR: invalid direction combination\n"); exit(1); \
720  break; \
721  case 1: \
722  CALL_FUNCTION(0,1,0); \
723  break; \
724  case 2: \
725  CALL_FUNCTION(0,2,0); \
726  break; \
727  case 3: \
728  CALL_FUNCTION(0,3,0); \
729  break; \
730  } \
731  break; \
732  case 1: \
733  switch(nu){ \
734  case 0: \
735  CALL_FUNCTION(1,0,0); \
736  break; \
737  case 1: \
738  printf("ERROR: invalid direction combination\n"); exit(1); \
739  break; \
740  case 2: \
741  CALL_FUNCTION(1,2,0); \
742  break; \
743  case 3: \
744  CALL_FUNCTION(1,3,0); \
745  break; \
746  } \
747  break; \
748  case 2: \
749  switch(nu){ \
750  case 0: \
751  CALL_FUNCTION(2,0,0); \
752  break; \
753  case 1: \
754  CALL_FUNCTION(2,1,0); \
755  break; \
756  case 2: \
757  printf("ERROR: invalid direction combination\n"); exit(1); \
758  break; \
759  case 3: \
760  CALL_FUNCTION(2,3,0); \
761  break; \
762  } \
763  break; \
764  case 3: \
765  switch(nu){ \
766  case 0: \
767  CALL_FUNCTION(3,0,0); \
768  break; \
769  case 1: \
770  CALL_FUNCTION(3,1,0); \
771  break; \
772  case 2: \
773  CALL_FUNCTION(3,2,0); \
774  break; \
775  case 3: \
776  printf("ERROR: invalid direction combination\n"); exit(1); \
777  break; \
778  } \
779  break; \
780  } \
781  }
783  void siteComputeGenStapleParityKernel(void* staple_even, void* staple_odd,
784  const void* sitelink_even, const void* sitelink_odd,
785  void* fatlink_even, void* fatlink_odd,
786  int mu, int nu, double mycoeff,
788  dim3 halfGridDim, llfat_kernel_param_t kparam,
789  cudaStream_t* stream)
790  {
791 
792  //compute even and odd
793 
794 #define CALL_FUNCTION(mu, nu) \
795  if (prec == QUDA_DOUBLE_PRECISION){ \
796  if(recon == QUDA_RECONSTRUCT_NO){ \
797  do_siteComputeGenStapleParity18Kernel<mu,nu, 0> \
798  <<<halfGridDim, blockDim, 0, *stream>>>((double2*)staple_even, (double2*)staple_odd, \
799  (const double2*)sitelink_even, (const double2*)sitelink_odd, \
800  (double2*)fatlink_even, (double2*)fatlink_odd, \
801  (double)mycoeff, kparam); \
802  do_siteComputeGenStapleParity18Kernel<mu,nu, 1> \
803  <<<halfGridDim, blockDim, 0, *stream>>>((double2*)staple_odd, (double2*)staple_even, \
804  (const double2*)sitelink_odd, (const double2*)sitelink_even, \
805  (double2*)fatlink_odd, (double2*)fatlink_even, \
806  (double)mycoeff, kparam); \
807  }else{ \
808  do_siteComputeGenStapleParity12Kernel<mu,nu, 0> \
809  <<<halfGridDim, blockDim, 0, *stream>>>((double2*)staple_even, (double2*)staple_odd, \
810  (const double2*)sitelink_even, (const double2*)sitelink_odd, \
811  (double2*)fatlink_even, (double2*)fatlink_odd, \
812  (double)mycoeff, kparam); \
813  do_siteComputeGenStapleParity12Kernel<mu,nu, 1> \
814  <<<halfGridDim, blockDim, 0, *stream>>>((double2*)staple_odd, (double2*)staple_even, \
815  (const double2*)sitelink_odd, (const double2*)sitelink_even, \
816  (double2*)fatlink_odd, (double2*)fatlink_even, \
817  (double)mycoeff, kparam); \
818  } \
819  }else { \
820  if(recon == QUDA_RECONSTRUCT_NO){ \
821  do_siteComputeGenStapleParity18Kernel<mu,nu, 0> \
822  <<<halfGridDim, blockDim, 0, *stream>>>((float2*)staple_even, (float2*)staple_odd, \
823  (const float2*)sitelink_even, (const float2*)sitelink_odd, \
824  (float2*)fatlink_even, (float2*)fatlink_odd, \
825  (float)mycoeff, kparam); \
826  do_siteComputeGenStapleParity18Kernel<mu,nu, 1> \
827  <<<halfGridDim, blockDim, 0, *stream>>>((float2*)staple_odd, (float2*)staple_even, \
828  (const float2*)sitelink_odd, (const float2*)sitelink_even, \
829  (float2*)fatlink_odd, (float2*)fatlink_even, \
830  (float)mycoeff, kparam); \
831  }else{ \
832  do_siteComputeGenStapleParity12Kernel<mu,nu, 0> \
833  <<<halfGridDim, blockDim, 0, *stream>>>((float2*)staple_even, (float2*)staple_odd, \
834  (const float4*)sitelink_even, (const float4*)sitelink_odd, \
835  (float2*)fatlink_even, (float2*)fatlink_odd, \
836  (float)mycoeff, kparam); \
837  do_siteComputeGenStapleParity12Kernel<mu,nu, 1> \
838  <<<halfGridDim, blockDim, 0, *stream>>>((float2*)staple_odd, (float2*)staple_even, \
839  (const float4*)sitelink_odd, (const float4*)sitelink_even, \
840  (float2*)fatlink_odd, (float2*)fatlink_even, \
841  (float)mycoeff, kparam); \
842  } \
843  }
844 
845 
846  dim3 blockDim(BLOCK_DIM , 1, 1);
847  ENUMERATE_FUNCS(mu,nu);
848 
849 #undef CALL_FUNCTION
850 
851 
852  }
853 
854 
855  void
857  const void* sitelink_even, const void* sitelink_odd,
858  void* fatlink_even, void* fatlink_odd,
859  const void* mulink_even, const void* mulink_odd,
860  int mu, int nu, int save_staple,
861  double mycoeff,
863  dim3 halfGridDim, llfat_kernel_param_t kparam,
864  cudaStream_t* stream)
865  {
866 
867 #define CALL_FUNCTION(mu, nu, save_staple) \
868  if (prec == QUDA_DOUBLE_PRECISION){ \
869  if(recon == QUDA_RECONSTRUCT_NO){ \
870  do_computeGenStapleFieldParity18Kernel<mu,nu, 0, save_staple> \
871  <<<halfGridDim, blockDim, 0, *stream>>>((double2*)staple_even, (double2*)staple_odd, \
872  (const double2*)sitelink_even, (const double2*)sitelink_odd, \
873  (double2*)fatlink_even, (double2*)fatlink_odd, \
874  (const double2*)mulink_even, (const double2*)mulink_odd, \
875  (double)mycoeff, kparam); \
876  do_computeGenStapleFieldParity18Kernel<mu,nu, 1, save_staple> \
877  <<<halfGridDim, blockDim, 0, *stream>>>((double2*)staple_odd, (double2*)staple_even, \
878  (const double2*)sitelink_odd, (const double2*)sitelink_even, \
879  (double2*)fatlink_odd, (double2*)fatlink_even, \
880  (const double2*)mulink_odd, (const double2*)mulink_even, \
881  (double)mycoeff, kparam); \
882  }else{ \
883  do_computeGenStapleFieldParity12Kernel<mu,nu, 0, save_staple> \
884  <<<halfGridDim, blockDim, 0, *stream>>>((double2*)staple_even, (double2*)staple_odd, \
885  (const double2*)sitelink_even, (const double2*)sitelink_odd, \
886  (double2*)fatlink_even, (double2*)fatlink_odd, \
887  (const double2*)mulink_even, (const double2*)mulink_odd, \
888  (double)mycoeff, kparam); \
889  do_computeGenStapleFieldParity12Kernel<mu,nu, 1, save_staple> \
890  <<<halfGridDim, blockDim, 0, *stream>>>((double2*)staple_odd, (double2*)staple_even, \
891  (const double2*)sitelink_odd, (const double2*)sitelink_even, \
892  (double2*)fatlink_odd, (double2*)fatlink_even, \
893  (const double2*)mulink_odd, (const double2*)mulink_even, \
894  (double)mycoeff, kparam); \
895  } \
896  }else{ \
897  if(recon == QUDA_RECONSTRUCT_NO){ \
898  do_computeGenStapleFieldParity18Kernel<mu,nu, 0, save_staple> \
899  <<<halfGridDim, blockDim, 0, *stream>>>((float2*)staple_even, (float2*)staple_odd, \
900  (const float2*)sitelink_even, (const float2*)sitelink_odd, \
901  (float2*)fatlink_even, (float2*)fatlink_odd, \
902  (const float2*)mulink_even, (const float2*)mulink_odd, \
903  (float)mycoeff, kparam); \
904  do_computeGenStapleFieldParity18Kernel<mu,nu, 1, save_staple> \
905  <<<halfGridDim, blockDim, 0, *stream>>>((float2*)staple_odd, (float2*)staple_even, \
906  (const float2*)sitelink_odd, (const float2*)sitelink_even, \
907  (float2*)fatlink_odd, (float2*)fatlink_even, \
908  (const float2*)mulink_odd, (const float2*)mulink_even, \
909  (float)mycoeff, kparam); \
910  }else{ \
911  do_computeGenStapleFieldParity12Kernel<mu,nu, 0, save_staple> \
912  <<<halfGridDim, blockDim, 0, *stream>>>((float2*)staple_even, (float2*)staple_odd, \
913  (const float4*)sitelink_even, (const float4*)sitelink_odd, \
914  (float2*)fatlink_even, (float2*)fatlink_odd, \
915  (const float2*)mulink_even, (const float2*)mulink_odd, \
916  (float)mycoeff, kparam); \
917  do_computeGenStapleFieldParity12Kernel<mu,nu, 1, save_staple> \
918  <<<halfGridDim, blockDim, 0, *stream>>>((float2*)staple_odd, (float2*)staple_even, \
919  (const float4*)sitelink_odd, (const float4*)sitelink_even, \
920  (float2*)fatlink_odd, (float2*)fatlink_even, \
921  (const float2*)mulink_odd, (const float2*)mulink_even, \
922  (float)mycoeff, kparam); \
923  } \
924  }
925 
926  BIND_MU_LINK();
927  dim3 blockDim(BLOCK_DIM , 1, 1);
928  ENUMERATE_FUNCS_SAVE(mu,nu,save_staple);
929 
930  UNBIND_MU_LINK();
931 
932 #undef CALL_FUNCTION
933 
934  }
935 
936 
937  void siteComputeGenStapleParityKernel_ex(void* staple_even, void* staple_odd,
938  const void* sitelink_even, const void* sitelink_odd,
939  void* fatlink_even, void* fatlink_odd,
940  int mu, int nu, double mycoeff,
943  {
944 
945  //compute even and odd
946  dim3 blockDim = kparam.blockDim;
947  dim3 halfGridDim = kparam.halfGridDim;
948  int sbytes_dp = blockDim.x*5*sizeof(double2);
949  int sbytes_sp = blockDim.x*5*sizeof(float2);
950 
951 #define CALL_FUNCTION(mu, nu) \
952  if (prec == QUDA_DOUBLE_PRECISION){ \
953  if(recon == QUDA_RECONSTRUCT_NO){ \
954  do_siteComputeGenStapleParity18Kernel_ex<mu,nu, 0> \
955  <<<halfGridDim, blockDim, sbytes_dp>>>((double2*)staple_even, (double2*)staple_odd, \
956  (const double2*)sitelink_even, (const double2*)sitelink_odd, \
957  (double2*)fatlink_even, (double2*)fatlink_odd, \
958  (double)mycoeff, kparam); \
959  do_siteComputeGenStapleParity18Kernel_ex<mu,nu, 1> \
960  <<<halfGridDim, blockDim, sbytes_dp>>>((double2*)staple_odd, (double2*)staple_even, \
961  (const double2*)sitelink_odd, (const double2*)sitelink_even, \
962  (double2*)fatlink_odd, (double2*)fatlink_even, \
963  (double)mycoeff, kparam); \
964  }else{ \
965  do_siteComputeGenStapleParity12Kernel_ex<mu,nu, 0> \
966  <<<halfGridDim, blockDim, sbytes_dp>>>((double2*)staple_even, (double2*)staple_odd, \
967  (const double2*)sitelink_even, (const double2*)sitelink_odd, \
968  (double2*)fatlink_even, (double2*)fatlink_odd, \
969  (double)mycoeff, kparam); \
970  do_siteComputeGenStapleParity12Kernel_ex<mu,nu, 1> \
971  <<<halfGridDim, blockDim, sbytes_dp>>>((double2*)staple_odd, (double2*)staple_even, \
972  (const double2*)sitelink_odd, (const double2*)sitelink_even, \
973  (double2*)fatlink_odd, (double2*)fatlink_even, \
974  (double)mycoeff, kparam); \
975  } \
976  }else { \
977  if(recon == QUDA_RECONSTRUCT_NO){ \
978  do_siteComputeGenStapleParity18Kernel_ex<mu,nu, 0> \
979  <<<halfGridDim, blockDim, sbytes_sp>>>((float2*)staple_even, (float2*)staple_odd, \
980  (const float2*)sitelink_even, (const float2*)sitelink_odd, \
981  (float2*)fatlink_even, (float2*)fatlink_odd, \
982  (float)mycoeff, kparam); \
983  do_siteComputeGenStapleParity18Kernel_ex<mu,nu, 1> \
984  <<<halfGridDim, blockDim, sbytes_sp>>>((float2*)staple_odd, (float2*)staple_even, \
985  (const float2*)sitelink_odd, (const float2*)sitelink_even, \
986  (float2*)fatlink_odd, (float2*)fatlink_even, \
987  (float)mycoeff, kparam); \
988  }else{ \
989  do_siteComputeGenStapleParity12Kernel_ex<mu,nu, 0> \
990  <<<halfGridDim, blockDim, sbytes_sp>>>((float2*)staple_even, (float2*)staple_odd, \
991  (const float4*)sitelink_even, (const float4*)sitelink_odd, \
992  (float2*)fatlink_even, (float2*)fatlink_odd, \
993  (float)mycoeff, kparam); \
994  do_siteComputeGenStapleParity12Kernel_ex<mu,nu, 1> \
995  <<<halfGridDim, blockDim, sbytes_sp>>>((float2*)staple_odd, (float2*)staple_even, \
996  (const float4*)sitelink_odd, (const float4*)sitelink_even, \
997  (float2*)fatlink_odd, (float2*)fatlink_even, \
998  (float)mycoeff, kparam); \
999  } \
1000  }
1001 
1002 
1003  ENUMERATE_FUNCS(mu,nu);
1004 
1005 #undef CALL_FUNCTION
1006 
1007 
1008  }
1009 
1010 
1011 
1012  void
1014  const void* sitelink_even, const void* sitelink_odd,
1015  void* fatlink_even, void* fatlink_odd,
1016  const void* mulink_even, const void* mulink_odd,
1017  int mu, int nu, int save_staple,
1018  double mycoeff,
1021  {
1022 
1023  dim3 blockDim = kparam.blockDim;
1024  dim3 halfGridDim= kparam.halfGridDim;
1025 
1026  int sbytes_dp = blockDim.x*5*sizeof(double2);
1027  int sbytes_sp = blockDim.x*5*sizeof(float2);
1028 
1029 #define CALL_FUNCTION(mu, nu, save_staple) \
1030  if (prec == QUDA_DOUBLE_PRECISION){ \
1031  if(recon == QUDA_RECONSTRUCT_NO){ \
1032  do_computeGenStapleFieldParity18Kernel_ex<mu,nu, 0, save_staple> \
1033  <<<halfGridDim, blockDim, sbytes_dp>>>((double2*)staple_even, (double2*)staple_odd, \
1034  (const double2*)sitelink_even, (const double2*)sitelink_odd, \
1035  (double2*)fatlink_even, (double2*)fatlink_odd, \
1036  (const double2*)mulink_even, (const double2*)mulink_odd, \
1037  (double)mycoeff, kparam); \
1038  do_computeGenStapleFieldParity18Kernel_ex<mu,nu, 1, save_staple> \
1039  <<<halfGridDim, blockDim, sbytes_dp>>>((double2*)staple_odd, (double2*)staple_even, \
1040  (const double2*)sitelink_odd, (const double2*)sitelink_even, \
1041  (double2*)fatlink_odd, (double2*)fatlink_even, \
1042  (const double2*)mulink_odd, (const double2*)mulink_even, \
1043  (double)mycoeff, kparam); \
1044  }else{ \
1045  do_computeGenStapleFieldParity12Kernel_ex<mu,nu, 0, save_staple> \
1046  <<<halfGridDim, blockDim, sbytes_dp>>>((double2*)staple_even, (double2*)staple_odd, \
1047  (const double2*)sitelink_even, (const double2*)sitelink_odd, \
1048  (double2*)fatlink_even, (double2*)fatlink_odd, \
1049  (const double2*)mulink_even, (const double2*)mulink_odd, \
1050  (double)mycoeff, kparam); \
1051  do_computeGenStapleFieldParity12Kernel_ex<mu,nu, 1, save_staple> \
1052  <<<halfGridDim, blockDim, sbytes_dp>>>((double2*)staple_odd, (double2*)staple_even, \
1053  (const double2*)sitelink_odd, (const double2*)sitelink_even, \
1054  (double2*)fatlink_odd, (double2*)fatlink_even, \
1055  (const double2*)mulink_odd, (const double2*)mulink_even, \
1056  (double)mycoeff, kparam); \
1057  } \
1058  }else{ \
1059  if(recon == QUDA_RECONSTRUCT_NO){ \
1060  do_computeGenStapleFieldParity18Kernel_ex<mu,nu, 0, save_staple> \
1061  <<<halfGridDim, blockDim, sbytes_sp>>>((float2*)staple_even, (float2*)staple_odd, \
1062  (const float2*)sitelink_even, (const float2*)sitelink_odd, \
1063  (float2*)fatlink_even, (float2*)fatlink_odd, \
1064  (const float2*)mulink_even, (const float2*)mulink_odd, \
1065  (float)mycoeff, kparam); \
1066  do_computeGenStapleFieldParity18Kernel_ex<mu,nu, 1, save_staple> \
1067  <<<halfGridDim, blockDim, sbytes_sp>>>((float2*)staple_odd, (float2*)staple_even, \
1068  (const float2*)sitelink_odd, (const float2*)sitelink_even, \
1069  (float2*)fatlink_odd, (float2*)fatlink_even, \
1070  (const float2*)mulink_odd, (const float2*)mulink_even, \
1071  (float)mycoeff, kparam); \
1072  }else{ \
1073  do_computeGenStapleFieldParity12Kernel_ex<mu,nu, 0, save_staple> \
1074  <<<halfGridDim, blockDim, sbytes_sp>>>((float2*)staple_even, (float2*)staple_odd, \
1075  (const float4*)sitelink_even, (const float4*)sitelink_odd, \
1076  (float2*)fatlink_even, (float2*)fatlink_odd, \
1077  (const float2*)mulink_even, (const float2*)mulink_odd, \
1078  (float)mycoeff, kparam); \
1079  do_computeGenStapleFieldParity12Kernel_ex<mu,nu, 1, save_staple> \
1080  <<<halfGridDim, blockDim, sbytes_sp>>>((float2*)staple_odd, (float2*)staple_even, \
1081  (const float4*)sitelink_odd, (const float4*)sitelink_even, \
1082  (float2*)fatlink_odd, (float2*)fatlink_even, \
1083  (const float2*)mulink_odd, (const float2*)mulink_even, \
1084  (float)mycoeff, kparam); \
1085  } \
1086  }
1087 
1088  BIND_MU_LINK();
1089  ENUMERATE_FUNCS_SAVE(mu,nu,save_staple);
1090 
1092 
1093 #undef CALL_FUNCTION
1094 
1095  }
1096 
1097 
1098 
1099 
1101  cudaGaugeField& cudaStaple, cudaGaugeField& cudaStaple1,
1102  QudaGaugeParam* param, double* act_path_coeff)
1103  {
1104  QudaPrecision prec = cudaSiteLink.Precision();
1105  QudaReconstructType recon = cudaSiteLink.Reconstruct();
1106 
1108  int volume = param->X[0]*param->X[1]*param->X[2]*param->X[3];
1109  dim3 gridDim(volume/BLOCK_DIM,1,1);
1110  dim3 blockDim(BLOCK_DIM , 1, 1);
1111 
1112  staple_bytes = cudaStaple.Bytes();
1113 
1114  if(prec == QUDA_DOUBLE_PRECISION){
1115  if(recon == QUDA_RECONSTRUCT_NO){
1116  llfatOneLink18Kernel<<<gridDim, blockDim>>>((const double2*)cudaSiteLink.Even_p(), (const double2*)cudaSiteLink.Odd_p(),
1117  (double2*)cudaFatLink.Even_p(), (double2*)cudaFatLink.Odd_p(),
1118  (double)act_path_coeff[0], (double)act_path_coeff[5]);
1119  }else{
1120 
1121  llfatOneLink12Kernel<<<gridDim, blockDim>>>((const double2*)cudaSiteLink.Even_p(), (const double2*)cudaSiteLink.Odd_p(),
1122  (double2*)cudaFatLink.Even_p(), (double2*)cudaFatLink.Odd_p(),
1123  (double)act_path_coeff[0], (double)act_path_coeff[5]);
1124 
1125  }
1126  }else{ //single precision
1127  if(recon == QUDA_RECONSTRUCT_NO){
1128  llfatOneLink18Kernel<<<gridDim, blockDim>>>((const float2*)cudaSiteLink.Even_p(), (const float2*)cudaSiteLink.Odd_p(),
1129  (float2*)cudaFatLink.Even_p(), (float2*)cudaFatLink.Odd_p(),
1130  (float)act_path_coeff[0], (float)act_path_coeff[5]);
1131  }else{
1132  llfatOneLink12Kernel<<<gridDim, blockDim>>>((const float4*)cudaSiteLink.Even_p(), (const float4*)cudaSiteLink.Odd_p(),
1133  (float2*)cudaFatLink.Even_p(), (float2*)cudaFatLink.Odd_p(),
1134  (float)act_path_coeff[0], (float)act_path_coeff[5]);
1135  }
1136  }
1137  }
1138 
1139 
1140 
1142  cudaGaugeField& cudaStaple, cudaGaugeField& cudaStaple1,
1143  QudaGaugeParam* param, double* act_path_coeff,
1145  {
1146  QudaPrecision prec = cudaSiteLink.Precision();
1147  QudaReconstructType recon = cudaSiteLink.Reconstruct();
1148 
1150 
1151  dim3 gridDim;
1152  dim3 blockDim = kparam.blockDim;
1153  gridDim.x = 2* kparam.halfGridDim.x;
1154  gridDim.y = 1;
1155  gridDim.z = 1;
1156  staple_bytes = cudaStaple.Bytes();
1157 
1158  if(prec == QUDA_DOUBLE_PRECISION){
1159  if(recon == QUDA_RECONSTRUCT_NO){
1160  llfatOneLink18Kernel_ex<<<gridDim, blockDim>>>((const double2*)cudaSiteLink.Even_p(), (const double2*)cudaSiteLink.Odd_p(),
1161  (double2*)cudaFatLink.Even_p(), (double2*)cudaFatLink.Odd_p(),
1162  (double)act_path_coeff[0], (double)act_path_coeff[5], kparam);
1163  }else{
1164 
1165  llfatOneLink12Kernel_ex<<<gridDim, blockDim>>>((const double2*)cudaSiteLink.Even_p(), (const double2*)cudaSiteLink.Odd_p(),
1166  (double2*)cudaFatLink.Even_p(), (double2*)cudaFatLink.Odd_p(),
1167  (double)act_path_coeff[0], (double)act_path_coeff[5], kparam);
1168 
1169  }
1170  }else{ //single precision
1171  if(recon == QUDA_RECONSTRUCT_NO){
1172  llfatOneLink18Kernel_ex<<<gridDim, blockDim>>>((const float2*)cudaSiteLink.Even_p(), (const float2*)cudaSiteLink.Odd_p(),
1173  (float2*)cudaFatLink.Even_p(), (float2*)cudaFatLink.Odd_p(),
1174  (float)act_path_coeff[0], (float)act_path_coeff[5], kparam);
1175  }else{
1176  llfatOneLink12Kernel_ex<<<gridDim, blockDim>>>((const float4*)cudaSiteLink.Even_p(), (const float4*)cudaSiteLink.Odd_p(),
1177  (float2*)cudaFatLink.Even_p(), (float2*)cudaFatLink.Odd_p(),
1178  (float)act_path_coeff[0], (float)act_path_coeff[5], kparam);
1179  }
1180  }
1181  }
1182 
1183 #undef BLOCK_DIM
1184 
1185 } // namespace quda