7 return [complex(x)
for x
in a]
11 if a ==
int(a):
return `
int(a)`
15 if a == 0:
return "0i" 16 elif a == -1:
return "-i" 17 elif a == 1:
return "i" 18 else:
return fltToString(a)+
"i" 22 if re == 0
and im == 0:
return "0" 23 elif re == 0:
return imToString(im)
24 elif im == 0:
return fltToString(re)
26 im_str =
"-"+imToString(-im)
if im < 0
else "+"+imToString(im)
27 return fltToString(re)+im_str
74 two_P_L = [ id[x] - igamma5[x]/1j
for x
in range(0,4*4) ]
75 two_P_R = [ id[x] + igamma5[x]/1j
for x
in range(0,4*4) ]
85 return [x+y
for (x,y)
in zip(g1,g2)]
88 return [x-y
for (x,y)
in zip(g1,g2)]
108 def indentline(line):
return (n*
" "+line
if ( line
and line.count(
"#", 0, 1) == 0)
else line)
109 return ''.join([indentline(line)+
"\n" for line
in code.splitlines()])
112 return "{\n"+
indent(code)+
"}" 116 elif x==-1:
return "-" 117 elif x==+2:
return "+2*" 118 elif x==-2:
return "-2*" 121 return `(n/4)` +
"." + [
"x",
"y",
"z",
"w"][n%4]
124 return `(n/2)` +
"." + [
"x",
"y"][n%2]
127 def in_re(s, c):
return "i"+`s`+`c`+
"_re" 128 def in_im(s, c):
return "i"+`s`+`c`+
"_im" 129 def g_re(d, m, n):
return (
"g" if (d%2==0)
else "gT")+`m`+`n`+
"_re" 130 def g_im(d, m, n):
return (
"g" if (d%2==0)
else "gT")+`m`+`n`+
"_im" 131 def out_re(s, c):
return "o"+`s`+`c`+
"_re" 132 def out_im(s, c):
return "o"+`s`+`c`+
"_im" 133 def h1_re(h, c):
return [
"a",
"b"][h]+`c`+
"_re" 134 def h1_im(h, c):
return [
"a",
"b"][h]+`c`+
"_im" 135 def h2_re(h, c):
return [
"A",
"B"][h]+`c`+
"_re" 136 def h2_im(h, c):
return [
"A",
"B"][h]+`c`+
"_im" 137 def c_re(b, sm, cm, sn, cn):
return "c"+`(sm+2*b)`+`cm`+
"_"+`(sn+2*b)`+`cn`+
"_re" 138 def c_im(b, sm, cm, sn, cn):
return "c"+`(sm+2*b)`+`cm`+
"_"+`(sn+2*b)`+`cn`+
"_im" 139 def a_re(b, s, c):
return "a"+`(s+2*b)`+`c`+
"_re" 140 def a_im(b, s, c):
return "a"+`(s+2*b)`+`c`+
"_im" 142 def tmp_re(s, c):
return "tmp"+`s`+`c`+
"_re" 143 def tmp_im(s, c):
return "tmp"+`s`+`c`+
"_im" 148 str +=
"// input spinor\n" 149 str +=
"#ifdef SPINOR_DOUBLE\n" 150 str +=
"#define spinorFloat double\n" 151 str +=
"// workaround for C++11 bug in CUDA 6.5/7.0\n" 152 str +=
"#if CUDA_VERSION >= 6050 && CUDA_VERSION < 7050\n" 153 str +=
"#define POW(a, b) pow(a, static_cast<spinorFloat>(b))\n" 155 str +=
"#define POW(a, b) pow(a, b)\n" 162 str +=
"#define m5 param.m5_d\n" 163 str +=
"#define mdwf_b5 param.mdwf_b5_d\n" 164 str +=
"#define mdwf_c5 param.mdwf_c5_d\n" 165 str +=
"#define mferm param.mferm\n" 166 str +=
"#define a param.a\n" 167 str +=
"#define b param.b\n" 169 str +=
"#define spinorFloat float\n" 170 str +=
"#define POW(a, b) __fast_pow(a, b)\n" 176 str +=
"#define m5 param.m5_f\n" 177 str +=
"#define mdwf_b5 param.mdwf_b5_f\n" 178 str +=
"#define mdwf_c5 param.mdwf_c5_f\n" 179 str +=
"#define mferm param.mferm_f\n" 180 str +=
"#define a param.a\n" 181 str +=
"#define b param.b\n" 182 str +=
"#endif // SPINOR_DOUBLE\n\n" 188 str =
"// gauge link\n" 189 str +=
"#ifdef GAUGE_FLOAT2\n" 205 str +=
"#endif // GAUGE_DOUBLE\n\n" 207 str +=
"// conjugated gauge link\n" 211 str +=
"#define "+
g_re(1,m,n)+
" (+"+
g_re(0,n,m)+
")\n" 212 str +=
"#define "+
g_im(1,m,n)+
" (-"+
g_im(0,n,m)+
")\n" 220 str =
"// first chiral block of inverted clover term\n" 221 str +=
"#ifdef CLOVER_DOUBLE\n" 231 for m
in range(n+1,6):
235 str +=
"#define "+
c_im(0,sm,cm,sn,cn)+
" C"+
nthFloat2(i+1)+
"\n" 247 for m
in range(n+1,6):
251 str +=
"#define "+
c_im(0,sm,cm,sn,cn)+
" C"+
nthFloat4(i+1)+
"\n" 253 str +=
"#endif // CLOVER_DOUBLE\n\n" 261 str +=
"#define "+
c_re(0,sm,cm,sn,cn)+
" (+"+
c_re(0,sn,cn,sm,cm)+
")\n" 262 str +=
"#define "+
c_im(0,sm,cm,sn,cn)+
" (-"+
c_im(0,sn,cn,sm,cm)+
")\n" 265 str +=
"// second chiral block of inverted clover term (reuses C0,...,C9)\n" 272 str +=
"#define "+
c_re(1,sm,cm,sn,cn)+
" "+
c_re(0,sm,cm,sn,cn)+
"\n" 273 if m != n: str +=
"#define "+
c_im(1,sm,cm,sn,cn)+
" "+
c_im(0,sm,cm,sn,cn)+
"\n" 280 str =
"// output spinor\n" 284 if 2*i < sharedFloats:
285 str +=
"#define "+
out_re(s,c)+
" s["+`(2*i+0)`+
"*SHARED_STRIDE]\n" 287 str +=
"VOLATILE spinorFloat "+
out_re(s,c)+
";\n" 288 if 2*i+1 < sharedFloats:
289 str +=
"#define "+
out_im(s,c)+
" s["+`(2*i+1)`+
"*SHARED_STRIDE]\n" 291 str +=
"VOLATILE spinorFloat "+
out_im(s,c)+
";\n" 298 prolog_str= (
"// *** CUDA DSLASH ***\n\n" if not dagger
else "// *** CUDA DSLASH DAGGER ***\n\n")
299 prolog_str+=
"#define DSLASH_SHARED_FLOATS_PER_THREAD "+str(sharedFloats)+
"\n\n" 301 prolog_str= (
"// *** CUDA CLOVER ***\n\n")
302 prolog_str+=
"#define CLOVER_SHARED_FLOATS_PER_THREAD "+str(sharedFloats)+
"\n\n" 304 print "Undefined prolog" 309 #if (CUDA_VERSION >= 4010) 312 #define VOLATILE volatile 326 #if (__COMPUTE_CAPABILITY__ >= 200) 327 #define SHARED_STRIDE 16 // to avoid bank conflicts on Fermi 329 #define SHARED_STRIDE 8 // to avoid bank conflicts on G80 and GT200 332 #if (__COMPUTE_CAPABILITY__ >= 200) 333 #define SHARED_STRIDE 32 // to avoid bank conflicts on Fermi 335 #define SHARED_STRIDE 16 // to avoid bank conflicts on G80 and GT200 343 extern __shared__ char s_data[]; 349 VOLATILE spinorFloat *s = (spinorFloat*)s_data + DSLASH_SHARED_FLOATS_PER_THREAD*SHARED_STRIDE*(threadIdx.x/SHARED_STRIDE) 350 + (threadIdx.x % SHARED_STRIDE); 355 VOLATILE spinorFloat *s = (spinorFloat*)s_data + CLOVER_SHARED_FLOATS_PER_THREAD*SHARED_STRIDE*(threadIdx.x/SHARED_STRIDE) 356 + (threadIdx.x % SHARED_STRIDE); 362 prolog_str +=
"\n#include \"read_gauge.h\"\n" 364 prolog_str +=
"#include \"read_clover.h\"\n" 365 prolog_str +=
"#include \"io_spinor.h\"\n" 368 int sid = ((blockIdx.y*blockDim.y + threadIdx.y)*gridDim.x + blockIdx.x)*blockDim.x + threadIdx.x; 369 if (sid >= param.threads*param.dc.Ls) return; 381 int X, coord[5], boundaryCrossing; 392 if (kernel_type == INTERIOR_KERNEL) { 403 // Assume even dimensions 404 coordsFromIndex<5,QUDA_4D_PC,EVEN_X>(X, coord, sid, param); 410 boundaryCrossing = sid/param.dc.Xh[0] + sid/(param.dc.X[1]*param.dc.Xh[0]) + sid/(param.dc.X[2]*param.dc.X[1]*param.dc.Xh[0]); 412 X = 2*sid + (boundaryCrossing + param.parity) % 2; 413 coord[4] = X/(param.dc.X[0]*param.dc.X[1]*param.dc.X[2]*param.dc.X[3]); 421 int aux1 = X / param.dc.X[0]; 422 x1 = X - aux1 * param.dc.X[0]; 423 int aux2 = aux1 / param.dc.X[1]; 424 x2 = aux1 - aux2 * param.dc.X[1]; 425 x4 = aux2 / param.dc.X[2]; 426 x3 = aux2 - x4 * param.dc.X[2]; 427 aux1 = (param.parity + x4 + x3 + x2) & 1; 442 } else { // exterior kernel 444 const int face_volume = (param.threads*param.dc.Ls >> 1); // volume of one face 445 const int face_num = (sid >= face_volume); // is this thread updating face 0 or 1 446 face_idx = sid - face_num*face_volume; // index into the respective face 448 // ghostOffset is scaled to include body (includes stride) and number of FloatN arrays (SPINOR_HOP) 449 // face_idx not sid since faces are spin projected and share the same volume index (modulo UP/DOWN reading) 450 //sp_idx = face_idx + param.ghostOffset[dim]; 452 coordsFromFaceIndex<5,QUDA_4D_PC,kernel_type,1>(X, sid, coord, face_idx, face_num, param); 454 READ_INTERMEDIATE_SPINOR(INTERTEX, param.sp_stride, sid, sid); 467 // declare G## here and use ASSN below instead of READ 469 #if (DD_PREC==0) //temporal hack 505 #include "read_clover.h" 506 #include "io_spinor.h" 508 int sid = blockIdx.x*blockDim.x + threadIdx.x; 509 if (sid >= param.threads) return; 511 // read spinor from device memory 512 READ_SPINOR(SPINORTEX, param.sp_stride, sid, sid); 519 def gen(dir, pack_only=False):
520 projIdx = dir
if not dagger
else dir + ( +1
if dir%2 == 0
else -1 )
523 return projectors[projIdx][4*i+j]
530 return (1, proj(i,1))
532 return (0, proj(i,0))
534 boundary = [
"coord[0]==(param.dc.X[0]-1)",
"coord[0]==0",
"coord[1]==(param.dc.X[1]-1)",
"coord[1]==0",
"coord[2]==(param.dc.X[2]-1)",
"coord[2]==0",
"coord[3]==(param.dc.X[3]-1)",
"coord[3]==0"]
535 interior = [
"coord[0]<(param.dc.X[0]-1)",
"coord[0]>0",
"coord[1]<(param.dc.X[1]-1)",
"coord[1]>0",
"coord[2]<(param.dc.X[2]-1)",
"coord[2]>0",
"coord[3]<(param.dc.X[3]-1)",
"coord[3]>0"]
536 dim = [
"X",
"Y",
"Z",
"T"]
539 sp_idx = [
"X+1",
"X-1",
"X+param.dc.X[0]",
"X-param.dc.X[0]",
"X+param.dc.X2X1",
"X-param.dc.X2X1",
"X+param.dc.X3X2X1",
"X-param.dc.X3X2X1"]
542 sp_idx_wrap = [
"X-(param.dc.X[0]-1)",
"X+(param.dc.X[0]-1)",
"X-param.dc.X2X1mX1",
"X+param.dc.X2X1mX1",
"X-param.dc.X3X2X1mX2X1",
"X+param.dc.X3X2X1mX2X1",
543 "X-param.dc.X4X3X2X1mX3X2X1",
"X+param.dc.X4X3X2X1mX3X2X1"]
546 cond +=
"#ifdef MULTI_GPU\n" 547 cond +=
"if ( (kernel_type == INTERIOR_KERNEL && (!param.ghostDim["+`dir/2`+
"] || "+interior[dir]+
")) ||\n" 548 cond +=
" (kernel_type == EXTERIOR_KERNEL_"+dim[dir/2]+
" && "+boundary[dir]+
") )\n" 553 projName =
"P"+`dir/2`+[
"-",
"+"][projIdx%2]
554 str +=
"// Projector "+projName+
"\n" 555 for l
in projStr.splitlines():
559 str +=
"#ifdef MULTI_GPU\n" 560 str +=
"const int sp_idx = (kernel_type == INTERIOR_KERNEL) ? ("+boundary[dir]+
" ? "+sp_idx_wrap[dir]+
" : "+sp_idx[dir]+
") >> 1 :\n" 561 str +=
" face_idx + param.ghostOffset[static_cast<int>(kernel_type)][" + `(dir+1)%2` +
"];\n" 562 str +=
"#if (DD_PREC==2) // half precision\n" 563 str +=
"const int sp_norm_idx = face_idx + param.ghostNormOffset[static_cast<int>(kernel_type)][" + `(dir+1)%2` +
"];\n" 566 str +=
"const int sp_idx = ("+boundary[dir]+
" ? "+sp_idx_wrap[dir]+
" : "+sp_idx[dir]+
") >> 1;\n" 571 if domain_wall: str +=
"const int ga_idx = sid % param.dc.volume_4d_cb;\n" 572 else: str +=
"const int ga_idx = sid;\n" 574 str +=
"#ifdef MULTI_GPU\n" 575 if domain_wall: str +=
"const int ga_idx = ((kernel_type == INTERIOR_KERNEL) ? sp_idx % param.dc.volume_4d_cb : param.dc.volume_4d_cb+(face_idx % param.dc.ghostFace[static_cast<int>(kernel_type)]));\n" 576 else: str +=
"const int ga_idx = ((kernel_type == INTERIOR_KERNEL) ? sp_idx : param.dc.volume_4d_cb+face_idx);\n" 578 if domain_wall: str +=
"const int ga_idx = sp_idx % param.dc.volume_4d_cb;\n" 579 else: str +=
"const int ga_idx = sp_idx;\n" 584 row_cnt = ([0,0,0,0])
589 if re != 0
or im != 0:
591 row_cnt[0] += row_cnt[1]
592 row_cnt[2] += row_cnt[3]
595 for h
in range(0, 2):
596 for c
in range(0, 3):
597 decl_half +=
"spinorFloat "+
h1_re(h,c)+
", "+
h1_im(h,c)+
";\n";
600 load_spinor =
"// read spinor from device memory\n" 602 load_spinor +=
"READ_SPINOR_DOWN(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n" 603 elif row_cnt[2] == 0:
604 load_spinor +=
"READ_SPINOR_UP(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n" 606 load_spinor +=
"READ_SPINOR(SPINORTEX, param.sp_stride, sp_idx, sp_idx);\n" 611 load_half +=
"const int sp_stride_pad = param.dc.Ls*param.dc.ghostFace[static_cast<int>(kernel_type)];\n" 613 load_half +=
"const int sp_stride_pad = param.dc.ghostFace[static_cast<int>(kernel_type)];\n" 618 if dir >= 6: load_half +=
"const int t_proj_scale = TPROJSCALE;\n" 620 load_half +=
"// read half spinor from device memory\n" 624 load_half +=
"READ_SPINOR_GHOST(GHOSTSPINORTEX, sp_stride_pad, sp_idx, sp_norm_idx, "+`dir`+
");\n\n" 627 load_gauge =
"// read gauge matrix from device memory\n" 628 load_gauge +=
"ASSN_GAUGE_MATRIX(G, GAUGE"+`( dir%2)`+
"TEX, "+`dir`+
", ga_idx, param.gauge_stride);\n\n" 630 reconstruct_gauge =
"// reconstruct gauge matrix\n" 631 reconstruct_gauge +=
"RECONSTRUCT_GAUGE_MATRIX("+`dir`+
");\n\n" 633 project =
"// project spinor into half spinors\n" 634 for h
in range(0, 2):
635 for c
in range(0, 3):
638 for s
in range(0, 4):
641 if re==0
and im==0: ()
649 for s
in range(0, 4):
650 re = proj(h+2,s).real
651 im = proj(h+2,s).imag
652 if re==0
and im==0: ()
660 project +=
h1_re(h,c)+
" = "+strRe+
";\n" 661 project +=
h1_im(h,c)+
" = "+strIm+
";\n" 664 for h
in range(0, 2):
665 for c
in range(0, 3):
666 copy_half +=
h1_re(h,c)+
" = "+(
"t_proj_scale*" if (dir >= 6)
else "")+
in_re(h,c)+
"; " 667 copy_half +=
h1_im(h,c)+
" = "+(
"t_proj_scale*" if (dir >= 6)
else "")+
in_im(h,c)+
";\n" 671 prep_half +=
"#ifdef MULTI_GPU\n" 672 prep_half +=
"if (kernel_type == INTERIOR_KERNEL) {\n" 673 prep_half +=
"#endif\n" 675 prep_half +=
indent(load_spinor)
676 prep_half +=
indent(project)
678 prep_half +=
"#ifdef MULTI_GPU\n" 679 prep_half +=
"} else {\n" 681 prep_half +=
indent(load_half)
682 prep_half +=
indent(copy_half)
684 prep_half +=
"#endif // MULTI_GPU\n" 687 ident =
"// identity gauge matrix\n" 690 ident +=
"spinorFloat "+
h2_re(h,m)+
" = " +
h1_re(h,m) +
"; " 691 ident +=
"spinorFloat "+
h2_im(h,m)+
" = " +
h1_im(h,m) +
";\n" 696 mult +=
"// multiply row "+`m`+
"\n" 698 re =
"spinorFloat "+
h2_re(h,m)+
" = 0;\n" 699 im =
"spinorFloat "+
h2_im(h,m)+
" = 0;\n" 701 re +=
h2_re(h,m) +
" += " +
g_re(dir,m,c) +
" * "+
h1_re(h,c)+
";\n" 702 re +=
h2_re(h,m) +
" -= " +
g_im(dir,m,c) +
" * "+
h1_im(h,c)+
";\n" 703 im +=
h2_im(h,m) +
" += " +
g_re(dir,m,c) +
" * "+
h1_im(h,c)+
";\n" 704 im +=
h2_im(h,m) +
" += " +
g_im(dir,m,c) +
" * "+
h1_re(h,c)+
";\n" 715 reconstruct +=
out_re(h_out, m) +
" += " +
h2_re(h,m) +
";\n" 716 reconstruct +=
out_im(h_out, m) +
" += " +
h2_im(h,m) +
";\n" 722 if im == 0
and re == 0:
725 reconstruct +=
out_re(s, m) +
" " +
sign(re) +
"= " +
h2_re(h,m) +
";\n" 726 reconstruct +=
out_im(s, m) +
" " +
sign(re) +
"= " +
h2_im(h,m) +
";\n" 728 reconstruct +=
out_re(s, m) +
" " +
sign(-im) +
"= " +
h2_im(h,m) +
";\n" 729 reconstruct +=
out_im(s, m) +
" " +
sign(+im) +
"= " +
h2_re(h,m) +
";\n" 731 if ( m < 2 ): reconstruct +=
"\n" 734 str +=
"if (param.gauge_fixed && ga_idx < param.dc.X4X3X2X1hmX3X2X1h)\n" 735 str +=
block(decl_half + prep_half + ident + reconstruct)
737 str +=
block(load_gauge + decl_half + prep_half + reconstruct_gauge + mult + reconstruct)
739 str += load_gauge + decl_half + prep_half + reconstruct_gauge + mult + reconstruct
742 out = load_spinor + decl_half + project
743 out = out.replace(
"sp_idx",
"idx")
746 return cond +
block(str)+
"\n\n" 750 if dagger: lsign=
'-'; ledge =
'0'; rsign=
'+'; redge=
'param.dc.Ls-1' 751 else: lsign=
'+'; ledge =
'param.dc.Ls-1'; rsign=
'-'; redge=
'0' 754 str +=
"// 5th dimension -- NB: not partitionable!\n" 756 str +=
"#ifdef MULTI_GPU\nif(kernel_type == INTERIOR_KERNEL)\n#endif\n" 757 str +=
"{\n// 2 P_L = 2 P_- = ( ( +1, -1 ), ( -1, +1 ) )\n" 759 str +=
" int sp_idx = ( coord[4] == %s ? X%s(param.dc.Ls-1)*2*param.dc.volume_4d_cb : X%s2*param.dc.volume_4d_cb ) / 2;\n" % (ledge, rsign, lsign)
761 str +=
"// read spinor from device memory\n" 762 str +=
" READ_SPINOR( SPINORTEX, param.sp_stride, sp_idx, sp_idx );\n" 764 str +=
" if ( coord[4] != %s )\n" % ledge
768 return two_P_L[4*i+j]
772 for s1
in range(0,4):
775 re_rhs, im_rhs =
"",
"" 776 for s2
in range(0,4):
777 re, im = proj(s1,s2).real, proj(s1,s2).imag
784 out_L += 3*
" " +
out_re(s1,c) +
" += " + re_rhs +
";" 785 out_L += 3*
" " +
out_im(s1,c) +
" += " + im_rhs +
";\n" 786 if s1 < 3 : out_L +=
"\n" 796 str += out_L.replace(
" += ",
" += -mferm*(").replace(
";",
");")
798 str +=
" } // end if ( coord[4] != %s )\n" % ledge
799 str +=
" } // end P_L\n\n" 800 str +=
" // 2 P_R = 2 P_+ = ( ( +1, +1 ), ( +1, +1 ) )\n" 802 str +=
" int sp_idx = ( coord[4] == %s ? X%s(param.dc.Ls-1)*2*param.dc.volume_4d_cb : X%s2*param.dc.volume_4d_cb ) / 2;\n" % (redge, lsign, rsign)
804 str +=
"// read spinor from device memory\n" 805 str +=
" READ_SPINOR( SPINORTEX, param.sp_stride, sp_idx, sp_idx );\n" 807 str +=
" if ( coord[4] != %s )\n" % redge
811 str += out_L.replace(
"-",
"+")
818 str += out_L.replace(
"-",
"+").replace(
" += ",
" += -mferm*(").replace(
";",
");")
820 str +=
" } // end if ( coord[4] != %s )\n" % redge
821 str +=
" } // end P_R\n\n" 824 str +=
" // MDWF Dslash_5 operator is given as follow\n" 825 str +=
" // Dslash4pre = [c_5(s)(P_+\delta_{s,s`+1} - mP_+\delta_{s,0}\delta_{s`,L_s-1}\n" 826 str +=
" // + P_-\delta_{s,s`-1}-mP_-\delta_{s,L_s-1}\delta_{s`,0})\n" 827 str +=
" // + b_5(s)\delta_{s,s`}]\delta_{x,x`}\n" 828 str +=
" // For Dslash4pre\n" 829 str +=
" // C_5 \equiv c_5(s)*0.5\n" 830 str +=
" // B_5 \equiv b_5(s)\n" 831 str +=
" // For Dslash5\n" 832 str +=
" // C_5 \equiv 0.5*{c_5(s)(4+M_5)-1}/{b_5(s)(4+M_5)+1}\n" 833 str +=
" // B_5 \equiv 1.0\n" 834 str +=
"#ifdef MDWF_mode // Check whether MDWF option is enabled\n" 835 str +=
"#if (MDWF_mode==1)\n" 836 str +=
" VOLATILE spinorFloat C_5;\n" 837 str +=
" VOLATILE spinorFloat B_5;\n" 838 str +=
" C_5 = mdwf_c5[ coord[4] ]*static_cast<spinorFloat>(0.5);\n" 839 str +=
" B_5 = mdwf_b5[ coord[4] ];\n\n" 840 str +=
" READ_SPINOR( SPINORTEX, param.sp_stride, X/2, X/2 );\n" 841 str +=
" o00_re = C_5*o00_re + B_5*i00_re;\n" 842 str +=
" o00_im = C_5*o00_im + B_5*i00_im;\n" 843 str +=
" o01_re = C_5*o01_re + B_5*i01_re;\n" 844 str +=
" o01_im = C_5*o01_im + B_5*i01_im;\n" 845 str +=
" o02_re = C_5*o02_re + B_5*i02_re;\n" 846 str +=
" o02_im = C_5*o02_im + B_5*i02_im;\n" 847 str +=
" o10_re = C_5*o10_re + B_5*i10_re;\n" 848 str +=
" o10_im = C_5*o10_im + B_5*i10_im;\n" 849 str +=
" o11_re = C_5*o11_re + B_5*i11_re;\n" 850 str +=
" o11_im = C_5*o11_im + B_5*i11_im;\n" 851 str +=
" o12_re = C_5*o12_re + B_5*i12_re;\n" 852 str +=
" o12_im = C_5*o12_im + B_5*i12_im;\n" 853 str +=
" o20_re = C_5*o20_re + B_5*i20_re;\n" 854 str +=
" o20_im = C_5*o20_im + B_5*i20_im;\n" 855 str +=
" o21_re = C_5*o21_re + B_5*i21_re;\n" 856 str +=
" o21_im = C_5*o21_im + B_5*i21_im;\n" 857 str +=
" o22_re = C_5*o22_re + B_5*i22_re;\n" 858 str +=
" o22_im = C_5*o22_im + B_5*i22_im;\n" 859 str +=
" o30_re = C_5*o30_re + B_5*i30_re;\n" 860 str +=
" o30_im = C_5*o30_im + B_5*i30_im;\n" 861 str +=
" o31_re = C_5*o31_re + B_5*i31_re;\n" 862 str +=
" o31_im = C_5*o31_im + B_5*i31_im;\n" 863 str +=
" o32_re = C_5*o32_re + B_5*i32_re;\n" 864 str +=
" o32_im = C_5*o32_im + B_5*i32_im;\n" 865 str +=
"#elif (MDWF_mode==2)\n" 866 str +=
" VOLATILE spinorFloat C_5;\n" 867 str +=
" C_5 = static_cast<spinorFloat>(0.5)*(mdwf_c5[ coord[4] ]*(m5+static_cast<spinorFloat>(4.0)) - static_cast<spinorFloat>(1.0))/(mdwf_b5[ coord[4] ]*(m5+static_cast<spinorFloat>(4.0)) + static_cast<spinorFloat>(1.0));\n\n" 868 str +=
" READ_SPINOR( SPINORTEX, param.sp_stride, X/2, X/2 );\n" 869 str +=
" o00_re = C_5*o00_re + i00_re;\n" 870 str +=
" o00_im = C_5*o00_im + i00_im;\n" 871 str +=
" o01_re = C_5*o01_re + i01_re;\n" 872 str +=
" o01_im = C_5*o01_im + i01_im;\n" 873 str +=
" o02_re = C_5*o02_re + i02_re;\n" 874 str +=
" o02_im = C_5*o02_im + i02_im;\n" 875 str +=
" o10_re = C_5*o10_re + i10_re;\n" 876 str +=
" o10_im = C_5*o10_im + i10_im;\n" 877 str +=
" o11_re = C_5*o11_re + i11_re;\n" 878 str +=
" o11_im = C_5*o11_im + i11_im;\n" 879 str +=
" o12_re = C_5*o12_re + i12_re;\n" 880 str +=
" o12_im = C_5*o12_im + i12_im;\n" 881 str +=
" o20_re = C_5*o20_re + i20_re;\n" 882 str +=
" o20_im = C_5*o20_im + i20_im;\n" 883 str +=
" o21_re = C_5*o21_re + i21_re;\n" 884 str +=
" o21_im = C_5*o21_im + i21_im;\n" 885 str +=
" o22_re = C_5*o22_re + i22_re;\n" 886 str +=
" o22_im = C_5*o22_im + i22_im;\n" 887 str +=
" o30_re = C_5*o30_re + i30_re;\n" 888 str +=
" o30_im = C_5*o30_im + i30_im;\n" 889 str +=
" o31_re = C_5*o31_re + i31_re;\n" 890 str +=
" o31_im = C_5*o31_im + i31_im;\n" 891 str +=
" o32_re = C_5*o32_re + i32_re;\n" 892 str +=
" o32_im = C_5*o32_im + i32_im;\n" 893 str +=
"#endif // select MDWF mode\n" 894 str +=
"#endif // check MDWF on/off\n" 895 str +=
"} // end 5th dimension\n\n" 903 str +=
"VOLATILE spinorFloat kappa;\n\n" 904 str +=
"#ifdef MDWF_mode // Check whether MDWF option is enabled\n" 905 str +=
" kappa = -(mdwf_c5[ coord[4] ]*(static_cast<spinorFloat>(4.0) + m5) - static_cast<spinorFloat>(1.0))/(mdwf_b5[ coord[4] ]*(static_cast<spinorFloat>(4.0) + m5) + static_cast<spinorFloat>(1.0));\n" 907 str +=
" kappa = static_cast<spinorFloat>(2.0)*a;\n" 908 str +=
"#endif // select MDWF mode\n\n" 909 str +=
"// M5_inv operation -- NB: not partitionable!\n\n" 910 str +=
"// In this part, we will do the following operation in parallel way.\n\n" 911 str +=
"// w = M5inv * v\n" 912 str +=
"// 'w' means output vector\n" 913 str +=
"// 'v' means input vector\n" 915 str +=
" int base_idx = sid%param.dc.volume_4d_cb;\n" 916 str +=
" int sp_idx;\n\n" 917 str +=
"// let's assume the index,\n" 918 str +=
"// s = output vector index,\n" 919 str +=
"// s' = input vector index and\n" 920 str +=
"// 'a'= kappa5\n" 922 str +=
" spinorFloat inv_d_n = static_cast<spinorFloat>(0.5) / ( static_cast<spinorFloat>(1.0) + POW(kappa,param.dc.Ls)*mferm );\n" 923 str +=
" spinorFloat factorR;\n" 924 str +=
" spinorFloat factorL;\n" 926 str +=
" for(int s = 0; s < param.dc.Ls; s++)\n {\n" 928 str +=
" int exponent = coord[4] > s ? param.dc.Ls-coord[4]+s : s-coord[4];\n" 929 str +=
" factorR = inv_d_n * POW(kappa,exponent) * ( coord[4] > s ? -mferm : static_cast<spinorFloat>(1.0) );\n\n" 931 str +=
" int exponent = coord[4] < s ? param.dc.Ls-s+coord[4] : coord[4]-s;\n" 932 str +=
" factorR = inv_d_n * POW(kappa,exponent) * ( coord[4] < s ? -mferm : static_cast<spinorFloat>(1.0) );\n\n" 933 str +=
" sp_idx = base_idx + s*param.dc.volume_4d_cb;\n" 934 str +=
" // read spinor from device memory\n" 935 str +=
" READ_SPINOR( SPINORTEX, param.sp_stride, sp_idx, sp_idx );\n\n" 936 str +=
" o00_re += factorR*(i00_re + i20_re);\n" 937 str +=
" o00_im += factorR*(i00_im + i20_im);\n" 938 str +=
" o20_re += factorR*(i00_re + i20_re);\n" 939 str +=
" o20_im += factorR*(i00_im + i20_im);\n" 940 str +=
" o01_re += factorR*(i01_re + i21_re);\n" 941 str +=
" o01_im += factorR*(i01_im + i21_im);\n" 942 str +=
" o21_re += factorR*(i01_re + i21_re);\n" 943 str +=
" o21_im += factorR*(i01_im + i21_im);\n" 944 str +=
" o02_re += factorR*(i02_re + i22_re);\n" 945 str +=
" o02_im += factorR*(i02_im + i22_im);\n" 946 str +=
" o22_re += factorR*(i02_re + i22_re);\n" 947 str +=
" o22_im += factorR*(i02_im + i22_im);\n" 948 str +=
" o10_re += factorR*(i10_re + i30_re);\n" 949 str +=
" o10_im += factorR*(i10_im + i30_im);\n" 950 str +=
" o30_re += factorR*(i10_re + i30_re);\n" 951 str +=
" o30_im += factorR*(i10_im + i30_im);\n" 952 str +=
" o11_re += factorR*(i11_re + i31_re);\n" 953 str +=
" o11_im += factorR*(i11_im + i31_im);\n" 954 str +=
" o31_re += factorR*(i11_re + i31_re);\n" 955 str +=
" o31_im += factorR*(i11_im + i31_im);\n" 956 str +=
" o12_re += factorR*(i12_re + i32_re);\n" 957 str +=
" o12_im += factorR*(i12_im + i32_im);\n" 958 str +=
" o32_re += factorR*(i12_re + i32_re);\n" 959 str +=
" o32_im += factorR*(i12_im + i32_im);\n\n" 962 str +=
" int exponent2 = coord[4] < s ? param.dc.Ls-s+coord[4] : coord[4]-s;\n" 963 str +=
" factorL = inv_d_n * POW(kappa,exponent2) * ( coord[4] < s ? -mferm : static_cast<spinorFloat>(1.0));\n\n" 965 str +=
" int exponent2 = coord[4] > s ? param.dc.Ls-coord[4]+s : s-coord[4];\n" 966 str +=
" factorL = inv_d_n * POW(kappa,exponent2) * ( coord[4] > s ? -mferm : static_cast<spinorFloat>(1.0));\n\n" 968 str +=
" o00_re += factorL*(i00_re - i20_re);\n" 969 str +=
" o00_im += factorL*(i00_im - i20_im);\n" 970 str +=
" o01_re += factorL*(i01_re - i21_re);\n" 971 str +=
" o01_im += factorL*(i01_im - i21_im);\n" 972 str +=
" o02_re += factorL*(i02_re - i22_re);\n" 973 str +=
" o02_im += factorL*(i02_im - i22_im);\n" 974 str +=
" o10_re += factorL*(i10_re - i30_re);\n" 975 str +=
" o10_im += factorL*(i10_im - i30_im);\n" 976 str +=
" o11_re += factorL*(i11_re - i31_re);\n" 977 str +=
" o11_im += factorL*(i11_im - i31_im);\n" 978 str +=
" o12_re += factorL*(i12_re - i32_re);\n" 979 str +=
" o12_im += factorL*(i12_im - i32_im);\n" 980 str +=
" o20_re += factorL*(i20_re - i00_re);\n" 981 str +=
" o20_im += factorL*(i20_im - i00_im);\n" 982 str +=
" o21_re += factorL*(i21_re - i01_re);\n" 983 str +=
" o21_im += factorL*(i21_im - i01_im);\n" 984 str +=
" o22_re += factorL*(i22_re - i02_re);\n" 985 str +=
" o22_im += factorL*(i22_im - i02_im);\n" 986 str +=
" o30_re += factorL*(i30_re - i10_re);\n" 987 str +=
" o30_im += factorL*(i30_im - i10_im);\n" 988 str +=
" o31_re += factorL*(i31_re - i11_re);\n" 989 str +=
" o31_im += factorL*(i31_im - i11_im);\n" 990 str +=
" o32_re += factorL*(i32_re - i12_re);\n" 991 str +=
" o32_im += factorL*(i32_im - i12_im);\n" 993 str +=
"} // end of M5inv dimension\n\n" 994 str +=
"#undef POW\n" 1002 if z==0:
return out_re(s,c)
1005 if z==0:
return in_re(s,c)
1006 else:
return in_im(s,c)
1020 for s
in range (0,4):
1024 return block(str)+
"\n\n" 1030 str +=
"spinorFloat "+
a_re(0,0,c)+
" = "+
out_re(1,c)+
" + "+
out_re(3,c)+
";\n" 1031 str +=
"spinorFloat "+
a_im(0,0,c)+
" = "+
out_im(1,c)+
" + "+
out_im(3,c)+
";\n" 1032 str +=
"spinorFloat "+
a_re(0,1,c)+
" = -"+
out_re(0,c)+
" - "+
out_re(2,c)+
";\n" 1033 str +=
"spinorFloat "+
a_im(0,1,c)+
" = -"+
out_im(0,c)+
" - "+
out_im(2,c)+
";\n" 1034 str +=
"spinorFloat "+
a_re(0,2,c)+
" = "+
out_re(1,c)+
" - "+
out_re(3,c)+
";\n" 1035 str +=
"spinorFloat "+
a_im(0,2,c)+
" = "+
out_im(1,c)+
" - "+
out_im(3,c)+
";\n" 1036 str +=
"spinorFloat "+
a_re(0,3,c)+
" = -"+
out_re(0,c)+
" + "+
out_re(2,c)+
";\n" 1037 str +=
"spinorFloat "+
a_im(0,3,c)+
" = -"+
out_im(0,c)+
" + "+
out_im(2,c)+
";\n" 1040 for s
in range (0,4):
1044 return block(str)+
"\n\n" 1049 str =
"READ_CLOVER(CLOVERTEX, "+`chi`+
")\n\n" 1051 for s
in range (0,2):
1052 for c
in range (0,3):
1053 str +=
"spinorFloat "+
a_re(chi,s,c)+
" = 0; spinorFloat "+
a_im(chi,s,c)+
" = 0;\n" 1056 for sm
in range (0,2):
1057 for cm
in range (0,3):
1058 for sn
in range (0,2):
1059 for cn
in range (0,3):
1060 str +=
a_re(chi,sm,cm)+
" += "+
c_re(chi,sm,cm,sn,cn)+
" * "+
out_re(2*chi+sn,cn)+
";\n" 1061 if (sn != sm)
or (cn != cm):
1062 str +=
a_re(chi,sm,cm)+
" -= "+
c_im(chi,sm,cm,sn,cn)+
" * "+
out_im(2*chi+sn,cn)+
";\n" 1064 str +=
a_im(chi,sm,cm)+
" += "+
c_re(chi,sm,cm,sn,cn)+
" * "+
out_im(2*chi+sn,cn)+
";\n" 1065 if (sn != sm)
or (cn != cm):
1066 str +=
a_im(chi,sm,cm)+
" += "+
c_im(chi,sm,cm,sn,cn)+
" * "+
out_re(2*chi+sn,cn)+
";\n" 1070 for s
in range (0,2):
1071 for c
in range (0,3):
1072 str +=
out_re(2*chi+s,c)+
" = "+
a_re(chi,s,c)+
"; " 1073 str +=
out_im(2*chi+s,c)+
" = "+
a_im(chi,s,c)+
";\n" 1076 return block(str)+
"\n\n" 1081 if domain_wall:
return "" 1083 if dslash: str +=
"#ifdef DSLASH_CLOVER\n\n" 1084 str +=
"// change to chiral basis\n" 1086 str +=
"// apply first chiral block\n" 1088 str +=
"// apply second chiral block\n" 1090 str +=
"// change back from chiral basis\n" 1091 str +=
"// (note: required factor of 1/2 is included in clover term normalization)\n" 1093 if dslash: str +=
"#endif // DSLASH_CLOVER\n\n" 1102 elif dslash5
or dslash5inv:
1111 str +=
"#ifdef SPINOR_DOUBLE\n" 1113 for s
in range(0,4):
1114 for c
in range(0,3):
1121 for s
in range(0,4):
1122 for c
in range(0,3):
1127 str +=
"#endif // SPINOR_DOUBLE\n" 1133 str +=
"#ifdef DSLASH_XPAY\n" 1134 str +=
"READ_ACCUM(ACCUMTEX, param.sp_stride)\n" 1137 str +=
"VOLATILE spinorFloat coeff;\n\n" 1138 str +=
"#ifdef MDWF_mode\n" 1139 str +=
"coeff = static_cast<spinorFloat>(0.5)*a/(mdwf_b5[coord[4]]*(m5+static_cast<spinorFloat>(4.0)) + static_cast<spinorFloat>(1.0));\n" 1141 str +=
"coeff = a;\n" 1143 elif dslash5
or dslash5inv:
1144 str +=
"VOLATILE spinorFloat coeff;\n\n" 1145 str +=
"#ifdef MDWF_mode\n" 1146 str +=
"coeff = static_cast<spinorFloat>(0.5)/(mdwf_b5[coord[4]]*(m5+static_cast<spinorFloat>(4.0)) + static_cast<spinorFloat>(1.0));\n" 1147 str +=
"coeff *= coeff;\n" 1148 str +=
"coeff *= a;\n" 1151 str +=
"coeff = a;\n" 1153 str +=
"coeff = b;\n" 1156 str +=
"#ifdef YPAX\n" 1160 str +=
"#ifdef SPINOR_DOUBLE\n" 1162 for s
in range(0,4):
1163 for c
in range(0,3):
1170 for s
in range(0,4):
1171 for c
in range(0,3):
1176 str +=
"#endif // SPINOR_DOUBLE\n" 1178 str +=
"#endif // YPAX\n" 1179 str +=
"#endif // DSLASH_XPAY\n" 1190 str +=
"#if defined MULTI_GPU && defined DSLASH_XPAY\n" 1192 str +=
"#if defined MULTI_GPU && (defined DSLASH_XPAY || defined DSLASH_CLOVER)\n" 1196 int incomplete = 0; // Have all 8 contributions been computed for this site? 1198 switch(kernel_type) { // intentional fall-through 1199 case INTERIOR_KERNEL: 1200 incomplete = incomplete || (param.commDim[3] && (coord[3]==0 || coord[3]==(param.dc.X[3]-1))); 1201 case EXTERIOR_KERNEL_T: 1202 incomplete = incomplete || (param.commDim[2] && (coord[2]==0 || coord[2]==(param.dc.X[2]-1))); 1203 case EXTERIOR_KERNEL_Z: 1204 incomplete = incomplete || (param.commDim[1] && (coord[1]==0 || coord[1]==(param.dc.X[1]-1))); 1205 case EXTERIOR_KERNEL_Y: 1206 incomplete = incomplete || (param.commDim[0] && (coord[0]==0 || coord[0]==(param.dc.X[0]-1))); 1210 str +=
"if (!incomplete)\n" 1211 str +=
"#endif // MULTI_GPU\n" 1216 str +=
"// write spinor field back to device memory\n" 1217 str +=
"WRITE_SPINOR(param.sp_stride);\n\n" 1219 str +=
"// undefine to prevent warning when precision is changed\n" 1220 str +=
"#undef m5\n" 1221 str +=
"#undef mdwf_b5\n" 1222 str +=
"#undef mdwf_c5\n" 1223 str +=
"#undef mferm\n" 1226 str +=
"#undef spinorFloat\n" 1227 str +=
"#undef POW\n" 1228 str +=
"#undef SHARED_STRIDE\n\n" 1232 for m
in range(0,3):
1233 for n
in range(0,3):
1235 str +=
"#undef "+
g_re(0,m,n)+
"\n" 1236 str +=
"#undef "+
g_im(0,m,n)+
"\n" 1239 for s
in range(0,4):
1240 for c
in range(0,3):
1242 str +=
"#undef "+
in_re(s,c)+
"\n" 1243 str +=
"#undef "+
in_im(s,c)+
"\n" 1247 for m
in range(0,6):
1250 str +=
"#undef "+
c_re(0,s,c,s,c)+
"\n" 1251 for n
in range(0,6):
1254 for m
in range(n+1,6):
1257 str +=
"#undef "+
c_re(0,sm,cm,sn,cn)+
"\n" 1258 str +=
"#undef "+
c_im(0,sm,cm,sn,cn)+
"\n" 1261 for s
in range(0,4):
1262 for c
in range(0,3):
1264 if 2*i < sharedFloats:
1265 str +=
"#undef "+
out_re(s,c)+
"\n" 1266 if 2*i+1 < sharedFloats:
1267 str +=
"#undef "+
out_im(s,c)+
"\n" 1270 str +=
"#undef VOLATILE\n" 1278 str +=
"switch(dim) {\n" 1279 for dim
in range(0,4):
1280 str +=
"case "+`dim`+
":\n" 1281 proj =
gen(2*dim+facenum, pack_only=
True)
1283 proj +=
"// write half spinor back to device memory\n" 1284 proj +=
"WRITE_HALF_SPINOR(face_volume, face_idx);\n" 1291 assert (sharedFloats == 0)
1294 str +=
"#include \"io_spinor.h\"\n\n" 1296 str +=
"if (face_num) " 1302 str +=
"// undefine to prevent warning when precision is changed\n" 1303 str +=
"#undef spinorFloat\n" 1304 str +=
"#undef SHARED_STRIDE\n\n" 1306 for s
in range(0,4):
1307 for c
in range(0,3):
1309 str +=
"#undef "+
in_re(s,c)+
"\n" 1310 str +=
"#undef "+
in_im(s,c)+
"\n" 1318 for i
in range(0,8) :
1327 for i
in range(0,8) :
1350 cloverSharedFloats = 0
1351 if(
len(sys.argv) > 1):
1352 if (sys.argv[1] ==
'--shared'):
1353 sharedFloats =
int(sys.argv[2])
1354 print "Shared floats set to " + str(sharedFloats);
1364 print sys.argv[0] +
": generating dw_dslash4_core.h";
1367 f = open(
'dslash_core/dw_dslash4_core.h',
'w')
1371 print sys.argv[0] +
": generating dw_dslash4_dagger_core.h";
1374 f = open(
'dslash_core/dw_dslash4_dagger_core.h',
'w')
1381 print sys.argv[0] +
": generating dw_dslash5_core.h";
1384 f = open(
'dslash_core/dw_dslash5_core.h',
'w')
1388 print sys.argv[0] +
": generating dw_dslash5_dagger_core.h";
1391 f = open(
'dslash_core/dw_dslash5_dagger_core.h',
'w')
1398 print sys.argv[0] +
": generating dw_dslash5inv_core.h";
1401 f = open(
'dslash_core/dw_dslash5inv_core.h',
'w')
1405 print sys.argv[0] +
": generating dw_dslash5inv_dagger_core.h";
1408 f = open(
'dslash_core/dw_dslash5inv_dagger_core.h',
'w')
def gen(dir, pack_only=False)
def generate_dslash5D_inv()
def input_spinor(s, c, z)
def c_im(b, sm, cm, sn, cn)
def complexify(a)
complex numbers ######################################################################## ...
def c_re(b, sm, cm, sn, cn)
def indent(code, n=1)
code generation ######################################################################## ...