7 return [complex(x)
for x
in a]
11 if a ==
int(a):
return `
int(a)`
15 if a == 0:
return "0i" 16 elif a == -1:
return "-i" 17 elif a == 1:
return "i" 18 else:
return fltToString(a)+
"i" 22 if re == 0
and im == 0:
return "0" 23 elif re == 0:
return imToString(im)
24 elif im == 0:
return fltToString(re)
26 im_str =
"-"+imToString(-im)
if im < 0
else "+"+imToString(im)
27 return fltToString(re)+im_str
74 two_P_L = [ id[x] - igamma5[x]/1j
for x
in range(0,4*4) ]
75 two_P_R = [ id[x] + igamma5[x]/1j
for x
in range(0,4*4) ]
85 return [x+y
for (x,y)
in zip(g1,g2)]
88 return [x-y
for (x,y)
in zip(g1,g2)]
108 def indentline(line):
return (n*
" "+line
if ( line
and line.count(
"#", 0, 1) == 0)
else line)
109 return ''.join([indentline(line)+
"\n" for line
in code.splitlines()])
112 return "{\n"+
indent(code)+
"}" 116 elif x==-1:
return "-" 117 elif x==+2:
return "+2*" 118 elif x==-2:
return "-2*" 121 return `(n/4)` +
"." + [
"x",
"y",
"z",
"w"][n%4]
124 return `(n/2)` +
"." + [
"x",
"y"][n%2]
127 def in_re(s, c):
return "i"+`s`+`c`+
"_re" 128 def in_im(s, c):
return "i"+`s`+`c`+
"_im" 129 def g_re(d, m, n):
return (
"g" if (d%2==0)
else "gT")+`m`+`n`+
"_re" 130 def g_im(d, m, n):
return (
"g" if (d%2==0)
else "gT")+`m`+`n`+
"_im" 131 def out_re(s, c):
return "o"+`s`+`c`+
"_re" 132 def out_im(s, c):
return "o"+`s`+`c`+
"_im" 133 def h1_re(h, c):
return [
"a",
"b"][h]+`c`+
"_re" 134 def h1_im(h, c):
return [
"a",
"b"][h]+`c`+
"_im" 135 def h2_re(h, c):
return [
"A",
"B"][h]+`c`+
"_re" 136 def h2_im(h, c):
return [
"A",
"B"][h]+`c`+
"_im" 137 def c_re(b, sm, cm, sn, cn):
return "c"+`(sm+2*b)`+`cm`+
"_"+`(sn+2*b)`+`cn`+
"_re" 138 def c_im(b, sm, cm, sn, cn):
return "c"+`(sm+2*b)`+`cm`+
"_"+`(sn+2*b)`+`cn`+
"_im" 139 def a_re(b, s, c):
return "a"+`(s+2*b)`+`c`+
"_re" 140 def a_im(b, s, c):
return "a"+`(s+2*b)`+`c`+
"_im" 142 def tmp_re(s, c):
return "tmp"+`s`+`c`+
"_re" 143 def tmp_im(s, c):
return "tmp"+`s`+`c`+
"_im" 148 str +=
"// input spinor\n" 149 str +=
"#ifdef SPINOR_DOUBLE\n" 150 str +=
"#define spinorFloat double\n" 156 str +=
"#define m5 param.m5_d\n" 157 str +=
"#define mdwf_b5 param.mdwf_b5_d\n" 158 str +=
"#define mdwf_c5 param.mdwf_c5_d\n" 159 str +=
"#define mferm param.mferm\n" 160 str +=
"#define a param.a\n" 161 str +=
"#define b param.b\n" 163 str +=
"#define spinorFloat float\n" 169 str +=
"#define m5 param.m5_f\n" 170 str +=
"#define mdwf_b5 param.mdwf_b5_f\n" 171 str +=
"#define mdwf_c5 param.mdwf_c5_f\n" 172 str +=
"#define mferm param.mferm_f\n" 173 str +=
"#define a param.a_f\n" 174 str +=
"#define b param.b_f\n" 175 str +=
"#endif // SPINOR_DOUBLE\n\n" 181 str =
"// gauge link\n" 182 str +=
"#ifdef GAUGE_FLOAT2\n" 198 str +=
"#endif // GAUGE_DOUBLE\n\n" 200 str +=
"// conjugated gauge link\n" 204 str +=
"#define "+
g_re(1,m,n)+
" (+"+
g_re(0,n,m)+
")\n" 205 str +=
"#define "+
g_im(1,m,n)+
" (-"+
g_im(0,n,m)+
")\n" 213 str =
"// first chiral block of inverted clover term\n" 214 str +=
"#ifdef CLOVER_DOUBLE\n" 224 for m
in range(n+1,6):
228 str +=
"#define "+
c_im(0,sm,cm,sn,cn)+
" C"+
nthFloat2(i+1)+
"\n" 240 for m
in range(n+1,6):
244 str +=
"#define "+
c_im(0,sm,cm,sn,cn)+
" C"+
nthFloat4(i+1)+
"\n" 246 str +=
"#endif // CLOVER_DOUBLE\n\n" 254 str +=
"#define "+
c_re(0,sm,cm,sn,cn)+
" (+"+
c_re(0,sn,cn,sm,cm)+
")\n" 255 str +=
"#define "+
c_im(0,sm,cm,sn,cn)+
" (-"+
c_im(0,sn,cn,sm,cm)+
")\n" 258 str +=
"// second chiral block of inverted clover term (reuses C0,...,C9)\n" 265 str +=
"#define "+
c_re(1,sm,cm,sn,cn)+
" "+
c_re(0,sm,cm,sn,cn)+
"\n" 266 if m != n: str +=
"#define "+
c_im(1,sm,cm,sn,cn)+
" "+
c_im(0,sm,cm,sn,cn)+
"\n" 273 str =
"// output spinor\n" 277 if 2*i < sharedFloats:
278 str +=
"#define "+
out_re(s,c)+
" s["+`(2*i+0)`+
"*SHARED_STRIDE]\n" 280 str +=
"VOLATILE spinorFloat "+
out_re(s,c)+
";\n" 281 if 2*i+1 < sharedFloats:
282 str +=
"#define "+
out_im(s,c)+
" s["+`(2*i+1)`+
"*SHARED_STRIDE]\n" 284 str +=
"VOLATILE spinorFloat "+
out_im(s,c)+
";\n" 290 prolog_str = (
"#ifdef MULTI_GPU\n\n")
292 prolog_str+= (
"// *** CUDA DSLASH ***\n\n" if not dagger
else "// *** CUDA DSLASH DAGGER ***\n\n")
293 prolog_str+=
"#define DSLASH_SHARED_FLOATS_PER_THREAD "+str(sharedFloats)+
"\n\n" 295 prolog_str= (
"// *** CUDA CLOVER ***\n\n")
296 prolog_str+=
"#define CLOVER_SHARED_FLOATS_PER_THREAD "+str(sharedFloats)+
"\n\n" 298 print "Undefined prolog" 304 #if (CUDA_VERSION >= 4010) 307 #define VOLATILE volatile 321 #if (__COMPUTE_CAPABILITY__ >= 200) 322 #define SHARED_STRIDE 16 // to avoid bank conflicts on Fermi 324 #define SHARED_STRIDE 8 // to avoid bank conflicts on G80 and GT200 327 #if (__COMPUTE_CAPABILITY__ >= 200) 328 #define SHARED_STRIDE 32 // to avoid bank conflicts on Fermi 330 #define SHARED_STRIDE 16 // to avoid bank conflicts on G80 and GT200 338 extern __shared__ char s_data[]; 344 VOLATILE spinorFloat *s = (spinorFloat*)s_data + DSLASH_SHARED_FLOATS_PER_THREAD*SHARED_STRIDE*(threadIdx.x/SHARED_STRIDE) 345 + (threadIdx.x % SHARED_STRIDE); 350 VOLATILE spinorFloat *s = (spinorFloat*)s_data + CLOVER_SHARED_FLOATS_PER_THREAD*SHARED_STRIDE*(threadIdx.x/SHARED_STRIDE) 351 + (threadIdx.x % SHARED_STRIDE); 357 prolog_str +=
"\n#include \"read_gauge.h\"\n" 359 prolog_str +=
"#include \"read_clover.h\"\n" 360 prolog_str +=
"#include \"io_spinor.h\"\n" 364 #if (DD_PREC==2) // half precision 366 #endif // half precision 370 int sid = ((blockIdx.y*blockDim.y + threadIdx.y)*gridDim.x + blockIdx.x)*blockDim.x + threadIdx.x; 371 if (sid >= param.threads*param.dc.Ls) return; 404 dim = dimFromFaceIndex<5>(sid, param); // sid is also modified 406 //const int face_volume = (param.threads*param.dc.Ls >> 1); // volume of one face 407 const int face_volume = ((param.threadDimMapUpper[dim] - param.threadDimMapLower[dim])*param.dc.Ls >> 1); 409 const int face_num = (sid >= face_volume); // is this thread updating face 0 or 1 410 face_idx = sid - face_num*face_volume; // index into the respective face 412 // ghostOffset is scaled to include body (includes stride) and number of FloatN arrays (SPINOR_HOP) 413 // face_idx not sid since faces are spin projected and share the same volume index (modulo UP/DOWN reading) 414 //sp_idx = face_idx + param.ghostOffset[dim]; 418 coordsFromFaceIndex<5,QUDA_4D_PC,0,1>(X, sid, coord, face_idx, face_num, param); 421 coordsFromFaceIndex<5,QUDA_4D_PC,1,1>(X, sid, coord, face_idx, face_num, param); 424 coordsFromFaceIndex<5,QUDA_4D_PC,2,1>(X, sid, coord, face_idx, face_num, param); 427 coordsFromFaceIndex<5,QUDA_4D_PC,3,1>(X, sid, coord, face_idx, face_num, param); 433 for(int dir=0; dir<4; ++dir){ 434 active = active || isActive(dim,dir,+1,coord,param.commDim,param.dc.X); 441 READ_INTERMEDIATE_SPINOR(INTERTEX, param.sp_stride, sid, sid); 454 // declare G## here and use ASSN below instead of READ 456 #if (DD_PREC==0) //temporal hack 492 #include "read_clover.h" 493 #include "io_spinor.h" 495 int sid = blockIdx.x*blockDim.x + threadIdx.x; 496 if (sid >= param.threads) return; 498 // read spinor from device memory 499 READ_SPINOR(SPINORTEX, param.sp_stride, sid, sid); 506 def gen(dir, pack_only=False):
507 projIdx = dir
if not dagger
else dir + ( +1
if dir%2 == 0
else -1 )
510 return projectors[projIdx][4*i+j]
517 return (1, proj(i,1))
519 return (0, proj(i,0))
521 boundary = [
"coord[0]==(param.dc.X[0]-1)",
"coord[0]==0",
"coord[1]==(param.dc.X[1]-1)",
"coord[1]==0",
"coord[2]==(param.dc.X[2]-1)",
"coord[2]==0",
"coord[3]==(param.dc.X[3]-1)",
"coord[3]==0"]
522 interior = [
"coord[0]<(param.dc.X[0]-1)",
"coord[0]>0",
"coord[1]<(param.dc.X[1]-1)",
"coord[1]>0",
"coord[2]<(param.dc.X[2]-1)",
"coord[2]>0",
"coord[3]<(param.dc.X[3]-1)",
"coord[3]>0"]
524 offset = [
"+1",
"-1",
"+1",
"-1",
"+1",
"-1",
"+1",
"-1"]
526 dim = [
"X",
"Y",
"Z",
"T"]
529 sp_idx = [
"X+1",
"X-1",
"X+param.dc.X[0]",
"X-param.dc.X[0]",
"X+param.dc.X2X1",
"X-param.dc.X2X1",
"X+param.dc.X3X2X1",
"X-param.dc.X3X2X1"]
532 sp_idx_wrap = [
"X-(param.dc.X[0]-1)",
"X+(param.dc.X[0]-1)",
"X-param.dc.X2X1mX1",
"X+param.dc.X2X1mX1",
"X-param.dc.X3X2X1mX2X1",
"X+param.dc.X3X2X1mX2X1",
533 "X-param.dc.X4X3X2X1mX3X2X1",
"X+param.dc.X4X3X2X1mX3X2X1"]
536 cond +=
"if (isActive(dim," + `dir/2` +
"," + offset[dir] +
",coord,param.commDim,param.dc.X) && " + boundary[dir] +
" )\n" 541 projName =
"P"+`dir/2`+[
"-",
"+"][projIdx%2]
542 str +=
"// Projector "+projName+
"\n" 543 for l
in projStr.splitlines():
547 str +=
"faceIndexFromCoords<5,1>(face_idx,coord," + `dir/2` +
",param);\n" 548 str +=
"const int sp_idx = face_idx + param.ghostOffset[" + `dir/2` +
"][" + `1-dir%2` +
"];\n" 549 str +=
"#if (DD_PREC==2) // half precision\n" 550 str +=
" sp_norm_idx = face_idx + " 553 str +=
"param.ghostNormOffset[" + `dir/2` +
"][" + `1-dir%2` +
"];\n" 560 if domain_wall: str +=
"const int ga_idx = sid % param.dc.volume_4d_cb;\n" 561 else: str +=
"const int ga_idx = sid;\n" 563 if domain_wall: str +=
"const int ga_idx = param.dc.volume_4d_cb+(face_idx % param.dc.ghostFace[" + `dir/2` +
"]);\n" 564 else: str +=
"const int ga_idx = param.dc.volume_4d_cb+face_idx;\n" 568 row_cnt = ([0,0,0,0])
573 if re != 0
or im != 0:
575 row_cnt[0] += row_cnt[1]
576 row_cnt[2] += row_cnt[3]
579 for h
in range(0, 2):
580 for c
in range(0, 3):
581 decl_half +=
"spinorFloat "+
h1_re(h,c)+
", "+
h1_im(h,c)+
";\n";
584 load_spinor =
"// read spinor from device memory\n" 586 load_spinor +=
"READ_SPINOR_DOWN(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n" 587 elif row_cnt[2] == 0:
588 load_spinor +=
"READ_SPINOR_UP(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n" 590 load_spinor +=
"READ_SPINOR(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n" 595 load_half +=
"const int sp_stride_pad = param.dc.Ls*param.dc.ghostFace[" + `dir/2` +
"];\n" 597 load_half +=
"const int sp_stride_pad = param.dc.ghostFace[" + `dir/2` +
"];\n" 599 if dir >= 6: load_half +=
"const int t_proj_scale = TPROJSCALE;\n" 601 load_half +=
"// read half spinor from device memory\n" 605 load_half +=
"READ_SPINOR_GHOST(GHOSTSPINORTEX, sp_stride_pad, sp_idx, sp_norm_idx, "+`dir`+
");\n\n" 608 load_gauge =
"// read gauge matrix from device memory\n" 609 load_gauge +=
"ASSN_GAUGE_MATRIX(G, GAUGE"+`( dir%2)`+
"TEX, "+`dir`+
", ga_idx, param.gauge_stride);\n\n" 611 reconstruct_gauge =
"// reconstruct gauge matrix\n" 612 reconstruct_gauge +=
"RECONSTRUCT_GAUGE_MATRIX("+`dir`+
");\n\n" 614 project =
"// project spinor into half spinors\n" 615 for h
in range(0, 2):
616 for c
in range(0, 3):
619 for s
in range(0, 4):
622 if re==0
and im==0: ()
630 for s
in range(0, 4):
631 re = proj(h+2,s).real
632 im = proj(h+2,s).imag
633 if re==0
and im==0: ()
641 project +=
h1_re(h,c)+
" = "+strRe+
";\n" 642 project +=
h1_im(h,c)+
" = "+strIm+
";\n" 645 for h
in range(0, 2):
646 for c
in range(0, 3):
647 copy_half +=
h1_re(h,c)+
" = "+(
"t_proj_scale*" if (dir >= 6)
else "")+
in_re(h,c)+
"; " 648 copy_half +=
h1_im(h,c)+
" = "+(
"t_proj_scale*" if (dir >= 6)
else "")+
in_im(h,c)+
";\n" 654 prep_half +=
indent(load_half)
655 prep_half +=
indent(copy_half)
658 ident =
"// identity gauge matrix\n" 661 ident +=
"spinorFloat "+
h2_re(h,m)+
" = " +
h1_re(h,m) +
"; " 662 ident +=
"spinorFloat "+
h2_im(h,m)+
" = " +
h1_im(h,m) +
";\n" 667 mult +=
"// multiply row "+`m`+
"\n" 669 re =
"spinorFloat "+
h2_re(h,m)+
" = 0;\n" 670 im =
"spinorFloat "+
h2_im(h,m)+
" = 0;\n" 672 re +=
h2_re(h,m) +
" += " +
g_re(dir,m,c) +
" * "+
h1_re(h,c)+
";\n" 673 re +=
h2_re(h,m) +
" -= " +
g_im(dir,m,c) +
" * "+
h1_im(h,c)+
";\n" 674 im +=
h2_im(h,m) +
" += " +
g_re(dir,m,c) +
" * "+
h1_im(h,c)+
";\n" 675 im +=
h2_im(h,m) +
" += " +
g_im(dir,m,c) +
" * "+
h1_re(h,c)+
";\n" 686 reconstruct +=
out_re(h_out, m) +
" += " +
h2_re(h,m) +
";\n" 687 reconstruct +=
out_im(h_out, m) +
" += " +
h2_im(h,m) +
";\n" 693 if im == 0
and re == 0:
696 reconstruct +=
out_re(s, m) +
" " +
sign(re) +
"= " +
h2_re(h,m) +
";\n" 697 reconstruct +=
out_im(s, m) +
" " +
sign(re) +
"= " +
h2_im(h,m) +
";\n" 699 reconstruct +=
out_re(s, m) +
" " +
sign(-im) +
"= " +
h2_im(h,m) +
";\n" 700 reconstruct +=
out_im(s, m) +
" " +
sign(+im) +
"= " +
h2_re(h,m) +
";\n" 702 if ( m < 2 ): reconstruct +=
"\n" 705 str +=
"if (param.gauge_fixed && ga_idx < param.dc.X4X3X2X1hmX3X2X1h)\n" 706 str +=
block(decl_half + prep_half + ident + reconstruct)
708 str +=
block(load_gauge + decl_half + prep_half + reconstruct_gauge + mult + reconstruct)
710 str += load_gauge + decl_half + prep_half + reconstruct_gauge + mult + reconstruct
713 out = load_spinor + decl_half + project
714 out = out.replace(
"sp_idx",
"idx")
717 return cond +
block(str)+
"\n\n" 724 str +=
"VOLATILE spinorFloat kappa;\n\n" 725 str +=
"#ifdef MDWF_mode // Check whether MDWF option is enabled\n" 726 str +=
" kappa = (spinorFloat)(-(mdwf_c5[coord[4]]*(4.0 + m5) - 1.0)/(mdwf_b5[coord[4]]*(4.0 + m5) + 1.0));\n" 728 str +=
" kappa = 2.0*a;\n" 729 str +=
"#endif // select MDWF mode\n\n" 730 str +=
"// M5_inv operation -- NB: not partitionable!\n\n" 731 str +=
"// In this part, we will do the following operation in parallel way.\n\n" 732 str +=
"// w = M5inv * v\n" 733 str +=
"// 'w' means output vector\n" 734 str +=
"// 'v' means input vector\n" 736 str +=
" int base_idx = sid%param.dc.volume_4d_cb;\n" 737 str +=
" int sp_idx;\n\n" 738 str +=
"// let's assume the index,\n" 739 str +=
"// s = output vector index,\n" 740 str +=
"// s' = input vector index and\n" 741 str +=
"// 'a'= kappa5\n" 743 str +=
" spinorFloat inv_d_n = 1.0 / ( 1.0 + pow(kappa,param.dc.Ls)*mferm);\n" 744 str +=
" spinorFloat factorR;\n" 745 str +=
" spinorFloat factorL;\n" 747 str +=
" for(int s = 0; s < param.dc.Ls; s++)\n {\n" 749 str +=
" factorR = ( coord[4] > s ? -inv_d_n*pow(kappa,param.dc.Ls-coord[4]+s)*mferm : inv_d_n*pow(kappa,s-coord[4]))/2.0;\n\n" 751 str +=
" factorR = ( coord[4] < s ? -inv_d_n*pow(kappa,param.dc.Ls-s+coord[4])*mferm : inv_d_n*pow(kappa,coord[4]-s))/2.0;\n\n" 752 str +=
" sp_idx = base_idx + s*param.dc.volume_4d_cb;\n" 753 str +=
" // read spinor from device memory\n" 754 str +=
" READ_SPINOR( SPINORTEX, param.sp_stride, sp_idx, sp_idx );\n\n" 755 str +=
" o00_re += factorR*(i00_re + i20_re);\n" 756 str +=
" o00_im += factorR*(i00_im + i20_im);\n" 757 str +=
" o20_re += factorR*(i00_re + i20_re);\n" 758 str +=
" o20_im += factorR*(i00_im + i20_im);\n" 759 str +=
" o01_re += factorR*(i01_re + i21_re);\n" 760 str +=
" o01_im += factorR*(i01_im + i21_im);\n" 761 str +=
" o21_re += factorR*(i01_re + i21_re);\n" 762 str +=
" o21_im += factorR*(i01_im + i21_im);\n" 763 str +=
" o02_re += factorR*(i02_re + i22_re);\n" 764 str +=
" o02_im += factorR*(i02_im + i22_im);\n" 765 str +=
" o22_re += factorR*(i02_re + i22_re);\n" 766 str +=
" o22_im += factorR*(i02_im + i22_im);\n" 767 str +=
" o10_re += factorR*(i10_re + i30_re);\n" 768 str +=
" o10_im += factorR*(i10_im + i30_im);\n" 769 str +=
" o30_re += factorR*(i10_re + i30_re);\n" 770 str +=
" o30_im += factorR*(i10_im + i30_im);\n" 771 str +=
" o11_re += factorR*(i11_re + i31_re);\n" 772 str +=
" o11_im += factorR*(i11_im + i31_im);\n" 773 str +=
" o31_re += factorR*(i11_re + i31_re);\n" 774 str +=
" o31_im += factorR*(i11_im + i31_im);\n" 775 str +=
" o12_re += factorR*(i12_re + i32_re);\n" 776 str +=
" o12_im += factorR*(i12_im + i32_im);\n" 777 str +=
" o32_re += factorR*(i12_re + i32_re);\n" 778 str +=
" o32_im += factorR*(i12_im + i32_im);\n\n" 781 str +=
" factorL = ( coord[4] < s ? -inv_d_n*pow(kappa,param.dc.Ls-s+coord[4])*mferm : inv_d_n*pow(kappa,coord[4]-s))/2.0;\n\n" 783 str +=
" factorL = ( coord[4] > s ? -inv_d_n*pow(kappa,param.dc.Ls-coord[4]+s)*mferm : inv_d_n*pow(kappa,s-coord[4]))/2.0;\n\n" 785 str +=
" o00_re += factorL*(i00_re - i20_re);\n" 786 str +=
" o00_im += factorL*(i00_im - i20_im);\n" 787 str +=
" o01_re += factorL*(i01_re - i21_re);\n" 788 str +=
" o01_im += factorL*(i01_im - i21_im);\n" 789 str +=
" o02_re += factorL*(i02_re - i22_re);\n" 790 str +=
" o02_im += factorL*(i02_im - i22_im);\n" 791 str +=
" o10_re += factorL*(i10_re - i30_re);\n" 792 str +=
" o10_im += factorL*(i10_im - i30_im);\n" 793 str +=
" o11_re += factorL*(i11_re - i31_re);\n" 794 str +=
" o11_im += factorL*(i11_im - i31_im);\n" 795 str +=
" o12_re += factorL*(i12_re - i32_re);\n" 796 str +=
" o12_im += factorL*(i12_im - i32_im);\n" 797 str +=
" o20_re += factorL*(i20_re - i00_re);\n" 798 str +=
" o20_im += factorL*(i20_im - i00_im);\n" 799 str +=
" o21_re += factorL*(i21_re - i01_re);\n" 800 str +=
" o21_im += factorL*(i21_im - i01_im);\n" 801 str +=
" o22_re += factorL*(i22_re - i02_re);\n" 802 str +=
" o22_im += factorL*(i22_im - i02_im);\n" 803 str +=
" o30_re += factorL*(i30_re - i10_re);\n" 804 str +=
" o30_im += factorL*(i30_im - i10_im);\n" 805 str +=
" o31_re += factorL*(i31_re - i11_re);\n" 806 str +=
" o31_im += factorL*(i31_im - i11_im);\n" 807 str +=
" o32_re += factorL*(i32_re - i12_re);\n" 808 str +=
" o32_im += factorL*(i32_im - i12_im);\n" 810 str +=
"} // end of M5inv dimension\n\n" 818 if z==0:
return out_re(s,c)
821 if z==0:
return in_re(s,c)
822 else:
return in_im(s,c)
836 for s
in range (0,4):
840 return block(str)+
"\n\n" 848 str +=
"spinorFloat "+
a_re(0,1,c)+
" = -"+
out_re(0,c)+
" - "+
out_re(2,c)+
";\n" 849 str +=
"spinorFloat "+
a_im(0,1,c)+
" = -"+
out_im(0,c)+
" - "+
out_im(2,c)+
";\n" 852 str +=
"spinorFloat "+
a_re(0,3,c)+
" = -"+
out_re(0,c)+
" + "+
out_re(2,c)+
";\n" 853 str +=
"spinorFloat "+
a_im(0,3,c)+
" = -"+
out_im(0,c)+
" + "+
out_im(2,c)+
";\n" 856 for s
in range (0,4):
860 return block(str)+
"\n\n" 865 str =
"READ_CLOVER(CLOVERTEX, "+`chi`+
")\n\n" 867 for s
in range (0,2):
868 for c
in range (0,3):
869 str +=
"spinorFloat "+
a_re(chi,s,c)+
" = 0; spinorFloat "+
a_im(chi,s,c)+
" = 0;\n" 872 for sm
in range (0,2):
873 for cm
in range (0,3):
874 for sn
in range (0,2):
875 for cn
in range (0,3):
876 str +=
a_re(chi,sm,cm)+
" += "+
c_re(chi,sm,cm,sn,cn)+
" * "+
out_re(2*chi+sn,cn)+
";\n" 877 if (sn != sm)
or (cn != cm):
878 str +=
a_re(chi,sm,cm)+
" -= "+
c_im(chi,sm,cm,sn,cn)+
" * "+
out_im(2*chi+sn,cn)+
";\n" 880 str +=
a_im(chi,sm,cm)+
" += "+
c_re(chi,sm,cm,sn,cn)+
" * "+
out_im(2*chi+sn,cn)+
";\n" 881 if (sn != sm)
or (cn != cm):
882 str +=
a_im(chi,sm,cm)+
" += "+
c_im(chi,sm,cm,sn,cn)+
" * "+
out_re(2*chi+sn,cn)+
";\n" 886 for s
in range (0,2):
887 for c
in range (0,3):
888 str +=
out_re(2*chi+s,c)+
" = "+
a_re(chi,s,c)+
"; " 889 str +=
out_im(2*chi+s,c)+
" = "+
a_im(chi,s,c)+
";\n" 892 return block(str)+
"\n\n" 897 if domain_wall:
return "" 899 if dslash: str +=
"#ifdef DSLASH_CLOVER\n\n" 900 str +=
"// change to chiral basis\n" 902 str +=
"// apply first chiral block\n" 904 str +=
"// apply second chiral block\n" 906 str +=
"// change back from chiral basis\n" 907 str +=
"// (note: required factor of 1/2 is included in clover term normalization)\n" 909 if dslash: str +=
"#endif // DSLASH_CLOVER\n\n" 927 str +=
"#ifdef SPINOR_DOUBLE\n" 943 str +=
"#endif // SPINOR_DOUBLE\n" 949 str +=
"#ifdef DSLASH_XPAY\n" 950 str +=
"READ_ACCUM(ACCUMTEX, param.sp_stride)\n" 953 str +=
"VOLATILE spinorFloat coeff;\n\n" 954 str +=
"#ifdef MDWF_mode\n" 955 str +=
"coeff = (spinorFloat)(0.5*a/(mdwf_b5[coord[4]]*(m5+4.0) + 1.0));\n" 957 str +=
"coeff = a;\n" 960 str +=
"VOLATILE spinorFloat coeff;\n\n" 961 str +=
"#ifdef MDWF_mode\n" 962 str +=
"coeff = (spinorFloat)(0.5/(mdwf_b5[coord[4]]*(m5+4.0) + 1.0));\n" 963 str +=
"coeff *= -coeff;\n" 965 str +=
"coeff = a;\n" 967 str +=
"#ifdef YPAX\n" 971 str +=
"#ifdef SPINOR_DOUBLE\n" 987 str +=
"#endif // SPINOR_DOUBLE\n" 989 str +=
"#endif // YPAX\n" 990 str +=
"#endif // DSLASH_XPAY\n" 1001 str +=
"// write spinor field back to device memory\n" 1002 str +=
"WRITE_SPINOR(param.sp_stride);\n\n" 1004 str +=
"// undefine to prevent warning when precision is changed\n" 1005 str +=
"#undef m5\n" 1006 str +=
"#undef mdwf_b5\n" 1007 str +=
"#undef mdwf_c5\n" 1008 str +=
"#undef mferm\n" 1011 str +=
"#undef spinorFloat\n" 1012 str +=
"#undef SHARED_STRIDE\n\n" 1016 for m
in range(0,3):
1017 for n
in range(0,3):
1019 str +=
"#undef "+
g_re(0,m,n)+
"\n" 1020 str +=
"#undef "+
g_im(0,m,n)+
"\n" 1023 for s
in range(0,4):
1024 for c
in range(0,3):
1026 str +=
"#undef "+
in_re(s,c)+
"\n" 1027 str +=
"#undef "+
in_im(s,c)+
"\n" 1031 for m
in range(0,6):
1034 str +=
"#undef "+
c_re(0,s,c,s,c)+
"\n" 1035 for n
in range(0,6):
1038 for m
in range(n+1,6):
1041 str +=
"#undef "+
c_re(0,sm,cm,sn,cn)+
"\n" 1042 str +=
"#undef "+
c_im(0,sm,cm,sn,cn)+
"\n" 1045 for s
in range(0,4):
1046 for c
in range(0,3):
1048 if 2*i < sharedFloats:
1049 str +=
"#undef "+
out_re(s,c)+
"\n" 1050 if 2*i+1 < sharedFloats:
1051 str +=
"#undef "+
out_im(s,c)+
"\n" 1054 str +=
"#undef VOLATILE\n" 1055 str +=
"#endif // MULTI_GPU\n" 1063 str +=
"switch(dim) {\n" 1064 for dim
in range(0,4):
1065 str +=
"case "+`dim`+
":\n" 1066 proj =
gen(2*dim+facenum, pack_only=
True)
1068 proj +=
"// write half spinor back to device memory\n" 1069 proj +=
"WRITE_HALF_SPINOR(face_volume, face_idx);\n" 1076 assert (sharedFloats == 0)
1079 str +=
"#include \"io_spinor.h\"\n\n" 1081 str +=
"if (face_num) " 1087 str +=
"// undefine to prevent warning when precision is changed\n" 1088 str +=
"#undef spinorFloat\n" 1089 str +=
"#undef SHARED_STRIDE\n\n" 1091 for s
in range(0,4):
1092 for c
in range(0,3):
1094 str +=
"#undef "+
in_re(s,c)+
"\n" 1095 str +=
"#undef "+
in_im(s,c)+
"\n" 1103 for i
in range(0,8) :
1110 for i
in range(0,8) :
1133 cloverSharedFloats = 0
1134 if(
len(sys.argv) > 1):
1135 if (sys.argv[1] ==
'--shared'):
1136 sharedFloats =
int(sys.argv[2])
1137 print "Shared floats set to " + str(sharedFloats);
1146 print sys.argv[0] +
": generating dw_fused_exterior_dslash4_core.h";
1149 f = open(
'dslash_core/dw_fused_exterior_dslash4_core.h',
'w')
1153 print sys.argv[0] +
": generating dw_fused_exterior_dslash4_dagger_core.h";
1156 f = open(
'dslash_core/dw_fused_exterior_dslash4_dagger_core.h',
'w')
def c_im(b, sm, cm, sn, cn)
def input_spinor(s, c, z)
def gen(dir, pack_only=False)
def c_re(b, sm, cm, sn, cn)
def complexify(a)
complex numbers ######################################################################## ...
def indent(code, n=1)
code generation ######################################################################## ...