QUDA  1.0.0
domain_wall_dslash_reference.cpp
Go to the documentation of this file.
1 #include <iostream>
2 #include <stdio.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <math.h>
6 #include <complex.h>
7 
8 #include <quda.h>
9 #include <test_util.h>
10 #include <dslash_util.h>
12 #include <blas_reference.h>
13 
14 #include <gauge_field.h>
15 #include <color_spinor_field.h>
16 
17 using namespace quda;
18 
19 // i represents a "half index" into an even or odd "half lattice".
20 // when oddBit={0,1} the half lattice is {even,odd}.
21 //
22 // the displacements, such as dx, refer to the full lattice coordinates.
23 //
24 // neighborIndex() takes a "half index", displaces it, and returns the
25 // new "half index", which can be an index into either the even or odd lattices.
26 // displacements of magnitude one always interchange odd and even lattices.
27 //
28 //
29 int neighborIndex_4d(int i, int oddBit, int dx4, int dx3, int dx2, int dx1) {
30  // On input i should be in the range [0 , ... , Z[0]*Z[1]*Z[2]*Z[3]/2-1].
31  if (i < 0 || i >= (Z[0]*Z[1]*Z[2]*Z[3]/2))
32  { printf("i out of range in neighborIndex_4d\n"); exit(-1); }
33  // Compute the linear index. Then dissect.
34  // fullLatticeIndex_4d is in util_quda.cpp.
35  // The gauge fields live on a 4d sublattice.
36  int X = fullLatticeIndex_4d(i, oddBit);
37  int x4 = X/(Z[2]*Z[1]*Z[0]);
38  int x3 = (X/(Z[1]*Z[0])) % Z[2];
39  int x2 = (X/Z[0]) % Z[1];
40  int x1 = X % Z[0];
41 
42  x4 = (x4+dx4+Z[3]) % Z[3];
43  x3 = (x3+dx3+Z[2]) % Z[2];
44  x2 = (x2+dx2+Z[1]) % Z[1];
45  x1 = (x1+dx1+Z[0]) % Z[0];
46 
47  return (x4*(Z[2]*Z[1]*Z[0]) + x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2;
48 }
49 
50 
51 
52 //#ifndef MULTI_GPU
53 // This is just a copy of gaugeLink() from the quda code, except
54 // that neighborIndex() is replaced by the renamed version
55 // neighborIndex_4d().
56 //ok
57 template <typename Float>
58 Float *gaugeLink_sgpu(int i, int dir, int oddBit, Float **gaugeEven,
59  Float **gaugeOdd) {
60  Float **gaugeField;
61  int j;
62 
63  // If going forward, just grab link at site, U_\mu(x).
64  if (dir % 2 == 0) {
65  j = i;
66  // j will get used in the return statement below.
67  gaugeField = (oddBit ? gaugeOdd : gaugeEven);
68  } else {
69  // If going backward, a shift must occur, U_\mu(x-\muhat)^\dagger;
70  // dagger happens elsewhere, here we're just doing index gymnastics.
71  switch (dir) {
72  case 1: j = neighborIndex_4d(i, oddBit, 0, 0, 0, -1); break;
73  case 3: j = neighborIndex_4d(i, oddBit, 0, 0, -1, 0); break;
74  case 5: j = neighborIndex_4d(i, oddBit, 0, -1, 0, 0); break;
75  case 7: j = neighborIndex_4d(i, oddBit, -1, 0, 0, 0); break;
76  default: j = -1; break;
77  }
78  gaugeField = (oddBit ? gaugeEven : gaugeOdd);
79  }
80 
81  return &gaugeField[dir/2][j*(3*3*2)];
82 }
83 
84 
85 //#else
86 
87 //Standard 4d version (nothing to change)
88 template <typename Float>
89 Float *gaugeLink_mgpu(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd, Float** ghostGaugeEven, Float** ghostGaugeOdd, int n_ghost_faces, int nbr_distance) {
90  Float **gaugeField;
91  int j;
92  int d = nbr_distance;
93  if (dir % 2 == 0) {
94  j = i;
95  gaugeField = (oddBit ? gaugeOdd : gaugeEven);
96  }
97  else {
98 
99  int Y = fullLatticeIndex(i, oddBit);
100  int x4 = Y/(Z[2]*Z[1]*Z[0]);
101  int x3 = (Y/(Z[1]*Z[0])) % Z[2];
102  int x2 = (Y/Z[0]) % Z[1];
103  int x1 = Y % Z[0];
104  int X1= Z[0];
105  int X2= Z[1];
106  int X3= Z[2];
107  int X4= Z[3];
108  Float* ghostGaugeField;
109 
110  switch (dir) {
111  case 1:
112  { //-X direction
113  int new_x1 = (x1 - d + X1 )% X1;
114  if (x1 -d < 0 && comm_dim_partitioned(0)){
115  ghostGaugeField = (oddBit?ghostGaugeEven[0]: ghostGaugeOdd[0]);
116  int offset = (n_ghost_faces + x1 -d)*X4*X3*X2/2 + (x4*X3*X2 + x3*X2+x2)/2;
117  return &ghostGaugeField[offset*(3*3*2)];
118  }
119  j = (x4*X3*X2*X1 + x3*X2*X1 + x2*X1 + new_x1) / 2;
120  break;
121  }
122  case 3:
123  { //-Y direction
124  int new_x2 = (x2 - d + X2 )% X2;
125  if (x2 -d < 0 && comm_dim_partitioned(1)){
126  ghostGaugeField = (oddBit?ghostGaugeEven[1]: ghostGaugeOdd[1]);
127  int offset = (n_ghost_faces + x2 -d)*X4*X3*X1/2 + (x4*X3*X1 + x3*X1+x1)/2;
128  return &ghostGaugeField[offset*(3*3*2)];
129  }
130  j = (x4*X3*X2*X1 + x3*X2*X1 + new_x2*X1 + x1) / 2;
131  break;
132 
133  }
134  case 5:
135  { //-Z direction
136  int new_x3 = (x3 - d + X3 )% X3;
137  if (x3 -d < 0 && comm_dim_partitioned(2)){
138  ghostGaugeField = (oddBit?ghostGaugeEven[2]: ghostGaugeOdd[2]);
139  int offset = (n_ghost_faces + x3 -d)*X4*X2*X1/2 + (x4*X2*X1 + x2*X1+x1)/2;
140  return &ghostGaugeField[offset*(3*3*2)];
141  }
142  j = (x4*X3*X2*X1 + new_x3*X2*X1 + x2*X1 + x1) / 2;
143  break;
144  }
145  case 7:
146  { //-T direction
147  int new_x4 = (x4 - d + X4)% X4;
148  if (x4 -d < 0 && comm_dim_partitioned(3)){
149  ghostGaugeField = (oddBit?ghostGaugeEven[3]: ghostGaugeOdd[3]);
150  int offset = (n_ghost_faces + x4 -d)*X1*X2*X3/2 + (x3*X2*X1 + x2*X1+x1)/2;
151  return &ghostGaugeField[offset*(3*3*2)];
152  }
153  j = (new_x4*(X3*X2*X1) + x3*(X2*X1) + x2*(X1) + x1) / 2;
154  break;
155  }//7
156 
157  default: j = -1; printf("ERROR: wrong dir \n"); exit(1);
158  }
159  gaugeField = (oddBit ? gaugeEven : gaugeOdd);
160 
161  }
162 
163  return &gaugeField[dir/2][j*(3*3*2)];
164 }
165 
166 
167 //J Directions 0..7 were used in the 4d code.
168 //J Directions 8,9 will be for P_- and P_+, chiral
169 //J projectors.
170 const double projector[10][4][4][2] = {
171  {
172  {{1,0}, {0,0}, {0,0}, {0,-1}},
173  {{0,0}, {1,0}, {0,-1}, {0,0}},
174  {{0,0}, {0,1}, {1,0}, {0,0}},
175  {{0,1}, {0,0}, {0,0}, {1,0}}
176  },
177  {
178  {{1,0}, {0,0}, {0,0}, {0,1}},
179  {{0,0}, {1,0}, {0,1}, {0,0}},
180  {{0,0}, {0,-1}, {1,0}, {0,0}},
181  {{0,-1}, {0,0}, {0,0}, {1,0}}
182  },
183  {
184  {{1,0}, {0,0}, {0,0}, {1,0}},
185  {{0,0}, {1,0}, {-1,0}, {0,0}},
186  {{0,0}, {-1,0}, {1,0}, {0,0}},
187  {{1,0}, {0,0}, {0,0}, {1,0}}
188  },
189  {
190  {{1,0}, {0,0}, {0,0}, {-1,0}},
191  {{0,0}, {1,0}, {1,0}, {0,0}},
192  {{0,0}, {1,0}, {1,0}, {0,0}},
193  {{-1,0}, {0,0}, {0,0}, {1,0}}
194  },
195  {
196  {{1,0}, {0,0}, {0,-1}, {0,0}},
197  {{0,0}, {1,0}, {0,0}, {0,1}},
198  {{0,1}, {0,0}, {1,0}, {0,0}},
199  {{0,0}, {0,-1}, {0,0}, {1,0}}
200  },
201  {
202  {{1,0}, {0,0}, {0,1}, {0,0}},
203  {{0,0}, {1,0}, {0,0}, {0,-1}},
204  {{0,-1}, {0,0}, {1,0}, {0,0}},
205  {{0,0}, {0,1}, {0,0}, {1,0}}
206  },
207  {
208  {{1,0}, {0,0}, {-1,0}, {0,0}},
209  {{0,0}, {1,0}, {0,0}, {-1,0}},
210  {{-1,0}, {0,0}, {1,0}, {0,0}},
211  {{0,0}, {-1,0}, {0,0}, {1,0}}
212  },
213  {
214  {{1,0}, {0,0}, {1,0}, {0,0}},
215  {{0,0}, {1,0}, {0,0}, {1,0}},
216  {{1,0}, {0,0}, {1,0}, {0,0}},
217  {{0,0}, {1,0}, {0,0}, {1,0}}
218  },
219  // P_+ = P_R
220  {
221  {{0,0}, {0,0}, {0,0}, {0,0}},
222  {{0,0}, {0,0}, {0,0}, {0,0}},
223  {{0,0}, {0,0}, {2,0}, {0,0}},
224  {{0,0}, {0,0}, {0,0}, {2,0}}
225  },
226  // P_- = P_L
227  {
228  {{2,0}, {0,0}, {0,0}, {0,0}},
229  {{0,0}, {2,0}, {0,0}, {0,0}},
230  {{0,0}, {0,0}, {0,0}, {0,0}},
231  {{0,0}, {0,0}, {0,0}, {0,0}}
232  }
233 };
234 
235 
236 // todo pass projector
237 template <typename Float>
238 void multiplySpinorByDiracProjector5(Float *res, int projIdx, Float *spinorIn) {
239  for (int i=0; i<4*3*2; i++) res[i] = 0.0;
240 
241  for (int s = 0; s < 4; s++) {
242  for (int t = 0; t < 4; t++) {
243  Float projRe = projector[projIdx][s][t][0];
244  Float projIm = projector[projIdx][s][t][1];
245 
246  for (int m = 0; m < 3; m++) {
247  Float spinorRe = spinorIn[t*(3*2) + m*(2) + 0];
248  Float spinorIm = spinorIn[t*(3*2) + m*(2) + 1];
249  res[s*(3*2) + m*(2) + 0] += projRe*spinorRe - projIm*spinorIm;
250  res[s*(3*2) + m*(2) + 1] += projRe*spinorIm + projIm*spinorRe;
251  }
252  }
253  }
254 }
255 
256 
257 //#ifndef MULTI_GPU
258 // dslashReference_4d()
259 //J This is just the 4d wilson dslash of quda code, with a
260 //J few small changes to take into account that the spinors
261 //J are 5d and the gauge fields are 4d.
262 //
263 // if oddBit is zero: calculate odd parity spinor elements (using even parity spinor)
264 // if oddBit is one: calculate even parity spinor elements
265 //
266 // if daggerBit is zero: perform ordinary dslash operator
267 // if daggerBit is one: perform hermitian conjugate of dslash
268 //
269 //An "ok" will only be granted once check2.tex is deemed complete,
270 //since the logic in this function is important and nontrivial.
271 template <QudaPCType type, typename sFloat, typename gFloat>
272 void dslashReference_4d_sgpu(sFloat *res, gFloat **gaugeFull, sFloat *spinorField, int oddBit, int daggerBit)
273 {
274 
275  // Initialize the return half-spinor to zero. Note that it is a
276  // 5d spinor, hence the use of V5h.
277  for (int i=0; i<V5h*4*3*2; i++) res[i] = 0.0;
278 
279  // Some pointers that we use to march through arrays.
280  gFloat *gaugeEven[4], *gaugeOdd[4];
281  // Initialize to beginning of even and odd parts of
282  // gauge array.
283  for (int dir = 0; dir < 4; dir++) {
284  gaugeEven[dir] = gaugeFull[dir];
285  // Note the use of Vh here, since the gauge fields
286  // are 4-dim'l.
287  gaugeOdd[dir] = gaugeFull[dir]+Vh*gaugeSiteSize;
288  }
289  int sp_idx,gaugeOddBit;
290  for (int xs=0;xs<Ls;xs++) {
291  for (int gge_idx = 0; gge_idx < Vh; gge_idx++) {
292  for (int dir = 0; dir < 8; dir++) {
293  sp_idx=gge_idx+Vh*xs;
294  // Here is a function call to study. It is defined near
295  // Line 90 of this file.
296  // Here we have to switch oddBit depending on the value of xs. E.g., suppose
297  // xs=1. Then the odd spinor site x1=x2=x3=x4=0 wants the even gauge array
298  // element 0, so that we get U_\mu(0).
299  gaugeOddBit = (xs%2 == 0 || type == QUDA_4D_PC) ? oddBit : (oddBit+1) % 2;
300  gFloat *gauge = gaugeLink_sgpu(gge_idx, dir, gaugeOddBit, gaugeEven, gaugeOdd);
301 
302  // Even though we're doing the 4d part of the dslash, we need
303  // to use a 5d neighbor function, to get the offsets right.
304  sFloat *spinor = spinorNeighbor_5d<type>(sp_idx, dir, oddBit, spinorField);
305  sFloat projectedSpinor[4*3*2], gaugedSpinor[4*3*2];
306  int projIdx = 2*(dir/2)+(dir+daggerBit)%2;
307  multiplySpinorByDiracProjector5(projectedSpinor, projIdx, spinor);
308 
309  for (int s = 0; s < 4; s++) {
310  if (dir % 2 == 0) {
311  su3Mul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]);
312 #ifdef DBUG_VERBOSE
313  std::cout << "spinor:" << std::endl;
314  printSpinorElement(&projectedSpinor[s*(3*2)],0,QUDA_DOUBLE_PRECISION);
315  std::cout << "gauge:" << std::endl;
316 #endif
317  } else {
318  su3Tmul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]);
319  }
320  }
321 
322  sum(&res[sp_idx*(4*3*2)], &res[sp_idx*(4*3*2)], gaugedSpinor, 4*3*2);
323  }
324  }
325  }
326 }
327 
328 #ifdef MULTI_GPU
329 template <QudaPCType type, typename sFloat, typename gFloat>
330 void dslashReference_4d_mgpu(sFloat *res, gFloat **gaugeFull, gFloat **ghostGauge, sFloat *spinorField,
331  sFloat **fwdSpinor, sFloat **backSpinor, int oddBit, int daggerBit)
332 {
333  int mySpinorSiteSize = 24;
334  for (int i=0; i<V5h*mySpinorSiteSize; i++) res[i] = 0.0;
335 
336  gFloat *gaugeEven[4], *gaugeOdd[4];
337  gFloat *ghostGaugeEven[4], *ghostGaugeOdd[4];
338 
339  for (int dir = 0; dir < 4; dir++)
340  {
341  gaugeEven[dir] = gaugeFull[dir];
342  gaugeOdd[dir] = gaugeFull[dir]+Vh*gaugeSiteSize;
343 
344  ghostGaugeEven[dir] = ghostGauge[dir];
345  ghostGaugeOdd[dir] = ghostGauge[dir] + (faceVolume[dir]/2)*gaugeSiteSize;
346  }
347  for (int xs=0;xs<Ls;xs++)
348  {
349  int sp_idx;
350  for (int i = 0; i < Vh; i++)
351  {
352  sp_idx = i + Vh*xs;
353  for (int dir = 0; dir < 8; dir++)
354  {
355  int gaugeOddBit = (xs%2 == 0 || type == QUDA_4D_PC) ? oddBit : (oddBit + 1) % 2;
356 
357  gFloat *gauge = gaugeLink_mgpu(i, dir, gaugeOddBit, gaugeEven, gaugeOdd, ghostGaugeEven, ghostGaugeOdd, 1, 1);//this is unchanged from MPi version
358  sFloat *spinor = spinorNeighbor_5d_mgpu<type>(sp_idx, dir, oddBit, spinorField, fwdSpinor, backSpinor, 1, 1);
359 
360  sFloat projectedSpinor[mySpinorSiteSize], gaugedSpinor[mySpinorSiteSize];
361  int projIdx = 2*(dir/2)+(dir+daggerBit)%2;
362  multiplySpinorByDiracProjector5(projectedSpinor, projIdx, spinor);
363 
364  for (int s = 0; s < 4; s++)
365  {
366  if (dir % 2 == 0) su3Mul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]);
367  else su3Tmul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]);
368  }
369  sum(&res[sp_idx*(4*3*2)], &res[sp_idx*(4*3*2)], gaugedSpinor, 4*3*2);
370  }
371  }
372  }
373 }
374 #endif
375 
376 //Currently we consider only spacetime decomposition (not in 5th dim), so this operator is local
377 template <QudaPCType type, bool zero_initialize = false, typename sFloat>
378 void dslashReference_5th(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm)
379 {
380  for (int i = 0; i < V5h; i++) {
381  if (zero_initialize) for(int one_site = 0 ; one_site < 24 ; one_site++)
382  res[i*(4*3*2)+one_site] = 0.0;
383  for (int dir = 8; dir < 10; dir++) {
384  // Calls for an extension of the original function.
385  // 8 is forward hop, which wants P_+, 9 is backward hop,
386  // which wants P_-. Dagger reverses these.
387  sFloat *spinor = spinorNeighbor_5d<type>(i, dir, oddBit, spinorField);
388  sFloat projectedSpinor[4*3*2];
389  int projIdx = 2*(dir/2)+(dir+daggerBit)%2;
390  multiplySpinorByDiracProjector5(projectedSpinor, projIdx, spinor);
391  //J Need a conditional here for s=0 and s=Ls-1.
392  int X = (type == QUDA_5D_PC) ? fullLatticeIndex_5d(i, oddBit) : fullLatticeIndex_5d_4dpc(i, oddBit);
393  int xs = X/(Z[3]*Z[2]*Z[1]*Z[0]);
394 
395  if ( (xs == 0 && dir == 9) || (xs == Ls-1 && dir == 8) ) {
396  ax(projectedSpinor,(sFloat)(-mferm),projectedSpinor,4*3*2);
397  }
398  sum(&res[i*(4*3*2)], &res[i*(4*3*2)], projectedSpinor, 4*3*2);
399  }
400  }
401 }
402 
403 //Currently we consider only spacetime decomposition (not in 5th dim), so this operator is local
404 template <typename sFloat>
405 void dslashReference_5th_inv(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm, double *kappa)
406 {
407  double *inv_Ftr = (double*)malloc(Ls*sizeof(sFloat));
408  double *Ftr = (double*)malloc(Ls*sizeof(sFloat));
409  for(int xs = 0 ; xs < Ls ; xs++)
410  {
411  inv_Ftr[xs] = 1.0/(1.0+pow(2.0*kappa[xs], Ls)*mferm);
412  Ftr[xs] = -2.0*kappa[xs]*mferm*inv_Ftr[xs];
413  for (int i = 0; i < Vh; i++) {
414  memcpy(&res[24*(i+Vh*xs)], &spinorField[24*(i+Vh*xs)], 24*sizeof(sFloat));
415  }
416  }
417  if(daggerBit == 0)
418  {
419  // s = 0
420  for (int i = 0; i < Vh; i++) {
421  ax(&res[12+24*(i+Vh*(Ls-1))],(sFloat)(inv_Ftr[0]), &spinorField[12+24*(i+Vh*(Ls-1))], 12);
422  }
423 
424  // s = 1 ... ls-2
425  for(int xs = 0 ; xs <= Ls-2 ; ++xs)
426  {
427  for (int i = 0; i < Vh; i++) {
428  axpy((sFloat)(2.0*kappa[xs]), &res[24*(i+Vh*xs)], &res[24*(i+Vh*(xs+1))], 12);
429  axpy((sFloat)Ftr[xs], &res[12+24*(i+Vh*xs)], &res[12+24*(i+Vh*(Ls-1))], 12);
430  }
431  for (int tmp_s = 0 ; tmp_s < Ls ; tmp_s++)
432  Ftr[tmp_s] *= 2.0*kappa[tmp_s];
433  }
434  for(int xs = 0 ; xs < Ls ; xs++)
435  {
436  Ftr[xs] = -pow(2.0*kappa[xs],Ls-1)*mferm*inv_Ftr[xs];
437  }
438  // s = ls-2 ... 0
439  for(int xs = Ls-2 ; xs >=0 ; --xs)
440  {
441  for (int i = 0; i < Vh; i++) {
442  axpy((sFloat)Ftr[xs], &res[24*(i+Vh*(Ls-1))], &res[24*(i+Vh*xs)], 12);
443  axpy((sFloat)(2.0*kappa[xs]), &res[12+24*(i+Vh*(xs+1))], &res[12+24*(i+Vh*xs)], 12);
444  }
445  for (int tmp_s = 0 ; tmp_s < Ls ; tmp_s++)
446  Ftr[tmp_s] /= 2.0*kappa[tmp_s];
447  }
448  // s = ls -1
449  for (int i = 0; i < Vh; i++) {
450  ax(&res[24*(i+Vh*(Ls-1))], (sFloat)(inv_Ftr[Ls-1]), &res[24*(i+Vh*(Ls-1))], 12);
451  }
452  }
453  else
454  {
455  // s = 0
456  for (int i = 0; i < Vh; i++) {
457  ax(&res[24*(i+Vh*(Ls-1))],(sFloat)(inv_Ftr[0]), &spinorField[24*(i+Vh*(Ls-1))], 12);
458  }
459 
460  // s = 1 ... ls-2
461  for(int xs = 0 ; xs <= Ls-2 ; ++xs)
462  {
463  for (int i = 0; i < Vh; i++) {
464  axpy((sFloat)Ftr[xs], &res[24*(i+Vh*xs)], &res[24*(i+Vh*(Ls-1))], 12);
465  axpy((sFloat)(2.0*kappa[xs]), &res[12+24*(i+Vh*xs)], &res[12+24*(i+Vh*(xs+1))], 12);
466  }
467  for (int tmp_s = 0 ; tmp_s < Ls ; tmp_s++)
468  Ftr[tmp_s] *= 2.0*kappa[tmp_s];
469  }
470  for(int xs = 0 ; xs < Ls ; xs++)
471  {
472  Ftr[xs] = -pow(2.0*kappa[xs],Ls-1)*mferm*inv_Ftr[xs];
473  }
474  // s = ls-2 ... 0
475  for(int xs = Ls-2 ; xs >=0 ; --xs)
476  {
477  for (int i = 0; i < Vh; i++) {
478  axpy((sFloat)(2.0*kappa[xs]), &res[24*(i+Vh*(xs+1))], &res[24*(i+Vh*xs)], 12);
479  axpy((sFloat)Ftr[xs], &res[12+24*(i+Vh*(Ls-1))], &res[12+24*(i+Vh*xs)], 12);
480  }
481  for (int tmp_s = 0 ; tmp_s < Ls ; tmp_s++)
482  Ftr[tmp_s] /= 2.0*kappa[tmp_s];
483  }
484  // s = ls -1
485  for (int i = 0; i < Vh; i++) {
486  ax(&res[12+24*(i+Vh*(Ls-1))], (sFloat)(inv_Ftr[Ls-1]), &res[12+24*(i+Vh*(Ls-1))], 12);
487  }
488  }
489  free(inv_Ftr);
490  free(Ftr);
491 }
492 
493 // Currently we consider only spacetime decomposition (not in 5th dim), so this operator is local
494 template <typename sFloat, typename sComplex>
495 void mdslashReference_5th_inv(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm, sComplex *kappa)
496 {
497  sComplex *inv_Ftr = (sComplex *)malloc(Ls * sizeof(sComplex));
498  sComplex *Ftr = (sComplex *)malloc(Ls * sizeof(sComplex));
499  for (int xs = 0; xs < Ls; xs++) {
500  inv_Ftr[xs] = 1.0 / (1.0 + cpow(2.0 * kappa[xs], Ls) * mferm);
501  Ftr[xs] = -2.0 * kappa[xs] * mferm * inv_Ftr[xs];
502  for (int i = 0; i < Vh; i++) {
503  memcpy(&res[24 * (i + Vh * xs)], &spinorField[24 * (i + Vh * xs)], 24 * sizeof(sFloat));
504  }
505  }
506  if (daggerBit == 0) {
507  // s = 0
508  for (int i = 0; i < Vh; i++) {
509  ax((sComplex *)&res[12 + 24 * (i + Vh * (Ls - 1))], inv_Ftr[0],
510  (sComplex *)&spinorField[12 + 24 * (i + Vh * (Ls - 1))], 6);
511  }
512 
513  // s = 1 ... ls-2
514  for (int xs = 0; xs <= Ls - 2; ++xs) {
515  for (int i = 0; i < Vh; i++) {
516  axpy((2.0 * kappa[xs]), (sComplex *)&res[24 * (i + Vh * xs)], (sComplex *)&res[24 * (i + Vh * (xs + 1))], 6);
517  axpy(Ftr[xs], (sComplex *)&res[12 + 24 * (i + Vh * xs)], (sComplex *)&res[12 + 24 * (i + Vh * (Ls - 1))], 6);
518  }
519  for (int tmp_s = 0; tmp_s < Ls; tmp_s++) Ftr[tmp_s] *= 2.0 * kappa[tmp_s];
520  }
521  for (int xs = 0; xs < Ls; xs++) Ftr[xs] = -cpow(2.0 * kappa[xs], Ls - 1) * mferm * inv_Ftr[xs];
522 
523  // s = ls-2 ... 0
524  for (int xs = Ls - 2; xs >= 0; --xs) {
525  for (int i = 0; i < Vh; i++) {
526  axpy(Ftr[xs], (sComplex *)&res[24 * (i + Vh * (Ls - 1))], (sComplex *)&res[24 * (i + Vh * xs)], 6);
527  axpy((2.0 * kappa[xs]), (sComplex *)&res[12 + 24 * (i + Vh * (xs + 1))],
528  (sComplex *)&res[12 + 24 * (i + Vh * xs)], 6);
529  }
530  for (int tmp_s = 0; tmp_s < Ls; tmp_s++) Ftr[tmp_s] /= 2.0 * kappa[tmp_s];
531  }
532  // s = ls -1
533  for (int i = 0; i < Vh; i++) {
534  ax((sComplex *)&res[24 * (i + Vh * (Ls - 1))], inv_Ftr[Ls - 1], (sComplex *)&res[24 * (i + Vh * (Ls - 1))], 6);
535  }
536  } else {
537  // s = 0
538  for (int i = 0; i < Vh; i++) {
539  ax((sComplex *)&res[24 * (i + Vh * (Ls - 1))], inv_Ftr[0], (sComplex *)&spinorField[24 * (i + Vh * (Ls - 1))], 6);
540  }
541 
542  // s = 1 ... ls-2
543  for (int xs = 0; xs <= Ls - 2; ++xs) {
544  for (int i = 0; i < Vh; i++) {
545  axpy(Ftr[xs], (sComplex *)&res[24 * (i + Vh * xs)], (sComplex *)&res[24 * (i + Vh * (Ls - 1))], 6);
546  axpy((2.0 * kappa[xs]), (sComplex *)&res[12 + 24 * (i + Vh * xs)],
547  (sComplex *)&res[12 + 24 * (i + Vh * (xs + 1))], 6);
548  }
549  for (int tmp_s = 0; tmp_s < Ls; tmp_s++) Ftr[tmp_s] *= 2.0 * kappa[tmp_s];
550  }
551  for (int xs = 0; xs < Ls; xs++) Ftr[xs] = -cpow(2.0 * kappa[xs], Ls - 1) * mferm * inv_Ftr[xs];
552 
553  // s = ls-2 ... 0
554  for (int xs = Ls - 2; xs >= 0; --xs) {
555  for (int i = 0; i < Vh; i++) {
556  axpy((2.0 * kappa[xs]), (sComplex *)&res[24 * (i + Vh * (xs + 1))], (sComplex *)&res[24 * (i + Vh * xs)], 6);
557  axpy(Ftr[xs], (sComplex *)&res[12 + 24 * (i + Vh * (Ls - 1))], (sComplex *)&res[12 + 24 * (i + Vh * xs)], 6);
558  }
559  for (int tmp_s = 0; tmp_s < Ls; tmp_s++) Ftr[tmp_s] /= 2.0 * kappa[tmp_s];
560  }
561  // s = ls -1
562  for (int i = 0; i < Vh; i++) {
563  ax((sComplex *)&res[12 + 24 * (i + Vh * (Ls - 1))], inv_Ftr[Ls - 1],
564  (sComplex *)&res[12 + 24 * (i + Vh * (Ls - 1))], 6);
565  }
566  }
567  free(inv_Ftr);
568  free(Ftr);
569 }
570 
571 // this actually applies the preconditioned dslash, e.g., D_ee^{-1} D_eo or D_oo^{-1} D_oe
572 void dw_dslash(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision,
573  QudaGaugeParam &gauge_param, double mferm)
574 {
575 #ifndef MULTI_GPU
576  if (precision == QUDA_DOUBLE_PRECISION) {
577  dslashReference_4d_sgpu<QUDA_5D_PC>((double*)out, (double**)gauge, (double*)in, oddBit, daggerBit);
578  dslashReference_5th<QUDA_5D_PC>((double*)out, (double*)in, oddBit, daggerBit, mferm);
579  } else {
580  dslashReference_4d_sgpu<QUDA_5D_PC>((float*)out, (float**)gauge, (float*)in, oddBit, daggerBit);
581  dslashReference_5th<QUDA_5D_PC>((float*)out, (float*)in, oddBit, daggerBit, (float)mferm);
582  }
583 #else
584 
585  GaugeFieldParam gauge_field_param(gauge, gauge_param);
586  gauge_field_param.ghostExchange = QUDA_GHOST_EXCHANGE_PAD;
587  cpuGaugeField cpu(gauge_field_param);
588  void **ghostGauge = (void**)cpu.Ghost();
589 
590  // Get spinor ghost fields
591  // First wrap the input spinor into a ColorSpinorField
593  csParam.v = in;
594  csParam.nColor = 3;
595  csParam.nSpin = 4;
596  csParam.nDim = 5; //for DW dslash
597  for (int d=0; d<4; d++) csParam.x[d] = Z[d];
598  csParam.x[4] = Ls;//5th dimention
599  csParam.setPrecision(precision);
600  csParam.pad = 0;
602  csParam.x[0] /= 2;
607  csParam.pc_type = QUDA_5D_PC;
608 
609  cpuColorSpinorField inField(csParam);
610 
611  { // Now do the exchange
612  QudaParity otherParity = QUDA_INVALID_PARITY;
613  if (oddBit == QUDA_EVEN_PARITY) otherParity = QUDA_ODD_PARITY;
614  else if (oddBit == QUDA_ODD_PARITY) otherParity = QUDA_EVEN_PARITY;
615  else errorQuda("ERROR: full parity not supported in function %s", __FUNCTION__);
616  const int nFace = 1;
617 
618  inField.exchangeGhost(otherParity, nFace, daggerBit);
619  }
620  void** fwd_nbr_spinor = inField.fwdGhostFaceBuffer;
621  void** back_nbr_spinor = inField.backGhostFaceBuffer;
622  //NOTE: hopping in 5th dimension does not use MPI.
623  if (precision == QUDA_DOUBLE_PRECISION) {
624  dslashReference_4d_mgpu<QUDA_5D_PC>((double*)out, (double**)gauge, (double**)ghostGauge, (double*)in,(double**)fwd_nbr_spinor, (double**)back_nbr_spinor, oddBit, daggerBit);
625  //dslashReference_4d_sgpu<QUDA_5D_PC>((double*)out, (double**)gauge, (double*)in, oddBit, daggerBit);
626  dslashReference_5th<QUDA_5D_PC>((double*)out, (double*)in, oddBit, daggerBit, mferm);
627  } else {
628  dslashReference_4d_mgpu<QUDA_5D_PC>((float*)out, (float**)gauge, (float**)ghostGauge, (float*)in,
629  (float**)fwd_nbr_spinor, (float**)back_nbr_spinor, oddBit, daggerBit);
630  dslashReference_5th<QUDA_5D_PC>((float*)out, (float*)in, oddBit, daggerBit, (float)mferm);
631  }
632 
633 #endif
634 
635 }
636 
637 void dslash_4_4d(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
638 {
639 #ifndef MULTI_GPU
640  if (precision == QUDA_DOUBLE_PRECISION) {
641  dslashReference_4d_sgpu<QUDA_4D_PC>((double*)out, (double**)gauge, (double*)in, oddBit, daggerBit);
642  } else {
643  dslashReference_4d_sgpu<QUDA_4D_PC>((float*)out, (float**)gauge, (float*)in, oddBit, daggerBit);
644  }
645 #else
646 
647  GaugeFieldParam gauge_field_param(gauge, gauge_param);
648  gauge_field_param.ghostExchange = QUDA_GHOST_EXCHANGE_PAD;
649  cpuGaugeField cpu(gauge_field_param);
650  void **ghostGauge = (void**)cpu.Ghost();
651 
652  // Get spinor ghost fields
653  // First wrap the input spinor into a ColorSpinorField
655  csParam.v = in;
656  csParam.nColor = 3;
657  csParam.nSpin = 4;
658  csParam.nDim = 5; //for DW dslash
659  for (int d=0; d<4; d++) csParam.x[d] = Z[d];
660  csParam.x[4] = Ls;//5th dimention
661  csParam.setPrecision(precision);
662  csParam.pad = 0;
664  csParam.x[0] /= 2;
669  csParam.pc_type = QUDA_4D_PC;
670 
671  cpuColorSpinorField inField(csParam);
672 
673  { // Now do the exchange
674  QudaParity otherParity = QUDA_INVALID_PARITY;
675  if (oddBit == QUDA_EVEN_PARITY) otherParity = QUDA_ODD_PARITY;
676  else if (oddBit == QUDA_ODD_PARITY) otherParity = QUDA_EVEN_PARITY;
677  else errorQuda("ERROR: full parity not supported in function %s", __FUNCTION__);
678  const int nFace = 1;
679 
680  inField.exchangeGhost(otherParity, nFace, daggerBit);
681  }
682  void** fwd_nbr_spinor = inField.fwdGhostFaceBuffer;
683  void** back_nbr_spinor = inField.backGhostFaceBuffer;
684  if (precision == QUDA_DOUBLE_PRECISION) {
685  dslashReference_4d_mgpu<QUDA_4D_PC>((double*)out, (double**)gauge, (double**)ghostGauge, (double*)in,(double**)fwd_nbr_spinor, (double**)back_nbr_spinor, oddBit, daggerBit);
686  } else {
687  dslashReference_4d_mgpu<QUDA_4D_PC>((float*)out, (float**)gauge, (float**)ghostGauge, (float*)in,
688  (float**)fwd_nbr_spinor, (float**)back_nbr_spinor, oddBit, daggerBit);
689  }
690 
691 #endif
692 
693 }
694 
695 void dw_dslash_5_4d(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, bool zero_initialize)
696 {
697  if (precision == QUDA_DOUBLE_PRECISION) {
698  if (zero_initialize) dslashReference_5th<QUDA_4D_PC, true>((double*)out, (double*)in, oddBit, daggerBit, mferm);
699  else dslashReference_5th<QUDA_4D_PC, false>((double*)out, (double*)in, oddBit, daggerBit, mferm);
700  } else {
701  if (zero_initialize) dslashReference_5th<QUDA_4D_PC, true>((float*)out, (float*)in, oddBit, daggerBit, (float)mferm);
702  else dslashReference_5th<QUDA_4D_PC, false>((float*)out, (float*)in, oddBit, daggerBit, (float)mferm);
703  }
704 }
705 
706 void dslash_5_inv(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double *kappa)
707 {
708  if (precision == QUDA_DOUBLE_PRECISION) {
709  dslashReference_5th_inv((double*)out, (double*)in, oddBit, daggerBit, mferm, kappa);
710  } else {
711  dslashReference_5th_inv((float*)out, (float*)in, oddBit, daggerBit, (float)mferm, kappa);
712  }
713 }
714 
715 void mdw_dslash_5_inv(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision,
716  QudaGaugeParam &gauge_param, double mferm, double _Complex *kappa)
717 {
718  if (precision == QUDA_DOUBLE_PRECISION) {
719  mdslashReference_5th_inv((double *)out, (double *)in, oddBit, daggerBit, mferm, kappa);
720  } else {
721  mdslashReference_5th_inv((float *)out, (float *)in, oddBit, daggerBit, (float)mferm, kappa);
722  }
723 }
724 
725 void mdw_dslash_5(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision,
726  QudaGaugeParam &gauge_param, double mferm, double _Complex *kappa, bool zero_initialize)
727 {
728  if (precision == QUDA_DOUBLE_PRECISION) {
729  if (zero_initialize) dslashReference_5th<QUDA_4D_PC,true>((double*)out, (double*)in, oddBit, daggerBit, mferm);
730  else dslashReference_5th<QUDA_4D_PC,false>((double*)out, (double*)in, oddBit, daggerBit, mferm);
731  } else {
732  if (zero_initialize) dslashReference_5th<QUDA_4D_PC,true>((float*)out, (float*)in, oddBit, daggerBit, (float)mferm);
733  else dslashReference_5th<QUDA_4D_PC,false>((float*)out, (float*)in, oddBit, daggerBit, (float)mferm);
734  }
735  for(int xs = 0 ; xs < Ls ; xs++) {
736  cxpay((char *)in + precision * Vh * spinorSiteSize * xs, kappa[xs],
737  (char *)out + precision * Vh * spinorSiteSize * xs, Vh * spinorSiteSize, precision);
738  }
739 }
740 
741 void mdw_dslash_4_pre(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision,
742  QudaGaugeParam &gauge_param, double mferm, double _Complex *b5, double _Complex *c5, bool zero_initialize)
743 {
744  if (precision == QUDA_DOUBLE_PRECISION) {
745  if (zero_initialize) dslashReference_5th<QUDA_4D_PC, true>((double*)out, (double*)in, oddBit, daggerBit, mferm);
746  else dslashReference_5th<QUDA_4D_PC, false>((double*)out, (double*)in, oddBit, daggerBit, mferm);
747  for(int xs = 0 ; xs < Ls ; xs++)
748  {
749  axpby(b5[xs], (double _Complex *)in + Vh * spinorSiteSize / 2 * xs, 0.5 * c5[xs],
750  (double _Complex *)out + Vh * spinorSiteSize / 2 * xs, Vh * spinorSiteSize / 2);
751  }
752  } else {
753  if (zero_initialize)
754  dslashReference_5th<QUDA_4D_PC, true>((float *)out, (float *)in, oddBit, daggerBit, (float)mferm);
755  else dslashReference_5th<QUDA_4D_PC,false>((float*)out, (float*)in, oddBit, daggerBit, (float)mferm);
756  for(int xs = 0 ; xs < Ls ; xs++)
757  {
758  axpby((float _Complex)(b5[xs]), (float _Complex *)in + Vh * (spinorSiteSize / 2) * xs,
759  (float _Complex)(0.5 * c5[xs]), (float _Complex *)out + Vh * (spinorSiteSize / 2) * xs,
760  Vh * spinorSiteSize / 2);
761  }
762  }
763 
764 }
765 
766 void dw_mat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm) {
767 
768  void *inEven = in;
769  void *inOdd = (char*)in + V5h*spinorSiteSize*precision;
770  void *outEven = out;
771  void *outOdd = (char*)out + V5h*spinorSiteSize*precision;
772 
773  dw_dslash(outOdd, gauge, inEven, 1, dagger_bit, precision, gauge_param, mferm);
774  dw_dslash(outEven, gauge, inOdd, 0, dagger_bit, precision, gauge_param, mferm);
775 
776  // lastly apply the kappa term
777  xpay(in, -kappa, out, V5*spinorSiteSize, precision);
778 }
779 
780 void dw_4d_mat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm) {
781 
782  void *inEven = in;
783  void *inOdd = (char*)in + V5h*spinorSiteSize*precision;
784  void *outEven = out;
785  void *outOdd = (char*)out + V5h*spinorSiteSize*precision;
786 
787  dslash_4_4d(outOdd, gauge, inEven, 1, dagger_bit, precision, gauge_param, mferm);
788  dw_dslash_5_4d(outOdd, gauge, inOdd, 1, dagger_bit, precision, gauge_param, mferm, false);
789 
790  dslash_4_4d(outEven, gauge, inOdd, 0, dagger_bit, precision, gauge_param, mferm);
791  dw_dslash_5_4d(outEven, gauge, inEven, 0, dagger_bit, precision, gauge_param, mferm, false);
792 
793  // lastly apply the kappa term
794  xpay(in, -kappa, out, V5*spinorSiteSize, precision);
795 }
796 
797 void mdw_mat(void *out, void **gauge, void *in, double _Complex *kappa_b, double _Complex *kappa_c, int dagger,
798  QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *b5, double _Complex *c5)
799 {
800 
801  void *tmp = malloc(V5h*spinorSiteSize*precision);
802  double _Complex *kappa5 = (double _Complex *)malloc(Ls * sizeof(double _Complex));
803 
804  for(int xs = 0; xs < Ls ; xs++) kappa5[xs] = 0.5*kappa_b[xs]/kappa_c[xs];
805 
806  void *inEven = in;
807  void *inOdd = (char*)in + V5h*spinorSiteSize*precision;
808  void *outEven = out;
809  void *outOdd = (char*)out + V5h*spinorSiteSize*precision;
810 
811  mdw_dslash_4_pre(tmp, gauge, inEven, 0, dagger, precision, gauge_param, mferm, b5, c5, true);
812  dslash_4_4d(outOdd, gauge, tmp, 1, dagger, precision, gauge_param, mferm);
813  mdw_dslash_5(tmp, gauge, inOdd, 1, dagger, precision, gauge_param, mferm, kappa5, true);
814 
815  for(int xs = 0 ; xs < Ls ; xs++) {
816  cxpay((char *)tmp + precision * Vh * spinorSiteSize * xs, -kappa_b[xs],
817  (char *)outOdd + precision * Vh * spinorSiteSize * xs, Vh * spinorSiteSize, precision);
818  }
819 
820  mdw_dslash_4_pre(tmp, gauge, inOdd, 1, dagger, precision, gauge_param, mferm, b5, c5, true);
821  dslash_4_4d(outEven, gauge, tmp, 0, dagger, precision, gauge_param, mferm);
822  mdw_dslash_5(tmp, gauge, inEven, 0, dagger, precision, gauge_param, mferm, kappa5, true);
823 
824  for(int xs = 0 ; xs < Ls ; xs++) {
825  cxpay((char *)tmp + precision * Vh * spinorSiteSize * xs, -kappa_b[xs],
826  (char *)outEven + precision * Vh * spinorSiteSize * xs, Vh * spinorSiteSize, precision);
827  }
828 
829  free(kappa5);
830  free(tmp);
831 }
832 
833 //
834 void dw_matdagmat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
835 {
836 
837  void *tmp = malloc(V5*spinorSiteSize*precision);
838  dw_mat(tmp, gauge, in, kappa, dagger_bit, precision, gauge_param, mferm);
839  dagger_bit = (dagger_bit == 1) ? 0 : 1;
840  dw_mat(out, gauge, tmp, kappa, dagger_bit, precision, gauge_param, mferm);
841 
842  free(tmp);
843 }
844 
845 void dw_matpc(void *out, void **gauge, void *in, double kappa, QudaMatPCType matpc_type, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
846 {
847  void *tmp = malloc(V5h*spinorSiteSize*precision);
848 
849  if (matpc_type == QUDA_MATPC_EVEN_EVEN || matpc_type == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) {
850  dw_dslash(tmp, gauge, in, 1, dagger_bit, precision, gauge_param, mferm);
851  dw_dslash(out, gauge, tmp, 0, dagger_bit, precision, gauge_param, mferm);
852  } else {
853  dw_dslash(tmp, gauge, in, 0, dagger_bit, precision, gauge_param, mferm);
854  dw_dslash(out, gauge, tmp, 1, dagger_bit, precision, gauge_param, mferm);
855  }
856 
857  // lastly apply the kappa term
858  double kappa2 = -kappa*kappa;
859  xpay(in, kappa2, out, V5h*spinorSiteSize, precision);
860 
861  free(tmp);
862 }
863 
864 
865 void dw_4d_matpc(void *out, void **gauge, void *in, double kappa, QudaMatPCType matpc_type, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
866 {
867  double kappa2 = -kappa*kappa;
868  double *kappa5 = (double*)malloc(Ls*sizeof(double));
869  for(int xs = 0; xs < Ls ; xs++)
870  kappa5[xs] = kappa;
871  void *tmp = malloc(V5h*spinorSiteSize*precision);
872  //------------------------------------------
873  double *output = (double*)out;
874  for(int k = 0 ; k< V5h*spinorSiteSize; k++)
875  output[k] = 0.0;
876  //------------------------------------------
877 
878  int odd_bit = (matpc_type == QUDA_MATPC_ODD_ODD || matpc_type == QUDA_MATPC_ODD_ODD_ASYMMETRIC) ? 1 : 0;
879  bool symmetric =(matpc_type == QUDA_MATPC_EVEN_EVEN || matpc_type == QUDA_MATPC_ODD_ODD) ? true : false;
880  QudaParity parity[2] = {static_cast<QudaParity>((1 + odd_bit) % 2), static_cast<QudaParity>((0 + odd_bit) % 2)};
881 
882  if (symmetric && !dagger_bit) {
883  dslash_4_4d(tmp, gauge, in, parity[0], dagger_bit, precision, gauge_param, mferm);
884  dslash_5_inv(out, gauge, tmp, parity[0], dagger_bit, precision, gauge_param, mferm, kappa5);
885  dslash_4_4d(tmp, gauge, out, parity[1], dagger_bit, precision, gauge_param, mferm);
886  dslash_5_inv(out, gauge, tmp, parity[1], dagger_bit, precision, gauge_param, mferm, kappa5);
887  xpay(in, kappa2, out, V5h*spinorSiteSize, precision);
888  } else if (symmetric && dagger_bit) {
889  dslash_5_inv(tmp, gauge, in, parity[1], dagger_bit, precision, gauge_param, mferm, kappa5);
890  dslash_4_4d(out, gauge, tmp, parity[0], dagger_bit, precision, gauge_param, mferm);
891  dslash_5_inv(tmp, gauge, out, parity[0], dagger_bit, precision, gauge_param, mferm, kappa5);
892  dslash_4_4d(out, gauge, tmp, parity[1], dagger_bit, precision, gauge_param, mferm);
893  xpay(in, kappa2, out, V5h*spinorSiteSize, precision);
894  } else {
895  dslash_4_4d(tmp, gauge, in, parity[0], dagger_bit, precision, gauge_param, mferm);
896  dslash_5_inv(out, gauge, tmp, parity[0], dagger_bit, precision, gauge_param, mferm, kappa5);
897  dslash_4_4d(tmp, gauge, out, parity[1], dagger_bit, precision, gauge_param, mferm);
898  xpay(in, kappa2, tmp, V5h*spinorSiteSize, precision);
899  dw_dslash_5_4d(out, gauge, in, parity[1], dagger_bit, precision, gauge_param, mferm, true);
900  xpay(tmp, -kappa, out, V5h*spinorSiteSize, precision);
901  }
902  free(tmp);
903  free(kappa5);
904 }
905 
906 void mdw_matpc(void *out, void **gauge, void *in, double _Complex *kappa_b, double _Complex *kappa_c,
907  QudaMatPCType matpc_type, int dagger, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm,
908  double _Complex *b5, double _Complex *c5)
909 {
910  void *tmp = malloc(V5h*spinorSiteSize*precision);
911  double _Complex *kappa5 = (double _Complex *)malloc(Ls * sizeof(double _Complex));
912  double _Complex *kappa2 = (double _Complex *)malloc(Ls * sizeof(double _Complex));
913  double _Complex *kappa_mdwf = (double _Complex *)malloc(Ls * sizeof(double _Complex));
914  for(int xs = 0; xs < Ls ; xs++)
915  {
916  kappa5[xs] = 0.5*kappa_b[xs]/kappa_c[xs];
917  kappa2[xs] = -kappa_b[xs]*kappa_b[xs];
918  kappa_mdwf[xs] = -kappa5[xs];
919  }
920 
921  int odd_bit = (matpc_type == QUDA_MATPC_ODD_ODD || matpc_type == QUDA_MATPC_ODD_ODD_ASYMMETRIC) ? 1 : 0;
922  bool symmetric =(matpc_type == QUDA_MATPC_EVEN_EVEN || matpc_type == QUDA_MATPC_ODD_ODD) ? true : false;
923  QudaParity parity[2] = {static_cast<QudaParity>((1 + odd_bit) % 2), static_cast<QudaParity>((0 + odd_bit) % 2)};
924 
925  if (symmetric && !dagger) {
926  mdw_dslash_4_pre(tmp, gauge, in, parity[1], dagger, precision, gauge_param, mferm, b5, c5, true);
927  dslash_4_4d(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm);
928  mdw_dslash_5_inv(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm, kappa_mdwf);
929  mdw_dslash_4_pre(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm, b5, c5, true);
930  dslash_4_4d(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm);
931  mdw_dslash_5_inv(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm, kappa_mdwf);
932  for(int xs = 0 ; xs < Ls ; xs++) {
933  cxpay((char *)in + precision * Vh * spinorSiteSize * xs, kappa2[xs],
934  (char *)out + precision * Vh * spinorSiteSize * xs, Vh * spinorSiteSize, precision);
935  }
936  } else if (symmetric && dagger) {
937  mdw_dslash_5_inv(tmp, gauge, in, parity[1], dagger, precision, gauge_param, mferm, kappa_mdwf);
938  dslash_4_4d(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm);
939  mdw_dslash_4_pre(tmp, gauge, out, parity[0], dagger, precision, gauge_param, mferm, b5, c5, true);
940  mdw_dslash_5_inv(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm, kappa_mdwf);
941  dslash_4_4d(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm);
942  mdw_dslash_4_pre(out, gauge, tmp, parity[1], dagger, precision, gauge_param, mferm, b5, c5, true);
943  for(int xs = 0 ; xs < Ls ; xs++) {
944  cxpay((char *)in + precision * Vh * spinorSiteSize * xs, kappa2[xs],
945  (char *)out + precision * Vh * spinorSiteSize * xs, Vh * spinorSiteSize, precision);
946  }
947  } else if (!symmetric && !dagger) {
948  mdw_dslash_4_pre(out, gauge, in, parity[1], dagger, precision, gauge_param, mferm, b5, c5, true);
949  dslash_4_4d(tmp, gauge, out, parity[0], dagger, precision, gauge_param, mferm);
950  mdw_dslash_5_inv(out, gauge, tmp, parity[1], dagger, precision, gauge_param, mferm, kappa_mdwf);
951  mdw_dslash_4_pre(tmp, gauge, out, parity[0], dagger, precision, gauge_param, mferm, b5, c5, true);
952  dslash_4_4d(out, gauge, tmp, parity[1], dagger, precision, gauge_param, mferm);
953  mdw_dslash_5(tmp, gauge, in, parity[0], dagger, precision, gauge_param, mferm, kappa5, true);
954  for(int xs = 0 ; xs < Ls ; xs++) {
955  cxpay((char *)tmp + precision * Vh * spinorSiteSize * xs, kappa2[xs],
956  (char *)out + precision * Vh * spinorSiteSize * xs, Vh * spinorSiteSize, precision);
957  }
958  } else if (!symmetric && dagger) {
959  dslash_4_4d(out, gauge, in, parity[0], dagger, precision, gauge_param, mferm);
960  mdw_dslash_4_pre(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm, b5, c5, true);
961  mdw_dslash_5_inv(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm, kappa_mdwf);
962  dslash_4_4d(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm);
963  mdw_dslash_4_pre(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm, b5, c5, true);
964  mdw_dslash_5(tmp, gauge, in, parity[0], dagger, precision, gauge_param, mferm, kappa5, true);
965  for(int xs = 0 ; xs < Ls ; xs++) {
966  cxpay((char *)tmp + precision * Vh * spinorSiteSize * xs, kappa2[xs],
967  (char *)out + precision * Vh * spinorSiteSize * xs, Vh * spinorSiteSize, precision);
968  }
969  } else {
970  errorQuda("Unsupported matpc_type=%d dagger=%d", matpc_type, dagger);
971  }
972 
973  free(tmp);
974  free(kappa5);
975  free(kappa2);
976  free(kappa_mdwf);
977 }
978 
979 /*
980 // Apply the even-odd preconditioned Dirac operator
981 template <typename sFloat, typename gFloat>
982 void MatPC(sFloat *outEven, gFloat **gauge, sFloat *inEven, sFloat kappa,
983  QudaMatPCType matpc_type, sFloat mferm) {
984 
985  sFloat *tmp = (sFloat*)malloc(V5h*spinorSiteSize*sizeof(sFloat));
986 
987  // full dslash operator
988  if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
989  dslashReference_4d(tmp, gauge, inEven, 1, 0);
990  dslashReference_5th(tmp, inEven, 1, 0, mferm);
991  dslashReference_4d(outEven, gauge, tmp, 0, 0);
992  dslashReference_5th(outEven, tmp, 0, 0, mferm);
993  } else {
994  dslashReference_4d(tmp, gauge, inEven, 0, 0);
995  dslashReference_5th(tmp, inEven, 0, 0, mferm);
996  dslashReference_4d(outEven, gauge, tmp, 1, 0);
997  dslashReference_5th(outEven, tmp, 1, 0, mferm);
998  }
999 
1000  // lastly apply the kappa term
1001  sFloat kappa2 = -kappa*kappa;
1002  xpay(inEven, kappa2, outEven, V5h*spinorSiteSize);
1003  free(tmp);
1004 }
1005 
1006 // Apply the even-odd preconditioned Dirac operator
1007 template <typename sFloat, typename gFloat>
1008 void MatPCDag(sFloat *outEven, gFloat **gauge, sFloat *inEven, sFloat kappa,
1009  QudaMatPCType matpc_type, sFloat mferm) {
1010 
1011  sFloat *tmp = (sFloat*)malloc(V5h*spinorSiteSize*sizeof(sFloat));
1012 
1013  // full dslash operator
1014  if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
1015  dslashReference_4d(tmp, gauge, inEven, 1, 1);
1016  dslashReference_5th(tmp, inEven, 1, 1, mferm);
1017  dslashReference_4d(outEven, gauge, tmp, 0, 1);
1018  dslashReference_5th(outEven, tmp, 0, 1, mferm);
1019  } else {
1020  dslashReference_4d(tmp, gauge, inEven, 0, 1);
1021  dslashReference_5th(tmp, inEven, 0, 1, mferm);
1022  dslashReference_4d(outEven, gauge, tmp, 1, 1);
1023  dslashReference_5th(outEven, tmp, 1, 1, mferm);
1024  }
1025 
1026  sFloat kappa2 = -kappa*kappa;
1027  xpay(inEven, kappa2, outEven, V5h*spinorSiteSize);
1028  free(tmp);
1029 }
1030 */
1031 
1032 void matpc(void *outEven, void **gauge, void *inEven, double kappa,
1033  QudaMatPCType matpc_type, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision,
1034  double mferm) {
1035 /*
1036  if (!dagger_bit) {
1037  if (sPrecision == QUDA_DOUBLE_PRECISION)
1038  if (gPrecision == QUDA_DOUBLE_PRECISION)
1039  MatPC((double*)outEven, (double**)gauge, (double*)inEven, (double)kappa, matpc_type, (double)mferm);
1040  else
1041  MatPC((double*)outEven, (float**)gauge, (double*)inEven, (double)kappa, matpc_type, (double)mferm);
1042  else
1043  if (gPrecision == QUDA_DOUBLE_PRECISION)
1044  MatPC((float*)outEven, (double**)gauge, (float*)inEven, (float)kappa, matpc_type, (float)mferm);
1045  else
1046  MatPC((float*)outEven, (float**)gauge, (float*)inEven, (float)kappa, matpc_type, (float)mferm);
1047  } else {
1048  if (sPrecision == QUDA_DOUBLE_PRECISION)
1049  if (gPrecision == QUDA_DOUBLE_PRECISION)
1050  MatPCDag((double*)outEven, (double**)gauge, (double*)inEven, (double)kappa, matpc_type, (double)mferm);
1051  else
1052  MatPCDag((double*)outEven, (float**)gauge, (double*)inEven, (double)kappa, matpc_type, (double)mferm);
1053  else
1054  if (gPrecision == QUDA_DOUBLE_PRECISION)
1055  MatPCDag((float*)outEven, (double**)gauge, (float*)inEven, (float)kappa, matpc_type, (float)mferm);
1056  else
1057  MatPCDag((float*)outEven, (float**)gauge, (float*)inEven, (float)kappa, matpc_type, (float)mferm);
1058  }
1059 */
1060 }
1061 
1062 /*
1063 template <typename sFloat, typename gFloat>
1064 void MatDagMat(sFloat *out, gFloat **gauge, sFloat *in, sFloat kappa, sFloat mferm)
1065 {
1066  // Allocate a full spinor.
1067  sFloat *tmp = (sFloat*)malloc(V5*spinorSiteSize*sizeof(sFloat));
1068  // Call templates above.
1069  Mat(tmp, gauge, in, kappa, mferm);
1070  MatDag(out, gauge, tmp, kappa, mferm);
1071  free(tmp);
1072 }
1073 
1074 template <typename sFloat, typename gFloat>
1075 void MatPCDagMatPC(sFloat *out, gFloat **gauge, sFloat *in, sFloat kappa,
1076  QudaMatPCType matpc_type, sFloat mferm)
1077 {
1078 
1079  // Allocate half spinor
1080  sFloat *tmp = (sFloat*)malloc(V5h*spinorSiteSize*sizeof(sFloat));
1081  // Apply the PC templates above
1082  MatPC(tmp, gauge, in, kappa, matpc_type, mferm);
1083  MatPCDag(out, gauge, tmp, kappa, matpc_type, mferm);
1084  free(tmp);
1085 }
1086 */
1087 // Wrapper to templates that handles different precisions.
1088 void matdagmat(void *out, void **gauge, void *in, double kappa,
1089  QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm)
1090 {
1091 /*
1092  if (sPrecision == QUDA_DOUBLE_PRECISION) {
1093  if (gPrecision == QUDA_DOUBLE_PRECISION)
1094  MatDagMat((double*)out, (double**)gauge, (double*)in, (double)kappa,
1095  (double)mferm);
1096  else
1097  MatDagMat((double*)out, (float**)gauge, (double*)in, (double)kappa, (double)mferm);
1098  } else {
1099  if (gPrecision == QUDA_DOUBLE_PRECISION)
1100  MatDagMat((float*)out, (double**)gauge, (float*)in, (float)kappa,
1101  (float)mferm);
1102  else
1103  MatDagMat((float*)out, (float**)gauge, (float*)in, (float)kappa, (float)mferm);
1104  }
1105 */
1106 }
1107 
1108 // Wrapper to templates that handles different precisions.
1109 void matpcdagmatpc(void *out, void **gauge, void *in, double kappa,
1110  QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm, QudaMatPCType matpc_type)
1111 {
1112 /*
1113  if (sPrecision == QUDA_DOUBLE_PRECISION) {
1114  if (gPrecision == QUDA_DOUBLE_PRECISION)
1115  MatPCDagMatPC((double*)out, (double**)gauge, (double*)in, (double)kappa,
1116  matpc_type, (double)mferm);
1117  else
1118  MatPCDagMatPC((double*)out, (float**)gauge, (double*)in, (double)kappa,
1119  matpc_type, (double)mferm);
1120  } else {
1121  if (gPrecision == QUDA_DOUBLE_PRECISION)
1122  MatPCDagMatPC((float*)out, (double**)gauge, (float*)in, (float)kappa,
1123  matpc_type, (float)mferm);
1124  else
1125  MatPCDagMatPC((float*)out, (float**)gauge, (float*)in, (float)kappa,
1126  matpc_type, (float)mferm);
1127  }
1128 */
1129 }
1130 
1131 
void mdslashReference_5th_inv(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm, sComplex *kappa)
QudaGhostExchange ghostExchange
Definition: lattice_field.h:76
void mdw_matpc(void *out, void **gauge, void *in, double _Complex *kappa_b, double _Complex *kappa_c, QudaMatPCType matpc_type, int dagger, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *b5, double _Complex *c5)
void setPrecision(QudaPrecision precision, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION, bool force_native=false)
void dw_4d_matpc(void *out, void **gauge, void *in, double kappa, QudaMatPCType matpc_type, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
int Z[4]
Definition: test_util.cpp:26
void dw_dslash_5_4d(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, bool zero_initialize)
enum QudaPrecision_s QudaPrecision
void mdw_mat(void *out, void **gauge, void *in, double _Complex *kappa_b, double _Complex *kappa_c, int dagger, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *b5, double _Complex *c5)
double kappa
Definition: test_util.cpp:1647
void dslashReference_5th(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm)
static int X2
Definition: face_gauge.cpp:42
#define errorQuda(...)
Definition: util_quda.h:121
void printSpinorElement(void *spinor, int X, QudaPrecision precision)
Definition: test_util.cpp:223
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:44
int fullLatticeIndex_5d_4dpc(int i, int oddBit)
Definition: test_util.cpp:687
void dw_dslash(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
QudaGaugeParam gauge_param
void dw_mat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
__host__ __device__ void sum(double &a, double &b)
Definition: blas_helper.cuh:62
#define spinorSiteSize
QudaSiteSubset siteSubset
Definition: lattice_field.h:71
static void axpby(Float a, Float *x, Float b, Float *y, int len)
Definition: dslash_util.h:33
void xpay(ColorSpinorField &x, double a, ColorSpinorField &y)
Definition: blas_quda.h:37
int fullLatticeIndex(int dim[4], int index, int oddBit)
Definition: test_util.cpp:439
int Ls
Definition: test_util.cpp:38
void matpc(void *outEven, void **gauge, void *inEven, double kappa, QudaMatPCType matpc_type, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm)
int x[QUDA_MAX_DIM]
Definition: lattice_field.h:67
void mdw_dslash_4_pre(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *b5, double _Complex *c5, bool zero_initialize)
void dw_matdagmat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
void dw_matpc(void *out, void **gauge, void *in, double kappa, QudaMatPCType matpc_type, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
void ax(const double &a, GaugeField &u)
Scale the gauge field by the scalar a.
void mdw_dslash_5_inv(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *kappa)
int V5h
Definition: test_util.cpp:40
ColorSpinorParam csParam
Definition: pack_test.cpp:24
cpuColorSpinorField * in
void dw_4d_mat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
Float * gaugeLink_sgpu(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd)
enum QudaMatPCType_s QudaMatPCType
void cxpay(void *x, double _Complex a, void *y, int length, QudaPrecision precision)
void dslashReference_4d_sgpu(sFloat *res, gFloat **gaugeFull, sFloat *spinorField, int oddBit, int daggerBit)
__host__ __device__ ValueType pow(ValueType x, ExponentType e)
Definition: complex_quda.h:111
static void * backGhostFaceBuffer[QUDA_MAX_DIM]
const void ** Ghost() const
Definition: gauge_field.h:323
int X[4]
Definition: covdev_test.cpp:70
enum QudaParity_s QudaParity
QudaMatPCType matpc_type
Definition: test_util.cpp:1662
void exchangeGhost(QudaParity parity, int nFace, int dagger, const MemoryLocation *pack_destination=nullptr, const MemoryLocation *halo_location=nullptr, bool gdr_send=false, bool gdr_recv=false, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION) const
This is a unified ghost exchange function for doing a complete halo exchange regardless of the type o...
static int X3
Definition: face_gauge.cpp:42
void matpcdagmatpc(void *out, void **gauge, void *in, double kappa, QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm, QudaMatPCType matpc_type)
static void * fwdGhostFaceBuffer[QUDA_MAX_DIM]
void dslashReference_5th_inv(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm, double *kappa)
static int X1
Definition: face_gauge.cpp:42
#define mySpinorSiteSize
int fullLatticeIndex_5d(int i, int oddBit)
Definition: test_util.cpp:682
cpuColorSpinorField * out
void mdw_dslash_5(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double _Complex *kappa, bool zero_initialize)
Main header file for the QUDA library.
__shared__ float s[]
int fullLatticeIndex_4d(int i, int oddBit)
Definition: test_util.cpp:648
const double projector[10][4][4][2]
void dslash_4_4d(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
int V5
Definition: test_util.cpp:39
int faceVolume[4]
Definition: test_util.cpp:31
void matdagmat(void *out, void **gauge, void *in, double kappa, QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm)
static void su3Mul(sFloat *res, gFloat *mat, sFloat *vec)
Definition: dslash_util.h:80
__device__ void axpy(real a, const real *x, Link &y)
static void su3Tmul(sFloat *res, gFloat *mat, sFloat *vec)
Definition: dslash_util.h:85
Float * gaugeLink_mgpu(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd, Float **ghostGaugeEven, Float **ghostGaugeOdd, int n_ghost_faces, int nbr_distance)
QudaDagType dagger
Definition: test_util.cpp:1620
int neighborIndex_4d(int i, int oddBit, int dx4, int dx3, int dx2, int dx1)
QudaParity parity
Definition: covdev_test.cpp:54
void dslash_5_inv(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double *kappa)
void multiplySpinorByDiracProjector5(Float *res, int projIdx, Float *spinorIn)
double kappa5
cpuColorSpinorField * spinor
Definition: covdev_test.cpp:41
int comm_dim_partitioned(int dim)
static int X4
Definition: face_gauge.cpp:42
#define gaugeSiteSize
Definition: face_gauge.cpp:34
int Vh
Definition: test_util.cpp:28