QUDA  0.9.0
deg_tmc_dslash_cuda_gen.py
Go to the documentation of this file.
1 import sys
2 
3 
4 
5 def complexify(a):
6  return [complex(x) for x in a]
7 
8 def complexToStr(c):
9  def fltToString(a):
10  if a == int(a): return `int(a)`
11  else: return `a`
12 
13  def imToString(a):
14  if a == 0: return "0i"
15  elif a == -1: return "-i"
16  elif a == 1: return "i"
17  else: return fltToString(a)+"i"
18 
19  re = c.real
20  im = c.imag
21  if re == 0 and im == 0: return "0"
22  elif re == 0: return imToString(im)
23  elif im == 0: return fltToString(re)
24  else:
25  im_str = "-"+imToString(-im) if im < 0 else "+"+imToString(im)
26  return fltToString(re)+im_str
27 
28 
29 
30 
31 id = complexify([
32  1, 0, 0, 0,
33  0, 1, 0, 0,
34  0, 0, 1, 0,
35  0, 0, 0, 1
36 ])
37 
38 gamma1 = complexify([
39  0, 0, 0, 1j,
40  0, 0, 1j, 0,
41  0, -1j, 0, 0,
42  -1j, 0, 0, 0
43 ])
44 
45 gamma2 = complexify([
46  0, 0, 0, 1,
47  0, 0, -1, 0,
48  0, -1, 0, 0,
49  1, 0, 0, 0
50 ])
51 
52 gamma3 = complexify([
53  0, 0, 1j, 0,
54  0, 0, 0, -1j,
55  -1j, 0, 0, 0,
56  0, 1j, 0, 0
57 ])
58 
59 gamma4 = complexify([
60  1, 0, 0, 0,
61  0, 1, 0, 0,
62  0, 0, -1, 0,
63  0, 0, 0, -1
64 ])
65 
66 igamma5 = complexify([
67  0, 0, 1j, 0,
68  0, 0, 0, 1j,
69  1j, 0, 0, 0,
70  0, 1j, 0, 0
71 ])
72 
73 
74 def gplus(g1, g2):
75  return [x+y for (x,y) in zip(g1,g2)]
76 
77 def gminus(g1, g2):
78  return [x-y for (x,y) in zip(g1,g2)]
79 
81  out = ""
82  for i in range(0, 4):
83  for j in range(0,4):
84  out += complexToStr(p[4*i+j]) + " "
85  out += "\n"
86  return out
87 
88 projectors = [
89  gminus(id,gamma1), gplus(id,gamma1),
90  gminus(id,gamma2), gplus(id,gamma2),
91  gminus(id,gamma3), gplus(id,gamma3),
92  gminus(id,gamma4), gplus(id,gamma4),
93 ]
94 
95 
96 
97 def indent(code):
98  def indentline(line): return (" "+line if (line.count("#", 0, 1) == 0) else line)
99  return ''.join([indentline(line)+"\n" for line in code.splitlines()])
100 
101 def block(code):
102  return "{\n"+indent(code)+"}"
103 
104 def sign(x):
105  if x==1: return "+"
106  elif x==-1: return "-"
107  elif x==+2: return "+2*"
108  elif x==-2: return "-2*"
109 
110 def nthFloat4(n):
111  return `(n/4)` + "." + ["x", "y", "z", "w"][n%4]
112 
113 def nthFloat2(n):
114  return `(n/2)` + "." + ["x", "y"][n%2]
115 
116 
117 def in_re(s, c): return "i"+`s`+`c`+"_re"
118 def in_im(s, c): return "i"+`s`+`c`+"_im"
119 def g_re(d, m, n): return ("g" if (d%2==0) else "gT")+`m`+`n`+"_re"
120 def g_im(d, m, n): return ("g" if (d%2==0) else "gT")+`m`+`n`+"_im"
121 def out_re(s, c): return "o"+`s`+`c`+"_re"
122 def out_im(s, c): return "o"+`s`+`c`+"_im"
123 def h1_re(h, c): return ["a","b"][h]+`c`+"_re"
124 def h1_im(h, c): return ["a","b"][h]+`c`+"_im"
125 def h2_re(h, c): return ["A","B"][h]+`c`+"_re"
126 def h2_im(h, c): return ["A","B"][h]+`c`+"_im"
127 def c_re(b, sm, cm, sn, cn): return "c"+`(sm+2*b)`+`cm`+"_"+`(sn+2*b)`+`cn`+"_re"
128 def c_im(b, sm, cm, sn, cn): return "c"+`(sm+2*b)`+`cm`+"_"+`(sn+2*b)`+`cn`+"_im"
129 def cinv_re(b, sm, cm, sn, cn): return "cinv"+`(sm+2*b)`+`cm`+"_"+`(sn+2*b)`+`cn`+"_re"
130 def cinv_im(b, sm, cm, sn, cn): return "cinv"+`(sm+2*b)`+`cm`+"_"+`(sn+2*b)`+`cn`+"_im"
131 def a_re(b, s, c): return "a"+`(s+2*b)`+`c`+"_re"
132 def a_im(b, s, c): return "a"+`(s+2*b)`+`c`+"_im"
133 
134 def acc_re(s, c): return "acc"+`s`+`c`+"_re"
135 def acc_im(s, c): return "acc"+`s`+`c`+"_im"
136 
137 def tmp_re(s, c): return "tmp"+`s`+`c`+"_re"
138 def tmp_im(s, c): return "tmp"+`s`+`c`+"_im"
139 
140 def spinor(name, s, c, z):
141  if z==0: return name+`s`+`c`+"_re"
142  else: return name+`s`+`c`+"_im"
143 
145  str = ""
146  str += "// input spinor\n"
147  str += "#ifdef SPINOR_DOUBLE\n"
148  str += "#define spinorFloat double\n"
149  if sharedDslash:
150  str += "#define WRITE_SPINOR_SHARED WRITE_SPINOR_SHARED_DOUBLE2\n"
151  str += "#define READ_SPINOR_SHARED READ_SPINOR_SHARED_DOUBLE2\n"
152 
153  for s in range(0,4):
154  for c in range(0,3):
155  i = 3*s+c
156  str += "#define "+in_re(s,c)+" I"+nthFloat2(2*i+0)+"\n"
157  str += "#define "+in_im(s,c)+" I"+nthFloat2(2*i+1)+"\n"
158  if dslash and not pack:
159  for s in range(0,4):
160  for c in range(0,3):
161  i = 3*s+c
162  str += "#define "+acc_re(s,c)+" accum"+nthFloat2(2*i+0)+"\n"
163  str += "#define "+acc_im(s,c)+" accum"+nthFloat2(2*i+1)+"\n"
164  str += "#else\n"
165  str += "#define spinorFloat float\n"
166  if sharedDslash:
167  str += "#define WRITE_SPINOR_SHARED WRITE_SPINOR_SHARED_FLOAT4\n"
168  str += "#define READ_SPINOR_SHARED READ_SPINOR_SHARED_FLOAT4\n"
169  for s in range(0,4):
170  for c in range(0,3):
171  i = 3*s+c
172  str += "#define "+in_re(s,c)+" I"+nthFloat4(2*i+0)+"\n"
173  str += "#define "+in_im(s,c)+" I"+nthFloat4(2*i+1)+"\n"
174  if dslash and not pack:
175  for s in range(0,4):
176  for c in range(0,3):
177  i = 3*s+c
178  str += "#define "+acc_re(s,c)+" accum"+nthFloat4(2*i+0)+"\n"
179  str += "#define "+acc_im(s,c)+" accum"+nthFloat4(2*i+1)+"\n"
180  str += "#endif // SPINOR_DOUBLE\n\n"
181  return str
182 # end def def_input_spinor
183 
184 
185 def def_gauge():
186  str = "// gauge link\n"
187  str += "#ifdef GAUGE_FLOAT2\n"
188  for m in range(0,3):
189  for n in range(0,3):
190  i = 3*m+n
191  str += "#define "+g_re(0,m,n)+" G"+nthFloat2(2*i+0)+"\n"
192  str += "#define "+g_im(0,m,n)+" G"+nthFloat2(2*i+1)+"\n"
193 
194  str += "\n"
195  str += "#else\n"
196  for m in range(0,3):
197  for n in range(0,3):
198  i = 3*m+n
199  str += "#define "+g_re(0,m,n)+" G"+nthFloat4(2*i+0)+"\n"
200  str += "#define "+g_im(0,m,n)+" G"+nthFloat4(2*i+1)+"\n"
201 
202  str += "\n"
203  str += "#endif // GAUGE_DOUBLE\n\n"
204 
205  str += "// conjugated gauge link\n"
206  for m in range(0,3):
207  for n in range(0,3):
208  i = 3*m+n
209  str += "#define "+g_re(1,m,n)+" (+"+g_re(0,n,m)+")\n"
210  str += "#define "+g_im(1,m,n)+" (-"+g_im(0,n,m)+")\n"
211  str += "\n"
212 
213  return str
214 # end def def_gauge
215 
216 def def_clover(pack_only=False):
217  str = "// first chiral block of clover term\n"
218  str += "#ifdef CLOVER_DOUBLE\n"
219  i = 0
220  for m in range(0,6):
221  s = m/3
222  c = m%3
223  str += "#define "+c_re(0,s,c,s,c)+" C"+nthFloat2(i)+"\n"
224  i += 1
225  for n in range(0,6):
226  sn = n/3
227  cn = n%3
228  for m in range(n+1,6):
229  sm = m/3
230  cm = m%3
231  str += "#define "+c_re(0,sm,cm,sn,cn)+" C"+nthFloat2(i)+"\n"
232  str += "#define "+c_im(0,sm,cm,sn,cn)+" C"+nthFloat2(i+1)+"\n"
233  i += 2
234  str += "#else\n"
235  i = 0
236  for m in range(0,6):
237  s = m/3
238  c = m%3
239  str += "#define "+c_re(0,s,c,s,c)+" C"+nthFloat4(i)+"\n"
240  i += 1
241  for n in range(0,6):
242  sn = n/3
243  cn = n%3
244  for m in range(n+1,6):
245  sm = m/3
246  cm = m%3
247  str += "#define "+c_re(0,sm,cm,sn,cn)+" C"+nthFloat4(i)+"\n"
248  str += "#define "+c_im(0,sm,cm,sn,cn)+" C"+nthFloat4(i+1)+"\n"
249  i += 2
250  str += "#endif // CLOVER_DOUBLE\n\n"
251 
252  for n in range(0,6):
253  sn = n/3
254  cn = n%3
255  for m in range(0,n):
256  sm = m/3
257  cm = m%3
258  str += "#define "+c_re(0,sm,cm,sn,cn)+" (+"+c_re(0,sn,cn,sm,cm)+")\n"
259  str += "#define "+c_im(0,sm,cm,sn,cn)+" (-"+c_im(0,sn,cn,sm,cm)+")\n"
260  str += "\n"
261 
262  str += "// second chiral block of clover term (reuses C0,...,C9)\n"
263  for n in range(0,6):
264  sn = n/3
265  cn = n%3
266  for m in range(0,6):
267  sm = m/3
268  cm = m%3
269  str += "#define "+c_re(1,sm,cm,sn,cn)+" "+c_re(0,sm,cm,sn,cn)+"\n"
270  if m != n: str += "#define "+c_im(1,sm,cm,sn,cn)+" "+c_im(0,sm,cm,sn,cn)+"\n"
271  str += "\n\n"
272 
273 
274  str += "// first chiral block of inverted clover term\n"
275  str += "#ifdef CLOVER_DOUBLE\n"
276  i = 0
277  for m in range(0,6):
278  s = m/3
279  c = m%3
280  str += "#define "+cinv_re(0,s,c,s,c)+" C"+nthFloat2(i)+"\n"
281  i += 1
282  for n in range(0,6):
283  sn = n/3
284  cn = n%3
285  for m in range(n+1,6):
286  sm = m/3
287  cm = m%3
288  str += "#define "+cinv_re(0,sm,cm,sn,cn)+" C"+nthFloat2(i)+"\n"
289  str += "#define "+cinv_im(0,sm,cm,sn,cn)+" C"+nthFloat2(i+1)+"\n"
290  i += 2
291  str += "#else\n"
292  i = 0
293  for m in range(0,6):
294  s = m/3
295  c = m%3
296  str += "#define "+cinv_re(0,s,c,s,c)+" C"+nthFloat4(i)+"\n"
297  i += 1
298  for n in range(0,6):
299  sn = n/3
300  cn = n%3
301  for m in range(n+1,6):
302  sm = m/3
303  cm = m%3
304  str += "#define "+cinv_re(0,sm,cm,sn,cn)+" C"+nthFloat4(i)+"\n"
305  str += "#define "+cinv_im(0,sm,cm,sn,cn)+" C"+nthFloat4(i+1)+"\n"
306  i += 2
307  str += "#endif // CLOVER_DOUBLE\n\n"
308 
309  for n in range(0,6):
310  sn = n/3
311  cn = n%3
312  for m in range(0,n):
313  sm = m/3
314  cm = m%3
315  str += "#define "+cinv_re(0,sm,cm,sn,cn)+" (+"+cinv_re(0,sn,cn,sm,cm)+")\n"
316  str += "#define "+cinv_im(0,sm,cm,sn,cn)+" (-"+cinv_im(0,sn,cn,sm,cm)+")\n"
317  str += "\n"
318 
319  str += "// second chiral block of inverted clover term (reuses C0,...,C9)\n"
320  for n in range(0,6):
321  sn = n/3
322  cn = n%3
323  for m in range(0,6):
324  sm = m/3
325  cm = m%3
326  str += "#define "+cinv_re(1,sm,cm,sn,cn)+" "+cinv_re(0,sm,cm,sn,cn)+"\n"
327  if m != n: str += "#define "+cinv_im(1,sm,cm,sn,cn)+" "+cinv_im(0,sm,cm,sn,cn)+"\n"
328  str += "\n\n"
329  if dagger and not pack_only:
330  str += "#ifndef CLOVER_TWIST_INV_DSLASH\n"
331  if not pack_only:
332  str += (
333 """
334 // declare C## here and use ASSN below instead of READ
335 #ifdef CLOVER_DOUBLE
336 double2 C0;
337 double2 C1;
338 double2 C2;
339 double2 C3;
340 double2 C4;
341 double2 C5;
342 double2 C6;
343 double2 C7;
344 double2 C8;
345 double2 C9;
346 double2 C10;
347 double2 C11;
348 double2 C12;
349 double2 C13;
350 double2 C14;
351 double2 C15;
352 double2 C16;
353 double2 C17;
354 #else
355 float4 C0;
356 float4 C1;
357 float4 C2;
358 float4 C3;
359 float4 C4;
360 float4 C5;
361 float4 C6;
362 float4 C7;
363 float4 C8;
364 
365 #if (DD_PREC==2)
366 float K;
367 #endif
368 
369 #endif // CLOVER_DOUBLE
370 """)
371  if dagger and not pack_only:
372  str += "#endif\n\n"
373 
374  return str
375 # end def def_clover
376 
378 # sharedDslash = True: input spinors stored in shared memory
379 # sharedDslash = False: output spinors stored in shared memory
380  str = "// output spinor\n"
381  for s in range(0,4):
382  for c in range(0,3):
383  i = 3*s+c
384  if 2*i < sharedFloats and not sharedDslash:
385  str += "#define "+out_re(s,c)+" s["+`(2*i+0)`+"*SHARED_STRIDE]\n"
386  else:
387  str += "VOLATILE spinorFloat "+out_re(s,c)+";\n"
388  if 2*i+1 < sharedFloats and not sharedDslash:
389  str += "#define "+out_im(s,c)+" s["+`(2*i+1)`+"*SHARED_STRIDE]\n"
390  else:
391  str += "VOLATILE spinorFloat "+out_im(s,c)+";\n"
392  return str
393 # end def def_output_spinor
394 
395 
396 def prolog():
397  global arch
398 
399  if dslash:
400  prolog_str= ("// *** CUDA DSLASH ***\n\n" if not dagger else "// *** CUDA DSLASH DAGGER ***\n\n")
401  prolog_str+= "#define DSLASH_SHARED_FLOATS_PER_THREAD "+str(sharedFloats)+"\n\n"
402  else:
403  print "Undefined prolog"
404  exit
405 
406  prolog_str+= (
407 """
408 #if ((CUDA_VERSION >= 4010) && (__COMPUTE_CAPABILITY__ >= 200)) // NVVM compiler
409 #define VOLATILE
410 #else // Open64 compiler
411 #define VOLATILE volatile
412 #endif
413 """)
414 
415  prolog_str+= def_input_spinor()
416  if dslash == True: prolog_str+= def_gauge()
417  prolog_str+= def_clover()
418  prolog_str+= def_output_spinor()
419 
420  if (sharedFloats > 0):
421  if (arch >= 200):
422  prolog_str+= (
423 """
424 #ifdef SPINOR_DOUBLE
425 #define SHARED_STRIDE 16 // to avoid bank conflicts on Fermi
426 #else
427 #define SHARED_STRIDE 32 // to avoid bank conflicts on Fermi
428 #endif
429 """)
430  else:
431  prolog_str+= (
432 """
433 #ifdef SPINOR_DOUBLE
434 #define SHARED_STRIDE 8 // to avoid bank conflicts on G80 and GT200
435 #else
436 #define SHARED_STRIDE 16 // to avoid bank conflicts on G80 and GT200
437 #endif
438 """)
439 
440 
441  # set the pointer if using shared memory for pseudo registers
442  if sharedFloats > 0 and not sharedDslash:
443  prolog_str += (
444 """
445 extern __shared__ char s_data[];
446 """)
447 
448  if dslash:
449  prolog_str += (
450 """
451 VOLATILE spinorFloat *s = (spinorFloat*)s_data + DSLASH_SHARED_FLOATS_PER_THREAD*SHARED_STRIDE*(threadIdx.x/SHARED_STRIDE)
452  + (threadIdx.x % SHARED_STRIDE);
453 """)
454 
455 
456  if dslash:
457  prolog_str += (
458 """
459 #include "read_gauge.h"
460 #include "io_spinor.h"
461 #include "read_clover.h"
462 #include "tmc_core.h"
463 
464 int coord[5];
465 int X;
466 
467 int sid;
468 """)
469 
470  if sharedDslash:
471  prolog_str += (
472 """
473 #ifdef MULTI_GPU
474 int face_idx;
475 if (kernel_type == INTERIOR_KERNEL) {
476 #endif
477 
478  // Assume even dimensions
479  coordsFromIndex3D<EVEN_X>(X, coord, sid, param);
480 
481  // only need to check Y and Z dims currently since X and T set to match exactly
482  if (coord[1] >= param.dc.X[1]) return;
483  if (coord[2] >= param.dc.X[2]) return;
484 
485 """)
486  else:
487  prolog_str += (
488 """
489 #ifdef MULTI_GPU
490 int face_idx;
491 if (kernel_type == INTERIOR_KERNEL) {
492 #endif
493 
494  sid = blockIdx.x*blockDim.x + threadIdx.x;
495  if (sid >= param.threads) return;
496 
497  // Assume even dimensions
498  coordsFromIndex<4,QUDA_4D_PC,EVEN_X>(X, coord, sid, param);
499 
500 """)
501 
502  out = ""
503  for s in range(0,4):
504  for c in range(0,3):
505  out += out_re(s,c)+" = 0; "+out_im(s,c)+" = 0;\n"
506  prolog_str+= indent(out)
507 
508  prolog_str+= (
509 """
510 #ifdef MULTI_GPU
511 } else { // exterior kernel
512 
513  sid = blockIdx.x*blockDim.x + threadIdx.x;
514  if (sid >= param.threads) return;
515 
516  const int face_volume = (param.threads >> 1); // volume of one face
517  const int face_num = (sid >= face_volume); // is this thread updating face 0 or 1
518  face_idx = sid - face_num*face_volume; // index into the respective face
519 
520  // ghostOffset is scaled to include body (includes stride) and number of FloatN arrays (SPINOR_HOP)
521  // face_idx not sid since faces are spin projected and share the same volume index (modulo UP/DOWN reading)
522  //sp_idx = face_idx + param.ghostOffset[dim];
523 
524  coordsFromFaceIndex<4,QUDA_4D_PC,kernel_type,1>(X, sid, coord, face_idx, face_num, param);
525 
526  READ_INTERMEDIATE_SPINOR(INTERTEX, param.sp_stride, sid, sid);
527 
528 """)
529 
530  out = ""
531  for s in range(0,4):
532  for c in range(0,3):
533  out += out_re(s,c)+" = "+in_re(s,c)+"; "+out_im(s,c)+" = "+in_im(s,c)+";\n"
534  prolog_str+= indent(out)
535  prolog_str+= "}\n"
536  prolog_str+= "#endif // MULTI_GPU\n\n\n"
537 
538  return prolog_str
539 # end def prolog
540 
541 
542 def gen(dir, pack_only=False):
543  projIdx = dir if not dagger else dir + (1 - 2*(dir%2))
544  projStr = projectorToStr(projectors[projIdx])
545  def proj(i,j):
546  return projectors[projIdx][4*i+j]
547 
548  # if row(i) = (j, c), then the i'th row of the projector can be represented
549  # as a multiple of the j'th row: row(i) = c row(j)
550  def row(i):
551  assert i==2 or i==3
552  if proj(i,0) == 0j:
553  return (1, proj(i,1))
554  if proj(i,1) == 0j:
555  return (0, proj(i,0))
556 
557  boundary = ["coord[0]==(param.dc.X[0]-1)", "coord[0]==0", "coord[1]==(param.dc.X[1]-1)", "coord[1]==0", "coord[2]==(param.dc.X[2]-1)", "coord[2]==0", "coord[3]==(param.dc.X[3]-1)", "coord[3]==0"]
558  interior = ["coord[0]<(param.dc.X[0]-1)", "coord[0]>0", "coord[1]<(param.dc.X[1]-1)", "coord[1]>0", "coord[2]<(param.dc.X[2]-1)", "coord[2]>0", "coord[3]<(param.dc.X[3]-1)", "coord[3]>0"]
559  dim = ["X", "Y", "Z", "T"]
560 
561  # index of neighboring site when not on boundary
562  sp_idx = ["X+1", "X-1", "X+param.dc.X[0]", "X-param.dc.X[0]", "X+param.dc.X2X1", "X-param.dc.X2X1", "X+param.dc.X3X2X1", "X-param.dc.X3X2X1"]
563 
564  # index of neighboring site (across boundary)
565  sp_idx_wrap = ["X-(param.dc.X[0]-1)", "X+(param.dc.X[0]-1)", "X-param.dc.X2X1mX1", "X+param.dc.X2X1mX1", "X-param.dc.X3X2X1mX2X1", "X+param.dc.X3X2X1mX2X1",
566  "X-param.dc.X4X3X2X1mX3X2X1", "X+param.dc.X4X3X2X1mX3X2X1"]
567 
568  cond = ""
569  cond += "#ifdef MULTI_GPU\n"
570  cond += "if ( (kernel_type == INTERIOR_KERNEL && (!param.ghostDim["+`dir/2`+"] || "+interior[dir]+")) ||\n"
571  cond += " (kernel_type == EXTERIOR_KERNEL_"+dim[dir/2]+" && "+boundary[dir]+") )\n"
572  cond += "#endif\n"
573 
574  str = ""
575 
576  projName = "P"+`dir/2`+["-","+"][projIdx%2]
577  str += "// Projector "+projName+"\n"
578  for l in projStr.splitlines():
579  str += "// "+l+"\n"
580  str += "\n"
581 
582  str += "#ifdef MULTI_GPU\n"
583  str += "const int sp_idx = (kernel_type == INTERIOR_KERNEL) ? ("+boundary[dir]+" ? "+sp_idx_wrap[dir]+" : "+sp_idx[dir]+") >> 1 :\n"
584  str += " face_idx + param.ghostOffset[static_cast<int>(kernel_type)][" + `(dir+1)%2` + "];\n"
585  str += "#if (DD_PREC==2) // half precision\n"
586  str += "const int sp_norm_idx = face_idx + param.ghostNormOffset[static_cast<int>(kernel_type)][" + `(dir+1)%2` + "];\n"
587  str += "#endif\n"
588  str += "#else\n"
589  str += "const int sp_idx = ("+boundary[dir]+" ? "+sp_idx_wrap[dir]+" : "+sp_idx[dir]+") >> 1;\n"
590  str += "#endif\n"
591 
592  str += "\n"
593  if dir % 2 == 0:
594  str += "const int ga_idx = sid;\n"
595  else:
596  str += "#ifdef MULTI_GPU\n"
597  str += "const int ga_idx = ((kernel_type == INTERIOR_KERNEL) ? sp_idx : param.dc.Vh+face_idx);\n"
598  str += "#else\n"
599  str += "const int ga_idx = sp_idx;\n"
600  str += "#endif\n"
601  str += "\n"
602 
603  # scan the projector to determine which loads are required
604  row_cnt = ([0,0,0,0])
605  for h in range(0,4):
606  for s in range(0,4):
607  re = proj(h,s).real
608  im = proj(h,s).imag
609  if re != 0 or im != 0:
610  row_cnt[h] += 1
611  row_cnt[0] += row_cnt[1]
612  row_cnt[2] += row_cnt[3]
613 
614  decl_half = ""
615  for h in range(0, 2):
616  for c in range(0, 3):
617  decl_half += "spinorFloat "+h1_re(h,c)+", "+h1_im(h,c)+";\n";
618  decl_half += "\n"
619 
620  load_spinor = "// read spinor from device memory\n"
621  if row_cnt[0] == 0:
622  load_spinor += "READ_SPINOR_DOWN(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
623  elif row_cnt[2] == 0:
624  load_spinor += "READ_SPINOR_UP(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
625  else:
626  load_spinor += "READ_SPINOR(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
627  load_spinor += "\n"
628 
629  load_half = ""
630  load_half += "const int sp_stride_pad = param.dc.ghostFace[static_cast<int>(kernel_type)];\n"
631  #load_half += "#if (DD_PREC==2) // half precision\n"
632  #load_half += "const int sp_norm_idx = sid + param.ghostNormOffset[static_cast<int>(kernel_type)];\n"
633  #load_half += "#endif\n"
634 
635  if dir >= 6:
636  #if not dagger: load_half += "//const int t_proj_scale = TPROJSCALE;\n"
637  #else: load_half += "const int t_proj_scale = TPROJSCALE;\n"
638  load_half += "const int t_proj_scale = TPROJSCALE;\n"
639 
640  load_half += "\n"
641  load_half += "// read half spinor from device memory\n"
642 
643 # we have to use the same volume index for backwards and forwards gathers
644 # instead of using READ_UP_SPINOR and READ_DOWN_SPINOR, just use READ_HALF_SPINOR with the appropriate shift
645  load_half += "READ_SPINOR_GHOST(GHOSTSPINORTEX, sp_stride_pad, sp_idx, sp_norm_idx, "+`dir`+");\n\n"
646 # if (dir+1) % 2 == 0: load_half += "READ_HALF_SPINOR(SPINORTEX, sp_stride_pad, sp_idx, sp_norm_idx);\n\n"
647 # else: load_half += "READ_HALF_SPINOR(SPINORTEX, sp_stride_pad, sp_idx + (SPINOR_HOP/2)*sp_stride_pad, sp_norm_idx);\n\n"
648  load_gauge = "// read gauge matrix from device memory\n"
649  load_gauge += "READ_GAUGE_MATRIX(G, GAUGE"+`dir%2`+"TEX, "+`dir`+", ga_idx, param.gauge_stride);\n\n"
650 
651  reconstruct_gauge = "// reconstruct gauge matrix\n"
652  reconstruct_gauge += "RECONSTRUCT_GAUGE_MATRIX("+`dir`+");\n\n"
653 
654  project = "// project spinor into half spinors\n"
655  for h in range(0, 2):
656  for c in range(0, 3):
657  strRe = ""
658  strIm = ""
659  for s in range(0, 4):
660  re = proj(h,s).real
661  im = proj(h,s).imag
662  if re==0 and im==0: ()
663  elif im==0:
664  strRe += sign(re)+in_re(s,c)
665  strIm += sign(re)+in_im(s,c)
666  elif re==0:
667  strRe += sign(-im)+in_im(s,c)
668  strIm += sign(im)+in_re(s,c)
669  if row_cnt[0] == 0: # projector defined on lower half only
670  for s in range(0, 4):
671  re = proj(h+2,s).real
672  im = proj(h+2,s).imag
673  if re==0 and im==0: ()
674  elif im==0:
675  strRe += sign(re)+in_re(s,c)
676  strIm += sign(re)+in_im(s,c)
677  elif re==0:
678  strRe += sign(-im)+in_im(s,c)
679  strIm += sign(im)+in_re(s,c)
680 
681  project += h1_re(h,c)+" = "+strRe+";\n"
682  project += h1_im(h,c)+" = "+strIm+";\n"
683 
684  write_shared = (
685 """// store spinor into shared memory
686 WRITE_SPINOR_SHARED(threadIdx.x, threadIdx.y, threadIdx.z, i);\n
687 """)
688 
689  load_shared_1 = (
690 """// load spinor from shared memory
691 int tx = (threadIdx.x > 0) ? threadIdx.x-1 : blockDim.x-1;
692 __syncthreads();
693 READ_SPINOR_SHARED(tx, threadIdx.y, threadIdx.z);\n
694 """)
695 
696  load_shared_2 = (
697 """// load spinor from shared memory
698 int tx = (threadIdx.x + blockDim.x - ((coord[0]+1)&1) ) % blockDim.x;
699 int ty = (threadIdx.y < blockDim.y - 1) ? threadIdx.y + 1 : 0;
700 READ_SPINOR_SHARED(tx, ty, threadIdx.z);\n
701 """)
702 
703  load_shared_3 = (
704 """// load spinor from shared memory
705 int tx = (threadIdx.x + blockDim.x - ((coord[0]+1)&1)) % blockDim.x;
706 int ty = (threadIdx.y > 0) ? threadIdx.y - 1 : blockDim.y - 1;
707 READ_SPINOR_SHARED(tx, ty, threadIdx.z);\n
708 """)
709 
710  load_shared_4 = (
711 """// load spinor from shared memory
712 int tx = (threadIdx.x + blockDim.x - ((coord[0]+1)&1) ) % blockDim.x;
713 int tz = (threadIdx.z < blockDim.z - 1) ? threadIdx.z + 1 : 0;
714 READ_SPINOR_SHARED(tx, threadIdx.y, tz);\n
715 """)
716 
717  load_shared_5 = (
718 """// load spinor from shared memory
719 int tx = (threadIdx.x + blockDim.x - ((coord[0]+1)&1)) % blockDim.x;
720 int tz = (threadIdx.z > 0) ? threadIdx.z - 1 : blockDim.z - 1;
721 READ_SPINOR_SHARED(tx, threadIdx.y, tz);\n
722 """)
723 
724 
725  copy_half = ""
726  if dir < 6:
727  for h in range(0, 2):
728  for c in range(0, 3):
729  copy_half += h1_re(h,c)+" = "+in_re(h,c)+"; "
730  copy_half += h1_im(h,c)+" = "+in_im(h,c)+";\n"
731  else:
732  for h in range(0, 2):
733  for c in range(0, 3):
734  copy_half += h1_re(h,c)+" = t_proj_scale*"+in_re(h,c)+"; "
735  copy_half += h1_im(h,c)+" = t_proj_scale*"+in_im(h,c)+";\n"
736  copy_half += "\n"
737 
738  prep_half = ""
739  prep_half += "#ifdef MULTI_GPU\n"
740  prep_half += "if (kernel_type == INTERIOR_KERNEL) {\n"
741  prep_half += "#endif\n"
742  prep_half += "\n"
743 
744  if sharedDslash:
745  if dir == 0:
746  prep_half += indent(load_spinor)
747  prep_half += indent(write_shared)
748  prep_half += indent(project)
749  elif dir == 1:
750  prep_half += indent(load_shared_1)
751  prep_half += indent(project)
752  elif dir == 2:
753  prep_half += indent("if (threadIdx.y == blockDim.y-1 && blockDim.y < param.dc.X[1] ) {\n")
754  prep_half += indent(load_spinor)
755  prep_half += indent(project)
756  prep_half += indent("} else {")
757  prep_half += indent(load_shared_2)
758  prep_half += indent(project)
759  prep_half += indent("}")
760  elif dir == 3:
761  prep_half += indent("if (threadIdx.y == 0 && blockDim.y < param.dc.X[1]) {\n")
762  prep_half += indent(load_spinor)
763  prep_half += indent(project)
764  prep_half += indent("} else {")
765  prep_half += indent(load_shared_3)
766  prep_half += indent(project)
767  prep_half += indent("}")
768  elif dir == 4:
769  prep_half += indent("if (threadIdx.z == blockDim.z-1 && blockDim.z < X3) {\n")
770  prep_half += indent(load_spinor)
771  prep_half += indent(project)
772  prep_half += indent("} else {")
773  prep_half += indent(load_shared_4)
774  prep_half += indent(project)
775  prep_half += indent("}")
776  elif dir == 5:
777  prep_half += indent("if (threadIdx.z == 0 && blockDim.z < X3) {\n")
778  prep_half += indent(load_spinor)
779  prep_half += indent(project)
780  prep_half += indent("} else {")
781  prep_half += indent(load_shared_5)
782  prep_half += indent(project)
783  prep_half += indent("}")
784  else:
785  prep_half += indent(load_spinor)
786  prep_half += indent(project)
787  else:
788  prep_half += indent(load_spinor)
789  prep_half += indent(project)
790 
791  prep_half += "\n"
792  prep_half += "#ifdef MULTI_GPU\n"
793  prep_half += "} else {\n"
794  prep_half += "\n"
795  prep_half += indent(load_half)
796  prep_half += indent(copy_half)
797  prep_half += "}\n"
798  prep_half += "#endif // MULTI_GPU\n"
799  prep_half += "\n"
800 
801  ident = "// identity gauge matrix\n"
802  for m in range(0,3):
803  for h in range(0,2):
804  ident += "spinorFloat "+h2_re(h,m)+" = " + h1_re(h,m) + "; "
805  ident += "spinorFloat "+h2_im(h,m)+" = " + h1_im(h,m) + ";\n"
806  ident += "\n"
807 
808  mult = ""
809  for m in range(0,3):
810  mult += "// multiply row "+`m`+"\n"
811  for h in range(0,2):
812  re = "spinorFloat "+h2_re(h,m)+" = 0;\n"
813  im = "spinorFloat "+h2_im(h,m)+" = 0;\n"
814  for c in range(0,3):
815  re += h2_re(h,m) + " += " + g_re(dir,m,c) + " * "+h1_re(h,c)+";\n"
816  re += h2_re(h,m) + " -= " + g_im(dir,m,c) + " * "+h1_im(h,c)+";\n"
817  im += h2_im(h,m) + " += " + g_re(dir,m,c) + " * "+h1_im(h,c)+";\n"
818  im += h2_im(h,m) + " += " + g_im(dir,m,c) + " * "+h1_re(h,c)+";\n"
819  mult += re + im
820  mult += "\n"
821 
822  reconstruct = ""
823  for m in range(0,3):
824 
825  for h in range(0,2):
826  h_out = h
827  if row_cnt[0] == 0: # projector defined on lower half only
828  h_out = h+2
829 
830  reconstruct += out_re(h_out, m) + " += " + h2_re(h,m) + ";\n"
831  reconstruct += out_im(h_out, m) + " += " + h2_im(h,m) + ";\n"
832 
833  for s in range(2,4):
834  (h,c) = row(s)
835  re = c.real
836  im = c.imag
837  if im == 0 and re == 0: ()
838  elif im == 0:
839  reconstruct += out_re(s, m) + " " + sign(re) + "= " + h2_re(h,m) + ";\n"
840  reconstruct += out_im(s, m) + " " + sign(re) + "= " + h2_im(h,m) + ";\n"
841  elif re == 0:
842  reconstruct += out_re(s, m) + " " + sign(-im) + "= " + h2_im(h,m) + ";\n"
843  reconstruct += out_im(s, m) + " " + sign(+im) + "= " + h2_re(h,m) + ";\n"
844 
845  reconstruct += "\n"
846 
847  if dir >= 6:
848  str += "if (param.gauge_fixed && ga_idx < param.dc.X4X3X2X1hmX3X2X1h)\n"
849  str += block(decl_half + prep_half + ident + reconstruct)
850  str += " else "
851  str += block(decl_half + prep_half + load_gauge + reconstruct_gauge + mult + reconstruct)
852  else:
853  str += decl_half + prep_half + load_gauge + reconstruct_gauge + mult + reconstruct
854 
855  return cond + block(str)+"\n\n"
856 # end def gen
857 
858 
859 def input_spinor(s,c,z):
860  if dslash:
861  if z==0: return out_re(s,c)
862  else: return out_im(s,c)
863  else:
864  if z==0: return in_re(s,c)
865  else: return in_im(s,c)
866 
867 
869  str = ""
870  if dagger:
871  str += "#if !defined(CLOVER_TWIST_INV_DSLASH)\n"
872  str += "#ifdef SPINOR_DOUBLE\n"
873  str += "spinorFloat a = param.a;\n"
874  str += "#else\n"
875  str += "spinorFloat a = param.a_f;\n"
876  str += "#endif\n"
877  if dagger:
878  str += "#endif\n"
879 
880  str += "#ifdef DSLASH_XPAY\n"
881 
882  str += "#ifdef SPINOR_DOUBLE\n"
883  str += "spinorFloat b = param.b;\n"
884  str += "#else\n"
885  str += "spinorFloat b = param.b_f;\n"
886  str += "#endif\n"
887 
888  str += "READ_ACCUM(ACCUMTEX, param.sp_stride)\n\n"
889 
890  if not dagger:
891  str += "#ifndef CLOVER_TWIST_XPAY\n"
892  str += "//perform invert twist first:\n"
893  str += "#ifndef DYNAMIC_CLOVER\n"
894  str += "APPLY_CLOVER_TWIST_INV(c, cinv, a, o);\n"
895  str += "#else\n"
896  str += "APPLY_CLOVER_TWIST_DYN_INV(c, a, o);\n"
897  str += "#endif\n"
898  for s in range(0,4):
899  for c in range(0,3):
900  i = 3*s+c
901  str += out_re(s,c) +" = b*"+out_re(s,c)+" + "+acc_re(s,c)+";\n"
902  str += out_im(s,c) +" = b*"+out_im(s,c)+" + "+acc_im(s,c)+";\n"
903  str += "#else\n"
904  str += "APPLY_CLOVER_TWIST(c, a, acc);\n"
905  for s in range(0,4):
906  for c in range(0,3):
907  i = 3*s+c
908  str += out_re(s,c) +" = b*"+out_re(s,c)+" + "+acc_re(s,c)+";\n"
909  str += out_im(s,c) +" = b*"+out_im(s,c)+" + "+acc_im(s,c)+";\n"
910  str += "#endif//CLOVER_TWIST_XPAY\n"
911  str += "#else //no XPAY\n"
912  str += "#ifndef DYNAMIC_CLOVER\n"
913  str += "APPLY_CLOVER_TWIST_INV(c, cinv, a, o);\n"
914  str += "#else\n"
915  str += "APPLY_CLOVER_TWIST_DYN_INV(c, a, o);\n"
916  str += "#endif\n"
917  str += "#endif\n"
918  else:
919  str += "#ifndef CLOVER_TWIST_INV_DSLASH\n"
920  str += "#ifndef CLOVER_TWIST_XPAY\n"
921  str += "//perform invert twist first:\n"
922  str += "#ifndef DYNAMIC_CLOVER\n"
923  str += "APPLY_CLOVER_TWIST_INV(c, cinv, -a, o);\n"
924  str += "#else\n"
925  str += "APPLY_CLOVER_TWIST_DYN_INV(c, -a, o);\n"
926  str += "#endif\n"
927  str += "#else\n"
928  str += "APPLY_CLOVER_TWIST(c, -a, acc);\n"
929  str += "#endif\n"
930  str += "#endif\n"
931  for s in range(0,4):
932  for c in range(0,3):
933  i = 3*s+c
934  str += out_re(s,c) +" = b*"+out_re(s,c)+" + "+acc_re(s,c)+";\n"
935  str += out_im(s,c) +" = b*"+out_im(s,c)+" + "+acc_im(s,c)+";\n"
936  str += "#else //no XPAY\n"
937  str += "#ifndef CLOVER_TWIST_INV_DSLASH\n"
938  str += "#ifndef DYNAMIC_CLOVER\n"
939  str += "APPLY_CLOVER_TWIST_INV(c, cinv, -a, o);\n"
940  str += "#else\n"
941  str += "APPLY_CLOVER_TWIST_DYN_INV(c, -a, o);\n"
942  str += "#endif\n"
943  str += "#endif\n"
944  str += "#endif\n"
945  return str
946 # end def clover_twisted_xpay
947 
948 
949 def epilog():
950  str = ""
951  if dslash:
952  if twist:
953  str += "#ifdef MULTI_GPU\n"
954  else:
955  str += "#if defined MULTI_GPU && (defined DSLASH_XPAY || defined DSLASH_CLOVER)\n"
956  str += (
957 """
958 int incomplete = 0; // Have all 8 contributions been computed for this site?
959 
960 switch(kernel_type) { // intentional fall-through
961 
962 case INTERIOR_KERNEL:
963  incomplete = incomplete || (param.commDim[3] && (coord[3]==0 || coord[3]==(param.dc.X[3]-1)));
964 case EXTERIOR_KERNEL_T:
965  incomplete = incomplete || (param.commDim[2] && (coord[2]==0 || coord[2]==(param.dc.X[2]-1)));
966 case EXTERIOR_KERNEL_Z:
967  incomplete = incomplete || (param.commDim[1] && (coord[1]==0 || coord[1]==(param.dc.X[1]-1)));
968 case EXTERIOR_KERNEL_Y:
969  incomplete = incomplete || (param.commDim[0] && (coord[0]==0 || coord[0]==(param.dc.X[0]-1)));
970 }
971 
972 """)
973  str += "if (!incomplete)\n"
974  str += "#endif // MULTI_GPU\n"
975 
976  block_str = ""
977  block_str += clover_twisted_xpay()
978  str += block( block_str )
979 
980  str += "\n\n"
981  str += "// write spinor field back to device memory\n"
982  str += "WRITE_SPINOR(param.sp_stride);\n\n"
983 
984  str += "// undefine to prevent warning when precision is changed\n"
985  str += "#undef spinorFloat\n"
986  if sharedDslash:
987  str += "#undef WRITE_SPINOR_SHARED\n"
988  str += "#undef READ_SPINOR_SHARED\n"
989  if sharedFloats > 0: str += "#undef SHARED_STRIDE\n\n"
990 
991  if dslash:
992  for m in range(0,3):
993  for n in range(0,3):
994  i = 3*m+n
995  str += "#undef "+g_re(0,m,n)+"\n"
996  str += "#undef "+g_im(0,m,n)+"\n"
997  str += "\n"
998 
999  for s in range(0,4):
1000  for c in range(0,3):
1001  i = 3*s+c
1002  str += "#undef "+in_re(s,c)+"\n"
1003  str += "#undef "+in_im(s,c)+"\n"
1004  str += "\n"
1005 
1006  for m in range(0,6):
1007  s = m/3
1008  c = m%3
1009  str += "#undef "+c_re(0,s,c,s,c)+"\n"
1010  for n in range(0,6):
1011  sn = n/3
1012  cn = n%3
1013  for m in range(n+1,6):
1014  sm = m/3
1015  cm = m%3
1016  str += "#undef "+c_re(0,sm,cm,sn,cn)+"\n"
1017  str += "#undef "+c_im(0,sm,cm,sn,cn)+"\n"
1018  str += "\n"
1019 
1020  for m in range(0,6):
1021  s = m/3
1022  c = m%3
1023  str += "#undef "+cinv_re(0,s,c,s,c)+"\n"
1024  for n in range(0,6):
1025  sn = n/3
1026  cn = n%3
1027  for m in range(n+1,6):
1028  sm = m/3
1029  cm = m%3
1030  str += "#undef "+cinv_re(0,sm,cm,sn,cn)+"\n"
1031  str += "#undef "+cinv_im(0,sm,cm,sn,cn)+"\n"
1032  str += "\n"
1033 
1034  if dslash:
1035  for s in range(0,4):
1036  for c in range(0,3):
1037  i = 3*s+c
1038  str += "#undef "+acc_re(s,c)+"\n"
1039  str += "#undef "+acc_im(s,c)+"\n"
1040  str += "\n"
1041 
1042  str += "\n"
1043 
1044  for s in range(0,4):
1045  for c in range(0,3):
1046  i = 3*s+c
1047  if 2*i < sharedFloats:
1048  str += "#undef "+out_re(s,c)+"\n"
1049  if 2*i+1 < sharedFloats:
1050  str += "#undef "+out_im(s,c)+"\n"
1051  str += "\n"
1052 
1053  str += "#undef VOLATILE\n"
1054 
1055  return str
1056 # end def epilog
1057 
1058 
1060  return prolog() + gen(0) + gen(1) + gen(2) + gen(3) + gen(4) + gen(5) + gen(6) + gen(7) + epilog()
1061 
1062 # generate Wilson-like Dslash kernels
1064  print "Generating dslash kernel for sm" + str(arch/10)
1065 
1066  global sharedFloats
1067  global sharedDslash
1068  global dslash
1069  global dagger
1070  global twist #deg_twist
1071 # global ndeg_twist #new!
1072 
1073  sharedFloats = 0
1074  if arch >= 200:
1075  sharedFloats = 24
1076  sharedDslash = True
1077  name = "fermi"
1078  elif arch >= 120:
1079  sharedFloats = 0
1080  sharedDslash = False
1081  name = "gt200"
1082  else:
1083  sharedFloats = 19
1084  sharedDslash = False
1085  name = "g80"
1086 
1087  print "Shared floats set to " + str(sharedFloats)
1088 
1089  dslash = True
1090  twist = True
1091  dagger = False
1092  filename = 'dslash_core/tmc_dslash_' + name + '_core.h'
1093  print sys.argv[0] + ": generating " + filename;
1094  f = open(filename, 'w')
1095  f.write(generate_dslash())
1096  f.close()
1097 
1098  dagger = True
1099  filename = 'dslash_core/tmc_dslash_dagger_' + name + '_core.h'
1100  print sys.argv[0] + ": generating " + filename + "\n";
1101  f = open(filename, 'w')
1102  f.write(generate_dslash())
1103  f.close()
1104 
1105  twist = False
1106  dslash = False
1107 
1108 
1109 
1110 dslash = False
1111 dagger = False
1112 twist = False
1113 sharedFloats = 0
1114 sharedDslash = False
1115 pack = False
1116 
1117 # generate dslash kernels
1118 arch = 200
1120 
1121 arch = 130
1123 
def indent(code)
code generation ######################################################################## ...
def cinv_re(b, sm, cm, sn, cn)
def gen(dir, pack_only=False)
Definition: gen.py:1
def complexify(a)
complex numbers ######################################################################## ...
def def_clover(pack_only=False)
def c_im(b, sm, cm, sn, cn)
def c_re(b, sm, cm, sn, cn)
def cinv_im(b, sm, cm, sn, cn)