QUDA  0.9.0
dw_dslash_4D_cuda_gen.py
Go to the documentation of this file.
1 # -*- coding: utf-8 -*-
2 import sys
3 
4 
5 
6 def complexify(a):
7  return [complex(x) for x in a]
8 
9 def complexToStr(c):
10  def fltToString(a):
11  if a == int(a): return `int(a)`
12  else: return `a`
13 
14  def imToString(a):
15  if a == 0: return "0i"
16  elif a == -1: return "-i"
17  elif a == 1: return "i"
18  else: return fltToString(a)+"i"
19 
20  re = c.real
21  im = c.imag
22  if re == 0 and im == 0: return "0"
23  elif re == 0: return imToString(im)
24  elif im == 0: return fltToString(re)
25  else:
26  im_str = "-"+imToString(-im) if im < 0 else "+"+imToString(im)
27  return fltToString(re)+im_str
28 
29 
30 
31 
32 id = complexify([
33  1, 0, 0, 0,
34  0, 1, 0, 0,
35  0, 0, 1, 0,
36  0, 0, 0, 1
37 ])
38 
39 gamma1 = complexify([
40  0, 0, 0, 1j,
41  0, 0, 1j, 0,
42  0, -1j, 0, 0,
43  -1j, 0, 0, 0
44 ])
45 
46 gamma2 = complexify([
47  0, 0, 0, 1,
48  0, 0, -1, 0,
49  0, -1, 0, 0,
50  1, 0, 0, 0
51 ])
52 
53 gamma3 = complexify([
54  0, 0, 1j, 0,
55  0, 0, 0, -1j,
56  -1j, 0, 0, 0,
57  0, 1j, 0, 0
58 ])
59 
60 gamma4 = complexify([
61  1, 0, 0, 0,
62  0, 1, 0, 0,
63  0, 0, -1, 0,
64  0, 0, 0, -1
65 ])
66 
67 igamma5 = complexify([
68  0, 0, 1j, 0,
69  0, 0, 0, 1j,
70  1j, 0, 0, 0,
71  0, 1j, 0, 0
72 ])
73 
74 two_P_L = [ id[x] - igamma5[x]/1j for x in range(0,4*4) ]
75 two_P_R = [ id[x] + igamma5[x]/1j for x in range(0,4*4) ]
76 
77 # for s1 in range(0,4) :
78 # for s2 in range (0,4): print "%8s" % two_P_L[s1*4+s2],
79 # print " ",
80 # for s2 in range (0,4): print "%8s" % two_P_R[s1*4+s2],
81 # print ""
82 
83 
84 def gplus(g1, g2):
85  return [x+y for (x,y) in zip(g1,g2)]
86 
87 def gminus(g1, g2):
88  return [x-y for (x,y) in zip(g1,g2)]
89 
91  out = ""
92  for i in range(0, 4):
93  for j in range(0,4):
94  out += '%3s' % complexToStr(p[4*i+j])
95  out += "\n"
96  return out
97 
98 projectors = [
99  gminus(id,gamma1), gplus(id,gamma1),
100  gminus(id,gamma2), gplus(id,gamma2),
101  gminus(id,gamma3), gplus(id,gamma3),
102  gminus(id,gamma4), gplus(id,gamma4),
103 ]
104 
105 
106 
107 def indent(code, n=1):
108  def indentline(line): return (n*" "+line if ( line and line.count("#", 0, 1) == 0) else line)
109  return ''.join([indentline(line)+"\n" for line in code.splitlines()])
110 
111 def block(code):
112  return "{\n"+indent(code)+"}"
113 
114 def sign(x):
115  if x==1: return "+"
116  elif x==-1: return "-"
117  elif x==+2: return "+2*"
118  elif x==-2: return "-2*"
119 
120 def nthFloat4(n):
121  return `(n/4)` + "." + ["x", "y", "z", "w"][n%4]
122 
123 def nthFloat2(n):
124  return `(n/2)` + "." + ["x", "y"][n%2]
125 
126 
127 def in_re(s, c): return "i"+`s`+`c`+"_re"
128 def in_im(s, c): return "i"+`s`+`c`+"_im"
129 def g_re(d, m, n): return ("g" if (d%2==0) else "gT")+`m`+`n`+"_re"
130 def g_im(d, m, n): return ("g" if (d%2==0) else "gT")+`m`+`n`+"_im"
131 def out_re(s, c): return "o"+`s`+`c`+"_re"
132 def out_im(s, c): return "o"+`s`+`c`+"_im"
133 def h1_re(h, c): return ["a","b"][h]+`c`+"_re"
134 def h1_im(h, c): return ["a","b"][h]+`c`+"_im"
135 def h2_re(h, c): return ["A","B"][h]+`c`+"_re"
136 def h2_im(h, c): return ["A","B"][h]+`c`+"_im"
137 def c_re(b, sm, cm, sn, cn): return "c"+`(sm+2*b)`+`cm`+"_"+`(sn+2*b)`+`cn`+"_re"
138 def c_im(b, sm, cm, sn, cn): return "c"+`(sm+2*b)`+`cm`+"_"+`(sn+2*b)`+`cn`+"_im"
139 def a_re(b, s, c): return "a"+`(s+2*b)`+`c`+"_re"
140 def a_im(b, s, c): return "a"+`(s+2*b)`+`c`+"_im"
141 
142 def tmp_re(s, c): return "tmp"+`s`+`c`+"_re"
143 def tmp_im(s, c): return "tmp"+`s`+`c`+"_im"
144 
145 
147  str = ""
148  str += "// input spinor\n"
149  str += "#ifdef SPINOR_DOUBLE\n"
150  str += "#define spinorFloat double\n"
151  str += "// workaround for C++11 bug in CUDA 6.5/7.0\n"
152  str += "#if CUDA_VERSION >= 6050 && CUDA_VERSION < 7050\n"
153  str += "#define POW(a, b) pow(a, static_cast<spinorFloat>(b))\n"
154  str += "#else\n"
155  str += "#define POW(a, b) pow(a, b)\n"
156  str += "#endif\n\n"
157  for s in range(0,4):
158  for c in range(0,3):
159  i = 3*s+c
160  str += "#define "+in_re(s,c)+" I"+nthFloat2(2*i+0)+"\n"
161  str += "#define "+in_im(s,c)+" I"+nthFloat2(2*i+1)+"\n"
162  str += "#define m5 param.m5_d\n"
163  str += "#define mdwf_b5 param.mdwf_b5_d\n"
164  str += "#define mdwf_c5 param.mdwf_c5_d\n"
165  str += "#define mferm param.mferm\n"
166  str += "#define a param.a\n"
167  str += "#define b param.b\n"
168  str += "#else\n"
169  str += "#define spinorFloat float\n"
170  str += "#define POW(a, b) __fast_pow(a, b)\n"
171  for s in range(0,4):
172  for c in range(0,3):
173  i = 3*s+c
174  str += "#define "+in_re(s,c)+" I"+nthFloat4(2*i+0)+"\n"
175  str += "#define "+in_im(s,c)+" I"+nthFloat4(2*i+1)+"\n"
176  str += "#define m5 param.m5_f\n"
177  str += "#define mdwf_b5 param.mdwf_b5_f\n"
178  str += "#define mdwf_c5 param.mdwf_c5_f\n"
179  str += "#define mferm param.mferm_f\n"
180  str += "#define a param.a\n"
181  str += "#define b param.b\n"
182  str += "#endif // SPINOR_DOUBLE\n\n"
183  return str
184 # end def def_input_spinor
185 
186 
187 def def_gauge():
188  str = "// gauge link\n"
189  str += "#ifdef GAUGE_FLOAT2\n"
190  for m in range(0,3):
191  for n in range(0,3):
192  i = 3*m+n
193  str += "#define "+g_re(0,m,n)+" G"+nthFloat2(2*i+0)+"\n"
194  str += "#define "+g_im(0,m,n)+" G"+nthFloat2(2*i+1)+"\n"
195 
196  str += "\n"
197  str += "#else\n"
198  for m in range(0,3):
199  for n in range(0,3):
200  i = 3*m+n
201  str += "#define "+g_re(0,m,n)+" G"+nthFloat4(2*i+0)+"\n"
202  str += "#define "+g_im(0,m,n)+" G"+nthFloat4(2*i+1)+"\n"
203 
204  str += "\n"
205  str += "#endif // GAUGE_DOUBLE\n\n"
206 
207  str += "// conjugated gauge link\n"
208  for m in range(0,3):
209  for n in range(0,3):
210  i = 3*m+n
211  str += "#define "+g_re(1,m,n)+" (+"+g_re(0,n,m)+")\n"
212  str += "#define "+g_im(1,m,n)+" (-"+g_im(0,n,m)+")\n"
213  str += "\n"
214 
215  return str
216 # end def def_gauge
217 
218 
220  str = "// first chiral block of inverted clover term\n"
221  str += "#ifdef CLOVER_DOUBLE\n"
222  i = 0
223  for m in range(0,6):
224  s = m/3
225  c = m%3
226  str += "#define "+c_re(0,s,c,s,c)+" C"+nthFloat2(i)+"\n"
227  i += 1
228  for n in range(0,6):
229  sn = n/3
230  cn = n%3
231  for m in range(n+1,6):
232  sm = m/3
233  cm = m%3
234  str += "#define "+c_re(0,sm,cm,sn,cn)+" C"+nthFloat2(i)+"\n"
235  str += "#define "+c_im(0,sm,cm,sn,cn)+" C"+nthFloat2(i+1)+"\n"
236  i += 2
237  str += "#else\n"
238  i = 0
239  for m in range(0,6):
240  s = m/3
241  c = m%3
242  str += "#define "+c_re(0,s,c,s,c)+" C"+nthFloat4(i)+"\n"
243  i += 1
244  for n in range(0,6):
245  sn = n/3
246  cn = n%3
247  for m in range(n+1,6):
248  sm = m/3
249  cm = m%3
250  str += "#define "+c_re(0,sm,cm,sn,cn)+" C"+nthFloat4(i)+"\n"
251  str += "#define "+c_im(0,sm,cm,sn,cn)+" C"+nthFloat4(i+1)+"\n"
252  i += 2
253  str += "#endif // CLOVER_DOUBLE\n\n"
254 
255  for n in range(0,6):
256  sn = n/3
257  cn = n%3
258  for m in range(0,n):
259  sm = m/3
260  cm = m%3
261  str += "#define "+c_re(0,sm,cm,sn,cn)+" (+"+c_re(0,sn,cn,sm,cm)+")\n"
262  str += "#define "+c_im(0,sm,cm,sn,cn)+" (-"+c_im(0,sn,cn,sm,cm)+")\n"
263  str += "\n"
264 
265  str += "// second chiral block of inverted clover term (reuses C0,...,C9)\n"
266  for n in range(0,6):
267  sn = n/3
268  cn = n%3
269  for m in range(0,6):
270  sm = m/3
271  cm = m%3
272  str += "#define "+c_re(1,sm,cm,sn,cn)+" "+c_re(0,sm,cm,sn,cn)+"\n"
273  if m != n: str += "#define "+c_im(1,sm,cm,sn,cn)+" "+c_im(0,sm,cm,sn,cn)+"\n"
274  str += "\n"
275 
276  return str
277 # end def def_clover
278 
280  str = "// output spinor\n"
281  for s in range(0,4):
282  for c in range(0,3):
283  i = 3*s+c
284  if 2*i < sharedFloats:
285  str += "#define "+out_re(s,c)+" s["+`(2*i+0)`+"*SHARED_STRIDE]\n"
286  else:
287  str += "VOLATILE spinorFloat "+out_re(s,c)+";\n"
288  if 2*i+1 < sharedFloats:
289  str += "#define "+out_im(s,c)+" s["+`(2*i+1)`+"*SHARED_STRIDE]\n"
290  else:
291  str += "VOLATILE spinorFloat "+out_im(s,c)+";\n"
292  return str
293 # end def def_output_spinor
294 
295 
296 def prolog():
297  if dslash:
298  prolog_str= ("// *** CUDA DSLASH ***\n\n" if not dagger else "// *** CUDA DSLASH DAGGER ***\n\n")
299  prolog_str+= "#define DSLASH_SHARED_FLOATS_PER_THREAD "+str(sharedFloats)+"\n\n"
300  elif clover:
301  prolog_str= ("// *** CUDA CLOVER ***\n\n")
302  prolog_str+= "#define CLOVER_SHARED_FLOATS_PER_THREAD "+str(sharedFloats)+"\n\n"
303  else:
304  print "Undefined prolog"
305  exit
306 
307  prolog_str+= (
308 """
309 #if (CUDA_VERSION >= 4010)
310 #define VOLATILE
311 #else
312 #define VOLATILE volatile
313 #endif
314 """)
315 
316  prolog_str+= def_input_spinor()
317  if dslash == True:
318  if dslash4 == True:
319  prolog_str+= def_gauge()
320  if clover == True: prolog_str+= def_clover()
321  prolog_str+= def_output_spinor()
322 
323  prolog_str+= (
324 """
325 #ifdef SPINOR_DOUBLE
326 #if (__COMPUTE_CAPABILITY__ >= 200)
327 #define SHARED_STRIDE 16 // to avoid bank conflicts on Fermi
328 #else
329 #define SHARED_STRIDE 8 // to avoid bank conflicts on G80 and GT200
330 #endif
331 #else
332 #if (__COMPUTE_CAPABILITY__ >= 200)
333 #define SHARED_STRIDE 32 // to avoid bank conflicts on Fermi
334 #else
335 #define SHARED_STRIDE 16 // to avoid bank conflicts on G80 and GT200
336 #endif
337 #endif
338 """)
339 
340  if sharedFloats > 0:
341  prolog_str += (
342 """
343 extern __shared__ char s_data[];
344 """)
345 
346  if dslash:
347  prolog_str += (
348 """
349 VOLATILE spinorFloat *s = (spinorFloat*)s_data + DSLASH_SHARED_FLOATS_PER_THREAD*SHARED_STRIDE*(threadIdx.x/SHARED_STRIDE)
350 + (threadIdx.x % SHARED_STRIDE);
351 """)
352  else:
353  prolog_str += (
354 """
355 VOLATILE spinorFloat *s = (spinorFloat*)s_data + CLOVER_SHARED_FLOATS_PER_THREAD*SHARED_STRIDE*(threadIdx.x/SHARED_STRIDE)
356 + (threadIdx.x % SHARED_STRIDE);
357 """)
358 
359 
360  if dslash:
361  if dslash4 == True:
362  prolog_str += "\n#include \"read_gauge.h\"\n"
363  if not domain_wall:
364  prolog_str += "#include \"read_clover.h\"\n"
365  prolog_str += "#include \"io_spinor.h\"\n"
366  prolog_str += (
367 """
368 int sid = ((blockIdx.y*blockDim.y + threadIdx.y)*gridDim.x + blockIdx.x)*blockDim.x + threadIdx.x;
369 if (sid >= param.threads*param.dc.Ls) return;
370 
371 """)
372  if domain_wall:
373  if dslash4 == True:
374  prolog_str+=(
375 """
376 int X, coord[5];
377 """)
378  else:
379  prolog_str+=(
380 """
381 int X, coord[5], boundaryCrossing;
382 """)
383  else:
384  prolog_str+=(
385 """
386 int X, coord[5];
387 """)
388  if dslash4 == True:
389  prolog_str += (
390 """
391 int face_idx;
392 if (kernel_type == INTERIOR_KERNEL) {
393 """)
394  prolog_str += (
395 """
396 
397 """)
398  if domain_wall:
399  if dslash4 == True:
400  prolog_str+=(
401 """
402 
403  // Assume even dimensions
404  coordsFromIndex<5,QUDA_4D_PC,EVEN_X>(X, coord, sid, param);
405 
406 """)
407  else:
408  prolog_str+=(
409 """
410 boundaryCrossing = sid/param.dc.Xh[0] + sid/(param.dc.X[1]*param.dc.Xh[0]) + sid/(param.dc.X[2]*param.dc.X[1]*param.dc.Xh[0]);
411 
412 X = 2*sid + (boundaryCrossing + param.parity) % 2;
413 coord[4] = X/(param.dc.X[0]*param.dc.X[1]*param.dc.X[2]*param.dc.X[3]);
414 
415 """)
416 
417  else:
418  prolog_str+=(
419 """
420 X = 2*sid;
421 int aux1 = X / param.dc.X[0];
422 x1 = X - aux1 * param.dc.X[0];
423 int aux2 = aux1 / param.dc.X[1];
424 x2 = aux1 - aux2 * param.dc.X[1];
425 x4 = aux2 / param.dc.X[2];
426 x3 = aux2 - x4 * param.dc.X[2];
427 aux1 = (param.parity + x4 + x3 + x2) & 1;
428 x1 += aux1;
429 X += aux1;
430 
431 """)
432 
433  out = ""
434  for s in range(0,4):
435  for c in range(0,3):
436  out += out_re(s,c)+" = 0; "+out_im(s,c)+" = 0;\n"
437  prolog_str+= indent(out)
438 
439  if dslash4 == True:
440  prolog_str+= (
441 """
442 } else { // exterior kernel
443 
444 const int face_volume = (param.threads*param.dc.Ls >> 1); // volume of one face
445 const int face_num = (sid >= face_volume); // is this thread updating face 0 or 1
446 face_idx = sid - face_num*face_volume; // index into the respective face
447 
448 // ghostOffset is scaled to include body (includes stride) and number of FloatN arrays (SPINOR_HOP)
449 // face_idx not sid since faces are spin projected and share the same volume index (modulo UP/DOWN reading)
450 //sp_idx = face_idx + param.ghostOffset[dim];
451 
452 coordsFromFaceIndex<5,QUDA_4D_PC,kernel_type,1>(X, sid, coord, face_idx, face_num, param);
453 
454 READ_INTERMEDIATE_SPINOR(INTERTEX, param.sp_stride, sid, sid);
455 """)
456  out = ""
457  for s in range(0,4):
458  for c in range(0,3):
459  out += out_re(s,c)+" = "+in_re(s,c)+"; "+out_im(s,c)+" = "+in_im(s,c)+";\n"
460  prolog_str+= indent(out)
461  prolog_str+= "}\n"
462 
463  if domain_wall:
464  if dslash4 == True:
465  prolog_str += (
466 """
467 // declare G## here and use ASSN below instead of READ
468 #ifdef GAUGE_FLOAT2
469 #if (DD_PREC==0) //temporal hack
470 double2 G0;
471 double2 G1;
472 double2 G2;
473 double2 G3;
474 double2 G4;
475 double2 G5;
476 double2 G6;
477 double2 G7;
478 double2 G8;
479 #else
480 float2 G0;
481 float2 G1;
482 float2 G2;
483 float2 G3;
484 float2 G4;
485 float2 G5;
486 float2 G6;
487 float2 G7;
488 float2 G8;
489 #endif
490 #else
491 float4 G0;
492 float4 G1;
493 float4 G2;
494 float4 G3;
495 float4 G4;
496 #endif
497 
498 """)
499 
500  prolog_str+= "\n\n"
501 
502  else:
503  prolog_str+=(
504 """
505 #include "read_clover.h"
506 #include "io_spinor.h"
507 
508 int sid = blockIdx.x*blockDim.x + threadIdx.x;
509 if (sid >= param.threads) return;
510 
511 // read spinor from device memory
512 READ_SPINOR(SPINORTEX, param.sp_stride, sid, sid);
513 
514 """)
515  return prolog_str
516 # end def prolog
517 
518 
519 def gen(dir, pack_only=False):
520  projIdx = dir if not dagger else dir + ( +1 if dir%2 == 0 else -1 )
521  projStr = projectorToStr(projectors[projIdx])
522  def proj(i,j):
523  return projectors[projIdx][4*i+j]
524 
525  # if row(i) = (j, c), then the i'th row of the projector can be represented
526  # as a multiple of the j'th row: row(i) = c row(j)
527  def row(i):
528  assert i==2 or i==3
529  if proj(i,0) == 0j:
530  return (1, proj(i,1))
531  if proj(i,1) == 0j:
532  return (0, proj(i,0))
533 
534  boundary = ["coord[0]==(param.dc.X[0]-1)", "coord[0]==0", "coord[1]==(param.dc.X[1]-1)", "coord[1]==0", "coord[2]==(param.dc.X[2]-1)", "coord[2]==0", "coord[3]==(param.dc.X[3]-1)", "coord[3]==0"]
535  interior = ["coord[0]<(param.dc.X[0]-1)", "coord[0]>0", "coord[1]<(param.dc.X[1]-1)", "coord[1]>0", "coord[2]<(param.dc.X[2]-1)", "coord[2]>0", "coord[3]<(param.dc.X[3]-1)", "coord[3]>0"]
536  dim = ["X", "Y", "Z", "T"]
537 
538  # index of neighboring site when not on boundary
539  sp_idx = ["X+1", "X-1", "X+param.dc.X[0]", "X-param.dc.X[0]", "X+param.dc.X2X1", "X-param.dc.X2X1", "X+param.dc.X3X2X1", "X-param.dc.X3X2X1"]
540 
541  # index of neighboring site (across boundary)
542  sp_idx_wrap = ["X-(param.dc.X[0]-1)", "X+(param.dc.X[0]-1)", "X-param.dc.X2X1mX1", "X+param.dc.X2X1mX1", "X-param.dc.X3X2X1mX2X1", "X+param.dc.X3X2X1mX2X1",
543  "X-param.dc.X4X3X2X1mX3X2X1", "X+param.dc.X4X3X2X1mX3X2X1"]
544 
545  cond = ""
546  cond += "#ifdef MULTI_GPU\n"
547  cond += "if ( (kernel_type == INTERIOR_KERNEL && (!param.ghostDim["+`dir/2`+"] || "+interior[dir]+")) ||\n"
548  cond += " (kernel_type == EXTERIOR_KERNEL_"+dim[dir/2]+" && "+boundary[dir]+") )\n"
549  cond += "#endif\n"
550 
551  str = ""
552 
553  projName = "P"+`dir/2`+["-","+"][projIdx%2]
554  str += "// Projector "+projName+"\n"
555  for l in projStr.splitlines():
556  str += "//"+l+"\n"
557  str += "\n"
558 
559  str += "#ifdef MULTI_GPU\n"
560  str += "const int sp_idx = (kernel_type == INTERIOR_KERNEL) ? ("+boundary[dir]+" ? "+sp_idx_wrap[dir]+" : "+sp_idx[dir]+") >> 1 :\n"
561  str += " face_idx + param.ghostOffset[static_cast<int>(kernel_type)][" + `(dir+1)%2` + "];\n"
562  str += "#if (DD_PREC==2) // half precision\n"
563  str += "const int sp_norm_idx = face_idx + param.ghostNormOffset[static_cast<int>(kernel_type)][" + `(dir+1)%2` + "];\n"
564  str += "#endif\n"
565  str += "#else\n"
566  str += "const int sp_idx = ("+boundary[dir]+" ? "+sp_idx_wrap[dir]+" : "+sp_idx[dir]+") >> 1;\n"
567  str += "#endif\n"
568 
569  str += "\n"
570  if dir % 2 == 0:
571  if domain_wall: str += "const int ga_idx = sid % param.dc.volume_4d_cb;\n"
572  else: str += "const int ga_idx = sid;\n"
573  else:
574  str += "#ifdef MULTI_GPU\n"
575  if domain_wall: str += "const int ga_idx = ((kernel_type == INTERIOR_KERNEL) ? sp_idx % param.dc.volume_4d_cb : param.dc.volume_4d_cb+(face_idx % param.dc.ghostFace[static_cast<int>(kernel_type)]));\n"
576  else: str += "const int ga_idx = ((kernel_type == INTERIOR_KERNEL) ? sp_idx : param.dc.volume_4d_cb+face_idx);\n"
577  str += "#else\n"
578  if domain_wall: str += "const int ga_idx = sp_idx % param.dc.volume_4d_cb;\n"
579  else: str += "const int ga_idx = sp_idx;\n"
580  str += "#endif\n"
581  str += "\n"
582 
583  # scan the projector to determine which loads are required
584  row_cnt = ([0,0,0,0])
585  for h in range(0,4):
586  for s in range(0,4):
587  re = proj(h,s).real
588  im = proj(h,s).imag
589  if re != 0 or im != 0:
590  row_cnt[h] += 1
591  row_cnt[0] += row_cnt[1]
592  row_cnt[2] += row_cnt[3]
593 
594  decl_half = ""
595  for h in range(0, 2):
596  for c in range(0, 3):
597  decl_half += "spinorFloat "+h1_re(h,c)+", "+h1_im(h,c)+";\n";
598  decl_half += "\n"
599 
600  load_spinor = "// read spinor from device memory\n"
601  if row_cnt[0] == 0:
602  load_spinor += "READ_SPINOR_DOWN(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
603  elif row_cnt[2] == 0:
604  load_spinor += "READ_SPINOR_UP(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
605  else:
606  load_spinor += "READ_SPINOR(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n"
607  load_spinor += "\n"
608 
609  load_half = ""
610  if domain_wall :
611  load_half += "const int sp_stride_pad = param.dc.Ls*param.dc.ghostFace[static_cast<int>(kernel_type)];\n"
612  else :
613  load_half += "const int sp_stride_pad = param.dc.ghostFace[static_cast<int>(kernel_type)];\n"
614  #load_half += "#if (DD_PREC==2) // half precision\n"
615  #load_half += "const int sp_norm_idx = sid + param.ghostNormOffset[static_cast<int>(kernel_type)];\n"
616  #load_half += "#endif\n"
617 
618  if dir >= 6: load_half += "const int t_proj_scale = TPROJSCALE;\n"
619  load_half += "\n"
620  load_half += "// read half spinor from device memory\n"
621 
622 # we have to use the same volume index for backwards and forwards gathers
623 # instead of using READ_UP_SPINOR and READ_DOWN_SPINOR, just use READ_HALF_SPINOR with the appropriate shift
624  load_half += "READ_SPINOR_GHOST(GHOSTSPINORTEX, sp_stride_pad, sp_idx, sp_norm_idx, "+`dir`+");\n\n"
625 # if (dir+1) % 2 == 0: load_half += "READ_HALF_SPINOR(SPINORTEX, sp_stride_pad, sp_idx, sp_norm_idx);\n\n"
626 # else: load_half += "READ_HALF_SPINOR(SPINORTEX, sp_stride_pad, sp_idx + (SPINOR_HOP/2)*sp_stride_pad, sp_norm_idx);\n\n"
627  load_gauge = "// read gauge matrix from device memory\n"
628  load_gauge += "ASSN_GAUGE_MATRIX(G, GAUGE"+`( dir%2)`+"TEX, "+`dir`+", ga_idx, param.gauge_stride);\n\n"
629 
630  reconstruct_gauge = "// reconstruct gauge matrix\n"
631  reconstruct_gauge += "RECONSTRUCT_GAUGE_MATRIX("+`dir`+");\n\n"
632 
633  project = "// project spinor into half spinors\n"
634  for h in range(0, 2):
635  for c in range(0, 3):
636  strRe = ""
637  strIm = ""
638  for s in range(0, 4):
639  re = proj(h,s).real
640  im = proj(h,s).imag
641  if re==0 and im==0: ()
642  elif im==0:
643  strRe += sign(re)+in_re(s,c)
644  strIm += sign(re)+in_im(s,c)
645  elif re==0:
646  strRe += sign(-im)+in_im(s,c)
647  strIm += sign(im)+in_re(s,c)
648  if row_cnt[0] == 0: # projector defined on lower half only
649  for s in range(0, 4):
650  re = proj(h+2,s).real
651  im = proj(h+2,s).imag
652  if re==0 and im==0: ()
653  elif im==0:
654  strRe += sign(re)+in_re(s,c)
655  strIm += sign(re)+in_im(s,c)
656  elif re==0:
657  strRe += sign(-im)+in_im(s,c)
658  strIm += sign(im)+in_re(s,c)
659 
660  project += h1_re(h,c)+" = "+strRe+";\n"
661  project += h1_im(h,c)+" = "+strIm+";\n"
662 
663  copy_half = ""
664  for h in range(0, 2):
665  for c in range(0, 3):
666  copy_half += h1_re(h,c)+" = "+("t_proj_scale*" if (dir >= 6) else "")+in_re(h,c)+"; "
667  copy_half += h1_im(h,c)+" = "+("t_proj_scale*" if (dir >= 6) else "")+in_im(h,c)+";\n"
668  copy_half += "\n"
669 
670  prep_half = ""
671  prep_half += "#ifdef MULTI_GPU\n"
672  prep_half += "if (kernel_type == INTERIOR_KERNEL) {\n"
673  prep_half += "#endif\n"
674  prep_half += "\n"
675  prep_half += indent(load_spinor)
676  prep_half += indent(project)
677  prep_half += "\n"
678  prep_half += "#ifdef MULTI_GPU\n"
679  prep_half += "} else {\n"
680  prep_half += "\n"
681  prep_half += indent(load_half)
682  prep_half += indent(copy_half)
683  prep_half += "}\n"
684  prep_half += "#endif // MULTI_GPU\n"
685  prep_half += "\n"
686 
687  ident = "// identity gauge matrix\n"
688  for m in range(0,3):
689  for h in range(0,2):
690  ident += "spinorFloat "+h2_re(h,m)+" = " + h1_re(h,m) + "; "
691  ident += "spinorFloat "+h2_im(h,m)+" = " + h1_im(h,m) + ";\n"
692  ident += "\n"
693 
694  mult = ""
695  for m in range(0,3):
696  mult += "// multiply row "+`m`+"\n"
697  for h in range(0,2):
698  re = "spinorFloat "+h2_re(h,m)+" = 0;\n"
699  im = "spinorFloat "+h2_im(h,m)+" = 0;\n"
700  for c in range(0,3):
701  re += h2_re(h,m) + " += " + g_re(dir,m,c) + " * "+h1_re(h,c)+";\n"
702  re += h2_re(h,m) + " -= " + g_im(dir,m,c) + " * "+h1_im(h,c)+";\n"
703  im += h2_im(h,m) + " += " + g_re(dir,m,c) + " * "+h1_im(h,c)+";\n"
704  im += h2_im(h,m) + " += " + g_im(dir,m,c) + " * "+h1_re(h,c)+";\n"
705  mult += re + im
706  mult += "\n"
707 
708  reconstruct = ""
709  for m in range(0,3):
710 
711  for h in range(0,2):
712  h_out = h
713  if row_cnt[0] == 0: # projector defined on lower half only
714  h_out = h+2
715  reconstruct += out_re(h_out, m) + " += " + h2_re(h,m) + ";\n"
716  reconstruct += out_im(h_out, m) + " += " + h2_im(h,m) + ";\n"
717 
718  for s in range(2,4):
719  (h,c) = row(s)
720  re = c.real
721  im = c.imag
722  if im == 0 and re == 0:
723  ()
724  elif im == 0:
725  reconstruct += out_re(s, m) + " " + sign(re) + "= " + h2_re(h,m) + ";\n"
726  reconstruct += out_im(s, m) + " " + sign(re) + "= " + h2_im(h,m) + ";\n"
727  elif re == 0:
728  reconstruct += out_re(s, m) + " " + sign(-im) + "= " + h2_im(h,m) + ";\n"
729  reconstruct += out_im(s, m) + " " + sign(+im) + "= " + h2_re(h,m) + ";\n"
730 
731  if ( m < 2 ): reconstruct += "\n"
732 
733  if dir >= 6:
734  str += "if (param.gauge_fixed && ga_idx < param.dc.X4X3X2X1hmX3X2X1h)\n"
735  str += block(decl_half + prep_half + ident + reconstruct)
736  str += " else "
737  str += block(load_gauge + decl_half + prep_half + reconstruct_gauge + mult + reconstruct)
738  else:
739  str += load_gauge + decl_half + prep_half + reconstruct_gauge + mult + reconstruct
740 
741  if pack_only:
742  out = load_spinor + decl_half + project
743  out = out.replace("sp_idx", "idx")
744  return out
745  else:
746  return cond + block(str)+"\n\n"
747 # end def gen
748 
749 def gen_dw():
750  if dagger: lsign='-'; ledge = '0'; rsign='+'; redge='param.dc.Ls-1'
751  else: lsign='+'; ledge = 'param.dc.Ls-1'; rsign='-'; redge='0'
752 
753  str = "\n\n"
754  str += "// 5th dimension -- NB: not partitionable!\n"
755  if normalDWF:
756  str += "#ifdef MULTI_GPU\nif(kernel_type == INTERIOR_KERNEL)\n#endif\n"
757  str += "{\n// 2 P_L = 2 P_- = ( ( +1, -1 ), ( -1, +1 ) )\n"
758  str += " {\n"
759  str += " int sp_idx = ( coord[4] == %s ? X%s(param.dc.Ls-1)*2*param.dc.volume_4d_cb : X%s2*param.dc.volume_4d_cb ) / 2;\n" % (ledge, rsign, lsign)
760  str += "\n"
761  str += "// read spinor from device memory\n"
762  str += " READ_SPINOR( SPINORTEX, param.sp_stride, sp_idx, sp_idx );\n"
763  str += "\n"
764  str += " if ( coord[4] != %s )\n" % ledge
765  str += " {\n"
766 
767  def proj(i,j):
768  return two_P_L[4*i+j]
769 
770  # xs != 0:
771  out_L = ""
772  for s1 in range(0,4):
773  #{
774  for c in range(0,3):
775  re_rhs, im_rhs = "", ""
776  for s2 in range(0,4):
777  re, im = proj(s1,s2).real, proj(s1,s2).imag
778  if re != 0 :
779  re_rhs += sign(re) + in_re(s2,c)
780  im_rhs += sign(re) + in_im(s2,c)
781  if im != 0 :
782  re_rhs += sign(-im) + in_im(s2,c)
783  im_rhs += sign(im) + in_re(s2,c)
784  out_L += 3*" " + out_re(s1,c) + " += " + re_rhs + ";"
785  out_L += 3*" " + out_im(s1,c) + " += " + im_rhs + ";\n"
786  if s1 < 3 : out_L += "\n"
787  #}
788 
789  str += out_L
790 
791  str += " }\n"
792  str += " else\n"
793  str += " {\n"
794 
795  # xs == 0:
796  str += out_L.replace(" += "," += -mferm*(").replace(";",");")
797 
798  str += " } // end if ( coord[4] != %s )\n" % ledge
799  str += " } // end P_L\n\n"
800  str += " // 2 P_R = 2 P_+ = ( ( +1, +1 ), ( +1, +1 ) )\n"
801  str += " {\n"
802  str += " int sp_idx = ( coord[4] == %s ? X%s(param.dc.Ls-1)*2*param.dc.volume_4d_cb : X%s2*param.dc.volume_4d_cb ) / 2;\n" % (redge, lsign, rsign)
803  str += "\n"
804  str += "// read spinor from device memory\n"
805  str += " READ_SPINOR( SPINORTEX, param.sp_stride, sp_idx, sp_idx );\n"
806  str += "\n"
807  str += " if ( coord[4] != %s )\n" % redge
808  str += " {\n"
809 
810  # xs != Ls-1
811  str += out_L.replace("-","+")
812 
813  str += " }\n"
814  str += " else\n"
815  str += " {\n"
816 
817  # xs == Ls-1
818  str += out_L.replace("-","+").replace(" += "," += -mferm*(").replace(";",");")
819 
820  str += " } // end if ( coord[4] != %s )\n" % redge
821  str += " } // end P_R\n\n"
822 
823  if dslash5:
824  str += " // MDWF Dslash_5 operator is given as follow\n"
825  str += " // Dslash4pre = [c_5(s)(P_+\delta_{s,s`+1} - mP_+\delta_{s,0}\delta_{s`,L_s-1}\n"
826  str += " // + P_-\delta_{s,s`-1}-mP_-\delta_{s,L_s-1}\delta_{s`,0})\n"
827  str += " // + b_5(s)\delta_{s,s`}]\delta_{x,x`}\n"
828  str += " // For Dslash4pre\n"
829  str += " // C_5 \equiv c_5(s)*0.5\n"
830  str += " // B_5 \equiv b_5(s)\n"
831  str += " // For Dslash5\n"
832  str += " // C_5 \equiv 0.5*{c_5(s)(4+M_5)-1}/{b_5(s)(4+M_5)+1}\n"
833  str += " // B_5 \equiv 1.0\n"
834  str += "#ifdef MDWF_mode // Check whether MDWF option is enabled\n"
835  str += "#if (MDWF_mode==1)\n"
836  str += " VOLATILE spinorFloat C_5;\n"
837  str += " VOLATILE spinorFloat B_5;\n"
838  str += " C_5 = mdwf_c5[ coord[4] ]*static_cast<spinorFloat>(0.5);\n"
839  str += " B_5 = mdwf_b5[ coord[4] ];\n\n"
840  str += " READ_SPINOR( SPINORTEX, param.sp_stride, X/2, X/2 );\n"
841  str += " o00_re = C_5*o00_re + B_5*i00_re;\n"
842  str += " o00_im = C_5*o00_im + B_5*i00_im;\n"
843  str += " o01_re = C_5*o01_re + B_5*i01_re;\n"
844  str += " o01_im = C_5*o01_im + B_5*i01_im;\n"
845  str += " o02_re = C_5*o02_re + B_5*i02_re;\n"
846  str += " o02_im = C_5*o02_im + B_5*i02_im;\n"
847  str += " o10_re = C_5*o10_re + B_5*i10_re;\n"
848  str += " o10_im = C_5*o10_im + B_5*i10_im;\n"
849  str += " o11_re = C_5*o11_re + B_5*i11_re;\n"
850  str += " o11_im = C_5*o11_im + B_5*i11_im;\n"
851  str += " o12_re = C_5*o12_re + B_5*i12_re;\n"
852  str += " o12_im = C_5*o12_im + B_5*i12_im;\n"
853  str += " o20_re = C_5*o20_re + B_5*i20_re;\n"
854  str += " o20_im = C_5*o20_im + B_5*i20_im;\n"
855  str += " o21_re = C_5*o21_re + B_5*i21_re;\n"
856  str += " o21_im = C_5*o21_im + B_5*i21_im;\n"
857  str += " o22_re = C_5*o22_re + B_5*i22_re;\n"
858  str += " o22_im = C_5*o22_im + B_5*i22_im;\n"
859  str += " o30_re = C_5*o30_re + B_5*i30_re;\n"
860  str += " o30_im = C_5*o30_im + B_5*i30_im;\n"
861  str += " o31_re = C_5*o31_re + B_5*i31_re;\n"
862  str += " o31_im = C_5*o31_im + B_5*i31_im;\n"
863  str += " o32_re = C_5*o32_re + B_5*i32_re;\n"
864  str += " o32_im = C_5*o32_im + B_5*i32_im;\n"
865  str += "#elif (MDWF_mode==2)\n"
866  str += " VOLATILE spinorFloat C_5;\n"
867  str += " C_5 = static_cast<spinorFloat>(0.5)*(mdwf_c5[ coord[4] ]*(m5+static_cast<spinorFloat>(4.0)) - static_cast<spinorFloat>(1.0))/(mdwf_b5[ coord[4] ]*(m5+static_cast<spinorFloat>(4.0)) + static_cast<spinorFloat>(1.0));\n\n"
868  str += " READ_SPINOR( SPINORTEX, param.sp_stride, X/2, X/2 );\n"
869  str += " o00_re = C_5*o00_re + i00_re;\n"
870  str += " o00_im = C_5*o00_im + i00_im;\n"
871  str += " o01_re = C_5*o01_re + i01_re;\n"
872  str += " o01_im = C_5*o01_im + i01_im;\n"
873  str += " o02_re = C_5*o02_re + i02_re;\n"
874  str += " o02_im = C_5*o02_im + i02_im;\n"
875  str += " o10_re = C_5*o10_re + i10_re;\n"
876  str += " o10_im = C_5*o10_im + i10_im;\n"
877  str += " o11_re = C_5*o11_re + i11_re;\n"
878  str += " o11_im = C_5*o11_im + i11_im;\n"
879  str += " o12_re = C_5*o12_re + i12_re;\n"
880  str += " o12_im = C_5*o12_im + i12_im;\n"
881  str += " o20_re = C_5*o20_re + i20_re;\n"
882  str += " o20_im = C_5*o20_im + i20_im;\n"
883  str += " o21_re = C_5*o21_re + i21_re;\n"
884  str += " o21_im = C_5*o21_im + i21_im;\n"
885  str += " o22_re = C_5*o22_re + i22_re;\n"
886  str += " o22_im = C_5*o22_im + i22_im;\n"
887  str += " o30_re = C_5*o30_re + i30_re;\n"
888  str += " o30_im = C_5*o30_im + i30_im;\n"
889  str += " o31_re = C_5*o31_re + i31_re;\n"
890  str += " o31_im = C_5*o31_im + i31_im;\n"
891  str += " o32_re = C_5*o32_re + i32_re;\n"
892  str += " o32_im = C_5*o32_im + i32_im;\n"
893  str += "#endif // select MDWF mode\n"
894  str += "#endif // check MDWF on/off\n"
895  str += "} // end 5th dimension\n\n"
896 
897  return str
898 # end def gen_dw
899 
901 
902  str = "\n"
903  str += "VOLATILE spinorFloat kappa;\n\n"
904  str += "#ifdef MDWF_mode // Check whether MDWF option is enabled\n"
905  str += " kappa = -(mdwf_c5[ coord[4] ]*(static_cast<spinorFloat>(4.0) + m5) - static_cast<spinorFloat>(1.0))/(mdwf_b5[ coord[4] ]*(static_cast<spinorFloat>(4.0) + m5) + static_cast<spinorFloat>(1.0));\n"
906  str += "#else\n"
907  str += " kappa = static_cast<spinorFloat>(2.0)*a;\n"
908  str += "#endif // select MDWF mode\n\n"
909  str += "// M5_inv operation -- NB: not partitionable!\n\n"
910  str += "// In this part, we will do the following operation in parallel way.\n\n"
911  str += "// w = M5inv * v\n"
912  str += "// 'w' means output vector\n"
913  str += "// 'v' means input vector\n"
914  str += "{\n"
915  str += " int base_idx = sid%param.dc.volume_4d_cb;\n"
916  str += " int sp_idx;\n\n"
917  str += "// let's assume the index,\n"
918  str += "// s = output vector index,\n"
919  str += "// s' = input vector index and\n"
920  str += "// 'a'= kappa5\n"
921  str += "\n"
922  str += " spinorFloat inv_d_n = static_cast<spinorFloat>(0.5) / ( static_cast<spinorFloat>(1.0) + POW(kappa,param.dc.Ls)*mferm );\n"
923  str += " spinorFloat factorR;\n"
924  str += " spinorFloat factorL;\n"
925  str += "\n"
926  str += " for(int s = 0; s < param.dc.Ls; s++)\n {\n"
927  if dagger == True :
928  str += " int exponent = coord[4] > s ? param.dc.Ls-coord[4]+s : s-coord[4];\n"
929  str += " factorR = inv_d_n * POW(kappa,exponent) * ( coord[4] > s ? -mferm : static_cast<spinorFloat>(1.0) );\n\n"
930  else :
931  str += " int exponent = coord[4] < s ? param.dc.Ls-s+coord[4] : coord[4]-s;\n"
932  str += " factorR = inv_d_n * POW(kappa,exponent) * ( coord[4] < s ? -mferm : static_cast<spinorFloat>(1.0) );\n\n"
933  str += " sp_idx = base_idx + s*param.dc.volume_4d_cb;\n"
934  str += " // read spinor from device memory\n"
935  str += " READ_SPINOR( SPINORTEX, param.sp_stride, sp_idx, sp_idx );\n\n"
936  str += " o00_re += factorR*(i00_re + i20_re);\n"
937  str += " o00_im += factorR*(i00_im + i20_im);\n"
938  str += " o20_re += factorR*(i00_re + i20_re);\n"
939  str += " o20_im += factorR*(i00_im + i20_im);\n"
940  str += " o01_re += factorR*(i01_re + i21_re);\n"
941  str += " o01_im += factorR*(i01_im + i21_im);\n"
942  str += " o21_re += factorR*(i01_re + i21_re);\n"
943  str += " o21_im += factorR*(i01_im + i21_im);\n"
944  str += " o02_re += factorR*(i02_re + i22_re);\n"
945  str += " o02_im += factorR*(i02_im + i22_im);\n"
946  str += " o22_re += factorR*(i02_re + i22_re);\n"
947  str += " o22_im += factorR*(i02_im + i22_im);\n"
948  str += " o10_re += factorR*(i10_re + i30_re);\n"
949  str += " o10_im += factorR*(i10_im + i30_im);\n"
950  str += " o30_re += factorR*(i10_re + i30_re);\n"
951  str += " o30_im += factorR*(i10_im + i30_im);\n"
952  str += " o11_re += factorR*(i11_re + i31_re);\n"
953  str += " o11_im += factorR*(i11_im + i31_im);\n"
954  str += " o31_re += factorR*(i11_re + i31_re);\n"
955  str += " o31_im += factorR*(i11_im + i31_im);\n"
956  str += " o12_re += factorR*(i12_re + i32_re);\n"
957  str += " o12_im += factorR*(i12_im + i32_im);\n"
958  str += " o32_re += factorR*(i12_re + i32_re);\n"
959  str += " o32_im += factorR*(i12_im + i32_im);\n\n"
960 
961  if dagger == True :
962  str += " int exponent2 = coord[4] < s ? param.dc.Ls-s+coord[4] : coord[4]-s;\n"
963  str += " factorL = inv_d_n * POW(kappa,exponent2) * ( coord[4] < s ? -mferm : static_cast<spinorFloat>(1.0));\n\n"
964  else :
965  str += " int exponent2 = coord[4] > s ? param.dc.Ls-coord[4]+s : s-coord[4];\n"
966  str += " factorL = inv_d_n * POW(kappa,exponent2) * ( coord[4] > s ? -mferm : static_cast<spinorFloat>(1.0));\n\n"
967 
968  str += " o00_re += factorL*(i00_re - i20_re);\n"
969  str += " o00_im += factorL*(i00_im - i20_im);\n"
970  str += " o01_re += factorL*(i01_re - i21_re);\n"
971  str += " o01_im += factorL*(i01_im - i21_im);\n"
972  str += " o02_re += factorL*(i02_re - i22_re);\n"
973  str += " o02_im += factorL*(i02_im - i22_im);\n"
974  str += " o10_re += factorL*(i10_re - i30_re);\n"
975  str += " o10_im += factorL*(i10_im - i30_im);\n"
976  str += " o11_re += factorL*(i11_re - i31_re);\n"
977  str += " o11_im += factorL*(i11_im - i31_im);\n"
978  str += " o12_re += factorL*(i12_re - i32_re);\n"
979  str += " o12_im += factorL*(i12_im - i32_im);\n"
980  str += " o20_re += factorL*(i20_re - i00_re);\n"
981  str += " o20_im += factorL*(i20_im - i00_im);\n"
982  str += " o21_re += factorL*(i21_re - i01_re);\n"
983  str += " o21_im += factorL*(i21_im - i01_im);\n"
984  str += " o22_re += factorL*(i22_re - i02_re);\n"
985  str += " o22_im += factorL*(i22_im - i02_im);\n"
986  str += " o30_re += factorL*(i30_re - i10_re);\n"
987  str += " o30_im += factorL*(i30_im - i10_im);\n"
988  str += " o31_re += factorL*(i31_re - i11_re);\n"
989  str += " o31_im += factorL*(i31_im - i11_im);\n"
990  str += " o32_re += factorL*(i32_re - i12_re);\n"
991  str += " o32_im += factorL*(i32_im - i12_im);\n"
992  str += " }\n"
993  str += "} // end of M5inv dimension\n\n"
994  str += "#undef POW\n"
995 
996  return str
997 # end def gen_dw_inv
998 
999 
1000 def input_spinor(s,c,z):
1001  if dslash:
1002  if z==0: return out_re(s,c)
1003  else: return out_im(s,c)
1004  else:
1005  if z==0: return in_re(s,c)
1006  else: return in_im(s,c)
1007 
1009  str = ""
1010  str += "spinorFloat "+a_re(0,0,c)+" = -"+input_spinor(1,c,0)+" - "+input_spinor(3,c,0)+";\n"
1011  str += "spinorFloat "+a_im(0,0,c)+" = -"+input_spinor(1,c,1)+" - "+input_spinor(3,c,1)+";\n"
1012  str += "spinorFloat "+a_re(0,1,c)+" = "+input_spinor(0,c,0)+" + "+input_spinor(2,c,0)+";\n"
1013  str += "spinorFloat "+a_im(0,1,c)+" = "+input_spinor(0,c,1)+" + "+input_spinor(2,c,1)+";\n"
1014  str += "spinorFloat "+a_re(0,2,c)+" = -"+input_spinor(1,c,0)+" + "+input_spinor(3,c,0)+";\n"
1015  str += "spinorFloat "+a_im(0,2,c)+" = -"+input_spinor(1,c,1)+" + "+input_spinor(3,c,1)+";\n"
1016  str += "spinorFloat "+a_re(0,3,c)+" = "+input_spinor(0,c,0)+" - "+input_spinor(2,c,0)+";\n"
1017  str += "spinorFloat "+a_im(0,3,c)+" = "+input_spinor(0,c,1)+" - "+input_spinor(2,c,1)+";\n"
1018  str += "\n"
1019 
1020  for s in range (0,4):
1021  str += out_re(s,c)+" = "+a_re(0,s,c)+"; "
1022  str += out_im(s,c)+" = "+a_im(0,s,c)+";\n"
1023 
1024  return block(str)+"\n\n"
1025 # end def to_chiral_basis
1026 
1027 
1028 def from_chiral_basis(c): # note: factor of 1/2 is included in clover term normalization
1029  str = ""
1030  str += "spinorFloat "+a_re(0,0,c)+" = "+out_re(1,c)+" + "+out_re(3,c)+";\n"
1031  str += "spinorFloat "+a_im(0,0,c)+" = "+out_im(1,c)+" + "+out_im(3,c)+";\n"
1032  str += "spinorFloat "+a_re(0,1,c)+" = -"+out_re(0,c)+" - "+out_re(2,c)+";\n"
1033  str += "spinorFloat "+a_im(0,1,c)+" = -"+out_im(0,c)+" - "+out_im(2,c)+";\n"
1034  str += "spinorFloat "+a_re(0,2,c)+" = "+out_re(1,c)+" - "+out_re(3,c)+";\n"
1035  str += "spinorFloat "+a_im(0,2,c)+" = "+out_im(1,c)+" - "+out_im(3,c)+";\n"
1036  str += "spinorFloat "+a_re(0,3,c)+" = -"+out_re(0,c)+" + "+out_re(2,c)+";\n"
1037  str += "spinorFloat "+a_im(0,3,c)+" = -"+out_im(0,c)+" + "+out_im(2,c)+";\n"
1038  str += "\n"
1039 
1040  for s in range (0,4):
1041  str += out_re(s,c)+" = "+a_re(0,s,c)+"; "
1042  str += out_im(s,c)+" = "+a_im(0,s,c)+";\n"
1043 
1044  return block(str)+"\n\n"
1045 # end def from_chiral_basis
1046 
1047 
1048 def clover_mult(chi):
1049  str = "READ_CLOVER(CLOVERTEX, "+`chi`+")\n\n"
1050 
1051  for s in range (0,2):
1052  for c in range (0,3):
1053  str += "spinorFloat "+a_re(chi,s,c)+" = 0; spinorFloat "+a_im(chi,s,c)+" = 0;\n"
1054  str += "\n"
1055 
1056  for sm in range (0,2):
1057  for cm in range (0,3):
1058  for sn in range (0,2):
1059  for cn in range (0,3):
1060  str += a_re(chi,sm,cm)+" += "+c_re(chi,sm,cm,sn,cn)+" * "+out_re(2*chi+sn,cn)+";\n"
1061  if (sn != sm) or (cn != cm):
1062  str += a_re(chi,sm,cm)+" -= "+c_im(chi,sm,cm,sn,cn)+" * "+out_im(2*chi+sn,cn)+";\n"
1063  #else: str += ";\n"
1064  str += a_im(chi,sm,cm)+" += "+c_re(chi,sm,cm,sn,cn)+" * "+out_im(2*chi+sn,cn)+";\n"
1065  if (sn != sm) or (cn != cm):
1066  str += a_im(chi,sm,cm)+" += "+c_im(chi,sm,cm,sn,cn)+" * "+out_re(2*chi+sn,cn)+";\n"
1067  #else: str += ";\n"
1068  str += "\n"
1069 
1070  for s in range (0,2):
1071  for c in range (0,3):
1072  str += out_re(2*chi+s,c)+" = "+a_re(chi,s,c)+"; "
1073  str += out_im(2*chi+s,c)+" = "+a_im(chi,s,c)+";\n"
1074  str += "\n"
1075 
1076  return block(str)+"\n\n"
1077 # end def clover_mult
1078 
1079 
1081  if domain_wall: return ""
1082  str = ""
1083  if dslash: str += "#ifdef DSLASH_CLOVER\n\n"
1084  str += "// change to chiral basis\n"
1085  str += to_chiral_basis(0) + to_chiral_basis(1) + to_chiral_basis(2)
1086  str += "// apply first chiral block\n"
1087  str += clover_mult(0)
1088  str += "// apply second chiral block\n"
1089  str += clover_mult(1)
1090  str += "// change back from chiral basis\n"
1091  str += "// (note: required factor of 1/2 is included in clover term normalization)\n"
1093  if dslash: str += "#endif // DSLASH_CLOVER\n\n"
1094 
1095  return str
1096 # end def clover
1097 
1098 
1099 def coeff():
1100  if dslash4:
1101  str = "coeff"
1102  elif dslash5 or dslash5inv:
1103  str = "coeff"
1104  else :
1105  str = "a"
1106  return str
1107 # end def coeff()
1108 
1109 def ypax():
1110  str = ""
1111  str += "#ifdef SPINOR_DOUBLE\n"
1112 
1113  for s in range(0,4):
1114  for c in range(0,3):
1115  i = 3*s+c
1116  str += out_re(s,c) +" = "+out_re(s,c)+" + coeff*accum"+nthFloat2(2*i+0)+";\n"
1117  str += out_im(s,c) +" = "+out_im(s,c)+" + coeff*accum"+nthFloat2(2*i+1)+";\n"
1118 
1119  str += "#else\n"
1120 
1121  for s in range(0,4):
1122  for c in range(0,3):
1123  i = 3*s+c
1124  str += out_re(s,c) +" = "+out_re(s,c)+" + coeff*accum"+nthFloat4(2*i+0)+";\n"
1125  str += out_im(s,c) +" = "+out_im(s,c)+" + coeff*accum"+nthFloat4(2*i+1)+";\n"
1126 
1127  str += "#endif // SPINOR_DOUBLE\n"
1128  return str
1129 # end def ypax
1130 
1131 def xpay():
1132  str = ""
1133  str += "#ifdef DSLASH_XPAY\n"
1134  str += "READ_ACCUM(ACCUMTEX, param.sp_stride)\n"
1135 
1136  if dslash4:
1137  str += "VOLATILE spinorFloat coeff;\n\n"
1138  str += "#ifdef MDWF_mode\n"
1139  str += "coeff = static_cast<spinorFloat>(0.5)*a/(mdwf_b5[coord[4]]*(m5+static_cast<spinorFloat>(4.0)) + static_cast<spinorFloat>(1.0));\n"
1140  str += "#else\n"
1141  str += "coeff = a;\n"
1142  str += "#endif\n\n"
1143  elif dslash5 or dslash5inv:
1144  str += "VOLATILE spinorFloat coeff;\n\n"
1145  str += "#ifdef MDWF_mode\n"
1146  str += "coeff = static_cast<spinorFloat>(0.5)/(mdwf_b5[coord[4]]*(m5+static_cast<spinorFloat>(4.0)) + static_cast<spinorFloat>(1.0));\n"
1147  str += "coeff *= coeff;\n"
1148  str += "coeff *= a;\n"
1149  str += "#else\n"
1150  if dslash5:
1151  str += "coeff = a;\n"
1152  elif dslash5inv:
1153  str += "coeff = b;\n"
1154  str += "#endif\n\n"
1155  if dslash5:
1156  str += "#ifdef YPAX\n"
1157  str += ypax()
1158  str += "#else\n"
1159 
1160  str += "#ifdef SPINOR_DOUBLE\n"
1161 
1162  for s in range(0,4):
1163  for c in range(0,3):
1164  i = 3*s+c
1165  str += out_re(s,c) +" = "+coeff()+"*"+out_re(s,c)+" + accum"+nthFloat2(2*i+0)+";\n"
1166  str += out_im(s,c) +" = "+coeff()+"*"+out_im(s,c)+" + accum"+nthFloat2(2*i+1)+";\n"
1167 
1168  str += "#else\n"
1169 
1170  for s in range(0,4):
1171  for c in range(0,3):
1172  i = 3*s+c
1173  str += out_re(s,c) +" = "+coeff()+"*"+out_re(s,c)+" + accum"+nthFloat4(2*i+0)+";\n"
1174  str += out_im(s,c) +" = "+coeff()+"*"+out_im(s,c)+" + accum"+nthFloat4(2*i+1)+";\n"
1175 
1176  str += "#endif // SPINOR_DOUBLE\n"
1177  if dslash5:
1178  str += "#endif // YPAX\n"
1179  str += "#endif // DSLASH_XPAY\n"
1180 
1181  return str
1182 # end def xpay
1183 
1184 
1185 def epilog():
1186  str = ""
1187  if dslash:
1188  if domain_wall:
1189  if dslash4:
1190  str += "#if defined MULTI_GPU && defined DSLASH_XPAY\n"
1191  else:
1192  str += "#if defined MULTI_GPU && (defined DSLASH_XPAY || defined DSLASH_CLOVER)\n"
1193  if dslash4:
1194  str += (
1195 """
1196 int incomplete = 0; // Have all 8 contributions been computed for this site?
1197 
1198 switch(kernel_type) { // intentional fall-through
1199 case INTERIOR_KERNEL:
1200  incomplete = incomplete || (param.commDim[3] && (coord[3]==0 || coord[3]==(param.dc.X[3]-1)));
1201 case EXTERIOR_KERNEL_T:
1202  incomplete = incomplete || (param.commDim[2] && (coord[2]==0 || coord[2]==(param.dc.X[2]-1)));
1203 case EXTERIOR_KERNEL_Z:
1204  incomplete = incomplete || (param.commDim[1] && (coord[1]==0 || coord[1]==(param.dc.X[1]-1)));
1205 case EXTERIOR_KERNEL_Y:
1206  incomplete = incomplete || (param.commDim[0] && (coord[0]==0 || coord[0]==(param.dc.X[0]-1)));
1207 }
1208 
1209 """)
1210  str += "if (!incomplete)\n"
1211  str += "#endif // MULTI_GPU\n"
1212 
1213  str += block( "\n" + (apply_clover()) + xpay() )
1214 
1215  str += "\n\n"
1216  str += "// write spinor field back to device memory\n"
1217  str += "WRITE_SPINOR(param.sp_stride);\n\n"
1218 
1219  str += "// undefine to prevent warning when precision is changed\n"
1220  str += "#undef m5\n"
1221  str += "#undef mdwf_b5\n"
1222  str += "#undef mdwf_c5\n"
1223  str += "#undef mferm\n"
1224  str += "#undef a\n"
1225  str += "#undef b\n"
1226  str += "#undef spinorFloat\n"
1227  str += "#undef POW\n"
1228  str += "#undef SHARED_STRIDE\n\n"
1229 
1230  if dslash:
1231  if dslash4 == True:
1232  for m in range(0,3):
1233  for n in range(0,3):
1234  i = 3*m+n
1235  str += "#undef "+g_re(0,m,n)+"\n"
1236  str += "#undef "+g_im(0,m,n)+"\n"
1237  str += "\n"
1238 
1239  for s in range(0,4):
1240  for c in range(0,3):
1241  i = 3*s+c
1242  str += "#undef "+in_re(s,c)+"\n"
1243  str += "#undef "+in_im(s,c)+"\n"
1244  str += "\n"
1245 
1246  if clover == True:
1247  for m in range(0,6):
1248  s = m/3
1249  c = m%3
1250  str += "#undef "+c_re(0,s,c,s,c)+"\n"
1251  for n in range(0,6):
1252  sn = n/3
1253  cn = n%3
1254  for m in range(n+1,6):
1255  sm = m/3
1256  cm = m%3
1257  str += "#undef "+c_re(0,sm,cm,sn,cn)+"\n"
1258  str += "#undef "+c_im(0,sm,cm,sn,cn)+"\n"
1259  str += "\n"
1260 
1261  for s in range(0,4):
1262  for c in range(0,3):
1263  i = 3*s+c
1264  if 2*i < sharedFloats:
1265  str += "#undef "+out_re(s,c)+"\n"
1266  if 2*i+1 < sharedFloats:
1267  str += "#undef "+out_im(s,c)+"\n"
1268  str += "\n"
1269 
1270  str += "#undef VOLATILE\n"
1271 
1272  return str
1273 # end def epilog
1274 
1275 
1276 def pack_face(facenum):
1277  str = "\n"
1278  str += "switch(dim) {\n"
1279  for dim in range(0,4):
1280  str += "case "+`dim`+":\n"
1281  proj = gen(2*dim+facenum, pack_only=True)
1282  proj += "\n"
1283  proj += "// write half spinor back to device memory\n"
1284  proj += "WRITE_HALF_SPINOR(face_volume, face_idx);\n"
1285  str += indent(block(proj)+"\n"+"break;\n")
1286  str += "}\n\n"
1287  return str
1288 # end def pack_face
1289 
1291  assert (sharedFloats == 0)
1292  str = ""
1293  str += def_input_spinor()
1294  str += "#include \"io_spinor.h\"\n\n"
1295 
1296  str += "if (face_num) "
1297  str += block(pack_face(1))
1298  str += " else "
1299  str += block(pack_face(0))
1300 
1301  str += "\n\n"
1302  str += "// undefine to prevent warning when precision is changed\n"
1303  str += "#undef spinorFloat\n"
1304  str += "#undef SHARED_STRIDE\n\n"
1305 
1306  for s in range(0,4):
1307  for c in range(0,3):
1308  i = 3*s+c
1309  str += "#undef "+in_re(s,c)+"\n"
1310  str += "#undef "+in_im(s,c)+"\n"
1311  str += "\n"
1312 
1313  return str
1314 # end def generate_pack
1315 
1317  r = prolog()
1318  for i in range(0,8) :
1319  r += gen( i )
1320  if domain_wall:
1321  r += gen_dw()
1322  r += epilog()
1323  return r
1324 
1326  r = prolog()
1327  for i in range(0,8) :
1328  r += gen( i )
1329  r += epilog()
1330  return r
1331 
1333  r = prolog()
1334  r += gen_dw()
1335  r += epilog()
1336  return r
1337 
1339  r = prolog()
1340  r += gen_dw_inv()
1341  r += epilog()
1342  return r
1343 
1345  return prolog() + epilog()
1346 
1347 # To fit 192 threads/SM (single precision) with 16K shared memory, set sharedFloats to 19 or smaller
1348 
1349 sharedFloats = 0
1350 cloverSharedFloats = 0
1351 if(len(sys.argv) > 1):
1352  if (sys.argv[1] == '--shared'):
1353  sharedFloats = int(sys.argv[2])
1354 print "Shared floats set to " + str(sharedFloats);
1355 
1356 # generate Domain Wall Dslash kernels
1357 domain_wall = True
1358 clover = False
1359 dslash4 = True
1360 dslash5 = False
1361 dslash5inv = False
1362 normalDWF = False
1363 
1364 print sys.argv[0] + ": generating dw_dslash4_core.h";
1365 dslash = True
1366 dagger = False
1367 f = open('dslash_core/dw_dslash4_core.h', 'w')
1368 f.write(generate_dslash4D())
1369 f.close()
1370 
1371 print sys.argv[0] + ": generating dw_dslash4_dagger_core.h";
1372 dslash = True
1373 dagger = True
1374 f = open('dslash_core/dw_dslash4_dagger_core.h', 'w')
1375 f.write(generate_dslash4D())
1376 f.close()
1377 
1378 dslash4 = False
1379 dslash5 = True
1380 
1381 print sys.argv[0] + ": generating dw_dslash5_core.h";
1382 dslash = True
1383 dagger = False
1384 f = open('dslash_core/dw_dslash5_core.h', 'w')
1385 f.write(generate_dslash5D())
1386 f.close()
1387 
1388 print sys.argv[0] + ": generating dw_dslash5_dagger_core.h";
1389 dslash = True
1390 dagger = True
1391 f = open('dslash_core/dw_dslash5_dagger_core.h', 'w')
1392 f.write(generate_dslash5D())
1393 f.close()
1394 
1395 dslash5 = False
1396 dslash5inv = True
1397 
1398 print sys.argv[0] + ": generating dw_dslash5inv_core.h";
1399 dslash = True
1400 dagger = False
1401 f = open('dslash_core/dw_dslash5inv_core.h', 'w')
1402 f.write(generate_dslash5D_inv())
1403 f.close()
1404 
1405 print sys.argv[0] + ": generating dw_dslash5inv_dagger_core.h";
1406 dslash = True
1407 dagger = True
1408 f = open('dslash_core/dw_dslash5inv_dagger_core.h', 'w')
1409 f.write(generate_dslash5D_inv())
1410 f.close()
def gen(dir, pack_only=False)
Definition: gen.py:1
def c_im(b, sm, cm, sn, cn)
def complexify(a)
complex numbers ######################################################################## ...
def c_re(b, sm, cm, sn, cn)
if(err !=cudaSuccess)
def indent(code, n=1)
code generation ######################################################################## ...