QUDA  0.9.0
dw_dslash5inv_dagger_core.h
Go to the documentation of this file.
1 // *** CUDA DSLASH DAGGER ***
2 
3 #define DSLASH_SHARED_FLOATS_PER_THREAD 0
4 
5 
6 #if (CUDA_VERSION >= 4010)
7 #define VOLATILE
8 #else
9 #define VOLATILE volatile
10 #endif
11 // input spinor
12 #ifdef SPINOR_DOUBLE
13 #define spinorFloat double
14 // workaround for C++11 bug in CUDA 6.5/7.0
15 #if CUDA_VERSION >= 6050 && CUDA_VERSION < 7050
16 #define POW(a, b) pow(a, static_cast<spinorFloat>(b))
17 #else
18 #define POW(a, b) pow(a, b)
19 #endif
20 
21 #define i00_re I0.x
22 #define i00_im I0.y
23 #define i01_re I1.x
24 #define i01_im I1.y
25 #define i02_re I2.x
26 #define i02_im I2.y
27 #define i10_re I3.x
28 #define i10_im I3.y
29 #define i11_re I4.x
30 #define i11_im I4.y
31 #define i12_re I5.x
32 #define i12_im I5.y
33 #define i20_re I6.x
34 #define i20_im I6.y
35 #define i21_re I7.x
36 #define i21_im I7.y
37 #define i22_re I8.x
38 #define i22_im I8.y
39 #define i30_re I9.x
40 #define i30_im I9.y
41 #define i31_re I10.x
42 #define i31_im I10.y
43 #define i32_re I11.x
44 #define i32_im I11.y
45 #define m5 param.m5_d
46 #define mdwf_b5 param.mdwf_b5_d
47 #define mdwf_c5 param.mdwf_c5_d
48 #define mferm param.mferm
49 #define a param.a
50 #define b param.b
51 #else
52 #define spinorFloat float
53 #define POW(a, b) __fast_pow(a, b)
54 #define i00_re I0.x
55 #define i00_im I0.y
56 #define i01_re I0.z
57 #define i01_im I0.w
58 #define i02_re I1.x
59 #define i02_im I1.y
60 #define i10_re I1.z
61 #define i10_im I1.w
62 #define i11_re I2.x
63 #define i11_im I2.y
64 #define i12_re I2.z
65 #define i12_im I2.w
66 #define i20_re I3.x
67 #define i20_im I3.y
68 #define i21_re I3.z
69 #define i21_im I3.w
70 #define i22_re I4.x
71 #define i22_im I4.y
72 #define i30_re I4.z
73 #define i30_im I4.w
74 #define i31_re I5.x
75 #define i31_im I5.y
76 #define i32_re I5.z
77 #define i32_im I5.w
78 #define m5 param.m5_f
79 #define mdwf_b5 param.mdwf_b5_f
80 #define mdwf_c5 param.mdwf_c5_f
81 #define mferm param.mferm_f
82 #define a param.a
83 #define b param.b
84 #endif // SPINOR_DOUBLE
85 
86 // output spinor
111 
112 #ifdef SPINOR_DOUBLE
113 #if (__COMPUTE_CAPABILITY__ >= 200)
114 #define SHARED_STRIDE 16 // to avoid bank conflicts on Fermi
115 #else
116 #define SHARED_STRIDE 8 // to avoid bank conflicts on G80 and GT200
117 #endif
118 #else
119 #if (__COMPUTE_CAPABILITY__ >= 200)
120 #define SHARED_STRIDE 32 // to avoid bank conflicts on Fermi
121 #else
122 #define SHARED_STRIDE 16 // to avoid bank conflicts on G80 and GT200
123 #endif
124 #endif
125 #include "io_spinor.h"
126 
127 int sid = ((blockIdx.y*blockDim.y + threadIdx.y)*gridDim.x + blockIdx.x)*blockDim.x + threadIdx.x;
128 if (sid >= param.threads*param.dc.Ls) return;
129 
130 
132 
133 
134 
135 boundaryCrossing = sid/param.dc.Xh[0] + sid/(param.dc.X[1]*param.dc.Xh[0]) + sid/(param.dc.X[2]*param.dc.X[1]*param.dc.Xh[0]);
136 
137 X = 2*sid + (boundaryCrossing + param.parity) % 2;
138 coord[4] = X/(param.dc.X[0]*param.dc.X[1]*param.dc.X[2]*param.dc.X[3]);
139 
140  o00_re = 0; o00_im = 0;
141  o01_re = 0; o01_im = 0;
142  o02_re = 0; o02_im = 0;
143  o10_re = 0; o10_im = 0;
144  o11_re = 0; o11_im = 0;
145  o12_re = 0; o12_im = 0;
146  o20_re = 0; o20_im = 0;
147  o21_re = 0; o21_im = 0;
148  o22_re = 0; o22_im = 0;
149  o30_re = 0; o30_im = 0;
150  o31_re = 0; o31_im = 0;
151  o32_re = 0; o32_im = 0;
152 
154 
155 #ifdef MDWF_mode // Check whether MDWF option is enabled
156  kappa = -(mdwf_c5[ coord[4] ]*(static_cast<spinorFloat>(4.0) + m5) - static_cast<spinorFloat>(1.0))/(mdwf_b5[ coord[4] ]*(static_cast<spinorFloat>(4.0) + m5) + static_cast<spinorFloat>(1.0));
157 #else
158  kappa = static_cast<spinorFloat>(2.0)*a;
159 #endif // select MDWF mode
160 
161 // M5_inv operation -- NB: not partitionable!
162 
163 // In this part, we will do the following operation in parallel way.
164 
165 // w = M5inv * v
166 // 'w' means output vector
167 // 'v' means input vector
168 {
169  int base_idx = sid%param.dc.volume_4d_cb;
170  int sp_idx;
171 
172 // let's assume the index,
173 // s = output vector index,
174 // s' = input vector index and
175 // 'a'= kappa5
176 
177  spinorFloat inv_d_n = static_cast<spinorFloat>(0.5) / ( static_cast<spinorFloat>(1.0) + POW(kappa,param.dc.Ls)*mferm );
180 
181  for(int s = 0; s < param.dc.Ls; s++)
182  {
183  int exponent = coord[4] > s ? param.dc.Ls-coord[4]+s : s-coord[4];
184  factorR = inv_d_n * POW(kappa,exponent) * ( coord[4] > s ? -mferm : static_cast<spinorFloat>(1.0) );
185 
186  sp_idx = base_idx + s*param.dc.volume_4d_cb;
187  // read spinor from device memory
188  READ_SPINOR( SPINORTEX, param.sp_stride, sp_idx, sp_idx );
189 
190  o00_re += factorR*(i00_re + i20_re);
191  o00_im += factorR*(i00_im + i20_im);
192  o20_re += factorR*(i00_re + i20_re);
193  o20_im += factorR*(i00_im + i20_im);
194  o01_re += factorR*(i01_re + i21_re);
195  o01_im += factorR*(i01_im + i21_im);
196  o21_re += factorR*(i01_re + i21_re);
197  o21_im += factorR*(i01_im + i21_im);
198  o02_re += factorR*(i02_re + i22_re);
199  o02_im += factorR*(i02_im + i22_im);
200  o22_re += factorR*(i02_re + i22_re);
201  o22_im += factorR*(i02_im + i22_im);
202  o10_re += factorR*(i10_re + i30_re);
203  o10_im += factorR*(i10_im + i30_im);
204  o30_re += factorR*(i10_re + i30_re);
205  o30_im += factorR*(i10_im + i30_im);
206  o11_re += factorR*(i11_re + i31_re);
207  o11_im += factorR*(i11_im + i31_im);
208  o31_re += factorR*(i11_re + i31_re);
209  o31_im += factorR*(i11_im + i31_im);
210  o12_re += factorR*(i12_re + i32_re);
211  o12_im += factorR*(i12_im + i32_im);
212  o32_re += factorR*(i12_re + i32_re);
213  o32_im += factorR*(i12_im + i32_im);
214 
215  int exponent2 = coord[4] < s ? param.dc.Ls-s+coord[4] : coord[4]-s;
216  factorL = inv_d_n * POW(kappa,exponent2) * ( coord[4] < s ? -mferm : static_cast<spinorFloat>(1.0));
217 
218  o00_re += factorL*(i00_re - i20_re);
219  o00_im += factorL*(i00_im - i20_im);
220  o01_re += factorL*(i01_re - i21_re);
221  o01_im += factorL*(i01_im - i21_im);
222  o02_re += factorL*(i02_re - i22_re);
223  o02_im += factorL*(i02_im - i22_im);
224  o10_re += factorL*(i10_re - i30_re);
225  o10_im += factorL*(i10_im - i30_im);
226  o11_re += factorL*(i11_re - i31_re);
227  o11_im += factorL*(i11_im - i31_im);
228  o12_re += factorL*(i12_re - i32_re);
229  o12_im += factorL*(i12_im - i32_im);
230  o20_re += factorL*(i20_re - i00_re);
231  o20_im += factorL*(i20_im - i00_im);
232  o21_re += factorL*(i21_re - i01_re);
233  o21_im += factorL*(i21_im - i01_im);
234  o22_re += factorL*(i22_re - i02_re);
235  o22_im += factorL*(i22_im - i02_im);
236  o30_re += factorL*(i30_re - i10_re);
237  o30_im += factorL*(i30_im - i10_im);
238  o31_re += factorL*(i31_re - i11_re);
239  o31_im += factorL*(i31_im - i11_im);
240  o32_re += factorL*(i32_re - i12_re);
241  o32_im += factorL*(i32_im - i12_im);
242  }
243 } // end of M5inv dimension
244 
245 #undef POW
246 {
247 
248 #ifdef DSLASH_XPAY
249  READ_ACCUM(ACCUMTEX, param.sp_stride)
251 
252 #ifdef MDWF_mode
253  coeff = static_cast<spinorFloat>(0.5)/(mdwf_b5[coord[4]]*(m5+static_cast<spinorFloat>(4.0)) + static_cast<spinorFloat>(1.0));
254  coeff *= coeff;
255  coeff *= a;
256 #else
257  coeff = b;
258 #endif
259 
260 #ifdef SPINOR_DOUBLE
261  o00_re = coeff*o00_re + accum0.x;
262  o00_im = coeff*o00_im + accum0.y;
263  o01_re = coeff*o01_re + accum1.x;
264  o01_im = coeff*o01_im + accum1.y;
265  o02_re = coeff*o02_re + accum2.x;
266  o02_im = coeff*o02_im + accum2.y;
267  o10_re = coeff*o10_re + accum3.x;
268  o10_im = coeff*o10_im + accum3.y;
269  o11_re = coeff*o11_re + accum4.x;
270  o11_im = coeff*o11_im + accum4.y;
271  o12_re = coeff*o12_re + accum5.x;
272  o12_im = coeff*o12_im + accum5.y;
273  o20_re = coeff*o20_re + accum6.x;
274  o20_im = coeff*o20_im + accum6.y;
275  o21_re = coeff*o21_re + accum7.x;
276  o21_im = coeff*o21_im + accum7.y;
277  o22_re = coeff*o22_re + accum8.x;
278  o22_im = coeff*o22_im + accum8.y;
279  o30_re = coeff*o30_re + accum9.x;
280  o30_im = coeff*o30_im + accum9.y;
281  o31_re = coeff*o31_re + accum10.x;
282  o31_im = coeff*o31_im + accum10.y;
283  o32_re = coeff*o32_re + accum11.x;
284  o32_im = coeff*o32_im + accum11.y;
285 #else
286  o00_re = coeff*o00_re + accum0.x;
287  o00_im = coeff*o00_im + accum0.y;
288  o01_re = coeff*o01_re + accum0.z;
289  o01_im = coeff*o01_im + accum0.w;
290  o02_re = coeff*o02_re + accum1.x;
291  o02_im = coeff*o02_im + accum1.y;
292  o10_re = coeff*o10_re + accum1.z;
293  o10_im = coeff*o10_im + accum1.w;
294  o11_re = coeff*o11_re + accum2.x;
295  o11_im = coeff*o11_im + accum2.y;
296  o12_re = coeff*o12_re + accum2.z;
297  o12_im = coeff*o12_im + accum2.w;
298  o20_re = coeff*o20_re + accum3.x;
299  o20_im = coeff*o20_im + accum3.y;
300  o21_re = coeff*o21_re + accum3.z;
301  o21_im = coeff*o21_im + accum3.w;
302  o22_re = coeff*o22_re + accum4.x;
303  o22_im = coeff*o22_im + accum4.y;
304  o30_re = coeff*o30_re + accum4.z;
305  o30_im = coeff*o30_im + accum4.w;
306  o31_re = coeff*o31_re + accum5.x;
307  o31_im = coeff*o31_im + accum5.y;
308  o32_re = coeff*o32_re + accum5.z;
309  o32_im = coeff*o32_im + accum5.w;
310 #endif // SPINOR_DOUBLE
311 #endif // DSLASH_XPAY
312 }
313 
314 // write spinor field back to device memory
315 WRITE_SPINOR(param.sp_stride);
316 
317 // undefine to prevent warning when precision is changed
318 #undef m5
319 #undef mdwf_b5
320 #undef mdwf_c5
321 #undef mferm
322 #undef a
323 #undef b
324 #undef spinorFloat
325 #undef POW
326 #undef SHARED_STRIDE
327 
328 #undef i00_re
329 #undef i00_im
330 #undef i01_re
331 #undef i01_im
332 #undef i02_re
333 #undef i02_im
334 #undef i10_re
335 #undef i10_im
336 #undef i11_re
337 #undef i11_im
338 #undef i12_re
339 #undef i12_im
340 #undef i20_re
341 #undef i20_im
342 #undef i21_re
343 #undef i21_im
344 #undef i22_re
345 #undef i22_im
346 #undef i30_re
347 #undef i30_im
348 #undef i31_re
349 #undef i31_im
350 #undef i32_re
351 #undef i32_im
352 
353 
354 
355 #undef VOLATILE
VOLATILE spinorFloat o21_re
dim3 dim3 blockDim
VOLATILE spinorFloat o12_im
#define i10_re
VOLATILE spinorFloat o20_re
#define i00_re
#define i01_im
spinorFloat inv_d_n
VOLATILE spinorFloat o32_re
#define WRITE_SPINOR
int coord[5]
#define i22_re
#define i12_re
#define i02_re
VOLATILE spinorFloat o22_im
VOLATILE spinorFloat o11_im
#define i30_im
#define i11_im
VOLATILE spinorFloat o02_re
#define i10_im
VOLATILE spinorFloat o12_re
#define i01_re
spinorFloat factorL
VOLATILE spinorFloat kappa
VOLATILE spinorFloat o01_re
int boundaryCrossing
QudaGaugeParam param
Definition: pack_test.cpp:17
#define i30_re
VOLATILE spinorFloat o31_im
#define i21_re
#define m5
VOLATILE spinorFloat o10_im
#define i20_im
#define SPINORTEX
#define mdwf_c5
VOLATILE spinorFloat o02_im
#define READ_SPINOR
#define POW(a, b)
#define i31_im
int X[4]
Definition: quda.h:29
#define i32_im
VOLATILE spinorFloat o30_re
#define i21_im
VOLATILE spinorFloat o01_im
#define i31_re
VOLATILE spinorFloat o11_re
#define i22_im
#define mferm
#define mdwf_b5
VOLATILE spinorFloat o00_re
VOLATILE spinorFloat o32_im
spinorFloat factorR
VOLATILE spinorFloat o31_re
#define i02_im
#define a
VOLATILE spinorFloat o30_im
#define spinorFloat
#define VOLATILE
#define i12_im
#define b
VOLATILE spinorFloat o00_im
#define i00_im
VOLATILE spinorFloat o20_im
#define i11_re
VOLATILE spinorFloat o21_im
#define i32_re
VOLATILE spinorFloat o22_re
VOLATILE spinorFloat o10_re
#define i20_re