QUDA  0.9.0
dslash_util.h
Go to the documentation of this file.
1 #ifndef _DSLASH_UTIL_H
2 #define _DSLASH_UTIL_H
3 
4 #include <test_util.h>
5 #include <comm_quda.h>
6 
7 template <typename Float>
8 static inline void sum(Float *dst, Float *a, Float *b, int cnt) {
9  for (int i = 0; i < cnt; i++)
10  dst[i] = a[i] + b[i];
11 }
12 
13 template <typename Float>
14 static inline void sub(Float *dst, Float *a, Float *b, int cnt) {
15  for (int i = 0; i < cnt; i++)
16  dst[i] = a[i] - b[i];
17 }
18 
19 template <typename Float>
20 static inline void ax(Float *dst, Float a, Float *x, int cnt) {
21  for (int i = 0; i < cnt; i++)
22  dst[i] = a * x[i];
23 }
24 
25 // performs the operation y[i] = a*x[i] + y[i]
26 template <typename Float>
27 static inline void axpy(Float a, Float *x, Float *y, int len) {
28  for (int i=0; i<len; i++) y[i] = a*x[i] + y[i];
29 }
30 
31 // performs the operation y[i] = a*x[i] + b*y[i]
32 template <typename Float>
33 static inline void axpby(Float a, Float *x, Float b, Float *y, int len) {
34  for (int i=0; i<len; i++) y[i] = a*x[i] + b*y[i];
35 }
36 
37 // performs the operation y[i] = a*x[i] - y[i]
38 template <typename Float>
39 static inline void axmy(Float *x, Float a, Float *y, int len) {
40  for (int i=0; i<len; i++) y[i] = a*x[i] - y[i];
41 }
42 
43 template <typename Float>
44 static double norm2(Float *v, int len) {
45  double sum=0.0;
46  for (int i=0; i<len; i++) sum += v[i]*v[i];
47  return sum;
48 }
49 
50 template <typename Float>
51 static inline void negx(Float *x, int len) {
52  for (int i=0; i<len; i++) x[i] = -x[i];
53 }
54 
55 template <typename sFloat, typename gFloat>
56 static inline void dot(sFloat* res, gFloat* a, sFloat* b) {
57  res[0] = res[1] = 0;
58  for (int m = 0; m < 3; m++) {
59  sFloat a_re = a[2*m+0];
60  sFloat a_im = a[2*m+1];
61  sFloat b_re = b[2*m+0];
62  sFloat b_im = b[2*m+1];
63  res[0] += a_re * b_re - a_im * b_im;
64  res[1] += a_re * b_im + a_im * b_re;
65  }
66 }
67 
68 template <typename Float>
69 static inline void su3Transpose(Float *res, Float *mat) {
70  for (int m = 0; m < 3; m++) {
71  for (int n = 0; n < 3; n++) {
72  res[m*(3*2) + n*(2) + 0] = + mat[n*(3*2) + m*(2) + 0];
73  res[m*(3*2) + n*(2) + 1] = - mat[n*(3*2) + m*(2) + 1];
74  }
75  }
76 }
77 
78 
79 template <typename sFloat, typename gFloat>
80 static inline void su3Mul(sFloat *res, gFloat *mat, sFloat *vec) {
81  for (int n = 0; n < 3; n++) dot(&res[n*(2)], &mat[n*(3*2)], vec);
82 }
83 
84 template <typename sFloat, typename gFloat>
85 static inline void su3Tmul(sFloat *res, gFloat *mat, sFloat *vec) {
86  gFloat matT[3*3*2];
87  su3Transpose(matT, mat);
88  su3Mul(res, matT, vec);
89 }
90 
91 
92 // i represents a "half index" into an even or odd "half lattice".
93 // when oddBit={0,1} the half lattice is {even,odd}.
94 //
95 // the displacements, such as dx, refer to the full lattice coordinates.
96 //
97 // neighborIndex() takes a "half index", displaces it, and returns the
98 // new "half index", which can be an index into either the even or odd lattices.
99 // displacements of magnitude one always interchange odd and even lattices.
100 //
101 
102 
103 template <typename Float>
104 static inline Float *gaugeLink(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd, int nbr_distance) {
105  Float **gaugeField;
106  int j;
107  int d = nbr_distance;
108  if (dir % 2 == 0) {
109  j = i;
110  gaugeField = (oddBit ? gaugeOdd : gaugeEven);
111  }
112  else {
113  switch (dir) {
114  case 1: j = neighborIndex(i, oddBit, 0, 0, 0, -d); break;
115  case 3: j = neighborIndex(i, oddBit, 0, 0, -d, 0); break;
116  case 5: j = neighborIndex(i, oddBit, 0, -d, 0, 0); break;
117  case 7: j = neighborIndex(i, oddBit, -d, 0, 0, 0); break;
118  default: j = -1; break;
119  }
120  gaugeField = (oddBit ? gaugeEven : gaugeOdd);
121  }
122 
123  return &gaugeField[dir/2][j*(3*3*2)];
124 }
125 
126 template <typename Float>
127 static inline Float *spinorNeighbor(int i, int dir, int oddBit, Float *spinorField, int neighbor_distance)
128 {
129  int j;
130  int nb = neighbor_distance;
131  switch (dir) {
132  case 0: j = neighborIndex(i, oddBit, 0, 0, 0, +nb); break;
133  case 1: j = neighborIndex(i, oddBit, 0, 0, 0, -nb); break;
134  case 2: j = neighborIndex(i, oddBit, 0, 0, +nb, 0); break;
135  case 3: j = neighborIndex(i, oddBit, 0, 0, -nb, 0); break;
136  case 4: j = neighborIndex(i, oddBit, 0, +nb, 0, 0); break;
137  case 5: j = neighborIndex(i, oddBit, 0, -nb, 0, 0); break;
138  case 6: j = neighborIndex(i, oddBit, +nb, 0, 0, 0); break;
139  case 7: j = neighborIndex(i, oddBit, -nb, 0, 0, 0); break;
140  default: j = -1; break;
141  }
142 
143  return &spinorField[j*(mySpinorSiteSize)];
144 }
145 
146 
147 // i represents a "half index" into an even or odd "half lattice".
148 // when oddBit={0,1} the half lattice is {even,odd}.
149 //
150 // the displacements, such as dx, refer to the full lattice coordinates.
151 //
152 // neighborIndex() takes a "half index", displaces it, and returns the
153 // new "half index", which can be an index into either the even or odd lattices.
154 // displacements of magnitude one always interchange odd and even lattices.
155 //
156 //
157 template<QudaDWFPCType type>
158 int neighborIndex_5d(int i, int oddBit, int dxs, int dx4, int dx3, int dx2, int dx1) {
159  // fullLatticeIndex was modified for fullLatticeIndex_4d. It is in util_quda.cpp.
160  // This code bit may not properly perform 5dPC.
161  int X = type == QUDA_5D_PC ? fullLatticeIndex_5d(i, oddBit) : fullLatticeIndex_5d_4dpc(i, oddBit);
162  // Checked that this matches code in dslash_core_ante.h.
163  int xs = X/(Z[3]*Z[2]*Z[1]*Z[0]);
164  int x4 = (X/(Z[2]*Z[1]*Z[0])) % Z[3];
165  int x3 = (X/(Z[1]*Z[0])) % Z[2];
166  int x2 = (X/Z[0]) % Z[1];
167  int x1 = X % Z[0];
168  // Displace and project back into domain 0,...,Ls-1.
169  // Note that we add Ls to avoid the negative problem
170  // of the C % operator.
171  xs = (xs+dxs+Ls) % Ls;
172  // Etc.
173  x4 = (x4+dx4+Z[3]) % Z[3];
174  x3 = (x3+dx3+Z[2]) % Z[2];
175  x2 = (x2+dx2+Z[1]) % Z[1];
176  x1 = (x1+dx1+Z[0]) % Z[0];
177  // Return linear half index. Remember that integer division
178  // rounds down.
179  return (xs*(Z[3]*Z[2]*Z[1]*Z[0]) + x4*(Z[2]*Z[1]*Z[0]) + x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2;
180 }
181 
182 
183 template <QudaDWFPCType type, typename Float>
184  Float *spinorNeighbor_5d(int i, int dir, int oddBit, Float *spinorField, int neighbor_distance=1, int siteSize=24) {
185  int nb = neighbor_distance;
186  int j;
187  switch (dir) {
188  case 0: j = neighborIndex_5d<type>(i, oddBit, 0, 0, 0, 0, +nb); break;
189  case 1: j = neighborIndex_5d<type>(i, oddBit, 0, 0, 0, 0, -nb); break;
190  case 2: j = neighborIndex_5d<type>(i, oddBit, 0, 0, 0, +nb, 0); break;
191  case 3: j = neighborIndex_5d<type>(i, oddBit, 0, 0, 0, -nb, 0); break;
192  case 4: j = neighborIndex_5d<type>(i, oddBit, 0, 0, +nb, 0, 0); break;
193  case 5: j = neighborIndex_5d<type>(i, oddBit, 0, 0, -nb, 0, 0); break;
194  case 6: j = neighborIndex_5d<type>(i, oddBit, 0, +nb, 0, 0, 0); break;
195  case 7: j = neighborIndex_5d<type>(i, oddBit, 0, -nb, 0, 0, 0); break;
196  case 8: j = neighborIndex_5d<type>(i, oddBit, +nb, 0, 0, 0, 0); break;
197  case 9: j = neighborIndex_5d<type>(i, oddBit, -nb, 0, 0, 0, 0); break;
198  default: j = -1; break;
199  }
200  return &spinorField[j*siteSize];
201 }
202 
203 
204 #ifdef MULTI_GPU
205 
206 static inline int
207 x4_mg(int i, int oddBit)
208 {
209  int Y = fullLatticeIndex(i, oddBit);
210  int x4 = Y/(Z[2]*Z[1]*Z[0]);
211  return x4;
212 }
213 
214 template <typename Float>
215 static inline Float *gaugeLink_mg4dir(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd,
216  Float** ghostGaugeEven, Float** ghostGaugeOdd, int n_ghost_faces, int nbr_distance) {
217  Float **gaugeField;
218  int j;
219  int d = nbr_distance;
220  if (dir % 2 == 0) {
221  j = i;
222  gaugeField = (oddBit ? gaugeOdd : gaugeEven);
223  }
224  else {
225 
226  int Y = fullLatticeIndex(i, oddBit);
227  int x4 = Y/(Z[2]*Z[1]*Z[0]);
228  int x3 = (Y/(Z[1]*Z[0])) % Z[2];
229  int x2 = (Y/Z[0]) % Z[1];
230  int x1 = Y % Z[0];
231  int X1= Z[0];
232  int X2= Z[1];
233  int X3= Z[2];
234  int X4= Z[3];
235  Float* ghostGaugeField;
236 
237  switch (dir) {
238  case 1:
239  { //-X direction
240  int new_x1 = (x1 - d + X1 )% X1;
241  if (x1 -d < 0 && comm_dim_partitioned(0)){
242  ghostGaugeField = (oddBit?ghostGaugeEven[0]: ghostGaugeOdd[0]);
243  int offset = (n_ghost_faces + x1 -d)*X4*X3*X2/2 + (x4*X3*X2 + x3*X2+x2)/2;
244  return &ghostGaugeField[offset*(3*3*2)];
245  }
246  j = (x4*X3*X2*X1 + x3*X2*X1 + x2*X1 + new_x1) / 2;
247  break;
248  }
249  case 3:
250  { //-Y direction
251  int new_x2 = (x2 - d + X2 )% X2;
252  if (x2 -d < 0 && comm_dim_partitioned(1)){
253  ghostGaugeField = (oddBit?ghostGaugeEven[1]: ghostGaugeOdd[1]);
254  int offset = (n_ghost_faces + x2 -d)*X4*X3*X1/2 + (x4*X3*X1 + x3*X1+x1)/2;
255  return &ghostGaugeField[offset*(3*3*2)];
256  }
257  j = (x4*X3*X2*X1 + x3*X2*X1 + new_x2*X1 + x1) / 2;
258  break;
259 
260  }
261  case 5:
262  { //-Z direction
263  int new_x3 = (x3 - d + X3 )% X3;
264  if (x3 -d < 0 && comm_dim_partitioned(2)){
265  ghostGaugeField = (oddBit?ghostGaugeEven[2]: ghostGaugeOdd[2]);
266  int offset = (n_ghost_faces + x3 -d)*X4*X2*X1/2 + (x4*X2*X1 + x2*X1+x1)/2;
267  return &ghostGaugeField[offset*(3*3*2)];
268  }
269  j = (x4*X3*X2*X1 + new_x3*X2*X1 + x2*X1 + x1) / 2;
270  break;
271  }
272  case 7:
273  { //-T direction
274  int new_x4 = (x4 - d + X4)% X4;
275  if (x4 -d < 0 && comm_dim_partitioned(3)){
276  ghostGaugeField = (oddBit?ghostGaugeEven[3]: ghostGaugeOdd[3]);
277  int offset = (n_ghost_faces + x4 -d)*X1*X2*X3/2 + (x3*X2*X1 + x2*X1+x1)/2;
278  return &ghostGaugeField[offset*(3*3*2)];
279  }
280  j = (new_x4*(X3*X2*X1) + x3*(X2*X1) + x2*(X1) + x1) / 2;
281  break;
282  }//7
283 
284  default: j = -1; printf("ERROR: wrong dir \n"); exit(1);
285  }
286  gaugeField = (oddBit ? gaugeEven : gaugeOdd);
287 
288  }
289 
290  return &gaugeField[dir/2][j*(3*3*2)];
291 }
292 
293 template <typename Float>
294 static inline Float *spinorNeighbor_mg4dir(int i, int dir, int oddBit, Float *spinorField, Float** fwd_nbr_spinor,
295  Float** back_nbr_spinor, int neighbor_distance, int nFace)
296 {
297  int j;
298  int nb = neighbor_distance;
299  int Y = fullLatticeIndex(i, oddBit);
300  int x4 = Y/(Z[2]*Z[1]*Z[0]);
301  int x3 = (Y/(Z[1]*Z[0])) % Z[2];
302  int x2 = (Y/Z[0]) % Z[1];
303  int x1 = Y % Z[0];
304  int X1= Z[0];
305  int X2= Z[1];
306  int X3= Z[2];
307  int X4= Z[3];
308 
309  switch (dir) {
310  case 0://+X
311  {
312  int new_x1 = (x1 + nb)% X1;
313  if(x1+nb >=X1 && comm_dim_partitioned(0) ){
314  int offset = ( x1 + nb -X1)*X4*X3*X2/2+(x4*X3*X2 + x3*X2+x2)/2;
315  return fwd_nbr_spinor[0] + offset*mySpinorSiteSize;
316  }
317  j = (x4*X3*X2*X1 + x3*X2*X1 + x2*X1 + new_x1) / 2;
318  break;
319  }
320  case 1://-X
321  {
322  int new_x1 = (x1 - nb + X1)% X1;
323  if(x1 - nb < 0 && comm_dim_partitioned(0)){
324  int offset = ( x1+nFace- nb)*X4*X3*X2/2+(x4*X3*X2 + x3*X2+x2)/2;
325  return back_nbr_spinor[0] + offset*mySpinorSiteSize;
326  }
327  j = (x4*X3*X2*X1 + x3*X2*X1 + x2*X1 + new_x1) / 2;
328  break;
329  }
330  case 2://+Y
331  {
332  int new_x2 = (x2 + nb)% X2;
333  if(x2+nb >=X2 && comm_dim_partitioned(1)){
334  int offset = ( x2 + nb -X2)*X4*X3*X1/2+(x4*X3*X1 + x3*X1+x1)/2;
335  return fwd_nbr_spinor[1] + offset*mySpinorSiteSize;
336  }
337  j = (x4*X3*X2*X1 + x3*X2*X1 + new_x2*X1 + x1) / 2;
338  break;
339  }
340  case 3:// -Y
341  {
342  int new_x2 = (x2 - nb + X2)% X2;
343  if(x2 - nb < 0 && comm_dim_partitioned(1)){
344  int offset = ( x2 + nFace -nb)*X4*X3*X1/2+(x4*X3*X1 + x3*X1+x1)/2;
345  return back_nbr_spinor[1] + offset*mySpinorSiteSize;
346  }
347  j = (x4*X3*X2*X1 + x3*X2*X1 + new_x2*X1 + x1) / 2;
348  break;
349  }
350  case 4://+Z
351  {
352  int new_x3 = (x3 + nb)% X3;
353  if(x3+nb >=X3 && comm_dim_partitioned(2)){
354  int offset = ( x3 + nb -X3)*X4*X2*X1/2+(x4*X2*X1 + x2*X1+x1)/2;
355  return fwd_nbr_spinor[2] + offset*mySpinorSiteSize;
356  }
357  j = (x4*X3*X2*X1 + new_x3*X2*X1 + x2*X1 + x1) / 2;
358  break;
359  }
360  case 5://-Z
361  {
362  int new_x3 = (x3 - nb + X3)% X3;
363  if(x3 - nb < 0 && comm_dim_partitioned(2)){
364  int offset = ( x3 + nFace -nb)*X4*X2*X1/2+(x4*X2*X1 + x2*X1+x1)/2;
365  return back_nbr_spinor[2] + offset*mySpinorSiteSize;
366  }
367  j = (x4*X3*X2*X1 + new_x3*X2*X1 + x2*X1 + x1) / 2;
368  break;
369  }
370  case 6://+T
371  {
372  j = neighborIndex_mg(i, oddBit, +nb, 0, 0, 0);
373  int x4 = x4_mg(i, oddBit);
374  if ( (x4 + nb) >= Z[3] && comm_dim_partitioned(3)){
375  int offset = (x4+nb - Z[3])*Vsh_t;
376  return &fwd_nbr_spinor[3][(offset+j)*mySpinorSiteSize];
377  }
378  break;
379  }
380  case 7://-T
381  {
382  j = neighborIndex_mg(i, oddBit, -nb, 0, 0, 0);
383  int x4 = x4_mg(i, oddBit);
384  if ( (x4 - nb) < 0 && comm_dim_partitioned(3)){
385  int offset = ( x4 - nb +nFace)*Vsh_t;
386  return &back_nbr_spinor[3][(offset+j)*mySpinorSiteSize];
387  }
388  break;
389  }
390  default: j = -1; printf("ERROR: wrong dir\n"); exit(1);
391  }
392 
393  return &spinorField[j*(mySpinorSiteSize)];
394 }
395 
396 template<QudaDWFPCType type>
397 int neighborIndex_5d_mgpu(int i, int oddBit, int dxs, int dx4, int dx3, int dx2, int dx1)
398 {
399  int ret;
400 
401  int Y = (type == QUDA_5D_PC) ? fullLatticeIndex_5d(i, oddBit) : fullLatticeIndex_5d_4dpc(i, oddBit);
402 
403  int xs = Y/(Z[3]*Z[2]*Z[1]*Z[0]);
404  int x4 = (Y/(Z[2]*Z[1]*Z[0])) % Z[3];
405  int x3 = (Y/(Z[1]*Z[0])) % Z[2];
406  int x2 = (Y/Z[0]) % Z[1];
407  int x1 = Y % Z[0];
408  int ghost_x4 = x4+ dx4;
409 
410  xs = (xs+dxs+Ls) % Ls;
411  x4 = (x4+dx4+Z[3]) % Z[3];
412  x3 = (x3+dx3+Z[2]) % Z[2];
413  x2 = (x2+dx2+Z[1]) % Z[1];
414  x1 = (x1+dx1+Z[0]) % Z[0];
415 
416  if ( (ghost_x4 >= 0 && ghost_x4) < Z[3] || !comm_dim_partitioned(3)){
417  ret = (xs*Z[3]*Z[2]*Z[1]*Z[0] + x4*Z[2]*Z[1]*Z[0] + x3*Z[1]*Z[0] + x2*Z[0] + x1) >> 1;
418  }else{
419  ret = (xs*Z[2]*Z[1]*Z[0] + x3*Z[1]*Z[0] + x2*Z[0] + x1) >> 1;
420  }
421 
422  return ret;
423 }
424 
425 template <QudaDWFPCType type>
426 int x4_5d_mgpu(int i, int oddBit)
427 {
428  int Y = (type == QUDA_5D_PC) ? fullLatticeIndex_5d(i, oddBit) : fullLatticeIndex_5d_4dpc(i, oddBit);
429  return (Y/(Z[2]*Z[1]*Z[0])) % Z[3];
430 }
431 
432 
433 template <QudaDWFPCType type, typename Float>
434 Float *spinorNeighbor_5d_mgpu(int i, int dir, int oddBit, Float *spinorField, Float** fwd_nbr_spinor, Float** back_nbr_spinor, int neighbor_distance, int nFace, int spinorSize = 24)
435 {
436  int j;
437  int nb = neighbor_distance;
438  int Y = (type == QUDA_5D_PC) ? fullLatticeIndex_5d(i, oddBit) : fullLatticeIndex_5d_4dpc(i, oddBit);
439 
440  int xs = Y/(Z[3]*Z[2]*Z[1]*Z[0]);
441  int x4 = (Y/(Z[2]*Z[1]*Z[0])) % Z[3];
442  int x3 = (Y/(Z[1]*Z[0])) % Z[2];
443  int x2 = (Y/Z[0]) % Z[1];
444  int x1 = Y % Z[0];
445 
446  int X1= Z[0];
447  int X2= Z[1];
448  int X3= Z[2];
449  int X4= Z[3];
450  switch (dir) {
451  case 0://+X
452  {
453  int new_x1 = (x1 + nb)% X1;
454  if(x1+nb >=X1 && comm_dim_partitioned(0)) {
455  int offset = ((x1 + nb -X1)*Ls*X4*X3*X2+xs*X4*X3*X2+x4*X3*X2 + x3*X2+x2) >> 1;
456  return fwd_nbr_spinor[0] + offset*spinorSize;
457  }
458  j = (xs*X4*X3*X2*X1 + x4*X3*X2*X1 + x3*X2*X1 + x2*X1 + new_x1) >> 1;
459  break;
460  }
461  case 1://-X
462  {
463  int new_x1 = (x1 - nb + X1)% X1;
464  if(x1 - nb < 0 && comm_dim_partitioned(0)) {
465  int offset = (( x1+nFace- nb)*Ls*X4*X3*X2 + xs*X4*X3*X2 + x4*X3*X2 + x3*X2 + x2) >> 1;
466  return back_nbr_spinor[0] + offset*spinorSize;
467  }
468  j = (xs*X4*X3*X2*X1 + x4*X3*X2*X1 + x3*X2*X1 + x2*X1 + new_x1) >> 1;
469  break;
470  }
471  case 2://+Y
472  {
473  int new_x2 = (x2 + nb)% X2;
474  if(x2+nb >=X2 && comm_dim_partitioned(1)) {
475  int offset = (( x2 + nb -X2)*Ls*X4*X3*X1+xs*X4*X3*X1+x4*X3*X1 + x3*X1+x1) >> 1;
476  return fwd_nbr_spinor[1] + offset*spinorSize;
477  }
478  j = (xs*X4*X3*X2*X1 + x4*X3*X2*X1 + x3*X2*X1 + new_x2*X1 + x1) >> 1;
479  break;
480  }
481  case 3:// -Y
482  {
483  int new_x2 = (x2 - nb + X2)% X2;
484  if(x2 - nb < 0 && comm_dim_partitioned(1)) {
485  int offset = (( x2 + nFace -nb)*Ls*X4*X3*X1+xs*X4*X3*X1+ x4*X3*X1 + x3*X1+x1) >> 1;
486  return back_nbr_spinor[1] + offset*spinorSize;
487  }
488  j = (xs*X4*X3*X2*X1 + x4*X3*X2*X1 + x3*X2*X1 + new_x2*X1 + x1) >> 1;
489  break;
490  }
491  case 4://+Z
492  {
493  int new_x3 = (x3 + nb)% X3;
494  if(x3+nb >=X3 && comm_dim_partitioned(2)) {
495  int offset = (( x3 + nb -X3)*Ls*X4*X2*X1+xs*X4*X2*X1+x4*X2*X1 + x2*X1+x1) >> 1;
496  return fwd_nbr_spinor[2] + offset*spinorSize;
497  }
498  j = (xs*X4*X3*X2*X1 + x4*X3*X2*X1 + new_x3*X2*X1 + x2*X1 + x1) >> 1;
499  break;
500  }
501  case 5://-Z
502  {
503  int new_x3 = (x3 - nb + X3)% X3;
504  if(x3 - nb < 0 && comm_dim_partitioned(2)){
505  int offset = (( x3 + nFace -nb)*Ls*X4*X2*X1+xs*X4*X2*X1+x4*X2*X1+x2*X1+x1) >> 1;
506  return back_nbr_spinor[2] + offset*spinorSize;
507  }
508  j = (xs*X4*X3*X2*X1 + x4*X3*X2*X1 + new_x3*X2*X1 + x2*X1 + x1) >> 1;
509  break;
510  }
511  case 6://+T
512  {
513  int x4 = x4_5d_mgpu<type>(i, oddBit);
514  if ( (x4 + nb) >= Z[3] && comm_dim_partitioned(3)) {
515  int offset = ((x4 + nb - Z[3])*Ls*X3*X2*X1+xs*X3*X2*X1+x3*X2*X1+x2*X1+x1) >> 1;
516  return fwd_nbr_spinor[3] + offset*spinorSize;
517  }
518  j = neighborIndex_5d_mgpu<type>(i, oddBit, 0, +nb, 0, 0, 0);
519  break;
520  }
521  case 7://-T
522  {
523  int x4 = x4_5d_mgpu<type>(i, oddBit);
524  if ( (x4 - nb) < 0 && comm_dim_partitioned(3)) {
525  int offset = (( x4 - nb +nFace)*Ls*X3*X2*X1+xs*X3*X2*X1+x3*X2*X1+x2*X1+x1) >> 1;
526  return back_nbr_spinor[3] + offset*spinorSize;
527  }
528  j = neighborIndex_5d_mgpu<type>(i, oddBit, 0, -nb, 0, 0, 0);
529  break;
530  }
531  default: j = -1; printf("ERROR: wrong dir\n"); exit(1);
532  }
533 
534  return &spinorField[j*(spinorSize)];
535 }
536 
537 
538 #endif // MULTI_GPU
539 
540 #endif // _DSLASH_UTIL_H
__device__ __forceinline__ int neighborIndex(const unsigned int &cb_idx, const int(&shift)[4], const bool(&partitioned)[4], const unsigned int &parity)
static void sum(Float *dst, Float *a, Float *b, int cnt)
Definition: dslash_util.h:8
Float * spinorNeighbor_5d(int i, int dir, int oddBit, Float *spinorField, int neighbor_distance=1, int siteSize=24)
Definition: dslash_util.h:184
static double norm2(Float *v, int len)
Definition: dslash_util.h:44
static void sub(Float *dst, Float *a, Float *b, int cnt)
Definition: dslash_util.h:14
int fullLatticeIndex_5d_4dpc(int i, int oddBit)
Definition: test_util.cpp:697
static void axmy(Float *x, Float a, Float *y, int len)
Definition: dslash_util.h:39
static unsigned cnt
void exit(int) __attribute__((noreturn))
static void axpby(Float a, Float *x, Float b, Float *y, int len)
Definition: dslash_util.h:33
static Float * spinorNeighbor(int i, int dir, int oddBit, Float *spinorField, int neighbor_distance)
Definition: dslash_util.h:127
size_t size_t offset
int fullLatticeIndex(int dim[4], int index, int oddBit)
Definition: test_util.cpp:442
int Ls
Definition: test_util.cpp:39
#define b
int Vsh_t
Definition: test_util.cpp:31
int printf(const char *,...) __attribute__((__format__(__printf__
#define mySpinorSiteSize
int Z[4]
Definition: test_util.cpp:27
static void ax(Float *dst, Float a, Float *x, int cnt)
Definition: dslash_util.h:20
int neighborIndex_5d(int i, int oddBit, int dxs, int dx4, int dx3, int dx2, int dx1)
Definition: dslash_util.h:158
int neighborIndex_mg(int i, int oddBit, int dx4, int dx3, int dx2, int dx1)
Definition: test_util.cpp:527
static Float * gaugeLink(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd, int nbr_distance)
Definition: dslash_util.h:104
int fullLatticeIndex_5d(int i, int oddBit)
Definition: test_util.cpp:692
static void su3Mul(sFloat *res, gFloat *mat, sFloat *vec)
Definition: dslash_util.h:80
static void su3Tmul(sFloat *res, gFloat *mat, sFloat *vec)
Definition: dslash_util.h:85
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
static __inline__ size_t size_t d
static void axpy(Float a, Float *x, Float *y, int len)
Definition: dslash_util.h:27
#define a
int comm_dim_partitioned(int dim)
static void negx(Float *x, int len)
Definition: dslash_util.h:51
static void su3Transpose(Float *res, Float *mat)
Definition: dslash_util.h:69
static void dot(sFloat *res, gFloat *a, sFloat *b)
Definition: dslash_util.h:56