QUDA  0.9.0
deg_tm_dslash_cuda_gen.py
Go to the documentation of this file.
1 import sys
2 
3 
4 
5 def complexify(a):
6  return [complex(x) for x in a]
7 
8 def complexToStr(c):
9  def fltToString(a):
10  if a == int(a): return `int(a)`
11  else: return `a`
12 
13  def imToString(a):
14  if a == 0: return "0i"
15  elif a == -1: return "-i"
16  elif a == 1: return "i"
17  else: return fltToString(a)+"i"
18 
19  re = c.real
20  im = c.imag
21  if re == 0 and im == 0: return "0"
22  elif re == 0: return imToString(im)
23  elif im == 0: return fltToString(re)
24  else:
25  im_str = "-"+imToString(-im) if im < 0 else "+"+imToString(im)
26  return fltToString(re)+im_str
27 
28 
29 
30 
31 id = complexify([
32  1, 0, 0, 0,
33  0, 1, 0, 0,
34  0, 0, 1, 0,
35  0, 0, 0, 1
36 ])
37 
38 gamma1 = complexify([
39  0, 0, 0, 1j,
40  0, 0, 1j, 0,
41  0, -1j, 0, 0,
42  -1j, 0, 0, 0
43 ])
44 
45 gamma2 = complexify([
46  0, 0, 0, 1,
47  0, 0, -1, 0,
48  0, -1, 0, 0,
49  1, 0, 0, 0
50 ])
51 
52 gamma3 = complexify([
53  0, 0, 1j, 0,
54  0, 0, 0, -1j,
55  -1j, 0, 0, 0,
56  0, 1j, 0, 0
57 ])
58 
59 gamma4 = complexify([
60  1, 0, 0, 0,
61  0, 1, 0, 0,
62  0, 0, -1, 0,
63  0, 0, 0, -1
64 ])
65 
66 igamma5 = complexify([
67  0, 0, 1j, 0,
68  0, 0, 0, 1j,
69  1j, 0, 0, 0,
70  0, 1j, 0, 0
71 ])
72 
73 
74 def gplus(g1, g2):
75  return [x+y for (x,y) in zip(g1,g2)]
76 
77 def gminus(g1, g2):
78  return [x-y for (x,y) in zip(g1,g2)]
79 
81  out = ""
82  for i in range(0, 4):
83  for j in range(0,4):
84  out += complexToStr(p[4*i+j]) + " "
85  out += "\n"
86  return out
87 
88 projectors = [
89  gminus(id,gamma1), gplus(id,gamma1),
90  gminus(id,gamma2), gplus(id,gamma2),
91  gminus(id,gamma3), gplus(id,gamma3),
92  gminus(id,gamma4), gplus(id,gamma4),
93 ]
94 
95 
96 
97 def indent(code):
98  def indentline(line): return (" "+line if (line.count("#", 0, 1) == 0) else line)
99  return ''.join([indentline(line)+"\n" for line in code.splitlines()])
100 
101 def block(code):
102  return "{\n"+indent(code)+"}"
103 
104 def sign(x):
105  if x==1: return "+"
106  elif x==-1: return "-"
107  elif x==+2: return "+2*"
108  elif x==-2: return "-2*"
109 
110 def nthFloat4(n):
111  return `(n/4)` + "." + ["x", "y", "z", "w"][n%4]
112 
113 def nthFloat2(n):
114  return `(n/2)` + "." + ["x", "y"][n%2]
115 
116 
117 def in_re(s, c): return "i"+`s`+`c`+"_re"
118 def in_im(s, c): return "i"+`s`+`c`+"_im"
119 def g_re(d, m, n): return ("g" if (d%2==0) else "gT")+`m`+`n`+"_re"
120 def g_im(d, m, n): return ("g" if (d%2==0) else "gT")+`m`+`n`+"_im"
121 def out_re(s, c): return "o"+`s`+`c`+"_re"
122 def out_im(s, c): return "o"+`s`+`c`+"_im"
123 def h1_re(h, c): return ["a","b"][h]+`c`+"_re"
124 def h1_im(h, c): return ["a","b"][h]+`c`+"_im"
125 def h2_re(h, c): return ["A","B"][h]+`c`+"_re"
126 def h2_im(h, c): return ["A","B"][h]+`c`+"_im"
127 def a_re(b, s, c): return "a"+`(s+2*b)`+`c`+"_re"
128 def a_im(b, s, c): return "a"+`(s+2*b)`+`c`+"_im"
129 
130 def acc_re(s, c): return "acc"+`s`+`c`+"_re"
131 def acc_im(s, c): return "acc"+`s`+`c`+"_im"
132 
133 def tmp_re(s, c): return "tmp"+`s`+`c`+"_re"
134 def tmp_im(s, c): return "tmp"+`s`+`c`+"_im"
135 
136 def spinor(name, s, c, z):
137  if z==0: return name+`s`+`c`+"_re"
138  else: return name+`s`+`c`+"_im"
139 
141  str = ""
142  str += "// input spinor\n"
143  str += "#ifdef SPINOR_DOUBLE\n"
144  str += "#define spinorFloat double\n"
145  if sharedDslash:
146  str += "#define WRITE_SPINOR_SHARED WRITE_SPINOR_SHARED_DOUBLE2\n"
147  str += "#define READ_SPINOR_SHARED READ_SPINOR_SHARED_DOUBLE2\n"
148 
149  for s in range(0,4):
150  for c in range(0,3):
151  i = 3*s+c
152  str += "#define "+in_re(s,c)+" I"+nthFloat2(2*i+0)+"\n"
153  str += "#define "+in_im(s,c)+" I"+nthFloat2(2*i+1)+"\n"
154  if dslash and not pack:
155  for s in range(0,4):
156  for c in range(0,3):
157  i = 3*s+c
158  str += "#define "+acc_re(s,c)+" accum"+nthFloat2(2*i+0)+"\n"
159  str += "#define "+acc_im(s,c)+" accum"+nthFloat2(2*i+1)+"\n"
160  str += "#else\n"
161  str += "#define spinorFloat float\n"
162  if sharedDslash:
163  str += "#define WRITE_SPINOR_SHARED WRITE_SPINOR_SHARED_FLOAT4\n"
164  str += "#define READ_SPINOR_SHARED READ_SPINOR_SHARED_FLOAT4\n"
165  for s in range(0,4):
166  for c in range(0,3):
167  i = 3*s+c
168  str += "#define "+in_re(s,c)+" I"+nthFloat4(2*i+0)+"\n"
169  str += "#define "+in_im(s,c)+" I"+nthFloat4(2*i+1)+"\n"
170  if dslash and not pack:
171  for s in range(0,4):
172  for c in range(0,3):
173  i = 3*s+c
174  str += "#define "+acc_re(s,c)+" accum"+nthFloat4(2*i+0)+"\n"
175  str += "#define "+acc_im(s,c)+" accum"+nthFloat4(2*i+1)+"\n"
176  str += "#endif // SPINOR_DOUBLE\n\n"
177  return str
178 # end def def_input_spinor
179 
180 
181 def def_gauge():
182  str = "// gauge link\n"
183  str += "#ifdef GAUGE_FLOAT2\n"
184  for m in range(0,3):
185  for n in range(0,3):
186  i = 3*m+n
187  str += "#define "+g_re(0,m,n)+" G"+nthFloat2(2*i+0)+"\n"
188  str += "#define "+g_im(0,m,n)+" G"+nthFloat2(2*i+1)+"\n"
189 
190  str += "\n"
191  str += "#else\n"
192  for m in range(0,3):
193  for n in range(0,3):
194  i = 3*m+n
195  str += "#define "+g_re(0,m,n)+" G"+nthFloat4(2*i+0)+"\n"
196  str += "#define "+g_im(0,m,n)+" G"+nthFloat4(2*i+1)+"\n"
197 
198  str += "\n"
199  str += "#endif // GAUGE_DOUBLE\n\n"
200 
201  str += "// conjugated gauge link\n"
202  for m in range(0,3):
203  for n in range(0,3):
204  i = 3*m+n
205  str += "#define "+g_re(1,m,n)+" (+"+g_re(0,n,m)+")\n"
206  str += "#define "+g_im(1,m,n)+" (-"+g_im(0,n,m)+")\n"
207  str += "\n"
208 
209  return str
210 # end def def_gauge
211 
212 
214 # sharedDslash = True: input spinors stored in shared memory
215 # sharedDslash = False: output spinors stored in shared memory
216  str = "// output spinor\n"
217  for s in range(0,4):
218  for c in range(0,3):
219  i = 3*s+c
220  if 2*i < sharedFloats and not sharedDslash:
221  str += "#define "+out_re(s,c)+" s["+`(2*i+0)`+"*SHARED_STRIDE]\n"
222  else:
223  str += "VOLATILE spinorFloat "+out_re(s,c)+";\n"
224  if 2*i+1 < sharedFloats and not sharedDslash:
225  str += "#define "+out_im(s,c)+" s["+`(2*i+1)`+"*SHARED_STRIDE]\n"
226  else:
227  str += "VOLATILE spinorFloat "+out_im(s,c)+";\n"
228  return str
229 # end def def_output_spinor
230 
231 
232 def prolog():
233  global arch
234 
235  if dslash:
236  prolog_str= ("// *** CUDA DSLASH ***\n\n" if not dagger else "// *** CUDA DSLASH DAGGER ***\n\n")
237  prolog_str+= "#define DSLASH_SHARED_FLOATS_PER_THREAD "+str(sharedFloats)+"\n\n"
238  else:
239  print "Undefined prolog"
240  exit
241 
242  prolog_str+= (
243 """
244 #if ((CUDA_VERSION >= 4010) && (__COMPUTE_CAPABILITY__ >= 200)) // NVVM compiler
245 #define VOLATILE
246 #else // Open64 compiler
247 #define VOLATILE volatile
248 #endif
249 """)
250 
251  prolog_str+= def_input_spinor()
252  if dslash == True: prolog_str+= def_gauge()
253  prolog_str+= def_output_spinor()
254 
255  if (sharedFloats > 0):
256  if (arch >= 200):
257  prolog_str+= (
258 """
259 #ifdef SPINOR_DOUBLE
260 #define SHARED_STRIDE 16 // to avoid bank conflicts on Fermi
261 #else
262 #define SHARED_STRIDE 32 // to avoid bank conflicts on Fermi
263 #endif
264 """)
265  else:
266  prolog_str+= (
267 """
268 #ifdef SPINOR_DOUBLE
269 #define SHARED_STRIDE 8 // to avoid bank conflicts on G80 and GT200
270 #else
271 #define SHARED_STRIDE 16 // to avoid bank conflicts on G80 and GT200
272 #endif
273 """)
274 
275 
276  # set the pointer if using shared memory for pseudo registers
277  if sharedFloats > 0 and not sharedDslash:
278  prolog_str += (
279 """
280 extern __shared__ char s_data[];
281 """)
282 
283  if dslash:
284  prolog_str += (
285 """
286 VOLATILE spinorFloat *s = (spinorFloat*)s_data + DSLASH_SHARED_FLOATS_PER_THREAD*SHARED_STRIDE*(threadIdx.x/SHARED_STRIDE)
287  + (threadIdx.x % SHARED_STRIDE);
288 """)
289 
290 
291  if dslash:
292  prolog_str += (
293 """
294 #include "read_gauge.h"
295 #include "io_spinor.h"
296 
297 int coord[5];
298 int X;
299 
300 int sid;
301 """)
302 
303  if sharedDslash:
304  prolog_str += (
305 """
306 #ifdef MULTI_GPU
307 int face_idx;
308 if (kernel_type == INTERIOR_KERNEL) {
309 #endif
310 
311  // Assume even dimensions
312  coordsFromIndex3D<EVEN_X>(X, coord, sid, param);
313 
314  // only need to check Y and Z dims currently since X and T set to match exactly
315  if (coord[1] >= param.dc.X[1]) return;
316  if (cpprd[2] >= param.dc.X[2]) return;
317 
318 """)
319  else:
320  prolog_str += (
321 """
322 #ifdef MULTI_GPU
323 int face_idx;
324 if (kernel_type == INTERIOR_KERNEL) {
325 #endif
326 
327  sid = blockIdx.x*blockDim.x + threadIdx.x;
328  if (sid >= param.threads) return;
329 
330  // Assume even dimensions
331  coordsFromIndex<4,QUDA_4D_PC,EVEN_X>(X, coord, sid, param);
332 
333 """)
334 
335  out = ""
336  for s in range(0,4):
337  for c in range(0,3):
338  out += out_re(s,c)+" = 0; "+out_im(s,c)+" = 0;\n"
339  prolog_str+= indent(out)
340 
341  prolog_str+= (
342 """
343 #ifdef MULTI_GPU
344 } else { // exterior kernel
345 
346  sid = blockIdx.x*blockDim.x + threadIdx.x;
347  if (sid >= param.threads) return;
348 
349  const int face_volume = (param.threads >> 1); // volume of one face
350  const int face_num = (sid >= face_volume); // is this thread updating face 0 or 1
351  face_idx = sid - face_num*face_volume; // index into the respective face
352 
353  // ghostOffset is scaled to include body (includes stride) and number of FloatN arrays (SPINOR_HOP)
354  // face_idx not sid since faces are spin projected and share the same volume index (modulo UP/DOWN reading)
355  //sp_idx = face_idx + param.ghostOffset[dim];
356 
357  coordsFromFaceIndex<4,QUDA_4D_PC,kernel_type,1>(X, sid, coord, face_idx, face_num, param);
358 
359  READ_INTERMEDIATE_SPINOR(INTERTEX, param.sp_stride, sid, sid);
360 
361 """)
362 
363  out = ""
364  for s in range(0,4):
365  for c in range(0,3):
366  out += out_re(s,c)+" = "+in_re(s,c)+"; "+out_im(s,c)+" = "+in_im(s,c)+";\n"
367  prolog_str+= indent(out)
368  prolog_str+= "}\n"
369  prolog_str+= "#endif // MULTI_GPU\n\n\n"
370 
371  return prolog_str
372 # end def prolog
373 
374 
375 def gen(dir, pack_only=False):
376  projIdx = dir if not dagger else dir + (1 - 2*(dir%2))
377  projStr = projectorToStr(projectors[projIdx])
378  def proj(i,j):
379  return projectors[projIdx][4*i+j]
380 
381  # if row(i) = (j, c), then the i'th row of the projector can be represented
382  # as a multiple of the j'th row: row(i) = c row(j)
383  def row(i):
384  assert i==2 or i==3
385  if proj(i,0) == 0j:
386  return (1, proj(i,1))
387  if proj(i,1) == 0j:
388  return (0, proj(i,0))
389 
390  boundary = ["coord[0]==(param.dc.X[0]-1)", "coord[0]==0", "coord[1]==(param.dc.X[1]-1)", "coord[1]==0", "coord[2]==(param.dc.X[2]-1)", "coord[2]==0", "coord[3]==(param.dc.X[3]-1)", "coord[3]==0"]
391  interior = ["coord[0]<(param.dc.X[0]-1)", "coord[0]>0", "coord[1]<(param.dc.X[1]-1)", "coord[1]>0", "coord[2]<(param.dc.X[2]-1)", "coord[2]>0", "coord[3]<(param.dc.X[3]-1)", "coord[3]>0"]
392  dim = ["X", "Y", "Z", "T"]
393 
394  # index of neighboring site when not on boundary
395  sp_idx = ["X+1", "X-1", "X+param.dc.X[0]", "X-param.dc.X[0]", "X+param.dc.X2X1", "X-param.dc.X2X1", "X+param.dc.X3X2X1", "X-param.dc.X3X2X1"]
396 
397  # index of neighboring site (across boundary)
398  sp_idx_wrap = ["X-(param.dc.X[0]-1)", "X+(param.dc.X[0]-1)", "X-param.dc.X2X1mX1", "X+param.dc.X2X1mX1", "X-param.dc.X3X2X1mX2X1", "X+param.dc.X3X2X1mX2X1",
399  "X-param.dc.X4X3X2X1mX3X2X1", "X+param.dc.X4X3X2X1mX3X2X1"]
400 
401  cond = ""
402  cond += "#ifdef MULTI_GPU\n"
403  cond += "if ( (kernel_type == INTERIOR_KERNEL && (!param.ghostDim["+`dir/2`+"] || "+interior[dir]+")) ||\n"
404  cond += " (kernel_type == EXTERIOR_KERNEL_"+dim[dir/2]+" && "+boundary[dir]+") )\n"
405  cond += "#endif\n"
406 
407  str = ""
408 
409  projName = "P"+`dir/2`+["-","+"][projIdx%2]
410  str += "// Projector "+projName+"\n"
411  for l in projStr.splitlines():
412  str += "// "+l+"\n"
413  str += "\n"
414 
415  str += "#ifdef MULTI_GPU\n"
416  str += "const int sp_idx = (kernel_type == INTERIOR_KERNEL) ? ("+boundary[dir]+" ? "+sp_idx_wrap[dir]+" : "+sp_idx[dir]+") >> 1 :\n"
417  str += " face_idx + param.ghostOffset[static_cast<int>(kernel_type)][" + `(dir+1)%2` + "];\n"
418  str += "#if (DD_PREC==2) // half precision\n"
419  str += "const int sp_norm_idx = face_idx + param.ghostNormOffset[static_cast<int>(kernel_type)][" + `(dir+1)%2` + "];\n"
420  str += "#endif\n"
421  str += "#else\n"
422  str += "const int sp_idx = ("+boundary[dir]+" ? "+sp_idx_wrap[dir]+" : "+sp_idx[dir]+") >> 1;\n"
423  str += "#endif\n"
424 
425  str += "\n"
426  if dir % 2 == 0:
427  str += "const int ga_idx = sid;\n"
428  else:
429  str += "#ifdef MULTI_GPU\n"
430  str += "const int ga_idx = ((kernel_type == INTERIOR_KERNEL) ? sp_idx : param.dc.Vh+face_idx);\n"
431  str += "#else\n"
432  str += "const int ga_idx = sp_idx;\n"
433  str += "#endif\n"
434  str += "\n"
435 
436  # scan the projector to determine which loads are required
437  row_cnt = ([0,0,0,0])
438  for h in range(0,4):
439  for s in range(0,4):
440  re = proj(h,s).real
441  im = proj(h,s).imag
442  if re != 0 or im != 0:
443  row_cnt[h] += 1
444  row_cnt[0] += row_cnt[1]
445  row_cnt[2] += row_cnt[3]
446 
447  decl_half = ""
448  for h in range(0, 2):
449  for c in range(0, 3):
450  decl_half += "spinorFloat "+h1_re(h,c)+", "+h1_im(h,c)+";\n";
451  decl_half += "\n"
452 
453  load_spinor = "// read spinor from device memory\n"
454 
455  load_spinor += "#ifdef TWIST_INV_DSLASH\n"
456  load_spinor += "#ifdef SPINOR_DOUBLE\n"
457  load_spinor += "const spinorFloat a = param.a;\n"
458  load_spinor += "const spinorFloat b = param.b;\n"
459  load_spinor += "#else\n"
460  load_spinor += "const spinorFloat a = param.a_f;\n"
461  load_spinor += "const spinorFloat b = param.b_f;\n"
462  load_spinor += "#endif\n"
463  load_spinor += "#endif\n"
464 
465  if row_cnt[0] == 0:
466  if not pack_only:
467  load_spinor += "#ifndef TWIST_INV_DSLASH\n"
468  load_spinor += "READ_SPINOR_DOWN(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
469  load_spinor += "#else\n"
470  load_spinor += "READ_SPINOR(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
471  if not dagger:
472  load_spinor += "APPLY_TWIST_INV( a, b, i);\n"
473  else:
474  load_spinor += "APPLY_TWIST_INV(-a, b, i);\n"
475  if not pack_only:
476  load_spinor += "#endif\n"
477  elif row_cnt[2] == 0:
478  if not pack_only:
479  load_spinor += "#ifndef TWIST_INV_DSLASH\n"
480  load_spinor += "READ_SPINOR_UP(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
481  load_spinor += "#else\n"
482  load_spinor += "READ_SPINOR(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
483  if not dagger:
484  load_spinor += "APPLY_TWIST_INV( a, b, i);\n"
485  else:
486  load_spinor += "APPLY_TWIST_INV(-a, b, i);\n"
487  if not pack_only:
488  load_spinor += "#endif\n"
489  else:
490  load_spinor += "READ_SPINOR(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
491  if not pack_only:
492  load_spinor += "#ifdef TWIST_INV_DSLASH\n"
493  if not dagger:
494  load_spinor += "APPLY_TWIST_INV( a, b, i);\n"
495  else:
496  load_spinor += "APPLY_TWIST_INV(-a, b, i);\n"
497  if not pack_only:
498  load_spinor += "#endif\n"
499  load_spinor += "\n"
500 
501  load_half = ""
502  load_half += "const int sp_stride_pad = param.dc.ghostFace[static_cast<int>(kernel_type)];\n"
503  #load_half += "#if (DD_PREC==2) // half precision\n"
504  #load_half += "const int sp_norm_idx = sid + param.ghostNormOffset[static_cast<int>(kernel_type)];\n"
505  #load_half += "#endif\n"
506 
507  if dir >= 6: load_half += "const int t_proj_scale = TPROJSCALE;\n"
508  load_half += "\n"
509  load_half += "// read half spinor from device memory\n"
510 
511 # we have to use the same volume index for backwards and forwards gathers
512 # instead of using READ_UP_SPINOR and READ_DOWN_SPINOR, just use READ_HALF_SPINOR with the appropriate shift
513  load_half += "READ_SPINOR_GHOST(GHOSTSPINORTEX, sp_stride_pad, sp_idx, sp_norm_idx, "+`dir`+");\n\n"
514 # if (dir+1) % 2 == 0: load_half += "READ_HALF_SPINOR(SPINORTEX, sp_stride_pad, sp_idx, sp_norm_idx);\n\n"
515 # else: load_half += "READ_HALF_SPINOR(SPINORTEX, sp_stride_pad, sp_idx + (SPINOR_HOP/2)*sp_stride_pad, sp_norm_idx);\n\n"
516  load_gauge = "// read gauge matrix from device memory\n"
517  load_gauge += "READ_GAUGE_MATRIX(G, GAUGE"+`dir%2`+"TEX, "+`dir`+", ga_idx, param.gauge_stride);\n\n"
518 
519  reconstruct_gauge = "// reconstruct gauge matrix\n"
520  reconstruct_gauge += "RECONSTRUCT_GAUGE_MATRIX("+`dir`+");\n\n"
521 
522  project = "// project spinor into half spinors\n"
523  for h in range(0, 2):
524  for c in range(0, 3):
525  strRe = ""
526  strIm = ""
527  for s in range(0, 4):
528  re = proj(h,s).real
529  im = proj(h,s).imag
530  if re==0 and im==0: ()
531  elif im==0:
532  strRe += sign(re)+in_re(s,c)
533  strIm += sign(re)+in_im(s,c)
534  elif re==0:
535  strRe += sign(-im)+in_im(s,c)
536  strIm += sign(im)+in_re(s,c)
537  if row_cnt[0] == 0: # projector defined on lower half only
538  for s in range(0, 4):
539  re = proj(h+2,s).real
540  im = proj(h+2,s).imag
541  if re==0 and im==0: ()
542  elif im==0:
543  strRe += sign(re)+in_re(s,c)
544  strIm += sign(re)+in_im(s,c)
545  elif re==0:
546  strRe += sign(-im)+in_im(s,c)
547  strIm += sign(im)+in_re(s,c)
548 
549  project += h1_re(h,c)+" = "+strRe+";\n"
550  project += h1_im(h,c)+" = "+strIm+";\n"
551 
552  write_shared = (
553 """// store spinor into shared memory
554 WRITE_SPINOR_SHARED(threadIdx.x, threadIdx.y, threadIdx.z, i);\n
555 """)
556 
557  load_shared_1 = (
558 """// load spinor from shared memory
559 int tx = (threadIdx.x > 0) ? threadIdx.x-1 : blockDim.x-1;
560 __syncthreads();
561 READ_SPINOR_SHARED(tx, threadIdx.y, threadIdx.z);\n
562 """)
563 
564  load_shared_2 = (
565 """// load spinor from shared memory
566 int tx = (threadIdx.x + blockDim.x - ((coord[0]+1)&1) ) % blockDim.x;
567 int ty = (threadIdx.y < blockDim.y - 1) ? threadIdx.y + 1 : 0;
568 READ_SPINOR_SHARED(tx, ty, threadIdx.z);\n
569 """)
570 
571  load_shared_3 = (
572 """// load spinor from shared memory
573 int tx = (threadIdx.x + blockDim.x - ((coord[0]+1)&1)) % blockDim.x;
574 int ty = (threadIdx.y > 0) ? threadIdx.y - 1 : blockDim.y - 1;
575 READ_SPINOR_SHARED(tx, ty, threadIdx.z);\n
576 """)
577 
578  load_shared_4 = (
579 """// load spinor from shared memory
580 int tx = (threadIdx.x + blockDim.x - ((coord[0]+1)&1) ) % blockDim.x;
581 int tz = (threadIdx.z < blockDim.z - 1) ? threadIdx.z + 1 : 0;
582 READ_SPINOR_SHARED(tx, threadIdx.y, tz);\n
583 """)
584 
585  load_shared_5 = (
586 """// load spinor from shared memory
587 int tx = (threadIdx.x + blockDim.x - ((coord[0]+1)&1)) % blockDim.x;
588 int tz = (threadIdx.z > 0) ? threadIdx.z - 1 : blockDim.z - 1;
589 READ_SPINOR_SHARED(tx, threadIdx.y, tz);\n
590 """)
591 
592 
593  copy_half = ""
594  for h in range(0, 2):
595  for c in range(0, 3):
596  copy_half += h1_re(h,c)+" = "+("t_proj_scale*" if (dir >= 6) else "")+in_re(h,c)+"; "
597  copy_half += h1_im(h,c)+" = "+("t_proj_scale*" if (dir >= 6) else "")+in_im(h,c)+";\n"
598  copy_half += "\n"
599 
600  prep_half = ""
601  prep_half += "#ifdef MULTI_GPU\n"
602  prep_half += "if (kernel_type == INTERIOR_KERNEL) {\n"
603  prep_half += "#endif\n"
604  prep_half += "\n"
605 
606  if sharedDslash:
607  if dir == 0:
608  prep_half += indent(load_spinor)
609  prep_half += indent(write_shared)
610  prep_half += indent(project)
611  elif dir == 1:
612  prep_half += indent(load_shared_1)
613  prep_half += indent(project)
614  elif dir == 2:
615  prep_half += indent("if (threadIdx.y == blockDim.y-1 && blockDim.y < param.dc.X[1] ) {\n")
616  prep_half += indent(load_spinor)
617  prep_half += indent(project)
618  prep_half += indent("} else {")
619  prep_half += indent(load_shared_2)
620  prep_half += indent(project)
621  prep_half += indent("}")
622  elif dir == 3:
623  prep_half += indent("if (threadIdx.y == 0 && blockDim.y < param.dc.X[1]) {\n")
624  prep_half += indent(load_spinor)
625  prep_half += indent(project)
626  prep_half += indent("} else {")
627  prep_half += indent(load_shared_3)
628  prep_half += indent(project)
629  prep_half += indent("}")
630  elif dir == 4:
631  prep_half += indent("if (threadIdx.z == blockDim.z-1 && blockDim.z < X3) {\n")
632  prep_half += indent(load_spinor)
633  prep_half += indent(project)
634  prep_half += indent("} else {")
635  prep_half += indent(load_shared_4)
636  prep_half += indent(project)
637  prep_half += indent("}")
638  elif dir == 5:
639  prep_half += indent("if (threadIdx.z == 0 && blockDim.z < X3) {\n")
640  prep_half += indent(load_spinor)
641  prep_half += indent(project)
642  prep_half += indent("} else {")
643  prep_half += indent(load_shared_5)
644  prep_half += indent(project)
645  prep_half += indent("}")
646  else:
647  prep_half += indent(load_spinor)
648  prep_half += indent(project)
649  else:
650  prep_half += indent(load_spinor)
651  prep_half += indent(project)
652 
653  prep_half += "\n"
654  prep_half += "#ifdef MULTI_GPU\n"
655  prep_half += "} else {\n"
656  prep_half += "\n"
657  prep_half += indent(load_half)
658  prep_half += indent(copy_half)
659  prep_half += "}\n"
660  prep_half += "#endif // MULTI_GPU\n"
661  prep_half += "\n"
662 
663  ident = "// identity gauge matrix\n"
664  for m in range(0,3):
665  for h in range(0,2):
666  ident += "spinorFloat "+h2_re(h,m)+" = " + h1_re(h,m) + "; "
667  ident += "spinorFloat "+h2_im(h,m)+" = " + h1_im(h,m) + ";\n"
668  ident += "\n"
669 
670  mult = ""
671  for m in range(0,3):
672  mult += "// multiply row "+`m`+"\n"
673  for h in range(0,2):
674  re = "spinorFloat "+h2_re(h,m)+" = 0;\n"
675  im = "spinorFloat "+h2_im(h,m)+" = 0;\n"
676  for c in range(0,3):
677  re += h2_re(h,m) + " += " + g_re(dir,m,c) + " * "+h1_re(h,c)+";\n"
678  re += h2_re(h,m) + " -= " + g_im(dir,m,c) + " * "+h1_im(h,c)+";\n"
679  im += h2_im(h,m) + " += " + g_re(dir,m,c) + " * "+h1_im(h,c)+";\n"
680  im += h2_im(h,m) + " += " + g_im(dir,m,c) + " * "+h1_re(h,c)+";\n"
681  mult += re + im
682  mult += "\n"
683 
684  reconstruct = ""
685  for m in range(0,3):
686 
687  for h in range(0,2):
688  h_out = h
689  if row_cnt[0] == 0: # projector defined on lower half only
690  h_out = h+2
691 
692  reconstruct += out_re(h_out, m) + " += " + h2_re(h,m) + ";\n"
693  reconstruct += out_im(h_out, m) + " += " + h2_im(h,m) + ";\n"
694 
695  for s in range(2,4):
696  (h,c) = row(s)
697  re = c.real
698  im = c.imag
699  if im == 0 and re == 0: ()
700  elif im == 0:
701  reconstruct += out_re(s, m) + " " + sign(re) + "= " + h2_re(h,m) + ";\n"
702  reconstruct += out_im(s, m) + " " + sign(re) + "= " + h2_im(h,m) + ";\n"
703  elif re == 0:
704  reconstruct += out_re(s, m) + " " + sign(-im) + "= " + h2_im(h,m) + ";\n"
705  reconstruct += out_im(s, m) + " " + sign(+im) + "= " + h2_re(h,m) + ";\n"
706 
707  reconstruct += "\n"
708 
709  if dir >= 6:
710  str += "if (param.gauge_fixed && ga_idx < param.dc.X4X3X2X1hmX3X2X1h)\n"
711  str += block(decl_half + prep_half + ident + reconstruct)
712  str += " else "
713  str += block(decl_half + prep_half + load_gauge + reconstruct_gauge + mult + reconstruct)
714  else:
715  str += decl_half + prep_half + load_gauge + reconstruct_gauge + mult + reconstruct
716 
717  if pack_only:
718  out = load_spinor + decl_half + project
719  out = out.replace("sp_idx", "idx")
720  return out
721  else:
722  return cond + block(str)+"\n\n"
723 # end def gen
724 
725 
726 def input_spinor(s,c,z):
727  if dslash:
728  if z==0: return out_re(s,c)
729  else: return out_im(s,c)
730  else:
731  if z==0: return in_re(s,c)
732  else: return in_im(s,c)
733 
734 
736  str = ""
737  str += "#ifndef TWIST_INV_DSLASH\n"
738  str += "#ifdef SPINOR_DOUBLE\n"
739  str += "const spinorFloat a = param.a;\n"
740  str += "const spinorFloat b = param.b;\n"
741  str += "#else\n"
742  str += "const spinorFloat a = param.a_f;\n"
743  str += "const spinorFloat b = param.b_f;\n"
744  str += "#endif\n"
745  str += "#endif\n"
746 
747  str += "#ifdef DSLASH_XPAY\n"
748  str += "READ_ACCUM(ACCUMTEX, param.sp_stride)\n\n"
749  str += "#ifndef TWIST_XPAY\n"
750  str += "#ifndef TWIST_INV_DSLASH\n"
751  str += "//perform invert twist first:\n"
752  if not dagger:
753  str += "APPLY_TWIST_INV( a, b, o);\n"
754  else:
755  str += "APPLY_TWIST_INV(-a, b, o);\n"
756  str += "#endif\n"
757  for s in range(0,4):
758  for c in range(0,3):
759  i = 3*s+c
760  str += out_re(s,c) +" += "+acc_re(s,c)+";\n"
761  str += out_im(s,c) +" += "+acc_im(s,c)+";\n"
762  str += "#else\n"
763  if not dagger:
764  str += "APPLY_TWIST( a, acc);\n"
765  else:
766  str += "APPLY_TWIST(-a, acc);\n"
767  str += "//warning! b is unrelated to the twisted mass parameter in this case!\n\n"
768  for s in range(0,4):
769  for c in range(0,3):
770  i = 3*s+c
771  str += out_re(s,c) +" = b*"+out_re(s,c)+"+"+acc_re(s,c)+";\n"
772  str += out_im(s,c) +" = b*"+out_im(s,c)+"+"+acc_im(s,c)+";\n"
773  str += "#endif//TWIST_XPAY\n"
774  str += "#else //no XPAY\n"
775  str += "#ifndef TWIST_INV_DSLASH\n"
776  if not dagger:
777  str += " APPLY_TWIST_INV( a, b, o);\n"
778  else:
779  str += " APPLY_TWIST_INV(-a, b, o);\n"
780  str += "#endif\n"
781  str += "#endif\n"
782  return str
783 # end def twisted_xpay
784 
785 
786 def epilog():
787  str = ""
788  if dslash:
789  if twist:
790  str += "#ifdef MULTI_GPU\n"
791  else:
792  str += "#if defined MULTI_GPU && (defined DSLASH_XPAY || defined DSLASH_CLOVER)\n"
793  str += (
794 """
795 int incomplete = 0; // Have all 8 contributions been computed for this site?
796 
797 switch(kernel_type) { // intentional fall-through
798 case INTERIOR_KERNEL:
799  incomplete = incomplete || (param.commDim[3] && (coord[3]==0 || coord[3]==(param.dc.X[3]-1)));
800 case EXTERIOR_KERNEL_T:
801  incomplete = incomplete || (param.commDim[2] && (coord[2]==0 || coord[2]==(param.dc.X[2]-1)));
802 case EXTERIOR_KERNEL_Z:
803  incomplete = incomplete || (param.commDim[1] && (coord[1]==0 || coord[1]==(param.dc.X[1]-1)));
804 case EXTERIOR_KERNEL_Y:
805  incomplete = incomplete || (param.commDim[0] && (coord[0]==0 || coord[0]==(param.dc.X[0]-1)));
806 }
807 
808 """)
809  str += "if (!incomplete)\n"
810  str += "#endif // MULTI_GPU\n"
811 
812  block_str = ""
813  block_str += twisted_xpay()
814  str += block( block_str )
815 
816  str += "\n\n"
817  str += "// write spinor field back to device memory\n"
818  str += "WRITE_SPINOR(param.sp_stride);\n\n"
819 
820  str += "// undefine to prevent warning when precision is changed\n"
821  str += "#undef spinorFloat\n"
822  if sharedDslash:
823  str += "#undef WRITE_SPINOR_SHARED\n"
824  str += "#undef READ_SPINOR_SHARED\n"
825  if sharedFloats > 0: str += "#undef SHARED_STRIDE\n\n"
826 
827  if dslash:
828  for m in range(0,3):
829  for n in range(0,3):
830  i = 3*m+n
831  str += "#undef "+g_re(0,m,n)+"\n"
832  str += "#undef "+g_im(0,m,n)+"\n"
833  str += "\n"
834 
835  for s in range(0,4):
836  for c in range(0,3):
837  i = 3*s+c
838  str += "#undef "+in_re(s,c)+"\n"
839  str += "#undef "+in_im(s,c)+"\n"
840  str += "\n"
841 
842  if dslash:
843  for s in range(0,4):
844  for c in range(0,3):
845  i = 3*s+c
846  str += "#undef "+acc_re(s,c)+"\n"
847  str += "#undef "+acc_im(s,c)+"\n"
848  str += "\n"
849 
850  str += "\n"
851 
852  for s in range(0,4):
853  for c in range(0,3):
854  i = 3*s+c
855  if 2*i < sharedFloats:
856  str += "#undef "+out_re(s,c)+"\n"
857  if 2*i+1 < sharedFloats:
858  str += "#undef "+out_im(s,c)+"\n"
859  str += "\n"
860 
861  str += "#undef VOLATILE\n"
862 
863  return str
864 # end def epilog
865 
866 
867 def pack_face(facenum):
868  str = "\n"
869  str += "switch(dim) {\n"
870  for dim in range(0,4):
871  str += "case "+`dim`+":\n"
872  proj = gen(2*dim+facenum, pack_only=True)
873  proj += "\n"
874  proj += "// write half spinor back to device memory\n"
875  proj += "WRITE_HALF_SPINOR(face_volume, face_idx);\n"
876  str += indent(block(proj)+"\n"+"break;\n")
877  str += "}\n\n"
878  return str
879 # end def pack_face
880 
882  assert (sharedFloats == 0)
883  str = ""
884  str += def_input_spinor()
885  str += "#include \"io_spinor.h\"\n\n"
886 
887  str += "if (face_num) "
888  str += block(pack_face(1))
889  str += " else "
890  str += block(pack_face(0))
891 
892  str += "\n\n"
893  str += "// undefine to prevent warning when precision is changed\n"
894  str += "#undef spinorFloat\n"
895  str += "#undef SHARED_STRIDE\n\n"
896 
897  for s in range(0,4):
898  for c in range(0,3):
899  i = 3*s+c
900  str += "#undef "+in_re(s,c)+"\n"
901  str += "#undef "+in_im(s,c)+"\n"
902  str += "\n"
903 
904  return str
905 # end def generate_pack
906 
907 
909  return prolog() + gen(0) + gen(1) + gen(2) + gen(3) + gen(4) + gen(5) + gen(6) + gen(7) + epilog()
910 
911 # generate Wilson-like Dslash kernels
913  print "Generating dslash kernel for sm" + str(arch/10)
914 
915  global sharedFloats
916  global sharedDslash
917  global dslash
918  global dagger
919  global twist #deg_twist
920 # global ndeg_twist #new!
921 
922  sharedFloats = 0
923  if arch >= 200:
924  sharedFloats = 24
925  sharedDslash = True
926  name = "fermi"
927  elif arch >= 120:
928  sharedFloats = 0
929  sharedDslash = False
930  name = "gt200"
931  else:
932  sharedFloats = 19
933  sharedDslash = False
934  name = "g80"
935 
936  print "Shared floats set to " + str(sharedFloats)
937 
938  dslash = True
939  twist = True
940  dagger = False
941  filename = 'dslash_core/tm_dslash_' + name + '_core.h'
942  print sys.argv[0] + ": generating " + filename;
943  f = open(filename, 'w')
944  f.write(generate_dslash())
945  f.close()
946 
947  dagger = True
948  filename = 'dslash_core/tm_dslash_dagger_' + name + '_core.h'
949  print sys.argv[0] + ": generating " + filename + "\n";
950  f = open(filename, 'w')
951  f.write(generate_dslash())
952  f.close()
953 
954  twist = False
955  dslash = False
956 
957 
958 
959 dslash = False
960 dagger = False
961 twist = False
962 sharedFloats = 0
963 sharedDslash = False
964 pack = False
965 
966 # generate dslash kernels
967 arch = 200
969 
970 arch = 130
972 
973 # generate packing kernels
974 dslash = True
975 sharedFloats = 0
976 twist = False
977 dagger = False
978 pack = True
979 print sys.argv[0] + ": generating wilson_pack_face_core.h";
980 f = open('dslash_core/wilson_pack_twisted_face_core.h', 'w')
981 f.write(generate_pack())
982 f.close()
983 
984 dagger = True
985 print sys.argv[0] + ": generating wilson_pack_face_dagger_core.h";
986 f = open('dslash_core/wilson_pack_twisted_face_dagger_core.h', 'w')
987 f.write(generate_pack())
988 f.close()
989 dslash = False
990 pack = False
991 
992 
def complexify(a)
complex numbers ######################################################################## ...
def indent(code)
code generation ######################################################################## ...
Definition: gen.py:1
def gen(dir, pack_only=False)