QUDA  0.9.0
fused_exterior_deg_tm_dslash_cuda_gen.py
Go to the documentation of this file.
1 import sys
2 
3 
4 
5 def complexify(a):
6  return [complex(x) for x in a]
7 
8 def complexToStr(c):
9  def fltToString(a):
10  if a == int(a): return `int(a)`
11  else: return `a`
12 
13  def imToString(a):
14  if a == 0: return "0i"
15  elif a == -1: return "-i"
16  elif a == 1: return "i"
17  else: return fltToString(a)+"i"
18 
19  re = c.real
20  im = c.imag
21  if re == 0 and im == 0: return "0"
22  elif re == 0: return imToString(im)
23  elif im == 0: return fltToString(re)
24  else:
25  im_str = "-"+imToString(-im) if im < 0 else "+"+imToString(im)
26  return fltToString(re)+im_str
27 
28 
29 
30 
31 id = complexify([
32  1, 0, 0, 0,
33  0, 1, 0, 0,
34  0, 0, 1, 0,
35  0, 0, 0, 1
36 ])
37 
38 gamma1 = complexify([
39  0, 0, 0, 1j,
40  0, 0, 1j, 0,
41  0, -1j, 0, 0,
42  -1j, 0, 0, 0
43 ])
44 
45 gamma2 = complexify([
46  0, 0, 0, 1,
47  0, 0, -1, 0,
48  0, -1, 0, 0,
49  1, 0, 0, 0
50 ])
51 
52 gamma3 = complexify([
53  0, 0, 1j, 0,
54  0, 0, 0, -1j,
55  -1j, 0, 0, 0,
56  0, 1j, 0, 0
57 ])
58 
59 gamma4 = complexify([
60  1, 0, 0, 0,
61  0, 1, 0, 0,
62  0, 0, -1, 0,
63  0, 0, 0, -1
64 ])
65 
66 igamma5 = complexify([
67  0, 0, 1j, 0,
68  0, 0, 0, 1j,
69  1j, 0, 0, 0,
70  0, 1j, 0, 0
71 ])
72 
73 
74 def gplus(g1, g2):
75  return [x+y for (x,y) in zip(g1,g2)]
76 
77 def gminus(g1, g2):
78  return [x-y for (x,y) in zip(g1,g2)]
79 
81  out = ""
82  for i in range(0, 4):
83  for j in range(0,4):
84  out += complexToStr(p[4*i+j]) + " "
85  out += "\n"
86  return out
87 
88 projectors = [
89  gminus(id,gamma1), gplus(id,gamma1),
90  gminus(id,gamma2), gplus(id,gamma2),
91  gminus(id,gamma3), gplus(id,gamma3),
92  gminus(id,gamma4), gplus(id,gamma4),
93 ]
94 
95 
96 
97 def indent(code):
98  def indentline(line): return (" "+line if (line.count("#", 0, 1) == 0) else line)
99  return ''.join([indentline(line)+"\n" for line in code.splitlines()])
100 
101 def block(code):
102  return "{\n"+indent(code)+"}"
103 
104 def sign(x):
105  if x==1: return "+"
106  elif x==-1: return "-"
107  elif x==+2: return "+2*"
108  elif x==-2: return "-2*"
109 
110 def nthFloat4(n):
111  return `(n/4)` + "." + ["x", "y", "z", "w"][n%4]
112 
113 def nthFloat2(n):
114  return `(n/2)` + "." + ["x", "y"][n%2]
115 
116 
117 def in_re(s, c): return "i"+`s`+`c`+"_re"
118 def in_im(s, c): return "i"+`s`+`c`+"_im"
119 def g_re(d, m, n): return ("g" if (d%2==0) else "gT")+`m`+`n`+"_re"
120 def g_im(d, m, n): return ("g" if (d%2==0) else "gT")+`m`+`n`+"_im"
121 def out_re(s, c): return "o"+`s`+`c`+"_re"
122 def out_im(s, c): return "o"+`s`+`c`+"_im"
123 def h1_re(h, c): return ["a","b"][h]+`c`+"_re"
124 def h1_im(h, c): return ["a","b"][h]+`c`+"_im"
125 def h2_re(h, c): return ["A","B"][h]+`c`+"_re"
126 def h2_im(h, c): return ["A","B"][h]+`c`+"_im"
127 def a_re(b, s, c): return "a"+`(s+2*b)`+`c`+"_re"
128 def a_im(b, s, c): return "a"+`(s+2*b)`+`c`+"_im"
129 
130 def acc_re(s, c): return "acc"+`s`+`c`+"_re"
131 def acc_im(s, c): return "acc"+`s`+`c`+"_im"
132 
133 def tmp_re(s, c): return "tmp"+`s`+`c`+"_re"
134 def tmp_im(s, c): return "tmp"+`s`+`c`+"_im"
135 
136 def spinor(name, s, c, z):
137  if z==0: return name+`s`+`c`+"_re"
138  else: return name+`s`+`c`+"_im"
139 
141  str = ""
142  str += "// input spinor\n"
143  str += "#ifdef SPINOR_DOUBLE\n"
144  str += "#define spinorFloat double\n"
145  if sharedDslash:
146  str += "#define WRITE_SPINOR_SHARED WRITE_SPINOR_SHARED_DOUBLE2\n"
147  str += "#define READ_SPINOR_SHARED READ_SPINOR_SHARED_DOUBLE2\n"
148 
149  for s in range(0,4):
150  for c in range(0,3):
151  i = 3*s+c
152  str += "#define "+in_re(s,c)+" I"+nthFloat2(2*i+0)+"\n"
153  str += "#define "+in_im(s,c)+" I"+nthFloat2(2*i+1)+"\n"
154  if dslash and not pack:
155  for s in range(0,4):
156  for c in range(0,3):
157  i = 3*s+c
158  str += "#define "+acc_re(s,c)+" accum"+nthFloat2(2*i+0)+"\n"
159  str += "#define "+acc_im(s,c)+" accum"+nthFloat2(2*i+1)+"\n"
160  str += "#else\n"
161  str += "#define spinorFloat float\n"
162  if sharedDslash:
163  str += "#define WRITE_SPINOR_SHARED WRITE_SPINOR_SHARED_FLOAT4\n"
164  str += "#define READ_SPINOR_SHARED READ_SPINOR_SHARED_FLOAT4\n"
165  for s in range(0,4):
166  for c in range(0,3):
167  i = 3*s+c
168  str += "#define "+in_re(s,c)+" I"+nthFloat4(2*i+0)+"\n"
169  str += "#define "+in_im(s,c)+" I"+nthFloat4(2*i+1)+"\n"
170  if dslash and not pack:
171  for s in range(0,4):
172  for c in range(0,3):
173  i = 3*s+c
174  str += "#define "+acc_re(s,c)+" accum"+nthFloat4(2*i+0)+"\n"
175  str += "#define "+acc_im(s,c)+" accum"+nthFloat4(2*i+1)+"\n"
176  str += "#endif // SPINOR_DOUBLE\n\n"
177  return str
178 # end def def_input_spinor
179 
180 
181 def def_gauge():
182  str = "// gauge link\n"
183  str += "#ifdef GAUGE_FLOAT2\n"
184  for m in range(0,3):
185  for n in range(0,3):
186  i = 3*m+n
187  str += "#define "+g_re(0,m,n)+" G"+nthFloat2(2*i+0)+"\n"
188  str += "#define "+g_im(0,m,n)+" G"+nthFloat2(2*i+1)+"\n"
189 
190  str += "\n"
191  str += "#else\n"
192  for m in range(0,3):
193  for n in range(0,3):
194  i = 3*m+n
195  str += "#define "+g_re(0,m,n)+" G"+nthFloat4(2*i+0)+"\n"
196  str += "#define "+g_im(0,m,n)+" G"+nthFloat4(2*i+1)+"\n"
197 
198  str += "\n"
199  str += "#endif // GAUGE_DOUBLE\n\n"
200 
201  str += "// conjugated gauge link\n"
202  for m in range(0,3):
203  for n in range(0,3):
204  i = 3*m+n
205  str += "#define "+g_re(1,m,n)+" (+"+g_re(0,n,m)+")\n"
206  str += "#define "+g_im(1,m,n)+" (-"+g_im(0,n,m)+")\n"
207  str += "\n"
208 
209  return str
210 # end def def_gauge
211 
212 
214 # sharedDslash = True: input spinors stored in shared memory
215 # sharedDslash = False: output spinors stored in shared memory
216  str = "// output spinor\n"
217  for s in range(0,4):
218  for c in range(0,3):
219  i = 3*s+c
220  if 2*i < sharedFloats and not sharedDslash:
221  str += "#define "+out_re(s,c)+" s["+`(2*i+0)`+"*SHARED_STRIDE]\n"
222  else:
223  str += "VOLATILE spinorFloat "+out_re(s,c)+";\n"
224  if 2*i+1 < sharedFloats and not sharedDslash:
225  str += "#define "+out_im(s,c)+" s["+`(2*i+1)`+"*SHARED_STRIDE]\n"
226  else:
227  str += "VOLATILE spinorFloat "+out_im(s,c)+";\n"
228  return str
229 # end def def_output_spinor
230 
231 
232 def prolog():
233  global arch
234 
235  prolog_str = ("#ifdef MULTI_GPU\n\n")
236 
237  if dslash:
238  prolog_str+= ("// *** CUDA DSLASH ***\n\n" if not dagger else "// *** CUDA DSLASH DAGGER ***\n\n")
239  prolog_str+= "#define DSLASH_SHARED_FLOATS_PER_THREAD "+str(sharedFloats)+"\n\n"
240  else:
241  print "Undefined prolog"
242  exit
243 
244  prolog_str+= (
245 """
246 #if ((CUDA_VERSION >= 4010) && (__COMPUTE_CAPABILITY__ >= 200)) // NVVM compiler
247 #define VOLATILE
248 #else // Open64 compiler
249 #define VOLATILE volatile
250 #endif
251 """)
252 
253  prolog_str+= def_input_spinor()
254  if dslash == True: prolog_str+= def_gauge()
255  prolog_str+= def_output_spinor()
256 
257  if (sharedFloats > 0):
258  if (arch >= 200):
259  prolog_str+= (
260 """
261 #ifdef SPINOR_DOUBLE
262 #define SHARED_STRIDE 16 // to avoid bank conflicts on Fermi
263 #else
264 #define SHARED_STRIDE 32 // to avoid bank conflicts on Fermi
265 #endif
266 """)
267  else:
268  prolog_str+= (
269 """
270 #ifdef SPINOR_DOUBLE
271 #define SHARED_STRIDE 8 // to avoid bank conflicts on G80 and GT200
272 #else
273 #define SHARED_STRIDE 16 // to avoid bank conflicts on G80 and GT200
274 #endif
275 """)
276 
277 
278  # set the pointer if using shared memory for pseudo registers
279  if sharedFloats > 0 and not sharedDslash:
280  prolog_str += (
281 """
282 extern __shared__ char s_data[];
283 """)
284 
285  if dslash:
286  prolog_str += (
287 """
288 VOLATILE spinorFloat *s = (spinorFloat*)s_data + DSLASH_SHARED_FLOATS_PER_THREAD*SHARED_STRIDE*(threadIdx.x/SHARED_STRIDE)
289  + (threadIdx.x % SHARED_STRIDE);
290 """)
291 
292 
293  if dslash:
294  prolog_str += (
295 """
296 #include "read_gauge.h"
297 #include "io_spinor.h"
298 
299 int coord[5];
300 int X;
301 
302 #if (DD_PREC==2) // half precision
303 int sp_norm_idx;
304 #endif // half precision
305 
306 int sid;
307 """)
308 
309 
310  prolog_str+= (
311 """
312  sid = blockIdx.x*blockDim.x + threadIdx.x;
313  if (sid >= param.threads) return;
314 
315 
316  int dim = dimFromFaceIndex(sid, param); // sid is also modified
317 
318  const int face_volume = ((param.threadDimMapUpper[dim] - param.threadDimMapLower[dim]) >> 1);
319  const int face_num = (sid >= face_volume); // is this thread updating face 0 or 1
320  int face_idx = sid - face_num*face_volume; // index into the respective face
321 
322  switch(dim) {
323  case 0:
324  coordsFromFaceIndex<4,QUDA_4D_PC,0,1>(X, sid, coord, face_idx, face_num, param);
325  break;
326  case 1:
327  coordsFromFaceIndex<4,QUDA_4D_PC,1,1>(X, sid, coord, face_idx, face_num, param);
328  break;
329  case 2:
330  coordsFromFaceIndex<4,QUDA_4D_PC,2,1>(X, sid, coord, face_idx, face_num, param);
331  break;
332  case 3:
333  coordsFromFaceIndex<4,QUDA_4D_PC,3,1>(X, sid, coord, face_idx, face_num, param);
334  break;
335  }
336 
337 
338  bool active = false;
339  for(int dir=0; dir<4; ++dir){
340  active = active || isActive(dim,dir,+1,coord,param.commDim,param.dc.X);
341  }
342  if(!active) return;
343 
344 
345  READ_INTERMEDIATE_SPINOR(INTERTEX, param.sp_stride, sid, sid);
346 
347 """)
348 
349  out = ""
350  for s in range(0,4):
351  for c in range(0,3):
352  out += out_re(s,c)+" = "+in_re(s,c)+"; "+out_im(s,c)+" = "+in_im(s,c)+";\n"
353  prolog_str+= indent(out)
354 # prolog_str+= "}\n"
355 # prolog_str+= "#endif // MULTI_GPU\n\n\n"
356 
357  return prolog_str
358 # end def prolog
359 
360 
361 def gen(dir, pack_only=False):
362  projIdx = dir if not dagger else dir + (1 - 2*(dir%2))
363  projStr = projectorToStr(projectors[projIdx])
364  def proj(i,j):
365  return projectors[projIdx][4*i+j]
366 
367  # if row(i) = (j, c), then the i'th row of the projector can be represented
368  # as a multiple of the j'th row: row(i) = c row(j)
369  def row(i):
370  assert i==2 or i==3
371  if proj(i,0) == 0j:
372  return (1, proj(i,1))
373  if proj(i,1) == 0j:
374  return (0, proj(i,0))
375 
376  boundary = ["coord[0]==(param.dc.X[0]-1)", "coord[0]==0", "coord[1]==(param.dc.X[1]-1)", "coord[1]==0", "coord[2]==(param.dc.X[2]-1)", "coord[2]==0", "coord[3]==(param.dc.X[3]-1)", "coord[3]==0"]
377  interior = ["coord[0]<(param.dc.X[0]-1)", "coord[0]>0", "coord[1]<(param.dc.X[1]-1)", "coord[1]>0", "coord[2]<(param.dc.X[2]-1)", "coord[2]>0", "coord[3]<(param.dc.X[3]-1)", "coord[3]>0"]
378  offset = ["+1","-1","+1","-1","+1","-1","+1","-1"];
379  dim = ["X", "Y", "Z", "T"]
380 
381  # index of neighboring site when not on boundary
382  sp_idx = ["X+1", "X-1", "X+param.dc.X[0]", "X-param.dc.X[0]", "X+param.dc.X2X1", "X-param.dc.X2X1", "X+param.dc.X3X2X1", "X-param.dc.X3X2X1"]
383 
384  # index of neighboring site (across boundary)
385  sp_idx_wrap = ["X-(param.dc.X[0]-1)", "X+(param.dc.X[0]-1)", "X-param.dc.X2X1mX1", "X+param.dc.X2X1mX1", "X-param.dc.X3X2X1mX2X1", "X+param.dc.X3X2X1mX2X1",
386  "X-param.dc.X4X3X2X1mX3X2X1", "X+param.dc.X4X3X2X1mX3X2X1"]
387 
388  cond = ""
389 # cond += "#ifdef MULTI_GPU\n"
390  cond += "if (isActive(dim," + `dir/2` + "," + offset[dir] + ",coord,param.commDim,param.dc.X) && " +boundary[dir]+" )\n"
391 # cond += "#endif\n"
392 
393  str = ""
394 
395  projName = "P"+`dir/2`+["-","+"][projIdx%2]
396  str += "// Projector "+projName+"\n"
397  for l in projStr.splitlines():
398  str += "// "+l+"\n"
399  str += "\n"
400 
401  str += "faceIndexFromCoords<4,1>(face_idx,coord," + `dir/2` + ",param);\n"
402  str += "const int sp_idx = face_idx + param.ghostOffset[" + `dir/2` + "][" + `1-dir%2` +"];\n"
403 
404 
405  str += "#if (DD_PREC==2)\n"
406  str += " sp_norm_idx = face_idx + "
407  str += "param.ghostNormOffset[" + `dir/2` + "][" + `1-dir%2` + "];\n"
408  str += "#endif\n"
409 
410 
411 
412  str += "\n"
413  if dir % 2 == 0:
414  str += "const int ga_idx = sid;\n"
415  else:
416  str += "const int ga_idx = param.dc.Vh+face_idx;\n"
417  str += "\n"
418 
419  # scan the projector to determine which loads are required
420  row_cnt = ([0,0,0,0])
421  for h in range(0,4):
422  for s in range(0,4):
423  re = proj(h,s).real
424  im = proj(h,s).imag
425  if re != 0 or im != 0:
426  row_cnt[h] += 1
427  row_cnt[0] += row_cnt[1]
428  row_cnt[2] += row_cnt[3]
429 
430  decl_half = ""
431  for h in range(0, 2):
432  for c in range(0, 3):
433  decl_half += "spinorFloat "+h1_re(h,c)+", "+h1_im(h,c)+";\n";
434  decl_half += "\n"
435 
436  load_spinor = "// read spinor from device memory\n"
437 
438  load_spinor += "#ifdef TWIST_INV_DSLASH\n"
439  load_spinor += "#ifdef SPINOR_DOUBLE\n"
440  load_spinor += "const spinorFloat a = param.a;\n"
441  load_spinor += "const spinorFloat b = param.b;\n"
442  load_spinor += "#else\n"
443  load_spinor += "const spinorFloat a = param.a_f;\n"
444  load_spinor += "const spinorFloat b = param.b_f;\n"
445  load_spinor += "#endif\n"
446  load_spinor += "#endif\n"
447 
448  if row_cnt[0] == 0:
449  if not pack_only:
450  load_spinor += "#ifndef TWIST_INV_DSLASH\n"
451  load_spinor += "READ_SPINOR_DOWN(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
452  load_spinor += "#else\n"
453  load_spinor += "READ_SPINOR(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
454  if not dagger:
455  load_spinor += "APPLY_TWIST_INV( a, b, i);\n"
456  else:
457  load_spinor += "APPLY_TWIST_INV(-a, b, i);\n"
458  if not pack_only:
459  load_spinor += "#endif\n"
460  elif row_cnt[2] == 0:
461  if not pack_only:
462  load_spinor += "#ifndef TWIST_INV_DSLASH\n"
463  load_spinor += "READ_SPINOR_UP(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
464  load_spinor += "#else\n"
465  load_spinor += "READ_SPINOR(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
466  if not dagger:
467  load_spinor += "APPLY_TWIST_INV( a, b, i);\n"
468  else:
469  load_spinor += "APPLY_TWIST_INV(-a, b, i);\n"
470  if not pack_only:
471  load_spinor += "#endif\n"
472  else:
473  load_spinor += "READ_SPINOR(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
474  if not pack_only:
475  load_spinor += "#ifdef TWIST_INV_DSLASH\n"
476  if not dagger:
477  load_spinor += "APPLY_TWIST_INV( a, b, i);\n"
478  else:
479  load_spinor += "APPLY_TWIST_INV(-a, b, i);\n"
480  if not pack_only:
481  load_spinor += "#endif\n"
482  load_spinor += "\n"
483 
484  load_half = ""
485  load_half += "const int sp_stride_pad = param.dc.ghostFace[" + `dir/2` + "];\n"
486  #load_half += "#if (DD_PREC==2) // half precision\n"
487  #load_half += "const int sp_norm_idx = sid + param.ghostNormOffset[static_cast<int>(kernel_type)];\n"
488  #load_half += "#endif\n"
489 
490  if dir >= 6: load_half += "const int t_proj_scale = TPROJSCALE;\n"
491  #if dir >= 6: load_half += "const int t_proj_scale = 2;//set this manually\n"
492  load_half += "\n"
493  load_half += "// read half spinor from device memory\n"
494 
495 # we have to use the same volume index for backwards and forwards gathers
496 # instead of using READ_UP_SPINOR and READ_DOWN_SPINOR, just use READ_HALF_SPINOR with the appropriate shift
497  load_half += "READ_SPINOR_GHOST(GHOSTSPINORTEX, sp_stride_pad, sp_idx, sp_norm_idx, "+`dir`+");\n\n"
498 # if (dir+1) % 2 == 0: load_half += "READ_HALF_SPINOR(SPINORTEX, sp_stride_pad, sp_idx, sp_norm_idx);\n\n"
499 # else: load_half += "READ_HALF_SPINOR(SPINORTEX, sp_stride_pad, sp_idx + (SPINOR_HOP/2)*sp_stride_pad, sp_norm_idx);\n\n"
500  load_gauge = "// read gauge matrix from device memory\n"
501  load_gauge += "READ_GAUGE_MATRIX(G, GAUGE"+`dir%2`+"TEX, "+`dir`+", ga_idx, param.gauge_stride);\n\n"
502 
503  reconstruct_gauge = "// reconstruct gauge matrix\n"
504  reconstruct_gauge += "RECONSTRUCT_GAUGE_MATRIX("+`dir`+");\n\n"
505 
506  project = "// project spinor into half spinors\n"
507  for h in range(0, 2):
508  for c in range(0, 3):
509  strRe = ""
510  strIm = ""
511  for s in range(0, 4):
512  re = proj(h,s).real
513  im = proj(h,s).imag
514  if re==0 and im==0: ()
515  elif im==0:
516  strRe += sign(re)+in_re(s,c)
517  strIm += sign(re)+in_im(s,c)
518  elif re==0:
519  strRe += sign(-im)+in_im(s,c)
520  strIm += sign(im)+in_re(s,c)
521  if row_cnt[0] == 0: # projector defined on lower half only
522  for s in range(0, 4):
523  re = proj(h+2,s).real
524  im = proj(h+2,s).imag
525  if re==0 and im==0: ()
526  elif im==0:
527  strRe += sign(re)+in_re(s,c)
528  strIm += sign(re)+in_im(s,c)
529  elif re==0:
530  strRe += sign(-im)+in_im(s,c)
531  strIm += sign(im)+in_re(s,c)
532 
533  project += h1_re(h,c)+" = "+strRe+";\n"
534  project += h1_im(h,c)+" = "+strIm+";\n"
535 
536  write_shared = (
537 """// store spinor into shared memory
538 WRITE_SPINOR_SHARED(threadIdx.x, threadIdx.y, threadIdx.z, i);\n
539 """)
540 
541  load_shared_1 = (
542 """// load spinor from shared memory
543 int tx = (threadIdx.x > 0) ? threadIdx.x-1 : blockDim.x-1;
544 __syncthreads();
545 READ_SPINOR_SHARED(tx, threadIdx.y, threadIdx.z);\n
546 """)
547 
548  load_shared_2 = (
549 """// load spinor from shared memory
550 int tx = (threadIdx.x + blockDim.x - ((x1+1)&1) ) % blockDim.x;
551 int ty = (threadIdx.y < blockDim.y - 1) ? threadIdx.y + 1 : 0;
552 READ_SPINOR_SHARED(tx, ty, threadIdx.z);\n
553 """)
554 
555  load_shared_3 = (
556 """// load spinor from shared memory
557 int tx = (threadIdx.x + blockDim.x - ((x1+1)&1)) % blockDim.x;
558 int ty = (threadIdx.y > 0) ? threadIdx.y - 1 : blockDim.y - 1;
559 READ_SPINOR_SHARED(tx, ty, threadIdx.z);\n
560 """)
561 
562  load_shared_4 = (
563 """// load spinor from shared memory
564 int tx = (threadIdx.x + blockDim.x - ((x1+1)&1) ) % blockDim.x;
565 int tz = (threadIdx.z < blockDim.z - 1) ? threadIdx.z + 1 : 0;
566 READ_SPINOR_SHARED(tx, threadIdx.y, tz);\n
567 """)
568 
569  load_shared_5 = (
570 """// load spinor from shared memory
571 int tx = (threadIdx.x + blockDim.x - ((x1+1)&1)) % blockDim.x;
572 int tz = (threadIdx.z > 0) ? threadIdx.z - 1 : blockDim.z - 1;
573 READ_SPINOR_SHARED(tx, threadIdx.y, tz);\n
574 """)
575 
576 
577  copy_half = ""
578  for h in range(0, 2):
579  for c in range(0, 3):
580  copy_half += h1_re(h,c)+" = "+("t_proj_scale*" if (dir >= 6) else "")+in_re(h,c)+"; "
581  copy_half += h1_im(h,c)+" = "+("t_proj_scale*" if (dir >= 6) else "")+in_im(h,c)+";\n"
582  copy_half += "\n"
583 
584  prep_half = ""
585  prep_half += "\n"
586  prep_half += load_half
587  prep_half += copy_half
588 
589  ident = "// identity gauge matrix\n"
590  for m in range(0,3):
591  for h in range(0,2):
592  ident += "spinorFloat "+h2_re(h,m)+" = " + h1_re(h,m) + "; "
593  ident += "spinorFloat "+h2_im(h,m)+" = " + h1_im(h,m) + ";\n"
594  ident += "\n"
595 
596  mult = ""
597  for m in range(0,3):
598  mult += "// multiply row "+`m`+"\n"
599  for h in range(0,2):
600  re = "spinorFloat "+h2_re(h,m)+" = 0;\n"
601  im = "spinorFloat "+h2_im(h,m)+" = 0;\n"
602  for c in range(0,3):
603  re += h2_re(h,m) + " += " + g_re(dir,m,c) + " * "+h1_re(h,c)+";\n"
604  re += h2_re(h,m) + " -= " + g_im(dir,m,c) + " * "+h1_im(h,c)+";\n"
605  im += h2_im(h,m) + " += " + g_re(dir,m,c) + " * "+h1_im(h,c)+";\n"
606  im += h2_im(h,m) + " += " + g_im(dir,m,c) + " * "+h1_re(h,c)+";\n"
607  mult += re + im
608  mult += "\n"
609 
610  reconstruct = ""
611  for m in range(0,3):
612 
613  for h in range(0,2):
614  h_out = h
615  if row_cnt[0] == 0: # projector defined on lower half only
616  h_out = h+2
617 
618  reconstruct += out_re(h_out, m) + " += " + h2_re(h,m) + ";\n"
619  reconstruct += out_im(h_out, m) + " += " + h2_im(h,m) + ";\n"
620 
621  for s in range(2,4):
622  (h,c) = row(s)
623  re = c.real
624  im = c.imag
625  if im == 0 and re == 0: ()
626  elif im == 0:
627  reconstruct += out_re(s, m) + " " + sign(re) + "= " + h2_re(h,m) + ";\n"
628  reconstruct += out_im(s, m) + " " + sign(re) + "= " + h2_im(h,m) + ";\n"
629  elif re == 0:
630  reconstruct += out_re(s, m) + " " + sign(-im) + "= " + h2_im(h,m) + ";\n"
631  reconstruct += out_im(s, m) + " " + sign(+im) + "= " + h2_re(h,m) + ";\n"
632 
633  reconstruct += "\n"
634 
635  if dir >= 6:
636  str += "if (param.gauge_fixed && ga_idx < param.dc.X4X3X2X1hmX3X2X1h)\n"
637  str += block(decl_half + prep_half + ident + reconstruct)
638  str += " else "
639  str += block(decl_half + prep_half + load_gauge + reconstruct_gauge + mult + reconstruct)
640  else:
641  str += decl_half + prep_half + load_gauge + reconstruct_gauge + mult + reconstruct
642 
643  if pack_only:
644  out = load_spinor + decl_half + project
645  out = out.replace("sp_idx", "idx")
646  return out
647  else:
648  return cond + block(str)+"\n\n"
649 # end def gen
650 
651 
652 def input_spinor(s,c,z):
653  if dslash:
654  if z==0: return out_re(s,c)
655  else: return out_im(s,c)
656  else:
657  if z==0: return in_re(s,c)
658  else: return in_im(s,c)
659 
660 
662  str = ""
663  str += "#ifndef TWIST_INV_DSLASH\n"
664  str += "#ifdef SPINOR_DOUBLE\n"
665  str += "const spinorFloat a = param.a;\n"
666  str += "const spinorFloat b = param.b;\n"
667  str += "#else\n"
668  str += "const spinorFloat a = param.a_f;\n"
669  str += "const spinorFloat b = param.b_f;\n"
670  str += "#endif\n"
671  str += "#endif\n"
672 
673  str += "#ifdef DSLASH_XPAY\n"
674  str += "READ_ACCUM(ACCUMTEX, param.sp_stride)\n\n"
675 
676  str += "#ifndef TWIST_XPAY\n"
677  str += "#ifndef TWIST_INV_DSLASH\n"
678  str += "//perform invert twist first:\n"
679  if not dagger:
680  str += "APPLY_TWIST_INV( a, b, o);\n"
681  else:
682  str += "APPLY_TWIST_INV(-a, b, o);\n"
683  str += "#endif\n"
684  for s in range(0,4):
685  for c in range(0,3):
686  i = 3*s+c
687  str += out_re(s,c) +" += "+acc_re(s,c)+";\n"
688  str += out_im(s,c) +" += "+acc_im(s,c)+";\n"
689  str += "#else\n"
690  if not dagger:
691  str += "APPLY_TWIST( a, acc);\n"
692  else:
693  str += "APPLY_TWIST(-a, acc);\n"
694  str += "//warning! b is unrelated to the twisted mass parameter in this case!\n\n"
695  for s in range(0,4):
696  for c in range(0,3):
697  i = 3*s+c
698  str += out_re(s,c) +" = b*"+out_re(s,c)+"+"+acc_re(s,c)+";\n"
699  str += out_im(s,c) +" = b*"+out_im(s,c)+"+"+acc_im(s,c)+";\n"
700  str += "#endif//TWIST_XPAY\n"
701  str += "#else //no XPAY\n"
702  str += "#ifndef TWIST_INV_DSLASH\n"
703  if not dagger:
704  str += " APPLY_TWIST_INV( a, b, o);\n"
705  else:
706  str += " APPLY_TWIST_INV(-a, b, o);\n"
707  str += "#endif\n"
708  str += "#endif\n"
709  return str
710 # end def twisted_xpay
711 
712 
713 def epilog():
714  str = ""
715  block_str = ""
716  block_str += twisted_xpay()
717  str += block( block_str )
718 
719  str += "\n\n"
720  str += "// write spinor field back to device memory\n"
721  str += "WRITE_SPINOR(param.sp_stride);\n\n"
722 
723  str += "// undefine to prevent warning when precision is changed\n"
724  str += "#undef spinorFloat\n"
725  if sharedDslash:
726  str += "#undef WRITE_SPINOR_SHARED\n"
727  str += "#undef READ_SPINOR_SHARED\n"
728  if sharedFloats > 0: str += "#undef SHARED_STRIDE\n\n"
729 
730  if dslash:
731  for m in range(0,3):
732  for n in range(0,3):
733  i = 3*m+n
734  str += "#undef "+g_re(0,m,n)+"\n"
735  str += "#undef "+g_im(0,m,n)+"\n"
736  str += "\n"
737 
738  for s in range(0,4):
739  for c in range(0,3):
740  i = 3*s+c
741  str += "#undef "+in_re(s,c)+"\n"
742  str += "#undef "+in_im(s,c)+"\n"
743  str += "\n"
744 
745  if dslash:
746  for s in range(0,4):
747  for c in range(0,3):
748  i = 3*s+c
749  str += "#undef "+acc_re(s,c)+"\n"
750  str += "#undef "+acc_im(s,c)+"\n"
751  str += "\n"
752 
753  str += "\n"
754 
755  for s in range(0,4):
756  for c in range(0,3):
757  i = 3*s+c
758  if 2*i < sharedFloats:
759  str += "#undef "+out_re(s,c)+"\n"
760  if 2*i+1 < sharedFloats:
761  str += "#undef "+out_im(s,c)+"\n"
762  str += "\n"
763 
764  str += "#undef VOLATILE\n\n"
765 
766  str += "#endif // MULTI_GPU\n"
767 
768  return str
769 # end def epilog
770 
771 
772 def pack_face(facenum):
773  str = "\n"
774  str += "switch(dim) {\n"
775  for dim in range(0,4):
776  str += "case "+`dim`+":\n"
777  proj = gen(2*dim+facenum, pack_only=True)
778  proj += "\n"
779  proj += "// write half spinor back to device memory\n"
780  proj += "WRITE_HALF_SPINOR(face_volume, face_idx);\n"
781  str += indent(block(proj)+"\n"+"break;\n")
782  str += "}\n\n"
783  return str
784 # end def pack_face
785 
787  assert (sharedFloats == 0)
788  str = ""
789  str += def_input_spinor()
790  str += "#include \"io_spinor.h\"\n\n"
791 
792  str += "if (face_num) "
793  str += block(pack_face(1))
794  str += " else "
795  str += block(pack_face(0))
796 
797  str += "\n\n"
798  str += "// undefine to prevent warning when precision is changed\n"
799  str += "#undef spinorFloat\n"
800  str += "#undef SHARED_STRIDE\n\n"
801 
802  for s in range(0,4):
803  for c in range(0,3):
804  i = 3*s+c
805  str += "#undef "+in_re(s,c)+"\n"
806  str += "#undef "+in_im(s,c)+"\n"
807  str += "\n"
808 
809  return str
810 # end def generate_pack
811 
812 
814  return prolog() + gen(0) + gen(1) + gen(2) + gen(3) + gen(4) + gen(5) + gen(6) + gen(7) + epilog()
815 
816 # generate Wilson-like Dslash kernels
818  print "Generating dslash kernel for sm" + str(arch/10)
819 
820  global sharedFloats
821  global sharedDslash
822  global dslash
823  global dagger
824  global twist #deg_twist
825 # global ndeg_twist #new!
826 
827  sharedFloats = 0
828  if arch >= 200:
829  sharedFloats = 24
830  sharedDslash = True
831  name = "fermi"
832  elif arch >= 120:
833  sharedFloats = 0
834  sharedDslash = False
835  name = "gt200"
836  else:
837  sharedFloats = 19
838  sharedDslash = False
839  name = "g80"
840 
841  print "Shared floats set to " + str(sharedFloats)
842 
843  dslash = True
844  twist = True
845  dagger = False
846  filename = 'dslash_core/tm_fused_exterior_dslash_' + name + '_core.h'
847  print sys.argv[0] + ": generating " + filename;
848  f = open(filename, 'w')
849  f.write(generate_dslash())
850  f.close()
851 
852  dagger = True
853  filename = 'dslash_core/tm_fused_exterior_dslash_dagger_' + name + '_core.h'
854  print sys.argv[0] + ": generating " + filename + "\n";
855  f = open(filename, 'w')
856  f.write(generate_dslash())
857  f.close()
858 
859  twist = False
860  dslash = False
861 
862 
863 
864 dslash = False
865 dagger = False
866 twist = False
867 sharedFloats = 0
868 sharedDslash = False
869 pack = False
870 
871 # generate dslash kernels
872 arch = 200
874 
875 arch = 130
def complexify(a)
complex numbers ######################################################################## ...
def indent(code)
code generation ######################################################################## ...
Definition: gen.py:1