6 template <
typename Float>
8 for (
int i = 0; i < cnt; i++)
12 template <
typename Float>
14 for (
int i = 0; i < cnt; i++)
18 template <
typename Float>
20 for (
int i = 0; i < cnt; i++)
25 template <
typename Float>
27 for (
int i=0; i<len; i++) y[i] = x[i] + a*y[i];
31 template <
typename Float>
33 for (
int i=0; i<len; i++) y[i] = a*x[i] + y[i];
37 template <
typename Float>
39 for (
int i=0; i<len; i++) y[i] = a*x[i] + b*y[i];
43 template <
typename Float>
45 for (
int i=0; i<len; i++) y[i] = a*x[i] - y[i];
48 template <
typename Float>
51 for (
int i=0; i<len; i++) sum += v[i]*v[i];
55 template <
typename Float>
56 static inline void negx(
Float *x,
int len) {
57 for (
int i=0; i<len; i++) x[i] = -x[i];
60 template <
typename sFloat,
typename gFloat>
61 static inline void dot(sFloat* res, gFloat* a, sFloat* b) {
63 for (
int m = 0; m < 3; m++) {
64 sFloat a_re = a[2*m+0];
65 sFloat a_im = a[2*m+1];
66 sFloat b_re = b[2*m+0];
67 sFloat b_im = b[2*m+1];
68 res[0] += a_re * b_re - a_im * b_im;
69 res[1] += a_re * b_im + a_im * b_re;
73 template <
typename Float>
75 for (
int m = 0; m < 3; m++) {
76 for (
int n = 0; n < 3; n++) {
77 res[m*(3*2) + n*(2) + 0] = + mat[n*(3*2) + m*(2) + 0];
78 res[m*(3*2) + n*(2) + 1] = - mat[n*(3*2) + m*(2) + 1];
84 template <
typename sFloat,
typename gFloat>
85 static inline void su3Mul(sFloat *res, gFloat *mat, sFloat *vec) {
86 for (
int n = 0; n < 3; n++) dot(&res[n*(2)], &mat[n*(3*2)], vec);
89 template <
typename sFloat,
typename gFloat>
90 static inline void su3Tmul(sFloat *res, gFloat *mat, sFloat *vec) {
92 su3Transpose(matT, mat);
93 su3Mul(res, matT, vec);
108 template <
typename Float>
109 static inline Float *gaugeLink(
int i,
int dir,
int oddBit,
Float **gaugeEven,
Float **gaugeOdd,
int nbr_distance) {
112 int d = nbr_distance;
115 gaugeField = (oddBit ? gaugeOdd : gaugeEven);
123 default: j = -1;
break;
125 gaugeField = (oddBit ? gaugeEven : gaugeOdd);
128 return &gaugeField[dir/2][j*(3*3*2)];
131 template <
typename Float>
132 static inline Float *spinorNeighbor(
int i,
int dir,
int oddBit,
Float *spinorField,
int neighbor_distance)
135 int nb = neighbor_distance;
145 default: j = -1;
break;
155 x4_mg(
int i,
int oddBit)
158 int x4 = Y/(
Z[2]*
Z[1]*
Z[0]);
162 template <
typename Float>
163 static inline Float *gaugeLink_mg4dir(
int i,
int dir,
int oddBit,
Float **gaugeEven,
Float **gaugeOdd,
164 Float** ghostGaugeEven,
Float** ghostGaugeOdd,
int n_ghost_faces,
int nbr_distance) {
167 int d = nbr_distance;
170 gaugeField = (oddBit ? gaugeOdd : gaugeEven);
175 int x4 = Y/(
Z[2]*
Z[1]*
Z[0]);
176 int x3 = (Y/(
Z[1]*
Z[0])) %
Z[2];
177 int x2 = (Y/
Z[0]) %
Z[1];
183 Float* ghostGaugeField;
188 int new_x1 = (x1 - d +
X1 )% X1;
190 ghostGaugeField = (oddBit?ghostGaugeEven[0]: ghostGaugeOdd[0]);
191 int offset = (n_ghost_faces + x1 -d)*X4*X3*X2/2 + (x4*X3*X2 + x3*X2+x2)/2;
192 return &ghostGaugeField[offset*(3*3*2)];
194 j = (x4*X3*X2*X1 + x3*X2*X1 + x2*X1 + new_x1) / 2;
199 int new_x2 = (x2 - d +
X2 )% X2;
201 ghostGaugeField = (oddBit?ghostGaugeEven[1]: ghostGaugeOdd[1]);
202 int offset = (n_ghost_faces + x2 -d)*X4*X3*X1/2 + (x4*X3*X1 + x3*X1+x1)/2;
203 return &ghostGaugeField[offset*(3*3*2)];
205 j = (x4*X3*X2*X1 + x3*X2*X1 + new_x2*X1 +
x1) / 2;
211 int new_x3 = (x3 - d +
X3 )% X3;
213 ghostGaugeField = (oddBit?ghostGaugeEven[2]: ghostGaugeOdd[2]);
214 int offset = (n_ghost_faces + x3 -d)*X4*X2*X1/2 + (x4*X2*X1 + x2*X1+x1)/2;
215 return &ghostGaugeField[offset*(3*3*2)];
217 j = (x4*X3*X2*X1 + new_x3*X2*X1 + x2*X1 +
x1) / 2;
222 int new_x4 = (x4 - d +
X4)% X4;
224 ghostGaugeField = (oddBit?ghostGaugeEven[3]: ghostGaugeOdd[3]);
225 int offset = (n_ghost_faces + x4 -d)*X1*X2*X3/2 + (x3*X2*X1 + x2*X1+x1)/2;
226 return &ghostGaugeField[offset*(3*3*2)];
228 j = (new_x4*(X3*X2*
X1) + x3*(X2*X1) + x2*(
X1) + x1) / 2;
232 default: j = -1; printf(
"ERROR: wrong dir \n"); exit(1);
234 gaugeField = (oddBit ? gaugeEven : gaugeOdd);
238 return &gaugeField[dir/2][j*(3*3*2)];
241 template <
typename Float>
242 static inline Float *spinorNeighbor_mg4dir(
int i,
int dir,
int oddBit,
Float *spinorField,
Float** fwd_nbr_spinor,
243 Float** back_nbr_spinor,
int neighbor_distance,
int nFace)
246 int nb = neighbor_distance;
248 int x4 = Y/(
Z[2]*
Z[1]*
Z[0]);
249 int x3 = (Y/(
Z[1]*
Z[0])) %
Z[2];
250 int x2 = (Y/
Z[0]) %
Z[1];
260 int new_x1 = (x1 + nb)% X1;
262 int offset = ( x1 + nb -
X1)*X4*X3*X2/2+(x4*X3*X2 + x3*X2+x2)/2;
265 j = (x4*X3*X2*X1 + x3*X2*X1 + x2*X1 + new_x1) / 2;
271 int new_x1 = (x1 - nb +
X1)% X1;
273 int offset = ( x1+nFace- nb)*X4*X3*X2/2+(x4*X3*X2 + x3*X2+x2)/2;
276 j = (x4*X3*X2*X1 + x3*X2*X1 + x2*X1 + new_x1) / 2;
281 int new_x2 = (x2 + nb)% X2;
283 int offset = ( x2 + nb -
X2)*X4*X3*X1/2+(x4*X3*X1 + x3*X1+x1)/2;
286 j = (x4*X3*X2*X1 + x3*X2*X1 + new_x2*X1 +
x1) / 2;
291 int new_x2 = (x2 - nb +
X2)% X2;
293 int offset = ( x2 + nFace -nb)*X4*X3*X1/2+(x4*X3*X1 + x3*X1+x1)/2;
296 j = (x4*X3*X2*X1 + x3*X2*X1 + new_x2*X1 +
x1) / 2;
301 int new_x3 = (x3 + nb)% X3;
303 int offset = ( x3 + nb -
X3)*X4*X2*X1/2+(x4*X2*X1 + x2*X1+x1)/2;
306 j = (x4*X3*X2*X1 + new_x3*X2*X1 + x2*X1 +
x1) / 2;
311 int new_x3 = (x3 - nb +
X3)% X3;
313 int offset = ( x3 + nFace -nb)*X4*X2*X1/2+(x4*X2*X1 + x2*X1+x1)/2;
316 j = (x4*X3*X2*X1 + new_x3*X2*X1 + x2*X1 +
x1) / 2;
322 int x4 = x4_mg(i, oddBit);
323 if ( (x4 + nb) >=
Z[3]){
324 int offset = (x4+nb -
Z[3])*
Vsh_t;
332 int x4 = x4_mg(i, oddBit);
334 int offset = ( x4 - nb +nFace)*
Vsh_t;
339 default: j = -1; printf(
"ERROR: wrong dir\n"); exit(1);
347 #endif // _DSLASH_UTIL_H
__device__ __forceinline__ int neighborIndex(const unsigned int &cb_idx, const int(&shift)[4], const bool(&partitioned)[4], const unsigned int &parity)
void axpby(const Float &a, const Float *x, const Float &b, Float *y, const int N)
void mat(void *out, void **fatlink, void **longlink, void *in, double kappa, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision)
void ax(double a, void *x, int len, QudaPrecision precision)
int fullLatticeIndex(int dim[4], int index, int oddBit)
FloatingPoint< float > Float
double norm2(Float *v, int len)
int neighborIndex_mg(int i, int oddBit, int dx4, int dx3, int dx2, int dx1)
void axpy(double a, void *x, void *y, int len, QudaPrecision precision)