QUDA  0.9.0
ndeg_tm_dslash_cuda_gen.py
Go to the documentation of this file.
1 # -*- coding: utf-8 -*-
2 import sys
3 
4 
5 
6 def complexify(a):
7  return [complex(x) for x in a]
8 
9 def complexToStr(c):
10  def fltToString(a):
11  if a == int(a): return `int(a)`
12  else: return `a`
13 
14  def imToString(a):
15  if a == 0: return "0i"
16  elif a == -1: return "-i"
17  elif a == 1: return "i"
18  else: return fltToString(a)+"i"
19 
20  re = c.real
21  im = c.imag
22  if re == 0 and im == 0: return "0"
23  elif re == 0: return imToString(im)
24  elif im == 0: return fltToString(re)
25  else:
26  im_str = "-"+imToString(-im) if im < 0 else "+"+imToString(im)
27  return fltToString(re)+im_str
28 
29 
30 
31 
32 id = complexify([
33  1, 0, 0, 0,
34  0, 1, 0, 0,
35  0, 0, 1, 0,
36  0, 0, 0, 1
37 ])
38 
39 gamma1 = complexify([
40  0, 0, 0, 1j,
41  0, 0, 1j, 0,
42  0, -1j, 0, 0,
43  -1j, 0, 0, 0
44 ])
45 
46 gamma2 = complexify([
47  0, 0, 0, 1,
48  0, 0, -1, 0,
49  0, -1, 0, 0,
50  1, 0, 0, 0
51 ])
52 
53 gamma3 = complexify([
54  0, 0, 1j, 0,
55  0, 0, 0, -1j,
56  -1j, 0, 0, 0,
57  0, 1j, 0, 0
58 ])
59 
60 gamma4 = complexify([
61  1, 0, 0, 0,
62  0, 1, 0, 0,
63  0, 0, -1, 0,
64  0, 0, 0, -1
65 ])
66 
67 igamma5 = complexify([
68  0, 0, 1j, 0,
69  0, 0, 0, 1j,
70  1j, 0, 0, 0,
71  0, 1j, 0, 0
72 ])
73 
74 
75 def gplus(g1, g2):
76  return [x+y for (x,y) in zip(g1,g2)]
77 
78 def gminus(g1, g2):
79  return [x-y for (x,y) in zip(g1,g2)]
80 
82  out = ""
83  for i in range(0, 4):
84  for j in range(0,4):
85  out += complexToStr(p[4*i+j]) + " "
86  out += "\n"
87  return out
88 
89 projectors = [
90  gminus(id,gamma1), gplus(id,gamma1),
91  gminus(id,gamma2), gplus(id,gamma2),
92  gminus(id,gamma3), gplus(id,gamma3),
93  gminus(id,gamma4), gplus(id,gamma4),
94 ]
95 
96 
97 
98 def indent(code):
99  def indentline(line): return (" "+line if (line.count("#", 0, 1) == 0) else line)
100  return ''.join([indentline(line)+"\n" for line in code.splitlines()])
101 
102 def block(code):
103  return "{\n"+indent(code)+"}"
104 
105 def sign(x):
106  if x==1: return "+"
107  elif x==-1: return "-"
108  elif x==+2: return "+2*"
109  elif x==-2: return "-2*"
110 
111 def nthFloat4(n):
112  return `(n/4)` + "." + ["x", "y", "z", "w"][n%4]
113 
114 def nthFloat2(n):
115  return `(n/2)` + "." + ["x", "y"][n%2]
116 
117 
118 def in_re(s, c): return "i"+`s`+`c`+"_re"
119 def in_im(s, c): return "i"+`s`+`c`+"_im"
120 def g_re(d, m, n): return ("g" if (d%2==0) else "gT")+`m`+`n`+"_re"
121 def g_im(d, m, n): return ("g" if (d%2==0) else "gT")+`m`+`n`+"_im"
122 def out1_re(s, c): return "o1_"+`s`+`c`+"_re"
123 def out1_im(s, c): return "o1_"+`s`+`c`+"_im"
124 def out2_re(s, c): return "o2_"+`s`+`c`+"_re"
125 def out2_im(s, c): return "o2_"+`s`+`c`+"_im"
126 def h1_re(h, c): return ["a","b"][h]+`c`+"_re"
127 def h1_im(h, c): return ["a","b"][h]+`c`+"_im"
128 def h2_re(h, c): return ["A","B"][h]+`c`+"_re"
129 def h2_im(h, c): return ["A","B"][h]+`c`+"_im"
130 def a_re(b, s, c): return "a"+`(s+2*b)`+`c`+"_re"
131 def a_im(b, s, c): return "a"+`(s+2*b)`+`c`+"_im"
132 
133 def tmp_re(s, c): return "tmp"+`s`+`c`+"_re"
134 def tmp_im(s, c): return "tmp"+`s`+`c`+"_im"
135 
136 def acc_re(s, c): return "acc_"+`s`+`c`+"_re"
137 def acc_im(s, c): return "acc_"+`s`+`c`+"_im"
138 def acc1_re(s, c): return "acc1_"+`s`+`c`+"_re"
139 def acc1_im(s, c): return "acc1_"+`s`+`c`+"_im"
140 def acc2_re(s, c): return "acc2_"+`s`+`c`+"_re"
141 def acc2_im(s, c): return "acc2_"+`s`+`c`+"_im"
142 
143 
145  str = ""
146  str += "// input spinor\n"
147  str += "#ifdef SPINOR_DOUBLE\n"
148  str += "#define spinorFloat double\n"
149  if sharedDslash:
150  str += "#define WRITE_SPINOR_SHARED WRITE_SPINOR_SHARED_DOUBLE2\n"
151  str += "#define READ_SPINOR_SHARED READ_SPINOR_SHARED_DOUBLE2\n"
152 
153  for s in range(0,4):
154  for c in range(0,3):
155  i = 3*s+c
156  str += "#define "+in_re(s,c)+" I"+nthFloat2(2*i+0)+"\n"
157  str += "#define "+in_im(s,c)+" I"+nthFloat2(2*i+1)+"\n"
158  str += "#else\n"
159  str += "#define spinorFloat float\n"
160  if sharedDslash:
161  str += "#define WRITE_SPINOR_SHARED WRITE_SPINOR_SHARED_FLOAT4\n"
162  str += "#define READ_SPINOR_SHARED READ_SPINOR_SHARED_FLOAT4\n"
163  for s in range(0,4):
164  for c in range(0,3):
165  i = 3*s+c
166  str += "#define "+in_re(s,c)+" I"+nthFloat4(2*i+0)+"\n"
167  str += "#define "+in_im(s,c)+" I"+nthFloat4(2*i+1)+"\n"
168  str += "#endif // SPINOR_DOUBLE\n\n"
169  return str
170 # end def def_input_spinor
171 
172 
173 def def_gauge():
174  str = "// gauge link\n"
175  str += "#ifdef GAUGE_FLOAT2\n"
176  for m in range(0,3):
177  for n in range(0,3):
178  i = 3*m+n
179  str += "#define "+g_re(0,m,n)+" G"+nthFloat2(2*i+0)+"\n"
180  str += "#define "+g_im(0,m,n)+" G"+nthFloat2(2*i+1)+"\n"
181 
182  str += "\n"
183  str += "#else\n"
184  for m in range(0,3):
185  for n in range(0,3):
186  i = 3*m+n
187  str += "#define "+g_re(0,m,n)+" G"+nthFloat4(2*i+0)+"\n"
188  str += "#define "+g_im(0,m,n)+" G"+nthFloat4(2*i+1)+"\n"
189 
190  str += "\n"
191  str += "#endif // GAUGE_DOUBLE\n\n"
192 
193  str += "// conjugated gauge link\n"
194  for m in range(0,3):
195  for n in range(0,3):
196  i = 3*m+n
197  str += "#define "+g_re(1,m,n)+" (+"+g_re(0,n,m)+")\n"
198  str += "#define "+g_im(1,m,n)+" (-"+g_im(0,n,m)+")\n"
199  str += "\n"
200 
201  return str
202 # end def def_gauge
203 
204 
205 
207 # sharedDslash = True: input spinors stored in shared memory
208 # sharedDslash = False: output spinors stored in shared memory
209  str = "// output spinor for flavor 1\n"
210  for s in range(0,4):
211  for c in range(0,3):
212  i = 3*s+c
213  if 2*i < sharedFloatsPerFlavor and not sharedDslash:
214  str += "#define "+out1_re(s,c)+" s["+`(2*i+0)`+"*SHARED_STRIDE]\n"
215  else:
216  str += "VOLATILE spinorFloat "+out1_re(s,c)+";\n"
217  if 2*i+1 < sharedFloatsPerFlavor and not sharedDslash:
218  str += "#define "+out1_im(s,c)+" s["+`(2*i+1)`+"*SHARED_STRIDE]\n"
219  else:
220  str += "VOLATILE spinorFloat "+out1_im(s,c)+";\n"
221 
222  str += "// output spinor for flavor 2\n"
223  for s in range(0,4):
224  for c in range(0,3):
225  i = 3*s+c
226  if 2*i < sharedFloatsPerFlavor and not sharedDslash:
227  str += "#define "+out2_re(s,c)+" s["+`(2*i+0)+sharedFloatsPerFlavor`+"*SHARED_STRIDE]\n"
228  else:
229  str += "VOLATILE spinorFloat "+out2_re(s,c)+";\n"
230  if 2*i+1 < sharedFloatsPerFlavor and not sharedDslash:
231  str += "#define "+out2_im(s,c)+" s["+`(2*i+1)+sharedFloatsPerFlavor`+"*SHARED_STRIDE]\n"
232  else:
233  str += "VOLATILE spinorFloat "+out2_im(s,c)+";\n"
234  return str
235 # end def def_output_spinor
236 
237 
238 def prolog():
239  global arch
240 #WARNING: change for twisted mass!
241  if dslash:
242  prolog_str= ("// *** CUDA NDEG TWISTED MASS DSLASH ***\n\n" if not dagger else "// *** CUDA NDEG TWISTED MASS DSLASH DAGGER ***\n\n")
243  prolog_str+= ("// Arguments (double) mu, (double)eta and (double)delta \n")
244  prolog_str+= "#define SHARED_TMNDEG_FLOATS_PER_THREAD "+str(2*sharedFloatsPerFlavor)+"\n"
245  prolog_str+= "#define FLAVORS 2\n\n"
246  else:
247  print "Undefined prolog"
248  exit
249 
250  prolog_str+= (
251 """
252 #if ((CUDA_VERSION >= 4010) && (__COMPUTE_CAPABILITY__ >= 200)) // NVVM compiler
253 #define VOLATILE
254 #else // Open64 compiler
255 #define VOLATILE volatile
256 #endif
257 """)
258 
259  prolog_str+= def_input_spinor()
260  if dslash == True: prolog_str+= def_gauge()
261  prolog_str+= def_output_spinor()
262 
263  if (sharedFloatsPerFlavor > 0):
264  if (arch >= 200):
265  prolog_str+= (
266 """
267 #ifdef SPINOR_DOUBLE
268 #define SHARED_STRIDE 16 // to avoid bank conflicts on Fermi
269 #else
270 #define SHARED_STRIDE 32 // to avoid bank conflicts on Fermi
271 #endif
272 """)
273  else:
274  prolog_str+= (
275 """
276 #ifdef SPINOR_DOUBLE
277 #define SHARED_STRIDE 8 // to avoid bank conflicts on G80 and GT200
278 #else
279 #define SHARED_STRIDE 16 // to avoid bank conflicts on G80 and GT200
280 #endif
281 """)
282 
283 
284  # set the pointer if using shared memory for pseudo registers
285 # if sharedFloatsPerFlavor > 0 and not sharedDslash:
286  if sharedFloatsPerFlavor > 0:
287  prolog_str += (
288 """
289 extern __shared__ char s_data[];
290 """)
291 
292  if dslash:
293  prolog_str += (
294 """
295 VOLATILE spinorFloat *s = (spinorFloat*)s_data + SHARED_TMNDEG_FLOATS_PER_THREAD*SHARED_STRIDE*(threadIdx.x/SHARED_STRIDE)
296  + (threadIdx.x % SHARED_STRIDE);
297 """)
298 
299  if dslash:
300  prolog_str += (
301 """
302 #include "read_gauge.h"
303 #include "io_spinor.h"
304 
305 int coord[5];
306 int X;
307 
308 int sid;
309 """)
310 
311  if sharedDslash:
312  prolog_str += (
313 """
314 #ifdef MULTI_GPU
315 int face_idx;
316 if (kernel_type == INTERIOR_KERNEL) {
317 #endif
318 
319  // Assume even dimensions
320  coordsFromIndex3D<EVEN_X>(X, coord, sid, param);
321 
322  // only need to check Y and Z dims currently since X and T set to match exactly
323  if (coord[1] >= param.dc.X[1]) return;
324  if (coord[2] >= param.dc.X[2]) return;
325 
326 """)
327  else:
328  prolog_str += (
329 """
330 #ifdef MULTI_GPU
331 int face_idx;
332 if (kernel_type == INTERIOR_KERNEL) {
333 #endif
334 
335  sid = blockIdx.x*blockDim.x + threadIdx.x;
336  if (sid >= param.threads) return;
337 
338  // Assume even dimensions
339  coordsFromIndex<4,QUDA_4D_PC,EVEN_X>(X, coord, sid, param);
340 
341 """)
342 
343  out = ""
344  for s in range(0,4):
345  for c in range(0,3):
346  out += out1_re(s,c)+" = 0; "+out1_im(s,c)+" = 0;\n"
347 
348  out += "\n"
349 
350  for s in range(0,4):
351  for c in range(0,3):
352  out += out2_re(s,c)+" = 0; "+out2_im(s,c)+" = 0;\n"
353 
354  prolog_str+= indent(out)
355 
356  prolog_str+= (
357 """
358 #ifdef MULTI_GPU
359 } else { // exterior kernel
360 
361  sid = blockIdx.x*blockDim.x + threadIdx.x;
362  if (sid >= param.threads) return;
363 
364  const int face_volume = (param.threads >> 1); // volume of one face (per flavor)
365  const int face_num = (sid >= face_volume); // is this thread updating face 0 or 1
366  face_idx = sid - face_num*face_volume; // index into the respective face
367 
368  // ghostOffset is scaled to include body (includes stride) and number of FloatN arrays (SPINOR_HOP)
369  // face_idx not sid since faces are spin projected and share the same volume index (modulo UP/DOWN reading)
370  //sp_idx = face_idx + param.ghostOffset[dim];
371 
372  coordsFromFaceIndex<4,QUDA_4D_PC,kernel_type,1>(X, sid, coord, face_idx, face_num, param);
373 
374 """)
375 
376 #for flavor 1:
377  prolog_str+= (
378 """
379  {
380  READ_INTERMEDIATE_SPINOR(INTERTEX, param.sp_stride, sid, sid);
381 """)
382 
383  out1 = " "
384  for s in range(0,4):
385  for c in range(0,3):
386  out1 += out1_re(s,c)+" = "+in_re(s,c)+"; "+out1_im(s,c)+" = "+in_im(s,c)+";\n "
387  prolog_str+= indent(out1)
388 
389 #for flavor 2:
390  prolog_str+= (
391 """
392  }
393  {
394  READ_INTERMEDIATE_SPINOR(INTERTEX, param.sp_stride, sid+param.fl_stride, sid+param.fl_stride);
395 """)
396 
397  out2 = " "
398  for s in range(0,4):
399  for c in range(0,3):
400  out2 += out2_re(s,c)+" = "+in_re(s,c)+"; "+out2_im(s,c)+" = "+in_im(s,c)+";\n "
401  prolog_str+= indent(out2)
402  prolog_str+= (
403 """
404  }
405 """)
406 
407 
408  prolog_str+= "}\n"
409  prolog_str+= "#endif // MULTI_GPU\n\n\n"
410 
411  else:
412  prolog_str+=(
413 """
414 #include "io_spinor.h"
415 
416 int sid = blockIdx.x*blockDim.x + threadIdx.x;
417 if (sid >= param.threads) return;
418 
419 // read spinor from device memory
420 READ_SPINOR(SPINORTEX, param.sp_stride, sid, sid);
421 """)
422  return prolog_str
423 # end def prolog
424 
425 
426 def gen(dir, pack_only=False):
427  projIdx = dir if not dagger else dir + (1 - 2*(dir%2))
428  projStr = projectorToStr(projectors[projIdx])
429  def proj(i,j):
430  return projectors[projIdx][4*i+j]
431 
432  # if row(i) = (j, c), then the i'th row of the projector can be represented
433  # as a multiple of the j'th row: row(i) = c row(j)
434  def row(i):
435  assert i==2 or i==3
436  if proj(i,0) == 0j:
437  return (1, proj(i,1))
438  if proj(i,1) == 0j:
439  return (0, proj(i,0))
440 
441  boundary = ["coord[0]==(param.dc.X[0]-1)", "coord[0]==0", "coord[1]==(param.dc.X[1]-1)", "coord[1]==0", "coord[2]==(param.dc.X[2]-1)", "coord[2]==0", "coord[3]==(param.dc.X[3]-1)", "coord[3]==0"]
442  interior = ["coord[0]<(param.dc.X[0]-1)", "coord[0]>0", "coord[1]<(param.dc.X[1]-1)", "coord[1]>0", "coord[2]<(param.dc.X[2]-1)", "coord[2]>0", "coord[3]<(param.dc.X[3]-1)", "coord[3]>0"]
443  dim = ["X", "Y", "Z", "T"]
444 
445  # index of neighboring site when not on boundary
446  sp_idx = ["X+1", "X-1", "X+param.dc.X[0]", "X-param.dc.X[0]", "X+param.dc.X2X1", "X-param.dc.X2X1", "X+param.dc.X3X2X1", "X-param.dc.X3X2X1"]
447 
448  # index of neighboring site (across boundary)
449  sp_idx_wrap = ["X-(param.dc.X[0]-1)", "X+(param.dc.X[0]-1)", "X-param.dc.X2X1mX1", "X+param.dc.X2X1mX1", "X-param.dc.X3X2X1mX2X1", "X+param.dc.X3X2X1mX2X1",
450  "X-param.dc.X4X3X2X1mX3X2X1", "X+param.dc.X4X3X2X1mX3X2X1"]
451 
452  cond = ""
453  cond += "#ifdef MULTI_GPU\n"
454  cond += "if ( (kernel_type == INTERIOR_KERNEL && (!param.ghostDim["+`dir/2`+"] || "+interior[dir]+")) ||\n"
455  cond += " (kernel_type == EXTERIOR_KERNEL_"+dim[dir/2]+" && "+boundary[dir]+") )\n"
456  cond += "#endif\n"
457 
458  str = ""
459 
460  projName = "P"+`dir/2`+["-","+"][projIdx%2]
461  str += "// Projector "+projName+"\n"
462  for l in projStr.splitlines():
463  str += "// "+l+"\n"
464  str += "\n"
465 
466  str += "#ifdef MULTI_GPU\n"
467  str += "const int sp_idx = (kernel_type == INTERIOR_KERNEL) ? ("+boundary[dir]+" ? "+sp_idx_wrap[dir]+" : "+sp_idx[dir]+") >> 1 :\n"
468  str += " face_idx + param.ghostOffset[static_cast<int>(kernel_type)][" + `(dir+1)%2` + "];\n"
469  str += "#if (DD_PREC==2) // half precision\n"
470  str += "const int sp_norm_idx = face_idx + param.ghostNormOffset[static_cast<int>(kernel_type)][" + `(dir+1)%2` + "];\n"
471  str += "#endif\n"
472  str += "#else\n"
473  str += "const int sp_idx = ("+boundary[dir]+" ? "+sp_idx_wrap[dir]+" : "+sp_idx[dir]+") >> 1;\n"
474  str += "#endif\n"
475 
476  str += "\n"
477  if dir % 2 == 0:
478  str += "const int ga_idx = sid;\n"
479  else:
480  str += "#ifdef MULTI_GPU\n"
481  str += "const int ga_idx = ((kernel_type == INTERIOR_KERNEL) ? sp_idx : param.dc.Vh+face_idx);\n"
482  str += "#else\n"
483  str += "const int ga_idx = sp_idx;\n"
484  str += "#endif\n"
485  str += "\n"
486 
487  # scan the projector to determine which loads are required
488  row_cnt = ([0,0,0,0])
489  for h in range(0,4):
490  for s in range(0,4):
491  re = proj(h,s).real
492  im = proj(h,s).imag
493  if re != 0 or im != 0:
494  row_cnt[h] += 1
495  row_cnt[0] += row_cnt[1]
496  row_cnt[2] += row_cnt[3]
497 
498  decl_half = ""
499  for h in range(0, 2):
500  for c in range(0, 3):
501  decl_half += "spinorFloat "+h1_re(h,c)+", "+h1_im(h,c)+";\n";
502  decl_half += "\n"
503 
504  load_gauge = "// read gauge matrix from device memory\n"
505  load_gauge += "READ_GAUGE_MATRIX(G, GAUGE"+`dir%2`+"TEX, "+`dir`+", ga_idx, param.gauge_stride);\n\n"
506 
507  reconstruct_gauge = "// reconstruct gauge matrix\n"
508  reconstruct_gauge += "RECONSTRUCT_GAUGE_MATRIX("+`dir`+");\n\n"
509 
510 #flavor 1:
511  load_flv1 = "// read flavor 1 from device memory\n"
512  if row_cnt[0] == 0:
513  load_flv1 += "READ_SPINOR_DOWN(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
514  elif row_cnt[2] == 0:
515  load_flv1 += "READ_SPINOR_UP(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
516  else:
517  load_flv1 += "READ_SPINOR(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
518  load_flv1 += "\n"
519 
520 #flavor 2:
521  load_flv2 = "// read flavor 2 from device memory\n"
522  if row_cnt[0] == 0:
523  load_flv2 += "READ_SPINOR_DOWN(SPINORTEX, param.sp_stride, sp_idx+param.fl_stride, sp_idx+param.fl_stride);\n"
524  elif row_cnt[2] == 0:
525  load_flv2 += "READ_SPINOR_UP(SPINORTEX, param.sp_stride, sp_idx+param.fl_stride, sp_idx+param.fl_stride);\n"
526  else:
527  load_flv2 += "READ_SPINOR(SPINORTEX, param.sp_stride, sp_idx+param.fl_stride, sp_idx+param.fl_stride);\n"
528  load_flv2 += "\n"
529 
530 
531  load_half_cond = ""
532  load_half_cond += "const int sp_stride_pad = FLAVORS*param.dc.ghostFace[static_cast<int>(kernel_type)];\n"
533  #load_half += "#if (DD_PREC==2) // half precision\n"
534  #load_half += "const int sp_norm_idx = sid + param.ghostNormOffset[static_cast<int>(kernel_type)];\n"
535  #load_half += "#endif\n"
536 
537  if dir >= 6: load_half_cond += "const int t_proj_scale = TPROJSCALE;\n"
538  load_half_cond += "\n"
539 
540  load_half_flv1 = "// read half spinor for the first flavor from device memory\n"
541 
542 
543 
544 # we have to use the same volume index for backwards and forwards gathers
545 # instead of using READ_UP_SPINOR and READ_DOWN_SPINOR, just use READ_HALF_SPINOR with the appropriate shift
546 #if (dir+1) % 2 == 0:
547  # load_half_flv1 += "READ_HALF_SPINOR(SPINORTEX, sp_stride_pad, sp_idx, sp_norm_idx);\n\n"
548  # else:
549 #flavor offset: extra param.dc.ghostFace[static_cast<int>(kernel_type)]
550  # load_half_flv1 += "READ_HALF_SPINOR(SPINORTEX, sp_stride_pad, sp_idx + (SPINOR_HOP/2)*sp_stride_pad, sp_norm_idx);\n\n"
551 
552  load_half_flv1 += "READ_SPINOR_GHOST(GHOSTSPINORTEX, sp_stride_pad, sp_idx, sp_norm_idx, "+`dir`+");\n\n"
553 
554  load_half_flv2 = "// read half spinor for the second flavor from device memory\n"
555  load_half_flv2 += "const int fl_idx = sp_idx + param.dc.ghostFace[static_cast<int>(kernel_type)];\n"
556 # we have to use the same volume index for backwards and forwards gathers
557 # instead of using READ_UP_SPINOR and READ_DOWN_SPINOR, just use READ_HALF_SPINOR with the appropriate shift
558  #if (dir+1) % 2 == 0:
559  # load_half_flv2 += "READ_HALF_SPINOR(SPINORTEX, sp_stride_pad, fl_idx, sp_norm_idx+param.dc.ghostFace[static_cast<int>(kernel_type)]);\n\n"
560  #else:
561 #flavor offset: extra param.dc.ghostFace[static_cast<int>(kernel_type)]
562  # load_half_flv2 += "READ_HALF_SPINOR(SPINORTEX, sp_stride_pad, fl_idx + (SPINOR_HOP/2)*sp_stride_pad, sp_norm_idx+param.dc.ghostFace[static_cast<int>(kernel_type)]);\n\n"
563 
564  load_half_flv2 += "READ_SPINOR_GHOST(GHOSTSPINORTEX, sp_stride_pad, fl_idx, sp_norm_idx+param.dc.ghostFace[static_cast<int>(kernel_type)],"+`dir`+");\n\n"
565 
566  project = "// project spinor into half spinors\n"
567  for h in range(0, 2):
568  for c in range(0, 3):
569  strRe = ""
570  strIm = ""
571  for s in range(0, 4):
572  re = proj(h,s).real
573  im = proj(h,s).imag
574  if re==0 and im==0: ()
575  elif im==0:
576  strRe += sign(re)+in_re(s,c)
577  strIm += sign(re)+in_im(s,c)
578  elif re==0:
579  strRe += sign(-im)+in_im(s,c)
580  strIm += sign(im)+in_re(s,c)
581  if row_cnt[0] == 0: # projector defined on lower half only
582  for s in range(0, 4):
583  re = proj(h+2,s).real
584  im = proj(h+2,s).imag
585  if re==0 and im==0: ()
586  elif im==0:
587  strRe += sign(re)+in_re(s,c)
588  strIm += sign(re)+in_im(s,c)
589  elif re==0:
590  strRe += sign(-im)+in_im(s,c)
591  strIm += sign(im)+in_re(s,c)
592 
593  project += h1_re(h,c)+" = "+strRe+";\n"
594  project += h1_im(h,c)+" = "+strIm+";\n"
595 
596  write_shared = (
597 """// store spinor into shared memory
598 WRITE_SPINOR_SHARED(threadIdx.x, threadIdx.y, threadIdx.z, i);\n
599 """)
600 
601  load_shared_1 = (
602 """// load spinor from shared memory
603 int tx = (threadIdx.x > 0) ? threadIdx.x-1 : blockDim.x-1;
604 __syncthreads();
605 READ_SPINOR_SHARED(tx, threadIdx.y, threadIdx.z);\n
606 """)
607 
608  load_shared_2 = (
609 """// load spinor from shared memory
610 int tx = (threadIdx.x + blockDim.x - ((coord[0]+1)&1) ) % blockDim.x;
611 int ty = (threadIdx.y < blockDim.y - 1) ? threadIdx.y + 1 : 0;
612 READ_SPINOR_SHARED(tx, ty, threadIdx.z);\n
613 """)
614 
615  load_shared_3 = (
616 """// load spinor from shared memory
617 int tx = (threadIdx.x + blockDim.x - ((coord[0]+1)&1)) % blockDim.x;
618 int ty = (threadIdx.y > 0) ? threadIdx.y - 1 : blockDim.y - 1;
619 READ_SPINOR_SHARED(tx, ty, threadIdx.z);\n
620 """)
621 
622  load_shared_4 = (
623 """// load spinor from shared memory
624 int tx = (threadIdx.x + blockDim.x - ((coord[0]+1)&1) ) % blockDim.x;
625 int tz = (threadIdx.z < blockDim.z - 1) ? threadIdx.z + 1 : 0;
626 READ_SPINOR_SHARED(tx, threadIdx.y, tz);\n
627 """)
628 
629  load_shared_5 = (
630 """// load spinor from shared memory
631 int tx = (threadIdx.x + blockDim.x - ((coord[0]+1)&1)) % blockDim.x;
632 int tz = (threadIdx.z > 0) ? threadIdx.z - 1 : blockDim.z - 1;
633 READ_SPINOR_SHARED(tx, threadIdx.y, tz);\n
634 """)
635 
636 
637  copy_half = ""
638  for h in range(0, 2):
639  for c in range(0, 3):
640  copy_half += h1_re(h,c)+" = "+("t_proj_scale*" if (dir >= 6) else "")+in_re(h,c)+"; "
641  copy_half += h1_im(h,c)+" = "+("t_proj_scale*" if (dir >= 6) else "")+in_im(h,c)+";\n"
642 
643  copy_half += "\n"
644 
645  prep_half_cond1 = ""
646  prep_half_cond1 += "#ifdef MULTI_GPU\n"
647  prep_half_cond1 += "if (kernel_type == INTERIOR_KERNEL) {\n"
648  prep_half_cond1 += "#endif\n"
649  prep_half_cond1 += "\n"
650 
651  prep_half_flv1 = ""
652  prep_half_flv1 += indent(load_flv1)
653  prep_half_flv1 += indent(project)
654 
655  prep_half_flv2 = ""
656  prep_half_flv2 += indent(load_flv2)
657  prep_half_flv2 += indent(project)
658 
659  prep_half_cond2 = "\n"
660  prep_half_cond2 += "#ifdef MULTI_GPU\n"
661  prep_half_cond2 += "} else {\n"
662  prep_half_cond2 += "\n"
663 
664  prep_face_flv1 = indent(load_half_flv1)
665  prep_face_flv2 = indent(load_half_flv2)
666 
667  prep_half = indent(copy_half)
668 
669  prep_half_cond3 = "}\n"
670  prep_half_cond3 += "#endif // MULTI_GPU\n"
671  prep_half_cond3 += "\n"
672 
673  ident = "// identity gauge matrix\n"
674  for m in range(0,3):
675  for h in range(0,2):
676  ident += "spinorFloat "+h2_re(h,m)+" = " + h1_re(h,m) + "; "
677  ident += "spinorFloat "+h2_im(h,m)+" = " + h1_im(h,m) + ";\n"
678  ident += "\n"
679 
680  mult = ""
681  for m in range(0,3):
682  mult += "// multiply row "+`m`+"\n"
683  for h in range(0,2):
684  re = "spinorFloat "+h2_re(h,m)+" = 0;\n"
685  im = "spinorFloat "+h2_im(h,m)+" = 0;\n"
686  for c in range(0,3):
687  re += h2_re(h,m) + " += " + g_re(dir,m,c) + " * "+h1_re(h,c)+";\n"
688  re += h2_re(h,m) + " -= " + g_im(dir,m,c) + " * "+h1_im(h,c)+";\n"
689  im += h2_im(h,m) + " += " + g_re(dir,m,c) + " * "+h1_im(h,c)+";\n"
690  im += h2_im(h,m) + " += " + g_im(dir,m,c) + " * "+h1_re(h,c)+";\n"
691  mult += re + im
692  mult += "\n"
693 
694  reconstruct_flv1 = ""
695  for m in range(0,3):
696 
697  for h in range(0,2):
698  h_out = h
699  if row_cnt[0] == 0: # projector defined on lower half only
700  h_out = h+2
701  reconstruct_flv1 += out1_re(h_out, m) + " += " + h2_re(h,m) + ";\n"
702  reconstruct_flv1 += out1_im(h_out, m) + " += " + h2_im(h,m) + ";\n"
703 
704  for s in range(2,4):
705  (h,c) = row(s)
706  re = c.real
707  im = c.imag
708  if im == 0 and re == 0:
709  ()
710  elif im == 0:
711  reconstruct_flv1 += out1_re(s, m) + " " + sign(re) + "= " + h2_re(h,m) + ";\n"
712  reconstruct_flv1 += out1_im(s, m) + " " + sign(re) + "= " + h2_im(h,m) + ";\n"
713  elif re == 0:
714  reconstruct_flv1 += out1_re(s, m) + " " + sign(-im) + "= " + h2_im(h,m) + ";\n"
715  reconstruct_flv1 += out1_im(s, m) + " " + sign(+im) + "= " + h2_re(h,m) + ";\n"
716 
717  reconstruct_flv1 += "\n"
718 
719  reconstruct_flv2 = ""
720  for m in range(0,3):
721 
722  for h in range(0,2):
723  h_out = h
724  if row_cnt[0] == 0: # projector defined on lower half only
725  h_out = h+2
726  reconstruct_flv2 += out2_re(h_out, m) + " += " + h2_re(h,m) + ";\n"
727  reconstruct_flv2 += out2_im(h_out, m) + " += " + h2_im(h,m) + ";\n"
728 
729  for s in range(2,4):
730  (h,c) = row(s)
731  re = c.real
732  im = c.imag
733  if im == 0 and re == 0:
734  ()
735  elif im == 0:
736  reconstruct_flv2 += out2_re(s, m) + " " + sign(re) + "= " + h2_re(h,m) + ";\n"
737  reconstruct_flv2 += out2_im(s, m) + " " + sign(re) + "= " + h2_im(h,m) + ";\n"
738  elif re == 0:
739  reconstruct_flv2 += out2_re(s, m) + " " + sign(-im) + "= " + h2_im(h,m) + ";\n"
740  reconstruct_flv2 += out2_im(s, m) + " " + sign(+im) + "= " + h2_re(h,m) + ";\n"
741 
742  reconstruct_flv2 += "\n"
743 
744 
745  if dir >= 6:
746  str += decl_half
747  str += "if (param.gauge_fixed && ga_idx < param.dc.X4X3X2X1hmX3X2X1h)\n"
748  str += block("{\n" + prep_half_cond1 + prep_half_flv1 + prep_half_cond2 + load_half_cond + prep_face_flv1 + prep_half + prep_half_cond3 + ident + reconstruct_flv1 + "}\n" + "{\n" + prep_half_cond1 + prep_half_flv2 + prep_half_cond2 + load_half_cond + prep_face_flv2 + prep_half + prep_half_cond3 + ident + reconstruct_flv2 + "}\n")
749  str += " else "
750  str += block(load_gauge + reconstruct_gauge + "{\n"+ prep_half_cond1 + prep_half_flv1 + prep_half_cond2 + load_half_cond + prep_face_flv1 + prep_half + prep_half_cond3 + mult + reconstruct_flv1 + "}\n" + "{\n"+ prep_half_cond1 + prep_half_flv2 + prep_half_cond2 + load_half_cond + prep_face_flv2 + prep_half + prep_half_cond3 + mult + reconstruct_flv2 +"}\n")
751  else:
752  str += decl_half + load_gauge + reconstruct_gauge
753  str +="{\n" + prep_half_cond1 + prep_half_flv1 + prep_half_cond2 + load_half_cond + prep_face_flv1 + prep_half + prep_half_cond3 + mult + reconstruct_flv1 + "}\n"
754  str +="{\n" + prep_half_cond1 + prep_half_flv2 + prep_half_cond2 + load_half_cond + prep_face_flv2 + prep_half + prep_half_cond3 + mult + reconstruct_flv2 + "}\n"
755 
756  if pack_only:
757  out = load_spinor + decl_half + project
758  out = out.replace("sp_idx", "idx")
759  return out
760  else:
761  return cond + block(str)+"\n\n"
762 # end def gen
763 
764 
765 def twisted():
766 
767  str = ""
768  str += "#ifdef SPINOR_DOUBLE\n"
769  str += "const spinorFloat a = param.a;\n"
770  str += "const spinorFloat b = param.b;\n"
771  str += "#else\n"
772  str += "const spinorFloat a = param.a_f;\n"
773  str += "const spinorFloat b = param.b_f;\n"
774  str += "#endif\n"
775 
776 
777  str += "//Perform twist rotation first:\n"
778  if dagger :
779  str += "//(1 + i*a*gamma_5 * tau_3 + b * tau_1)\n"
780  else:
781  str += "//(1 - i*a*gamma_5 * tau_3 + b * tau_1)\n"
782  str += "volatile spinorFloat x1_re, x1_im, y1_re, y1_im;\n"
783  str += "volatile spinorFloat x2_re, x2_im, y2_re, y2_im;\n\n"
784 
785  str += "x1_re = 0.0, x1_im = 0.0;\n"
786  str += "y1_re = 0.0, y1_im = 0.0;\n"
787  str += "x2_re = 0.0, x2_im = 0.0;\n"
788  str += "y2_re = 0.0, y2_im = 0.0;\n\n\n"
789 
790  a1 = ""
791  a2 = ""
792 
793  if dagger :
794  a1 += " - a *"
795  a2 += " + a *"
796  else:
797  a1 += " + a *"
798  a2 += " - a *"
799 
800  for c in range(0,3):
801  for h in range(0,2):
802  #h, h+2
803  str += "// using o1 regs:\n"
804  str += "x1_re = " + out1_re(h,c) + a1 + out1_im(h+2,c) + ";\n"
805  str += "x1_im = " + out1_im(h,c) + a2 + out1_re(h+2,c) + ";\n"
806  str += "x2_re = " + "b * " + out1_re(h,c) + ";\n"
807  str += "x2_im = " + "b * " + out1_im(h,c) + ";\n\n"
808  str += "y1_re = " + out1_re(h+2,c) + a1 + out1_im(h,c) + ";\n"
809  str += "y1_im = " + out1_im(h+2,c) + a2 + out1_re(h,c) + ";\n"
810  str += "y2_re = " + "b * " + out1_re(h+2,c) + ";\n"
811  str += "y2_im = " + "b * " + out1_im(h+2,c) + ";\n\n\n"
812  str += "// using o2 regs:\n"
813  str += "x2_re += " + out2_re(h,c) + a2 + out2_im(h+2,c) + ";\n"
814  str += "x2_im += " + out2_im(h,c) + a1 + out2_re(h+2,c) + ";\n"
815  str += "x1_re += " + "b * " + out2_re(h,c) + ";\n"
816  str += "x1_im += " + "b * " + out2_im(h,c) + ";\n\n"
817  str += "y2_re += " + out2_re(h+2,c) + a2 + out2_im(h,c) + ";\n"
818  str += "y2_im += " + out2_im(h+2,c) + a1 + out2_re(h,c) + ";\n"
819  str += "y1_re += " + "b * " + out2_re(h+2,c) + ";\n"
820  str += "y1_im += " + "b * " + out2_im(h+2,c) + ";\n"
821  str += "\n\n"
822  str += out1_re(h,c) + " = x1_re; " + out1_im(h,c) + " = x1_im;\n"
823  str += out1_re(h+2,c) + " = y1_re; " + out1_im(h+2,c) + " = y1_im;\n"
824  str += "\n"
825  str += out2_re(h,c) + " = x2_re; " + out2_im(h,c) + " = x2_im;\n"
826  str += out2_re(h+2,c) + " = y2_re; " + out2_im(h+2,c) + " = y2_im;\n\n"
827  #str += "#endif\n"
828 
829  return "#ifdef DSLASH_TWIST\n" + block(str) + "\n#endif\n"
830 # end def twisted
831 
832 
833 def xpay():
834  str = "\n"
835 
836  str += "#if !defined(DSLASH_XPAY) || defined(DSLASH_TWIST)\n"
837  str += "#ifdef SPINOR_DOUBLE\n"
838  str += "const spinorFloat c = param.c;\n"
839  str += "#else\n"
840  str += "const spinorFloat c = param.c_f;\n"
841  str += "#endif\n"
842  str += "#endif\n"
843 
844  str += "#ifndef DSLASH_XPAY\n"
845 
846  for s in range(0,4):
847  for c in range(0,3):
848  i = 3*s+c
849  str += out1_re(s,c) +" *= c;\n"
850  str += out1_im(s,c) +" *= c;\n"
851  str += "\n"
852 
853  for s in range(0,4):
854  for c in range(0,3):
855  i = 3*s+c
856  str += out2_re(s,c) +" *= c;\n"
857  str += out2_im(s,c) +" *= c;\n"
858 
859 
860  str += "#else\n"
861 
862  str += "#ifdef DSLASH_TWIST\n"
863  str += "// accum spinor\n"
864  str += "#ifdef SPINOR_DOUBLE\n"
865  str += "\n"
866  for s in range(0,4):
867  for c in range(0,3):
868  i = 3*s+c
869  str += "#define "+acc_re(s,c)+" accum"+nthFloat2(2*i+0)+"\n"
870  str += "#define "+acc_im(s,c)+" accum"+nthFloat2(2*i+1)+"\n"
871  str += "\n"
872  str += "#else\n"
873  for s in range(0,4):
874  for c in range(0,3):
875  i = 3*s+c
876  str += "#define "+acc_re(s,c)+" accum"+nthFloat4(2*i+0)+"\n"
877  str += "#define "+acc_im(s,c)+" accum"+nthFloat4(2*i+1)+"\n"
878  str += "\n"
879  str += "#endif // SPINOR_DOUBLE\n\n"
880  str += "{\n"
881  str += " READ_ACCUM(ACCUMTEX, param.sp_stride)\n\n"
882  for s in range(0,4):
883  for c in range(0,3):
884  i = 3*s+c
885  str += " " + out1_re(s,c) +" = c*"+out1_re(s,c)+ " + "+ acc_re(s,c)+";\n"
886  str += " " + out1_im(s,c) +" = c*"+out1_im(s,c)+ " + "+ acc_im(s,c)+";\n"
887  str += "\n"
888  str += " ASSN_ACCUM(ACCUMTEX, param.sp_stride, param.fl_stride)\n\n"
889  for s in range(0,4):
890  for c in range(0,3):
891  i = 3*s+c
892  str += " " + out2_re(s,c) +" = c*"+out2_re(s,c)+ " + "+ acc_re(s,c)+";\n"
893  str += " " + out2_im(s,c) +" = c*"+out2_im(s,c)+ " + "+ acc_im(s,c)+";\n"
894  str += "}\n"
895  str += "\n"
896  for s in range(0,4):
897  for c in range(0,3):
898  i = 3*s+c
899  str += "#undef "+acc_re(s,c)+"\n"
900  str += "#undef "+acc_im(s,c)+"\n"
901  str += "\n"
902  str += "#else\n"
903 
904  str += "// accum spinor\n"
905  str += "#ifdef SPINOR_DOUBLE\n"
906  str += "\n"
907  for s in range(0,4):
908  for c in range(0,3):
909  i = 3*s+c
910  str += "#define "+acc1_re(s,c)+" flv1_accum"+nthFloat2(2*i+0)+"\n"
911 
912  str += "#define "+acc1_im(s,c)+" flv1_accum"+nthFloat2(2*i+1)+"\n"
913  str += "\n"
914  for s in range(0,4):
915  for c in range(0,3):
916  i = 3*s+c
917  str += "#define "+acc2_re(s,c)+" flv2_accum"+nthFloat2(2*i+0)+"\n"
918  str += "#define "+acc2_im(s,c)+" flv2_accum"+nthFloat2(2*i+1)+"\n"
919  str += "\n"
920  str += "#else\n"
921  str += "\n"
922  for s in range(0,4):
923  for c in range(0,3):
924  i = 3*s+c
925  str += "#define "+acc1_re(s,c)+" flv1_accum"+nthFloat4(2*i+0)+"\n"
926  str += "#define "+acc1_im(s,c)+" flv1_accum"+nthFloat4(2*i+1)+"\n"
927  str += "\n"
928  for s in range(0,4):
929  for c in range(0,3):
930  i = 3*s+c
931  str += "#define "+acc2_re(s,c)+" flv2_accum"+nthFloat4(2*i+0)+"\n"
932  str += "#define "+acc2_im(s,c)+" flv2_accum"+nthFloat4(2*i+1)+"\n"
933  str += "\n"
934  str += "#endif // SPINOR_DOUBLE\n\n"
935 
936  str += "{\n"
937 
938  str += " READ_ACCUM_FLAVOR(ACCUMTEX, param.sp_stride, param.fl_stride)\n\n"
939 
940  str += "#ifdef SPINOR_DOUBLE\n"
941  str += "const spinorFloat a = param.a;\n"
942  str += "const spinorFloat b = param.b;\n"
943  str += "#else\n"
944  str += "const spinorFloat a = param.a_f;\n"
945  str += "const spinorFloat b = param.b_f;\n"
946  str += "#endif\n"
947 
948  str += " //Perform twist rotation:\n"
949  if dagger :
950  str += "//(1 + i*a*gamma_5 * tau_3 + b * tau_1)\n"
951  else:
952  str += "//(1 - i*a*gamma_5 * tau_3 + b * tau_1)\n"
953  str += " volatile spinorFloat x1_re, x1_im, y1_re, y1_im;\n"
954  str += " volatile spinorFloat x2_re, x2_im, y2_re, y2_im;\n\n"
955 
956  str += " x1_re = 0.0, x1_im = 0.0;\n"
957  str += " y1_re = 0.0, y1_im = 0.0;\n"
958  str += " x2_re = 0.0, x2_im = 0.0;\n"
959  str += " y2_re = 0.0, y2_im = 0.0;\n\n\n"
960 
961  a1 = ""
962  a2 = ""
963 
964  if dagger :
965  a1 += " - a *"
966  a2 += " + a *"
967  else:
968  a1 += " + a *"
969  a2 += " - a *"
970 
971  for c in range(0,3):
972  for h in range(0,2):
973  #h, h+2
974  str += " // using acc1 regs:\n"
975  str += " x1_re = " + acc1_re(h,c) + a1 + acc1_im(h+2,c) + ";\n"
976  str += " x1_im = " + acc1_im(h,c) + a2 + acc1_re(h+2,c) + ";\n"
977  str += " x2_re = " + "b * " + acc1_re(h,c) + ";\n"
978  str += " x2_im = " + "b * " + acc1_im(h,c) + ";\n\n"
979  str += " y1_re = " + acc1_re(h+2,c) + a1 + acc1_im(h,c) + ";\n"
980  str += " y1_im = " + acc1_im(h+2,c) + a2 + acc1_re(h,c) + ";\n"
981  str += " y2_re = " + "b * " + acc1_re(h+2,c) + ";\n"
982  str += " y2_im = " + "b * " + acc1_im(h+2,c) + ";\n\n\n"
983  str += " // using acc2 regs:\n"
984  str += " x2_re += " + acc2_re(h,c) + a2 + acc2_im(h+2,c) + ";\n"
985  str += " x2_im += " + acc2_im(h,c) + a1 + acc2_re(h+2,c) + ";\n"
986  str += " x1_re += " + "b * " + acc2_re(h,c) + ";\n"
987  str += " x1_im += " + "b * " + acc2_im(h,c) + ";\n\n"
988  str += " y2_re += " + acc2_re(h+2,c) + a2 + acc2_im(h,c) + ";\n"
989  str += " y2_im += " + acc2_im(h+2,c) + a1 + acc2_re(h,c) + ";\n"
990  str += " y1_re += " + "b * " + acc2_re(h+2,c) + ";\n"
991  str += " y1_im += " + "b * " + acc2_im(h+2,c) + ";\n"
992  str += "\n\n"
993  str += acc1_re(h,c) + " = x1_re; " + acc1_im(h,c) + " = x1_im;\n"
994  str += acc1_re(h+2,c) + " = y1_re; " + acc1_im(h+2,c) + " = y1_im;\n"
995  str += "\n"
996  str += acc2_re(h,c) + " = x2_re; " + acc2_im(h,c) + " = x2_im;\n"
997  str += acc2_re(h+2,c) + " = y2_re; " + acc2_im(h+2,c) + " = y2_im;\n\n"
998 
999 
1000  str += "#ifdef SPINOR_DOUBLE\n"
1001  str += "const spinorFloat k = param.d;\n"
1002  str += "#else\n"
1003  str += "const spinorFloat k = param.d_f;\n"
1004  str += "#endif\n"
1005 
1006  for s in range(0,4):
1007  for c in range(0,3):
1008  i = 3*s+c
1009  str += " " + out1_re(s,c) +" = k*"+out1_re(s,c) + " + "+ acc1_re(s,c)+";\n"
1010  str += " " + out1_im(s,c) +" = k*"+out1_im(s,c) + " + "+ acc1_im(s,c)+ ";\n"
1011 
1012  str += "\n"
1013 
1014  for s in range(0,4):
1015  for c in range(0,3):
1016  i = 3*s+c
1017  str += " " + out2_re(s,c) +" = k*"+out2_re(s,c) + " + "+ acc2_re(s,c)+ ";\n"
1018  str += " " + out2_im(s,c) +" = k*"+out2_im(s,c) + " + "+ acc2_im(s,c)+ ";\n"
1019 
1020  str += "}\n"
1021  str += "\n"
1022  for s in range(0,4):
1023  for c in range(0,3):
1024  i = 3*s+c
1025  str += "#undef "+acc1_re(s,c)+"\n"
1026  str += "#undef "+acc1_im(s,c)+"\n"
1027  str += "\n"
1028  for s in range(0,4):
1029  for c in range(0,3):
1030  i = 3*s+c
1031  str += "#undef "+acc2_re(s,c)+"\n"
1032  str += "#undef "+acc2_im(s,c)+"\n"
1033  str += "\n"
1034  str += "#endif//DSLASH_TWIST\n"
1035  str += "\n"
1036  str += "#endif // DSLASH_XPAY\n"
1037 
1038  return str
1039 # end def xpay
1040 
1041 
1042 def epilog():
1043  str = ""
1044  if dslash:
1045  str += "#ifdef MULTI_GPU\n"
1046  str += (
1047 """
1048 int incomplete = 0; // Have all 8 contributions been computed for this site?
1049 
1050 switch(kernel_type) { // intentional fall-through
1051 case INTERIOR_KERNEL:
1052  incomplete = incomplete || (param.commDim[3] && (coord[3]==0 || coord[3]==(param.dc.X[3]-1)));
1053 case EXTERIOR_KERNEL_T:
1054  incomplete = incomplete || (param.commDim[2] && (coord[2]==0 || coord[2]==(param.dc.X[2]-1)));
1055 case EXTERIOR_KERNEL_Z:
1056  incomplete = incomplete || (param.commDim[1] && (coord[1]==0 || coord[1]==(param.dc.X[1]-1)));
1057 case EXTERIOR_KERNEL_Y:
1058  incomplete = incomplete || (param.commDim[0] && (coord[0]==0 || coord[0]==(param.dc.X[0]-1)));
1059 }
1060 
1061 """)
1062  str += "\n"
1063  str += "if (!incomplete)\n"
1064  str += "#endif // MULTI_GPU\n"
1065  str += "// apply twisted mass rotation\n"
1066  str += block( "\n" + twisted() + xpay() )
1067 
1068  str += "\n\n"
1069  str += "// write spinor field back to device memory\n"
1070  str += "WRITE_FLAVOR_SPINOR();\n\n"
1071 
1072  str += "// undefine to prevent warning when precision is changed\n"
1073  str += "#undef spinorFloat\n"
1074  if sharedDslash:
1075  str += "#undef WRITE_SPINOR_SHARED\n"
1076  str += "#undef READ_SPINOR_SHARED\n"
1077  if sharedFloatsPerFlavor > 0: str += "#undef SHARED_STRIDE\n\n"
1078 
1079  if dslash:
1080  for m in range(0,3):
1081  for n in range(0,3):
1082  i = 3*m+n
1083  str += "#undef "+g_re(0,m,n)+"\n"
1084  str += "#undef "+g_im(0,m,n)+"\n"
1085  str += "\n"
1086 
1087  for s in range(0,4):
1088  for c in range(0,3):
1089  i = 3*s+c
1090  str += "#undef "+in_re(s,c)+"\n"
1091  str += "#undef "+in_im(s,c)+"\n"
1092  str += "\n"
1093 #fixme
1094  for s in range(0,4):
1095  for c in range(0,3):
1096  i = 3*s+c
1097  if 2*i < sharedFloatsPerFlavor:
1098  str += "#undef "+out1_re(s,c)+"\n"
1099  if 2*i+1 < sharedFloatsPerFlavor:
1100  str += "#undef "+out1_im(s,c)+"\n"
1101  str += "\n"
1102 
1103  str += "#undef VOLATILE\n"
1104 
1105  return str
1106 # end def epilog
1107 
1108 
1109 
1111  return prolog() + gen(0) + gen(1) + gen(2) + gen(3) + gen(4) + gen(5) + gen(6) + gen(7) + epilog()
1112 # return prolog() + epilog()
1113 
1114 
1115 # generate Wilson-like Dslash kernels
1117  print "Generating dslash kernel for sm" + str(arch/10)
1118 
1119  global sharedFloatsPerFlavor
1120  global sharedDslash
1121  global dslash
1122  global dagger
1123  global twist
1124 
1125  sharedFloatsPerFlavor = 0
1126  if arch >= 200:
1127  sharedFloatsPerFlavor = 0
1128  #sharedDslash = True
1129  sharedDslash = False
1130  name = "fermi"
1131  elif arch >= 120:
1132  sharedFloatsPerFlavor = 0
1133  sharedDslash = False
1134  name = "gt200"
1135  else:
1136  sharedFloatsPerFlavor = 19
1137  sharedDslash = False
1138  name = "g80"
1139 
1140  print "Shared floats set to " + str(sharedFloatsPerFlavor)
1141 
1142  dslash = True
1143  twist = False
1144  dagger = False
1145 
1146  twist = True
1147  dagger = False
1148  filename = 'dslash_core/tm_ndeg_dslash_core.h'
1149  print sys.argv[0] + ": generating " + filename;
1150  f = open(filename, 'w')
1151  f.write(generate_dslash())
1152  f.close()
1153 
1154  dagger = True
1155  filename = 'dslash_core/tm_ndeg_dslash_dagger_core.h'
1156  print sys.argv[0] + ": generating " + filename + "\n";
1157  f = open(filename, 'w')
1158  f.write(generate_dslash())
1159  f.close()
1160 
1161  dslash = False
1162 
1163 
1164 
1165 dslash = False
1166 dagger = False
1167 twist = False
1168 sharedFloatsPerFlavor = 0
1169 sharedDslash = False
1170 
1171 # generate dslash kernels
1172 #arch = 200
1173 #generate_dslash_kernels(arch)
1174 
1175 arch = 200
1177 
1178 #arch = 100
1179 #generate_dslash_kernels(arch)
def indent(code)
code generation ######################################################################## ...
def gen(dir, pack_only=False)
Definition: gen.py:1
def complexify(a)
complex numbers ######################################################################## ...