QUDA  0.9.0
fused_exterior_deg_tmc_dslash_cuda_gen.py
Go to the documentation of this file.
1 import sys
2 
3 
4 
5 def complexify(a):
6  return [complex(x) for x in a]
7 
8 def complexToStr(c):
9  def fltToString(a):
10  if a == int(a): return `int(a)`
11  else: return `a`
12 
13  def imToString(a):
14  if a == 0: return "0i"
15  elif a == -1: return "-i"
16  elif a == 1: return "i"
17  else: return fltToString(a)+"i"
18 
19  re = c.real
20  im = c.imag
21  if re == 0 and im == 0: return "0"
22  elif re == 0: return imToString(im)
23  elif im == 0: return fltToString(re)
24  else:
25  im_str = "-"+imToString(-im) if im < 0 else "+"+imToString(im)
26  return fltToString(re)+im_str
27 
28 
29 
30 
31 id = complexify([
32  1, 0, 0, 0,
33  0, 1, 0, 0,
34  0, 0, 1, 0,
35  0, 0, 0, 1
36 ])
37 
38 gamma1 = complexify([
39  0, 0, 0, 1j,
40  0, 0, 1j, 0,
41  0, -1j, 0, 0,
42  -1j, 0, 0, 0
43 ])
44 
45 gamma2 = complexify([
46  0, 0, 0, 1,
47  0, 0, -1, 0,
48  0, -1, 0, 0,
49  1, 0, 0, 0
50 ])
51 
52 gamma3 = complexify([
53  0, 0, 1j, 0,
54  0, 0, 0, -1j,
55  -1j, 0, 0, 0,
56  0, 1j, 0, 0
57 ])
58 
59 gamma4 = complexify([
60  1, 0, 0, 0,
61  0, 1, 0, 0,
62  0, 0, -1, 0,
63  0, 0, 0, -1
64 ])
65 
66 igamma5 = complexify([
67  0, 0, 1j, 0,
68  0, 0, 0, 1j,
69  1j, 0, 0, 0,
70  0, 1j, 0, 0
71 ])
72 
73 
74 def gplus(g1, g2):
75  return [x+y for (x,y) in zip(g1,g2)]
76 
77 def gminus(g1, g2):
78  return [x-y for (x,y) in zip(g1,g2)]
79 
81  out = ""
82  for i in range(0, 4):
83  for j in range(0,4):
84  out += complexToStr(p[4*i+j]) + " "
85  out += "\n"
86  return out
87 
88 projectors = [
89  gminus(id,gamma1), gplus(id,gamma1),
90  gminus(id,gamma2), gplus(id,gamma2),
91  gminus(id,gamma3), gplus(id,gamma3),
92  gminus(id,gamma4), gplus(id,gamma4),
93 ]
94 
95 
96 
97 def indent(code):
98  def indentline(line): return (" "+line if (line.count("#", 0, 1) == 0) else line)
99  return ''.join([indentline(line)+"\n" for line in code.splitlines()])
100 
101 def block(code):
102  return "{\n"+indent(code)+"}"
103 
104 def sign(x):
105  if x==1: return "+"
106  elif x==-1: return "-"
107  elif x==+2: return "+2*"
108  elif x==-2: return "-2*"
109 
110 def nthFloat4(n):
111  return `(n/4)` + "." + ["x", "y", "z", "w"][n%4]
112 
113 def nthFloat2(n):
114  return `(n/2)` + "." + ["x", "y"][n%2]
115 
116 
117 def in_re(s, c): return "i"+`s`+`c`+"_re"
118 def in_im(s, c): return "i"+`s`+`c`+"_im"
119 def g_re(d, m, n): return ("g" if (d%2==0) else "gT")+`m`+`n`+"_re"
120 def g_im(d, m, n): return ("g" if (d%2==0) else "gT")+`m`+`n`+"_im"
121 def out_re(s, c): return "o"+`s`+`c`+"_re"
122 def out_im(s, c): return "o"+`s`+`c`+"_im"
123 def h1_re(h, c): return ["a","b"][h]+`c`+"_re"
124 def h1_im(h, c): return ["a","b"][h]+`c`+"_im"
125 def h2_re(h, c): return ["A","B"][h]+`c`+"_re"
126 def h2_im(h, c): return ["A","B"][h]+`c`+"_im"
127 def c_re(b, sm, cm, sn, cn): return "c"+`(sm+2*b)`+`cm`+"_"+`(sn+2*b)`+`cn`+"_re"
128 def c_im(b, sm, cm, sn, cn): return "c"+`(sm+2*b)`+`cm`+"_"+`(sn+2*b)`+`cn`+"_im"
129 def cinv_re(b, sm, cm, sn, cn): return "cinv"+`(sm+2*b)`+`cm`+"_"+`(sn+2*b)`+`cn`+"_re"
130 def cinv_im(b, sm, cm, sn, cn): return "cinv"+`(sm+2*b)`+`cm`+"_"+`(sn+2*b)`+`cn`+"_im"
131 def a_re(b, s, c): return "a"+`(s+2*b)`+`c`+"_re"
132 def a_im(b, s, c): return "a"+`(s+2*b)`+`c`+"_im"
133 
134 def acc_re(s, c): return "acc"+`s`+`c`+"_re"
135 def acc_im(s, c): return "acc"+`s`+`c`+"_im"
136 
137 def tmp_re(s, c): return "tmp"+`s`+`c`+"_re"
138 def tmp_im(s, c): return "tmp"+`s`+`c`+"_im"
139 
140 def spinor(name, s, c, z):
141  if z==0: return name+`s`+`c`+"_re"
142  else: return name+`s`+`c`+"_im"
143 
145  str = ""
146  str += "// input spinor\n"
147  str += "#ifdef SPINOR_DOUBLE\n"
148  str += "#define spinorFloat double\n"
149  if sharedDslash:
150  str += "#define WRITE_SPINOR_SHARED WRITE_SPINOR_SHARED_DOUBLE2\n"
151  str += "#define READ_SPINOR_SHARED READ_SPINOR_SHARED_DOUBLE2\n"
152 
153  for s in range(0,4):
154  for c in range(0,3):
155  i = 3*s+c
156  str += "#define "+in_re(s,c)+" I"+nthFloat2(2*i+0)+"\n"
157  str += "#define "+in_im(s,c)+" I"+nthFloat2(2*i+1)+"\n"
158  if dslash and not pack:
159  for s in range(0,4):
160  for c in range(0,3):
161  i = 3*s+c
162  str += "#define "+acc_re(s,c)+" accum"+nthFloat2(2*i+0)+"\n"
163  str += "#define "+acc_im(s,c)+" accum"+nthFloat2(2*i+1)+"\n"
164  str += "#else\n"
165  str += "#define spinorFloat float\n"
166  if sharedDslash:
167  str += "#define WRITE_SPINOR_SHARED WRITE_SPINOR_SHARED_FLOAT4\n"
168  str += "#define READ_SPINOR_SHARED READ_SPINOR_SHARED_FLOAT4\n"
169  for s in range(0,4):
170  for c in range(0,3):
171  i = 3*s+c
172  str += "#define "+in_re(s,c)+" I"+nthFloat4(2*i+0)+"\n"
173  str += "#define "+in_im(s,c)+" I"+nthFloat4(2*i+1)+"\n"
174  if dslash and not pack:
175  for s in range(0,4):
176  for c in range(0,3):
177  i = 3*s+c
178  str += "#define "+acc_re(s,c)+" accum"+nthFloat4(2*i+0)+"\n"
179  str += "#define "+acc_im(s,c)+" accum"+nthFloat4(2*i+1)+"\n"
180  str += "#endif // SPINOR_DOUBLE\n\n"
181  return str
182 # end def def_input_spinor
183 
184 
185 def def_gauge():
186  str = "// gauge link\n"
187  str += "#ifdef GAUGE_FLOAT2\n"
188  for m in range(0,3):
189  for n in range(0,3):
190  i = 3*m+n
191  str += "#define "+g_re(0,m,n)+" G"+nthFloat2(2*i+0)+"\n"
192  str += "#define "+g_im(0,m,n)+" G"+nthFloat2(2*i+1)+"\n"
193 
194  str += "\n"
195  str += "#else\n"
196  for m in range(0,3):
197  for n in range(0,3):
198  i = 3*m+n
199  str += "#define "+g_re(0,m,n)+" G"+nthFloat4(2*i+0)+"\n"
200  str += "#define "+g_im(0,m,n)+" G"+nthFloat4(2*i+1)+"\n"
201 
202  str += "\n"
203  str += "#endif // GAUGE_DOUBLE\n\n"
204 
205  str += "// conjugated gauge link\n"
206  for m in range(0,3):
207  for n in range(0,3):
208  i = 3*m+n
209  str += "#define "+g_re(1,m,n)+" (+"+g_re(0,n,m)+")\n"
210  str += "#define "+g_im(1,m,n)+" (-"+g_im(0,n,m)+")\n"
211  str += "\n"
212 
213  return str
214 # end def def_gauge
215 
216 def def_clover(pack_only=False):
217  str = "// first chiral block of clover term\n"
218  str += "#ifdef CLOVER_DOUBLE\n"
219  i = 0
220  for m in range(0,6):
221  s = m/3
222  c = m%3
223  str += "#define "+c_re(0,s,c,s,c)+" C"+nthFloat2(i)+"\n"
224  i += 1
225  for n in range(0,6):
226  sn = n/3
227  cn = n%3
228  for m in range(n+1,6):
229  sm = m/3
230  cm = m%3
231  str += "#define "+c_re(0,sm,cm,sn,cn)+" C"+nthFloat2(i)+"\n"
232  str += "#define "+c_im(0,sm,cm,sn,cn)+" C"+nthFloat2(i+1)+"\n"
233  i += 2
234  str += "#else\n"
235  i = 0
236  for m in range(0,6):
237  s = m/3
238  c = m%3
239  str += "#define "+c_re(0,s,c,s,c)+" C"+nthFloat4(i)+"\n"
240  i += 1
241  for n in range(0,6):
242  sn = n/3
243  cn = n%3
244  for m in range(n+1,6):
245  sm = m/3
246  cm = m%3
247  str += "#define "+c_re(0,sm,cm,sn,cn)+" C"+nthFloat4(i)+"\n"
248  str += "#define "+c_im(0,sm,cm,sn,cn)+" C"+nthFloat4(i+1)+"\n"
249  i += 2
250  str += "#endif // CLOVER_DOUBLE\n\n"
251 
252  for n in range(0,6):
253  sn = n/3
254  cn = n%3
255  for m in range(0,n):
256  sm = m/3
257  cm = m%3
258  str += "#define "+c_re(0,sm,cm,sn,cn)+" (+"+c_re(0,sn,cn,sm,cm)+")\n"
259  str += "#define "+c_im(0,sm,cm,sn,cn)+" (-"+c_im(0,sn,cn,sm,cm)+")\n"
260  str += "\n"
261 
262  str += "// second chiral block of clover term (reuses C0,...,C9)\n"
263  for n in range(0,6):
264  sn = n/3
265  cn = n%3
266  for m in range(0,6):
267  sm = m/3
268  cm = m%3
269  str += "#define "+c_re(1,sm,cm,sn,cn)+" "+c_re(0,sm,cm,sn,cn)+"\n"
270  if m != n: str += "#define "+c_im(1,sm,cm,sn,cn)+" "+c_im(0,sm,cm,sn,cn)+"\n"
271  str += "\n\n"
272 
273 
274  str += "// first chiral block of inverted clover term\n"
275  str += "#ifdef CLOVER_DOUBLE\n"
276  i = 0
277  for m in range(0,6):
278  s = m/3
279  c = m%3
280  str += "#define "+cinv_re(0,s,c,s,c)+" C"+nthFloat2(i)+"\n"
281  i += 1
282  for n in range(0,6):
283  sn = n/3
284  cn = n%3
285  for m in range(n+1,6):
286  sm = m/3
287  cm = m%3
288  str += "#define "+cinv_re(0,sm,cm,sn,cn)+" C"+nthFloat2(i)+"\n"
289  str += "#define "+cinv_im(0,sm,cm,sn,cn)+" C"+nthFloat2(i+1)+"\n"
290  i += 2
291  str += "#else\n"
292  i = 0
293  for m in range(0,6):
294  s = m/3
295  c = m%3
296  str += "#define "+cinv_re(0,s,c,s,c)+" C"+nthFloat4(i)+"\n"
297  i += 1
298  for n in range(0,6):
299  sn = n/3
300  cn = n%3
301  for m in range(n+1,6):
302  sm = m/3
303  cm = m%3
304  str += "#define "+cinv_re(0,sm,cm,sn,cn)+" C"+nthFloat4(i)+"\n"
305  str += "#define "+cinv_im(0,sm,cm,sn,cn)+" C"+nthFloat4(i+1)+"\n"
306  i += 2
307  str += "#endif // CLOVER_DOUBLE\n\n"
308 
309  for n in range(0,6):
310  sn = n/3
311  cn = n%3
312  for m in range(0,n):
313  sm = m/3
314  cm = m%3
315  str += "#define "+cinv_re(0,sm,cm,sn,cn)+" (+"+cinv_re(0,sn,cn,sm,cm)+")\n"
316  str += "#define "+cinv_im(0,sm,cm,sn,cn)+" (-"+cinv_im(0,sn,cn,sm,cm)+")\n"
317  str += "\n"
318 
319  str += "// second chiral block of inverted clover term (reuses C0,...,C9)\n"
320  for n in range(0,6):
321  sn = n/3
322  cn = n%3
323  for m in range(0,6):
324  sm = m/3
325  cm = m%3
326  str += "#define "+cinv_re(1,sm,cm,sn,cn)+" "+cinv_re(0,sm,cm,sn,cn)+"\n"
327  if m != n: str += "#define "+cinv_im(1,sm,cm,sn,cn)+" "+cinv_im(0,sm,cm,sn,cn)+"\n"
328  str += "\n\n"
329  if dagger and not pack_only:
330  str += "#ifndef CLOVER_TWIST_INV_DSLASH\n"
331  if not pack_only:
332  str += (
333 """
334 // declare C## here and use ASSN below instead of READ
335 #ifdef CLOVER_DOUBLE
336 double2 C0;
337 double2 C1;
338 double2 C2;
339 double2 C3;
340 double2 C4;
341 double2 C5;
342 double2 C6;
343 double2 C7;
344 double2 C8;
345 double2 C9;
346 double2 C10;
347 double2 C11;
348 double2 C12;
349 double2 C13;
350 double2 C14;
351 double2 C15;
352 double2 C16;
353 double2 C17;
354 #else
355 float4 C0;
356 float4 C1;
357 float4 C2;
358 float4 C3;
359 float4 C4;
360 float4 C5;
361 float4 C6;
362 float4 C7;
363 float4 C8;
364 
365 #if (DD_PREC==2)
366 float K;
367 #endif
368 
369 #endif // CLOVER_DOUBLE
370 """)
371  if dagger and not pack_only:
372  str += "#endif\n\n"
373 
374  return str
375 # end def def_clover
376 
377 
379 # sharedDslash = True: input spinors stored in shared memory
380 # sharedDslash = False: output spinors stored in shared memory
381  str = "// output spinor\n"
382  for s in range(0,4):
383  for c in range(0,3):
384  i = 3*s+c
385  if 2*i < sharedFloats and not sharedDslash:
386  str += "#define "+out_re(s,c)+" s["+`(2*i+0)`+"*SHARED_STRIDE]\n"
387  else:
388  str += "VOLATILE spinorFloat "+out_re(s,c)+";\n"
389  if 2*i+1 < sharedFloats and not sharedDslash:
390  str += "#define "+out_im(s,c)+" s["+`(2*i+1)`+"*SHARED_STRIDE]\n"
391  else:
392  str += "VOLATILE spinorFloat "+out_im(s,c)+";\n"
393  return str
394 # end def def_output_spinor
395 
396 
397 def prolog():
398  global arch
399 
400  prolog_str = ("#ifdef MULTI_GPU\n\n")
401 
402  if dslash:
403  prolog_str+= ("// *** CUDA DSLASH ***\n\n" if not dagger else "// *** CUDA DSLASH DAGGER ***\n\n")
404  prolog_str+= "#define DSLASH_SHARED_FLOATS_PER_THREAD "+str(sharedFloats)+"\n\n"
405  else:
406  print "Undefined prolog"
407  exit
408 
409  prolog_str+= (
410 """
411 #if ((CUDA_VERSION >= 4010) && (__COMPUTE_CAPABILITY__ >= 200)) // NVVM compiler
412 #define VOLATILE
413 #else // Open64 compiler
414 #define VOLATILE volatile
415 #endif
416 """)
417 
418  prolog_str+= def_input_spinor()
419  if dslash == True: prolog_str+= def_gauge()
420  prolog_str+= def_clover()
421  prolog_str+= def_output_spinor()
422 
423  if (sharedFloats > 0):
424  if (arch >= 200):
425  prolog_str+= (
426 """
427 #ifdef SPINOR_DOUBLE
428 #define SHARED_STRIDE 16 // to avoid bank conflicts on Fermi
429 #else
430 #define SHARED_STRIDE 32 // to avoid bank conflicts on Fermi
431 #endif
432 """)
433  else:
434  prolog_str+= (
435 """
436 #ifdef SPINOR_DOUBLE
437 #define SHARED_STRIDE 8 // to avoid bank conflicts on G80 and GT200
438 #else
439 #define SHARED_STRIDE 16 // to avoid bank conflicts on G80 and GT200
440 #endif
441 """)
442 
443 
444  # set the pointer if using shared memory for pseudo registers
445  if sharedFloats > 0 and not sharedDslash:
446  prolog_str += (
447 """
448 extern __shared__ char s_data[];
449 """)
450 
451  if dslash:
452  prolog_str += (
453 """
454 VOLATILE spinorFloat *s = (spinorFloat*)s_data + DSLASH_SHARED_FLOATS_PER_THREAD*SHARED_STRIDE*(threadIdx.x/SHARED_STRIDE)
455  + (threadIdx.x % SHARED_STRIDE);
456 """)
457 
458 
459  if dslash:
460  prolog_str += (
461 """
462 #include "read_gauge.h"
463 #include "io_spinor.h"
464 #include "read_clover.h"
465 #include "tmc_core.h"
466 
467 int coord[5];
468 int X;
469 
470 #if (DD_PREC==2) // half precision
471 int sp_norm_idx;
472 #endif // half precision
473 
474 int sid;
475 """)
476 
477  prolog_str+= (
478 """
479  sid = blockIdx.x*blockDim.x + threadIdx.x;
480  if (sid >= param.threads) return;
481 
482 
483  int dim = dimFromFaceIndex(sid, param); // sid is also modified
484 
485  const int face_volume = ((param.threadDimMapUpper[dim] - param.threadDimMapLower[dim]) >> 1);
486  const int face_num = (sid >= face_volume); // is this thread updating face 0 or 1
487  int face_idx = sid - face_num*face_volume; // index into the respective face
488 
489  switch(dim) {
490  case 0:
491  coordsFromFaceIndex<4,QUDA_4D_PC,0,1>(X, sid, coord, face_idx, face_num, param);
492  break;
493  case 1:
494  coordsFromFaceIndex<4,QUDA_4D_PC,1,1>(X, sid, coord, face_idx, face_num, param);
495  break;
496  case 2:
497  coordsFromFaceIndex<4,QUDA_4D_PC,2,1>(X, sid, coord, face_idx, face_num, param);
498  break;
499  case 3:
500  coordsFromFaceIndex<4,QUDA_4D_PC,3,1>(X, sid, coord, face_idx, face_num, param);
501  break;
502  }
503 
504 
505  bool active = false;
506  for(int dir=0; dir<4; ++dir){
507  active = active || isActive(dim,dir,+1,coord,param.commDim,param.dc.X);
508  }
509  if(!active) return;
510 
511 
512  READ_INTERMEDIATE_SPINOR(INTERTEX, param.sp_stride, sid, sid);
513 
514 """)
515 
516  out = ""
517  for s in range(0,4):
518  for c in range(0,3):
519  out += out_re(s,c)+" = "+in_re(s,c)+"; "+out_im(s,c)+" = "+in_im(s,c)+";\n"
520  prolog_str+= indent(out)
521 # prolog_str+= "}\n"
522 # prolog_str+= "#endif // MULTI_GPU\n\n\n"
523 
524  return prolog_str
525 # end def prolog
526 
527 
528 def gen(dir, pack_only=False):
529  projIdx = dir if not dagger else dir + (1 - 2*(dir%2))
530  projStr = projectorToStr(projectors[projIdx])
531  def proj(i,j):
532  return projectors[projIdx][4*i+j]
533 
534  # if row(i) = (j, c), then the i'th row of the projector can be represented
535  # as a multiple of the j'th row: row(i) = c row(j)
536  def row(i):
537  assert i==2 or i==3
538  if proj(i,0) == 0j:
539  return (1, proj(i,1))
540  if proj(i,1) == 0j:
541  return (0, proj(i,0))
542 
543  boundary = ["coord[0]==(param.dc.X[0]-1)", "coord[0]==0", "coord[1]==(param.dc.X[1]-1)", "coord[1]==0", "coord[2]==(param.dc.X[2]-1)", "coord[2]==0", "coord[3]==(param.dc.X[3]-1)", "coord[3]==0"]
544  interior = ["coord[0]<(param.dc.X[0]-1)", "coord[0]>0", "coord[1]<(param.dc.X[1]-1)", "coord[1]>0", "coord[2]<(param.dc.X[2]-1)", "coord[2]>0", "coord[3]<(param.dc.X[3]-1)", "coord[3]>0"]
545  offset = ["+1","-1","+1","-1","+1","-1","+1","-1"];
546  dim = ["X", "Y", "Z", "T"]
547 
548  # index of neighboring site when not on boundary
549  sp_idx = ["X+1", "X-1", "X+param.dc.X[0]", "X-param.dc.X[0]", "X+param.dc.X2X1", "X-param.dc.X2X1", "X+param.dc.X3X2X1", "X-param.dc.X3X2X1"]
550 
551  # index of neighboring site (across boundary)
552  sp_idx_wrap = ["X-(param.dc.X[0]-1)", "X+(param.dc.X[0]-1)", "X-param.dc.X2X1mX1", "X+param.dc.X2X1mX1", "X-param.dc.X3X2X1mX2X1", "X+param.dc.X3X2X1mX2X1",
553  "X-param.dc.X4X3X2X1mX3X2X1", "X+param.dc.X4X3X2X1mX3X2X1"]
554 
555  cond = ""
556 # cond += "#ifdef MULTI_GPU\n"
557  cond += "if (isActive(dim," + `dir/2` + "," + offset[dir] + ",coord,param.commDim,param.dc.X) && " +boundary[dir]+" )\n"
558 # cond += "#endif\n"
559 
560  str = ""
561 
562  projName = "P"+`dir/2`+["-","+"][projIdx%2]
563  str += "// Projector "+projName+"\n"
564  for l in projStr.splitlines():
565  str += "// "+l+"\n"
566  str += "\n"
567 
568  str += "faceIndexFromCoords<4,1>(face_idx,coord," + `dir/2` + ",param);\n"
569  str += "const int sp_idx = face_idx + param.ghostOffset[" + `dir/2` + "][" + `1-dir%2` +"];\n"
570 
571  str += "#if (DD_PREC==2)\n"
572  str += " sp_norm_idx = face_idx + "
573  str += "param.ghostNormOffset[" + `dir/2` + "][" + `1-dir%2` + "];\n"
574  str += "#endif\n"
575 
576 
577 
578 
579  str += "\n"
580  if dir % 2 == 0:
581  str += "const int ga_idx = sid;\n"
582  else:
583  str += "const int ga_idx = param.dc.Vh+face_idx;\n"
584  str += "\n"
585 
586  # scan the projector to determine which loads are required
587  row_cnt = ([0,0,0,0])
588  for h in range(0,4):
589  for s in range(0,4):
590  re = proj(h,s).real
591  im = proj(h,s).imag
592  if re != 0 or im != 0:
593  row_cnt[h] += 1
594  row_cnt[0] += row_cnt[1]
595  row_cnt[2] += row_cnt[3]
596 
597  decl_half = ""
598  for h in range(0, 2):
599  for c in range(0, 3):
600  decl_half += "spinorFloat "+h1_re(h,c)+", "+h1_im(h,c)+";\n";
601  decl_half += "\n"
602 
603  load_spinor = "// read spinor from device memory\n"
604  if row_cnt[0] == 0:
605  if not dagger:
606  if not pack_only:
607  load_spinor += "#ifndef CLOVER_TWIST_INV_DSLASH\n"
608  load_spinor += "READ_SPINOR_DOWN(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
609  load_spinor += "#else\n"
610  load_spinor += "READ_SPINOR(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
611  load_spinor += "#ifndef DYNAMIC_CLOVER\n"
612  load_spinor += "APPLY_CLOVER_TWIST_INV(c, cinv, a, i);\n"
613  load_spinor += "#else\n"
614  load_spinor += "APPLY_CLOVER_TWIST_DYN_INV(c, a, i);\n"
615  load_spinor += "#endif\n"
616  else:
617  if not pack_only:
618  load_spinor += "READ_SPINOR_DOWN(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
619  else:
620  load_spinor += "READ_SPINOR(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
621  load_spinor += "#ifndef DYNAMIC_CLOVER\n"
622  load_spinor += "APPLY_CLOVER_TWIST_INV(c, cinv, -a, i);\n"
623  load_spinor += "#else\n"
624  load_spinor += "APPLY_CLOVER_TWIST_DYN_INV(c, -a, i);\n"
625  load_spinor += "#endif\n"
626  if not pack_only and not dagger:
627  load_spinor += "#endif\n"
628  elif row_cnt[2] == 0:
629  if not dagger:
630  if not pack_only:
631  load_spinor += "#ifndef CLOVER_TWIST_INV_DSLASH\n"
632  load_spinor += "READ_SPINOR_UP(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
633  load_spinor += "#else\n"
634  load_spinor += "READ_SPINOR(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
635  load_spinor += "#ifndef DYNAMIC_CLOVER\n"
636  load_spinor += "APPLY_CLOVER_TWIST_INV(c, cinv, a, i);\n"
637  load_spinor += "#else\n"
638  load_spinor += "APPLY_CLOVER_TWIST_DYN_INV(c, a, i);\n"
639  load_spinor += "#endif\n"
640  else:
641  if not pack_only:
642  load_spinor += "READ_SPINOR_UP(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
643  else:
644  load_spinor += "READ_SPINOR(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
645  load_spinor += "#ifndef DYNAMIC_CLOVER\n"
646  load_spinor += "APPLY_CLOVER_TWIST_INV(c, cinv, -a, i);\n"
647  load_spinor += "#else\n"
648  load_spinor += "APPLY_CLOVER_TWIST_DYN_INV(c, -a, i);\n"
649  load_spinor += "#endif\n"
650  if not pack_only and not dagger:
651  load_spinor += "#endif\n"
652  else:
653  load_spinor += "READ_SPINOR(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
654  if not dagger and not pack_only:
655  load_spinor += "#ifdef CLOVER_TWIST_INV_DSLASH\n"
656  load_spinor += "#ifndef DYNAMIC_CLOVER\n"
657  load_spinor += "APPLY_CLOVER_TWIST_INV(c, cinv, a, i);\n"
658  load_spinor += "#else\n"
659  load_spinor += "APPLY_CLOVER_TWIST_DYN_INV(c, a, i);\n"
660  load_spinor += "#endif\n"
661  load_spinor += "#endif\n"
662  if pack_only:
663  if not dagger: # Was behind
664  load_spinor += "#ifndef DYNAMIC_CLOVER\n"
665  load_spinor += "APPLY_CLOVER_TWIST_INV(c, cinv, a, i);\n"
666  load_spinor += "#else\n"
667  load_spinor += "APPLY_CLOVER_TWIST_DYN_INV(c, a, i);\n"
668  load_spinor += "#endif\n"
669  else:
670  load_spinor += "#ifndef DYNAMIC_CLOVER\n"
671  load_spinor += "APPLY_CLOVER_TWIST_INV(c, cinv, -a, i);\n"
672  load_spinor += "#else\n"
673  load_spinor += "APPLY_CLOVER_TWIST_DYN_INV(c, -a, i);\n"
674  load_spinor += "#endif\n"
675  load_spinor += "\n"
676 
677  load_half = ""
678  load_half += "const int sp_stride_pad = param.dc.ghostFace[" + `dir/2` + "];\n"
679  #load_half += "#if (DD_PREC==2) // half precision\n"
680  #load_half += "const int sp_norm_idx = sid + param.ghostNormOffset[static_cast<int>(kernel_type)];\n"
681  #load_half += "#endif\n"
682 
683  if dir >= 6:
684  load_half += "const int t_proj_scale = TPROJSCALE;\n"
685 
686  load_half += "\n"
687  load_half += "// read half spinor from device memory\n"
688 
689 # we have to use the same volume index for backwards and forwards gathers
690 # instead of using READ_UP_SPINOR and READ_DOWN_SPINOR, just use READ_HALF_SPINOR with the appropriate shift
691  load_half += "READ_SPINOR_GHOST(GHOSTSPINORTEX, sp_stride_pad, sp_idx, sp_norm_idx, "+`dir`+");\n\n"
692 # if (dir+1) % 2 == 0: load_half += "READ_HALF_SPINOR(SPINORTEX, sp_stride_pad, sp_idx, sp_norm_idx);\n\n"
693 # else: load_half += "READ_HALF_SPINOR(SPINORTEX, sp_stride_pad, sp_idx + (SPINOR_HOP/2)*sp_stride_pad, sp_norm_idx);\n\n"
694  load_gauge = "// read gauge matrix from device memory\n"
695  load_gauge += "READ_GAUGE_MATRIX(G, GAUGE"+`dir%2`+"TEX, "+`dir`+", ga_idx, param.gauge_stride);\n\n"
696 
697  reconstruct_gauge = "// reconstruct gauge matrix\n"
698  reconstruct_gauge += "RECONSTRUCT_GAUGE_MATRIX("+`dir`+");\n\n"
699 
700  project = "// project spinor into half spinors\n"
701  for h in range(0, 2):
702  for c in range(0, 3):
703  strRe = ""
704  strIm = ""
705  for s in range(0, 4):
706  re = proj(h,s).real
707  im = proj(h,s).imag
708  if re==0 and im==0: ()
709  elif im==0:
710  strRe += sign(re)+in_re(s,c)
711  strIm += sign(re)+in_im(s,c)
712  elif re==0:
713  strRe += sign(-im)+in_im(s,c)
714  strIm += sign(im)+in_re(s,c)
715  if row_cnt[0] == 0: # projector defined on lower half only
716  for s in range(0, 4):
717  re = proj(h+2,s).real
718  im = proj(h+2,s).imag
719  if re==0 and im==0: ()
720  elif im==0:
721  strRe += sign(re)+in_re(s,c)
722  strIm += sign(re)+in_im(s,c)
723  elif re==0:
724  strRe += sign(-im)+in_im(s,c)
725  strIm += sign(im)+in_re(s,c)
726 
727  project += h1_re(h,c)+" = "+strRe+";\n"
728  project += h1_im(h,c)+" = "+strIm+";\n"
729 
730  write_shared = (
731 """// store spinor into shared memory
732 WRITE_SPINOR_SHARED(threadIdx.x, threadIdx.y, threadIdx.z, i);\n
733 """)
734 
735  load_shared_1 = (
736 """// load spinor from shared memory
737 int tx = (threadIdx.x > 0) ? threadIdx.x-1 : blockDim.x-1;
738 __syncthreads();
739 READ_SPINOR_SHARED(tx, threadIdx.y, threadIdx.z);\n
740 """)
741 
742  load_shared_2 = (
743 """// load spinor from shared memory
744 int tx = (threadIdx.x + blockDim.x - ((coord[0]+1)&1) ) % blockDim.x;
745 int ty = (threadIdx.y < blockDim.y - 1) ? threadIdx.y + 1 : 0;
746 READ_SPINOR_SHARED(tx, ty, threadIdx.z);\n
747 """)
748 
749  load_shared_3 = (
750 """// load spinor from shared memory
751 int tx = (threadIdx.x + blockDim.x - ((coord[0]+1)&1)) % blockDim.x;
752 int ty = (threadIdx.y > 0) ? threadIdx.y - 1 : blockDim.y - 1;
753 READ_SPINOR_SHARED(tx, ty, threadIdx.z);\n
754 """)
755 
756  load_shared_4 = (
757 """// load spinor from shared memory
758 int tx = (threadIdx.x + blockDim.x - ((coord[0]+1)&1) ) % blockDim.x;
759 int tz = (threadIdx.z < blockDim.z - 1) ? threadIdx.z + 1 : 0;
760 READ_SPINOR_SHARED(tx, threadIdx.y, tz);\n
761 """)
762 
763  load_shared_5 = (
764 """// load spinor from shared memory
765 int tx = (threadIdx.x + blockDim.x - ((coord[0]+1)&1)) % blockDim.x;
766 int tz = (threadIdx.z > 0) ? threadIdx.z - 1 : blockDim.z - 1;
767 READ_SPINOR_SHARED(tx, threadIdx.y, tz);\n
768 """)
769 
770 
771  copy_half = ""
772  if dir < 6:
773  for h in range(0, 2):
774  for c in range(0, 3):
775  copy_half += h1_re(h,c)+" = "+in_re(h,c)+"; "
776  copy_half += h1_im(h,c)+" = "+in_im(h,c)+";\n"
777  else:
778  for h in range(0, 2):
779  for c in range(0, 3):
780  copy_half += h1_re(h,c)+" = t_proj_scale*"+in_re(h,c)+"; "
781  copy_half += h1_im(h,c)+" = t_proj_scale*"+in_im(h,c)+";\n"
782  copy_half += "\n"
783 
784  prep_half = ""
785  prep_half += "\n"
786  prep_half += load_half
787  prep_half += copy_half
788 
789  ident = "// identity gauge matrix\n"
790  for m in range(0,3):
791  for h in range(0,2):
792  ident += "spinorFloat "+h2_re(h,m)+" = " + h1_re(h,m) + "; "
793  ident += "spinorFloat "+h2_im(h,m)+" = " + h1_im(h,m) + ";\n"
794  ident += "\n"
795 
796  mult = ""
797  for m in range(0,3):
798  mult += "// multiply row "+`m`+"\n"
799  for h in range(0,2):
800  re = "spinorFloat "+h2_re(h,m)+" = 0;\n"
801  im = "spinorFloat "+h2_im(h,m)+" = 0;\n"
802  for c in range(0,3):
803  re += h2_re(h,m) + " += " + g_re(dir,m,c) + " * "+h1_re(h,c)+";\n"
804  re += h2_re(h,m) + " -= " + g_im(dir,m,c) + " * "+h1_im(h,c)+";\n"
805  im += h2_im(h,m) + " += " + g_re(dir,m,c) + " * "+h1_im(h,c)+";\n"
806  im += h2_im(h,m) + " += " + g_im(dir,m,c) + " * "+h1_re(h,c)+";\n"
807  mult += re + im
808  mult += "\n"
809 
810  reconstruct = ""
811  for m in range(0,3):
812 
813  for h in range(0,2):
814  h_out = h
815  if row_cnt[0] == 0: # projector defined on lower half only
816  h_out = h+2
817 
818  reconstruct += out_re(h_out, m) + " += " + h2_re(h,m) + ";\n"
819  reconstruct += out_im(h_out, m) + " += " + h2_im(h,m) + ";\n"
820 
821  for s in range(2,4):
822  (h,c) = row(s)
823  re = c.real
824  im = c.imag
825  if im == 0 and re == 0: ()
826  elif im == 0:
827  reconstruct += out_re(s, m) + " " + sign(re) + "= " + h2_re(h,m) + ";\n"
828  reconstruct += out_im(s, m) + " " + sign(re) + "= " + h2_im(h,m) + ";\n"
829  elif re == 0:
830  reconstruct += out_re(s, m) + " " + sign(-im) + "= " + h2_im(h,m) + ";\n"
831  reconstruct += out_im(s, m) + " " + sign(+im) + "= " + h2_re(h,m) + ";\n"
832 
833  reconstruct += "\n"
834 
835  if dir >= 6:
836  str += "if (param.gauge_fixed && ga_idx < param.dc.X4X3X2X1hmX3X2X1h)\n"
837  str += block(decl_half + prep_half + ident + reconstruct)
838  str += " else "
839  str += block(decl_half + prep_half + load_gauge + reconstruct_gauge + mult + reconstruct)
840  else:
841  str += decl_half + prep_half + load_gauge + reconstruct_gauge + mult + reconstruct
842 
843  if pack_only:
844  out = load_spinor + decl_half + project
845  out = out.replace("sp_idx", "idx")
846  return out
847  else:
848  return cond + block(str)+"\n\n"
849 # end def gen
850 
851 
852 def input_spinor(s,c,z):
853  if dslash:
854  if z==0: return out_re(s,c)
855  else: return out_im(s,c)
856  else:
857  if z==0: return in_re(s,c)
858  else: return in_im(s,c)
859 
860 
862  str = ""
863  if dagger:
864  str += "#if !defined(CLOVER_TWIST_INV_DSLASH)\n"
865  str += "#ifdef SPINOR_DOUBLE\n"
866  str += "spinorFloat a = param.a;\n"
867  str += "#else\n"
868  str += "spinorFloat a = param.a_f;\n"
869  str += "#endif\n"
870  if dagger:
871  str += "#endif\n"
872  str += "#ifdef DSLASH_XPAY\n"
873 
874  str += "#ifdef SPINOR_DOUBLE\n"
875  str += "spinorFloat b = param.b;\n"
876  str += "#else\n"
877  str += "spinorFloat b = param.b_f;\n"
878  str += "#endif\n"
879 
880  str += "READ_ACCUM(ACCUMTEX, param.sp_stride)\n\n"
881  if not dagger:
882  str += "#ifndef CLOVER_TWIST_XPAY\n"
883  str += "//perform invert twist first:\n"
884  str += "#ifndef DYNAMIC_CLOVER\n"
885  str += "APPLY_CLOVER_TWIST_INV(c, cinv, a, o);\n"
886  str += "#else\n"
887  str += "APPLY_CLOVER_TWIST_DYN_INV(c, a, o);\n"
888  str += "#endif\n"
889  for s in range(0,4):
890  for c in range(0,3):
891  i = 3*s+c
892  str += out_re(s,c) +" = b*"+out_re(s,c)+" + "+acc_re(s,c)+";\n"
893  str += out_im(s,c) +" = b*"+out_im(s,c)+" + "+acc_im(s,c)+";\n"
894  str += "#else\n"
895  str += "APPLY_CLOVER_TWIST(c, a, acc);\n"
896  for s in range(0,4):
897  for c in range(0,3):
898  i = 3*s+c
899  str += out_re(s,c) +" = b*"+out_re(s,c)+" + "+acc_re(s,c)+";\n"
900  str += out_im(s,c) +" = b*"+out_im(s,c)+" + "+acc_im(s,c)+";\n"
901  str += "#endif//CLOVER_TWIST_XPAY\n"
902  str += "#else //no XPAY\n"
903  str += "#ifndef DYNAMIC_CLOVER\n"
904  str += "APPLY_CLOVER_TWIST_INV(c, cinv, a, o);\n"
905  str += "#else\n"
906  str += "APPLY_CLOVER_TWIST_DYN_INV(c, a, o);\n"
907  str += "#endif\n"
908  str += "#endif\n"
909  else:
910  str += "#ifndef CLOVER_TWIST_INV_DSLASH\n"
911  str += "#ifndef CLOVER_TWIST_XPAY\n"
912  str += "//perform invert twist first:\n"
913  str += "#ifndef DYNAMIC_CLOVER\n"
914  str += "APPLY_CLOVER_TWIST_INV(c, cinv, -a, o);\n"
915  str += "#else\n"
916  str += "APPLY_CLOVER_TWIST_DYN_INV(c, -a, o);\n"
917  str += "#endif\n"
918  str += "#else\n"
919  str += "APPLY_CLOVER_TWIST(c, -a, acc);\n"
920  str += "#endif\n"
921  str += "#endif\n"
922  for s in range(0,4):
923  for c in range(0,3):
924  i = 3*s+c
925  str += out_re(s,c) +" = b*"+out_re(s,c)+" + "+acc_re(s,c)+";\n"
926  str += out_im(s,c) +" = b*"+out_im(s,c)+" + "+acc_im(s,c)+";\n"
927  str += "#else //no XPAY\n"
928  str += "#ifndef CLOVER_TWIST_INV_DSLASH\n"
929  str += "#ifndef DYNAMIC_CLOVER\n"
930  str += "APPLY_CLOVER_TWIST_INV(c, cinv, -a, o);\n"
931  str += "#else\n"
932  str += "APPLY_CLOVER_TWIST_DYN_INV(c, -a, o);\n"
933  str += "#endif\n"
934  str += "#endif\n"
935  str += "#endif\n"
936  return str
937 # end def clover_twisted_xpay
938 
939 
940 def epilog():
941  str = ""
942  block_str = ""
943  block_str += clover_twisted_xpay()
944  str += block( block_str )
945 
946  str += "\n\n"
947  str += "// write spinor field back to device memory\n"
948  str += "WRITE_SPINOR(param.sp_stride);\n\n"
949 
950  str += "// undefine to prevent warning when precision is changed\n"
951  str += "#undef spinorFloat\n"
952  if sharedDslash:
953  str += "#undef WRITE_SPINOR_SHARED\n"
954  str += "#undef READ_SPINOR_SHARED\n"
955  if sharedFloats > 0: str += "#undef SHARED_STRIDE\n\n"
956 
957  if dslash:
958  for m in range(0,3):
959  for n in range(0,3):
960  i = 3*m+n
961  str += "#undef "+g_re(0,m,n)+"\n"
962  str += "#undef "+g_im(0,m,n)+"\n"
963  str += "\n"
964 
965  for s in range(0,4):
966  for c in range(0,3):
967  i = 3*s+c
968  str += "#undef "+in_re(s,c)+"\n"
969  str += "#undef "+in_im(s,c)+"\n"
970  str += "\n"
971 
972  if dslash:
973  for s in range(0,4):
974  for c in range(0,3):
975  i = 3*s+c
976  str += "#undef "+acc_re(s,c)+"\n"
977  str += "#undef "+acc_im(s,c)+"\n"
978  str += "\n"
979 
980  str += "\n"
981 
982  for s in range(0,4):
983  for c in range(0,3):
984  i = 3*s+c
985  if 2*i < sharedFloats:
986  str += "#undef "+out_re(s,c)+"\n"
987  if 2*i+1 < sharedFloats:
988  str += "#undef "+out_im(s,c)+"\n"
989  str += "\n"
990 
991  for m in range(0,6):
992  s = m/3
993  c = m%3
994  str += "#undef "+c_re(0,s,c,s,c)+"\n"
995  for n in range(0,6):
996  sn = n/3
997  cn = n%3
998  for m in range(n+1,6):
999  sm = m/3
1000  cm = m%3
1001  str += "#undef "+c_re(0,sm,cm,sn,cn)+"\n"
1002  str += "#undef "+c_im(0,sm,cm,sn,cn)+"\n"
1003  str += "\n"
1004 
1005  for m in range(0,6):
1006  s = m/3
1007  c = m%3
1008  str += "#undef "+cinv_re(0,s,c,s,c)+"\n"
1009  for n in range(0,6):
1010  sn = n/3
1011  cn = n%3
1012  for m in range(n+1,6):
1013  sm = m/3
1014  cm = m%3
1015  str += "#undef "+cinv_re(0,sm,cm,sn,cn)+"\n"
1016  str += "#undef "+cinv_im(0,sm,cm,sn,cn)+"\n"
1017  str += "\n"
1018 
1019  str += "#undef VOLATILE\n\n"
1020  str += "#endif // MULTI_GPU\n"
1021 
1022  return str
1023 # end def epilog
1024 
1025 
1026 def pack_face(facenum):
1027  str = "\n"
1028  str += "switch(dim) {\n"
1029  for dim in range(0,4):
1030  str += "case "+`dim`+":\n"
1031  proj = gen(2*dim+facenum, pack_only=True)
1032  proj += "\n"
1033  proj += "// write half spinor back to device memory\n"
1034  proj += "WRITE_HALF_SPINOR(face_volume, face_idx);\n"
1035  str += indent(block(proj)+"\n"+"break;\n")
1036  str += "}\n\n"
1037  return str
1038 # end def pack_face
1039 
1041  assert (sharedFloats == 0)
1042  str = ""
1043  str += def_input_spinor()
1044  str += def_clover(True)
1045  str += "#include \"io_spinor.h\"\n\n"
1046  str += "#include \"read_clover.h\"\n\n"
1047  str += "#include \"tmc_core.h\"\n\n"
1048 
1049  str += "if (face_num) "
1050  str += block(pack_face(1))
1051  str += " else "
1052  str += block(pack_face(0))
1053 
1054  str += "\n\n"
1055  str += "// undefine to prevent warning when precision is changed\n"
1056  str += "#undef spinorFloat\n"
1057  str += "#undef SHARED_STRIDE\n\n"
1058 
1059  for s in range(0,4):
1060  for c in range(0,3):
1061  i = 3*s+c
1062  str += "#undef "+in_re(s,c)+"\n"
1063  str += "#undef "+in_im(s,c)+"\n"
1064  str += "\n"
1065 
1066  for m in range(0,6):
1067  s = m/3
1068  c = m%3
1069  str += "#undef "+c_re(0,s,c,s,c)+"\n"
1070  for n in range(0,6):
1071  sn = n/3
1072  cn = n%3
1073  for m in range(n+1,6):
1074  sm = m/3
1075  cm = m%3
1076  str += "#undef "+c_re(0,sm,cm,sn,cn)+"\n"
1077  str += "#undef "+c_im(0,sm,cm,sn,cn)+"\n"
1078  str += "\n"
1079 
1080  for m in range(0,6):
1081  s = m/3
1082  c = m%3
1083  str += "#undef "+cinv_re(0,s,c,s,c)+"\n"
1084  for n in range(0,6):
1085  sn = n/3
1086  cn = n%3
1087  for m in range(n+1,6):
1088  sm = m/3
1089  cm = m%3
1090  str += "#undef "+cinv_re(0,sm,cm,sn,cn)+"\n"
1091  str += "#undef "+cinv_im(0,sm,cm,sn,cn)+"\n"
1092  str += "\n"
1093 
1094  return str
1095 # end def generate_pack
1096 
1097 
1099  return prolog() + gen(0) + gen(1) + gen(2) + gen(3) + gen(4) + gen(5) + gen(6) + gen(7) + epilog()
1100 
1101 # generate Wilson-like Dslash kernels
1103  print "Generating dslash kernel for sm" + str(arch/10)
1104 
1105  global sharedFloats
1106  global sharedDslash
1107  global dslash
1108  global dagger
1109  global twist #deg_twist
1110 # global ndeg_twist #new!
1111 
1112  sharedFloats = 0
1113  if arch >= 200:
1114  sharedFloats = 24
1115  sharedDslash = True
1116  name = "fermi"
1117  elif arch >= 120:
1118  sharedFloats = 0
1119  sharedDslash = False
1120  name = "gt200"
1121  else:
1122  sharedFloats = 19
1123  sharedDslash = False
1124  name = "g80"
1125 
1126  print "Shared floats set to " + str(sharedFloats)
1127 
1128  dslash = True
1129  twist = True
1130  dagger = False
1131  filename = 'dslash_core/tmc_fused_exterior_dslash_' + name + '_core.h'
1132  print sys.argv[0] + ": generating " + filename;
1133  f = open(filename, 'w')
1134  f.write(generate_dslash())
1135  f.close()
1136 
1137  dagger = True
1138  filename = 'dslash_core/tmc_fused_exterior_dslash_dagger_' + name + '_core.h'
1139  print sys.argv[0] + ": generating " + filename + "\n";
1140  f = open(filename, 'w')
1141  f.write(generate_dslash())
1142  f.close()
1143 
1144  twist = False
1145  dslash = False
1146 
1147 
1148 
1149 dslash = False
1150 dagger = False
1151 twist = False
1152 sharedFloats = 0
1153 sharedDslash = False
1154 pack = False
1155 
1156 # generate dslash kernels
1157 arch = 200
1159 
1160 arch = 130
Definition: gen.py:1
def indent(code)
code generation ######################################################################## ...
def complexify(a)
complex numbers ######################################################################## ...