QUDA  0.9.0
domain_wall_dslash_reference.cpp
Go to the documentation of this file.
1 #include <iostream>
2 #include <stdio.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <math.h>
6 
7 #include <quda.h>
8 #include <test_util.h>
9 #include <dslash_util.h>
11 #include <blas_reference.h>
12 
13 #include <gauge_field.h>
14 #include <color_spinor_field.h>
15 
16 using namespace quda;
17 
18 // i represents a "half index" into an even or odd "half lattice".
19 // when oddBit={0,1} the half lattice is {even,odd}.
20 //
21 // the displacements, such as dx, refer to the full lattice coordinates.
22 //
23 // neighborIndex() takes a "half index", displaces it, and returns the
24 // new "half index", which can be an index into either the even or odd lattices.
25 // displacements of magnitude one always interchange odd and even lattices.
26 //
27 //
28 int neighborIndex_4d(int i, int oddBit, int dx4, int dx3, int dx2, int dx1) {
29  // On input i should be in the range [0 , ... , Z[0]*Z[1]*Z[2]*Z[3]/2-1].
30  if (i < 0 || i >= (Z[0]*Z[1]*Z[2]*Z[3]/2))
31  { printf("i out of range in neighborIndex_4d\n"); exit(-1); }
32  // Compute the linear index. Then dissect.
33  // fullLatticeIndex_4d is in util_quda.cpp.
34  // The gauge fields live on a 4d sublattice.
35  int X = fullLatticeIndex_4d(i, oddBit);
36  int x4 = X/(Z[2]*Z[1]*Z[0]);
37  int x3 = (X/(Z[1]*Z[0])) % Z[2];
38  int x2 = (X/Z[0]) % Z[1];
39  int x1 = X % Z[0];
40 
41  x4 = (x4+dx4+Z[3]) % Z[3];
42  x3 = (x3+dx3+Z[2]) % Z[2];
43  x2 = (x2+dx2+Z[1]) % Z[1];
44  x1 = (x1+dx1+Z[0]) % Z[0];
45 
46  return (x4*(Z[2]*Z[1]*Z[0]) + x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2;
47 }
48 
49 
50 
51 //#ifndef MULTI_GPU
52 // This is just a copy of gaugeLink() from the quda code, except
53 // that neighborIndex() is replaced by the renamed version
54 // neighborIndex_4d().
55 //ok
56 template <typename Float>
57 Float *gaugeLink_sgpu(int i, int dir, int oddBit, Float **gaugeEven,
58  Float **gaugeOdd) {
59  Float **gaugeField;
60  int j;
61 
62  // If going forward, just grab link at site, U_\mu(x).
63  if (dir % 2 == 0) {
64  j = i;
65  // j will get used in the return statement below.
66  gaugeField = (oddBit ? gaugeOdd : gaugeEven);
67  } else {
68  // If going backward, a shift must occur, U_\mu(x-\muhat)^\dagger;
69  // dagger happens elsewhere, here we're just doing index gymnastics.
70  switch (dir) {
71  case 1: j = neighborIndex_4d(i, oddBit, 0, 0, 0, -1); break;
72  case 3: j = neighborIndex_4d(i, oddBit, 0, 0, -1, 0); break;
73  case 5: j = neighborIndex_4d(i, oddBit, 0, -1, 0, 0); break;
74  case 7: j = neighborIndex_4d(i, oddBit, -1, 0, 0, 0); break;
75  default: j = -1; break;
76  }
77  gaugeField = (oddBit ? gaugeEven : gaugeOdd);
78  }
79 
80  return &gaugeField[dir/2][j*(3*3*2)];
81 }
82 
83 
84 //#else
85 
86 //Standard 4d version (nothing to change)
87 template <typename Float>
88 Float *gaugeLink_mgpu(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd, Float** ghostGaugeEven, Float** ghostGaugeOdd, int n_ghost_faces, int nbr_distance) {
89  Float **gaugeField;
90  int j;
91  int d = nbr_distance;
92  if (dir % 2 == 0) {
93  j = i;
94  gaugeField = (oddBit ? gaugeOdd : gaugeEven);
95  }
96  else {
97 
98  int Y = fullLatticeIndex(i, oddBit);
99  int x4 = Y/(Z[2]*Z[1]*Z[0]);
100  int x3 = (Y/(Z[1]*Z[0])) % Z[2];
101  int x2 = (Y/Z[0]) % Z[1];
102  int x1 = Y % Z[0];
103  int X1= Z[0];
104  int X2= Z[1];
105  int X3= Z[2];
106  int X4= Z[3];
107  Float* ghostGaugeField;
108 
109  switch (dir) {
110  case 1:
111  { //-X direction
112  int new_x1 = (x1 - d + X1 )% X1;
113  if (x1 -d < 0 && comm_dim_partitioned(0)){
114  ghostGaugeField = (oddBit?ghostGaugeEven[0]: ghostGaugeOdd[0]);
115  int offset = (n_ghost_faces + x1 -d)*X4*X3*X2/2 + (x4*X3*X2 + x3*X2+x2)/2;
116  return &ghostGaugeField[offset*(3*3*2)];
117  }
118  j = (x4*X3*X2*X1 + x3*X2*X1 + x2*X1 + new_x1) / 2;
119  break;
120  }
121  case 3:
122  { //-Y direction
123  int new_x2 = (x2 - d + X2 )% X2;
124  if (x2 -d < 0 && comm_dim_partitioned(1)){
125  ghostGaugeField = (oddBit?ghostGaugeEven[1]: ghostGaugeOdd[1]);
126  int offset = (n_ghost_faces + x2 -d)*X4*X3*X1/2 + (x4*X3*X1 + x3*X1+x1)/2;
127  return &ghostGaugeField[offset*(3*3*2)];
128  }
129  j = (x4*X3*X2*X1 + x3*X2*X1 + new_x2*X1 + x1) / 2;
130  break;
131 
132  }
133  case 5:
134  { //-Z direction
135  int new_x3 = (x3 - d + X3 )% X3;
136  if (x3 -d < 0 && comm_dim_partitioned(2)){
137  ghostGaugeField = (oddBit?ghostGaugeEven[2]: ghostGaugeOdd[2]);
138  int offset = (n_ghost_faces + x3 -d)*X4*X2*X1/2 + (x4*X2*X1 + x2*X1+x1)/2;
139  return &ghostGaugeField[offset*(3*3*2)];
140  }
141  j = (x4*X3*X2*X1 + new_x3*X2*X1 + x2*X1 + x1) / 2;
142  break;
143  }
144  case 7:
145  { //-T direction
146  int new_x4 = (x4 - d + X4)% X4;
147  if (x4 -d < 0 && comm_dim_partitioned(3)){
148  ghostGaugeField = (oddBit?ghostGaugeEven[3]: ghostGaugeOdd[3]);
149  int offset = (n_ghost_faces + x4 -d)*X1*X2*X3/2 + (x3*X2*X1 + x2*X1+x1)/2;
150  return &ghostGaugeField[offset*(3*3*2)];
151  }
152  j = (new_x4*(X3*X2*X1) + x3*(X2*X1) + x2*(X1) + x1) / 2;
153  break;
154  }//7
155 
156  default: j = -1; printf("ERROR: wrong dir \n"); exit(1);
157  }
158  gaugeField = (oddBit ? gaugeEven : gaugeOdd);
159 
160  }
161 
162  return &gaugeField[dir/2][j*(3*3*2)];
163 }
164 
165 
166 //J Directions 0..7 were used in the 4d code.
167 //J Directions 8,9 will be for P_- and P_+, chiral
168 //J projectors.
169 const double projector[10][4][4][2] = {
170  {
171  {{1,0}, {0,0}, {0,0}, {0,-1}},
172  {{0,0}, {1,0}, {0,-1}, {0,0}},
173  {{0,0}, {0,1}, {1,0}, {0,0}},
174  {{0,1}, {0,0}, {0,0}, {1,0}}
175  },
176  {
177  {{1,0}, {0,0}, {0,0}, {0,1}},
178  {{0,0}, {1,0}, {0,1}, {0,0}},
179  {{0,0}, {0,-1}, {1,0}, {0,0}},
180  {{0,-1}, {0,0}, {0,0}, {1,0}}
181  },
182  {
183  {{1,0}, {0,0}, {0,0}, {1,0}},
184  {{0,0}, {1,0}, {-1,0}, {0,0}},
185  {{0,0}, {-1,0}, {1,0}, {0,0}},
186  {{1,0}, {0,0}, {0,0}, {1,0}}
187  },
188  {
189  {{1,0}, {0,0}, {0,0}, {-1,0}},
190  {{0,0}, {1,0}, {1,0}, {0,0}},
191  {{0,0}, {1,0}, {1,0}, {0,0}},
192  {{-1,0}, {0,0}, {0,0}, {1,0}}
193  },
194  {
195  {{1,0}, {0,0}, {0,-1}, {0,0}},
196  {{0,0}, {1,0}, {0,0}, {0,1}},
197  {{0,1}, {0,0}, {1,0}, {0,0}},
198  {{0,0}, {0,-1}, {0,0}, {1,0}}
199  },
200  {
201  {{1,0}, {0,0}, {0,1}, {0,0}},
202  {{0,0}, {1,0}, {0,0}, {0,-1}},
203  {{0,-1}, {0,0}, {1,0}, {0,0}},
204  {{0,0}, {0,1}, {0,0}, {1,0}}
205  },
206  {
207  {{1,0}, {0,0}, {-1,0}, {0,0}},
208  {{0,0}, {1,0}, {0,0}, {-1,0}},
209  {{-1,0}, {0,0}, {1,0}, {0,0}},
210  {{0,0}, {-1,0}, {0,0}, {1,0}}
211  },
212  {
213  {{1,0}, {0,0}, {1,0}, {0,0}},
214  {{0,0}, {1,0}, {0,0}, {1,0}},
215  {{1,0}, {0,0}, {1,0}, {0,0}},
216  {{0,0}, {1,0}, {0,0}, {1,0}}
217  },
218  // P_+ = P_R
219  {
220  {{0,0}, {0,0}, {0,0}, {0,0}},
221  {{0,0}, {0,0}, {0,0}, {0,0}},
222  {{0,0}, {0,0}, {2,0}, {0,0}},
223  {{0,0}, {0,0}, {0,0}, {2,0}}
224  },
225  // P_- = P_L
226  {
227  {{2,0}, {0,0}, {0,0}, {0,0}},
228  {{0,0}, {2,0}, {0,0}, {0,0}},
229  {{0,0}, {0,0}, {0,0}, {0,0}},
230  {{0,0}, {0,0}, {0,0}, {0,0}}
231  }
232 };
233 
234 
235 // todo pass projector
236 template <typename Float>
237 void multiplySpinorByDiracProjector5(Float *res, int projIdx, Float *spinorIn) {
238  for (int i=0; i<4*3*2; i++) res[i] = 0.0;
239 
240  for (int s = 0; s < 4; s++) {
241  for (int t = 0; t < 4; t++) {
242  Float projRe = projector[projIdx][s][t][0];
243  Float projIm = projector[projIdx][s][t][1];
244 
245  for (int m = 0; m < 3; m++) {
246  Float spinorRe = spinorIn[t*(3*2) + m*(2) + 0];
247  Float spinorIm = spinorIn[t*(3*2) + m*(2) + 1];
248  res[s*(3*2) + m*(2) + 0] += projRe*spinorRe - projIm*spinorIm;
249  res[s*(3*2) + m*(2) + 1] += projRe*spinorIm + projIm*spinorRe;
250  }
251  }
252  }
253 }
254 
255 
256 //#ifndef MULTI_GPU
257 // dslashReference_4d()
258 //J This is just the 4d wilson dslash of quda code, with a
259 //J few small changes to take into account that the spinors
260 //J are 5d and the gauge fields are 4d.
261 //
262 // if oddBit is zero: calculate odd parity spinor elements (using even parity spinor)
263 // if oddBit is one: calculate even parity spinor elements
264 //
265 // if daggerBit is zero: perform ordinary dslash operator
266 // if daggerBit is one: perform hermitian conjugate of dslash
267 //
268 //An "ok" will only be granted once check2.tex is deemed complete,
269 //since the logic in this function is important and nontrivial.
270 template <QudaDWFPCType type, typename sFloat, typename gFloat>
271 void dslashReference_4d_sgpu(sFloat *res, gFloat **gaugeFull, sFloat *spinorField,
272  int oddBit, int daggerBit) {
273 
274  // Initialize the return half-spinor to zero. Note that it is a
275  // 5d spinor, hence the use of V5h.
276  for (int i=0; i<V5h*4*3*2; i++) res[i] = 0.0;
277 
278  // Some pointers that we use to march through arrays.
279  gFloat *gaugeEven[4], *gaugeOdd[4];
280  // Initialize to beginning of even and odd parts of
281  // gauge array.
282  for (int dir = 0; dir < 4; dir++) {
283  gaugeEven[dir] = gaugeFull[dir];
284  // Note the use of Vh here, since the gauge fields
285  // are 4-dim'l.
286  gaugeOdd[dir] = gaugeFull[dir]+Vh*gaugeSiteSize;
287  }
288  int sp_idx,gaugeOddBit;
289  for (int xs=0;xs<Ls;xs++) {
290  for (int gge_idx = 0; gge_idx < Vh; gge_idx++) {
291  for (int dir = 0; dir < 8; dir++) {
292  sp_idx=gge_idx+Vh*xs;
293  // Here is a function call to study. It is defined near
294  // Line 90 of this file.
295  // Here we have to switch oddBit depending on the value of xs. E.g., suppose
296  // xs=1. Then the odd spinor site x1=x2=x3=x4=0 wants the even gauge array
297  // element 0, so that we get U_\mu(0).
298  gaugeOddBit = (xs%2 == 0 || type == QUDA_4D_PC) ? oddBit : (oddBit+1) % 2;
299  gFloat *gauge = gaugeLink_sgpu(gge_idx, dir, gaugeOddBit, gaugeEven, gaugeOdd);
300 
301  // Even though we're doing the 4d part of the dslash, we need
302  // to use a 5d neighbor function, to get the offsets right.
303  sFloat *spinor = spinorNeighbor_5d<type>(sp_idx, dir, oddBit, spinorField);
304  sFloat projectedSpinor[4*3*2], gaugedSpinor[4*3*2];
305  int projIdx = 2*(dir/2)+(dir+daggerBit)%2;
306  multiplySpinorByDiracProjector5(projectedSpinor, projIdx, spinor);
307 
308  for (int s = 0; s < 4; s++) {
309  if (dir % 2 == 0) {
310  su3Mul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]);
311 #ifdef DBUG_VERBOSE
312  std::cout << "spinor:" << std::endl;
313  printSpinorElement(&projectedSpinor[s*(3*2)],0,QUDA_DOUBLE_PRECISION);
314  std::cout << "gauge:" << std::endl;
315 #endif
316  } else {
317  su3Tmul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]);
318  }
319  }
320 
321  sum(&res[sp_idx*(4*3*2)], &res[sp_idx*(4*3*2)], gaugedSpinor, 4*3*2);
322  }
323  }
324  }
325 }
326 
327 #ifdef MULTI_GPU
328 template <QudaDWFPCType type, typename sFloat, typename gFloat>
329 void dslashReference_4d_mgpu(sFloat *res, gFloat **gaugeFull, gFloat **ghostGauge, sFloat *spinorField,
330  sFloat **fwdSpinor, sFloat **backSpinor, int oddBit, int daggerBit)
331 {
332  int mySpinorSiteSize = 24;
333  for (int i=0; i<V5h*mySpinorSiteSize; i++) res[i] = 0.0;
334 
335  gFloat *gaugeEven[4], *gaugeOdd[4];
336  gFloat *ghostGaugeEven[4], *ghostGaugeOdd[4];
337 
338  for (int dir = 0; dir < 4; dir++)
339  {
340  gaugeEven[dir] = gaugeFull[dir];
341  gaugeOdd[dir] = gaugeFull[dir]+Vh*gaugeSiteSize;
342 
343  ghostGaugeEven[dir] = ghostGauge[dir];
344  ghostGaugeOdd[dir] = ghostGauge[dir] + (faceVolume[dir]/2)*gaugeSiteSize;
345  }
346  for (int xs=0;xs<Ls;xs++)
347  {
348  int sp_idx;
349  for (int i = 0; i < Vh; i++)
350  {
351  sp_idx = i + Vh*xs;
352  for (int dir = 0; dir < 8; dir++)
353  {
354  int gaugeOddBit = (xs%2 == 0 || type == QUDA_4D_PC) ? oddBit : (oddBit + 1) % 2;
355 
356  gFloat *gauge = gaugeLink_mgpu(i, dir, gaugeOddBit, gaugeEven, gaugeOdd, ghostGaugeEven, ghostGaugeOdd, 1, 1);//this is unchanged from MPi version
357  sFloat *spinor = spinorNeighbor_5d_mgpu<type>(sp_idx, dir, oddBit, spinorField, fwdSpinor, backSpinor, 1, 1);
358 
359  sFloat projectedSpinor[mySpinorSiteSize], gaugedSpinor[mySpinorSiteSize];
360  int projIdx = 2*(dir/2)+(dir+daggerBit)%2;
361  multiplySpinorByDiracProjector5(projectedSpinor, projIdx, spinor);
362 
363  for (int s = 0; s < 4; s++)
364  {
365  if (dir % 2 == 0) su3Mul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]);
366  else su3Tmul(&gaugedSpinor[s*(3*2)], gauge, &projectedSpinor[s*(3*2)]);
367  }
368  sum(&res[sp_idx*(4*3*2)], &res[sp_idx*(4*3*2)], gaugedSpinor, 4*3*2);
369  }
370  }
371  }
372 }
373 #endif
374 
375 //Currently we consider only spacetime decomposition (not in 5th dim), so this operator is local
376 template <QudaDWFPCType type, bool zero_initialize=false, typename sFloat>
377 void dslashReference_5th(sFloat *res, sFloat *spinorField,
378  int oddBit, int daggerBit, sFloat mferm) {
379  for (int i = 0; i < V5h; i++) {
380  if (zero_initialize) for(int one_site = 0 ; one_site < 24 ; one_site++)
381  res[i*(4*3*2)+one_site] = 0.0;
382  for (int dir = 8; dir < 10; dir++) {
383  // Calls for an extension of the original function.
384  // 8 is forward hop, which wants P_+, 9 is backward hop,
385  // which wants P_-. Dagger reverses these.
386  sFloat *spinor = spinorNeighbor_5d<type>(i, dir, oddBit, spinorField);
387  sFloat projectedSpinor[4*3*2];
388  int projIdx = 2*(dir/2)+(dir+daggerBit)%2;
389  multiplySpinorByDiracProjector5(projectedSpinor, projIdx, spinor);
390  //J Need a conditional here for s=0 and s=Ls-1.
391  int X = (type == QUDA_5D_PC) ? fullLatticeIndex_5d(i, oddBit) : fullLatticeIndex_5d_4dpc(i, oddBit);
392  int xs = X/(Z[3]*Z[2]*Z[1]*Z[0]);
393 
394  if ( (xs == 0 && dir == 9) || (xs == Ls-1 && dir == 8) ) {
395  ax(projectedSpinor,(sFloat)(-mferm),projectedSpinor,4*3*2);
396  }
397  sum(&res[i*(4*3*2)], &res[i*(4*3*2)], projectedSpinor, 4*3*2);
398  }
399  }
400 }
401 
402 //Currently we consider only spacetime decomposition (not in 5th dim), so this operator is local
403 template <typename sFloat>
404 void dslashReference_5th_inv(sFloat *res, sFloat *spinorField,
405  int oddBit, int daggerBit, sFloat mferm, double *kappa) {
406  double *inv_Ftr = (double*)malloc(Ls*sizeof(sFloat));
407  double *Ftr = (double*)malloc(Ls*sizeof(sFloat));
408  for(int xs = 0 ; xs < Ls ; xs++)
409  {
410  inv_Ftr[xs] = 1.0/(1.0+pow(2.0*kappa[xs], Ls)*mferm);
411  Ftr[xs] = -2.0*kappa[xs]*mferm*inv_Ftr[xs];
412  for (int i = 0; i < Vh; i++) {
413  memcpy(&res[24*(i+Vh*xs)], &spinorField[24*(i+Vh*xs)], 24*sizeof(sFloat));
414  }
415  }
416  if(daggerBit == 0)
417  {
418  // s = 0
419  for (int i = 0; i < Vh; i++) {
420  ax(&res[12+24*(i+Vh*(Ls-1))],(sFloat)(inv_Ftr[0]), &spinorField[12+24*(i+Vh*(Ls-1))], 12);
421  }
422 
423  // s = 1 ... ls-2
424  for(int xs = 0 ; xs <= Ls-2 ; ++xs)
425  {
426  for (int i = 0; i < Vh; i++) {
427  axpy((sFloat)(2.0*kappa[xs]), &res[24*(i+Vh*xs)], &res[24*(i+Vh*(xs+1))], 12);
428  axpy((sFloat)Ftr[xs], &res[12+24*(i+Vh*xs)], &res[12+24*(i+Vh*(Ls-1))], 12);
429  }
430  for (int tmp_s = 0 ; tmp_s < Ls ; tmp_s++)
431  Ftr[tmp_s] *= 2.0*kappa[tmp_s];
432  }
433  for(int xs = 0 ; xs < Ls ; xs++)
434  {
435  Ftr[xs] = -pow(2.0*kappa[xs],Ls-1)*mferm*inv_Ftr[xs];
436  }
437  // s = ls-2 ... 0
438  for(int xs = Ls-2 ; xs >=0 ; --xs)
439  {
440  for (int i = 0; i < Vh; i++) {
441  axpy((sFloat)Ftr[xs], &res[24*(i+Vh*(Ls-1))], &res[24*(i+Vh*xs)], 12);
442  axpy((sFloat)(2.0*kappa[xs]), &res[12+24*(i+Vh*(xs+1))], &res[12+24*(i+Vh*xs)], 12);
443  }
444  for (int tmp_s = 0 ; tmp_s < Ls ; tmp_s++)
445  Ftr[tmp_s] /= 2.0*kappa[tmp_s];
446  }
447  // s = ls -1
448  for (int i = 0; i < Vh; i++) {
449  ax(&res[24*(i+Vh*(Ls-1))], (sFloat)(inv_Ftr[Ls-1]), &res[24*(i+Vh*(Ls-1))], 12);
450  }
451  }
452  else
453  {
454  // s = 0
455  for (int i = 0; i < Vh; i++) {
456  ax(&res[24*(i+Vh*(Ls-1))],(sFloat)(inv_Ftr[0]), &spinorField[24*(i+Vh*(Ls-1))], 12);
457  }
458 
459  // s = 1 ... ls-2
460  for(int xs = 0 ; xs <= Ls-2 ; ++xs)
461  {
462  for (int i = 0; i < Vh; i++) {
463  axpy((sFloat)Ftr[xs], &res[24*(i+Vh*xs)], &res[24*(i+Vh*(Ls-1))], 12);
464  axpy((sFloat)(2.0*kappa[xs]), &res[12+24*(i+Vh*xs)], &res[12+24*(i+Vh*(xs+1))], 12);
465  }
466  for (int tmp_s = 0 ; tmp_s < Ls ; tmp_s++)
467  Ftr[tmp_s] *= 2.0*kappa[tmp_s];
468  }
469  for(int xs = 0 ; xs < Ls ; xs++)
470  {
471  Ftr[xs] = -pow(2.0*kappa[xs],Ls-1)*mferm*inv_Ftr[xs];
472  }
473  // s = ls-2 ... 0
474  for(int xs = Ls-2 ; xs >=0 ; --xs)
475  {
476  for (int i = 0; i < Vh; i++) {
477  axpy((sFloat)(2.0*kappa[xs]), &res[24*(i+Vh*(xs+1))], &res[24*(i+Vh*xs)], 12);
478  axpy((sFloat)Ftr[xs], &res[12+24*(i+Vh*(Ls-1))], &res[12+24*(i+Vh*xs)], 12);
479  }
480  for (int tmp_s = 0 ; tmp_s < Ls ; tmp_s++)
481  Ftr[tmp_s] /= 2.0*kappa[tmp_s];
482  }
483  // s = ls -1
484  for (int i = 0; i < Vh; i++) {
485  ax(&res[12+24*(i+Vh*(Ls-1))], (sFloat)(inv_Ftr[Ls-1]), &res[12+24*(i+Vh*(Ls-1))], 12);
486  }
487  }
488  free(inv_Ftr);
489  free(Ftr);
490 }
491 
492 
493 // this actually applies the preconditioned dslash, e.g., D_ee^{-1} D_eo or D_oo^{-1} D_oe
494 void dw_dslash(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
495 {
496 #ifndef MULTI_GPU
497  if (precision == QUDA_DOUBLE_PRECISION) {
498  dslashReference_4d_sgpu<QUDA_5D_PC>((double*)out, (double**)gauge, (double*)in, oddBit, daggerBit);
499  dslashReference_5th<QUDA_5D_PC>((double*)out, (double*)in, oddBit, daggerBit, mferm);
500  } else {
501  dslashReference_4d_sgpu<QUDA_5D_PC>((float*)out, (float**)gauge, (float*)in, oddBit, daggerBit);
502  dslashReference_5th<QUDA_5D_PC>((float*)out, (float*)in, oddBit, daggerBit, (float)mferm);
503  }
504 #else
505 
506  GaugeFieldParam gauge_field_param(gauge, gauge_param);
507  gauge_field_param.ghostExchange = QUDA_GHOST_EXCHANGE_PAD;
508  cpuGaugeField cpu(gauge_field_param);
509  void **ghostGauge = (void**)cpu.Ghost();
510 
511  // Get spinor ghost fields
512  // First wrap the input spinor into a ColorSpinorField
514  csParam.v = in;
515  csParam.nColor = 3;
516  csParam.nSpin = 4;
517  csParam.nDim = 5; //for DW dslash
518  for (int d=0; d<4; d++) csParam.x[d] = Z[d];
519  csParam.x[4] = Ls;//5th dimention
520  csParam.precision = precision;
521  csParam.pad = 0;
523  csParam.x[0] /= 2;
529 
530  cpuColorSpinorField inField(csParam);
531 
532  { // Now do the exchange
533  QudaParity otherParity = QUDA_INVALID_PARITY;
534  if (oddBit == QUDA_EVEN_PARITY) otherParity = QUDA_ODD_PARITY;
535  else if (oddBit == QUDA_ODD_PARITY) otherParity = QUDA_EVEN_PARITY;
536  else errorQuda("ERROR: full parity not supported in function %s", __FUNCTION__);
537  const int nFace = 1;
538 
539  inField.exchangeGhost(otherParity, nFace, daggerBit);
540  }
541  void** fwd_nbr_spinor = inField.fwdGhostFaceBuffer;
542  void** back_nbr_spinor = inField.backGhostFaceBuffer;
543  //NOTE: hopping in 5th dimension does not use MPI.
544  if (precision == QUDA_DOUBLE_PRECISION) {
545  dslashReference_4d_mgpu<QUDA_5D_PC>((double*)out, (double**)gauge, (double**)ghostGauge, (double*)in,(double**)fwd_nbr_spinor, (double**)back_nbr_spinor, oddBit, daggerBit);
546  //dslashReference_4d_sgpu<QUDA_5D_PC>((double*)out, (double**)gauge, (double*)in, oddBit, daggerBit);
547  dslashReference_5th<QUDA_5D_PC>((double*)out, (double*)in, oddBit, daggerBit, mferm);
548  } else {
549  dslashReference_4d_mgpu<QUDA_5D_PC>((float*)out, (float**)gauge, (float**)ghostGauge, (float*)in,
550  (float**)fwd_nbr_spinor, (float**)back_nbr_spinor, oddBit, daggerBit);
551  dslashReference_5th<QUDA_5D_PC>((float*)out, (float*)in, oddBit, daggerBit, (float)mferm);
552  }
553 
554 #endif
555 
556 }
557 
558 void dslash_4_4d(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
559 {
560 #ifndef MULTI_GPU
561  if (precision == QUDA_DOUBLE_PRECISION) {
562  dslashReference_4d_sgpu<QUDA_4D_PC>((double*)out, (double**)gauge, (double*)in, oddBit, daggerBit);
563  } else {
564  dslashReference_4d_sgpu<QUDA_4D_PC>((float*)out, (float**)gauge, (float*)in, oddBit, daggerBit);
565  }
566 #else
567 
568  GaugeFieldParam gauge_field_param(gauge, gauge_param);
569  gauge_field_param.ghostExchange = QUDA_GHOST_EXCHANGE_PAD;
570  cpuGaugeField cpu(gauge_field_param);
571  void **ghostGauge = (void**)cpu.Ghost();
572 
573  // Get spinor ghost fields
574  // First wrap the input spinor into a ColorSpinorField
576  csParam.v = in;
577  csParam.nColor = 3;
578  csParam.nSpin = 4;
579  csParam.nDim = 5; //for DW dslash
580  for (int d=0; d<4; d++) csParam.x[d] = Z[d];
581  csParam.x[4] = Ls;//5th dimention
582  csParam.precision = precision;
583  csParam.pad = 0;
585  csParam.x[0] /= 2;
591 
592  cpuColorSpinorField inField(csParam);
593 
594  { // Now do the exchange
595  QudaParity otherParity = QUDA_INVALID_PARITY;
596  if (oddBit == QUDA_EVEN_PARITY) otherParity = QUDA_ODD_PARITY;
597  else if (oddBit == QUDA_ODD_PARITY) otherParity = QUDA_EVEN_PARITY;
598  else errorQuda("ERROR: full parity not supported in function %s", __FUNCTION__);
599  const int nFace = 1;
600 
601  inField.exchangeGhost(otherParity, nFace, daggerBit);
602  }
603  void** fwd_nbr_spinor = inField.fwdGhostFaceBuffer;
604  void** back_nbr_spinor = inField.backGhostFaceBuffer;
605  if (precision == QUDA_DOUBLE_PRECISION) {
606  dslashReference_4d_mgpu<QUDA_4D_PC>((double*)out, (double**)gauge, (double**)ghostGauge, (double*)in,(double**)fwd_nbr_spinor, (double**)back_nbr_spinor, oddBit, daggerBit);
607  } else {
608  dslashReference_4d_mgpu<QUDA_4D_PC>((float*)out, (float**)gauge, (float**)ghostGauge, (float*)in,
609  (float**)fwd_nbr_spinor, (float**)back_nbr_spinor, oddBit, daggerBit);
610  }
611 
612 #endif
613 
614 }
615 
616 void dw_dslash_5_4d(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, bool zero_initialize)
617 {
618  if (precision == QUDA_DOUBLE_PRECISION) {
619  if (zero_initialize) dslashReference_5th<QUDA_4D_PC, true>((double*)out, (double*)in, oddBit, daggerBit, mferm);
620  else dslashReference_5th<QUDA_4D_PC, false>((double*)out, (double*)in, oddBit, daggerBit, mferm);
621  } else {
622  if (zero_initialize) dslashReference_5th<QUDA_4D_PC, true>((float*)out, (float*)in, oddBit, daggerBit, (float)mferm);
623  else dslashReference_5th<QUDA_4D_PC, false>((float*)out, (float*)in, oddBit, daggerBit, (float)mferm);
624  }
625 }
626 
627 void dslash_5_inv(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double *kappa)
628 {
629  if (precision == QUDA_DOUBLE_PRECISION) {
630  dslashReference_5th_inv((double*)out, (double*)in, oddBit, daggerBit, mferm, kappa);
631  } else {
632  dslashReference_5th_inv((float*)out, (float*)in, oddBit, daggerBit, (float)mferm, kappa);
633  }
634 }
635 
636 void mdw_dslash_5(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double *kappa, bool zero_initialize)
637 {
638  if (precision == QUDA_DOUBLE_PRECISION) {
639  if (zero_initialize) dslashReference_5th<QUDA_4D_PC,true>((double*)out, (double*)in, oddBit, daggerBit, mferm);
640  else dslashReference_5th<QUDA_4D_PC,false>((double*)out, (double*)in, oddBit, daggerBit, mferm);
641  } else {
642  if (zero_initialize) dslashReference_5th<QUDA_4D_PC,true>((float*)out, (float*)in, oddBit, daggerBit, (float)mferm);
643  else dslashReference_5th<QUDA_4D_PC,false>((float*)out, (float*)in, oddBit, daggerBit, (float)mferm);
644  }
645  for(int xs = 0 ; xs < Ls ; xs++) {
646  xpay((char*)in + precision*Vh*spinorSiteSize*xs, kappa[xs], (char*)out + precision*Vh*spinorSiteSize*xs, Vh*spinorSiteSize, precision);
647  }
648 }
649 
650 void mdw_dslash_4_pre(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double *b5, double *c5, bool zero_initialize)
651 {
652  if (precision == QUDA_DOUBLE_PRECISION) {
653  if (zero_initialize) dslashReference_5th<QUDA_4D_PC, true>((double*)out, (double*)in, oddBit, daggerBit, mferm);
654  else dslashReference_5th<QUDA_4D_PC, false>((double*)out, (double*)in, oddBit, daggerBit, mferm);
655  for(int xs = 0 ; xs < Ls ; xs++)
656  {
657  axpby(b5[xs],(double*)in + Vh*spinorSiteSize*xs,0.5*c5[xs], (double*)out + Vh*spinorSiteSize*xs, Vh*spinorSiteSize);
658  }
659  } else {
660  if (zero_initialize) dslashReference_5th<QUDA_4D_PC, true>((float*)out, (float*)in, oddBit, daggerBit, (float)mferm);
661  else dslashReference_5th<QUDA_4D_PC,false>((float*)out, (float*)in, oddBit, daggerBit, (float)mferm);
662  for(int xs = 0 ; xs < Ls ; xs++)
663  {
664  axpby((float)(b5[xs]),(float*)in + Vh*spinorSiteSize*xs, (float)(0.5*c5[xs]), (float*)out + Vh*spinorSiteSize*xs, Vh*spinorSiteSize);
665  }
666  }
667 
668 }
669 
670 void dw_mat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm) {
671 
672  void *inEven = in;
673  void *inOdd = (char*)in + V5h*spinorSiteSize*precision;
674  void *outEven = out;
675  void *outOdd = (char*)out + V5h*spinorSiteSize*precision;
676 
677  dw_dslash(outOdd, gauge, inEven, 1, dagger_bit, precision, gauge_param, mferm);
678  dw_dslash(outEven, gauge, inOdd, 0, dagger_bit, precision, gauge_param, mferm);
679 
680  // lastly apply the kappa term
681  xpay(in, -kappa, out, V5*spinorSiteSize, precision);
682 }
683 
684 void dw_4d_mat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm) {
685 
686  void *inEven = in;
687  void *inOdd = (char*)in + V5h*spinorSiteSize*precision;
688  void *outEven = out;
689  void *outOdd = (char*)out + V5h*spinorSiteSize*precision;
690 
691  dslash_4_4d(outOdd, gauge, inEven, 1, dagger_bit, precision, gauge_param, mferm);
692  dw_dslash_5_4d(outOdd, gauge, inOdd, 1, dagger_bit, precision, gauge_param, mferm, false);
693 
694  dslash_4_4d(outEven, gauge, inOdd, 0, dagger_bit, precision, gauge_param, mferm);
695  dw_dslash_5_4d(outEven, gauge, inEven, 0, dagger_bit, precision, gauge_param, mferm, false);
696 
697  // lastly apply the kappa term
698  xpay(in, -kappa, out, V5*spinorSiteSize, precision);
699 }
700 
701 void mdw_mat(void *out, void **gauge, void *in, double *kappa_b, double *kappa_c, int dagger, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double *b5, double *c5) {
702 
703  void *tmp = malloc(V5h*spinorSiteSize*precision);
704  double *kappa5 = (double*)malloc(Ls*sizeof(double));
705 
706  for(int xs = 0; xs < Ls ; xs++) kappa5[xs] = 0.5*kappa_b[xs]/kappa_c[xs];
707 
708  void *inEven = in;
709  void *inOdd = (char*)in + V5h*spinorSiteSize*precision;
710  void *outEven = out;
711  void *outOdd = (char*)out + V5h*spinorSiteSize*precision;
712 
713  mdw_dslash_4_pre(tmp, gauge, inEven, 0, dagger, precision, gauge_param, mferm, b5, c5, true);
714  dslash_4_4d(outOdd, gauge, tmp, 1, dagger, precision, gauge_param, mferm);
715  mdw_dslash_5(tmp, gauge, inOdd, 1, dagger, precision, gauge_param, mferm, kappa5, true);
716 
717  for(int xs = 0 ; xs < Ls ; xs++) {
718  xpay((char*)tmp + precision*Vh*spinorSiteSize*xs, -kappa_b[xs], (char*)outOdd + precision*Vh*spinorSiteSize*xs,
719  Vh*spinorSiteSize, precision);
720  }
721 
722  mdw_dslash_4_pre(tmp, gauge, inOdd, 1, dagger, precision, gauge_param, mferm, b5, c5, true);
723  dslash_4_4d(outEven, gauge, tmp, 0, dagger, precision, gauge_param, mferm);
724  mdw_dslash_5(tmp, gauge, inEven, 0, dagger, precision, gauge_param, mferm, kappa5, true);
725 
726  for(int xs = 0 ; xs < Ls ; xs++) {
727  xpay((char*)tmp + precision*Vh*spinorSiteSize*xs, -kappa_b[xs], (char*)outEven + precision*Vh*spinorSiteSize*xs,
728  Vh*spinorSiteSize, precision);
729  }
730 
731  free(kappa5);
732  free(tmp);
733 }
734 
735 //
736 void dw_matdagmat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
737 {
738 
739  void *tmp = malloc(V5*spinorSiteSize*precision);
740  dw_mat(tmp, gauge, in, kappa, dagger_bit, precision, gauge_param, mferm);
741  dagger_bit = (dagger_bit == 1) ? 0 : 1;
742  dw_mat(out, gauge, tmp, kappa, dagger_bit, precision, gauge_param, mferm);
743 
744  free(tmp);
745 }
746 
747 void dw_matpc(void *out, void **gauge, void *in, double kappa, QudaMatPCType matpc_type, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
748 {
749  void *tmp = malloc(V5h*spinorSiteSize*precision);
750 
752  dw_dslash(tmp, gauge, in, 1, dagger_bit, precision, gauge_param, mferm);
753  dw_dslash(out, gauge, tmp, 0, dagger_bit, precision, gauge_param, mferm);
754  } else {
755  dw_dslash(tmp, gauge, in, 0, dagger_bit, precision, gauge_param, mferm);
756  dw_dslash(out, gauge, tmp, 1, dagger_bit, precision, gauge_param, mferm);
757  }
758 
759  // lastly apply the kappa term
760  double kappa2 = -kappa*kappa;
761  xpay(in, kappa2, out, V5h*spinorSiteSize, precision);
762 
763  free(tmp);
764 }
765 
766 
767 void dw_4d_matpc(void *out, void **gauge, void *in, double kappa, QudaMatPCType matpc_type, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
768 {
769  double kappa2 = -kappa*kappa;
770  double *kappa5 = (double*)malloc(Ls*sizeof(double));
771  for(int xs = 0; xs < Ls ; xs++)
772  kappa5[xs] = kappa;
773  void *tmp = malloc(V5h*spinorSiteSize*precision);
774  //------------------------------------------
775  double *output = (double*)out;
776  for(int k = 0 ; k< V5h*spinorSiteSize; k++)
777  output[k] = 0.0;
778  //------------------------------------------
779 
781  bool symmetric =(matpc_type == QUDA_MATPC_EVEN_EVEN || matpc_type == QUDA_MATPC_ODD_ODD) ? true : false;
782  QudaParity parity[2] = {static_cast<QudaParity>((1 + odd_bit) % 2), static_cast<QudaParity>((0 + odd_bit) % 2)};
783 
784  if (symmetric && !dagger_bit) {
785  dslash_4_4d(tmp, gauge, in, parity[0], dagger_bit, precision, gauge_param, mferm);
786  dslash_5_inv(out, gauge, tmp, parity[0], dagger_bit, precision, gauge_param, mferm, kappa5);
787  dslash_4_4d(tmp, gauge, out, parity[1], dagger_bit, precision, gauge_param, mferm);
788  dslash_5_inv(out, gauge, tmp, parity[1], dagger_bit, precision, gauge_param, mferm, kappa5);
789  xpay(in, kappa2, out, V5h*spinorSiteSize, precision);
790  } else if (symmetric && dagger_bit) {
791  dslash_5_inv(tmp, gauge, in, parity[1], dagger_bit, precision, gauge_param, mferm, kappa5);
792  dslash_4_4d(out, gauge, tmp, parity[0], dagger_bit, precision, gauge_param, mferm);
793  dslash_5_inv(tmp, gauge, out, parity[0], dagger_bit, precision, gauge_param, mferm, kappa5);
794  dslash_4_4d(out, gauge, tmp, parity[1], dagger_bit, precision, gauge_param, mferm);
795  xpay(in, kappa2, out, V5h*spinorSiteSize, precision);
796  } else {
797  dslash_4_4d(tmp, gauge, in, parity[0], dagger_bit, precision, gauge_param, mferm);
798  dslash_5_inv(out, gauge, tmp, parity[0], dagger_bit, precision, gauge_param, mferm, kappa5);
799  dslash_4_4d(tmp, gauge, out, parity[1], dagger_bit, precision, gauge_param, mferm);
800  xpay(in, kappa2, tmp, V5h*spinorSiteSize, precision);
801  dw_dslash_5_4d(out, gauge, in, parity[1], dagger_bit, precision, gauge_param, mferm, true);
802  xpay(tmp, -kappa, out, V5h*spinorSiteSize, precision);
803  }
804  free(tmp);
805  free(kappa5);
806 }
807 
808 void mdw_matpc(void *out, void **gauge, void *in, double *kappa_b, double *kappa_c, QudaMatPCType matpc_type, int dagger, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double *b5, double *c5)
809 {
810  void *tmp = malloc(V5h*spinorSiteSize*precision);
811  double *kappa5 = (double*)malloc(Ls*sizeof(double));
812  double *kappa2 = (double*)malloc(Ls*sizeof(double));
813  double *kappa_mdwf = (double*)malloc(Ls*sizeof(double));
814  for(int xs = 0; xs < Ls ; xs++)
815  {
816  kappa5[xs] = 0.5*kappa_b[xs]/kappa_c[xs];
817  kappa2[xs] = -kappa_b[xs]*kappa_b[xs];
818  kappa_mdwf[xs] = -kappa5[xs];
819  }
820 
822  bool symmetric =(matpc_type == QUDA_MATPC_EVEN_EVEN || matpc_type == QUDA_MATPC_ODD_ODD) ? true : false;
823  QudaParity parity[2] = {static_cast<QudaParity>((1 + odd_bit) % 2), static_cast<QudaParity>((0 + odd_bit) % 2)};
824 
825  if (symmetric && !dagger) {
826  mdw_dslash_4_pre(tmp, gauge, in, parity[1], dagger, precision, gauge_param, mferm, b5, c5, true);
827  dslash_4_4d(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm);
828  dslash_5_inv(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm, kappa_mdwf);
829  mdw_dslash_4_pre(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm, b5, c5, true);
830  dslash_4_4d(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm);
831  dslash_5_inv(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm, kappa_mdwf);
832  for(int xs = 0 ; xs < Ls ; xs++) {
833  xpay((char*)in + precision*Vh*spinorSiteSize*xs, kappa2[xs], (char*)out + precision*Vh*spinorSiteSize*xs,
834  Vh*spinorSiteSize, precision);
835  }
836  } else if (symmetric && dagger) {
837  dslash_5_inv(tmp, gauge, in, parity[1], dagger, precision, gauge_param, mferm, kappa_mdwf);
838  dslash_4_4d(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm);
839  mdw_dslash_4_pre(tmp, gauge, out, parity[0], dagger, precision, gauge_param, mferm, b5, c5, true);
840  dslash_5_inv(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm, kappa_mdwf);
841  dslash_4_4d(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm);
842  mdw_dslash_4_pre(out, gauge, tmp, parity[1], dagger, precision, gauge_param, mferm, b5, c5, true);
843  for(int xs = 0 ; xs < Ls ; xs++) {
844  xpay((char*)in + precision*Vh*spinorSiteSize*xs, kappa2[xs], (char*)out + precision*Vh*spinorSiteSize*xs,
845  Vh*spinorSiteSize, precision);
846  }
847  } else if (!symmetric && !dagger) {
848  mdw_dslash_4_pre(out, gauge, in, parity[1], dagger, precision, gauge_param, mferm, b5, c5, true);
849  dslash_4_4d(tmp, gauge, out, parity[0], dagger, precision, gauge_param, mferm);
850  dslash_5_inv(out, gauge, tmp, parity[1], dagger, precision, gauge_param, mferm, kappa_mdwf);
851  mdw_dslash_4_pre(tmp, gauge, out, parity[0], dagger, precision, gauge_param, mferm, b5, c5, true);
852  dslash_4_4d(out, gauge, tmp, parity[1], dagger, precision, gauge_param, mferm);
853  mdw_dslash_5(tmp, gauge, in, parity[0], dagger, precision, gauge_param, mferm, kappa5, true);
854  for(int xs = 0 ; xs < Ls ; xs++) {
855  xpay((char*)tmp + precision*Vh*spinorSiteSize*xs, kappa2[xs], (char*)out + precision*Vh*spinorSiteSize*xs,
856  Vh*spinorSiteSize, precision);
857  }
858  } else if (!symmetric && dagger) {
859  dslash_4_4d(out, gauge, in, parity[0], dagger, precision, gauge_param, mferm);
860  mdw_dslash_4_pre(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm, b5, c5, true);
861  dslash_5_inv(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm, kappa_mdwf);
862  dslash_4_4d(tmp, gauge, out, parity[1], dagger, precision, gauge_param, mferm);
863  mdw_dslash_4_pre(out, gauge, tmp, parity[0], dagger, precision, gauge_param, mferm, b5, c5, true);
864  mdw_dslash_5(tmp, gauge, in, parity[0], dagger, precision, gauge_param, mferm, kappa5, true);
865  for(int xs = 0 ; xs < Ls ; xs++) {
866  xpay((char*)tmp + precision*Vh*spinorSiteSize*xs, kappa2[xs], (char*)out + precision*Vh*spinorSiteSize*xs,
867  Vh*spinorSiteSize, precision);
868  }
869  } else {
870  errorQuda("Unsupported matpc_type=%d dagger=%d", matpc_type, dagger);
871  }
872 
873  free(tmp);
874  free(kappa5);
875  free(kappa2);
876  free(kappa_mdwf);
877 }
878 
879 /*
880 // Apply the even-odd preconditioned Dirac operator
881 template <typename sFloat, typename gFloat>
882 void MatPC(sFloat *outEven, gFloat **gauge, sFloat *inEven, sFloat kappa,
883  QudaMatPCType matpc_type, sFloat mferm) {
884 
885  sFloat *tmp = (sFloat*)malloc(V5h*spinorSiteSize*sizeof(sFloat));
886 
887  // full dslash operator
888  if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
889  dslashReference_4d(tmp, gauge, inEven, 1, 0);
890  dslashReference_5th(tmp, inEven, 1, 0, mferm);
891  dslashReference_4d(outEven, gauge, tmp, 0, 0);
892  dslashReference_5th(outEven, tmp, 0, 0, mferm);
893  } else {
894  dslashReference_4d(tmp, gauge, inEven, 0, 0);
895  dslashReference_5th(tmp, inEven, 0, 0, mferm);
896  dslashReference_4d(outEven, gauge, tmp, 1, 0);
897  dslashReference_5th(outEven, tmp, 1, 0, mferm);
898  }
899 
900  // lastly apply the kappa term
901  sFloat kappa2 = -kappa*kappa;
902  xpay(inEven, kappa2, outEven, V5h*spinorSiteSize);
903  free(tmp);
904 }
905 
906 // Apply the even-odd preconditioned Dirac operator
907 template <typename sFloat, typename gFloat>
908 void MatPCDag(sFloat *outEven, gFloat **gauge, sFloat *inEven, sFloat kappa,
909  QudaMatPCType matpc_type, sFloat mferm) {
910 
911  sFloat *tmp = (sFloat*)malloc(V5h*spinorSiteSize*sizeof(sFloat));
912 
913  // full dslash operator
914  if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
915  dslashReference_4d(tmp, gauge, inEven, 1, 1);
916  dslashReference_5th(tmp, inEven, 1, 1, mferm);
917  dslashReference_4d(outEven, gauge, tmp, 0, 1);
918  dslashReference_5th(outEven, tmp, 0, 1, mferm);
919  } else {
920  dslashReference_4d(tmp, gauge, inEven, 0, 1);
921  dslashReference_5th(tmp, inEven, 0, 1, mferm);
922  dslashReference_4d(outEven, gauge, tmp, 1, 1);
923  dslashReference_5th(outEven, tmp, 1, 1, mferm);
924  }
925 
926  sFloat kappa2 = -kappa*kappa;
927  xpay(inEven, kappa2, outEven, V5h*spinorSiteSize);
928  free(tmp);
929 }
930 */
931 
932 void matpc(void *outEven, void **gauge, void *inEven, double kappa,
933  QudaMatPCType matpc_type, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision,
934  double mferm) {
935 /*
936  if (!dagger_bit) {
937  if (sPrecision == QUDA_DOUBLE_PRECISION)
938  if (gPrecision == QUDA_DOUBLE_PRECISION)
939  MatPC((double*)outEven, (double**)gauge, (double*)inEven, (double)kappa, matpc_type, (double)mferm);
940  else
941  MatPC((double*)outEven, (float**)gauge, (double*)inEven, (double)kappa, matpc_type, (double)mferm);
942  else
943  if (gPrecision == QUDA_DOUBLE_PRECISION)
944  MatPC((float*)outEven, (double**)gauge, (float*)inEven, (float)kappa, matpc_type, (float)mferm);
945  else
946  MatPC((float*)outEven, (float**)gauge, (float*)inEven, (float)kappa, matpc_type, (float)mferm);
947  } else {
948  if (sPrecision == QUDA_DOUBLE_PRECISION)
949  if (gPrecision == QUDA_DOUBLE_PRECISION)
950  MatPCDag((double*)outEven, (double**)gauge, (double*)inEven, (double)kappa, matpc_type, (double)mferm);
951  else
952  MatPCDag((double*)outEven, (float**)gauge, (double*)inEven, (double)kappa, matpc_type, (double)mferm);
953  else
954  if (gPrecision == QUDA_DOUBLE_PRECISION)
955  MatPCDag((float*)outEven, (double**)gauge, (float*)inEven, (float)kappa, matpc_type, (float)mferm);
956  else
957  MatPCDag((float*)outEven, (float**)gauge, (float*)inEven, (float)kappa, matpc_type, (float)mferm);
958  }
959 */
960 }
961 
962 /*
963 template <typename sFloat, typename gFloat>
964 void MatDagMat(sFloat *out, gFloat **gauge, sFloat *in, sFloat kappa, sFloat mferm)
965 {
966  // Allocate a full spinor.
967  sFloat *tmp = (sFloat*)malloc(V5*spinorSiteSize*sizeof(sFloat));
968  // Call templates above.
969  Mat(tmp, gauge, in, kappa, mferm);
970  MatDag(out, gauge, tmp, kappa, mferm);
971  free(tmp);
972 }
973 
974 template <typename sFloat, typename gFloat>
975 void MatPCDagMatPC(sFloat *out, gFloat **gauge, sFloat *in, sFloat kappa,
976  QudaMatPCType matpc_type, sFloat mferm)
977 {
978 
979  // Allocate half spinor
980  sFloat *tmp = (sFloat*)malloc(V5h*spinorSiteSize*sizeof(sFloat));
981  // Apply the PC templates above
982  MatPC(tmp, gauge, in, kappa, matpc_type, mferm);
983  MatPCDag(out, gauge, tmp, kappa, matpc_type, mferm);
984  free(tmp);
985 }
986 */
987 // Wrapper to templates that handles different precisions.
988 void matdagmat(void *out, void **gauge, void *in, double kappa,
989  QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm)
990 {
991 /*
992  if (sPrecision == QUDA_DOUBLE_PRECISION) {
993  if (gPrecision == QUDA_DOUBLE_PRECISION)
994  MatDagMat((double*)out, (double**)gauge, (double*)in, (double)kappa,
995  (double)mferm);
996  else
997  MatDagMat((double*)out, (float**)gauge, (double*)in, (double)kappa, (double)mferm);
998  } else {
999  if (gPrecision == QUDA_DOUBLE_PRECISION)
1000  MatDagMat((float*)out, (double**)gauge, (float*)in, (float)kappa,
1001  (float)mferm);
1002  else
1003  MatDagMat((float*)out, (float**)gauge, (float*)in, (float)kappa, (float)mferm);
1004  }
1005 */
1006 }
1007 
1008 // Wrapper to templates that handles different precisions.
1009 void matpcdagmatpc(void *out, void **gauge, void *in, double kappa,
1010  QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm, QudaMatPCType matpc_type)
1011 {
1012 /*
1013  if (sPrecision == QUDA_DOUBLE_PRECISION) {
1014  if (gPrecision == QUDA_DOUBLE_PRECISION)
1015  MatPCDagMatPC((double*)out, (double**)gauge, (double*)in, (double)kappa,
1016  matpc_type, (double)mferm);
1017  else
1018  MatPCDagMatPC((double*)out, (float**)gauge, (double*)in, (double)kappa,
1019  matpc_type, (double)mferm);
1020  } else {
1021  if (gPrecision == QUDA_DOUBLE_PRECISION)
1022  MatPCDagMatPC((float*)out, (double**)gauge, (float*)in, (float)kappa,
1023  matpc_type, (float)mferm);
1024  else
1025  MatPCDagMatPC((float*)out, (float**)gauge, (float*)in, (float)kappa,
1026  matpc_type, (float)mferm);
1027  }
1028 */
1029 }
1030 
1031 
QudaGhostExchange ghostExchange
Definition: lattice_field.h:60
void dw_4d_matpc(void *out, void **gauge, void *in, double kappa, QudaMatPCType matpc_type, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
void dw_dslash_5_4d(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, bool zero_initialize)
void free(void *)
void xpay(ColorSpinorField &x, const double &a, ColorSpinorField &y)
Definition: blas_quda.cu:173
enum QudaPrecision_s QudaPrecision
void dslashReference_5th(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm)
#define errorQuda(...)
Definition: util_quda.h:90
void printSpinorElement(void *spinor, int X, QudaPrecision precision)
Definition: test_util.cpp:204
int sp_idx
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:44
int fullLatticeIndex_5d_4dpc(int i, int oddBit)
Definition: test_util.cpp:697
QudaPrecision precision
Definition: lattice_field.h:54
void dw_dslash(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
QudaGaugeParam gauge_param
void mdw_mat(void *out, void **gauge, void *in, double *kappa_b, double *kappa_c, int dagger, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double *b5, double *c5)
void dw_mat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
#define mferm
#define spinorSiteSize
QudaSiteSubset siteSubset
Definition: lattice_field.h:55
void exit(int) __attribute__((noreturn))
static void axpby(Float a, Float *x, Float b, Float *y, int len)
Definition: dslash_util.h:33
size_t size_t offset
int fullLatticeIndex(int dim[4], int index, int oddBit)
Definition: test_util.cpp:442
int Ls
Definition: test_util.cpp:39
void matpc(void *outEven, void **gauge, void *inEven, double kappa, QudaMatPCType matpc_type, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm)
int x[QUDA_MAX_DIM]
Definition: lattice_field.h:50
void dw_matdagmat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
void dw_matpc(void *out, void **gauge, void *in, double kappa, QudaMatPCType matpc_type, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
void ax(const double &a, GaugeField &u)
Scale the gauge field by the scalar a.
void * malloc(size_t __size) __attribute__((__warn_unused_result__)) __attribute__((alloc_size(1)))
int printf(const char *,...) __attribute__((__format__(__printf__
void mdw_dslash_5(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double *kappa, bool zero_initialize)
VOLATILE spinorFloat kappa
__host__ __device__ void sum(double &a, double &b)
int V5h
Definition: test_util.cpp:41
ColorSpinorParam csParam
Definition: pack_test.cpp:24
cpuColorSpinorField * in
void mdw_dslash_4_pre(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double *b5, double *c5, bool zero_initialize)
#define mySpinorSiteSize
void dw_4d_mat(void *out, void **gauge, void *in, double kappa, int dagger_bit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
Float * gaugeLink_sgpu(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd)
enum QudaMatPCType_s QudaMatPCType
#define gaugeSiteSize
Definition: test_util.h:6
void mdw_matpc(void *out, void **gauge, void *in, double *kappa_b, double *kappa_c, QudaMatPCType matpc_type, int dagger, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double *b5, double *c5)
void dslashReference_4d_sgpu(sFloat *res, gFloat **gaugeFull, sFloat *spinorField, int oddBit, int daggerBit)
__host__ __device__ ValueType pow(ValueType x, ExponentType e)
Definition: complex_quda.h:100
int Z[4]
Definition: test_util.cpp:27
const void ** Ghost() const
Definition: gauge_field.h:254
enum QudaParity_s QudaParity
QudaMatPCType matpc_type
Definition: test_util.cpp:1652
void matpcdagmatpc(void *out, void **gauge, void *in, double kappa, QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm, QudaMatPCType matpc_type)
void * memcpy(void *__dst, const void *__src, size_t __n)
void dslashReference_5th_inv(sFloat *res, sFloat *spinorField, int oddBit, int daggerBit, sFloat mferm, double *kappa)
int fullLatticeIndex_5d(int i, int oddBit)
Definition: test_util.cpp:692
cpuColorSpinorField * out
Main header file for the QUDA library.
int fullLatticeIndex_4d(int i, int oddBit)
Definition: test_util.cpp:658
const double projector[10][4][4][2]
int Vh
Definition: test_util.cpp:29
void dslash_4_4d(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm)
int V5
Definition: test_util.cpp:40
int faceVolume[4]
Definition: test_util.cpp:32
void matdagmat(void *out, void **gauge, void *in, double kappa, QudaPrecision sPrecision, QudaPrecision gPrecision, double mferm)
static void su3Mul(sFloat *res, gFloat *mat, sFloat *vec)
Definition: dslash_util.h:80
__device__ void axpy(real a, const real *x, Link &y)
static void su3Tmul(sFloat *res, gFloat *mat, sFloat *vec)
Definition: dslash_util.h:85
Float * gaugeLink_mgpu(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd, Float **ghostGaugeEven, Float **ghostGaugeOdd, int n_ghost_faces, int nbr_distance)
static __inline__ size_t size_t d
int neighborIndex_4d(int i, int oddBit, int dx4, int dx3, int dx2, int dx1)
QudaParity parity
Definition: covdev_test.cpp:53
void dslash_5_inv(void *out, void **gauge, void *in, int oddBit, int daggerBit, QudaPrecision precision, QudaGaugeParam &gauge_param, double mferm, double *kappa)
void multiplySpinorByDiracProjector5(Float *res, int projIdx, Float *spinorIn)
double kappa5
cpuColorSpinorField * spinor
Definition: covdev_test.cpp:41
int comm_dim_partitioned(int dim)