QUDA v0.4.0
A library for QCD on GPUs
|
00001 00002 //macro KERNEL_ENABLED is used to control compile time, debug purpose only 00003 #if (PRECISION == 0 && RECON == 18) 00004 #define EXT _dp_18_ 00005 #ifdef COMPILE_HISQ_DP_18 00006 #define KERNEL_ENABLED 00007 #endif 00008 #elif (PRECISION == 0 && RECON == 12) 00009 #define EXT _dp_12_ 00010 #ifdef COMPILE_HISQ_DP_12 00011 #define KERNEL_ENABLED 00012 #endif 00013 #elif (PRECISION == 1 && RECON == 18) 00014 #define EXT _sp_18_ 00015 #ifdef COMPILE_HISQ_SP_18 00016 #define KERNEL_ENABLED 00017 #endif 00018 #else 00019 #define EXT _sp_12_ 00020 #ifdef COMPILE_HISQ_SP_12 00021 #define KERNEL_ENABLED 00022 #endif 00023 #endif 00024 00025 00026 /**************************do_middle_link_kernel***************************** 00027 * 00028 * 00029 * Generally we need 00030 * READ 00031 * 3 LINKS: ab_link, bc_link, ad_link 00032 * 3 COLOR MATRIX: newOprod_at_A, oprod_at_C, Qprod_at_D 00033 * WRITE 00034 * 4 COLOR MATRIX: newOprod_at_A, P3_at_A, Pmu_at_B, Qmu_at_A 00035 * 00036 * Three call variations: 00037 * 1. when Qprev == NULL: Qprod_at_D does not exit and is not read in 00038 * 2. full read/write 00039 * 3. when Pmu/Qmu == NULL, Pmu_at_B and Qmu_at_A are not written out 00040 * 00041 * In all three above case, if the direction sig is negative, newOprod_at_A is 00042 * not read in or written out. 00043 * 00044 * Therefore the data traffic, in two-number pair (num_of_link, num_of_color_matrix) 00045 * Call 1: (called 48 times, half positive sig, half negative sig) 00046 * if (sig is positive): (3, 6) 00047 * else : (3, 4) 00048 * Call 2: (called 192 time, half positive sig, half negative sig) 00049 * if (sig is positive): (3, 7) 00050 * else : (3, 5) 00051 * Call 3: (called 48 times, half positive sig, half negative sig) 00052 * if (sig is positive): (3, 5) 00053 * else : (3, 2) no need to loadQprod_at_D in this case 00054 * 00055 * note: oprod_at_C could actually be read in from D when it is the fresh outer product 00056 * and we call it oprod_at_C to simply naming. This does not affect our data traffic analysis 00057 * 00058 * Flop count, in two-number pair (matrix_multi, matrix_add) 00059 * call 1: if (sig is positive) (3, 1) 00060 * else (2, 0) 00061 * call 2: if (sig is positive) (4, 1) 00062 * else (3, 0) 00063 * call 3: if (sig is positive) (4, 1) 00064 * else (2, 0) 00065 * 00066 ****************************************************************************/ 00067 template<class RealA, class RealB, int sig_positive, int mu_positive, int oddBit> 00068 __global__ void 00069 HISQ_KERNEL_NAME(do_middle_link, EXT)(const RealA* const oprodEven, const RealA* const oprodOdd, 00070 const RealA* const QprevEven, const RealA* const QprevOdd, 00071 const RealB* const linkEven, const RealB* const linkOdd, 00072 int sig, int mu, 00073 typename RealTypeId<RealA>::Type coeff, 00074 RealA* const PmuEven, RealA* const PmuOdd, 00075 RealA* const P3Even, RealA* const P3Odd, 00076 RealA* const QmuEven, RealA* const QmuOdd, 00077 RealA* const newOprodEven, RealA* const newOprodOdd) 00078 { 00079 00080 #ifdef KERNEL_ENABLED 00081 int sid = blockIdx.x * blockDim.x + threadIdx.x; 00082 00083 int x[4]; 00084 int z1 = sid/X1h; 00085 int x1h = sid - z1*X1h; 00086 int z2 = z1/X2; 00087 x[1] = z1 - z2*X2; 00088 x[3] = z2/X3; 00089 x[2] = z2 - x[3]*X3; 00090 int x1odd = (x[1] + x[2] + x[3] + oddBit) & 1; 00091 x[0] = 2*x1h + x1odd; 00092 int X = 2*sid + x1odd; 00093 00094 int new_x[4]; 00095 int new_mem_idx; 00096 #if(RECON == 12) 00097 int ad_link_sign; 00098 int ab_link_sign; 00099 int bc_link_sign; 00100 #endif 00101 00102 RealA ab_link[ArrayLength<RealA>::result]; 00103 RealA bc_link[ArrayLength<RealA>::result]; 00104 RealA ad_link[ArrayLength<RealA>::result]; 00105 00106 RealA COLOR_MAT_W[ArrayLength<RealA>::result]; 00107 RealA COLOR_MAT_Y[ArrayLength<RealA>::result]; 00108 RealA COLOR_MAT_X[ArrayLength<RealA>::result]; 00109 00110 /* A________B 00111 * mu | | 00112 * D| |C 00113 * 00114 * A is the current point (sid) 00115 * 00116 */ 00117 00118 int point_b, point_c, point_d; 00119 int ad_link_nbr_idx, ab_link_nbr_idx, bc_link_nbr_idx; 00120 int mymu; 00121 00122 new_x[0] = x[0]; 00123 new_x[1] = x[1]; 00124 new_x[2] = x[2]; 00125 new_x[3] = x[3]; 00126 00127 if(mu_positive){ 00128 mymu = mu; 00129 FF_COMPUTE_NEW_FULL_IDX_MINUS_UPDATE(mu, X, new_mem_idx); 00130 }else{ 00131 mymu = OPP_DIR(mu); 00132 FF_COMPUTE_NEW_FULL_IDX_PLUS_UPDATE(OPP_DIR(mu), X, new_mem_idx); 00133 } 00134 point_d = (new_mem_idx >> 1); 00135 if (mu_positive){ 00136 ad_link_nbr_idx = point_d; 00137 COMPUTE_LINK_SIGN(&ad_link_sign, mymu, new_x); 00138 }else{ 00139 ad_link_nbr_idx = sid; 00140 COMPUTE_LINK_SIGN(&ad_link_sign, mymu, x); 00141 } 00142 00143 int mysig; 00144 if(sig_positive){ 00145 mysig = sig; 00146 FF_COMPUTE_NEW_FULL_IDX_PLUS_UPDATE(sig, new_mem_idx, new_mem_idx); 00147 }else{ 00148 mysig = OPP_DIR(sig); 00149 FF_COMPUTE_NEW_FULL_IDX_MINUS_UPDATE(OPP_DIR(sig), new_mem_idx, new_mem_idx); 00150 } 00151 point_c = (new_mem_idx >> 1); 00152 if (mu_positive){ 00153 bc_link_nbr_idx = point_c; 00154 COMPUTE_LINK_SIGN(&bc_link_sign, mymu, new_x); 00155 } 00156 00157 new_x[0] = x[0]; 00158 new_x[1] = x[1]; 00159 new_x[2] = x[2]; 00160 new_x[3] = x[3]; 00161 00162 if(sig_positive){ 00163 FF_COMPUTE_NEW_FULL_IDX_PLUS_UPDATE(sig, X, new_mem_idx); 00164 }else{ 00165 FF_COMPUTE_NEW_FULL_IDX_MINUS_UPDATE(OPP_DIR(sig), X, new_mem_idx); 00166 } 00167 point_b = (new_mem_idx >> 1); 00168 00169 if (!mu_positive){ 00170 bc_link_nbr_idx = point_b; 00171 COMPUTE_LINK_SIGN(&bc_link_sign, mymu, new_x); 00172 } 00173 00174 if(sig_positive){ 00175 ab_link_nbr_idx = sid; 00176 COMPUTE_LINK_SIGN(&ab_link_sign, mysig, x); 00177 }else{ 00178 ab_link_nbr_idx = point_b; 00179 COMPUTE_LINK_SIGN(&ab_link_sign, mysig, new_x); 00180 } 00181 // now we have ab_link_nbr_idx 00182 00183 00184 // load the link variable connecting a and b 00185 // Store in ab_link 00186 if(sig_positive){ 00187 HISQ_LOAD_LINK(linkEven, linkOdd, mysig, ab_link_nbr_idx, ab_link, oddBit); 00188 }else{ 00189 HISQ_LOAD_LINK(linkEven, linkOdd, mysig, ab_link_nbr_idx, ab_link, 1-oddBit); 00190 } 00191 RECONSTRUCT_SITE_LINK(ab_link, ab_link_sign) 00192 00193 // load the link variable connecting b and c 00194 // Store in bc_link 00195 if(mu_positive){ 00196 HISQ_LOAD_LINK(linkEven, linkOdd, mymu, bc_link_nbr_idx, bc_link, oddBit); 00197 }else{ 00198 HISQ_LOAD_LINK(linkEven, linkOdd, mymu, bc_link_nbr_idx, bc_link, 1-oddBit); 00199 } 00200 RECONSTRUCT_SITE_LINK(bc_link, bc_link_sign) 00201 00202 if(QprevOdd == NULL){ 00203 if(sig_positive){ 00204 loadMatrixFromField(oprodEven, oprodOdd, sig, point_d, COLOR_MAT_Y, 1-oddBit); 00205 }else{ 00206 loadMatrixFromField(oprodEven, oprodOdd, OPP_DIR(sig), point_c, COLOR_MAT_Y, oddBit); 00207 adjointMatrix(COLOR_MAT_Y); 00208 } 00209 }else{ // QprevOdd != NULL 00210 loadMatrixFromField(oprodEven, oprodOdd, point_c, COLOR_MAT_Y, oddBit); 00211 } 00212 00213 00214 MATRIX_PRODUCT(bc_link, COLOR_MAT_Y, !mu_positive, COLOR_MAT_W); 00215 if(PmuOdd){ 00216 storeMatrixToField(COLOR_MAT_W, point_b, PmuEven, PmuOdd, 1-oddBit); 00217 } 00218 MATRIX_PRODUCT(ab_link, COLOR_MAT_W, sig_positive,COLOR_MAT_Y); 00219 storeMatrixToField(COLOR_MAT_Y, sid, P3Even, P3Odd, oddBit); 00220 00221 00222 if(mu_positive){ 00223 HISQ_LOAD_LINK(linkEven, linkOdd, mymu, ad_link_nbr_idx, ad_link, 1-oddBit); 00224 RECONSTRUCT_SITE_LINK(ad_link, ad_link_sign) 00225 }else{ 00226 HISQ_LOAD_LINK(linkEven, linkOdd, mymu, ad_link_nbr_idx, ad_link, oddBit); 00227 RECONSTRUCT_SITE_LINK(ad_link, ad_link_sign) 00228 adjointMatrix(ad_link); 00229 00230 } 00231 00232 00233 if(QprevOdd == NULL){ 00234 if(sig_positive){ 00235 MAT_MUL_MAT(COLOR_MAT_W, ad_link, COLOR_MAT_Y); 00236 } 00237 if(QmuEven){ 00238 ASSIGN_MAT(ad_link, COLOR_MAT_X); 00239 storeMatrixToField(COLOR_MAT_X, sid, QmuEven, QmuOdd, oddBit); 00240 } 00241 }else{ 00242 if(QmuEven || sig_positive){ 00243 loadMatrixFromField(QprevEven, QprevOdd, point_d, COLOR_MAT_Y, 1-oddBit); 00244 MAT_MUL_MAT(COLOR_MAT_Y, ad_link, COLOR_MAT_X); 00245 } 00246 if(QmuEven){ 00247 storeMatrixToField(COLOR_MAT_X, sid, QmuEven, QmuOdd, oddBit); 00248 } 00249 if(sig_positive){ 00250 MAT_MUL_MAT(COLOR_MAT_W, COLOR_MAT_X, COLOR_MAT_Y); 00251 } 00252 } 00253 00254 if(sig_positive){ 00255 //addMatrixToField(COLOR_MAT_Y, sig, sid, coeff, newOprodEven, newOprodOdd, oddBit); 00256 addMatrixToNewOprod(COLOR_MAT_Y, sig, sid, coeff, newOprodEven, newOprodOdd, oddBit); 00257 } 00258 00259 #endif 00260 return; 00261 } 00262 00263 /***********************************do_side_link_kernel*************************** 00264 * 00265 * In general we need 00266 * READ 00267 * 1 LINK: ad_link 00268 * 4 COLOR MATRIX: shortP_at_D, newOprod, P3_at_A, Qprod_at_D, 00269 * WRITE 00270 * 2 COLOR MATRIX: shortP_at_D, newOprod, 00271 * 00272 * Two call variations: 00273 * 1. full read/write 00274 * 2. when shortP == NULL && Qprod == NULL: 00275 * no need to read ad_link/shortP_at_D or write shortP_at_D 00276 * Qprod_at_D does not exit and is not read in 00277 * 00278 * 00279 * Therefore the data traffic, in two-number pair (num_of_links, num_of_color_matrix) 00280 * Call 1: (called 192 times) 00281 * (1, 6) 00282 * 00283 * Call 2: (called 48 times) 00284 * (0, 3) 00285 * 00286 * note: newOprod can be at point D or A, depending on if mu is postive or negative 00287 * 00288 * Flop count, in two-number pair (matrix_multi, matrix_add) 00289 * call 1: (2, 2) 00290 * call 2: (0, 1) 00291 * 00292 *********************************************************************************/ 00293 00294 template<class RealA, class RealB, int sig_positive, int mu_positive, int oddBit> 00295 __global__ void 00296 HISQ_KERNEL_NAME(do_side_link, EXT)(const RealA* const P3Even, const RealA* const P3Odd, 00297 const RealA* const QprodEven, const RealA* const QprodOdd, 00298 const RealB* const linkEven, const RealB* const linkOdd, 00299 int sig, int mu, 00300 typename RealTypeId<RealA>::Type coeff, 00301 typename RealTypeId<RealA>::Type accumu_coeff, 00302 RealA* const shortPEven, RealA* const shortPOdd, 00303 RealA* const newOprodEven, RealA* const newOprodOdd) 00304 { 00305 #ifdef KERNEL_ENABLED 00306 00307 int sid = blockIdx.x * blockDim.x + threadIdx.x; 00308 00309 int x[4]; 00310 int z1 = sid/X1h; 00311 int x1h = sid - z1*X1h; 00312 int z2 = z1/X2; 00313 x[1] = z1 - z2*X2; 00314 x[3] = z2/X3; 00315 x[2] = z2 - x[3]*X3; 00316 int x1odd = (x[1] + x[2] + x[3] + oddBit) & 1; 00317 x[0] = 2*x1h + x1odd; 00318 int X = 2*sid + x1odd; 00319 00320 #if(RECON == 12) 00321 int ad_link_sign; 00322 #endif 00323 00324 00325 RealA ad_link[ArrayLength<RealA>::result]; 00326 00327 RealA COLOR_MAT_W[ArrayLength<RealA>::result]; 00328 RealA COLOR_MAT_X[ArrayLength<RealA>::result]; 00329 RealA COLOR_MAT_Y[ArrayLength<RealA>::result]; 00330 // The compiler probably knows to reorder so that loads are done early on 00331 loadMatrixFromField(P3Even, P3Odd, sid, COLOR_MAT_Y, oddBit); 00332 00333 /* compute the side link contribution to the momentum 00334 * 00335 * sig 00336 * A________B 00337 * | | mu 00338 * D | |C 00339 * 00340 * A is the current point (sid) 00341 * 00342 */ 00343 00344 typename RealTypeId<RealA>::Type mycoeff; 00345 int point_d; 00346 int ad_link_nbr_idx; 00347 int mymu; 00348 int new_mem_idx; 00349 00350 int new_x[4]; 00351 new_x[0] = x[0]; 00352 new_x[1] = x[1]; 00353 new_x[2] = x[2]; 00354 new_x[3] = x[3]; 00355 00356 if(mu_positive){ 00357 mymu=mu; 00358 FF_COMPUTE_NEW_FULL_IDX_MINUS_UPDATE(mymu,X, new_mem_idx); 00359 }else{ 00360 mymu = OPP_DIR(mu); 00361 FF_COMPUTE_NEW_FULL_IDX_PLUS_UPDATE(mymu, X, new_mem_idx); 00362 } 00363 point_d = (new_mem_idx >> 1); 00364 00365 00366 // Should all be inside if (shortPOdd) 00367 if (shortPOdd){ 00368 if (mu_positive){ 00369 ad_link_nbr_idx = point_d; 00370 COMPUTE_LINK_SIGN(&ad_link_sign, mymu, new_x); 00371 }else{ 00372 ad_link_nbr_idx = sid; 00373 COMPUTE_LINK_SIGN(&ad_link_sign, mymu, x); 00374 } 00375 00376 00377 if(mu_positive){ 00378 HISQ_LOAD_LINK(linkEven, linkOdd, mymu, ad_link_nbr_idx, ad_link, 1-oddBit); 00379 }else{ 00380 HISQ_LOAD_LINK(linkEven, linkOdd, mymu, ad_link_nbr_idx, ad_link, oddBit); 00381 } 00382 RECONSTRUCT_SITE_LINK(ad_link, ad_link_sign) 00383 00384 MATRIX_PRODUCT(ad_link, COLOR_MAT_Y, mu_positive, COLOR_MAT_W); 00385 addMatrixToField(COLOR_MAT_W, point_d, accumu_coeff, shortPEven, shortPOdd, 1-oddBit); 00386 } 00387 00388 00389 mycoeff = CoeffSign<sig_positive,oddBit>::result*coeff; 00390 00391 if(QprodOdd){ 00392 loadMatrixFromField(QprodEven, QprodOdd, point_d, COLOR_MAT_X, 1-oddBit); 00393 if(mu_positive){ 00394 MAT_MUL_MAT(COLOR_MAT_Y, COLOR_MAT_X, COLOR_MAT_W); 00395 00396 // Added by J.F. 00397 if(!oddBit){ mycoeff = -mycoeff; } 00398 addMatrixToField(COLOR_MAT_W, mu, point_d, mycoeff, newOprodEven, newOprodOdd, 1-oddBit); 00399 }else{ 00400 ADJ_MAT_MUL_ADJ_MAT(COLOR_MAT_X, COLOR_MAT_Y, COLOR_MAT_W); 00401 if(oddBit){ mycoeff = -mycoeff; } 00402 addMatrixToField(COLOR_MAT_W, OPP_DIR(mu), sid, mycoeff, newOprodEven, newOprodOdd, oddBit); 00403 } 00404 } 00405 00406 if(!QprodOdd){ 00407 if(mu_positive){ 00408 if(!oddBit){ mycoeff = -mycoeff;} 00409 //addMatrixToField(COLOR_MAT_Y, mu, point_d, mycoeff, newOprodEven, newOprodOdd, 1-oddBit); 00410 addMatrixToNewOprod(COLOR_MAT_Y, mu, point_d, mycoeff, newOprodEven, newOprodOdd, 1-oddBit); 00411 }else{ 00412 if(oddBit){ mycoeff = -mycoeff; } 00413 ADJ_MAT(COLOR_MAT_Y, COLOR_MAT_W); 00414 //addMatrixToField(COLOR_MAT_W, OPP_DIR(mu), sid, mycoeff, newOprodEven, newOprodOdd, oddBit); 00415 addMatrixToNewOprod(COLOR_MAT_W, OPP_DIR(mu), sid, mycoeff, newOprodEven, newOprodOdd, oddBit); 00416 } 00417 } 00418 #endif 00419 return; 00420 } 00421 00422 /********************************do_all_link_kernel********************************************* 00423 * 00424 * In this function we need 00425 * READ 00426 * 3 LINKS: ad_link, ab_link, bc_link 00427 * 5 COLOR MATRIX: Qprev_at_D, oprod_at_C, newOprod_at_A(sig), newOprod_at_D/newOprod_at_A(mu), shortP_at_D 00428 * WRITE: 00429 * 3 COLOR MATRIX: newOprod_at_A(sig), newOprod_at_D/newOprod_at_A(mu), shortP_at_D, 00430 * 00431 * If sig is negative, then we don't need to read/write the color matrix newOprod_at_A(sig) 00432 * 00433 * Therefore the data traffic, in two-number pair (num_of_link, num_of_color_matrix) 00434 * 00435 * if (sig is positive): (3, 8) 00436 * else : (3, 6) 00437 * 00438 * This function is called 384 times, half positive sig, half negative sig 00439 * 00440 * Flop count, in two-number pair (matrix_multi, matrix_add) 00441 * if(sig is positive) (6,3) 00442 * else (4,2) 00443 * 00444 ************************************************************************************************/ 00445 00446 template<class RealA, class RealB, short sig_positive, short mu_positive, short oddBit> 00447 __global__ void 00448 HISQ_KERNEL_NAME(do_all_link, EXT)(const RealA* const oprodEven, const RealA* const oprodOdd, 00449 const RealA* const QprevEven, const RealA* const QprevOdd, 00450 const RealB* const linkEven, const RealB* const linkOdd, 00451 short sig, short mu, 00452 typename RealTypeId<RealA>::Type coeff, 00453 typename RealTypeId<RealA>::Type accumu_coeff, 00454 RealA* const shortPEven, RealA* const shortPOdd, 00455 RealA* const newOprodEven, RealA* const newOprodOdd) 00456 { 00457 #ifdef KERNEL_ENABLED 00458 int sid = blockIdx.x * blockDim.x + threadIdx.x; 00459 short x[4]; 00460 int z1 = sid/X1h; 00461 short x1h = sid - z1*X1h; 00462 int z2 = z1/X2; 00463 x[1] = z1 - z2*X2; 00464 x[3] = z2/X3; 00465 x[2] = z2 - x[3]*X3; 00466 short x1odd = (x[1] + x[2] + x[3] + oddBit) & 1; 00467 x[0] = 2*x1h + x1odd; 00468 int X = 2*sid + x1odd; 00469 00470 #if(RECON == 12) 00471 int ad_link_sign; 00472 int ab_link_sign; 00473 int bc_link_sign; 00474 #endif 00475 00476 short new_x[4]; 00477 00478 RealA ab_link[ArrayLength<RealA>::result]; 00479 RealA bc_link[ArrayLength<RealA>::result]; 00480 RealA ad_link[ArrayLength<RealA>::result]; 00481 00482 RealA COLOR_MAT_X[ArrayLength<RealA>::result]; 00483 RealA COLOR_MAT_Y[ArrayLength<RealA>::result]; 00484 RealA COLOR_MAT_Z[ArrayLength<RealA>::result]; 00485 RealA COLOR_MAT_W[ArrayLength<RealA>::result]; 00486 00487 00488 /* sig 00489 * A________B 00490 * mu | | 00491 * D | |C 00492 * 00493 * A is the current point (sid) 00494 * 00495 */ 00496 00497 int point_b, point_c, point_d; 00498 int ab_link_nbr_idx; 00499 int new_mem_idx; 00500 new_x[0] = x[0]; 00501 new_x[1] = x[1]; 00502 new_x[2] = x[2]; 00503 new_x[3] = x[3]; 00504 00505 if(sig_positive){ 00506 FF_COMPUTE_NEW_FULL_IDX_PLUS_UPDATE(sig, X, new_mem_idx); 00507 }else{ 00508 FF_COMPUTE_NEW_FULL_IDX_MINUS_UPDATE(OPP_DIR(sig), X, new_mem_idx); 00509 } 00510 point_b = (new_mem_idx >> 1); 00511 ab_link_nbr_idx = (sig_positive) ? sid : point_b; 00512 if(sig_positive){ 00513 COMPUTE_LINK_SIGN(&ab_link_sign, sig, x); 00514 }else{ 00515 COMPUTE_LINK_SIGN(&ab_link_sign, OPP_DIR(sig), new_x); 00516 } 00517 if(!mu_positive){ 00518 COMPUTE_LINK_SIGN(&bc_link_sign, OPP_DIR(mu), new_x); 00519 } 00520 new_x[0] = x[0]; 00521 new_x[1] = x[1]; 00522 new_x[2] = x[2]; 00523 new_x[3] = x[3]; 00524 00525 00526 const typename RealTypeId<RealA>::Type & mycoeff = CoeffSign<sig_positive,oddBit>::result*coeff; 00527 if(mu_positive){ //positive mu 00528 FF_COMPUTE_NEW_FULL_IDX_MINUS_UPDATE(mu, X, new_mem_idx); 00529 point_d = (new_mem_idx >> 1); 00530 loadMatrixFromField(QprevEven, QprevOdd, point_d, COLOR_MAT_X, 1-oddBit); // COLOR_MAT_X 00531 COMPUTE_LINK_SIGN(&ad_link_sign, mu, new_x); 00532 HISQ_LOAD_LINK(linkEven, linkOdd, mu, point_d, ad_link, 1-oddBit); 00533 RECONSTRUCT_SITE_LINK(ad_link, ad_link_sign) 00534 00535 if(sig_positive){ 00536 FF_COMPUTE_NEW_FULL_IDX_PLUS_UPDATE(sig, new_mem_idx, new_mem_idx); 00537 }else{ 00538 FF_COMPUTE_NEW_FULL_IDX_MINUS_UPDATE(OPP_DIR(sig), new_mem_idx, new_mem_idx); 00539 } 00540 point_c = (new_mem_idx >> 1); 00541 loadMatrixFromField(oprodEven,oprodOdd, point_c, COLOR_MAT_Y, oddBit); // COLOR_MAT_Y 00542 HISQ_LOAD_LINK(linkEven, linkOdd, mu, point_c, bc_link, oddBit); 00543 COMPUTE_LINK_SIGN(&bc_link_sign, mu, new_x); 00544 RECONSTRUCT_SITE_LINK(bc_link, bc_link_sign) 00545 00546 MATRIX_PRODUCT(bc_link, COLOR_MAT_Y, 0, COLOR_MAT_Z); // COMPUTE_LINK_X 00547 00548 00549 if (sig_positive) 00550 { 00551 MAT_MUL_MAT(COLOR_MAT_X, ad_link, COLOR_MAT_Y); 00552 MAT_MUL_MAT(COLOR_MAT_Z, COLOR_MAT_Y, COLOR_MAT_W); 00553 //addMatrixToField(COLOR_MAT_W, sig, sid, Sign<oddBit>::result*mycoeff, newOprodEven, newOprodOdd, oddBit); 00554 addMatrixToNewOprod(COLOR_MAT_W, sig, sid, Sign<oddBit>::result*mycoeff, newOprodEven, newOprodOdd, oddBit); 00555 } 00556 00557 if (sig_positive){ 00558 HISQ_LOAD_LINK(linkEven, linkOdd, sig, ab_link_nbr_idx, ab_link, oddBit); 00559 }else{ 00560 HISQ_LOAD_LINK(linkEven, linkOdd, OPP_DIR(sig), ab_link_nbr_idx, ab_link, 1-oddBit); 00561 } 00562 RECONSTRUCT_SITE_LINK(ab_link, ab_link_sign) 00563 00564 MATRIX_PRODUCT(ab_link, COLOR_MAT_Z, sig_positive, COLOR_MAT_Y); // COLOR_MAT_Y is assigned here 00565 00566 MAT_MUL_MAT(COLOR_MAT_Y, COLOR_MAT_X, COLOR_MAT_W); 00567 //addMatrixToField(COLOR_MAT_W, mu, point_d, -Sign<oddBit>::result*mycoeff, newOprodEven, newOprodOdd, 1-oddBit); 00568 addMatrixToNewOprod(COLOR_MAT_W, mu, point_d, -Sign<oddBit>::result*mycoeff, newOprodEven, newOprodOdd, 1-oddBit); 00569 00570 MAT_MUL_MAT(ad_link, COLOR_MAT_Y, COLOR_MAT_W); 00571 addMatrixToField(COLOR_MAT_W, point_d, accumu_coeff, shortPEven, shortPOdd, 1-oddBit); 00572 } else{ //negative mu 00573 mu = OPP_DIR(mu); 00574 00575 new_x[0] = x[0]; 00576 new_x[1] = x[1]; 00577 new_x[2] = x[2]; 00578 new_x[3] = x[3]; 00579 FF_COMPUTE_NEW_FULL_IDX_PLUS_UPDATE(mu, X, new_mem_idx); 00580 point_d = (new_mem_idx >> 1); 00581 loadMatrixFromField(QprevEven, QprevOdd, point_d, COLOR_MAT_X, 1-oddBit); // COLOR_MAT_X used! 00582 HISQ_LOAD_LINK(linkEven, linkOdd, mu, sid, ad_link, oddBit); 00583 COMPUTE_LINK_SIGN(&ad_link_sign, mu, x); 00584 RECONSTRUCT_SITE_LINK(ad_link, ad_link_sign) 00585 00586 if(sig_positive){ 00587 FF_COMPUTE_NEW_FULL_IDX_PLUS_UPDATE(sig, new_mem_idx, new_mem_idx); 00588 }else{ 00589 FF_COMPUTE_NEW_FULL_IDX_MINUS_UPDATE(OPP_DIR(sig), new_mem_idx, new_mem_idx); 00590 } 00591 point_c = (new_mem_idx >> 1); 00592 loadMatrixFromField(oprodEven, oprodOdd, point_c, COLOR_MAT_Y, oddBit); // COLOR_MAT_Y used 00593 HISQ_LOAD_LINK(linkEven, linkOdd, mu, point_b, bc_link, 1-oddBit); 00594 RECONSTRUCT_SITE_LINK(bc_link, bc_link_sign) //bc_link_sign is computed earlier in the function 00595 00596 if(sig_positive){ 00597 MAT_MUL_ADJ_MAT(COLOR_MAT_X, ad_link, COLOR_MAT_W); 00598 } 00599 MAT_MUL_MAT(bc_link, COLOR_MAT_Y, COLOR_MAT_Z); 00600 if (sig_positive){ 00601 MAT_MUL_MAT(COLOR_MAT_Z, COLOR_MAT_W, COLOR_MAT_Y); 00602 //addMatrixToField(COLOR_MAT_Y, sig, sid, Sign<oddBit>::result*mycoeff, newOprodEven, newOprodOdd, oddBit); 00603 addMatrixToNewOprod(COLOR_MAT_Y, sig, sid, Sign<oddBit>::result*mycoeff, newOprodEven, newOprodOdd, oddBit); 00604 } 00605 00606 if (sig_positive){ 00607 HISQ_LOAD_LINK(linkEven, linkOdd, sig, ab_link_nbr_idx, ab_link, oddBit); 00608 }else{ 00609 HISQ_LOAD_LINK(linkEven, linkOdd, OPP_DIR(sig), ab_link_nbr_idx, ab_link, 1-oddBit); 00610 } 00611 RECONSTRUCT_SITE_LINK(ab_link, ab_link_sign) 00612 00613 MATRIX_PRODUCT(ab_link, COLOR_MAT_Z, sig_positive, COLOR_MAT_Y); 00614 ADJ_MAT_MUL_ADJ_MAT(COLOR_MAT_X, COLOR_MAT_Y, COLOR_MAT_W); 00615 //addMatrixToField(COLOR_MAT_W, mu, sid, Sign<oddBit>::result*mycoeff, newOprodEven, newOprodOdd, oddBit); 00616 addMatrixToNewOprod(COLOR_MAT_W, mu, sid, Sign<oddBit>::result*mycoeff, newOprodEven, newOprodOdd, oddBit); 00617 00618 MATRIX_PRODUCT(ad_link, COLOR_MAT_Y, 0, COLOR_MAT_W); 00619 addMatrixToField(COLOR_MAT_W, point_d, accumu_coeff, shortPEven, shortPOdd, 1-oddBit); 00620 } 00621 #endif 00622 return; 00623 } 00624 00625 00626 00627 00628 00629 template<class RealA, class RealB, int oddBit> 00630 __global__ void 00631 HISQ_KERNEL_NAME(do_longlink, EXT)(const RealB* const linkEven, const RealB* const linkOdd, 00632 const RealA* const naikOprodEven, const RealA* const naikOprodOdd, 00633 int sig, typename RealTypeId<RealA>::Type coeff, 00634 RealA* const outputEven, RealA* const outputOdd) 00635 { 00636 #ifdef KERNEL_ENABLED 00637 int sid = blockIdx.x * blockDim.x + threadIdx.x; 00638 00639 int x[4]; 00640 int z1 = sid/X1h; 00641 int x1h = sid - z1*X1h; 00642 int z2 = z1/X2; 00643 x[1] = z1 - z2*X2; 00644 x[3] = z2/X3; 00645 x[2] = z2 - x[3]*X3; 00646 int x1odd = (x[1] + x[2] + x[3] + oddBit) & 1; 00647 x[0] = 2*x1h + x1odd; 00648 00649 int new_x[4]; 00650 new_x[0] = x[0]; 00651 new_x[1] = x[1]; 00652 new_x[2] = x[2]; 00653 new_x[3] = x[3]; 00654 00655 00656 RealA ab_link[ArrayLength<RealA>::result]; 00657 RealA bc_link[ArrayLength<RealA>::result]; 00658 RealA de_link[ArrayLength<RealA>::result]; 00659 RealA ef_link[ArrayLength<RealA>::result]; 00660 00661 #if(RECON == 12) 00662 int ab_link_sign =1; 00663 int bc_link_sign =1; 00664 int de_link_sign =1; 00665 int ef_link_sign =1; 00666 #endif 00667 00668 RealA COLOR_MAT_U[ArrayLength<RealA>::result]; 00669 RealA COLOR_MAT_V[ArrayLength<RealA>::result]; 00670 RealA COLOR_MAT_W[ArrayLength<RealA>::result]; // used as a temporary 00671 RealA COLOR_MAT_X[ArrayLength<RealA>::result]; 00672 RealA COLOR_MAT_Y[ArrayLength<RealA>::result]; 00673 RealA COLOR_MAT_Z[ArrayLength<RealA>::result]; 00674 00675 00676 const int & point_c = sid; 00677 int point_a, point_b, point_d, point_e; 00678 // need to work these indices 00679 int X[4]; 00680 X[0] = X1; 00681 X[1] = X2; 00682 X[2] = X3; 00683 X[3] = X4; 00684 00685 /* 00686 * 00687 * A B C D E 00688 * ---- ---- ---- ---- 00689 * 00690 * ---> sig direction 00691 * 00692 * C is the current point (sid) 00693 * 00694 */ 00695 00696 // compute the force for forward long links 00697 if(GOES_FORWARDS(sig)) 00698 { 00699 new_x[sig] = (x[sig] + 1 + X[sig])%X[sig]; 00700 point_d = (new_x[3]*X3X2X1+new_x[2]*X2X1+new_x[1]*X1+new_x[0]) >> 1; 00701 COMPUTE_LINK_SIGN(&de_link_sign, sig, new_x); 00702 00703 new_x[sig] = (new_x[sig] + 1 + X[sig])%X[sig]; 00704 point_e = (new_x[3]*X3X2X1+new_x[2]*X2X1+new_x[1]*X1+new_x[0]) >> 1; 00705 COMPUTE_LINK_SIGN(&ef_link_sign, sig, new_x); 00706 00707 new_x[sig] = (x[sig] - 1 + X[sig])%X[sig]; 00708 point_b = (new_x[3]*X3X2X1+new_x[2]*X2X1+new_x[1]*X1+new_x[0]) >> 1; 00709 COMPUTE_LINK_SIGN(&bc_link_sign, sig, new_x); 00710 00711 new_x[sig] = (new_x[sig] - 1 + X[sig])%X[sig]; 00712 point_a = (new_x[3]*X3X2X1+new_x[2]*X2X1+new_x[1]*X1+new_x[0]) >> 1; 00713 COMPUTE_LINK_SIGN(&ab_link_sign, sig, new_x); 00714 00715 HISQ_LOAD_LINK(linkEven, linkOdd, sig, point_a, ab_link, oddBit); 00716 HISQ_LOAD_LINK(linkEven, linkOdd, sig, point_b, bc_link, 1-oddBit); 00717 HISQ_LOAD_LINK(linkEven, linkOdd, sig, point_d, de_link, 1-oddBit); 00718 HISQ_LOAD_LINK(linkEven, linkOdd, sig, point_e, ef_link, oddBit); 00719 00720 RECONSTRUCT_SITE_LINK(ab_link, ab_link_sign); 00721 RECONSTRUCT_SITE_LINK(bc_link, bc_link_sign); 00722 RECONSTRUCT_SITE_LINK(de_link, de_link_sign); 00723 RECONSTRUCT_SITE_LINK(ef_link, ef_link_sign); 00724 00725 loadMatrixFromField(naikOprodEven, naikOprodOdd, sig, point_c, COLOR_MAT_Z, oddBit); 00726 loadMatrixFromField(naikOprodEven, naikOprodOdd, sig, point_b, COLOR_MAT_Y, 1-oddBit); 00727 loadMatrixFromField(naikOprodEven, naikOprodOdd, sig, point_a, COLOR_MAT_X, oddBit); 00728 00729 MAT_MUL_MAT(ef_link, COLOR_MAT_Z, COLOR_MAT_W); // link(d)*link(e)*Naik(c) 00730 MAT_MUL_MAT(de_link, COLOR_MAT_W, COLOR_MAT_V); 00731 00732 MAT_MUL_MAT(de_link, COLOR_MAT_Y, COLOR_MAT_W); // link(d)*Naik(b)*link(b) 00733 MAT_MUL_MAT(COLOR_MAT_W, bc_link, COLOR_MAT_U); 00734 SCALAR_MULT_ADD_MATRIX(COLOR_MAT_V, COLOR_MAT_U, -1, COLOR_MAT_V); 00735 00736 MAT_MUL_MAT(COLOR_MAT_X, ab_link, COLOR_MAT_W); // Naik(a)*link(a)*link(b) 00737 MAT_MUL_MAT(COLOR_MAT_W, bc_link, COLOR_MAT_U); 00738 SCALAR_MULT_ADD_MATRIX(COLOR_MAT_V, COLOR_MAT_U, 1, COLOR_MAT_V); 00739 00740 addMatrixToField(COLOR_MAT_V, sig, sid, coeff, outputEven, outputOdd, oddBit); 00741 } 00742 #endif 00743 return; 00744 } 00745 00746 00747 template<class RealA, class RealB, int oddBit> 00748 __global__ void 00749 HISQ_KERNEL_NAME(do_complete_force, EXT)(const RealB* const linkEven, const RealB* const linkOdd, 00750 const RealA* const oprodEven, const RealA* const oprodOdd, 00751 int sig, 00752 RealA* const forceEven, RealA* const forceOdd) 00753 { 00754 #ifdef KERNEL_ENABLED 00755 int sid = blockIdx.x * blockDim.x + threadIdx.x; 00756 00757 int x[4]; 00758 int z1 = sid/X1h; 00759 int x1h = sid - z1*X1h; 00760 int z2 = z1/X2; 00761 x[1] = z1 - z2*X2; 00762 x[3] = z2/X3; 00763 x[2] = z2 - x[3]*X3; 00764 int x1odd = (x[1] + x[2] + x[3] + oddBit) & 1; 00765 x[0] = 2*x1h + x1odd; 00766 00767 #if(RECON == 12) 00768 int link_sign; 00769 #endif 00770 00771 RealA LINK_W[ArrayLength<RealA>::result]; 00772 RealA COLOR_MAT_W[ArrayLength<RealA>::result]; 00773 RealA COLOR_MAT_X[ArrayLength<RealA>::result]; 00774 00775 00776 HISQ_LOAD_LINK(linkEven, linkOdd, sig, sid, LINK_W, oddBit); 00777 COMPUTE_LINK_SIGN(&link_sign, sig, x); 00778 RECONSTRUCT_SITE_LINK(LINK_W, link_sign); 00779 00780 loadMatrixFromField(oprodEven, oprodOdd, sig, sid, COLOR_MAT_X, oddBit); 00781 00782 typename RealTypeId<RealA>::Type coeff = (oddBit==1) ? -1 : 1; 00783 MAT_MUL_MAT(LINK_W, COLOR_MAT_X, COLOR_MAT_W); 00784 00785 storeMatrixToMomentumField(COLOR_MAT_W, sig, sid, coeff, forceEven, forceOdd, oddBit); 00786 #endif 00787 return; 00788 } 00789 00790 #undef EXT 00791 #undef KERNEL_ENABLED