QUDA  v0.5.0
A library for QCD on GPUs
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
fermion_force_quda.cu
Go to the documentation of this file.
1 
2 #include <read_gauge.h>
3 #include <gauge_field.h>
4 
5 #include <fermion_force_quda.h>
6 #include <force_common.h>
7 #include <hw_quda.h>
8 
9 namespace quda {
10 
11 #define BLOCK_DIM 64
12 
13 #define LOAD_ANTI_HERMITIAN(src, dir, idx, var) LOAD_ANTI_HERMITIAN_DIRECT(src, dir, idx, var, Vh)
14 
15 #define LOAD_HW_SINGLE(hw_even, hw_odd, idx, var, oddness) do{ \
16  Float2* hw = (oddness)?hw_odd:hw_even; \
17  var##0 = hw[idx + 0*Vh]; \
18  var##1 = hw[idx + 1*Vh]; \
19  var##2 = hw[idx + 2*Vh]; \
20  var##3 = hw[idx + 3*Vh]; \
21  var##4 = hw[idx + 4*Vh]; \
22  var##5 = hw[idx + 5*Vh]; \
23  }while(0)
24 
25 #define WRITE_HW_SINGLE(hw_even, hw_odd, idx, var, oddness) do{ \
26  Float2* hw = (oddness)?hw_odd:hw_even; \
27  hw[idx + 0*Vh] = var##0; \
28  hw[idx + 1*Vh] = var##1; \
29  hw[idx + 2*Vh] = var##2; \
30  hw[idx + 3*Vh] = var##3; \
31  hw[idx + 4*Vh] = var##4; \
32  hw[idx + 5*Vh] = var##5; \
33  }while(0)
34 
35 #define LOAD_HW(hw_eve, hw_odd, idx, var, oddness) LOAD_HW_SINGLE(hw_eve, hw_odd, idx, var, oddness)
36 #define WRITE_HW(hw_even, hw_odd, idx, var, oddness) WRITE_HW_SINGLE(hw_even, hw_odd, idx, var, oddness)
37 #define LOAD_MATRIX(src, dir, idx, var) LOAD_MATRIX_12_SINGLE(src, dir, idx, var, Vh)
38 
39 #define FF_SITE_MATRIX_LOAD_TEX 1
40 
41 #define linkEvenTex siteLink0TexSingle_recon
42 #define linkOddTex siteLink1TexSingle_recon
43 
44 #if (FF_SITE_MATRIX_LOAD_TEX == 1)
45 #define FF_LOAD_MATRIX(dir, idx, var, oddness) LOAD_MATRIX_12_SINGLE_TEX(((oddness)?linkOddTex:linkEvenTex), dir, idx, var, Vh)
46 #else
47 #define FF_LOAD_MATRIX(dir, idx, var, oddness) LOAD_MATRIX_12_SINGLE(((oddness)?linkOdd:linkEven), dir, idx, var, Vh)
48 #endif
49 
50 
51 #define linka00_re LINKA0.x
52 #define linka00_im LINKA0.y
53 #define linka01_re LINKA0.z
54 #define linka01_im LINKA0.w
55 #define linka02_re LINKA1.x
56 #define linka02_im LINKA1.y
57 #define linka10_re LINKA1.z
58 #define linka10_im LINKA1.w
59 #define linka11_re LINKA2.x
60 #define linka11_im LINKA2.y
61 #define linka12_re LINKA2.z
62 #define linka12_im LINKA2.w
63 #define linka20_re LINKA3.x
64 #define linka20_im LINKA3.y
65 #define linka21_re LINKA3.z
66 #define linka21_im LINKA3.w
67 #define linka22_re LINKA4.x
68 #define linka22_im LINKA4.y
69 
70 #define linkb00_re LINKB0.x
71 #define linkb00_im LINKB0.y
72 #define linkb01_re LINKB0.z
73 #define linkb01_im LINKB0.w
74 #define linkb02_re LINKB1.x
75 #define linkb02_im LINKB1.y
76 #define linkb10_re LINKB1.z
77 #define linkb10_im LINKB1.w
78 #define linkb11_re LINKB2.x
79 #define linkb11_im LINKB2.y
80 #define linkb12_re LINKB2.z
81 #define linkb12_im LINKB2.w
82 #define linkb20_re LINKB3.x
83 #define linkb20_im LINKB3.y
84 #define linkb21_re LINKB3.z
85 #define linkb21_im LINKB3.w
86 #define linkb22_re LINKB4.x
87 #define linkb22_im LINKB4.y
88 
89 
90 #define MAT_MUL_HW(M, HW, HWOUT) \
91  HWOUT##00_re = (M##00_re * HW##00_re - M##00_im * HW##00_im) \
92  + (M##01_re * HW##01_re - M##01_im * HW##01_im) \
93  + (M##02_re * HW##02_re - M##02_im * HW##02_im); \
94  HWOUT##00_im = (M##00_re * HW##00_im + M##00_im * HW##00_re) \
95  + (M##01_re * HW##01_im + M##01_im * HW##01_re) \
96  + (M##02_re * HW##02_im + M##02_im * HW##02_re); \
97  HWOUT##01_re = (M##10_re * HW##00_re - M##10_im * HW##00_im) \
98  + (M##11_re * HW##01_re - M##11_im * HW##01_im) \
99  + (M##12_re * HW##02_re - M##12_im * HW##02_im); \
100  HWOUT##01_im = (M##10_re * HW##00_im + M##10_im * HW##00_re) \
101  + (M##11_re * HW##01_im + M##11_im * HW##01_re) \
102  + (M##12_re * HW##02_im + M##12_im * HW##02_re); \
103  HWOUT##02_re = (M##20_re * HW##00_re - M##20_im * HW##00_im) \
104  + (M##21_re * HW##01_re - M##21_im * HW##01_im) \
105  + (M##22_re * HW##02_re - M##22_im * HW##02_im); \
106  HWOUT##02_im = (M##20_re * HW##00_im + M##20_im * HW##00_re) \
107  + (M##21_re * HW##01_im + M##21_im * HW##01_re) \
108  + (M##22_re * HW##02_im + M##22_im * HW##02_re); \
109  HWOUT##10_re = (M##00_re * HW##10_re - M##00_im * HW##10_im) \
110  + (M##01_re * HW##11_re - M##01_im * HW##11_im) \
111  + (M##02_re * HW##12_re - M##02_im * HW##12_im); \
112  HWOUT##10_im = (M##00_re * HW##10_im + M##00_im * HW##10_re) \
113  + (M##01_re * HW##11_im + M##01_im * HW##11_re) \
114  + (M##02_re * HW##12_im + M##02_im * HW##12_re); \
115  HWOUT##11_re = (M##10_re * HW##10_re - M##10_im * HW##10_im) \
116  + (M##11_re * HW##11_re - M##11_im * HW##11_im) \
117  + (M##12_re * HW##12_re - M##12_im * HW##12_im); \
118  HWOUT##11_im = (M##10_re * HW##10_im + M##10_im * HW##10_re) \
119  + (M##11_re * HW##11_im + M##11_im * HW##11_re) \
120  + (M##12_re * HW##12_im + M##12_im * HW##12_re); \
121  HWOUT##12_re = (M##20_re * HW##10_re - M##20_im * HW##10_im) \
122  + (M##21_re * HW##11_re - M##21_im * HW##11_im) \
123  + (M##22_re * HW##12_re - M##22_im * HW##12_im); \
124  HWOUT##12_im = (M##20_re * HW##10_im + M##20_im * HW##10_re) \
125  + (M##21_re * HW##11_im + M##21_im * HW##11_re) \
126  + (M##22_re * HW##12_im + M##22_im * HW##12_re);
127 
128 
129 #define ADJ_MAT_MUL_HW(M, HW, HWOUT) \
130  HWOUT##00_re = (M##00_re * HW##00_re + M##00_im * HW##00_im) \
131  + (M##10_re * HW##01_re + M##10_im * HW##01_im) \
132  + (M##20_re * HW##02_re + M##20_im * HW##02_im); \
133  HWOUT##00_im = (M##00_re * HW##00_im - M##00_im * HW##00_re) \
134  + (M##10_re * HW##01_im - M##10_im * HW##01_re) \
135  + (M##20_re * HW##02_im - M##20_im * HW##02_re); \
136  HWOUT##01_re = (M##01_re * HW##00_re + M##01_im * HW##00_im) \
137  + (M##11_re * HW##01_re + M##11_im * HW##01_im) \
138  + (M##21_re * HW##02_re + M##21_im * HW##02_im); \
139  HWOUT##01_im = (M##01_re * HW##00_im - M##01_im * HW##00_re) \
140  + (M##11_re * HW##01_im - M##11_im * HW##01_re) \
141  + (M##21_re * HW##02_im - M##21_im * HW##02_re); \
142  HWOUT##02_re = (M##02_re * HW##00_re + M##02_im * HW##00_im) \
143  + (M##12_re * HW##01_re + M##12_im * HW##01_im) \
144  + (M##22_re * HW##02_re + M##22_im * HW##02_im); \
145  HWOUT##02_im = (M##02_re * HW##00_im - M##02_im * HW##00_re) \
146  + (M##12_re * HW##01_im - M##12_im * HW##01_re) \
147  + (M##22_re * HW##02_im - M##22_im * HW##02_re); \
148  HWOUT##10_re = (M##00_re * HW##10_re + M##00_im * HW##10_im) \
149  + (M##10_re * HW##11_re + M##10_im * HW##11_im) \
150  + (M##20_re * HW##12_re + M##20_im * HW##12_im); \
151  HWOUT##10_im = (M##00_re * HW##10_im - M##00_im * HW##10_re) \
152  + (M##10_re * HW##11_im - M##10_im * HW##11_re) \
153  + (M##20_re * HW##12_im - M##20_im * HW##12_re); \
154  HWOUT##11_re = (M##01_re * HW##10_re + M##01_im * HW##10_im) \
155  + (M##11_re * HW##11_re + M##11_im * HW##11_im) \
156  + (M##21_re * HW##12_re + M##21_im * HW##12_im); \
157  HWOUT##11_im = (M##01_re * HW##10_im - M##01_im * HW##10_re) \
158  + (M##11_re * HW##11_im - M##11_im * HW##11_re) \
159  + (M##21_re * HW##12_im - M##21_im * HW##12_re); \
160  HWOUT##12_re = (M##02_re * HW##10_re + M##02_im * HW##10_im) \
161  + (M##12_re * HW##11_re + M##12_im * HW##11_im) \
162  + (M##22_re * HW##12_re + M##22_im * HW##12_im); \
163  HWOUT##12_im = (M##02_re * HW##10_im - M##02_im * HW##10_re) \
164  + (M##12_re * HW##11_im - M##12_im * HW##11_re) \
165  + (M##22_re * HW##12_im - M##22_im * HW##12_re);
166 
167 
168 #define SU3_PROJECTOR(va, vb, m) \
169  m##00_re = va##0_re * vb##0_re + va##0_im * vb##0_im; \
170  m##00_im = va##0_im * vb##0_re - va##0_re * vb##0_im; \
171  m##01_re = va##0_re * vb##1_re + va##0_im * vb##1_im; \
172  m##01_im = va##0_im * vb##1_re - va##0_re * vb##1_im; \
173  m##02_re = va##0_re * vb##2_re + va##0_im * vb##2_im; \
174  m##02_im = va##0_im * vb##2_re - va##0_re * vb##2_im; \
175  m##10_re = va##1_re * vb##0_re + va##1_im * vb##0_im; \
176  m##10_im = va##1_im * vb##0_re - va##1_re * vb##0_im; \
177  m##11_re = va##1_re * vb##1_re + va##1_im * vb##1_im; \
178  m##11_im = va##1_im * vb##1_re - va##1_re * vb##1_im; \
179  m##12_re = va##1_re * vb##2_re + va##1_im * vb##2_im; \
180  m##12_im = va##1_im * vb##2_re - va##1_re * vb##2_im; \
181  m##20_re = va##2_re * vb##0_re + va##2_im * vb##0_im; \
182  m##20_im = va##2_im * vb##0_re - va##2_re * vb##0_im; \
183  m##21_re = va##2_re * vb##1_re + va##2_im * vb##1_im; \
184  m##21_im = va##2_im * vb##1_re - va##2_re * vb##1_im; \
185  m##22_re = va##2_re * vb##2_re + va##2_im * vb##2_im; \
186  m##22_im = va##2_im * vb##2_re - va##2_re * vb##2_im;
187 
188  //vc = va + vb*s
189 #define SCALAR_MULT_ADD_SU3_VECTOR(va, vb, s, vc) do { \
190  vc##0_re = va##0_re + vb##0_re * s; \
191  vc##0_im = va##0_im + vb##0_im * s; \
192  vc##1_re = va##1_re + vb##1_re * s; \
193  vc##1_im = va##1_im + vb##1_im * s; \
194  vc##2_re = va##2_re + vb##2_re * s; \
195  vc##2_im = va##2_im + vb##2_im * s; \
196  }while (0)
197 
198 
199 #define FF_COMPUTE_NEW_FULL_IDX_PLUS_UPDATE(mydir, idx, new_idx) do { \
200  switch(mydir){ \
201  case 0: \
202  new_idx = ( (new_x1==X1m1)?idx-X1m1:idx+1); \
203  new_x1 = (new_x1==X1m1)?0:new_x1+1; \
204  break; \
205  case 1: \
206  new_idx = ( (new_x2==X2m1)?idx-X2X1mX1:idx+X1); \
207  new_x2 = (new_x2==X2m1)?0:new_x2+1; \
208  break; \
209  case 2: \
210  new_idx = ( (new_x3==X3m1)?idx-X3X2X1mX2X1:idx+X2X1); \
211  new_x3 = (new_x3==X3m1)?0:new_x3+1; \
212  break; \
213  case 3: \
214  new_idx = ( (new_x4==X4m1)?idx-X4X3X2X1mX3X2X1:idx+X3X2X1); \
215  new_x4 = (new_x4==X4m1)?0:new_x4+1; \
216  break; \
217  } \
218  }while(0)
219 
220 #define FF_COMPUTE_NEW_FULL_IDX_MINUS_UPDATE(mydir, idx, new_idx) do { \
221  switch(mydir){ \
222  case 0: \
223  new_idx = ( (new_x1==0)?idx+X1m1:idx-1); \
224  new_x1 = (new_x1==0)?X1m1:new_x1 - 1; \
225  break; \
226  case 1: \
227  new_idx = ( (new_x2==0)?idx+X2X1mX1:idx-X1); \
228  new_x2 = (new_x2==0)?X2m1:new_x2 - 1; \
229  break; \
230  case 2: \
231  new_idx = ( (new_x3==0)?idx+X3X2X1mX2X1:idx-X2X1); \
232  new_x3 = (new_x3==0)?X3m1:new_x3 - 1; \
233  break; \
234  case 3: \
235  new_idx = ( (new_x4==0)?idx+X4X3X2X1mX3X2X1:idx-X3X2X1); \
236  new_x4 = (new_x4==0)?X4m1:new_x4 - 1; \
237  break; \
238  } \
239  }while(0)
240 
241 
242 #define FF_COMPUTE_NEW_FULL_IDX_PLUS(old_x1, old_x2, old_x3, old_x4, idx, mydir, new_idx) do { \
243  switch(mydir){ \
244  case 0: \
245  new_idx = ( (old_x1==X1m1)?idx-X1m1:idx+1); \
246  break; \
247  case 1: \
248  new_idx = ( (old_x2==X2m1)?idx-X2X1mX1:idx+X1); \
249  break; \
250  case 2: \
251  new_idx = ( (old_x3==X3m1)?idx-X3X2X1mX2X1:idx+X2X1); \
252  break; \
253  case 3: \
254  new_idx = ( (old_x4==X4m1)?idx-X4X3X2X1mX3X2X1:idx+X3X2X1); \
255  break; \
256  } \
257  }while(0)
258 
259 #define FF_COMPUTE_NEW_FULL_IDX_MINUS(old_x1, old_x2, old_x3, old_x4, idx, mydir, new_idx) do { \
260  switch(mydir){ \
261  case 0: \
262  new_idx = ( (old_x1==0)?idx+X1m1:idx-1); \
263  break; \
264  case 1: \
265  new_idx = ( (old_x2==0)?idx+X2X1mX1:idx-X1); \
266  break; \
267  case 2: \
268  new_idx = ( (old_x3==0)?idx+X3X2X1mX2X1:idx-X2X1); \
269  break; \
270  case 3: \
271  new_idx = ( (old_x4==0)?idx+X4X3X2X1mX3X2X1:idx-X3X2X1); \
272  break; \
273  } \
274  }while(0)
275 
276  //this macro require linka, linkb, and ah variables defined
277 #define ADD_FORCE_TO_MOM(hw1, hw2, idx, dir, cf,oddness) do{ \
278  Float2 my_coeff; \
279  int mydir; \
280  if (GOES_BACKWARDS(dir)){ \
281  mydir=OPP_DIR(dir); \
282  my_coeff.x = -cf.x; \
283  my_coeff.y = -cf.y; \
284  }else{ \
285  mydir=dir; \
286  my_coeff.x = cf.x; \
287  my_coeff.y = cf.y; \
288  } \
289  Float2 tmp_coeff; \
290  tmp_coeff.x = my_coeff.x; \
291  tmp_coeff.y = my_coeff.y; \
292  if(oddness){ \
293  tmp_coeff.x = - my_coeff.x; \
294  tmp_coeff.y = - my_coeff.y; \
295  } \
296  Float2* mom = oddness?momOdd:momEven; \
297  LOAD_ANTI_HERMITIAN(mom, mydir, idx, AH); \
298  UNCOMPRESS_ANTI_HERMITIAN(ah, linka); \
299  SU3_PROJECTOR(hw1##0, hw2##0, linkb); \
300  SCALAR_MULT_ADD_SU3_MATRIX(linka, linkb, tmp_coeff.x, linka); \
301  SU3_PROJECTOR(hw1##1, hw2##1, linkb); \
302  SCALAR_MULT_ADD_SU3_MATRIX(linka, linkb, tmp_coeff.y, linka); \
303  MAKE_ANTI_HERMITIAN(linka, ah); \
304  WRITE_ANTI_HERMITIAN(mom, mydir, idx, AH, Vh); \
305  }while(0)
306 
307 
308 #define FF_COMPUTE_RECONSTRUCT_SIGN(sign, dir, i1,i2,i3,i4) do { \
309  sign =1; \
310  switch(dir){ \
311  case XUP: \
312  if ( (i4 & 1) == 1){ \
313  sign = -1; \
314  } \
315  break; \
316  case YUP: \
317  if ( ((i4+i1) & 1) == 1){ \
318  sign = -1; \
319  } \
320  break; \
321  case ZUP: \
322  if ( ((i4+i1+i2) & 1) == 1){ \
323  sign = -1; \
324  } \
325  break; \
326  case TUP: \
327  if (i4 == X4m1 ){ \
328  sign = -1; \
329  } \
330  break; \
331  } \
332  }while (0)
333 
334 
335 #define hwa00_re HWA0.x
336 #define hwa00_im HWA0.y
337 #define hwa01_re HWA1.x
338 #define hwa01_im HWA1.y
339 #define hwa02_re HWA2.x
340 #define hwa02_im HWA2.y
341 #define hwa10_re HWA3.x
342 #define hwa10_im HWA3.y
343 #define hwa11_re HWA4.x
344 #define hwa11_im HWA4.y
345 #define hwa12_re HWA5.x
346 #define hwa12_im HWA5.y
347 
348 #define hwb00_re HWB0.x
349 #define hwb00_im HWB0.y
350 #define hwb01_re HWB1.x
351 #define hwb01_im HWB1.y
352 #define hwb02_re HWB2.x
353 #define hwb02_im HWB2.y
354 #define hwb10_re HWB3.x
355 #define hwb10_im HWB3.y
356 #define hwb11_re HWB4.x
357 #define hwb11_im HWB4.y
358 #define hwb12_re HWB5.x
359 #define hwb12_im HWB5.y
360 
361 #define hwc00_re HWC0.x
362 #define hwc00_im HWC0.y
363 #define hwc01_re HWC1.x
364 #define hwc01_im HWC1.y
365 #define hwc02_re HWC2.x
366 #define hwc02_im HWC2.y
367 #define hwc10_re HWC3.x
368 #define hwc10_im HWC3.y
369 #define hwc11_re HWC4.x
370 #define hwc11_im HWC4.y
371 #define hwc12_re HWC5.x
372 #define hwc12_im HWC5.y
373 
374 #define hwd00_re HWD0.x
375 #define hwd00_im HWD0.y
376 #define hwd01_re HWD1.x
377 #define hwd01_im HWD1.y
378 #define hwd02_re HWD2.x
379 #define hwd02_im HWD2.y
380 #define hwd10_re HWD3.x
381 #define hwd10_im HWD3.y
382 #define hwd11_re HWD4.x
383 #define hwd11_im HWD4.y
384 #define hwd12_re HWD5.x
385 #define hwd12_im HWD5.y
386 
387 #define hwe00_re HWE0.x
388 #define hwe00_im HWE0.y
389 #define hwe01_re HWE1.x
390 #define hwe01_im HWE1.y
391 #define hwe02_re HWE2.x
392 #define hwe02_im HWE2.y
393 #define hwe10_re HWE3.x
394 #define hwe10_im HWE3.y
395 #define hwe11_re HWE4.x
396 #define hwe11_im HWE4.y
397 #define hwe12_re HWE5.x
398 #define hwe12_im HWE5.y
399 
400 
402  {
403 
404 #ifdef MULTI_GPU
405 #error "multi gpu is not supported for fermion force computation"
406 #endif
407 
408  static int fermion_force_init_cuda_flag = 0;
409 
410  if (fermion_force_init_cuda_flag) return;
411 
412  fermion_force_init_cuda_flag=1;
413 
414  }
415 
416  /*
417  * This function computes contribution to mometum from the middle link in a staple
418  *
419  * tempx: IN
420  * Pmu: OUT
421  * P3: OUT
422  *
423  */
424 
425  template<int sig_positive, int mu_positive, int oddBit, typename Float2>
426  __global__ void
427  do_middle_link_kernel(Float2* tempxEven, Float2* tempxOdd,
428  Float2* PmuEven, Float2* PmuOdd,
429  Float2* P3Even, Float2* P3Odd,
430  int sig, int mu, Float2 coeff,
431  float4* linkEven, float4* linkOdd,
432  Float2* momEven, Float2* momOdd)
433  {
434  int sid = blockIdx.x * blockDim.x + threadIdx.x;
435 
436  int z1 = sid / X1h;
437  int x1h = sid - z1*X1h;
438  int z2 = z1 / X2;
439  int x2 = z1 - z2*X2;
440  int x4 = z2 / X3;
441  int x3 = z2 - x4*X3;
442  int x1odd = (x2 + x3 + x4 + oddBit) & 1;
443  int x1 = 2*x1h + x1odd;
444  int X = 2*sid + x1odd;
445 
446  int new_x1, new_x2, new_x3, new_x4;
447  int new_mem_idx;
448  int ad_link_sign=1;
449  int ab_link_sign=1;
450  int bc_link_sign=1;
451 
452  Float2 HWA0, HWA1, HWA2, HWA3, HWA4, HWA5;
453  Float2 HWB0, HWB1, HWB2, HWB3, HWB4, HWB5;
454  Float2 HWC0, HWC1, HWC2, HWC3, HWC4, HWC5;
455  Float2 HWD0, HWD1, HWD2, HWD3, HWD4, HWD5;
456  float4 LINKA0, LINKA1, LINKA2, LINKA3, LINKA4;
457  float4 LINKB0, LINKB1, LINKB2, LINKB3, LINKB4;
458  Float2 AH0, AH1, AH2, AH3, AH4;
459 
460  /* sig
461  * A________B
462  * mu | |
463  * D | |C
464  *
465  * A is the current point (sid)
466  */
467 
468  int point_b, point_c, point_d;
470  int mymu;
471 
472  new_x1 = x1;
473  new_x2 = x2;
474  new_x3 = x3;
475  new_x4 = x4;
476 
477  if(mu_positive){
478  mymu =mu;
479  FF_COMPUTE_NEW_FULL_IDX_MINUS_UPDATE(mu, X, new_mem_idx);
480  }else{
481  mymu = OPP_DIR(mu);
482  FF_COMPUTE_NEW_FULL_IDX_PLUS_UPDATE(OPP_DIR(mu), X, new_mem_idx);
483  }
484  point_d = (new_mem_idx >> 1);
485  if (mu_positive){
486  ad_link_nbr_idx = point_d;
487  FF_COMPUTE_RECONSTRUCT_SIGN(ad_link_sign, mymu, new_x1,new_x2,new_x3,new_x4);
488  }else{
489  ad_link_nbr_idx = sid;
490  FF_COMPUTE_RECONSTRUCT_SIGN(ad_link_sign, mymu, x1, x2, x3, x4);
491  }
492 
493  int mysig;
494  if(sig_positive){
495  mysig = sig;
496  FF_COMPUTE_NEW_FULL_IDX_PLUS_UPDATE(sig, new_mem_idx, new_mem_idx);
497  }else{
498  mysig = OPP_DIR(sig);
499  FF_COMPUTE_NEW_FULL_IDX_MINUS_UPDATE(OPP_DIR(sig), new_mem_idx, new_mem_idx);
500  }
501  point_c = (new_mem_idx >> 1);
502  if (mu_positive){
503  bc_link_nbr_idx = point_c;
504  FF_COMPUTE_RECONSTRUCT_SIGN(bc_link_sign, mymu, new_x1,new_x2,new_x3,new_x4);
505  }
506  new_x1 = x1;
507  new_x2 = x2;
508  new_x3 = x3;
509  new_x4 = x4;
510  if(sig_positive){
511  FF_COMPUTE_NEW_FULL_IDX_PLUS_UPDATE(sig, X, new_mem_idx);
512  }else{
513  FF_COMPUTE_NEW_FULL_IDX_MINUS_UPDATE(OPP_DIR(sig), X, new_mem_idx);
514  }
515  point_b = (new_mem_idx >> 1);
516 
517  if (!mu_positive){
518  bc_link_nbr_idx = point_b;
519  FF_COMPUTE_RECONSTRUCT_SIGN(bc_link_sign, mymu, new_x1,new_x2,new_x3,new_x4);
520  }
521 
522  if(sig_positive){
523  ab_link_nbr_idx = sid;
524  FF_COMPUTE_RECONSTRUCT_SIGN(ab_link_sign, mysig, x1, x2, x3, x4);
525  }else{
526  ab_link_nbr_idx = point_b;
527  FF_COMPUTE_RECONSTRUCT_SIGN(ab_link_sign, mysig, new_x1,new_x2,new_x3,new_x4);
528  }
529 
530  LOAD_HW(tempxEven, tempxOdd, point_d, HWA, 1-oddBit );
531  if(mu_positive){
532  FF_LOAD_MATRIX(mymu, ad_link_nbr_idx, LINKA, 1-oddBit);
533  }else{
534  FF_LOAD_MATRIX(mymu, ad_link_nbr_idx, LINKA, oddBit);
535  }
536 
537  RECONSTRUCT_LINK_12(ad_link_sign, linka);
538  if (mu_positive){
539  ADJ_MAT_MUL_HW(linka, hwa, hwd);
540  }else{
541  MAT_MUL_HW(linka, hwa, hwd);
542  }
543  WRITE_HW(PmuEven,PmuOdd, sid, HWD, oddBit);
544 
545  LOAD_HW(tempxEven,tempxOdd, point_c, HWA, oddBit);
546  if(mu_positive){
547  FF_LOAD_MATRIX(mymu, bc_link_nbr_idx, LINKA, oddBit);
548  }else{
549  FF_LOAD_MATRIX(mymu, bc_link_nbr_idx, LINKA, 1-oddBit);
550  }
551 
552  RECONSTRUCT_LINK_12(bc_link_sign, linka);
553  if (mu_positive){
554  ADJ_MAT_MUL_HW(linka, hwa, hwb);
555  }else{
556  MAT_MUL_HW(linka, hwa, hwb);
557  }
558  if(sig_positive){
559  FF_LOAD_MATRIX(mysig, ab_link_nbr_idx, LINKB, oddBit);
560  }else{
561  FF_LOAD_MATRIX(mysig, ab_link_nbr_idx, LINKB, 1-oddBit);
562  }
563 
564  RECONSTRUCT_LINK_12(ab_link_sign, linkb);
565  if (sig_positive){
566  MAT_MUL_HW(linkb, hwb, hwc);
567  }else{
568  ADJ_MAT_MUL_HW(linkb, hwb, hwc);
569  }
570  WRITE_HW(P3Even, P3Odd, sid, HWC, oddBit);
571 
572  if (sig_positive){
573  //add the force to mom
574  ADD_FORCE_TO_MOM(hwc, hwd, sid, sig, coeff, oddBit);
575  }
576  }
577 
578 
579  template<typename Float2>
580  static void
581  middle_link_kernel(Float2* tempxEven, Float2* tempxOdd,
582  Float2* PmuEven, Float2* PmuOdd,
583  Float2* P3Even, Float2* P3Odd,
584  int sig, int mu, Float2 coeff,
585  float4* linkEven, float4* linkOdd, cudaGaugeField &siteLink,
586  Float2* momEven, Float2* momOdd,
587  dim3 gridDim, dim3 BlockDim)
588  {
589  dim3 halfGridDim(gridDim.x/2, 1,1);
590 
591 
592 #define CALL_MIDDLE_LINK_KERNEL(sig_sign, mu_sign) \
593  do_middle_link_kernel<sig_sign, mu_sign,0><<<halfGridDim, BlockDim>>>( tempxEven, tempxOdd, \
594  PmuEven, PmuOdd, \
595  P3Even, P3Odd, \
596  sig, mu, coeff, \
597  linkEven, linkOdd, \
598  momEven, momOdd); \
599  do_middle_link_kernel<sig_sign, mu_sign, 1><<<halfGridDim, BlockDim>>>(tempxEven, tempxOdd, \
600  PmuEven, PmuOdd, \
601  P3Even, P3Odd, \
602  sig, mu, coeff, \
603  linkEven, linkOdd, \
604  momEven, momOdd);
605 
606 
607  if (GOES_FORWARDS(sig) && GOES_FORWARDS(mu)){
609  }else if (GOES_FORWARDS(sig) && GOES_BACKWARDS(mu)){
611  }else if (GOES_BACKWARDS(sig) && GOES_FORWARDS(mu)){
613  }else{
615  }
616 #undef CALL_MIDDLE_LINK_KERNEL
617 
618  }
619 
620  /*
621  * Computes contribution to momentum from the side links in a staple
622  *
623  * P3: IN
624  * P3mu: not used
625  * Tempx: IN
626  * Pmu: IN
627  * shortPE: OUT
628  *
629  */
630 
631  template<int sig_positive, int mu_positive, int oddBit, typename Float2>
632  __global__ void
633  do_side_link_kernel(Float2* P3Even, Float2* P3Odd,
634  Float2* P3muEven, Float2* P3muOdd,
635  Float2* TempxEven, Float2* TempxOdd,
636  Float2* PmuEven, Float2* PmuOdd,
637  Float2* shortPEven, Float2* shortPOdd,
638  int sig, int mu, Float2 coeff, Float2 accumu_coeff,
639  float4* linkEven, float4* linkOdd,
640  Float2* momEven, Float2* momOdd)
641  {
642  Float2 mcoeff;
643  mcoeff.x = -coeff.x;
644  mcoeff.y = -coeff.y;
645 
646  int sid = blockIdx.x * blockDim.x + threadIdx.x;
647 
648  int z1 = sid / X1h;
649  int x1h = sid - z1*X1h;
650  int z2 = z1 / X2;
651  int x2 = z1 - z2*X2;
652  int x4 = z2 / X3;
653  int x3 = z2 - x4*X3;
654  int x1odd = (x2 + x3 + x4 + oddBit) & 1;
655  int x1 = 2*x1h + x1odd;
656  int X = 2*sid + x1odd;
657 
658  int ad_link_sign = 1;
659  Float2 HWA0, HWA1, HWA2, HWA3, HWA4, HWA5;
660  Float2 HWB0, HWB1, HWB2, HWB3, HWB4, HWB5;
661  Float2 HWC0, HWC1, HWC2, HWC3, HWC4, HWC5;
662  float4 LINKA0, LINKA1, LINKA2, LINKA3, LINKA4;
663  float4 LINKB0, LINKB1, LINKB2, LINKB3, LINKB4;
664  Float2 AH0, AH1, AH2, AH3, AH4;
665 
666 
667 
668  /*
669  * compute the side link contribution to the momentum
670  *
671  *
672  * sig
673  * A________B
674  * | | mu
675  * D | |C
676  *
677  * A is the current point (sid)
678  */
679 
680  int point_d;
681  int ad_link_nbr_idx;
682  int mymu;
683  int new_mem_idx;
684 
685  int new_x1 = x1;
686  int new_x2 = x2;
687  int new_x3 = x3;
688  int new_x4 = x4;
689 
690  if(mu_positive){
691  mymu =mu;
692  FF_COMPUTE_NEW_FULL_IDX_MINUS_UPDATE(mymu,X, new_mem_idx);
693  }else{
694  mymu = OPP_DIR(mu);
695  FF_COMPUTE_NEW_FULL_IDX_PLUS_UPDATE(mymu, X, new_mem_idx);
696  }
697  point_d = (new_mem_idx >> 1);
698 
699  if (mu_positive){
700  ad_link_nbr_idx = point_d;
701  FF_COMPUTE_RECONSTRUCT_SIGN(ad_link_sign, mymu, new_x1,new_x2,new_x3,new_x4);
702  }else{
703  ad_link_nbr_idx = sid;
704  FF_COMPUTE_RECONSTRUCT_SIGN(ad_link_sign, mymu, x1, x2, x3, x4);
705  }
706 
707 
708  LOAD_HW(P3Even, P3Odd, sid, HWA, oddBit);
709  if(mu_positive){
710  FF_LOAD_MATRIX(mymu, ad_link_nbr_idx, LINKA, 1 - oddBit);
711  }else{
712  FF_LOAD_MATRIX(mymu, ad_link_nbr_idx, LINKA, oddBit);
713  }
714 
715  RECONSTRUCT_LINK_12(ad_link_sign, linka);
716  if (mu_positive){
717  MAT_MUL_HW(linka, hwa, hwb);
718  }else{
719  ADJ_MAT_MUL_HW(linka, hwa, hwb);
720  }
721 
722 
723  //start to add side link force
724  if (mu_positive){
725  LOAD_HW(TempxEven, TempxOdd, point_d, HWC, 1-oddBit);
726 
727  if (sig_positive){
728  ADD_FORCE_TO_MOM(hwb, hwc, point_d, mu, coeff, 1-oddBit);
729  }else{
730  ADD_FORCE_TO_MOM(hwc, hwb, point_d, OPP_DIR(mu), mcoeff, 1- oddBit);
731  }
732  }else{
733  LOAD_HW(PmuEven, PmuOdd, sid, HWC, oddBit);
734  if (sig_positive){
735  ADD_FORCE_TO_MOM(hwa, hwc, sid, mu, mcoeff, oddBit);
736  }else{
737  ADD_FORCE_TO_MOM(hwc, hwa, sid, OPP_DIR(mu), coeff, oddBit);
738  }
739 
740  }
741 
742  if (shortPOdd){
743  LOAD_HW(shortPEven, shortPOdd, point_d, HWA, 1-oddBit);
744  SCALAR_MULT_ADD_SU3_VECTOR(hwa0, hwb0, accumu_coeff.x, hwa0);
745  SCALAR_MULT_ADD_SU3_VECTOR(hwa1, hwb1, accumu_coeff.y, hwa1);
746  WRITE_HW(shortPEven, shortPOdd, point_d, HWA, 1-oddBit);
747  }
748 
749  }
750 
751 
752  template<typename Float2>
753  static void
754  side_link_kernel(Float2* P3Even, Float2* P3Odd,
755  Float2* P3muEven, Float2* P3muOdd,
756  Float2* TempxEven, Float2* TempxOdd,
757  Float2* PmuEven, Float2* PmuOdd,
758  Float2* shortPEven, Float2* shortPOdd,
759  int sig, int mu, Float2 coeff, Float2 accumu_coeff,
760  float4* linkEven, float4* linkOdd, cudaGaugeField &siteLink,
761  Float2* momEven, Float2* momOdd,
762  dim3 gridDim, dim3 blockDim)
763  {
764  dim3 halfGridDim(gridDim.x/2,1,1);
765 
766 #define CALL_SIDE_LINK_KERNEL(sig_sign, mu_sign) \
767  do_side_link_kernel<sig_sign,mu_sign,0><<<halfGridDim, blockDim>>>( P3Even, P3Odd, \
768  P3muEven, P3muOdd, \
769  TempxEven, TempxOdd, \
770  PmuEven, PmuOdd, \
771  shortPEven, shortPOdd, \
772  sig, mu, coeff, accumu_coeff, \
773  linkEven, linkOdd, \
774  momEven, momOdd); \
775  do_side_link_kernel<sig_sign,mu_sign,1><<<halfGridDim, blockDim>>>( P3Even, P3Odd, \
776  P3muEven, P3muOdd, \
777  TempxEven, TempxOdd, \
778  PmuEven, PmuOdd, \
779  shortPEven, shortPOdd, \
780  sig, mu, coeff, accumu_coeff, \
781  linkEven, linkOdd, \
782  momEven, momOdd);
783 
784  if (GOES_FORWARDS(sig) && GOES_FORWARDS(mu)){
785  CALL_SIDE_LINK_KERNEL(1,1);
786  }else if (GOES_FORWARDS(sig) && GOES_BACKWARDS(mu)){
787  CALL_SIDE_LINK_KERNEL(1,0);
788  }else if (GOES_BACKWARDS(sig) && GOES_FORWARDS(mu)){
789  CALL_SIDE_LINK_KERNEL(0,1);
790  }else{
792  }
793 
794 #undef CALL_SIDE_LINK_KERNEL
795 
796  }
797 
798  /*
799  * This function computes the contribution to momentum from middle and side links
800  *
801  * tempx: IN
802  * Pmu: not used
803  * P3: not used
804  * P3mu: not used
805  * shortP: OUT
806  *
807  */
808 
809  template<int sig_positive, int mu_positive, int oddBit, typename Float2>
810  __global__ void
811  do_all_link_kernel(Float2* tempxEven, Float2* tempxOdd,
812  Float2* PmuEven, Float2* PmuOdd,
813  Float2* P3Even, Float2* P3Odd,
814  Float2* P3muEven, Float2* P3muOdd,
815  Float2* shortPEven, Float2* shortPOdd,
816  int sig, int mu, Float2 coeff, Float2 mcoeff, Float2 accumu_coeff,
817  float4* linkEven, float4* linkOdd,
818  Float2* momEven, Float2* momOdd)
819  {
820  int sid = blockIdx.x * blockDim.x + threadIdx.x;
821 
822  int z1 = sid / X1h;
823  int x1h = sid - z1*X1h;
824  int z2 = z1 / X2;
825  int x2 = z1 - z2*X2;
826  int x4 = z2 / X3;
827  int x3 = z2 - x4*X3;
828  int x1odd = (x2 + x3 + x4 + oddBit) & 1;
829  int x1 = 2*x1h + x1odd;
830  int X = 2*sid + x1odd;
831 
832  int new_x1, new_x2, new_x3, new_x4;
833  int ad_link_sign=1;
834  int ab_link_sign=1;
835  int bc_link_sign=1;
836 
837  Float2 HWA0, HWA1, HWA2, HWA3, HWA4, HWA5;
838  Float2 HWB0, HWB1, HWB2, HWB3, HWB4, HWB5;
839  Float2 HWC0, HWC1, HWC2, HWC3, HWC4, HWC5;
840  Float2 HWD0, HWD1, HWD2, HWD3, HWD4, HWD5;
841  Float2 HWE0, HWE1, HWE2, HWE3, HWE4, HWE5;
842  float4 LINKA0, LINKA1, LINKA2, LINKA3, LINKA4;
843  float4 LINKB0, LINKB1, LINKB2, LINKB3, LINKB4;
844  float4 LINKC0, LINKC1, LINKC2, LINKC3, LINKC4;
845  Float2 AH0, AH1, AH2, AH3, AH4;
846 
847 
848  /* sig
849  * A________B
850  * mu | |
851  * D | |C
852  *
853  * A is the current point (sid)
854  */
855 
856 
857  int point_b, point_c, point_d;
859  int mymu;
860  int new_mem_idx;
861  new_x1 = x1;
862  new_x2 = x2;
863  new_x3 = x3;
864  new_x4 = x4;
865 
866  if(mu_positive){
867  mymu =mu;
868  FF_COMPUTE_NEW_FULL_IDX_MINUS_UPDATE(mu, X, new_mem_idx);
869  }else{
870  mymu = OPP_DIR(mu);
871  FF_COMPUTE_NEW_FULL_IDX_PLUS_UPDATE(OPP_DIR(mu), X, new_mem_idx);
872  }
873  point_d = (new_mem_idx >> 1);
874 
875  if (mu_positive){
876  ad_link_nbr_idx = point_d;
877  FF_COMPUTE_RECONSTRUCT_SIGN(ad_link_sign, mymu, new_x1,new_x2,new_x3,new_x4);
878  }else{
879  ad_link_nbr_idx = sid;
880  FF_COMPUTE_RECONSTRUCT_SIGN(ad_link_sign, mymu, x1, x2, x3, x4);
881  }
882 
883  int mysig;
884  if(sig_positive){
885  mysig = sig;
886  FF_COMPUTE_NEW_FULL_IDX_PLUS_UPDATE(sig, new_mem_idx, new_mem_idx);
887  }else{
888  mysig = OPP_DIR(sig);
889  FF_COMPUTE_NEW_FULL_IDX_MINUS_UPDATE(OPP_DIR(sig), new_mem_idx, new_mem_idx);
890  }
891  point_c = (new_mem_idx >> 1);
892  if (mu_positive){
893  bc_link_nbr_idx = point_c;
894  FF_COMPUTE_RECONSTRUCT_SIGN(bc_link_sign, mymu, new_x1,new_x2,new_x3,new_x4);
895  }
896 
897  new_x1 = x1;
898  new_x2 = x2;
899  new_x3 = x3;
900  new_x4 = x4;
901  if(sig_positive){
902  FF_COMPUTE_NEW_FULL_IDX_PLUS_UPDATE(sig, X, new_mem_idx);
903  }else{
904  FF_COMPUTE_NEW_FULL_IDX_MINUS_UPDATE(OPP_DIR(sig), X, new_mem_idx);
905  }
906  point_b = (new_mem_idx >> 1);
907  if (!mu_positive){
908  bc_link_nbr_idx = point_b;
909  FF_COMPUTE_RECONSTRUCT_SIGN(bc_link_sign, mymu, new_x1,new_x2,new_x3,new_x4);
910  }
911 
912  if(sig_positive){
913  ab_link_nbr_idx = sid;
914  FF_COMPUTE_RECONSTRUCT_SIGN(ab_link_sign, mysig, x1, x2, x3, x4);
915  }else{
916  ab_link_nbr_idx = point_b;
917  FF_COMPUTE_RECONSTRUCT_SIGN(ab_link_sign, mysig, new_x1,new_x2,new_x3,new_x4);
918  }
919 
920  LOAD_HW(tempxEven, tempxOdd, point_d, HWE, 1-oddBit);
921  if (mu_positive){
922  FF_LOAD_MATRIX(mymu, ad_link_nbr_idx, LINKC, 1-oddBit);
923  }else{
924  FF_LOAD_MATRIX(mymu, ad_link_nbr_idx, LINKC, oddBit);
925  }
926 
927  RECONSTRUCT_LINK_12(ad_link_sign, linkc);
928  if (mu_positive){
929  ADJ_MAT_MUL_HW(linkc, hwe, hwd);
930  }else{
931  MAT_MUL_HW(linkc, hwe, hwd);
932  }
933  //we do not need to write Pmu here
934  //WRITE_HW(myPmu, sid, HWD);
935 
936  LOAD_HW(tempxEven, tempxOdd, point_c, HWA, oddBit);
937  if (mu_positive){
938  FF_LOAD_MATRIX(mymu, bc_link_nbr_idx, LINKA, oddBit);
939  }else{
940  FF_LOAD_MATRIX(mymu, bc_link_nbr_idx, LINKA, 1-oddBit);
941  }
942 
943  RECONSTRUCT_LINK_12(bc_link_sign, linka);
944  if (mu_positive){
945  ADJ_MAT_MUL_HW(linka, hwa, hwb);
946  }else{
947  MAT_MUL_HW(linka, hwa, hwb);
948  }
949  if (sig_positive){
950  FF_LOAD_MATRIX(mysig, ab_link_nbr_idx, LINKA, oddBit);
951  }else{
952  FF_LOAD_MATRIX(mysig, ab_link_nbr_idx, LINKA, 1-oddBit);
953  }
954 
955  RECONSTRUCT_LINK_12(ab_link_sign, linka);
956  if (sig_positive){
957  MAT_MUL_HW(linka, hwb, hwc);
958  }else{
959  ADJ_MAT_MUL_HW(linka, hwb, hwc);
960  }
961 
962  //we do not need to write P3 here
963  //WRITE_HW(myP3, sid, HWC);
964 
965  //The middle link contribution
966  if (sig_positive){
967  //add the force to mom
968  ADD_FORCE_TO_MOM(hwc, hwd, sid, sig, mcoeff, oddBit);
969  }
970 
971  //P3 is hwc
972  //ad_link is linkc
973  if (mu_positive){
974  MAT_MUL_HW(linkc, hwc, hwa);
975  }else{
976  ADJ_MAT_MUL_HW(linkc, hwc, hwa);
977  }
978 
979  //accumulate P7rho to P5
980  //WRITE_HW(otherP3mu, point_d, HWA);
981  LOAD_HW(shortPEven, shortPOdd, point_d, HWB, 1-oddBit);
982  SCALAR_MULT_ADD_SU3_VECTOR(hwb0, hwa0, accumu_coeff.x, hwb0);
983  SCALAR_MULT_ADD_SU3_VECTOR(hwb1, hwa1, accumu_coeff.y, hwb1);
984  WRITE_HW(shortPEven, shortPOdd, point_d, HWB, 1-oddBit);
985 
986  //hwe holds tempx at point_d
987  //hwd holds Pmu at point A(sid)
988  if (mu_positive){
989  if (sig_positive){
990  ADD_FORCE_TO_MOM(hwa, hwe, point_d, mu, coeff, 1-oddBit);
991  }else{
992  ADD_FORCE_TO_MOM(hwe, hwa, point_d, OPP_DIR(mu), mcoeff, 1- oddBit);
993  }
994  }else{
995  if (sig_positive){
996  ADD_FORCE_TO_MOM(hwc, hwd, sid, mu, mcoeff, oddBit);
997  }else{
998  ADD_FORCE_TO_MOM(hwd, hwc, sid, OPP_DIR(mu), coeff, oddBit);
999  }
1000 
1001  }
1002 
1003 
1004  }
1005 
1006 
1007 
1008  template<typename Float2>
1009  static void
1010  all_link_kernel(Float2* tempxEven, Float2* tempxOdd,
1011  Float2* PmuEven, Float2* PmuOdd,
1012  Float2* P3Even, Float2* P3Odd,
1013  Float2* P3muEven, Float2* P3muOdd,
1014  Float2* shortPEven, Float2* shortPOdd,
1015  int sig, int mu, Float2 coeff, Float2 mcoeff, Float2 accumu_coeff,
1016  float4* linkEven, float4* linkOdd, cudaGaugeField &siteLink,
1017  Float2* momEven, Float2* momOdd,
1018  dim3 gridDim, dim3 blockDim)
1019 
1020  {
1021  dim3 halfGridDim(gridDim.x/2, 1,1);
1022 
1023 #define CALL_ALL_LINK_KERNEL(sig_sign, mu_sign) \
1024  do_all_link_kernel<sig_sign,mu_sign,0><<<halfGridDim, blockDim>>>(tempxEven, tempxOdd, \
1025  PmuEven, PmuOdd, \
1026  P3Even, P3Odd, \
1027  P3muEven, P3muOdd, \
1028  shortPEven, shortPOdd, \
1029  sig, mu, coeff, mcoeff, accumu_coeff, \
1030  linkEven, linkOdd, \
1031  momEven, momOdd); \
1032  do_all_link_kernel<sig_sign,mu_sign,1><<<halfGridDim, blockDim>>>(tempxEven, tempxOdd, \
1033  PmuEven, PmuOdd, \
1034  P3Even, P3Odd, \
1035  P3muEven, P3muOdd, \
1036  shortPEven, shortPOdd, \
1037  sig, mu, coeff, mcoeff, accumu_coeff, \
1038  linkEven, linkOdd, \
1039  momEven, momOdd);
1040 
1041 
1042  if (GOES_FORWARDS(sig) && GOES_FORWARDS(mu)){
1043  CALL_ALL_LINK_KERNEL(1,1);
1044  }else if (GOES_FORWARDS(sig) && GOES_BACKWARDS(mu)){
1045  CALL_ALL_LINK_KERNEL(1,0);
1046  }else if (GOES_BACKWARDS(sig) && GOES_FORWARDS(mu)){
1047  CALL_ALL_LINK_KERNEL(0,1);
1048  }else{
1049  CALL_ALL_LINK_KERNEL(0,0);
1050  }
1051 
1052 #undef CALL_ALL_LINK_KERNEL
1053 
1054  }
1055 
1056  /* This function computes the one and naik terms' contribution to momentum
1057  *
1058  * Tempx: IN
1059  * Pmu: IN
1060  * Pnumu: IN
1061  *
1062  */
1063  template <int oddBit, typename Float2>
1064  __global__ void
1065  do_one_and_naik_terms_kernel(Float2* TempxEven, Float2* TempxOdd,
1066  Float2* PmuEven, Float2* PmuOdd,
1067  Float2* PnumuEven, Float2* PnumuOdd,
1068  int mu, Float2 OneLink, Float2 Naik, Float2 mNaik,
1069  float4* linkEven, float4* linkOdd,
1070  Float2* momEven, Float2* momOdd)
1071  {
1072  Float2 HWA0, HWA1, HWA2, HWA3, HWA4, HWA5;
1073  Float2 HWB0, HWB1, HWB2, HWB3, HWB4, HWB5;
1074  Float2 HWC0, HWC1, HWC2, HWC3, HWC4, HWC5;
1075  Float2 HWD0, HWD1, HWD2, HWD3, HWD4, HWD5;
1076  float4 LINKA0, LINKA1, LINKA2, LINKA3, LINKA4;
1077  float4 LINKB0, LINKB1, LINKB2, LINKB3, LINKB4;
1078  Float2 AH0, AH1, AH2, AH3, AH4;
1079 
1080  int sid = blockIdx.x * blockDim.x + threadIdx.x;
1081  int z1 = sid / X1h;
1082  int x1h = sid - z1*X1h;
1083  int z2 = z1 / X2;
1084  int x2 = z1 - z2*X2;
1085  int x4 = z2 / X3;
1086  int x3 = z2 - x4*X3;
1087  int x1odd = (x2 + x3 + x4 + oddBit) & 1;
1088  int x1 = 2*x1h + x1odd;
1089  //int X = 2*sid + x1odd;
1090 
1091  int dx[4];
1092  int new_x1, new_x2, new_x3, new_x4, new_idx;
1093  int sign=1;
1094 
1095  if (GOES_BACKWARDS(mu)){
1096  //The one link
1097  LOAD_HW(PmuEven, PmuOdd, sid, HWA, oddBit);
1098  LOAD_HW(TempxEven, TempxOdd, sid, HWB, oddBit);
1099  ADD_FORCE_TO_MOM(hwa, hwb, sid, OPP_DIR(mu), OneLink, oddBit);
1100 
1101  //Naik term
1102  dx[3]=dx[2]=dx[1]=dx[0]=0;
1103  dx[OPP_DIR(mu)] = -1;
1104  new_x1 = (x1 + dx[0] + X1)%X1;
1105  new_x2 = (x2 + dx[1] + X2)%X2;
1106  new_x3 = (x3 + dx[2] + X3)%X3;
1107  new_x4 = (x4 + dx[3] + X4)%X4;
1108  new_idx = (new_x4*X3X2X1+new_x3*X2X1+new_x2*X1+new_x1) >> 1;
1109  LOAD_HW(TempxEven, TempxOdd, new_idx, HWA, 1-oddBit);
1110  FF_LOAD_MATRIX(OPP_DIR(mu), new_idx, LINKA, 1-oddBit);
1111  FF_COMPUTE_RECONSTRUCT_SIGN(sign, OPP_DIR(mu), new_x1,new_x2,new_x3,new_x4);
1112  RECONSTRUCT_LINK_12(sign, linka);
1113  ADJ_MAT_MUL_HW(linka, hwa, hwc); //Popmu
1114 
1115  LOAD_HW(PnumuEven, PnumuOdd, sid, HWD, oddBit);
1116  ADD_FORCE_TO_MOM(hwd, hwc, sid, OPP_DIR(mu), mNaik, oddBit);
1117 
1118  dx[3]=dx[2]=dx[1]=dx[0]=0;
1119  dx[OPP_DIR(mu)] = 1;
1120  new_x1 = (x1 + dx[0] + X1)%X1;
1121  new_x2 = (x2 + dx[1] + X2)%X2;
1122  new_x3 = (x3 + dx[2] + X3)%X3;
1123  new_x4 = (x4 + dx[3] + X4)%X4;
1124  new_idx = (new_x4*X3X2X1+new_x3*X2X1+new_x2*X1+new_x1) >> 1;
1125  LOAD_HW(PnumuEven, PnumuOdd, new_idx, HWA, 1-oddBit);
1126  FF_LOAD_MATRIX(OPP_DIR(mu), sid, LINKA, oddBit);
1127  FF_COMPUTE_RECONSTRUCT_SIGN(sign, OPP_DIR(mu), x1, x2, x3, x4);
1128  RECONSTRUCT_LINK_12(sign, linka);
1129  MAT_MUL_HW(linka, hwa, hwc);
1130  ADD_FORCE_TO_MOM(hwc, hwb, sid, OPP_DIR(mu), Naik, oddBit);
1131  }else{
1132  dx[3]=dx[2]=dx[1]=dx[0]=0;
1133  dx[mu] = 1;
1134  new_x1 = (x1 + dx[0] + X1)%X1;
1135  new_x2 = (x2 + dx[1] + X2)%X2;
1136  new_x3 = (x3 + dx[2] + X3)%X3;
1137  new_x4 = (x4 + dx[3] + X4)%X4;
1138  new_idx = (new_x4*X3X2X1+new_x3*X2X1+new_x2*X1+new_x1) >> 1;
1139  LOAD_HW(TempxEven, TempxOdd, new_idx, HWA, 1-oddBit);
1140  FF_LOAD_MATRIX(mu, sid, LINKA, oddBit);
1141  FF_COMPUTE_RECONSTRUCT_SIGN(sign, mu, x1, x2, x3, x4);
1142  RECONSTRUCT_LINK_12(sign, linka);
1143  MAT_MUL_HW(linka, hwa, hwb);
1144 
1145  LOAD_HW(PnumuEven, PnumuOdd, sid, HWC, oddBit);
1146  ADD_FORCE_TO_MOM(hwb, hwc, sid, mu, Naik, oddBit);
1147 
1148 
1149  }
1150  }
1151 
1152  template<typename Float2>
1153  static void
1154  one_and_naik_terms_kernel(Float2* TempxEven, Float2* TempxOdd,
1155  Float2* PmuEven, Float2* PmuOdd,
1156  Float2* PnumuEven, Float2* PnumuOdd,
1157  int mu, Float2 OneLink, Float2 Naik, Float2 mNaik,
1158  float4* linkEven, float4* linkOdd,
1159  Float2* momEven, Float2* momOdd,
1160  dim3 gridDim, dim3 blockDim)
1161  {
1162  dim3 halfGridDim(gridDim.x/2, 1,1);
1163 
1164  do_one_and_naik_terms_kernel<0><<<halfGridDim, blockDim>>>(TempxEven, TempxOdd,
1165  PmuEven, PmuOdd,
1166  PnumuEven, PnumuOdd,
1167  mu, OneLink, Naik, mNaik,
1168  linkEven, linkOdd,
1169  momEven, momOdd);
1170  do_one_and_naik_terms_kernel<1><<<halfGridDim, blockDim>>>(TempxEven, TempxOdd,
1171  PmuEven, PmuOdd,
1172  PnumuEven, PnumuOdd,
1173  mu, OneLink, Naik, mNaik,
1174  linkEven, linkOdd,
1175  momEven, momOdd);
1176  return;
1177  }
1178 
1179 
1180 
1181 #define Pmu tempvec[0]
1182 #define Pnumu tempvec[1]
1183 #define Prhonumu tempvec[2]
1184 #define P7 tempvec[3]
1185 #define P7rho tempvec[4]
1186 #define P7rhonu tempvec[5]
1187 #define P5 tempvec[6]
1188 #define P3 tempvec[7]
1189 #define P5nu tempvec[3]
1190 #define P3mu tempvec[3]
1191 #define Popmu tempvec[4]
1192 #define Pmumumu tempvec[4]
1193 
1194  template<typename Real>
1195  static void
1196  do_fermion_force_cuda(Real eps, Real weight1, Real weight2, Real* act_path_coeff, FullHw cudaHw,
1197  cudaGaugeField &siteLink, cudaGaugeField &cudaMom, FullHw tempvec[8], QudaGaugeParam* param)
1198  {
1199 
1200  int mu, nu, rho, sig;
1201  float2 coeff;
1202 
1203  float2 OneLink, Lepage, Naik, FiveSt, ThreeSt, SevenSt;
1204  float2 mNaik, mLepage, mFiveSt, mThreeSt, mSevenSt;
1205 
1206  Real ferm_epsilon;
1207  ferm_epsilon = 2.0*weight1*eps;
1208  OneLink.x = act_path_coeff[0]*ferm_epsilon ;
1209  Naik.x = act_path_coeff[1]*ferm_epsilon ; mNaik.x = -Naik.x;
1210  ThreeSt.x = act_path_coeff[2]*ferm_epsilon ; mThreeSt.x = -ThreeSt.x;
1211  FiveSt.x = act_path_coeff[3]*ferm_epsilon ; mFiveSt.x = -FiveSt.x;
1212  SevenSt.x = act_path_coeff[4]*ferm_epsilon ; mSevenSt.x = -SevenSt.x;
1213  Lepage.x = act_path_coeff[5]*ferm_epsilon ; mLepage.x = -Lepage.x;
1214 
1215  ferm_epsilon = 2.0*weight2*eps;
1216  OneLink.y = act_path_coeff[0]*ferm_epsilon ;
1217  Naik.y = act_path_coeff[1]*ferm_epsilon ; mNaik.y = -Naik.y;
1218  ThreeSt.y = act_path_coeff[2]*ferm_epsilon ; mThreeSt.y = -ThreeSt.y;
1219  FiveSt.y = act_path_coeff[3]*ferm_epsilon ; mFiveSt.y = -FiveSt.y;
1220  SevenSt.y = act_path_coeff[4]*ferm_epsilon ; mSevenSt.y = -SevenSt.y;
1221  Lepage.y = act_path_coeff[5]*ferm_epsilon ; mLepage.y = -Lepage.y;
1222 
1223  int DirectLinks[8] ;
1224 
1225  for(mu=0;mu<8;mu++){
1226  DirectLinks[mu] = 0 ;
1227  }
1228 
1229  int volume = param->X[0]*param->X[1]*param->X[2]*param->X[3];
1230  dim3 blockDim(BLOCK_DIM,1,1);
1231  dim3 gridDim(volume/blockDim.x, 1, 1);
1232 
1233 
1234  cudaBindTexture(0, siteLink0TexSingle_recon, siteLink.Even_p(), siteLink.Bytes()/2);
1235  cudaBindTexture(0, siteLink1TexSingle_recon, siteLink.Odd_p(), siteLink.Bytes()/2);
1236 
1237 
1238  for(sig=0; sig < 8; sig++){
1239  for(mu = 0; mu < 8; mu++){
1240  if ( (mu == sig) || (mu == OPP_DIR(sig))){
1241  continue;
1242  }
1243  //3-link
1244  //Kernel A: middle link
1245 
1246  middle_link_kernel( (float2*)cudaHw.even.data, (float2*)cudaHw.odd.data,
1247  (float2*)Pmu.even.data, (float2*)Pmu.odd.data,
1248  (float2*)P3.even.data, (float2*)P3.odd.data,
1249  sig, mu, mThreeSt,
1250  (float4*)siteLink.Even_p(), (float4*)siteLink.Odd_p(), siteLink,
1251  (float2*)cudaMom.Even_p(), (float2*)cudaMom.Odd_p(),
1252  gridDim, blockDim);
1253  checkCudaError();
1254  for(nu=0; nu < 8; nu++){
1255  if (nu == sig || nu == OPP_DIR(sig)
1256  || nu == mu || nu == OPP_DIR(mu)){
1257  continue;
1258  }
1259  //5-link: middle link
1260  //Kernel B
1261  middle_link_kernel( (float2*)Pmu.even.data, (float2*)Pmu.odd.data,
1262  (float2*)Pnumu.even.data, (float2*)Pnumu.odd.data,
1263  (float2*)P5.even.data, (float2*)P5.odd.data,
1264  sig, nu, FiveSt,
1265  (float4*)siteLink.Even_p(), (float4*)siteLink.Odd_p(), siteLink,
1266  (float2*)cudaMom.Even_p(), (float2*)cudaMom.Odd_p(),
1267  gridDim, blockDim);
1268  checkCudaError();
1269 
1270  for(rho =0; rho < 8; rho++){
1271  if (rho == sig || rho == OPP_DIR(sig)
1272  || rho == mu || rho == OPP_DIR(mu)
1273  || rho == nu || rho == OPP_DIR(nu)){
1274  continue;
1275  }
1276  //7-link: middle link and side link
1277  //kernel C
1278 
1279  if(FiveSt.x != 0)coeff.x = SevenSt.x/FiveSt.x ; else coeff.x = 0;
1280  if(FiveSt.y != 0)coeff.y = SevenSt.y/FiveSt.y ; else coeff.y = 0;
1281  all_link_kernel((float2*)Pnumu.even.data, (float2*)Pnumu.odd.data,
1282  (float2*)Prhonumu.even.data, (float2*)Prhonumu.odd.data,
1283  (float2*)P7.even.data, (float2*)P7.odd.data,
1284  (float2*)P7rho.even.data, (float2*)P7rho.odd.data,
1285  (float2*)P5.even.data, (float2*)P5.odd.data,
1286  sig, rho, SevenSt,mSevenSt,coeff,
1287  (float4*)siteLink.Even_p(), (float4*)siteLink.Odd_p(), siteLink,
1288  (float2*)cudaMom.Even_p(), (float2*)cudaMom.Odd_p(),
1289  gridDim, blockDim);
1290  checkCudaError();
1291 
1292  }//rho
1293 
1294  //5-link: side link
1295  //kernel B2
1296  if(ThreeSt.x != 0)coeff.x = FiveSt.x/ThreeSt.x ; else coeff.x = 0;
1297  if(ThreeSt.y != 0)coeff.y = FiveSt.y/ThreeSt.y ; else coeff.y = 0;
1298  side_link_kernel((float2*)P5.even.data, (float2*)P5.odd.data,
1299  (float2*)P5nu.even.data, (float2*)P5nu.odd.data,
1300  (float2*)Pmu.even.data, (float2*)Pmu.odd.data,
1301  (float2*)Pnumu.even.data, (float2*)Pnumu.odd.data,
1302  (float2*)P3.even.data, (float2*)P3.odd.data,
1303  sig, nu, mFiveSt, coeff,
1304  (float4*)siteLink.Even_p(), (float4*)siteLink.Odd_p(), siteLink,
1305  (float2*)cudaMom.Even_p(), (float2*)cudaMom.Odd_p(),
1306  gridDim, blockDim);
1307  checkCudaError();
1308 
1309  }//nu
1310 
1311  //lepage
1312  //Kernel A2
1313  middle_link_kernel( (float2*)Pmu.even.data, (float2*)Pmu.odd.data,
1314  (float2*)Pnumu.even.data, (float2*)Pnumu.odd.data,
1315  (float2*)P5.even.data, (float2*)P5.odd.data,
1316  sig, mu, Lepage,
1317  (float4*)siteLink.Even_p(), (float4*)siteLink.Odd_p(), siteLink,
1318  (float2*)cudaMom.Even_p(), (float2*)cudaMom.Odd_p(),
1319  gridDim, blockDim);
1320  checkCudaError();
1321 
1322  if(ThreeSt.x != 0)coeff.x = Lepage.x/ThreeSt.x ; else coeff.x = 0;
1323  if(ThreeSt.y != 0)coeff.y = Lepage.y/ThreeSt.y ; else coeff.y = 0;
1324 
1325  side_link_kernel((float2*)P5.even.data, (float2*)P5.odd.data,
1326  (float2*)P5nu.even.data, (float2*)P5nu.odd.data,
1327  (float2*)Pmu.even.data, (float2*)Pmu.odd.data,
1328  (float2*)Pnumu.even.data, (float2*)Pnumu.odd.data,
1329  (float2*)P3.even.data, (float2*)P3.odd.data,
1330  sig, mu, mLepage,coeff,
1331  (float4*)siteLink.Even_p(), (float4*)siteLink.Odd_p(), siteLink,
1332  (float2*)cudaMom.Even_p(), (float2*)cudaMom.Odd_p(),
1333  gridDim, blockDim);
1334  checkCudaError();
1335 
1336  //3-link side link
1337  coeff.x=coeff.y=0;
1338  side_link_kernel((float2*)P3.even.data, (float2*)P3.odd.data,
1339  (float2*)P3mu.even.data, (float2*)P3mu.odd.data,
1340  (float2*)cudaHw.even.data, (float2*)cudaHw.odd.data,
1341  (float2*)Pmu.even.data, (float2*)Pmu.odd.data,
1342  (float2*)NULL, (float2*)NULL,
1343  sig, mu, ThreeSt,coeff,
1344  (float4*)siteLink.Even_p(), (float4*)siteLink.Odd_p(), siteLink,
1345  (float2*)cudaMom.Even_p(), (float2*)cudaMom.Odd_p(),
1346  gridDim, blockDim);
1347  checkCudaError();
1348 
1349  //1-link and naik term
1350  if (!DirectLinks[mu]){
1351  DirectLinks[mu]=1;
1352  //kernel Z
1353  one_and_naik_terms_kernel((float2*)cudaHw.even.data, (float2*)cudaHw.odd.data,
1354  (float2*)Pmu.even.data, (float2*)Pmu.odd.data,
1355  (float2*)Pnumu.even.data, (float2*)Pnumu.odd.data,
1356  mu, OneLink, Naik, mNaik,
1357  (float4*)siteLink.Even_p(), (float4*)siteLink.Odd_p(),
1358  (float2*)cudaMom.Even_p(), (float2*)cudaMom.Odd_p(),
1359  gridDim, blockDim);
1360 
1361  checkCudaError();
1362  }
1363 
1364  }//mu
1365 
1366  }//sig
1367 
1368  cudaUnbindTexture(siteLink0TexSingle_recon);
1369  cudaUnbindTexture(siteLink1TexSingle_recon);
1370 
1371  }
1372 
1373 #undef Pmu
1374 #undef Pnumu
1375 #undef Prhonumu
1376 #undef P7
1377 #undef P7rho
1378 #undef P7rhonu
1379 #undef P5
1380 #undef P3
1381 #undef P5nu
1382 #undef P3mu
1383 #undef Popmu
1384 #undef Pmumumu
1385 
1386  void
1387  fermion_force_cuda(double eps, double weight1, double weight2, void* act_path_coeff,
1389  {
1390  int i;
1391  FullHw tempvec[8];
1392 
1393  if (siteLink.Reconstruct() != QUDA_RECONSTRUCT_12)
1394  errorQuda("Reconstruct type %d not supported for gauge field", siteLink.Reconstruct());
1395 
1396  if (cudaMom.Reconstruct() != QUDA_RECONSTRUCT_10)
1397  errorQuda("Reconstruct type %d not supported for momentum field", cudaMom.Reconstruct());
1398 
1399  for(i=0;i < 8;i++){
1400  tempvec[i] = createHwQuda(param->X, param->cuda_prec);
1401  }
1402 
1403  if (param->cuda_prec == QUDA_DOUBLE_PRECISION){
1404  /*
1405  do_fermion_force_cuda( (double)eps, (double)weight1, (double)weight2, (double*)act_path_coeff,
1406  cudaHw, siteLink, cudaMom, tempvec, param);
1407  */
1408  errorQuda("Double precision not supported?");
1409  }else{
1410  do_fermion_force_cuda( (float)eps, (float)weight1, (float)weight2, (float*)act_path_coeff,
1411  cudaHw, siteLink, cudaMom, tempvec, param);
1412  }
1413 
1414  for(i=0;i < 8;i++){
1415  freeHwQuda(tempvec[i]);
1416  }
1417 
1418 
1419 
1420  }
1421 
1422 #undef BLOCK_DIM
1423 
1424 #undef FF_COMPUTE_NEW_FULL_IDX_PLUS_UPDATE
1425 #undef FF_COMPUTE_NEW_FULL_IDX_MINUS_UPDATE
1426 
1427 } // namespace quda