QUDA  0.9.0
dw_dslash_cuda_gen.py
Go to the documentation of this file.
1 # -*- coding: utf-8 -*-
2 import sys
3 
4 
5 
6 def complexify(a):
7  return [complex(x) for x in a]
8 
9 def complexToStr(c):
10  def fltToString(a):
11  if a == int(a): return `int(a)`
12  else: return `a`
13 
14  def imToString(a):
15  if a == 0: return "0i"
16  elif a == -1: return "-i"
17  elif a == 1: return "i"
18  else: return fltToString(a)+"i"
19 
20  re = c.real
21  im = c.imag
22  if re == 0 and im == 0: return "0"
23  elif re == 0: return imToString(im)
24  elif im == 0: return fltToString(re)
25  else:
26  im_str = "-"+imToString(-im) if im < 0 else "+"+imToString(im)
27  return fltToString(re)+im_str
28 
29 
30 
31 
32 id = complexify([
33  1, 0, 0, 0,
34  0, 1, 0, 0,
35  0, 0, 1, 0,
36  0, 0, 0, 1
37 ])
38 
39 gamma1 = complexify([
40  0, 0, 0, 1j,
41  0, 0, 1j, 0,
42  0, -1j, 0, 0,
43  -1j, 0, 0, 0
44 ])
45 
46 gamma2 = complexify([
47  0, 0, 0, 1,
48  0, 0, -1, 0,
49  0, -1, 0, 0,
50  1, 0, 0, 0
51 ])
52 
53 gamma3 = complexify([
54  0, 0, 1j, 0,
55  0, 0, 0, -1j,
56  -1j, 0, 0, 0,
57  0, 1j, 0, 0
58 ])
59 
60 gamma4 = complexify([
61  1, 0, 0, 0,
62  0, 1, 0, 0,
63  0, 0, -1, 0,
64  0, 0, 0, -1
65 ])
66 
67 igamma5 = complexify([
68  0, 0, 1j, 0,
69  0, 0, 0, 1j,
70  1j, 0, 0, 0,
71  0, 1j, 0, 0
72 ])
73 
74 two_P_L = [ id[x] - igamma5[x]/1j for x in range(0,4*4) ]
75 two_P_R = [ id[x] + igamma5[x]/1j for x in range(0,4*4) ]
76 
77 # for s1 in range(0,4) :
78 # for s2 in range (0,4): print "%8s" % two_P_L[s1*4+s2],
79 # print " ",
80 # for s2 in range (0,4): print "%8s" % two_P_R[s1*4+s2],
81 # print ""
82 
83 
84 def gplus(g1, g2):
85  return [x+y for (x,y) in zip(g1,g2)]
86 
87 def gminus(g1, g2):
88  return [x-y for (x,y) in zip(g1,g2)]
89 
91  out = ""
92  for i in range(0, 4):
93  for j in range(0,4):
94  out += '%3s' % complexToStr(p[4*i+j])
95  out += "\n"
96  return out
97 
98 projectors = [
99  gminus(id,gamma1), gplus(id,gamma1),
100  gminus(id,gamma2), gplus(id,gamma2),
101  gminus(id,gamma3), gplus(id,gamma3),
102  gminus(id,gamma4), gplus(id,gamma4),
103 ]
104 
105 
106 
107 def indent(code, n=1):
108  def indentline(line): return (n*" "+line if ( line and line.count("#", 0, 1) == 0) else line)
109  return ''.join([indentline(line)+"\n" for line in code.splitlines()])
110 
111 def block(code):
112  return "{\n"+indent(code)+"}"
113 
114 def sign(x):
115  if x==1: return "+"
116  elif x==-1: return "-"
117  elif x==+2: return "+2*"
118  elif x==-2: return "-2*"
119 
120 def nthFloat4(n):
121  return `(n/4)` + "." + ["x", "y", "z", "w"][n%4]
122 
123 def nthFloat2(n):
124  return `(n/2)` + "." + ["x", "y"][n%2]
125 
126 
127 def in_re(s, c): return "i"+`s`+`c`+"_re"
128 def in_im(s, c): return "i"+`s`+`c`+"_im"
129 def g_re(d, m, n): return ("g" if (d%2==0) else "gT")+`m`+`n`+"_re"
130 def g_im(d, m, n): return ("g" if (d%2==0) else "gT")+`m`+`n`+"_im"
131 def out_re(s, c): return "o"+`s`+`c`+"_re"
132 def out_im(s, c): return "o"+`s`+`c`+"_im"
133 def h1_re(h, c): return ["a","b"][h]+`c`+"_re"
134 def h1_im(h, c): return ["a","b"][h]+`c`+"_im"
135 def h2_re(h, c): return ["A","B"][h]+`c`+"_re"
136 def h2_im(h, c): return ["A","B"][h]+`c`+"_im"
137 def c_re(b, sm, cm, sn, cn): return "c"+`(sm+2*b)`+`cm`+"_"+`(sn+2*b)`+`cn`+"_re"
138 def c_im(b, sm, cm, sn, cn): return "c"+`(sm+2*b)`+`cm`+"_"+`(sn+2*b)`+`cn`+"_im"
139 def a_re(b, s, c): return "a"+`(s+2*b)`+`c`+"_re"
140 def a_im(b, s, c): return "a"+`(s+2*b)`+`c`+"_im"
141 
142 def tmp_re(s, c): return "tmp"+`s`+`c`+"_re"
143 def tmp_im(s, c): return "tmp"+`s`+`c`+"_im"
144 
145 
147  str = ""
148  str += "// input spinor\n"
149  str += "#ifdef SPINOR_DOUBLE\n"
150  str += "#define spinorFloat double\n"
151  for s in range(0,4):
152  for c in range(0,3):
153  i = 3*s+c
154  str += "#define "+in_re(s,c)+" I"+nthFloat2(2*i+0)+"\n"
155  str += "#define "+in_im(s,c)+" I"+nthFloat2(2*i+1)+"\n"
156  str += "#else\n"
157  str += "#define spinorFloat float\n"
158  for s in range(0,4):
159  for c in range(0,3):
160  i = 3*s+c
161  str += "#define "+in_re(s,c)+" I"+nthFloat4(2*i+0)+"\n"
162  str += "#define "+in_im(s,c)+" I"+nthFloat4(2*i+1)+"\n"
163  str += "#endif // SPINOR_DOUBLE\n\n"
164  return str
165 # end def def_input_spinor
166 
167 
168 def def_gauge():
169  str = "// gauge link\n"
170  str += "#ifdef GAUGE_FLOAT2\n"
171  for m in range(0,3):
172  for n in range(0,3):
173  i = 3*m+n
174  str += "#define "+g_re(0,m,n)+" G"+nthFloat2(2*i+0)+"\n"
175  str += "#define "+g_im(0,m,n)+" G"+nthFloat2(2*i+1)+"\n"
176 
177  str += "\n"
178  str += "#else\n"
179  for m in range(0,3):
180  for n in range(0,3):
181  i = 3*m+n
182  str += "#define "+g_re(0,m,n)+" G"+nthFloat4(2*i+0)+"\n"
183  str += "#define "+g_im(0,m,n)+" G"+nthFloat4(2*i+1)+"\n"
184 
185  str += "\n"
186  str += "#endif // GAUGE_DOUBLE\n\n"
187 
188  str += "// conjugated gauge link\n"
189  for m in range(0,3):
190  for n in range(0,3):
191  i = 3*m+n
192  str += "#define "+g_re(1,m,n)+" (+"+g_re(0,n,m)+")\n"
193  str += "#define "+g_im(1,m,n)+" (-"+g_im(0,n,m)+")\n"
194  str += "\n"
195 
196  return str
197 # end def def_gauge
198 
199 
201  str = "// first chiral block of inverted clover term\n"
202  str += "#ifdef CLOVER_DOUBLE\n"
203  i = 0
204  for m in range(0,6):
205  s = m/3
206  c = m%3
207  str += "#define "+c_re(0,s,c,s,c)+" C"+nthFloat2(i)+"\n"
208  i += 1
209  for n in range(0,6):
210  sn = n/3
211  cn = n%3
212  for m in range(n+1,6):
213  sm = m/3
214  cm = m%3
215  str += "#define "+c_re(0,sm,cm,sn,cn)+" C"+nthFloat2(i)+"\n"
216  str += "#define "+c_im(0,sm,cm,sn,cn)+" C"+nthFloat2(i+1)+"\n"
217  i += 2
218  str += "#else\n"
219  i = 0
220  for m in range(0,6):
221  s = m/3
222  c = m%3
223  str += "#define "+c_re(0,s,c,s,c)+" C"+nthFloat4(i)+"\n"
224  i += 1
225  for n in range(0,6):
226  sn = n/3
227  cn = n%3
228  for m in range(n+1,6):
229  sm = m/3
230  cm = m%3
231  str += "#define "+c_re(0,sm,cm,sn,cn)+" C"+nthFloat4(i)+"\n"
232  str += "#define "+c_im(0,sm,cm,sn,cn)+" C"+nthFloat4(i+1)+"\n"
233  i += 2
234  str += "#endif // CLOVER_DOUBLE\n\n"
235 
236  for n in range(0,6):
237  sn = n/3
238  cn = n%3
239  for m in range(0,n):
240  sm = m/3
241  cm = m%3
242  str += "#define "+c_re(0,sm,cm,sn,cn)+" (+"+c_re(0,sn,cn,sm,cm)+")\n"
243  str += "#define "+c_im(0,sm,cm,sn,cn)+" (-"+c_im(0,sn,cn,sm,cm)+")\n"
244  str += "\n"
245 
246  str += "// second chiral block of inverted clover term (reuses C0,...,C9)\n"
247  for n in range(0,6):
248  sn = n/3
249  cn = n%3
250  for m in range(0,6):
251  sm = m/3
252  cm = m%3
253  str += "#define "+c_re(1,sm,cm,sn,cn)+" "+c_re(0,sm,cm,sn,cn)+"\n"
254  if m != n: str += "#define "+c_im(1,sm,cm,sn,cn)+" "+c_im(0,sm,cm,sn,cn)+"\n"
255  str += "\n"
256 
257  return str
258 # end def def_clover
259 
261  str = "// output spinor\n"
262  for s in range(0,4):
263  for c in range(0,3):
264  i = 3*s+c
265  if 2*i < sharedFloats:
266  str += "#define "+out_re(s,c)+" s["+`(2*i+0)`+"*SHARED_STRIDE]\n"
267  else:
268  str += "VOLATILE spinorFloat "+out_re(s,c)+";\n"
269  if 2*i+1 < sharedFloats:
270  str += "#define "+out_im(s,c)+" s["+`(2*i+1)`+"*SHARED_STRIDE]\n"
271  else:
272  str += "VOLATILE spinorFloat "+out_im(s,c)+";\n"
273  return str
274 # end def def_output_spinor
275 
276 
277 def prolog():
278  if dslash:
279  prolog_str= ("// *** CUDA DSLASH ***\n\n" if not dagger else "// *** CUDA DSLASH DAGGER ***\n\n")
280  prolog_str+= "#define DSLASH_SHARED_FLOATS_PER_THREAD "+str(sharedFloats)+"\n\n"
281  elif clover:
282  prolog_str= ("// *** CUDA CLOVER ***\n\n")
283  prolog_str+= "#define CLOVER_SHARED_FLOATS_PER_THREAD "+str(sharedFloats)+"\n\n"
284  else:
285  print "Undefined prolog"
286  exit
287 
288  if domain_wall: prolog_str += "// NB! Don't trust any MULTI_GPU code\n"
289 
290  prolog_str+= (
291 """
292 #if (CUDA_VERSION >= 4010)
293 #define VOLATILE
294 #else
295 #define VOLATILE volatile
296 #endif
297 """)
298 
299  prolog_str+= def_input_spinor()
300  if dslash == True: prolog_str+= def_gauge()
301  if clover == True: prolog_str+= def_clover()
302  prolog_str+= def_output_spinor()
303 
304  prolog_str+= (
305 """
306 #ifdef SPINOR_DOUBLE
307 #if (__COMPUTE_CAPABILITY__ >= 200)
308 #define SHARED_STRIDE 16 // to avoid bank conflicts on Fermi
309 #else
310 #define SHARED_STRIDE 8 // to avoid bank conflicts on G80 and GT200
311 #endif
312 #else
313 #if (__COMPUTE_CAPABILITY__ >= 200)
314 #define SHARED_STRIDE 32 // to avoid bank conflicts on Fermi
315 #else
316 #define SHARED_STRIDE 16 // to avoid bank conflicts on G80 and GT200
317 #endif
318 #endif
319 """)
320 
321  if sharedFloats > 0:
322  prolog_str += (
323 """
324 extern __shared__ char s_data[];
325 """)
326 
327  if dslash:
328  prolog_str += (
329 """
330 VOLATILE spinorFloat *s = (spinorFloat*)s_data + DSLASH_SHARED_FLOATS_PER_THREAD*SHARED_STRIDE*(threadIdx.x/SHARED_STRIDE)
331 + (threadIdx.x % SHARED_STRIDE);
332 """)
333  else:
334  prolog_str += (
335 """
336 VOLATILE spinorFloat *s = (spinorFloat*)s_data + CLOVER_SHARED_FLOATS_PER_THREAD*SHARED_STRIDE*(threadIdx.x/SHARED_STRIDE)
337 + (threadIdx.x % SHARED_STRIDE);
338 """)
339 
340 
341  if dslash:
342  prolog_str += "\n#include \"read_gauge.h\"\n"
343  if not domain_wall:
344  prolog_str += "#include \"read_clover.h\"\n"
345  prolog_str += "#include \"io_spinor.h\"\n"
346  prolog_str += (
347 """
348 
349 int sid = ((blockIdx.y*blockDim.y + threadIdx.y)*gridDim.x + blockIdx.x)*blockDim.x + threadIdx.x;
350 if (sid >= param.threads*param.dc.Ls) return;
351 
352 int X, coord[5];
353 
354 int s_parity;
355 
356 #ifdef MULTI_GPU
357 int face_idx;
358 if (kernel_type == INTERIOR_KERNEL) {
359 #endif
360 
361 """)
362  if domain_wall:
363  prolog_str+=(
364 """
365 coordsFromIndex<5,QUDA_5D_PC,EVEN_X>(X, coord, sid, param);
366 s_parity = ( sid/param.dc.volume_4d_cb ) % 2;
367 
368 """)
369  else:
370  prolog_str+=(
371 """
372 X = 2*sid;
373 int aux1 = X / param.dc.X[0]1;
374 x1 = X - aux1 * param.dc.X[0];
375 int aux2 = aux1 / param.dc.X[1];
376 x2 = aux1 - aux2 * param.dc.X[1];
377 x4 = aux2 / param.dc.X[2];
378 x3 = aux2 - x4 * param.dc.X[2];
379 aux1 = (param.parity + x4 + x3 + x2) & 1;
380 x1 += aux1;
381 X += aux1;
382 
383 """)
384 
385  out = ""
386  for s in range(0,4):
387  for c in range(0,3):
388  out += out_re(s,c)+" = 0; "+out_im(s,c)+" = 0;\n"
389  prolog_str+= indent(out)
390 
391  prolog_str+= (
392 """
393 #ifdef MULTI_GPU
394 } else { // exterior kernel
395 
396 const int face_volume = (param.threads*param.dc.Ls >> 1); // volume of one face
397 const int face_num = (sid >= face_volume); // is this thread updating face 0 or 1
398 face_idx = sid - face_num*face_volume; // index into the respective face
399 
400 // ghostOffset is scaled to include body (includes stride) and number of FloatN arrays (SPINOR_HOP)
401 // face_idx not sid since faces are spin projected and share the same volume index (modulo UP/DOWN reading)
402 //sp_idx = face_idx + param.ghostOffset[dim];
403 
404 coordsFromFaceIndex<5,QUDA_5D_PC,kernel_type,1>(X, sid, coord, face_idx, face_num, param);
405 s_parity = ( sid/param.dc.volume_4d_cb ) % 2;
406 
407 READ_INTERMEDIATE_SPINOR(INTERTEX, param.sp_stride, sid, sid);
408 
409 """)
410 
411  out = ""
412  for s in range(0,4):
413  for c in range(0,3):
414  out += out_re(s,c)+" = "+in_re(s,c)+"; "+out_im(s,c)+" = "+in_im(s,c)+";\n"
415  prolog_str+= indent(out)
416  prolog_str+= "}\n"
417  prolog_str+= "#endif // MULTI_GPU\n"
418 
419  if domain_wall:
420  prolog_str += (
421 """
422 // declare G## here and use ASSN below instead of READ
423 #ifdef GAUGE_FLOAT2
424 #if (DD_PREC==0) //temporal hack
425 double2 G0;
426 double2 G1;
427 double2 G2;
428 double2 G3;
429 double2 G4;
430 double2 G5;
431 double2 G6;
432 double2 G7;
433 double2 G8;
434 #else
435 float2 G0;
436 float2 G1;
437 float2 G2;
438 float2 G3;
439 float2 G4;
440 float2 G5;
441 float2 G6;
442 float2 G7;
443 float2 G8;
444 #endif
445 #else
446 float4 G0;
447 float4 G1;
448 float4 G2;
449 float4 G3;
450 float4 G4;
451 #endif
452 
453 """)
454 
455  prolog_str+= "\n\n"
456 
457  elif domain_wall:
458  prolog_str+=(
459 """
460 #include "io_spinor.h"
461 
462 int sid = blockIdx.x*blockDim.x + threadIdx.x;
463 if (sid >= param.threads) return;
464 
465 // read spinor from device memory
466 READ_SPINOR(SPINORTEX, param.sp_stride, sid, sid);
467 
468 """)
469  else:
470  prolog_str+=(
471 """
472 #include "read_clover.h"
473 #include "io_spinor.h"
474 
475 int sid = blockIdx.x*blockDim.x + threadIdx.x;
476 if (sid >= param.threads) return;
477 
478 // read spinor from device memory
479 READ_SPINOR(SPINORTEX, param.sp_stride, sid, sid);
480 
481 """)
482  return prolog_str
483 # end def prolog
484 
485 
486 def gen(dir, pack_only=False):
487  projIdx = dir if not dagger else dir + ( +1 if dir%2 == 0 else -1 )
488  projStr = projectorToStr(projectors[projIdx])
489  def proj(i,j):
490  return projectors[projIdx][4*i+j]
491 
492  # if row(i) = (j, c), then the i'th row of the projector can be represented
493  # as a multiple of the j'th row: row(i) = c row(j)
494  def row(i):
495  assert i==2 or i==3
496  if proj(i,0) == 0j:
497  return (1, proj(i,1))
498  if proj(i,1) == 0j:
499  return (0, proj(i,0))
500 
501  boundary = ["coord[0]==(param.dc.X[0]-1)", "coord[0]==0", "coord[1]==(param.dc.X[1]-1)", "coord[1]==0", "coord[2]==(param.dc.X[2]-1)", "coord[2]==0", "coord[3]==(param.dc.X[3]-1)", "coord[3]==0"]
502  interior = ["coord[0]<(param.dc.X[0]-1)", "coord[0]>0", "coord[1]<(param.dc.X[1]-1)", "coord[1]>0", "coord[2]<(param.dc.X[2]-1)", "coord[2]>0", "coord[3]<(param.dc.X[3]-1)", "coord[3]>0"]
503  dim = ["X", "Y", "Z", "T"]
504 
505  # index of neighboring site when not on boundary
506  sp_idx = ["X+1", "X-1", "X+param.dc.X[0]", "X-param.dc.X[0]", "X+param.dc.X2X1", "X-param.dc.X2X1", "X+param.dc.X3X2X1", "X-param.dc.X3X2X1"]
507 
508  # index of neighboring site (across boundary)
509  sp_idx_wrap = ["X-(param.dc.X[0]-1)", "X+(param.dc.X[0]-1)", "X-param.dc.X2X1mX1", "X+param.dc.X2X1mX1", "X-param.dc.X3X2X1mX2X1", "X+param.dc.X3X2X1mX2X1",
510  "X-param.dc.X4X3X2X1mX3X2X1", "X+param.dc.X4X3X2X1mX3X2X1"]
511 
512  cond = ""
513  cond += "#ifdef MULTI_GPU\n"
514  cond += "if ( (kernel_type == INTERIOR_KERNEL && (!param.ghostDim["+`dir/2`+"] || "+interior[dir]+")) ||\n"
515  cond += " (kernel_type == EXTERIOR_KERNEL_"+dim[dir/2]+" && "+boundary[dir]+") )\n"
516  cond += "#endif\n"
517 
518  str = ""
519 
520  projName = "P"+`dir/2`+["-","+"][projIdx%2]
521  str += "// Projector "+projName+"\n"
522  for l in projStr.splitlines():
523  str += "//"+l+"\n"
524  str += "\n"
525 
526  str += "#ifdef MULTI_GPU\n"
527  str += "const int sp_idx = (kernel_type == INTERIOR_KERNEL) ? ("+boundary[dir]+" ? "+sp_idx_wrap[dir]+" : "+sp_idx[dir]+") >> 1 :\n"
528  str += " face_idx + param.ghostOffset[static_cast<int>(kernel_type)][" + `(dir+1)%2` + "];\n"
529  str += "#if (DD_PREC==2) // half precision\n"
530  str += "const int sp_norm_idx = face_idx + param.ghostNormOffset[static_cast<int>(kernel_type)][" + `(dir+1)%2` + "];\n"
531  str += "#endif\n"
532  str += "#else\n"
533  str += "const int sp_idx = ("+boundary[dir]+" ? "+sp_idx_wrap[dir]+" : "+sp_idx[dir]+") >> 1;\n"
534  str += "#endif\n"
535 
536  str += "\n"
537  if dir % 2 == 0:
538  if domain_wall: str += "const int ga_idx = sid % param.dc.volume_4d_cb;\n"
539  else: str += "const int ga_idx = sid;\n"
540  else:
541  str += "#ifdef MULTI_GPU\n"
542  if domain_wall: str += "const int ga_idx = ((kernel_type == INTERIOR_KERNEL) ? sp_idx % param.dc.volume_4d_cb : param.dc.volume_4d_cb+(face_idx % param.dc.ghostFace[static_cast<int>(kernel_type)]));\n"
543  else: str += "const int ga_idx = ((kernel_type == INTERIOR_KERNEL) ? sp_idx : param.dc.volume_4d_cb+face_idx);\n"
544  str += "#else\n"
545  if domain_wall: str += "const int ga_idx = sp_idx % param.dc.volume_4d_cb;\n"
546  else: str += "const int ga_idx = sp_idx;\n"
547  str += "#endif\n"
548  str += "\n"
549 
550  # scan the projector to determine which loads are required
551  row_cnt = ([0,0,0,0])
552  for h in range(0,4):
553  for s in range(0,4):
554  re = proj(h,s).real
555  im = proj(h,s).imag
556  if re != 0 or im != 0:
557  row_cnt[h] += 1
558  row_cnt[0] += row_cnt[1]
559  row_cnt[2] += row_cnt[3]
560 
561  decl_half = ""
562  for h in range(0, 2):
563  for c in range(0, 3):
564  decl_half += "spinorFloat "+h1_re(h,c)+", "+h1_im(h,c)+";\n";
565  decl_half += "\n"
566 
567  load_spinor = "// read spinor from device memory\n"
568  if row_cnt[0] == 0:
569  load_spinor += "READ_SPINOR_DOWN(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
570  elif row_cnt[2] == 0:
571  load_spinor += "READ_SPINOR_UP(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
572  else:
573  load_spinor += "READ_SPINOR(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
574  load_spinor += "\n"
575 
576  load_half = ""
577  if domain_wall :
578  load_half += "const int sp_stride_pad = param.dc.Ls*param.dc.ghostFace[static_cast<int>(kernel_type)];\n"
579  else :
580  load_half += "const int sp_stride_pad = param.dc.ghostFace[static_cast<int>(kernel_type)];\n"
581  #load_half += "#if (DD_PREC==2) // half precision\n"
582  #load_half += "const int sp_norm_idx = sid + param.ghostNormOffset[static_cast<int>(kernel_type)];\n"
583  #load_half += "#endif\n"
584 
585  if dir >= 6: load_half += "const int t_proj_scale = TPROJSCALE;\n"
586  load_half += "\n"
587  load_half += "// read half spinor from device memory\n"
588 
589 # we have to use the same volume index for backwards and forwards gathers
590 # instead of using READ_UP_SPINOR and READ_DOWN_SPINOR, just use READ_HALF_SPINOR with the appropriate shift
591  load_half += "READ_SPINOR_GHOST(GHOSTSPINORTEX, sp_stride_pad, sp_idx, sp_norm_idx, "+`dir`+");\n\n"
592 # if (dir+1) % 2 == 0: load_half += "READ_HALF_SPINOR(SPINORTEX, sp_stride_pad, sp_idx, sp_norm_idx);\n\n"
593 # else: load_half += "READ_HALF_SPINOR(SPINORTEX, sp_stride_pad, sp_idx + (SPINOR_HOP/2)*sp_stride_pad, sp_norm_idx);\n\n"
594  load_gauge = "// read gauge matrix from device memory\n"
595  if domain_wall:
596  load_gauge += "if ( ! s_parity ) { ASSN_GAUGE_MATRIX(G, GAUGE"+`( dir%2)`+"TEX, "+`dir`+", ga_idx, param.gauge_stride); }\n"
597  load_gauge += "else { ASSN_GAUGE_MATRIX(G, GAUGE"+`(1-dir%2)`+"TEX, "+`dir`+", ga_idx, param.gauge_stride); }\n\n"
598  else:
599  load_gauge += "READ_GAUGE_MATRIX(G, GAUGE"+`dir%2`+"TEX, "+`dir`+", ga_idx, param.gauge_stride);\n\n"
600 
601  reconstruct_gauge = "// reconstruct gauge matrix\n"
602  reconstruct_gauge += "RECONSTRUCT_GAUGE_MATRIX("+`dir`+");\n\n"
603 
604  project = "// project spinor into half spinors\n"
605  for h in range(0, 2):
606  for c in range(0, 3):
607  strRe = ""
608  strIm = ""
609  for s in range(0, 4):
610  re = proj(h,s).real
611  im = proj(h,s).imag
612  if re==0 and im==0: ()
613  elif im==0:
614  strRe += sign(re)+in_re(s,c)
615  strIm += sign(re)+in_im(s,c)
616  elif re==0:
617  strRe += sign(-im)+in_im(s,c)
618  strIm += sign(im)+in_re(s,c)
619  if row_cnt[0] == 0: # projector defined on lower half only
620  for s in range(0, 4):
621  re = proj(h+2,s).real
622  im = proj(h+2,s).imag
623  if re==0 and im==0: ()
624  elif im==0:
625  strRe += sign(re)+in_re(s,c)
626  strIm += sign(re)+in_im(s,c)
627  elif re==0:
628  strRe += sign(-im)+in_im(s,c)
629  strIm += sign(im)+in_re(s,c)
630 
631  project += h1_re(h,c)+" = "+strRe+";\n"
632  project += h1_im(h,c)+" = "+strIm+";\n"
633 
634  copy_half = ""
635  for h in range(0, 2):
636  for c in range(0, 3):
637  copy_half += h1_re(h,c)+" = "+("t_proj_scale*" if (dir >= 6) else "")+in_re(h,c)+"; "
638  copy_half += h1_im(h,c)+" = "+("t_proj_scale*" if (dir >= 6) else "")+in_im(h,c)+";\n"
639  copy_half += "\n"
640 
641  prep_half = ""
642  prep_half += "#ifdef MULTI_GPU\n"
643  prep_half += "if (kernel_type == INTERIOR_KERNEL) {\n"
644  prep_half += "#endif\n"
645  prep_half += "\n"
646  prep_half += indent(load_spinor)
647  prep_half += indent(project)
648  prep_half += "\n"
649  prep_half += "#ifdef MULTI_GPU\n"
650  prep_half += "} else {\n"
651  prep_half += "\n"
652  prep_half += indent(load_half)
653  prep_half += indent(copy_half)
654  prep_half += "}\n"
655  prep_half += "#endif // MULTI_GPU\n"
656  prep_half += "\n"
657 
658  ident = "// identity gauge matrix\n"
659  for m in range(0,3):
660  for h in range(0,2):
661  ident += "spinorFloat "+h2_re(h,m)+" = " + h1_re(h,m) + "; "
662  ident += "spinorFloat "+h2_im(h,m)+" = " + h1_im(h,m) + ";\n"
663  ident += "\n"
664 
665  mult = ""
666  for m in range(0,3):
667  mult += "// multiply row "+`m`+"\n"
668  for h in range(0,2):
669  re = "spinorFloat "+h2_re(h,m)+" = 0;\n"
670  im = "spinorFloat "+h2_im(h,m)+" = 0;\n"
671  for c in range(0,3):
672  re += h2_re(h,m) + " += " + g_re(dir,m,c) + " * "+h1_re(h,c)+";\n"
673  re += h2_re(h,m) + " -= " + g_im(dir,m,c) + " * "+h1_im(h,c)+";\n"
674  im += h2_im(h,m) + " += " + g_re(dir,m,c) + " * "+h1_im(h,c)+";\n"
675  im += h2_im(h,m) + " += " + g_im(dir,m,c) + " * "+h1_re(h,c)+";\n"
676  mult += re + im
677  mult += "\n"
678 
679  reconstruct = ""
680  for m in range(0,3):
681 
682  for h in range(0,2):
683  h_out = h
684  if row_cnt[0] == 0: # projector defined on lower half only
685  h_out = h+2
686  reconstruct += out_re(h_out, m) + " += " + h2_re(h,m) + ";\n"
687  reconstruct += out_im(h_out, m) + " += " + h2_im(h,m) + ";\n"
688 
689  for s in range(2,4):
690  (h,c) = row(s)
691  re = c.real
692  im = c.imag
693  if im == 0 and re == 0:
694  ()
695  elif im == 0:
696  reconstruct += out_re(s, m) + " " + sign(re) + "= " + h2_re(h,m) + ";\n"
697  reconstruct += out_im(s, m) + " " + sign(re) + "= " + h2_im(h,m) + ";\n"
698  elif re == 0:
699  reconstruct += out_re(s, m) + " " + sign(-im) + "= " + h2_im(h,m) + ";\n"
700  reconstruct += out_im(s, m) + " " + sign(+im) + "= " + h2_re(h,m) + ";\n"
701 
702  if ( m < 2 ): reconstruct += "\n"
703 
704  if dir >= 6:
705  str += "if (param.gauge_fixed && ga_idx < param.dc.X4X3X2X1hmX3X2X1h)\n"
706  str += block(decl_half + prep_half + ident + reconstruct)
707  str += " else "
708  str += block(load_gauge + decl_half + prep_half + reconstruct_gauge + mult + reconstruct)
709  else:
710  str += load_gauge + decl_half + prep_half + reconstruct_gauge + mult + reconstruct
711 
712  if pack_only:
713  out = load_spinor + decl_half + project
714  out = out.replace("sp_idx", "idx")
715  return out
716  else:
717  return cond + block(str)+"\n\n"
718 # end def gen
719 
720 
721 def gen_dw():
722  if dagger: lsign='-'; ledge = '0'; rsign='+'; redge='param.dc.Ls-1'
723  else: lsign='+'; ledge = 'param.dc.Ls-1'; rsign='-'; redge='0'
724 
725  str = "\n\n"
726  str += "// 5th dimension -- NB: not partitionable!\n"
727  str += "#ifdef MULTI_GPU\nif(kernel_type == INTERIOR_KERNEL)\n#endif\n{\n"
728  str += "// 2 P_L = 2 P_- = ( ( +1, -1 ), ( -1, +1 ) )\n"
729  str += " {\n"
730  str += " int sp_idx = ( coord[4] == %s ? X%s(param.dc.Ls-1)*2*param.dc.volume_4d_cb : X%s2*param.dc.volume_4d_cb ) / 2;\n" % (ledge, rsign, lsign)
731  str += "\n"
732  str += "// read spinor from device memory\n"
733  str += " READ_SPINOR( SPINORTEX, param.sp_stride, sp_idx, sp_idx );\n"
734  str += "\n"
735  str += " if ( coord[4] != %s )\n" % ledge
736  str += " {\n"
737 
738  def proj(i,j):
739  return two_P_L[4*i+j]
740 
741  # xs != 0:
742  out_L = ""
743  for s1 in range(0,4):
744  #{
745  for c in range(0,3):
746  re_rhs, im_rhs = "", ""
747  for s2 in range(0,4):
748  re, im = proj(s1,s2).real, proj(s1,s2).imag
749  if re != 0 :
750  re_rhs += sign(re) + in_re(s2,c)
751  im_rhs += sign(re) + in_im(s2,c)
752  if im != 0 :
753  re_rhs += sign(-im) + in_im(s2,c)
754  im_rhs += sign(im) + in_re(s2,c)
755  out_L += 3*" " + out_re(s1,c) + " += " + re_rhs + ";"
756  out_L += 3*" " + out_im(s1,c) + " += " + im_rhs + ";\n"
757  if s1 < 3 : out_L += "\n"
758  #}
759 
760  str += out_L
761 
762  str += " }\n"
763  str += " else\n"
764  str += " {\n"
765 
766  # xs == 0:
767  str += out_L.replace(" += "," += -param.mferm*(").replace(";",");")
768 
769  str += " } // end if ( coord[4]!= %s )\n" % ledge
770  str += " } // end P_L\n\n"
771  str += " // 2 P_R = 2 P_+ = ( ( +1, +1 ), ( +1, +1 ) )\n"
772  str += " {\n"
773  str += " int sp_idx = ( coord[4] == %s ? X%s(param.dc.Ls-1)*2*param.dc.volume_4d_cb : X%s2*param.dc.volume_4d_cb ) / 2;\n" % (redge, lsign, rsign)
774  str += "\n"
775  str += "// read spinor from device memory\n"
776  str += " READ_SPINOR( SPINORTEX, param.sp_stride, sp_idx, sp_idx );\n"
777  str += "\n"
778  str += " if ( coord[4] != %s )\n" % redge
779  str += " {\n"
780 
781  # xs != Ls-1
782  str += out_L.replace("-","+")
783 
784  str += " }\n"
785  str += " else\n"
786  str += " {\n"
787 
788  # xs == Ls-1
789  str += out_L.replace("-","+").replace(" += "," += -param.mferm*(").replace(";",");")
790 
791  str += " } // end if ( coord[4] != %s )\n" % redge
792  str += " } // end P_R\n"
793  str += "} // end 5th dimension\n\n\n"
794 
795  return str
796 # end def gen_dw
797 
798 
799 def input_spinor(s,c,z):
800  if dslash:
801  if z==0: return out_re(s,c)
802  else: return out_im(s,c)
803  else:
804  if z==0: return in_re(s,c)
805  else: return in_im(s,c)
806 
808  str = ""
809  str += "spinorFloat "+a_re(0,0,c)+" = -"+input_spinor(1,c,0)+" - "+input_spinor(3,c,0)+";\n"
810  str += "spinorFloat "+a_im(0,0,c)+" = -"+input_spinor(1,c,1)+" - "+input_spinor(3,c,1)+";\n"
811  str += "spinorFloat "+a_re(0,1,c)+" = "+input_spinor(0,c,0)+" + "+input_spinor(2,c,0)+";\n"
812  str += "spinorFloat "+a_im(0,1,c)+" = "+input_spinor(0,c,1)+" + "+input_spinor(2,c,1)+";\n"
813  str += "spinorFloat "+a_re(0,2,c)+" = -"+input_spinor(1,c,0)+" + "+input_spinor(3,c,0)+";\n"
814  str += "spinorFloat "+a_im(0,2,c)+" = -"+input_spinor(1,c,1)+" + "+input_spinor(3,c,1)+";\n"
815  str += "spinorFloat "+a_re(0,3,c)+" = "+input_spinor(0,c,0)+" - "+input_spinor(2,c,0)+";\n"
816  str += "spinorFloat "+a_im(0,3,c)+" = "+input_spinor(0,c,1)+" - "+input_spinor(2,c,1)+";\n"
817  str += "\n"
818 
819  for s in range (0,4):
820  str += out_re(s,c)+" = "+a_re(0,s,c)+"; "
821  str += out_im(s,c)+" = "+a_im(0,s,c)+";\n"
822 
823  return block(str)+"\n\n"
824 # end def to_chiral_basis
825 
826 
827 def from_chiral_basis(c): # note: factor of 1/2 is included in clover term normalization
828  str = ""
829  str += "spinorFloat "+a_re(0,0,c)+" = "+out_re(1,c)+" + "+out_re(3,c)+";\n"
830  str += "spinorFloat "+a_im(0,0,c)+" = "+out_im(1,c)+" + "+out_im(3,c)+";\n"
831  str += "spinorFloat "+a_re(0,1,c)+" = -"+out_re(0,c)+" - "+out_re(2,c)+";\n"
832  str += "spinorFloat "+a_im(0,1,c)+" = -"+out_im(0,c)+" - "+out_im(2,c)+";\n"
833  str += "spinorFloat "+a_re(0,2,c)+" = "+out_re(1,c)+" - "+out_re(3,c)+";\n"
834  str += "spinorFloat "+a_im(0,2,c)+" = "+out_im(1,c)+" - "+out_im(3,c)+";\n"
835  str += "spinorFloat "+a_re(0,3,c)+" = -"+out_re(0,c)+" + "+out_re(2,c)+";\n"
836  str += "spinorFloat "+a_im(0,3,c)+" = -"+out_im(0,c)+" + "+out_im(2,c)+";\n"
837  str += "\n"
838 
839  for s in range (0,4):
840  str += out_re(s,c)+" = "+a_re(0,s,c)+"; "
841  str += out_im(s,c)+" = "+a_im(0,s,c)+";\n"
842 
843  return block(str)+"\n\n"
844 # end def from_chiral_basis
845 
846 
847 def clover_mult(chi):
848  str = "READ_CLOVER(CLOVERTEX, "+`chi`+")\n\n"
849 
850  for s in range (0,2):
851  for c in range (0,3):
852  str += "spinorFloat "+a_re(chi,s,c)+" = 0; spinorFloat "+a_im(chi,s,c)+" = 0;\n"
853  str += "\n"
854 
855  for sm in range (0,2):
856  for cm in range (0,3):
857  for sn in range (0,2):
858  for cn in range (0,3):
859  str += a_re(chi,sm,cm)+" += "+c_re(chi,sm,cm,sn,cn)+" * "+out_re(2*chi+sn,cn)+";\n"
860  if (sn != sm) or (cn != cm):
861  str += a_re(chi,sm,cm)+" -= "+c_im(chi,sm,cm,sn,cn)+" * "+out_im(2*chi+sn,cn)+";\n"
862  #else: str += ";\n"
863  str += a_im(chi,sm,cm)+" += "+c_re(chi,sm,cm,sn,cn)+" * "+out_im(2*chi+sn,cn)+";\n"
864  if (sn != sm) or (cn != cm):
865  str += a_im(chi,sm,cm)+" += "+c_im(chi,sm,cm,sn,cn)+" * "+out_re(2*chi+sn,cn)+";\n"
866  #else: str += ";\n"
867  str += "\n"
868 
869  for s in range (0,2):
870  for c in range (0,3):
871  str += out_re(2*chi+s,c)+" = "+a_re(chi,s,c)+"; "
872  str += out_im(2*chi+s,c)+" = "+a_im(chi,s,c)+";\n"
873  str += "\n"
874 
875  return block(str)+"\n\n"
876 # end def clover_mult
877 
878 
880  if domain_wall: return ""
881  str = ""
882  if dslash: str += "#ifdef DSLASH_CLOVER\n\n"
883  str += "// change to chiral basis\n"
885  str += "// apply first chiral block\n"
886  str += clover_mult(0)
887  str += "// apply second chiral block\n"
888  str += clover_mult(1)
889  str += "// change back from chiral basis\n"
890  str += "// (note: required factor of 1/2 is included in clover term normalization)\n"
892  if dslash: str += "#endif // DSLASH_CLOVER\n\n"
893 
894  return str
895 # end def clover
896 
897 
899  str = ""
900  str += "#if defined MULTI_GPU && defined DSLASH_XPAY\n"
901  str += "if (kernel_type == INTERIOR_KERNEL)\n"
902  str += "#endif\n"
903  str += "{\n"
904  str += "#ifdef DSLASH_XPAY\n"
905  str += " READ_ACCUM(ACCUMTEX, param.sp_stride)\n"
906  str += "#ifdef SPINOR_DOUBLE\n"
907  str += "spinorFloat a_inv = param.a_inv;\n"
908  str += "#else\n"
909  str += "spinorFloat a_inv = param.a_inv_f;\n"
910  str += "#endif\n"
911  str += "#ifdef SPINOR_DOUBLE\n"
912 
913  for s in range(0,4):
914  for c in range(0,3):
915  i = 3*s+c
916  str +=" "+ out_re(s,c) +" = "+out_re(s,c)+" + a_inv*accum"+nthFloat2(2*i+0)+";\n"
917  str +=" "+ out_im(s,c) +" = "+out_im(s,c)+" + a_inv*accum"+nthFloat2(2*i+1)+";\n"
918 
919  str += "#else\n"
920 
921  for s in range(0,4):
922  for c in range(0,3):
923  i = 3*s+c
924  str +=" "+ out_re(s,c) +" = "+out_re(s,c)+" + a_inv*accum"+nthFloat4(2*i+0)+";\n"
925  str +=" "+ out_im(s,c) +" = "+out_im(s,c)+" + a_inv*accum"+nthFloat4(2*i+1)+";\n"
926 
927  str += "#endif // SPINOR_DOUBLE\n\n"
928  str += "#endif // DSLASH_XPAY\n"
929  str += "}\n\n"
930 
931  return str
932 # end def xpay_lmem_pre
933 
934 
935 def xpay_lmem():
936  str = ""
937  str += "#ifdef DSLASH_XPAY\n"
938 
939  str += "#ifdef SPINOR_DOUBLE\n"
940  str += "spinorFloat a = param.a;\n"
941  str += "#else\n"
942  str += "spinorFloat a = param.a_f;\n"
943  str += "#endif\n"
944 
945  str += "#ifdef SPINOR_DOUBLE\n"
946 
947  for s in range(0,4):
948  for c in range(0,3):
949  i = 3*s+c
950  str +=" "+ out_re(s,c) +" = a*"+out_re(s,c)+";\n"
951  str +=" "+ out_im(s,c) +" = a*"+out_im(s,c)+";\n"
952 
953  str += "#else\n"
954 
955  for s in range(0,4):
956  for c in range(0,3):
957  i = 3*s+c
958  str +=" "+ out_re(s,c) +" = a*"+out_re(s,c)+";\n"
959  str +=" "+ out_im(s,c) +" = a*"+out_im(s,c)+";\n"
960 
961  str += "#endif // SPINOR_DOUBLE\n\n"
962  str += "#endif // DSLASH_XPAY\n"
963 
964  return str
965 # end def xpay_lmem
966 
967 
968 def epilog():
969  str = ""
970  if dslash:
971  if twist:
972  str += "#ifdef MULTI_GPU\n"
973  else:
974  if domain_wall:
975  str += xpay_lmem_pre()
976  str += "#if defined MULTI_GPU && defined DSLASH_XPAY\n"
977  else:
978  str += "#if defined MULTI_GPU && (defined DSLASH_XPAY || defined DSLASH_CLOVER)\n"
979  str += (
980 """
981 int incomplete = 0; // Have all 8 contributions been computed for this site?
982 
983 switch(kernel_type) { // intentional fall-through
984 case INTERIOR_KERNEL:
985  incomplete = incomplete || (param.commDim[3] && (coord[3]==0 || coord[3]==(param.dc.X[3]-1)));
986 case EXTERIOR_KERNEL_T:
987  incomplete = incomplete || (param.commDim[2] && (coord[2]==0 || coord[2]==(param.dc.X[2]-1)));
988 case EXTERIOR_KERNEL_Z:
989  incomplete = incomplete || (param.commDim[1] && (coord[1]==0 || coord[1]==(param.dc.X[1]-1)));
990 case EXTERIOR_KERNEL_Y:
991  incomplete = incomplete || (param.commDim[0] && (coord[0]==0 || coord[0]==(param.dc.X[0]-1)));
992 }
993 
994 """)
995  str += "if (!incomplete)\n"
996  str += "#endif // MULTI_GPU\n"
997 
998  str += block( "\n" + (twisted() if twist else apply_clover()) + xpay_lmem() )
999 
1000  str += "\n\n"
1001  str += "// write spinor field back to device memory\n"
1002  str += "WRITE_SPINOR(param.sp_stride);\n\n"
1003 
1004  str += "// undefine to prevent warning when precision is changed\n"
1005  str += "#undef spinorFloat\n"
1006  str += "#undef SHARED_STRIDE\n\n"
1007 
1008  if dslash:
1009  for m in range(0,3):
1010  for n in range(0,3):
1011  i = 3*m+n
1012  str += "#undef "+g_re(0,m,n)+"\n"
1013  str += "#undef "+g_im(0,m,n)+"\n"
1014  str += "\n"
1015 
1016  for s in range(0,4):
1017  for c in range(0,3):
1018  i = 3*s+c
1019  str += "#undef "+in_re(s,c)+"\n"
1020  str += "#undef "+in_im(s,c)+"\n"
1021  str += "\n"
1022 
1023  if clover == True:
1024  for m in range(0,6):
1025  s = m/3
1026  c = m%3
1027  str += "#undef "+c_re(0,s,c,s,c)+"\n"
1028  for n in range(0,6):
1029  sn = n/3
1030  cn = n%3
1031  for m in range(n+1,6):
1032  sm = m/3
1033  cm = m%3
1034  str += "#undef "+c_re(0,sm,cm,sn,cn)+"\n"
1035  str += "#undef "+c_im(0,sm,cm,sn,cn)+"\n"
1036  str += "\n"
1037 
1038  for s in range(0,4):
1039  for c in range(0,3):
1040  i = 3*s+c
1041  if 2*i < sharedFloats:
1042  str += "#undef "+out_re(s,c)+"\n"
1043  if 2*i+1 < sharedFloats:
1044  str += "#undef "+out_im(s,c)+"\n"
1045  str += "\n"
1046 
1047  str += "#undef VOLATILE\n"
1048 
1049  return str
1050 # end def epilog
1051 
1052 
1053 def pack_face(facenum):
1054  str = "\n"
1055  str += "switch(dim) {\n"
1056  for dim in range(0,4):
1057  str += "case "+`dim`+":\n"
1058  proj = gen(2*dim+facenum, pack_only=True)
1059  proj += "\n"
1060  proj += "// write half spinor back to device memory\n"
1061  proj += "WRITE_HALF_SPINOR(face_volume, face_idx);\n"
1062  str += indent(block(proj)+"\n"+"break;\n")
1063  str += "}\n\n"
1064  return str
1065 # end def pack_face
1066 
1068  assert (sharedFloats == 0)
1069  str = ""
1070  str += def_input_spinor()
1071  str += "#include \"io_spinor.h\"\n\n"
1072 
1073  str += "if (face_num) "
1074  str += block(pack_face(1))
1075  str += " else "
1076  str += block(pack_face(0))
1077 
1078  str += "\n\n"
1079  str += "// undefine to prevent warning when precision is changed\n"
1080  str += "#undef spinorFloat\n"
1081  str += "#undef SHARED_STRIDE\n\n"
1082 
1083  for s in range(0,4):
1084  for c in range(0,3):
1085  i = 3*s+c
1086  str += "#undef "+in_re(s,c)+"\n"
1087  str += "#undef "+in_im(s,c)+"\n"
1088  str += "\n"
1089 
1090  return str
1091 # end def generate_pack
1092 
1093 
1095  r = prolog()
1096  for i in range(0,8) :
1097  r += gen( i )
1098  if domain_wall:
1099  r += gen_dw()
1100  r += epilog()
1101  return r
1102 
1104  return prolog() + epilog()
1105 
1106 
1107 # To fit 192 threads/SM (single precision) with 16K shared memory, set sharedFloats to 19 or smaller
1108 
1109 sharedFloats = 0
1110 cloverSharedFloats = 0
1111 if(len(sys.argv) > 1):
1112  if (sys.argv[1] == '--shared'):
1113  sharedFloats = int(sys.argv[2])
1114 print "Shared floats set to " + str(sharedFloats);
1115 
1116 # generate Domain Wall Dslash kernels
1117 domain_wall = True
1118 twist = False
1119 clover = False
1120 
1121 print sys.argv[0] + ": generating dw_dslash_core.h";
1122 dslash = True
1123 dagger = False
1124 f = open('dslash_core/dw_dslash_core.h', 'w')
1125 f.write(generate_dslash())
1126 f.close()
1127 
1128 print sys.argv[0] + ": generating dw_dslash_dagger_core.h";
1129 dslash = True
1130 dagger = True
1131 f = open('dslash_core/dw_dslash_dagger_core.h', 'w')
1132 f.write(generate_dslash())
1133 f.close()
1134 
1135 
1136 
1137 
1138 
1139 
def complexify(a)
complex numbers ######################################################################## ...
def gen(dir, pack_only=False)
def indent(code, n=1)
code generation ######################################################################## ...
Definition: gen.py:1
if(err !=cudaSuccess)
def c_im(b, sm, cm, sn, cn)
def c_re(b, sm, cm, sn, cn)