QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
dslash_util.h
Go to the documentation of this file.
1 #ifndef _DSLASH_UTIL_H
2 #define _DSLASH_UTIL_H
3 
4 #include <test_util.h>
5 #include <comm_quda.h>
6 
7 template <typename Float>
8 static inline void sum(Float *dst, Float *a, Float *b, int cnt) {
9  for (int i = 0; i < cnt; i++)
10  dst[i] = a[i] + b[i];
11 }
12 
13 template <typename Float>
14 static inline void sub(Float *dst, Float *a, Float *b, int cnt) {
15  for (int i = 0; i < cnt; i++)
16  dst[i] = a[i] - b[i];
17 }
18 
19 template <typename Float>
20 static inline void ax(Float *dst, Float a, Float *x, int cnt) {
21  for (int i = 0; i < cnt; i++)
22  dst[i] = a * x[i];
23 }
24 
25 // performs the operation y[i] = a*x[i] + y[i]
26 template <typename Float>
27 static inline void axpy(Float a, Float *x, Float *y, int len) {
28  for (int i=0; i<len; i++) y[i] = a*x[i] + y[i];
29 }
30 
31 // performs the operation y[i] = a*x[i] + b*y[i]
32 template <typename Float>
33 static inline void axpby(Float a, Float *x, Float b, Float *y, int len) {
34  for (int i=0; i<len; i++) y[i] = a*x[i] + b*y[i];
35 }
36 
37 // performs the operation y[i] = a*x[i] - y[i]
38 template <typename Float>
39 static inline void axmy(Float *x, Float a, Float *y, int len) {
40  for (int i=0; i<len; i++) y[i] = a*x[i] - y[i];
41 }
42 
43 template <typename Float>
44 static double norm2(Float *v, int len) {
45  double sum=0.0;
46  for (int i=0; i<len; i++) sum += v[i]*v[i];
47  return sum;
48 }
49 
50 template <typename Float>
51 static inline void negx(Float *x, int len) {
52  for (int i=0; i<len; i++) x[i] = -x[i];
53 }
54 
55 template <typename sFloat, typename gFloat>
56 static inline void dot(sFloat* res, gFloat* a, sFloat* b) {
57  res[0] = res[1] = 0;
58  for (int m = 0; m < 3; m++) {
59  sFloat a_re = a[2*m+0];
60  sFloat a_im = a[2*m+1];
61  sFloat b_re = b[2*m+0];
62  sFloat b_im = b[2*m+1];
63  res[0] += a_re * b_re - a_im * b_im;
64  res[1] += a_re * b_im + a_im * b_re;
65  }
66 }
67 
68 template <typename Float>
69 static inline void su3Transpose(Float *res, Float *mat) {
70  for (int m = 0; m < 3; m++) {
71  for (int n = 0; n < 3; n++) {
72  res[m*(3*2) + n*(2) + 0] = + mat[n*(3*2) + m*(2) + 0];
73  res[m*(3*2) + n*(2) + 1] = - mat[n*(3*2) + m*(2) + 1];
74  }
75  }
76 }
77 
78 
79 template <typename sFloat, typename gFloat>
80 static inline void su3Mul(sFloat *res, gFloat *mat, sFloat *vec) {
81  for (int n = 0; n < 3; n++) dot(&res[n*(2)], &mat[n*(3*2)], vec);
82 }
83 
84 template <typename sFloat, typename gFloat>
85 static inline void su3Tmul(sFloat *res, gFloat *mat, sFloat *vec) {
86  gFloat matT[3*3*2];
87  su3Transpose(matT, mat);
88  su3Mul(res, matT, vec);
89 }
90 
91 
92 // i represents a "half index" into an even or odd "half lattice".
93 // when oddBit={0,1} the half lattice is {even,odd}.
94 //
95 // the displacements, such as dx, refer to the full lattice coordinates.
96 //
97 // neighborIndex() takes a "half index", displaces it, and returns the
98 // new "half index", which can be an index into either the even or odd lattices.
99 // displacements of magnitude one always interchange odd and even lattices.
100 //
101 
102 
103 template <typename Float>
104 static inline Float *gaugeLink(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd, int nbr_distance) {
105  Float **gaugeField;
106  int j;
107  int d = nbr_distance;
108  if (dir % 2 == 0) {
109  j = i;
110  gaugeField = (oddBit ? gaugeOdd : gaugeEven);
111  }
112  else {
113  switch (dir) {
114  case 1: j = neighborIndex(i, oddBit, 0, 0, 0, -d); break;
115  case 3: j = neighborIndex(i, oddBit, 0, 0, -d, 0); break;
116  case 5: j = neighborIndex(i, oddBit, 0, -d, 0, 0); break;
117  case 7: j = neighborIndex(i, oddBit, -d, 0, 0, 0); break;
118  default: j = -1; break;
119  }
120  gaugeField = (oddBit ? gaugeEven : gaugeOdd);
121  }
122 
123  return &gaugeField[dir/2][j*(3*3*2)];
124 }
125 
126 template <typename Float>
127 static inline Float *spinorNeighbor(int i, int dir, int oddBit, Float *spinorField, int neighbor_distance)
128 {
129  int j;
130  int nb = neighbor_distance;
131  switch (dir) {
132  case 0: j = neighborIndex(i, oddBit, 0, 0, 0, +nb); break;
133  case 1: j = neighborIndex(i, oddBit, 0, 0, 0, -nb); break;
134  case 2: j = neighborIndex(i, oddBit, 0, 0, +nb, 0); break;
135  case 3: j = neighborIndex(i, oddBit, 0, 0, -nb, 0); break;
136  case 4: j = neighborIndex(i, oddBit, 0, +nb, 0, 0); break;
137  case 5: j = neighborIndex(i, oddBit, 0, -nb, 0, 0); break;
138  case 6: j = neighborIndex(i, oddBit, +nb, 0, 0, 0); break;
139  case 7: j = neighborIndex(i, oddBit, -nb, 0, 0, 0); break;
140  default: j = -1; break;
141  }
142 
143  return &spinorField[j*(mySpinorSiteSize)];
144 }
145 
146 
147 // i represents a "half index" into an even or odd "half lattice".
148 // when oddBit={0,1} the half lattice is {even,odd}.
149 //
150 // the displacements, such as dx, refer to the full lattice coordinates.
151 //
152 // neighborIndex() takes a "half index", displaces it, and returns the
153 // new "half index", which can be an index into either the even or odd lattices.
154 // displacements of magnitude one always interchange odd and even lattices.
155 //
156 //
157 template <QudaPCType type> int neighborIndex_5d(int i, int oddBit, int dxs, int dx4, int dx3, int dx2, int dx1)
158 {
159  // fullLatticeIndex was modified for fullLatticeIndex_4d. It is in util_quda.cpp.
160  // This code bit may not properly perform 5dPC.
161  int X = type == QUDA_5D_PC ? fullLatticeIndex_5d(i, oddBit) : fullLatticeIndex_5d_4dpc(i, oddBit);
162  // Checked that this matches code in dslash_core_ante.h.
163  int xs = X/(Z[3]*Z[2]*Z[1]*Z[0]);
164  int x4 = (X/(Z[2]*Z[1]*Z[0])) % Z[3];
165  int x3 = (X/(Z[1]*Z[0])) % Z[2];
166  int x2 = (X/Z[0]) % Z[1];
167  int x1 = X % Z[0];
168  // Displace and project back into domain 0,...,Ls-1.
169  // Note that we add Ls to avoid the negative problem
170  // of the C % operator.
171  xs = (xs+dxs+Ls) % Ls;
172  // Etc.
173  x4 = (x4+dx4+Z[3]) % Z[3];
174  x3 = (x3+dx3+Z[2]) % Z[2];
175  x2 = (x2+dx2+Z[1]) % Z[1];
176  x1 = (x1+dx1+Z[0]) % Z[0];
177  // Return linear half index. Remember that integer division
178  // rounds down.
179  return (xs*(Z[3]*Z[2]*Z[1]*Z[0]) + x4*(Z[2]*Z[1]*Z[0]) + x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2;
180 }
181 
182 template <QudaPCType type, typename Float>
183 Float *spinorNeighbor_5d(int i, int dir, int oddBit, Float *spinorField, int neighbor_distance = 1, int siteSize = 24)
184 {
185  int nb = neighbor_distance;
186  int j;
187  switch (dir) {
188  case 0: j = neighborIndex_5d<type>(i, oddBit, 0, 0, 0, 0, +nb); break;
189  case 1: j = neighborIndex_5d<type>(i, oddBit, 0, 0, 0, 0, -nb); break;
190  case 2: j = neighborIndex_5d<type>(i, oddBit, 0, 0, 0, +nb, 0); break;
191  case 3: j = neighborIndex_5d<type>(i, oddBit, 0, 0, 0, -nb, 0); break;
192  case 4: j = neighborIndex_5d<type>(i, oddBit, 0, 0, +nb, 0, 0); break;
193  case 5: j = neighborIndex_5d<type>(i, oddBit, 0, 0, -nb, 0, 0); break;
194  case 6: j = neighborIndex_5d<type>(i, oddBit, 0, +nb, 0, 0, 0); break;
195  case 7: j = neighborIndex_5d<type>(i, oddBit, 0, -nb, 0, 0, 0); break;
196  case 8: j = neighborIndex_5d<type>(i, oddBit, +nb, 0, 0, 0, 0); break;
197  case 9: j = neighborIndex_5d<type>(i, oddBit, -nb, 0, 0, 0, 0); break;
198  default: j = -1; break;
199  }
200  return &spinorField[j*siteSize];
201 }
202 
203 #ifdef MULTI_GPU
204 
205 inline int x4_mg(int i, int oddBit)
206 {
207  int Y = fullLatticeIndex(i, oddBit);
208  int x4 = Y/(Z[2]*Z[1]*Z[0]);
209  return x4;
210 }
211 
212 template <typename Float>
213 static inline Float *gaugeLink_mg4dir(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd,
214  Float** ghostGaugeEven, Float** ghostGaugeOdd, int n_ghost_faces, int nbr_distance) {
215  Float **gaugeField;
216  int j;
217  int d = nbr_distance;
218  if (dir % 2 == 0) {
219  j = i;
220  gaugeField = (oddBit ? gaugeOdd : gaugeEven);
221  }
222  else {
223 
224  int Y = fullLatticeIndex(i, oddBit);
225  int x4 = Y/(Z[2]*Z[1]*Z[0]);
226  int x3 = (Y/(Z[1]*Z[0])) % Z[2];
227  int x2 = (Y/Z[0]) % Z[1];
228  int x1 = Y % Z[0];
229  int X1= Z[0];
230  int X2= Z[1];
231  int X3= Z[2];
232  int X4= Z[3];
233  Float* ghostGaugeField;
234 
235  switch (dir) {
236  case 1:
237  { //-X direction
238  int new_x1 = (x1 - d + X1 )% X1;
239  if (x1 -d < 0 && comm_dim_partitioned(0)){
240  ghostGaugeField = (oddBit?ghostGaugeEven[0]: ghostGaugeOdd[0]);
241  int offset = (n_ghost_faces + x1 -d)*X4*X3*X2/2 + (x4*X3*X2 + x3*X2+x2)/2;
242  return &ghostGaugeField[offset*(3*3*2)];
243  }
244  j = (x4*X3*X2*X1 + x3*X2*X1 + x2*X1 + new_x1) / 2;
245  break;
246  }
247  case 3:
248  { //-Y direction
249  int new_x2 = (x2 - d + X2 )% X2;
250  if (x2 -d < 0 && comm_dim_partitioned(1)){
251  ghostGaugeField = (oddBit?ghostGaugeEven[1]: ghostGaugeOdd[1]);
252  int offset = (n_ghost_faces + x2 -d)*X4*X3*X1/2 + (x4*X3*X1 + x3*X1+x1)/2;
253  return &ghostGaugeField[offset*(3*3*2)];
254  }
255  j = (x4*X3*X2*X1 + x3*X2*X1 + new_x2*X1 + x1) / 2;
256  break;
257 
258  }
259  case 5:
260  { //-Z direction
261  int new_x3 = (x3 - d + X3 )% X3;
262  if (x3 -d < 0 && comm_dim_partitioned(2)){
263  ghostGaugeField = (oddBit?ghostGaugeEven[2]: ghostGaugeOdd[2]);
264  int offset = (n_ghost_faces + x3 -d)*X4*X2*X1/2 + (x4*X2*X1 + x2*X1+x1)/2;
265  return &ghostGaugeField[offset*(3*3*2)];
266  }
267  j = (x4*X3*X2*X1 + new_x3*X2*X1 + x2*X1 + x1) / 2;
268  break;
269  }
270  case 7:
271  { //-T direction
272  int new_x4 = (x4 - d + X4)% X4;
273  if (x4 -d < 0 && comm_dim_partitioned(3)){
274  ghostGaugeField = (oddBit?ghostGaugeEven[3]: ghostGaugeOdd[3]);
275  int offset = (n_ghost_faces + x4 -d)*X1*X2*X3/2 + (x3*X2*X1 + x2*X1+x1)/2;
276  return &ghostGaugeField[offset*(3*3*2)];
277  }
278  j = (new_x4*(X3*X2*X1) + x3*(X2*X1) + x2*(X1) + x1) / 2;
279  break;
280  }//7
281 
282  default: j = -1; printf("ERROR: wrong dir \n"); exit(1);
283  }
284  gaugeField = (oddBit ? gaugeEven : gaugeOdd);
285 
286  }
287 
288  return &gaugeField[dir/2][j*(3*3*2)];
289 }
290 
291 template <typename Float>
292 static inline Float *spinorNeighbor_mg4dir(int i, int dir, int oddBit, Float *spinorField, Float** fwd_nbr_spinor,
293  Float** back_nbr_spinor, int neighbor_distance, int nFace)
294 {
295  int j;
296  int nb = neighbor_distance;
297  int Y = fullLatticeIndex(i, oddBit);
298  int x4 = Y/(Z[2]*Z[1]*Z[0]);
299  int x3 = (Y/(Z[1]*Z[0])) % Z[2];
300  int x2 = (Y/Z[0]) % Z[1];
301  int x1 = Y % Z[0];
302  int X1= Z[0];
303  int X2= Z[1];
304  int X3= Z[2];
305  int X4= Z[3];
306 
307  switch (dir) {
308  case 0://+X
309  {
310  int new_x1 = (x1 + nb)% X1;
311  if(x1+nb >=X1 && comm_dim_partitioned(0) ){
312  int offset = ( x1 + nb -X1)*X4*X3*X2/2+(x4*X3*X2 + x3*X2+x2)/2;
313  return fwd_nbr_spinor[0] + offset*mySpinorSiteSize;
314  }
315  j = (x4*X3*X2*X1 + x3*X2*X1 + x2*X1 + new_x1) / 2;
316  break;
317  }
318  case 1://-X
319  {
320  int new_x1 = (x1 - nb + X1)% X1;
321  if(x1 - nb < 0 && comm_dim_partitioned(0)){
322  int offset = ( x1+nFace- nb)*X4*X3*X2/2+(x4*X3*X2 + x3*X2+x2)/2;
323  return back_nbr_spinor[0] + offset*mySpinorSiteSize;
324  }
325  j = (x4*X3*X2*X1 + x3*X2*X1 + x2*X1 + new_x1) / 2;
326  break;
327  }
328  case 2://+Y
329  {
330  int new_x2 = (x2 + nb)% X2;
331  if(x2+nb >=X2 && comm_dim_partitioned(1)){
332  int offset = ( x2 + nb -X2)*X4*X3*X1/2+(x4*X3*X1 + x3*X1+x1)/2;
333  return fwd_nbr_spinor[1] + offset*mySpinorSiteSize;
334  }
335  j = (x4*X3*X2*X1 + x3*X2*X1 + new_x2*X1 + x1) / 2;
336  break;
337  }
338  case 3:// -Y
339  {
340  int new_x2 = (x2 - nb + X2)% X2;
341  if(x2 - nb < 0 && comm_dim_partitioned(1)){
342  int offset = ( x2 + nFace -nb)*X4*X3*X1/2+(x4*X3*X1 + x3*X1+x1)/2;
343  return back_nbr_spinor[1] + offset*mySpinorSiteSize;
344  }
345  j = (x4*X3*X2*X1 + x3*X2*X1 + new_x2*X1 + x1) / 2;
346  break;
347  }
348  case 4://+Z
349  {
350  int new_x3 = (x3 + nb)% X3;
351  if(x3+nb >=X3 && comm_dim_partitioned(2)){
352  int offset = ( x3 + nb -X3)*X4*X2*X1/2+(x4*X2*X1 + x2*X1+x1)/2;
353  return fwd_nbr_spinor[2] + offset*mySpinorSiteSize;
354  }
355  j = (x4*X3*X2*X1 + new_x3*X2*X1 + x2*X1 + x1) / 2;
356  break;
357  }
358  case 5://-Z
359  {
360  int new_x3 = (x3 - nb + X3)% X3;
361  if(x3 - nb < 0 && comm_dim_partitioned(2)){
362  int offset = ( x3 + nFace -nb)*X4*X2*X1/2+(x4*X2*X1 + x2*X1+x1)/2;
363  return back_nbr_spinor[2] + offset*mySpinorSiteSize;
364  }
365  j = (x4*X3*X2*X1 + new_x3*X2*X1 + x2*X1 + x1) / 2;
366  break;
367  }
368  case 6://+T
369  {
370  j = neighborIndex_mg(i, oddBit, +nb, 0, 0, 0);
371  int x4 = x4_mg(i, oddBit);
372  if ( (x4 + nb) >= Z[3] && comm_dim_partitioned(3)){
373  int offset = (x4+nb - Z[3])*Vsh_t;
374  return &fwd_nbr_spinor[3][(offset+j)*mySpinorSiteSize];
375  }
376  break;
377  }
378  case 7://-T
379  {
380  j = neighborIndex_mg(i, oddBit, -nb, 0, 0, 0);
381  int x4 = x4_mg(i, oddBit);
382  if ( (x4 - nb) < 0 && comm_dim_partitioned(3)){
383  int offset = ( x4 - nb +nFace)*Vsh_t;
384  return &back_nbr_spinor[3][(offset+j)*mySpinorSiteSize];
385  }
386  break;
387  }
388  default: j = -1; printf("ERROR: wrong dir\n"); exit(1);
389  }
390 
391  return &spinorField[j*(mySpinorSiteSize)];
392 }
393 
394 template <QudaPCType type> int neighborIndex_5d_mgpu(int i, int oddBit, int dxs, int dx4, int dx3, int dx2, int dx1)
395 {
396  int ret;
397 
398  int Y = (type == QUDA_5D_PC) ? fullLatticeIndex_5d(i, oddBit) : fullLatticeIndex_5d_4dpc(i, oddBit);
399 
400  int xs = Y/(Z[3]*Z[2]*Z[1]*Z[0]);
401  int x4 = (Y/(Z[2]*Z[1]*Z[0])) % Z[3];
402  int x3 = (Y/(Z[1]*Z[0])) % Z[2];
403  int x2 = (Y/Z[0]) % Z[1];
404  int x1 = Y % Z[0];
405  int ghost_x4 = x4+ dx4;
406 
407  xs = (xs+dxs+Ls) % Ls;
408  x4 = (x4+dx4+Z[3]) % Z[3];
409  x3 = (x3+dx3+Z[2]) % Z[2];
410  x2 = (x2+dx2+Z[1]) % Z[1];
411  x1 = (x1+dx1+Z[0]) % Z[0];
412 
413  if ( (ghost_x4 >= 0 && ghost_x4) < Z[3] || !comm_dim_partitioned(3)){
414  ret = (xs*Z[3]*Z[2]*Z[1]*Z[0] + x4*Z[2]*Z[1]*Z[0] + x3*Z[1]*Z[0] + x2*Z[0] + x1) >> 1;
415  }else{
416  ret = (xs*Z[2]*Z[1]*Z[0] + x3*Z[1]*Z[0] + x2*Z[0] + x1) >> 1;
417  }
418 
419  return ret;
420 }
421 
422 template <QudaPCType type> int x4_5d_mgpu(int i, int oddBit)
423 {
424  int Y = (type == QUDA_5D_PC) ? fullLatticeIndex_5d(i, oddBit) : fullLatticeIndex_5d_4dpc(i, oddBit);
425  return (Y/(Z[2]*Z[1]*Z[0])) % Z[3];
426 }
427 
428 template <QudaPCType type, typename Float>
429 Float *spinorNeighbor_5d_mgpu(int i, int dir, int oddBit, Float *spinorField, Float **fwd_nbr_spinor,
430  Float **back_nbr_spinor, int neighbor_distance, int nFace, int spinorSize = 24)
431 {
432  int j;
433  int nb = neighbor_distance;
434  int Y = (type == QUDA_5D_PC) ? fullLatticeIndex_5d(i, oddBit) : fullLatticeIndex_5d_4dpc(i, oddBit);
435 
436  int xs = Y/(Z[3]*Z[2]*Z[1]*Z[0]);
437  int x4 = (Y/(Z[2]*Z[1]*Z[0])) % Z[3];
438  int x3 = (Y/(Z[1]*Z[0])) % Z[2];
439  int x2 = (Y/Z[0]) % Z[1];
440  int x1 = Y % Z[0];
441 
442  int X1= Z[0];
443  int X2= Z[1];
444  int X3= Z[2];
445  int X4= Z[3];
446  switch (dir) {
447  case 0://+X
448  {
449  int new_x1 = (x1 + nb)% X1;
450  if(x1+nb >=X1 && comm_dim_partitioned(0)) {
451  int offset = ((x1 + nb -X1)*Ls*X4*X3*X2+xs*X4*X3*X2+x4*X3*X2 + x3*X2+x2) >> 1;
452  return fwd_nbr_spinor[0] + offset*spinorSize;
453  }
454  j = (xs*X4*X3*X2*X1 + x4*X3*X2*X1 + x3*X2*X1 + x2*X1 + new_x1) >> 1;
455  break;
456  }
457  case 1://-X
458  {
459  int new_x1 = (x1 - nb + X1)% X1;
460  if(x1 - nb < 0 && comm_dim_partitioned(0)) {
461  int offset = (( x1+nFace- nb)*Ls*X4*X3*X2 + xs*X4*X3*X2 + x4*X3*X2 + x3*X2 + x2) >> 1;
462  return back_nbr_spinor[0] + offset*spinorSize;
463  }
464  j = (xs*X4*X3*X2*X1 + x4*X3*X2*X1 + x3*X2*X1 + x2*X1 + new_x1) >> 1;
465  break;
466  }
467  case 2://+Y
468  {
469  int new_x2 = (x2 + nb)% X2;
470  if(x2+nb >=X2 && comm_dim_partitioned(1)) {
471  int offset = (( x2 + nb -X2)*Ls*X4*X3*X1+xs*X4*X3*X1+x4*X3*X1 + x3*X1+x1) >> 1;
472  return fwd_nbr_spinor[1] + offset*spinorSize;
473  }
474  j = (xs*X4*X3*X2*X1 + x4*X3*X2*X1 + x3*X2*X1 + new_x2*X1 + x1) >> 1;
475  break;
476  }
477  case 3:// -Y
478  {
479  int new_x2 = (x2 - nb + X2)% X2;
480  if(x2 - nb < 0 && comm_dim_partitioned(1)) {
481  int offset = (( x2 + nFace -nb)*Ls*X4*X3*X1+xs*X4*X3*X1+ x4*X3*X1 + x3*X1+x1) >> 1;
482  return back_nbr_spinor[1] + offset*spinorSize;
483  }
484  j = (xs*X4*X3*X2*X1 + x4*X3*X2*X1 + x3*X2*X1 + new_x2*X1 + x1) >> 1;
485  break;
486  }
487  case 4://+Z
488  {
489  int new_x3 = (x3 + nb)% X3;
490  if(x3+nb >=X3 && comm_dim_partitioned(2)) {
491  int offset = (( x3 + nb -X3)*Ls*X4*X2*X1+xs*X4*X2*X1+x4*X2*X1 + x2*X1+x1) >> 1;
492  return fwd_nbr_spinor[2] + offset*spinorSize;
493  }
494  j = (xs*X4*X3*X2*X1 + x4*X3*X2*X1 + new_x3*X2*X1 + x2*X1 + x1) >> 1;
495  break;
496  }
497  case 5://-Z
498  {
499  int new_x3 = (x3 - nb + X3)% X3;
500  if(x3 - nb < 0 && comm_dim_partitioned(2)){
501  int offset = (( x3 + nFace -nb)*Ls*X4*X2*X1+xs*X4*X2*X1+x4*X2*X1+x2*X1+x1) >> 1;
502  return back_nbr_spinor[2] + offset*spinorSize;
503  }
504  j = (xs*X4*X3*X2*X1 + x4*X3*X2*X1 + new_x3*X2*X1 + x2*X1 + x1) >> 1;
505  break;
506  }
507  case 6://+T
508  {
509  int x4 = x4_5d_mgpu<type>(i, oddBit);
510  if ( (x4 + nb) >= Z[3] && comm_dim_partitioned(3)) {
511  int offset = ((x4 + nb - Z[3])*Ls*X3*X2*X1+xs*X3*X2*X1+x3*X2*X1+x2*X1+x1) >> 1;
512  return fwd_nbr_spinor[3] + offset*spinorSize;
513  }
514  j = neighborIndex_5d_mgpu<type>(i, oddBit, 0, +nb, 0, 0, 0);
515  break;
516  }
517  case 7://-T
518  {
519  int x4 = x4_5d_mgpu<type>(i, oddBit);
520  if ( (x4 - nb) < 0 && comm_dim_partitioned(3)) {
521  int offset = (( x4 - nb +nFace)*Ls*X3*X2*X1+xs*X3*X2*X1+x3*X2*X1+x2*X1+x1) >> 1;
522  return back_nbr_spinor[3] + offset*spinorSize;
523  }
524  j = neighborIndex_5d_mgpu<type>(i, oddBit, 0, -nb, 0, 0, 0);
525  break;
526  }
527  default: j = -1; printf("ERROR: wrong dir\n"); exit(1);
528  }
529 
530  return &spinorField[j*(spinorSize)];
531 }
532 
533 
534 #endif // MULTI_GPU
535 
536 #endif // _DSLASH_UTIL_H
__device__ __forceinline__ int neighborIndex(const unsigned int &cb_idx, const int(&shift)[4], const bool(&partitioned)[4], const unsigned int &parity)
static void sum(Float *dst, Float *a, Float *b, int cnt)
Definition: dslash_util.h:8
int Z[4]
Definition: test_util.cpp:26
Float * spinorNeighbor_5d(int i, int dir, int oddBit, Float *spinorField, int neighbor_distance=1, int siteSize=24)
Definition: dslash_util.h:183
static double norm2(Float *v, int len)
Definition: dslash_util.h:44
static int X2
Definition: face_gauge.cpp:42
static void sub(Float *dst, Float *a, Float *b, int cnt)
Definition: dslash_util.h:14
int fullLatticeIndex_5d_4dpc(int i, int oddBit)
Definition: test_util.cpp:687
static void axmy(Float *x, Float a, Float *y, int len)
Definition: dslash_util.h:39
static void axpby(Float a, Float *x, Float b, Float *y, int len)
Definition: dslash_util.h:33
static Float * spinorNeighbor(int i, int dir, int oddBit, Float *spinorField, int neighbor_distance)
Definition: dslash_util.h:127
int fullLatticeIndex(int dim[4], int index, int oddBit)
Definition: test_util.cpp:439
int Ls
Definition: test_util.cpp:38
int Vsh_t
Definition: test_util.cpp:30
int X[4]
Definition: covdev_test.cpp:70
static void ax(Float *dst, Float a, Float *x, int cnt)
Definition: dslash_util.h:20
static int X3
Definition: face_gauge.cpp:42
static int X1
Definition: face_gauge.cpp:42
#define mySpinorSiteSize
int neighborIndex_5d(int i, int oddBit, int dxs, int dx4, int dx3, int dx2, int dx1)
Definition: dslash_util.h:157
int neighborIndex_mg(int i, int oddBit, int dx4, int dx3, int dx2, int dx1)
Definition: test_util.cpp:523
static Float * gaugeLink(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd, int nbr_distance)
Definition: dslash_util.h:104
int fullLatticeIndex_5d(int i, int oddBit)
Definition: test_util.cpp:682
static void su3Mul(sFloat *res, gFloat *mat, sFloat *vec)
Definition: dslash_util.h:80
static void su3Tmul(sFloat *res, gFloat *mat, sFloat *vec)
Definition: dslash_util.h:85
void mat(void *out, void **link, void *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision)
static void axpy(Float a, Float *x, Float *y, int len)
Definition: dslash_util.h:27
int comm_dim_partitioned(int dim)
static void negx(Float *x, int len)
Definition: dslash_util.h:51
static int X4
Definition: face_gauge.cpp:42
static void su3Transpose(Float *res, Float *mat)
Definition: dslash_util.h:69
static void dot(sFloat *res, gFloat *a, sFloat *b)
Definition: dslash_util.h:56