QUDA  0.9.0
fused_exterior_dw_dslash_cuda_gen.py
Go to the documentation of this file.
1 # -*- coding: utf-8 -*-
2 import sys
3 
4 
5 
6 def complexify(a):
7  return [complex(x) for x in a]
8 
9 def complexToStr(c):
10  def fltToString(a):
11  if a == int(a): return `int(a)`
12  else: return `a`
13 
14  def imToString(a):
15  if a == 0: return "0i"
16  elif a == -1: return "-i"
17  elif a == 1: return "i"
18  else: return fltToString(a)+"i"
19 
20  re = c.real
21  im = c.imag
22  if re == 0 and im == 0: return "0"
23  elif re == 0: return imToString(im)
24  elif im == 0: return fltToString(re)
25  else:
26  im_str = "-"+imToString(-im) if im < 0 else "+"+imToString(im)
27  return fltToString(re)+im_str
28 
29 
30 
31 
32 id = complexify([
33  1, 0, 0, 0,
34  0, 1, 0, 0,
35  0, 0, 1, 0,
36  0, 0, 0, 1
37 ])
38 
39 gamma1 = complexify([
40  0, 0, 0, 1j,
41  0, 0, 1j, 0,
42  0, -1j, 0, 0,
43  -1j, 0, 0, 0
44 ])
45 
46 gamma2 = complexify([
47  0, 0, 0, 1,
48  0, 0, -1, 0,
49  0, -1, 0, 0,
50  1, 0, 0, 0
51 ])
52 
53 gamma3 = complexify([
54  0, 0, 1j, 0,
55  0, 0, 0, -1j,
56  -1j, 0, 0, 0,
57  0, 1j, 0, 0
58 ])
59 
60 gamma4 = complexify([
61  1, 0, 0, 0,
62  0, 1, 0, 0,
63  0, 0, -1, 0,
64  0, 0, 0, -1
65 ])
66 
67 igamma5 = complexify([
68  0, 0, 1j, 0,
69  0, 0, 0, 1j,
70  1j, 0, 0, 0,
71  0, 1j, 0, 0
72 ])
73 
74 two_P_L = [ id[x] - igamma5[x]/1j for x in range(0,4*4) ]
75 two_P_R = [ id[x] + igamma5[x]/1j for x in range(0,4*4) ]
76 
77 # for s1 in range(0,4) :
78 # for s2 in range (0,4): print "%8s" % two_P_L[s1*4+s2],
79 # print " ",
80 # for s2 in range (0,4): print "%8s" % two_P_R[s1*4+s2],
81 # print ""
82 
83 
84 def gplus(g1, g2):
85  return [x+y for (x,y) in zip(g1,g2)]
86 
87 def gminus(g1, g2):
88  return [x-y for (x,y) in zip(g1,g2)]
89 
91  out = ""
92  for i in range(0, 4):
93  for j in range(0,4):
94  out += '%3s' % complexToStr(p[4*i+j])
95  out += "\n"
96  return out
97 
98 projectors = [
99  gminus(id,gamma1), gplus(id,gamma1),
100  gminus(id,gamma2), gplus(id,gamma2),
101  gminus(id,gamma3), gplus(id,gamma3),
102  gminus(id,gamma4), gplus(id,gamma4),
103 ]
104 
105 
106 
107 def indent(code, n=1):
108  def indentline(line): return (n*" "+line if ( line and line.count("#", 0, 1) == 0) else line)
109  return ''.join([indentline(line)+"\n" for line in code.splitlines()])
110 
111 def block(code):
112  return "{\n"+indent(code)+"}"
113 
114 def sign(x):
115  if x==1: return "+"
116  elif x==-1: return "-"
117  elif x==+2: return "+2*"
118  elif x==-2: return "-2*"
119 
120 def nthFloat4(n):
121  return `(n/4)` + "." + ["x", "y", "z", "w"][n%4]
122 
123 def nthFloat2(n):
124  return `(n/2)` + "." + ["x", "y"][n%2]
125 
126 
127 def in_re(s, c): return "i"+`s`+`c`+"_re"
128 def in_im(s, c): return "i"+`s`+`c`+"_im"
129 def g_re(d, m, n): return ("g" if (d%2==0) else "gT")+`m`+`n`+"_re"
130 def g_im(d, m, n): return ("g" if (d%2==0) else "gT")+`m`+`n`+"_im"
131 def out_re(s, c): return "o"+`s`+`c`+"_re"
132 def out_im(s, c): return "o"+`s`+`c`+"_im"
133 def h1_re(h, c): return ["a","b"][h]+`c`+"_re"
134 def h1_im(h, c): return ["a","b"][h]+`c`+"_im"
135 def h2_re(h, c): return ["A","B"][h]+`c`+"_re"
136 def h2_im(h, c): return ["A","B"][h]+`c`+"_im"
137 def c_re(b, sm, cm, sn, cn): return "c"+`(sm+2*b)`+`cm`+"_"+`(sn+2*b)`+`cn`+"_re"
138 def c_im(b, sm, cm, sn, cn): return "c"+`(sm+2*b)`+`cm`+"_"+`(sn+2*b)`+`cn`+"_im"
139 def a_re(b, s, c): return "a"+`(s+2*b)`+`c`+"_re"
140 def a_im(b, s, c): return "a"+`(s+2*b)`+`c`+"_im"
141 
142 def tmp_re(s, c): return "tmp"+`s`+`c`+"_re"
143 def tmp_im(s, c): return "tmp"+`s`+`c`+"_im"
144 
145 
147  str = ""
148  str += "// input spinor\n"
149  str += "#ifdef SPINOR_DOUBLE\n"
150  str += "#define spinorFloat double\n"
151  for s in range(0,4):
152  for c in range(0,3):
153  i = 3*s+c
154  str += "#define "+in_re(s,c)+" I"+nthFloat2(2*i+0)+"\n"
155  str += "#define "+in_im(s,c)+" I"+nthFloat2(2*i+1)+"\n"
156  str += "#else\n"
157  str += "#define spinorFloat float\n"
158  for s in range(0,4):
159  for c in range(0,3):
160  i = 3*s+c
161  str += "#define "+in_re(s,c)+" I"+nthFloat4(2*i+0)+"\n"
162  str += "#define "+in_im(s,c)+" I"+nthFloat4(2*i+1)+"\n"
163  str += "#endif // SPINOR_DOUBLE\n\n"
164  return str
165 # end def def_input_spinor
166 
167 
168 def def_gauge():
169  str = "// gauge link\n"
170  str += "#ifdef GAUGE_FLOAT2\n"
171  for m in range(0,3):
172  for n in range(0,3):
173  i = 3*m+n
174  str += "#define "+g_re(0,m,n)+" G"+nthFloat2(2*i+0)+"\n"
175  str += "#define "+g_im(0,m,n)+" G"+nthFloat2(2*i+1)+"\n"
176 
177  str += "\n"
178  str += "#else\n"
179  for m in range(0,3):
180  for n in range(0,3):
181  i = 3*m+n
182  str += "#define "+g_re(0,m,n)+" G"+nthFloat4(2*i+0)+"\n"
183  str += "#define "+g_im(0,m,n)+" G"+nthFloat4(2*i+1)+"\n"
184 
185  str += "\n"
186  str += "#endif // GAUGE_DOUBLE\n\n"
187 
188  str += "// conjugated gauge link\n"
189  for m in range(0,3):
190  for n in range(0,3):
191  i = 3*m+n
192  str += "#define "+g_re(1,m,n)+" (+"+g_re(0,n,m)+")\n"
193  str += "#define "+g_im(1,m,n)+" (-"+g_im(0,n,m)+")\n"
194  str += "\n"
195 
196  return str
197 # end def def_gauge
198 
199 
201  str = "// first chiral block of inverted clover term\n"
202  str += "#ifdef CLOVER_DOUBLE\n"
203  i = 0
204  for m in range(0,6):
205  s = m/3
206  c = m%3
207  str += "#define "+c_re(0,s,c,s,c)+" C"+nthFloat2(i)+"\n"
208  i += 1
209  for n in range(0,6):
210  sn = n/3
211  cn = n%3
212  for m in range(n+1,6):
213  sm = m/3
214  cm = m%3
215  str += "#define "+c_re(0,sm,cm,sn,cn)+" C"+nthFloat2(i)+"\n"
216  str += "#define "+c_im(0,sm,cm,sn,cn)+" C"+nthFloat2(i+1)+"\n"
217  i += 2
218  str += "#else\n"
219  i = 0
220  for m in range(0,6):
221  s = m/3
222  c = m%3
223  str += "#define "+c_re(0,s,c,s,c)+" C"+nthFloat4(i)+"\n"
224  i += 1
225  for n in range(0,6):
226  sn = n/3
227  cn = n%3
228  for m in range(n+1,6):
229  sm = m/3
230  cm = m%3
231  str += "#define "+c_re(0,sm,cm,sn,cn)+" C"+nthFloat4(i)+"\n"
232  str += "#define "+c_im(0,sm,cm,sn,cn)+" C"+nthFloat4(i+1)+"\n"
233  i += 2
234  str += "#endif // CLOVER_DOUBLE\n\n"
235 
236  for n in range(0,6):
237  sn = n/3
238  cn = n%3
239  for m in range(0,n):
240  sm = m/3
241  cm = m%3
242  str += "#define "+c_re(0,sm,cm,sn,cn)+" (+"+c_re(0,sn,cn,sm,cm)+")\n"
243  str += "#define "+c_im(0,sm,cm,sn,cn)+" (-"+c_im(0,sn,cn,sm,cm)+")\n"
244  str += "\n"
245 
246  str += "// second chiral block of inverted clover term (reuses C0,...,C9)\n"
247  for n in range(0,6):
248  sn = n/3
249  cn = n%3
250  for m in range(0,6):
251  sm = m/3
252  cm = m%3
253  str += "#define "+c_re(1,sm,cm,sn,cn)+" "+c_re(0,sm,cm,sn,cn)+"\n"
254  if m != n: str += "#define "+c_im(1,sm,cm,sn,cn)+" "+c_im(0,sm,cm,sn,cn)+"\n"
255  str += "\n"
256 
257  return str
258 # end def def_clover
259 
261  str = "// output spinor\n"
262  for s in range(0,4):
263  for c in range(0,3):
264  i = 3*s+c
265  if 2*i < sharedFloats:
266  str += "#define "+out_re(s,c)+" s["+`(2*i+0)`+"*SHARED_STRIDE]\n"
267  else:
268  str += "VOLATILE spinorFloat "+out_re(s,c)+";\n"
269  if 2*i+1 < sharedFloats:
270  str += "#define "+out_im(s,c)+" s["+`(2*i+1)`+"*SHARED_STRIDE]\n"
271  else:
272  str += "VOLATILE spinorFloat "+out_im(s,c)+";\n"
273  return str
274 # end def def_output_spinor
275 
276 
277 def prolog():
278  if dslash:
279  prolog_str= ("// *** CUDA DSLASH ***\n\n" if not dagger else "// *** CUDA DSLASH DAGGER ***\n\n")
280  prolog_str+= "#define DSLASH_SHARED_FLOATS_PER_THREAD "+str(sharedFloats)+"\n\n"
281  elif clover:
282  prolog_str= ("// *** CUDA CLOVER ***\n\n")
283  prolog_str+= "#define CLOVER_SHARED_FLOATS_PER_THREAD "+str(sharedFloats)+"\n\n"
284  else:
285  print "Undefined prolog"
286  exit
287 
288 
289  prolog_str+= (
290 """
291 #ifdef MULTI_GPU
292 
293 #if (CUDA_VERSION >= 4010)
294 #define VOLATILE
295 #else
296 #define VOLATILE volatile
297 #endif
298 """)
299 
300  prolog_str+= def_input_spinor()
301  if dslash == True: prolog_str+= def_gauge()
302  if clover == True: prolog_str+= def_clover()
303  prolog_str+= def_output_spinor()
304 
305  prolog_str+= (
306 """
307 #ifdef SPINOR_DOUBLE
308 #if (__COMPUTE_CAPABILITY__ >= 200)
309 #define SHARED_STRIDE 16 // to avoid bank conflicts on Fermi
310 #else
311 #define SHARED_STRIDE 8 // to avoid bank conflicts on G80 and GT200
312 #endif
313 #else
314 #if (__COMPUTE_CAPABILITY__ >= 200)
315 #define SHARED_STRIDE 32 // to avoid bank conflicts on Fermi
316 #else
317 #define SHARED_STRIDE 16 // to avoid bank conflicts on G80 and GT200
318 #endif
319 #endif
320 """)
321 
322  if sharedFloats > 0:
323  prolog_str += (
324 """
325 extern __shared__ char s_data[];
326 """)
327 
328  if dslash:
329  prolog_str += (
330 """
331 VOLATILE spinorFloat *s = (spinorFloat*)s_data + DSLASH_SHARED_FLOATS_PER_THREAD*SHARED_STRIDE*(threadIdx.x/SHARED_STRIDE)
332 + (threadIdx.x % SHARED_STRIDE);
333 """)
334  else:
335  prolog_str += (
336 """
337 VOLATILE spinorFloat *s = (spinorFloat*)s_data + CLOVER_SHARED_FLOATS_PER_THREAD*SHARED_STRIDE*(threadIdx.x/SHARED_STRIDE)
338 + (threadIdx.x % SHARED_STRIDE);
339 """)
340 
341 
342  if dslash:
343  prolog_str += "\n#include \"read_gauge.h\"\n"
344  if not domain_wall:
345  prolog_str += "#include \"read_clover.h\"\n"
346  prolog_str += "#include \"io_spinor.h\"\n"
347  prolog_str += (
348 """
349 #if (DD_PREC==2) // half precision
350 int sp_norm_idx;
351 #endif // half precision
352 
353 int sid = ((blockIdx.y*blockDim.y + threadIdx.y)*gridDim.x + blockIdx.x)*blockDim.x + threadIdx.x;
354 if (sid >= param.threads*param.dc.Ls) return;
355 
356 int dim;
357 int face_idx;
358 int coord[5];
359 int X;
360 int s_parity;
361 
362 
363 """)
364 
365 
366  prolog_str+= (
367 """
368 { // exterior kernel
369 
370 dim = dimFromFaceIndex<5>(sid, param); // sid is also modified
371 
372 const int face_volume = ((param.threadDimMapUpper[dim] - param.threadDimMapLower[dim])*param.dc.Ls >> 1);
373 const int face_num = (sid >= face_volume); // is this thread updating face 0 or 1
374 face_idx = sid - face_num*face_volume; // index into the respective face
375 
376 // ghostOffset is scaled to include body (includes stride) and number of FloatN arrays (SPINOR_HOP)
377 // face_idx not sid since faces are spin projected and share the same volume index (modulo UP/DOWN reading)
378 //sp_idx = face_idx + param.ghostOffset[dim];
379 
380 switch(dim) {
381 case 0:
382  coordsFromFaceIndex<5,QUDA_5D_PC,0,1>(X, sid, coord, face_idx, face_num, param);
383  break;
384 case 1:
385  coordsFromFaceIndex<5,QUDA_5D_PC,1,1>(X, sid, coord, face_idx, face_num, param);
386  break;
387 case 2:
388  coordsFromFaceIndex<5,QUDA_5D_PC,2,1>(X, sid, coord, face_idx, face_num, param);
389  break;
390 case 3:
391  coordsFromFaceIndex<5,QUDA_5D_PC,3,1>(X, sid, coord, face_idx, face_num, param);
392  break;
393 }
394 
395 bool active = false;
396 for(int dir=0; dir<4; ++dir){
397  active = active || isActive(dim,dir,+1,coord,param.commDim,param.dc.X);
398 }
399 if(!active) return;
400 
401 
402 s_parity = ( sid/param.dc.volume_4d_cb ) % 2;
403 
404 READ_INTERMEDIATE_SPINOR(INTERTEX, param.sp_stride, sid, sid);
405 
406 """)
407 
408  out = ""
409  for s in range(0,4):
410  for c in range(0,3):
411  out += out_re(s,c)+" = "+in_re(s,c)+"; "+out_im(s,c)+" = "+in_im(s,c)+";\n"
412  prolog_str+= indent(out)
413  prolog_str+= "}\n"
414 
415  if domain_wall:
416  prolog_str += (
417 """
418 // declare G## here and use ASSN below instead of READ
419 #ifdef GAUGE_FLOAT2
420 #if (DD_PREC==0) //temporal hack
421 double2 G0;
422 double2 G1;
423 double2 G2;
424 double2 G3;
425 double2 G4;
426 double2 G5;
427 double2 G6;
428 double2 G7;
429 double2 G8;
430 #else
431 float2 G0;
432 float2 G1;
433 float2 G2;
434 float2 G3;
435 float2 G4;
436 float2 G5;
437 float2 G6;
438 float2 G7;
439 float2 G8;
440 #endif
441 #else
442 float4 G0;
443 float4 G1;
444 float4 G2;
445 float4 G3;
446 float4 G4;
447 #endif
448 
449 """)
450 
451  prolog_str+= "\n\n"
452 
453  elif domain_wall:
454  prolog_str+=(
455 """
456 #include "io_spinor.h"
457 
458 int sid = blockIdx.x*blockDim.x + threadIdx.x;
459 if (sid >= param.threads) return;
460 
461 // read spinor from device memory
462 READ_SPINOR(SPINORTEX, param.sp_stride, sid, sid);
463 
464 """)
465  else:
466  prolog_str+=(
467 """
468 #include "read_clover.h"
469 #include "io_spinor.h"
470 
471 int sid = blockIdx.x*blockDim.x + threadIdx.x;
472 if (sid >= param.threads) return;
473 
474 // read spinor from device memory
475 READ_SPINOR(SPINORTEX, param.sp_stride, sid, sid);
476 
477 """)
478  return prolog_str
479 # end def prolog
480 
481 
482 def gen(dir, pack_only=False):
483  projIdx = dir if not dagger else dir + ( +1 if dir%2 == 0 else -1 )
484  projStr = projectorToStr(projectors[projIdx])
485  def proj(i,j):
486  return projectors[projIdx][4*i+j]
487 
488  # if row(i) = (j, c), then the i'th row of the projector can be represented
489  # as a multiple of the j'th row: row(i) = c row(j)
490  def row(i):
491  assert i==2 or i==3
492  if proj(i,0) == 0j:
493  return (1, proj(i,1))
494  if proj(i,1) == 0j:
495  return (0, proj(i,0))
496 
497  boundary = ["coord[0]==(param.dc.X[0]-1)", "coord[0]==0", "coord[1]==(param.dc.X[1]-1)", "coord[1]==0", "coord[2]==(param.dc.X[2]-1)", "coord[2]==0", "coord[3]==(param.dc.X[3]-1)", "coord[3]==0"]
498  interior = ["coord[0]<(param.dc.X[0]-1)", "coord[0]>0", "coord[1]<(param.dc.X[1]-1)", "coord[1]>0", "coord[2]<(param.dc.X[2]-1)", "coord[2]>0", "coord[3]<(param.dc.X[3]-1)", "coord[3]>0"]
499  offset = ["+1", "-1", "+1", "-1", "+1", "-1", "+1", "-1"]
500  dim = ["X", "Y", "Z", "T"]
501 
502  # index of neighboring site when not on boundary
503  sp_idx = ["X+1", "X-1", "X+param.dc.X[0]", "X-param.dc.X[0]", "X+param.dc.X2X1", "X-param.dc.X2X1", "X+param.dc.X3X2X1", "X-param.dc.X3X2X1"]
504 
505  # index of neighboring site (across boundary)
506  sp_idx_wrap = ["X-(param.dc.X[0]-1)", "X+(param.dc.X[0]-1)", "X-param.dc.X2X1mX1", "X+param.dc.X2X1mX1", "X-param.dc.X3X2X1mX2X1", "X+param.dc.X3X2X1mX2X1",
507  "X-param.dc.X4X3X2X1mX3X2X1", "X+param.dc.X4X3X2X1mX3X2X1"]
508 
509  cond = ""
510  cond += "if (isActive(dim," + `dir/2` + "," + offset[dir] + ",coord,param.commDim,param.dc.X) && " + boundary[dir] + " )\n"
511 
512 
513  str = ""
514 
515  projName = "P"+`dir/2`+["-","+"][projIdx%2]
516  str += "// Projector "+projName+"\n"
517  for l in projStr.splitlines():
518  str += "//"+l+"\n"
519  str += "\n"
520 
521  str += "faceIndexFromCoords<5,1>(face_idx,coord," + `dir/2` + ",param);\n"
522  str += "const int sp_idx = face_idx + param.ghostOffset[" + `dir/2` + "][" + `1-dir%2` + "];\n"
523  str += "#if (DD_PREC==2) // half precision\n"
524  str += " sp_norm_idx = face_idx + "
525 # if dir%2 == 0:
526 # str += "param.dc.Ls*param.dc.ghostFace[" + `dir/2` + "] + "
527  str += "param.ghostNormOffset[" + `dir/2` + "][" + `1-dir%2` + "];\n"
528  str += "#endif\n\n"
529  str += "\n"
530  if dir % 2 == 0:
531  if domain_wall: str += "const int ga_idx = sid % param.dc.volume_4d_cb;\n"
532  else: str += "const int ga_idx = sid;\n"
533  else:
534  if domain_wall: str += "const int ga_idx = param.dc.volume_4d_cb+(face_idx % param.dc.ghostFace[" + `dir/2` + "]);\n"
535  else: str += "const int ga_idx = param.dc.volume_4d_cb+face_idx;\n"
536  str += "\n"
537 
538  # scan the projector to determine which loads are required
539  row_cnt = ([0,0,0,0])
540  for h in range(0,4):
541  for s in range(0,4):
542  re = proj(h,s).real
543  im = proj(h,s).imag
544  if re != 0 or im != 0:
545  row_cnt[h] += 1
546  row_cnt[0] += row_cnt[1]
547  row_cnt[2] += row_cnt[3]
548 
549  decl_half = ""
550  for h in range(0, 2):
551  for c in range(0, 3):
552  decl_half += "spinorFloat "+h1_re(h,c)+", "+h1_im(h,c)+";\n";
553  decl_half += "\n"
554 
555  load_spinor = "// read spinor from device memory\n"
556  if row_cnt[0] == 0:
557  load_spinor += "READ_SPINOR_DOWN(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
558  elif row_cnt[2] == 0:
559  load_spinor += "READ_SPINOR_UP(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
560  else:
561  load_spinor += "READ_SPINOR(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
562  load_spinor += "\n"
563 
564  load_half = ""
565  if domain_wall :
566  load_half += "const int sp_stride_pad = param.dc.Ls*param.dc.ghostFace[" + `dir/2` + "];\n"
567  else :
568  load_half += "const int sp_stride_pad = param.dc.ghostFace[" + `dir/2` + "];\n"
569 
570  if dir >= 6: load_half += "const int t_proj_scale = TPROJSCALE;\n"
571  load_half += "\n"
572  load_half += "// read half spinor from device memory\n"
573 
574 # we have to use the same volume index for backwards and forwards gathers
575 # instead of using READ_UP_SPINOR and READ_DOWN_SPINOR, just use READ_HALF_SPINOR with the appropriate shift
576  load_half += "READ_SPINOR_GHOST(GHOSTSPINORTEX, sp_stride_pad, sp_idx, sp_norm_idx, "+`dir`+");\n\n"
577 # if (dir+1) % 2 == 0: load_half += "READ_HALF_SPINOR(GHOSTSPINORTEX, sp_stride_pad, sp_idx, sp_norm_idx);\n\n"
578 # else: load_half += "READ_HALF_SPINOR(SPINORTEX, sp_stride_pad, sp_idx + (SPINOR_HOP/2)*sp_stride_pad, sp_norm_idx);\n\n"
579  load_gauge = "// read gauge matrix from device memory\n"
580  if domain_wall:
581  load_gauge += "if ( ! s_parity ) { ASSN_GAUGE_MATRIX(G, GAUGE"+`( dir%2)`+"TEX, "+`dir`+", ga_idx, param.gauge_stride); }\n"
582  load_gauge += "else { ASSN_GAUGE_MATRIX(G, GAUGE"+`(1-dir%2)`+"TEX, "+`dir`+", ga_idx, param.gauge_stride); }\n\n"
583  else:
584  load_gauge += "READ_GAUGE_MATRIX(G, GAUGE"+`dir%2`+"TEX, "+`dir`+", ga_idx, param.gauge_stride);\n\n"
585 
586  reconstruct_gauge = "// reconstruct gauge matrix\n"
587  reconstruct_gauge += "RECONSTRUCT_GAUGE_MATRIX("+`dir`+");\n\n"
588 
589  project = "// project spinor into half spinors\n"
590  for h in range(0, 2):
591  for c in range(0, 3):
592  strRe = ""
593  strIm = ""
594  for s in range(0, 4):
595  re = proj(h,s).real
596  im = proj(h,s).imag
597  if re==0 and im==0: ()
598  elif im==0:
599  strRe += sign(re)+in_re(s,c)
600  strIm += sign(re)+in_im(s,c)
601  elif re==0:
602  strRe += sign(-im)+in_im(s,c)
603  strIm += sign(im)+in_re(s,c)
604  if row_cnt[0] == 0: # projector defined on lower half only
605  for s in range(0, 4):
606  re = proj(h+2,s).real
607  im = proj(h+2,s).imag
608  if re==0 and im==0: ()
609  elif im==0:
610  strRe += sign(re)+in_re(s,c)
611  strIm += sign(re)+in_im(s,c)
612  elif re==0:
613  strRe += sign(-im)+in_im(s,c)
614  strIm += sign(im)+in_re(s,c)
615 
616  project += h1_re(h,c)+" = "+strRe+";\n"
617  project += h1_im(h,c)+" = "+strIm+";\n"
618 
619  copy_half = ""
620  for h in range(0, 2):
621  for c in range(0, 3):
622  copy_half += h1_re(h,c)+" = "+("t_proj_scale*" if (dir >= 6) else "")+in_re(h,c)+"; "
623  copy_half += h1_im(h,c)+" = "+("t_proj_scale*" if (dir >= 6) else "")+in_im(h,c)+";\n"
624  copy_half += "\n"
625 
626  prep_half = ""
627  prep_half += "{\n"
628  prep_half += "\n"
629  prep_half += indent(load_half)
630  prep_half += indent(copy_half)
631  prep_half += "}\n"
632  prep_half += "\n"
633 
634  ident = "// identity gauge matrix\n"
635  for m in range(0,3):
636  for h in range(0,2):
637  ident += "spinorFloat "+h2_re(h,m)+" = " + h1_re(h,m) + "; "
638  ident += "spinorFloat "+h2_im(h,m)+" = " + h1_im(h,m) + ";\n"
639  ident += "\n"
640 
641  mult = ""
642  for m in range(0,3):
643  mult += "// multiply row "+`m`+"\n"
644  for h in range(0,2):
645  re = "spinorFloat "+h2_re(h,m)+" = 0;\n"
646  im = "spinorFloat "+h2_im(h,m)+" = 0;\n"
647  for c in range(0,3):
648  re += h2_re(h,m) + " += " + g_re(dir,m,c) + " * "+h1_re(h,c)+";\n"
649  re += h2_re(h,m) + " -= " + g_im(dir,m,c) + " * "+h1_im(h,c)+";\n"
650  im += h2_im(h,m) + " += " + g_re(dir,m,c) + " * "+h1_im(h,c)+";\n"
651  im += h2_im(h,m) + " += " + g_im(dir,m,c) + " * "+h1_re(h,c)+";\n"
652  mult += re + im
653  mult += "\n"
654 
655  reconstruct = ""
656  for m in range(0,3):
657 
658  for h in range(0,2):
659  h_out = h
660  if row_cnt[0] == 0: # projector defined on lower half only
661  h_out = h+2
662  reconstruct += out_re(h_out, m) + " += " + h2_re(h,m) + ";\n"
663  reconstruct += out_im(h_out, m) + " += " + h2_im(h,m) + ";\n"
664 
665  for s in range(2,4):
666  (h,c) = row(s)
667  re = c.real
668  im = c.imag
669  if im == 0 and re == 0:
670  ()
671  elif im == 0:
672  reconstruct += out_re(s, m) + " " + sign(re) + "= " + h2_re(h,m) + ";\n"
673  reconstruct += out_im(s, m) + " " + sign(re) + "= " + h2_im(h,m) + ";\n"
674  elif re == 0:
675  reconstruct += out_re(s, m) + " " + sign(-im) + "= " + h2_im(h,m) + ";\n"
676  reconstruct += out_im(s, m) + " " + sign(+im) + "= " + h2_re(h,m) + ";\n"
677 
678  if ( m < 2 ): reconstruct += "\n"
679 
680  if dir >= 6:
681  str += "if (param.gauge_fixed && ga_idx < param.dc.X4X3X2X1hmX3X2X1h)\n"
682  str += block(decl_half + prep_half + ident + reconstruct)
683  str += " else "
684  str += block(load_gauge + decl_half + prep_half + reconstruct_gauge + mult + reconstruct)
685  else:
686  str += load_gauge + decl_half + prep_half + reconstruct_gauge + mult + reconstruct
687 
688  if pack_only:
689  out = load_spinor + decl_half + project
690  out = out.replace("sp_idx", "idx")
691  return out
692  else:
693  return cond + block(str)+"\n\n"
694 # end def gen
695 
696 
697 
698 
699 def input_spinor(s,c,z):
700  if dslash:
701  if z==0: return out_re(s,c)
702  else: return out_im(s,c)
703  else:
704  if z==0: return in_re(s,c)
705  else: return in_im(s,c)
706 
708  str = ""
709  str += "spinorFloat "+a_re(0,0,c)+" = -"+input_spinor(1,c,0)+" - "+input_spinor(3,c,0)+";\n"
710  str += "spinorFloat "+a_im(0,0,c)+" = -"+input_spinor(1,c,1)+" - "+input_spinor(3,c,1)+";\n"
711  str += "spinorFloat "+a_re(0,1,c)+" = "+input_spinor(0,c,0)+" + "+input_spinor(2,c,0)+";\n"
712  str += "spinorFloat "+a_im(0,1,c)+" = "+input_spinor(0,c,1)+" + "+input_spinor(2,c,1)+";\n"
713  str += "spinorFloat "+a_re(0,2,c)+" = -"+input_spinor(1,c,0)+" + "+input_spinor(3,c,0)+";\n"
714  str += "spinorFloat "+a_im(0,2,c)+" = -"+input_spinor(1,c,1)+" + "+input_spinor(3,c,1)+";\n"
715  str += "spinorFloat "+a_re(0,3,c)+" = "+input_spinor(0,c,0)+" - "+input_spinor(2,c,0)+";\n"
716  str += "spinorFloat "+a_im(0,3,c)+" = "+input_spinor(0,c,1)+" - "+input_spinor(2,c,1)+";\n"
717  str += "\n"
718 
719  for s in range (0,4):
720  str += out_re(s,c)+" = "+a_re(0,s,c)+"; "
721  str += out_im(s,c)+" = "+a_im(0,s,c)+";\n"
722 
723  return block(str)+"\n\n"
724 # end def to_chiral_basis
725 
726 
727 def from_chiral_basis(c): # note: factor of 1/2 is included in clover term normalization
728  str = ""
729  str += "spinorFloat "+a_re(0,0,c)+" = "+out_re(1,c)+" + "+out_re(3,c)+";\n"
730  str += "spinorFloat "+a_im(0,0,c)+" = "+out_im(1,c)+" + "+out_im(3,c)+";\n"
731  str += "spinorFloat "+a_re(0,1,c)+" = -"+out_re(0,c)+" - "+out_re(2,c)+";\n"
732  str += "spinorFloat "+a_im(0,1,c)+" = -"+out_im(0,c)+" - "+out_im(2,c)+";\n"
733  str += "spinorFloat "+a_re(0,2,c)+" = "+out_re(1,c)+" - "+out_re(3,c)+";\n"
734  str += "spinorFloat "+a_im(0,2,c)+" = "+out_im(1,c)+" - "+out_im(3,c)+";\n"
735  str += "spinorFloat "+a_re(0,3,c)+" = -"+out_re(0,c)+" + "+out_re(2,c)+";\n"
736  str += "spinorFloat "+a_im(0,3,c)+" = -"+out_im(0,c)+" + "+out_im(2,c)+";\n"
737  str += "\n"
738 
739  for s in range (0,4):
740  str += out_re(s,c)+" = "+a_re(0,s,c)+"; "
741  str += out_im(s,c)+" = "+a_im(0,s,c)+";\n"
742 
743  return block(str)+"\n\n"
744 # end def from_chiral_basis
745 
746 
747 def clover_mult(chi):
748  str = "READ_CLOVER(CLOVERTEX, "+`chi`+")\n\n"
749 
750  for s in range (0,2):
751  for c in range (0,3):
752  str += "spinorFloat "+a_re(chi,s,c)+" = 0; spinorFloat "+a_im(chi,s,c)+" = 0;\n"
753  str += "\n"
754 
755  for sm in range (0,2):
756  for cm in range (0,3):
757  for sn in range (0,2):
758  for cn in range (0,3):
759  str += a_re(chi,sm,cm)+" += "+c_re(chi,sm,cm,sn,cn)+" * "+out_re(2*chi+sn,cn)+";\n"
760  if (sn != sm) or (cn != cm):
761  str += a_re(chi,sm,cm)+" -= "+c_im(chi,sm,cm,sn,cn)+" * "+out_im(2*chi+sn,cn)+";\n"
762  #else: str += ";\n"
763  str += a_im(chi,sm,cm)+" += "+c_re(chi,sm,cm,sn,cn)+" * "+out_im(2*chi+sn,cn)+";\n"
764  if (sn != sm) or (cn != cm):
765  str += a_im(chi,sm,cm)+" += "+c_im(chi,sm,cm,sn,cn)+" * "+out_re(2*chi+sn,cn)+";\n"
766  #else: str += ";\n"
767  str += "\n"
768 
769  for s in range (0,2):
770  for c in range (0,3):
771  str += out_re(2*chi+s,c)+" = "+a_re(chi,s,c)+"; "
772  str += out_im(2*chi+s,c)+" = "+a_im(chi,s,c)+";\n"
773  str += "\n"
774 
775  return block(str)+"\n\n"
776 # end def clover_mult
777 
778 
780  if domain_wall: return ""
781  str = ""
782  if dslash: str += "#ifdef DSLASH_CLOVER\n\n"
783  str += "// change to chiral basis\n"
785  str += "// apply first chiral block\n"
786  str += clover_mult(0)
787  str += "// apply second chiral block\n"
788  str += clover_mult(1)
789  str += "// change back from chiral basis\n"
790  str += "// (note: required factor of 1/2 is included in clover term normalization)\n"
792  if dslash: str += "#endif // DSLASH_CLOVER\n\n"
793 
794  return str
795 # end def clover
796 
797 
798 def xpay_lmem():
799  str = ""
800  str += "#ifdef DSLASH_XPAY\n"
801 
802  str += "#ifdef SPINOR_DOUBLE\n"
803  str += "spinorFloat a = param.a;\n"
804  str += "#else\n"
805  str += "spinorFloat a = param.a_f;\n"
806  str += "#endif\n"
807 
808  str += "#ifdef SPINOR_DOUBLE\n"
809 
810  for s in range(0,4):
811  for c in range(0,3):
812  i = 3*s+c
813  str +=" "+ out_re(s,c) +" = a*"+out_re(s,c)+";\n"
814  str +=" "+ out_im(s,c) +" = a*"+out_im(s,c)+";\n"
815 
816  str += "#else\n"
817 
818  for s in range(0,4):
819  for c in range(0,3):
820  i = 3*s+c
821  str +=" "+ out_re(s,c) +" = a*"+out_re(s,c)+";\n"
822  str +=" "+ out_im(s,c) +" = a*"+out_im(s,c)+";\n"
823 
824  str += "#endif // SPINOR_DOUBLE\n\n"
825  str += "#endif // DSLASH_XPAY\n"
826 
827  return str
828 # end def xpay_lmem
829 
830 
831 def epilog():
832  str = ""
833  str += block( "\n" + (twisted() if twist else apply_clover()) + xpay_lmem() )
834 
835  str += "\n\n"
836  str += "// write spinor field back to device memory\n"
837  str += "WRITE_SPINOR(param.sp_stride);\n\n"
838 
839  str += "// undefine to prevent warning when precision is changed\n"
840  str += "#undef spinorFloat\n"
841  str += "#undef SHARED_STRIDE\n\n"
842 
843  if dslash:
844  for m in range(0,3):
845  for n in range(0,3):
846  i = 3*m+n
847  str += "#undef "+g_re(0,m,n)+"\n"
848  str += "#undef "+g_im(0,m,n)+"\n"
849  str += "\n"
850 
851  for s in range(0,4):
852  for c in range(0,3):
853  i = 3*s+c
854  str += "#undef "+in_re(s,c)+"\n"
855  str += "#undef "+in_im(s,c)+"\n"
856  str += "\n"
857 
858  if clover == True:
859  for m in range(0,6):
860  s = m/3
861  c = m%3
862  str += "#undef "+c_re(0,s,c,s,c)+"\n"
863  for n in range(0,6):
864  sn = n/3
865  cn = n%3
866  for m in range(n+1,6):
867  sm = m/3
868  cm = m%3
869  str += "#undef "+c_re(0,sm,cm,sn,cn)+"\n"
870  str += "#undef "+c_im(0,sm,cm,sn,cn)+"\n"
871  str += "\n"
872 
873  for s in range(0,4):
874  for c in range(0,3):
875  i = 3*s+c
876  if 2*i < sharedFloats:
877  str += "#undef "+out_re(s,c)+"\n"
878  if 2*i+1 < sharedFloats:
879  str += "#undef "+out_im(s,c)+"\n"
880  str += "\n"
881 
882  str += "#undef VOLATILE\n"
883 
884  str += "\n"
885  str += "#endif // MULTI_GPU\n"
886  return str
887 # end def epilog
888 
889 
890 def pack_face(facenum):
891  str = "\n"
892  str += "switch(dim) {\n"
893  for dim in range(0,4):
894  str += "case "+`dim`+":\n"
895  proj = gen(2*dim+facenum, pack_only=True)
896  proj += "\n"
897  proj += "// write half spinor back to device memory\n"
898  proj += "WRITE_HALF_SPINOR(face_volume, face_idx);\n"
899  str += indent(block(proj)+"\n"+"break;\n")
900  str += "}\n\n"
901  return str
902 # end def pack_face
903 
905  assert (sharedFloats == 0)
906  str = ""
907  str += def_input_spinor()
908  str += "#include \"io_spinor.h\"\n\n"
909 
910  str += "if (face_num) "
911  str += block(pack_face(1))
912  str += " else "
913  str += block(pack_face(0))
914 
915  str += "\n\n"
916  str += "// undefine to prevent warning when precision is changed\n"
917  str += "#undef spinorFloat\n"
918  str += "#undef SHARED_STRIDE\n\n"
919 
920  for s in range(0,4):
921  for c in range(0,3):
922  i = 3*s+c
923  str += "#undef "+in_re(s,c)+"\n"
924  str += "#undef "+in_im(s,c)+"\n"
925  str += "\n"
926 
927  return str
928 # end def generate_pack
929 
930 
932  r = prolog()
933  for i in range(0,8) :
934  r += gen( i )
935  r += epilog()
936  return r
937 
939  return prolog() + epilog()
940 
941 
942 # To fit 192 threads/SM (single precision) with 16K shared memory, set sharedFloats to 19 or smaller
943 
944 sharedFloats = 0
945 cloverSharedFloats = 0
946 if(len(sys.argv) > 1):
947  if (sys.argv[1] == '--shared'):
948  sharedFloats = int(sys.argv[2])
949 print "Shared floats set to " + str(sharedFloats);
950 
951 # generate Domain Wall Dslash kernels
952 domain_wall = True
953 twist = False
954 clover = False
955 
956 print sys.argv[0] + ": generating dw_dslash_core.h";
957 dslash = True
958 dagger = False
959 f = open('dslash_core/dw_fused_exterior_dslash_core.h', 'w')
960 f.write(generate_dslash())
961 f.close()
962 
963 print sys.argv[0] + ": generating dw_dslash_dagger_core.h";
964 dslash = True
965 dagger = True
966 f = open('dslash_core/dw_fused_exterior_dslash_dagger_core.h', 'w')
967 f.write(generate_dslash())
968 f.close()
969 
970 
971 
972 
973 
974 
def indent(code, n=1)
code generation ######################################################################## ...
Definition: gen.py:1
def complexify(a)
complex numbers ######################################################################## ...
if(err !=cudaSuccess)