QUDA  0.9.0
dslash_cuda_gen.py
Go to the documentation of this file.
1 import sys
2 
3 
4 
5 def complexify(a):
6  return [complex(x) for x in a]
7 
8 def complexToStr(c):
9  def fltToString(a):
10  if a == int(a): return `int(a)`
11  else: return `a`
12 
13  def imToString(a):
14  if a == 0: return "0i"
15  elif a == -1: return "-i"
16  elif a == 1: return "i"
17  else: return fltToString(a)+"i"
18 
19  re = c.real
20  im = c.imag
21  if re == 0 and im == 0: return "0"
22  elif re == 0: return imToString(im)
23  elif im == 0: return fltToString(re)
24  else:
25  im_str = "-"+imToString(-im) if im < 0 else "+"+imToString(im)
26  return fltToString(re)+im_str
27 
28 
29 
30 
31 id = complexify([
32  1, 0, 0, 0,
33  0, 1, 0, 0,
34  0, 0, 1, 0,
35  0, 0, 0, 1
36 ])
37 
38 gamma1 = complexify([
39  0, 0, 0, 1j,
40  0, 0, 1j, 0,
41  0, -1j, 0, 0,
42  -1j, 0, 0, 0
43 ])
44 
45 gamma2 = complexify([
46  0, 0, 0, 1,
47  0, 0, -1, 0,
48  0, -1, 0, 0,
49  1, 0, 0, 0
50 ])
51 
52 gamma3 = complexify([
53  0, 0, 1j, 0,
54  0, 0, 0, -1j,
55  -1j, 0, 0, 0,
56  0, 1j, 0, 0
57 ])
58 
59 gamma4 = complexify([
60  1, 0, 0, 0,
61  0, 1, 0, 0,
62  0, 0, -1, 0,
63  0, 0, 0, -1
64 ])
65 
66 igamma5 = complexify([
67  0, 0, 1j, 0,
68  0, 0, 0, 1j,
69  1j, 0, 0, 0,
70  0, 1j, 0, 0
71 ])
72 
73 
74 def gplus(g1, g2):
75  return [x+y for (x,y) in zip(g1,g2)]
76 
77 def gminus(g1, g2):
78  return [x-y for (x,y) in zip(g1,g2)]
79 
81  out = ""
82  for i in range(0, 4):
83  for j in range(0,4):
84  out += complexToStr(p[4*i+j]) + " "
85  out += "\n"
86  return out
87 
88 projectors = [
89  gminus(id,gamma1), gplus(id,gamma1),
90  gminus(id,gamma2), gplus(id,gamma2),
91  gminus(id,gamma3), gplus(id,gamma3),
92  gminus(id,gamma4), gplus(id,gamma4),
93 ]
94 
95 
96 
97 def indent(code):
98  def indentline(line): return (" "+line if (line.count("#", 0, 1) == 0) else line)
99  return ''.join([indentline(line)+"\n" for line in code.splitlines()])
100 
101 def block(code):
102  return "{\n"+indent(code)+"}"
103 
104 def sign(x):
105  if x==1: return "+"
106  elif x==-1: return "-"
107  elif x==+2: return "+2*"
108  elif x==-2: return "-2*"
109 
110 def nthFloat4(n):
111  return `(n/4)` + "." + ["x", "y", "z", "w"][n%4]
112 
113 def nthFloat2(n):
114  return `(n/2)` + "." + ["x", "y"][n%2]
115 
116 
117 def in_re(s, c): return "i"+`s`+`c`+"_re"
118 def in_im(s, c): return "i"+`s`+`c`+"_im"
119 def g_re(d, m, n): return ("g" if (d%2==0) else "gT")+`m`+`n`+"_re"
120 def g_im(d, m, n): return ("g" if (d%2==0) else "gT")+`m`+`n`+"_im"
121 def out_re(s, c): return "o"+`s`+`c`+"_re"
122 def out_im(s, c): return "o"+`s`+`c`+"_im"
123 def h1_re(h, c): return ["a","b"][h]+`c`+"_re"
124 def h1_im(h, c): return ["a","b"][h]+`c`+"_im"
125 def h2_re(h, c): return ["A","B"][h]+`c`+"_re"
126 def h2_im(h, c): return ["A","B"][h]+`c`+"_im"
127 def c_re(b, sm, cm, sn, cn): return "c"+`(sm+2*b)`+`cm`+"_"+`(sn+2*b)`+`cn`+"_re"
128 def c_im(b, sm, cm, sn, cn): return "c"+`(sm+2*b)`+`cm`+"_"+`(sn+2*b)`+`cn`+"_im"
129 def a_re(b, s, c): return "a"+`(s+2*b)`+`c`+"_re"
130 def a_im(b, s, c): return "a"+`(s+2*b)`+`c`+"_im"
131 
132 def acc_re(s, c): return "acc"+`s`+`c`+"_re"
133 def acc_im(s, c): return "acc"+`s`+`c`+"_im"
134 
135 def tmp_re(s, c): return "tmp"+`s`+`c`+"_re"
136 def tmp_im(s, c): return "tmp"+`s`+`c`+"_im"
137 
138 def spinor(name, s, c, z):
139  if z==0: return name+`s`+`c`+"_re"
140  else: return name+`s`+`c`+"_im"
141 
143  str = ""
144  str += "// input spinor\n"
145  str += "#ifdef SPINOR_DOUBLE\n"
146  str += "#define spinorFloat double\n"
147  if sharedDslash:
148  str += "#define WRITE_SPINOR_SHARED WRITE_SPINOR_SHARED_DOUBLE2\n"
149  str += "#define READ_SPINOR_SHARED READ_SPINOR_SHARED_DOUBLE2\n"
150 
151  for s in range(0,4):
152  for c in range(0,3):
153  i = 3*s+c
154  str += "#define "+in_re(s,c)+" I"+nthFloat2(2*i+0)+"\n"
155  str += "#define "+in_im(s,c)+" I"+nthFloat2(2*i+1)+"\n"
156  if dslash and not pack:
157  for s in range(0,4):
158  for c in range(0,3):
159  i = 3*s+c
160  str += "#define "+acc_re(s,c)+" accum"+nthFloat2(2*i+0)+"\n"
161  str += "#define "+acc_im(s,c)+" accum"+nthFloat2(2*i+1)+"\n"
162  str += "#else\n"
163  str += "#define spinorFloat float\n"
164  if sharedDslash:
165  str += "#define WRITE_SPINOR_SHARED WRITE_SPINOR_SHARED_FLOAT4\n"
166  str += "#define READ_SPINOR_SHARED READ_SPINOR_SHARED_FLOAT4\n"
167  for s in range(0,4):
168  for c in range(0,3):
169  i = 3*s+c
170  str += "#define "+in_re(s,c)+" I"+nthFloat4(2*i+0)+"\n"
171  str += "#define "+in_im(s,c)+" I"+nthFloat4(2*i+1)+"\n"
172  if dslash and not pack:
173  for s in range(0,4):
174  for c in range(0,3):
175  i = 3*s+c
176  str += "#define "+acc_re(s,c)+" accum"+nthFloat4(2*i+0)+"\n"
177  str += "#define "+acc_im(s,c)+" accum"+nthFloat4(2*i+1)+"\n"
178  str += "#endif // SPINOR_DOUBLE\n\n"
179  return str
180 # end def def_input_spinor
181 
182 
183 def def_gauge():
184  str = "// gauge link\n"
185  str += "#ifdef GAUGE_FLOAT2\n"
186  for m in range(0,3):
187  for n in range(0,3):
188  i = 3*m+n
189  str += "#define "+g_re(0,m,n)+" G"+nthFloat2(2*i+0)+"\n"
190  str += "#define "+g_im(0,m,n)+" G"+nthFloat2(2*i+1)+"\n"
191 
192  str += "\n"
193  str += "#else\n"
194  for m in range(0,3):
195  for n in range(0,3):
196  i = 3*m+n
197  str += "#define "+g_re(0,m,n)+" G"+nthFloat4(2*i+0)+"\n"
198  str += "#define "+g_im(0,m,n)+" G"+nthFloat4(2*i+1)+"\n"
199 
200  str += "\n"
201  str += "#endif // GAUGE_DOUBLE\n\n"
202 
203  str += "// conjugated gauge link\n"
204  for m in range(0,3):
205  for n in range(0,3):
206  i = 3*m+n
207  str += "#define "+g_re(1,m,n)+" (+"+g_re(0,n,m)+")\n"
208  str += "#define "+g_im(1,m,n)+" (-"+g_im(0,n,m)+")\n"
209  str += "\n"
210 
211  return str
212 # end def def_gauge
213 
214 
216  str = "// first chiral block of inverted clover term\n"
217  str += "#ifdef CLOVER_DOUBLE\n"
218  i = 0
219  for m in range(0,6):
220  s = m/3
221  c = m%3
222  str += "#define "+c_re(0,s,c,s,c)+" C"+nthFloat2(i)+"\n"
223  i += 1
224  for n in range(0,6):
225  sn = n/3
226  cn = n%3
227  for m in range(n+1,6):
228  sm = m/3
229  cm = m%3
230  str += "#define "+c_re(0,sm,cm,sn,cn)+" C"+nthFloat2(i)+"\n"
231  str += "#define "+c_im(0,sm,cm,sn,cn)+" C"+nthFloat2(i+1)+"\n"
232  i += 2
233  str += "#else\n"
234  i = 0
235  for m in range(0,6):
236  s = m/3
237  c = m%3
238  str += "#define "+c_re(0,s,c,s,c)+" C"+nthFloat4(i)+"\n"
239  i += 1
240  for n in range(0,6):
241  sn = n/3
242  cn = n%3
243  for m in range(n+1,6):
244  sm = m/3
245  cm = m%3
246  str += "#define "+c_re(0,sm,cm,sn,cn)+" C"+nthFloat4(i)+"\n"
247  str += "#define "+c_im(0,sm,cm,sn,cn)+" C"+nthFloat4(i+1)+"\n"
248  i += 2
249  str += "#endif // CLOVER_DOUBLE\n\n"
250 
251  for n in range(0,6):
252  sn = n/3
253  cn = n%3
254  for m in range(0,n):
255  sm = m/3
256  cm = m%3
257  str += "#define "+c_re(0,sm,cm,sn,cn)+" (+"+c_re(0,sn,cn,sm,cm)+")\n"
258  str += "#define "+c_im(0,sm,cm,sn,cn)+" (-"+c_im(0,sn,cn,sm,cm)+")\n"
259  str += "\n"
260 
261  str += "// second chiral block of inverted clover term (reuses C0,...,C9)\n"
262  for n in range(0,6):
263  sn = n/3
264  cn = n%3
265  for m in range(0,6):
266  sm = m/3
267  cm = m%3
268  str += "#define "+c_re(1,sm,cm,sn,cn)+" "+c_re(0,sm,cm,sn,cn)+"\n"
269  if m != n: str += "#define "+c_im(1,sm,cm,sn,cn)+" "+c_im(0,sm,cm,sn,cn)+"\n"
270  str += "\n"
271 
272  return str
273 # end def def_clover
274 
275 
277 # sharedDslash = True: input spinors stored in shared memory
278 # sharedDslash = False: output spinors stored in shared memory
279  str = "// output spinor\n"
280  for s in range(0,4):
281  for c in range(0,3):
282  i = 3*s+c
283  if 2*i < sharedFloats and not sharedDslash:
284  str += "#define "+out_re(s,c)+" s["+`(2*i+0)`+"*SHARED_STRIDE]\n"
285  else:
286  str += "VOLATILE spinorFloat "+out_re(s,c)+";\n"
287  if 2*i+1 < sharedFloats and not sharedDslash:
288  str += "#define "+out_im(s,c)+" s["+`(2*i+1)`+"*SHARED_STRIDE]\n"
289  else:
290  str += "VOLATILE spinorFloat "+out_im(s,c)+";\n"
291  return str
292 # end def def_output_spinor
293 
294 
295 def prolog():
296  global arch
297 
298  if dslash:
299  prolog_str= ("// *** CUDA DSLASH ***\n\n" if not dagger else "// *** CUDA DSLASH DAGGER ***\n\n")
300  prolog_str+= "#define DSLASH_SHARED_FLOATS_PER_THREAD "+str(sharedFloats)+"\n\n"
301  elif clover:
302  prolog_str= ("// *** CUDA CLOVER ***\n\n")
303  prolog_str+= "#define CLOVER_SHARED_FLOATS_PER_THREAD "+str(sharedFloats)+"\n\n"
304  else:
305  print "Undefined prolog"
306  exit
307 
308  prolog_str+= (
309 """
310 #if ((CUDA_VERSION >= 4010) && (__COMPUTE_CAPABILITY__ >= 200)) // NVVM compiler
311 #define VOLATILE
312 #else // Open64 compiler
313 #define VOLATILE volatile
314 #endif
315 """)
316 
317  prolog_str+= def_input_spinor()
318  if dslash == True: prolog_str+= def_gauge()
319  if clover == True: prolog_str+= def_clover()
320  prolog_str+= def_output_spinor()
321 
322  if (sharedFloats > 0):
323  if (arch >= 200):
324  prolog_str+= (
325 """
326 #ifdef SPINOR_DOUBLE
327 #define SHARED_STRIDE 16 // to avoid bank conflicts on Fermi
328 #else
329 #define SHARED_STRIDE 32 // to avoid bank conflicts on Fermi
330 #endif
331 """)
332  else:
333  prolog_str+= (
334 """
335 #ifdef SPINOR_DOUBLE
336 #define SHARED_STRIDE 8 // to avoid bank conflicts on G80 and GT200
337 #else
338 #define SHARED_STRIDE 16 // to avoid bank conflicts on G80 and GT200
339 #endif
340 """)
341 
342 
343  # set the pointer if using shared memory for pseudo registers
344  if sharedFloats > 0 and not sharedDslash:
345  prolog_str += (
346 """
347 extern __shared__ char s_data[];
348 """)
349 
350  if dslash:
351  prolog_str += (
352 """
353 VOLATILE spinorFloat *s = (spinorFloat*)s_data + DSLASH_SHARED_FLOATS_PER_THREAD*SHARED_STRIDE*(threadIdx.x/SHARED_STRIDE)
354  + (threadIdx.x % SHARED_STRIDE);
355 """)
356  else:
357  prolog_str += (
358 """
359 VOLATILE spinorFloat *s = (spinorFloat*)s_data + CLOVER_SHARED_FLOATS_PER_THREAD*SHARED_STRIDE*(threadIdx.x/SHARED_STRIDE)
360  + (threadIdx.x % SHARED_STRIDE);
361 """)
362 
363 
364  if dslash:
365  prolog_str += (
366 """
367 #include "read_gauge.h"
368 #include "read_clover.h"
369 #include "io_spinor.h"
370 
371 int coord[5];
372 int X;
373 
374 int sid;
375 """)
376 
377  if sharedDslash:
378  prolog_str += (
379 """
380 #ifdef MULTI_GPU
381 int face_idx;
382 if (kernel_type == INTERIOR_KERNEL) {
383 #endif
384 
385  // Assume even dimensions
386  coordsFromIndex3D<EVEN_X>(X, x, sid, param.parity, param.dc.X);
387 
388  // only need to check Y and Z dims currently since X and T set to match exactly
389  if (coord[1] >= param.dc.X[1]) return;
390  if (coord[2] >= param.dc.X[2]) return;
391 
392 """)
393  else:
394  prolog_str += (
395 """
396 #ifdef MULTI_GPU
397 int face_idx;
398 if (kernel_type == INTERIOR_KERNEL) {
399 #endif
400 
401  sid = blockIdx.x*blockDim.x + threadIdx.x;
402  if (sid >= param.threads) return;
403 
404  // Assume even dimensions
405  coordsFromIndex<4,QUDA_4D_PC,EVEN_X>(X, coord, sid, param);
406 
407 """)
408 
409  out = ""
410  for s in range(0,4):
411  for c in range(0,3):
412  out += out_re(s,c)+" = 0; "+out_im(s,c)+" = 0;\n"
413  prolog_str+= indent(out)
414  if asymClover: prolog_str += indent(clover_xpay())
415 
416  prolog_str+= (
417 """
418 #ifdef MULTI_GPU
419 } else { // exterior kernel
420 
421  sid = blockIdx.x*blockDim.x + threadIdx.x;
422  if (sid >= param.threads) return;
423 
424  const int face_volume = (param.threads >> 1); // volume of one face
425  const int face_num = (sid >= face_volume); // is this thread updating face 0 or 1
426  face_idx = sid - face_num*face_volume; // index into the respective face
427 
428  // ghostOffset is scaled to include body (includes stride) and number of FloatN arrays (SPINOR_HOP)
429  // face_idx not sid since faces are spin projected and share the same volume index (modulo UP/DOWN reading)
430  //sp_idx = face_idx + param.ghostOffset[dim];
431 
432  coordsFromFaceIndex<4,QUDA_4D_PC,kernel_type,1>(X, sid, coord, face_idx, face_num, param);
433 
434  READ_INTERMEDIATE_SPINOR(INTERTEX, param.sp_stride, sid, sid);
435 
436 """)
437 
438  out = ""
439  for s in range(0,4):
440  for c in range(0,3):
441  out += out_re(s,c)+" = "+in_re(s,c)+"; "+out_im(s,c)+" = "+in_im(s,c)+";\n"
442  prolog_str+= indent(out)
443  prolog_str+= "}\n"
444  prolog_str+= "#endif // MULTI_GPU\n\n\n"
445 
446  else:
447  prolog_str+=(
448 """
449 #include "read_clover.h"
450 #include "io_spinor.h"
451 
452 int sid = blockIdx.x*blockDim.x + threadIdx.x;
453 if (sid >= param.threads) return;
454 
455 // read spinor from device memory
456 READ_SPINOR(SPINORTEX, param.sp_stride, sid, sid);
457 """)
458  return prolog_str
459 # end def prolog
460 
461 
462 def gen(dir, pack_only=False):
463  projIdx = dir if not dagger else dir + (1 - 2*(dir%2))
464  projStr = projectorToStr(projectors[projIdx])
465  def proj(i,j):
466  return projectors[projIdx][4*i+j]
467 
468  # if row(i) = (j, c), then the i'th row of the projector can be represented
469  # as a multiple of the j'th row: row(i) = c row(j)
470  def row(i):
471  assert i==2 or i==3
472  if proj(i,0) == 0j:
473  return (1, proj(i,1))
474  if proj(i,1) == 0j:
475  return (0, proj(i,0))
476 
477  boundary = ["coord[0]==(param.dc.X[0]-1)", "coord[0]==0", "coord[1]==(param.dc.X[1]-1)", "coord[1]==0", "coord[2]==(param.dc.X[2]-1)", "coord[2]==0", "coord[3]==(param.dc.X[3]-1)", "coord[3]==0"]
478  interior = ["coord[0]<(param.dc.X[0]-1)", "coord[0]>0", "coord[1]<(param.dc.X[1]-1)", "coord[1]>0", "coord[2]<(param.dc.X[2]-1)", "coord[2]>0", "coord[3]<(param.dc.X[3]-1)", "coord[3]>0"]
479  dim = ["X", "Y", "Z", "T"]
480 
481  # index of neighboring site when not on boundary
482  sp_idx = ["X+1", "X-1", "X+param.dc.X[0]", "X-param.dc.X[0]", "X+param.dc.X2X1", "X-param.dc.X2X1", "X+param.dc.X3X2X1", "X-param.dc.X3X2X1"]
483 
484  # index of neighboring site (across boundary)
485  sp_idx_wrap = ["X-(param.dc.X[0]-1)", "X+(param.dc.X[0]-1)", "X-param.dc.X2X1mX1", "X+param.dc.X2X1mX1", "X-param.dc.X3X2X1mX2X1", "X+param.dc.X3X2X1mX2X1",
486  "X-param.dc.X4X3X2X1mX3X2X1", "X+param.dc.X4X3X2X1mX3X2X1"]
487 
488  cond = ""
489  cond += "#ifdef MULTI_GPU\n"
490  cond += "if ( (kernel_type == INTERIOR_KERNEL && (!param.ghostDim["+`dir/2`+"] || "+interior[dir]+")) ||\n"
491  cond += " (kernel_type == EXTERIOR_KERNEL_"+dim[dir/2]+" && "+boundary[dir]+") )\n"
492  cond += "#endif\n"
493 
494  str = ""
495 
496  projName = "P"+`dir/2`+["-","+"][projIdx%2]
497  str += "// Projector "+projName+"\n"
498  for l in projStr.splitlines():
499  str += "// "+l+"\n"
500  str += "\n"
501 
502  str += "#ifdef MULTI_GPU\n"
503  str += "const int sp_idx = (kernel_type == INTERIOR_KERNEL) ? ("+boundary[dir]+" ? "+sp_idx_wrap[dir]+" : "+sp_idx[dir]+") >> 1 :\n"
504  str += " face_idx + param.ghostOffset[static_cast<int>(kernel_type)][" + `(dir+1)%2` + "];\n"
505  str += "#if (DD_PREC==2) // half precision\n"
506  str += "const int sp_norm_idx = face_idx + param.ghostNormOffset[static_cast<int>(kernel_type)][" + `(dir+1)%2` + "];\n"
507  str += "#endif\n"
508  str += "#else\n"
509  str += "const int sp_idx = ("+boundary[dir]+" ? "+sp_idx_wrap[dir]+" : "+sp_idx[dir]+") >> 1;\n"
510  str += "#endif\n"
511 
512  str += "\n"
513  if dir % 2 == 0:
514  str += "const int ga_idx = sid;\n"
515  else:
516  str += "#ifdef MULTI_GPU\n"
517  str += "const int ga_idx = ((kernel_type == INTERIOR_KERNEL) ? sp_idx : param.dc.Vh+face_idx);\n"
518  str += "#else\n"
519  str += "const int ga_idx = sp_idx;\n"
520  str += "#endif\n"
521  str += "\n"
522 
523  # scan the projector to determine which loads are required
524  row_cnt = ([0,0,0,0])
525  for h in range(0,4):
526  for s in range(0,4):
527  re = proj(h,s).real
528  im = proj(h,s).imag
529  if re != 0 or im != 0:
530  row_cnt[h] += 1
531  row_cnt[0] += row_cnt[1]
532  row_cnt[2] += row_cnt[3]
533 
534  decl_half = ""
535  for h in range(0, 2):
536  for c in range(0, 3):
537  decl_half += "spinorFloat "+h1_re(h,c)+", "+h1_im(h,c)+";\n";
538  decl_half += "\n"
539 
540  load_spinor = "// read spinor from device memory\n"
541  if row_cnt[0] == 0:
542  load_spinor += "READ_SPINOR_DOWN(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
543  elif row_cnt[2] == 0:
544  load_spinor += "READ_SPINOR_UP(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
545  else:
546  load_spinor += "READ_SPINOR(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
547  load_spinor += "\n"
548 
549  load_half = ""
550  load_half += "const int sp_stride_pad = param.dc.ghostFace[static_cast<int>(kernel_type)];\n"
551 
552  if dir >= 6: load_half += "const int t_proj_scale = TPROJSCALE;\n"
553  load_half += "\n"
554  load_half += "// read half spinor from device memory\n"
555 
556 # we have to use the same volume index for backwards and forwards gathers
557 # instead of using READ_UP_SPINOR and READ_DOWN_SPINOR, just use READ_HALF_SPINOR with the appropriate shift
558  load_half += "READ_SPINOR_GHOST(GHOSTSPINORTEX, sp_stride_pad, sp_idx, sp_norm_idx, "+`dir`+");\n\n"
559  load_gauge = "// read gauge matrix from device memory\n"
560  load_gauge += "READ_GAUGE_MATRIX(G, GAUGE"+`dir%2`+"TEX, "+`dir`+", ga_idx, param.gauge_stride);\n\n"
561 
562  reconstruct_gauge = "// reconstruct gauge matrix\n"
563  reconstruct_gauge += "RECONSTRUCT_GAUGE_MATRIX("+`dir`+");\n\n"
564 
565  project = "// project spinor into half spinors\n"
566  for h in range(0, 2):
567  for c in range(0, 3):
568  strRe = ""
569  strIm = ""
570  for s in range(0, 4):
571  re = proj(h,s).real
572  im = proj(h,s).imag
573  if re==0 and im==0: ()
574  elif im==0:
575  strRe += sign(re)+in_re(s,c)
576  strIm += sign(re)+in_im(s,c)
577  elif re==0:
578  strRe += sign(-im)+in_im(s,c)
579  strIm += sign(im)+in_re(s,c)
580  if row_cnt[0] == 0: # projector defined on lower half only
581  for s in range(0, 4):
582  re = proj(h+2,s).real
583  im = proj(h+2,s).imag
584  if re==0 and im==0: ()
585  elif im==0:
586  strRe += sign(re)+in_re(s,c)
587  strIm += sign(re)+in_im(s,c)
588  elif re==0:
589  strRe += sign(-im)+in_im(s,c)
590  strIm += sign(im)+in_re(s,c)
591 
592  project += h1_re(h,c)+" = "+strRe+";\n"
593  project += h1_im(h,c)+" = "+strIm+";\n"
594 
595  write_shared = (
596 """// store spinor into shared memory
597 WRITE_SPINOR_SHARED(threadIdx.x, threadIdx.y, threadIdx.z, i);\n
598 """)
599 
600  load_shared_1 = (
601 """// load spinor from shared memory
602 int tx = (threadIdx.x > 0) ? threadIdx.x-1 : blockDim.x-1;
603 __syncthreads();
604 READ_SPINOR_SHARED(tx, threadIdx.y, threadIdx.z);\n
605 """)
606 
607  load_shared_2 = (
608 """// load spinor from shared memory
609 int tx = (threadIdx.x + blockDim.x - ((coord[0]+1)&1) ) % blockDim.x;
610 int ty = (threadIdx.y < blockDim.y - 1) ? threadIdx.y + 1 : 0;
611 READ_SPINOR_SHARED(tx, ty, threadIdx.z);\n
612 """)
613 
614  load_shared_3 = (
615 """// load spinor from shared memory
616 int tx = (threadIdx.x + blockDim.x - ((coord[0]+1)&1)) % blockDim.x;
617 int ty = (threadIdx.y > 0) ? threadIdx.y - 1 : blockDim.y - 1;
618 READ_SPINOR_SHARED(tx, ty, threadIdx.z);\n
619 """)
620 
621  load_shared_4 = (
622 """// load spinor from shared memory
623 int tx = (threadIdx.x + blockDim.x - ((coord[0]+1)&1) ) % blockDim.x;
624 int tz = (threadIdx.z < blockDim.z - 1) ? threadIdx.z + 1 : 0;
625 READ_SPINOR_SHARED(tx, threadIdx.y, tz);\n
626 """)
627 
628  load_shared_5 = (
629 """// load spinor from shared memory
630 int tx = (threadIdx.x + blockDim.x - ((coord[0]+1)&1)) % blockDim.x;
631 int tz = (threadIdx.z > 0) ? threadIdx.z - 1 : blockDim.z - 1;
632 READ_SPINOR_SHARED(tx, threadIdx.y, tz);\n
633 """)
634 
635 
636  copy_half = ""
637  for h in range(0, 2):
638  for c in range(0, 3):
639  copy_half += h1_re(h,c)+" = "+("t_proj_scale*" if (dir >= 6) else "")+in_re(h,c)+"; "
640  copy_half += h1_im(h,c)+" = "+("t_proj_scale*" if (dir >= 6) else "")+in_im(h,c)+";\n"
641  copy_half += "\n"
642 
643  prep_half = ""
644  prep_half += "#ifdef MULTI_GPU\n"
645  prep_half += "if (kernel_type == INTERIOR_KERNEL) {\n"
646  prep_half += "#endif\n"
647  prep_half += "\n"
648 
649  if sharedDslash:
650  if dir == 0:
651  prep_half += indent(load_spinor)
652  prep_half += indent(write_shared)
653  prep_half += indent(project)
654  elif dir == 1:
655  prep_half += indent(load_shared_1)
656  prep_half += indent(project)
657  elif dir == 2:
658  prep_half += indent("if (threadIdx.y == blockDim.y-1 && blockDim.y < param.dc.X[1] ) {\n")
659  prep_half += indent(load_spinor)
660  prep_half += indent(project)
661  prep_half += indent("} else {")
662  prep_half += indent(load_shared_2)
663  prep_half += indent(project)
664  prep_half += indent("}")
665  elif dir == 3:
666  prep_half += indent("if (threadIdx.y == 0 && blockDim.y < param.dc.X[1]) {\n")
667  prep_half += indent(load_spinor)
668  prep_half += indent(project)
669  prep_half += indent("} else {")
670  prep_half += indent(load_shared_3)
671  prep_half += indent(project)
672  prep_half += indent("}")
673  elif dir == 4:
674  prep_half += indent("if (threadIdx.z == blockDim.z-1 && blockDim.z < param.dc.X[2]) {\n")
675  prep_half += indent(load_spinor)
676  prep_half += indent(project)
677  prep_half += indent("} else {")
678  prep_half += indent(load_shared_4)
679  prep_half += indent(project)
680  prep_half += indent("}")
681  elif dir == 5:
682  prep_half += indent("if (threadIdx.z == 0 && blockDim.z < param.dc.X[2]) {\n")
683  prep_half += indent(load_spinor)
684  prep_half += indent(project)
685  prep_half += indent("} else {")
686  prep_half += indent(load_shared_5)
687  prep_half += indent(project)
688  prep_half += indent("}")
689  else:
690  prep_half += indent(load_spinor)
691  prep_half += indent(project)
692  else:
693  prep_half += indent(load_spinor)
694  prep_half += indent(project)
695 
696  prep_half += "\n"
697  prep_half += "#ifdef MULTI_GPU\n"
698  prep_half += "} else {\n"
699  prep_half += "\n"
700  prep_half += indent(load_half)
701  prep_half += indent(copy_half)
702  prep_half += "}\n"
703  prep_half += "#endif // MULTI_GPU\n"
704  prep_half += "\n"
705 
706  ident = "// identity gauge matrix\n"
707  for m in range(0,3):
708  for h in range(0,2):
709  ident += "spinorFloat "+h2_re(h,m)+" = " + h1_re(h,m) + "; "
710  ident += "spinorFloat "+h2_im(h,m)+" = " + h1_im(h,m) + ";\n"
711  ident += "\n"
712 
713  mult = ""
714  for m in range(0,3):
715  mult += "// multiply row "+`m`+"\n"
716  for h in range(0,2):
717  re = "spinorFloat "+h2_re(h,m)+" = 0;\n"
718  im = "spinorFloat "+h2_im(h,m)+" = 0;\n"
719  for c in range(0,3):
720  re += h2_re(h,m) + " += " + g_re(dir,m,c) + " * "+h1_re(h,c)+";\n"
721  re += h2_re(h,m) + " -= " + g_im(dir,m,c) + " * "+h1_im(h,c)+";\n"
722  im += h2_im(h,m) + " += " + g_re(dir,m,c) + " * "+h1_im(h,c)+";\n"
723  im += h2_im(h,m) + " += " + g_im(dir,m,c) + " * "+h1_re(h,c)+";\n"
724  mult += re + im
725  mult += "\n"
726 
727  reconstruct = ""
728  if asymClover:
729  reconstruct += "#ifdef SPINOR_DOUBLE\n"
730  reconstruct += "spinorFloat a = param.a;\n"
731  reconstruct += "#else\n"
732  reconstruct += "spinorFloat a = param.a_f;\n"
733  reconstruct += "#endif\n"
734 
735  for m in range(0,3):
736 
737  for h in range(0,2):
738  h_out = h
739  if row_cnt[0] == 0: # projector defined on lower half only
740  h_out = h+2
741  if not asymClover:
742  reconstruct += out_re(h_out, m) + " += " + h2_re(h,m) + ";\n"
743  reconstruct += out_im(h_out, m) + " += " + h2_im(h,m) + ";\n"
744  else:
745  reconstruct += out_re(h_out, m) + " += a*" + h2_re(h,m) + ";\n"
746  reconstruct += out_im(h_out, m) + " += a*" + h2_im(h,m) + ";\n"
747 
748  for s in range(2,4):
749  (h,c) = row(s)
750  re = c.real
751  im = c.imag
752  if not asymClover:
753  if im == 0 and re == 0: ()
754  elif im == 0:
755  reconstruct += out_re(s, m) + " " + sign(re) + "= " + h2_re(h,m) + ";\n"
756  reconstruct += out_im(s, m) + " " + sign(re) + "= " + h2_im(h,m) + ";\n"
757  elif re == 0:
758  reconstruct += out_re(s, m) + " " + sign(-im) + "= " + h2_im(h,m) + ";\n"
759  reconstruct += out_im(s, m) + " " + sign(+im) + "= " + h2_re(h,m) + ";\n"
760  else:
761  if im == 0 and re == 0: ()
762  elif im == 0:
763  reconstruct += out_re(s, m) + " " + sign(re) + "= a*" + h2_re(h,m) + ";\n"
764  reconstruct += out_im(s, m) + " " + sign(re) + "= a*" + h2_im(h,m) + ";\n"
765  elif re == 0:
766  reconstruct += out_re(s, m) + " " + sign(-im) + "= a*" + h2_im(h,m) + ";\n"
767  reconstruct += out_im(s, m) + " " + sign(+im) + "= a*" + h2_re(h,m) + ";\n"
768 
769  reconstruct += "\n"
770 
771  if dir >= 6:
772  str += "if (param.gauge_fixed && ga_idx < param.dc.X4X3X2X1hmX3X2X1h)\n"
773  str += block(decl_half + prep_half + ident + reconstruct)
774  str += " else "
775  str += block(decl_half + prep_half + load_gauge + reconstruct_gauge + mult + reconstruct)
776  else:
777  str += decl_half + prep_half + load_gauge + reconstruct_gauge + mult + reconstruct
778 
779  if pack_only:
780  out = load_spinor + decl_half + project
781  out = out.replace("sp_idx", "idx")
782  return out
783  else:
784  return cond + block(str)+"\n\n"
785 # end def gen
786 
787 
788 def input_spinor(s,c,z):
789  if dslash:
790  if z==0: return out_re(s,c)
791  else: return out_im(s,c)
792  else:
793  if z==0: return in_re(s,c)
794  else: return in_im(s,c)
795 
796 def to_chiral_basis(v_out,v_in,c):
797  str = ""
798  str += "spinorFloat "+a_re(0,0,c)+" = -"+spinor(v_in,1,c,0)+" - "+spinor(v_in,3,c,0)+";\n"
799  str += "spinorFloat "+a_im(0,0,c)+" = -"+spinor(v_in,1,c,1)+" - "+spinor(v_in,3,c,1)+";\n"
800  str += "spinorFloat "+a_re(0,1,c)+" = "+spinor(v_in,0,c,0)+" + "+spinor(v_in,2,c,0)+";\n"
801  str += "spinorFloat "+a_im(0,1,c)+" = "+spinor(v_in,0,c,1)+" + "+spinor(v_in,2,c,1)+";\n"
802  str += "spinorFloat "+a_re(0,2,c)+" = -"+spinor(v_in,1,c,0)+" + "+spinor(v_in,3,c,0)+";\n"
803  str += "spinorFloat "+a_im(0,2,c)+" = -"+spinor(v_in,1,c,1)+" + "+spinor(v_in,3,c,1)+";\n"
804  str += "spinorFloat "+a_re(0,3,c)+" = "+spinor(v_in,0,c,0)+" - "+spinor(v_in,2,c,0)+";\n"
805  str += "spinorFloat "+a_im(0,3,c)+" = "+spinor(v_in,0,c,1)+" - "+spinor(v_in,2,c,1)+";\n"
806  str += "\n"
807 
808  for s in range (0,4):
809  str += spinor(v_out,s,c,0)+" = "+a_re(0,s,c)+"; "
810  str += spinor(v_out,s,c,1)+" = "+a_im(0,s,c)+";\n"
811 
812  return block(str)+"\n\n"
813 # end def to_chiral_basis
814 
815 def from_chiral_basis(v_out,v_in,c): # note: factor of 1/2 is included in clover term normalization
816  str = ""
817  str += "spinorFloat "+a_re(0,0,c)+" = "+spinor(v_in,1,c,0)+" + "+spinor(v_in,3,c,0)+";\n"
818  str += "spinorFloat "+a_im(0,0,c)+" = "+spinor(v_in,1,c,1)+" + "+spinor(v_in,3,c,1)+";\n"
819  str += "spinorFloat "+a_re(0,1,c)+" = -"+spinor(v_in,0,c,0)+" - "+spinor(v_in,2,c,0)+";\n"
820  str += "spinorFloat "+a_im(0,1,c)+" = -"+spinor(v_in,0,c,1)+" - "+spinor(v_in,2,c,1)+";\n"
821  str += "spinorFloat "+a_re(0,2,c)+" = "+spinor(v_in,1,c,0)+" - "+spinor(v_in,3,c,0)+";\n"
822  str += "spinorFloat "+a_im(0,2,c)+" = "+spinor(v_in,1,c,1)+" - "+spinor(v_in,3,c,1)+";\n"
823  str += "spinorFloat "+a_re(0,3,c)+" = -"+spinor(v_in,0,c,0)+" + "+spinor(v_in,2,c,0)+";\n"
824  str += "spinorFloat "+a_im(0,3,c)+" = -"+spinor(v_in,0,c,1)+" + "+spinor(v_in,2,c,1)+";\n"
825  str += "\n"
826 
827  for s in range (0,4):
828  str += spinor(v_out,s,c,0)+" = "+a_re(0,s,c)+"; "
829  str += spinor(v_out,s,c,1)+" = "+a_im(0,s,c)+";\n"
830 
831  return block(str)+"\n\n"
832 # end def from_chiral_basis
833 
834 
835 def clover_mult(v_out, v_in, chi):
836  str = "READ_CLOVER(CLOVERTEX, "+`chi`+")\n\n"
837 
838  for s in range (0,2):
839  for c in range (0,3):
840  str += "spinorFloat "+a_re(chi,s,c)+" = 0; spinorFloat "+a_im(chi,s,c)+" = 0;\n"
841  str += "\n"
842 
843  for sm in range (0,2):
844  for cm in range (0,3):
845  for sn in range (0,2):
846  for cn in range (0,3):
847  str += a_re(chi,sm,cm)+" += "+c_re(chi,sm,cm,sn,cn)+" * "+spinor(v_in,2*chi+sn,cn,0)+";\n"
848  if (sn != sm) or (cn != cm):
849  str += a_re(chi,sm,cm)+" -= "+c_im(chi,sm,cm,sn,cn)+" * "+spinor(v_in,2*chi+sn,cn,1)+";\n"
850  #else: str += ";\n"
851  str += a_im(chi,sm,cm)+" += "+c_re(chi,sm,cm,sn,cn)+" * "+spinor(v_in,2*chi+sn,cn,1)+";\n"
852  if (sn != sm) or (cn != cm):
853  str += a_im(chi,sm,cm)+" += "+c_im(chi,sm,cm,sn,cn)+" * "+spinor(v_in,2*chi+sn,cn,0)+";\n"
854  #else: str += ";\n"
855  str += "\n"
856 
857  for s in range (0,2):
858  for c in range (0,3):
859  str += spinor(v_out,2*chi+s,c,0)+" = "+a_re(chi,s,c)+"; "
860  str += spinor(v_out,2*chi+s,c,1)+" = "+a_im(chi,s,c)+";\n"
861  str += "\n"
862 
863  return block(str)+"\n\n"
864 # end def clover_mult
865 
866 
867 def apply_clover(v_out,v_in):
868  str = ""
869  if dslash: str += "#ifdef DSLASH_CLOVER\n\n"
870  str += "// change to chiral basis\n"
871  str += to_chiral_basis(v_out,v_in,0) + to_chiral_basis(v_out,v_in,1) + to_chiral_basis(v_out,v_in,2)
872  str += "// apply first chiral block\n"
873  str += clover_mult(v_out,v_out,0)
874  str += "// apply second chiral block\n"
875  str += clover_mult(v_out,v_out,1)
876  str += "// change back from chiral basis\n"
877  str += "// (note: required factor of 1/2 is included in clover term normalization)\n"
878  str += from_chiral_basis(v_out,v_out,0) + from_chiral_basis(v_out,v_out,1) + from_chiral_basis(v_out,v_out,2)
879  if dslash: str += "#endif // DSLASH_CLOVER\n\n"
880 
881  return str
882 # end def clover
883 
884 
886  str = "// apply twisted mass rotation\n"
887 
888  for h in range(0, 4):
889  for c in range(0, 3):
890  strRe = ""
891  strIm = ""
892  for s in range(0, 4):
893  # identity
894  re = id[4*h+s].real
895  im = id[4*h+s].imag
896  if re==0 and im==0: ()
897  elif im==0:
898  strRe += sign(re)+out_re(s,c)
899  strIm += sign(re)+out_im(s,c)
900  elif re==0:
901  strRe += sign(-im)+out_im(s,c)
902  strIm += sign(im)+out_re(s,c)
903 
904  # sign(x)*i*mu*gamma_5
905  re = igamma5[4*h+s].real
906  im = igamma5[4*h+s].imag
907  if re==0 and im==0: ()
908  elif im==0:
909  strRe += sign(re*x)+out_re(s,c) + "*a"
910  strIm += sign(re*x)+out_im(s,c) + "*a"
911  elif re==0:
912  strRe += sign(-im*x)+out_im(s,c) + "*a"
913  strIm += sign(im*x)+out_re(s,c) + "*a"
914 
915  str += "VOLATILE spinorFloat "+tmp_re(h,c)+" = " + strRe + ";\n"
916  str += "VOLATILE spinorFloat "+tmp_im(h,c)+" = " + strIm + ";\n"
917  str += "\n"
918 
919  return str+"\n"
920 
921 
922 def twisted():
923  str = ""
924  str += twisted_rotate(+1)
925 
926  str += "#ifndef DSLASH_XPAY\n"
927  str += "//scale by b = 1/(1 + a*a) \n"
928  for s in range(0,4):
929  for c in range(0,3):
930  str += out_re(s,c) + " = b*" + tmp_re(s,c) + ";\n"
931  str += out_im(s,c) + " = b*" + tmp_im(s,c) + ";\n"
932  str += "#else\n"
933  for s in range(0,4):
934  for c in range(0,3):
935  str += out_re(s,c) + " = " + tmp_re(s,c) + ";\n"
936  str += out_im(s,c) + " = " + tmp_im(s,c) + ";\n"
937  str += "#endif // DSLASH_XPAY\n"
938  str += "\n"
939 
940  return block(str)+"\n"
941 # end def twisted
942 
943 
945  str = ""
946  str += "#ifdef DSLASH_CLOVER_XPAY\n\n"
947  str += "READ_ACCUM(ACCUMTEX, param.sp_stride)\n\n"
948 
949  str += apply_clover("acc","acc")
950 
951  for s in range(0,4):
952  for c in range(0,3):
953  i = 3*s+c
954  str += out_re(s,c) +" = "+acc_re(s,c)+";\n"
955  str += out_im(s,c) +" = "+acc_im(s,c)+";\n"
956 
957  str += "#endif // DSLASH_CLOVER_XPAY\n"
958 
959  return str
960 #end def clover_xpay
961 
962 def xpay():
963  str = ""
964  str += "#ifdef DSLASH_XPAY\n\n"
965  str += "READ_ACCUM(ACCUMTEX, param.sp_stride)\n\n"
966 
967  str += "#ifdef SPINOR_DOUBLE\n"
968  str += "spinorFloat a = param.a;\n"
969  str += "#else\n"
970  str += "spinorFloat a = param.a_f;\n"
971  str += "#endif\n"
972 
973  for s in range(0,4):
974  for c in range(0,3):
975  i = 3*s+c
976  if twist == False:
977  str += out_re(s,c) +" = a*"+out_re(s,c)+"+"+acc_re(s,c)+";\n"
978  str += out_im(s,c) +" = a*"+out_im(s,c)+"+"+acc_im(s,c)+";\n"
979  else:
980  str += out_re(s,c) +" = b*"+out_re(s,c)+"+"+acc_re(s,c)+";\n"
981  str += out_im(s,c) +" = b*"+out_im(s,c)+"+"+acc_im(s,c)+";\n"
982 
983  str += "#endif // DSLASH_XPAY\n"
984 
985  return str
986 # end def xpay
987 
988 
989 def epilog():
990  str = ""
991  if dslash and not asymClover:
992  if twist:
993  str += "#ifdef MULTI_GPU\n"
994  else:
995  str += "#if defined MULTI_GPU && (defined DSLASH_XPAY || defined DSLASH_CLOVER)\n"
996  str += (
997 """
998 int incomplete = 0; // Have all 8 contributions been computed for this site?
999 
1000 switch(kernel_type) { // intentional fall-through
1001 case INTERIOR_KERNEL:
1002  incomplete = incomplete || (param.commDim[3] && (coord[3]==0 || coord[3]==(param.dc.X[3]-1)));
1003 case EXTERIOR_KERNEL_T:
1004  incomplete = incomplete || (param.commDim[2] && (coord[2]==0 || coord[2]==(param.dc.X[2]-1)));
1005 case EXTERIOR_KERNEL_Z:
1006  incomplete = incomplete || (param.commDim[1] && (coord[1]==0 || coord[1]==(param.dc.X[1]-1)));
1007 case EXTERIOR_KERNEL_Y:
1008  incomplete = incomplete || (param.commDim[0] && (coord[0]==0 || coord[0]==(param.dc.X[0]-1)));
1009 }
1010 
1011 """)
1012  str += "if (!incomplete)\n"
1013  str += "#endif // MULTI_GPU\n"
1014 
1015  if not asymClover:
1016  block_str = ""
1017  if twist: block_str += twisted()
1018  #elif asymClover: block_str += clover_xpay()
1019  elif dslash: block_str += apply_clover("o","o")
1020  else: block_str += apply_clover("o","i")
1021  if not asymClover: block_str += xpay()
1022 
1023  str += block( block_str )
1024 
1025  str += "\n\n"
1026  str += "// write spinor field back to device memory\n"
1027  str += "WRITE_SPINOR(param.sp_stride);\n\n"
1028 
1029  str += "// undefine to prevent warning when precision is changed\n"
1030  str += "#undef spinorFloat\n"
1031  if sharedDslash:
1032  str += "#undef WRITE_SPINOR_SHARED\n"
1033  str += "#undef READ_SPINOR_SHARED\n"
1034  if sharedFloats > 0: str += "#undef SHARED_STRIDE\n\n"
1035 
1036  if dslash:
1037  for m in range(0,3):
1038  for n in range(0,3):
1039  i = 3*m+n
1040  str += "#undef "+g_re(0,m,n)+"\n"
1041  str += "#undef "+g_im(0,m,n)+"\n"
1042  str += "\n"
1043 
1044  for s in range(0,4):
1045  for c in range(0,3):
1046  i = 3*s+c
1047  str += "#undef "+in_re(s,c)+"\n"
1048  str += "#undef "+in_im(s,c)+"\n"
1049  str += "\n"
1050 
1051  if dslash:
1052  for s in range(0,4):
1053  for c in range(0,3):
1054  i = 3*s+c
1055  str += "#undef "+acc_re(s,c)+"\n"
1056  str += "#undef "+acc_im(s,c)+"\n"
1057  str += "\n"
1058 
1059  if clover == True:
1060  for m in range(0,6):
1061  s = m/3
1062  c = m%3
1063  str += "#undef "+c_re(0,s,c,s,c)+"\n"
1064  for n in range(0,6):
1065  sn = n/3
1066  cn = n%3
1067  for m in range(n+1,6):
1068  sm = m/3
1069  cm = m%3
1070  str += "#undef "+c_re(0,sm,cm,sn,cn)+"\n"
1071  str += "#undef "+c_im(0,sm,cm,sn,cn)+"\n"
1072  str += "\n"
1073 
1074  for s in range(0,4):
1075  for c in range(0,3):
1076  i = 3*s+c
1077  if 2*i < sharedFloats:
1078  str += "#undef "+out_re(s,c)+"\n"
1079  if 2*i+1 < sharedFloats:
1080  str += "#undef "+out_im(s,c)+"\n"
1081  str += "\n"
1082 
1083  str += "#undef VOLATILE\n"
1084 
1085  return str
1086 # end def epilog
1087 
1088 
1089 def pack_face(facenum):
1090  str = "\n"
1091  str += "switch(dim) {\n"
1092  for dim in range(0,4):
1093  str += "case "+`dim`+":\n"
1094  proj = gen(2*dim+facenum, pack_only=True)
1095  proj += "\n"
1096  proj += "// write half spinor back to device memory\n"
1097  proj += "WRITE_HALF_SPINOR(face_volume, face_idx);\n"
1098  str += indent(block(proj)+"\n"+"break;\n")
1099  str += "}\n\n"
1100  return str
1101 # end def pack_face
1102 
1104  assert (sharedFloats == 0)
1105  str = ""
1106  str += def_input_spinor()
1107  str += "#include \"io_spinor.h\"\n\n"
1108 
1109  str += "if (face_num) "
1110  str += block(pack_face(1))
1111  str += " else "
1112  str += block(pack_face(0))
1113 
1114  str += "\n\n"
1115  str += "// undefine to prevent warning when precision is changed\n"
1116  str += "#undef spinorFloat\n"
1117  str += "#undef SHARED_STRIDE\n\n"
1118 
1119  for s in range(0,4):
1120  for c in range(0,3):
1121  i = 3*s+c
1122  str += "#undef "+in_re(s,c)+"\n"
1123  str += "#undef "+in_im(s,c)+"\n"
1124  str += "\n"
1125 
1126  return str
1127 # end def generate_pack
1128 
1129 
1131  return prolog() + gen(0) + gen(1) + gen(2) + gen(3) + gen(4) + gen(5) + gen(6) + gen(7) + epilog()
1132 
1134  return prolog() + epilog()
1135 
1136 # generate Wilson-like Dslash kernels
1138  print "Generating dslash kernel for sm" + str(arch/10)
1139 
1140  global sharedFloats
1141  global sharedDslash
1142  global dslash
1143  global dagger
1144  global clover
1145  global twist
1146  global asymClover
1147 
1148  sharedFloats = 0
1149  if arch >= 200:
1150  sharedFloats = 24
1151  sharedDslash = True
1152  name = "fermi"
1153  elif arch >= 120:
1154  sharedFloats = 0
1155  sharedDslash = False
1156  name = "gt200"
1157  else:
1158  sharedFloats = 19
1159  sharedDslash = False
1160  name = "g80"
1161 
1162  print "Shared floats set to " + str(sharedFloats)
1163 
1164  dslash = True
1165  twist = False
1166  clover = True
1167  dagger = False
1168 
1169  filename = 'dslash_core/wilson_dslash_' + name + '_core.h'
1170  print sys.argv[0] + ": generating " + filename;
1171  f = open(filename, 'w')
1172  f.write(generate_dslash())
1173  f.close()
1174 
1175  dagger = True
1176  filename = 'dslash_core/wilson_dslash_dagger_' + name + '_core.h'
1177  print sys.argv[0] + ": generating " + filename;
1178  f = open(filename, 'w')
1179  f.write(generate_dslash())
1180  f.close()
1181 
1182  asymClover = True
1183 
1184  dagger = False
1185  filename = 'dslash_core/asym_wilson_clover_dslash_' + name + '_core.h'
1186  print sys.argv[0] + ": generating " + filename;
1187  f = open(filename, 'w')
1188  f.write(generate_dslash())
1189  f.close()
1190 
1191  dagger = True
1192  filename = 'dslash_core/asym_wilson_clover_dslash_dagger_' + name + '_core.h'
1193  print sys.argv[0] + ": generating " + filename;
1194  f = open(filename, 'w')
1195  f.write(generate_dslash())
1196  f.close()
1197 
1198  asymClover = False
1199 
1200 # twist = True
1201 # clover = False
1202 # dagger = False
1203 # filename = 'dslash_core/tm_dslash_' + name + '_core.h'
1204 # print sys.argv[0] + ": generating " + filename;
1205 # f = open(filename, 'w')
1206 # f.write(generate_dslash())
1207 # f.close()
1208 
1209 # dagger = True
1210 # filename = 'dslash_core/tm_dslash_dagger_' + name + '_core.h'
1211 # print sys.argv[0] + ": generating " + filename + "\n";
1212 # f = open(filename, 'w')
1213 # f.write(generate_dslash())
1214 # f.close()
1215 #
1216 # twist = False
1217 # dslash = False
1218 
1219 
1220 
1221 dslash = False
1222 dagger = False
1223 twist = False
1224 clover = False
1225 asymClover = False
1226 sharedFloats = 0
1227 sharedDslash = False
1228 pack = False
1229 
1230 # generate dslash kernels
1231 arch = 200
1233 
1234 arch = 130
1236 
1237 # generate packing kernels
1238 dslash = True
1239 sharedFloats = 0
1240 twist = False
1241 clover = False
1242 dagger = False
1243 pack = True
1244 print sys.argv[0] + ": generating wilson_pack_face_core.h";
1245 f = open('dslash_core/wilson_pack_face_core.h', 'w')
1246 f.write(generate_pack())
1247 f.close()
1248 
1249 dagger = True
1250 print sys.argv[0] + ": generating wilson_pack_face_dagger_core.h";
1251 f = open('dslash_core/wilson_pack_face_dagger_core.h', 'w')
1252 f.write(generate_pack())
1253 f.close()
1254 dslash = False
1255 pack = False
1256 
1257 # generate clover solo term
1258 #clover = True
1259 #cloverSharedFloats = 0
1260 #sharedFloats = cloverSharedFloats
1261 
def g_re(d, m, n)
def gplus(g1, g2)
def generate_dslash_kernels(arch)
def c_im(b, sm, cm, sn, cn)
def g_im(d, m, n)
def indent(code)
code generation ######################################################################## ...
def a_im(b, s, c)
def input_spinor(s, c, z)
def apply_clover(v_out, v_in)
def c_re(b, sm, cm, sn, cn)
def a_re(b, s, c)
def gminus(g1, g2)
def complexify(a)
complex numbers ######################################################################## ...
Definition: gen.py:1
def from_chiral_basis(v_out, v_in, c)
def gen(dir, pack_only=False)
def to_chiral_basis(v_out, v_in, c)
def spinor(name, s, c, z)
def clover_mult(v_out, v_in, chi)
def pack_face(facenum)