QUDA v0.4.0
A library for QCD on GPUs
quda/lib/hisq_paths_force_quda.cu
Go to the documentation of this file.
00001 #include <read_gauge.h>
00002 #include <gauge_field.h>
00003 
00004 #include <hisq_force_quda.h>
00005 #include <hw_quda.h>
00006 #include <hisq_force_macros.h>
00007 #include<utility>
00008 
00009 
00010 //DEBUG : control conpile 
00011 #define COMPILE_HISQ_DP_18 
00012 #define COMPILE_HISQ_DP_12 
00013 #define COMPILE_HISQ_SP_18 
00014 #define COMPILE_HISQ_SP_12
00015 
00016 // Disable texture read for now. Need to revisit this.
00017 #define HISQ_SITE_MATRIX_LOAD_TEX 1
00018 #define HISQ_NEW_OPROD_LOAD_TEX 1
00019 
00020 namespace hisq {
00021   namespace fermion_force {
00022 
00023 
00024 
00025     texture<int4, 1> newOprod0TexDouble;
00026     texture<int4, 1> newOprod1TexDouble;
00027     texture<float2, 1, cudaReadModeElementType>  newOprod0TexSingle;
00028     texture<float2, 1, cudaReadModeElementType> newOprod1TexSingle;
00029     
00030     void hisqForceInitCuda(QudaGaugeParam* param)
00031     {
00032       static int hisq_force_init_cuda_flag = 0; 
00033       
00034         if (hisq_force_init_cuda_flag){
00035           return;
00036         }
00037         hisq_force_init_cuda_flag=1;
00038         init_kernel_cuda(param);    
00039     }
00040     
00041 
00042 
00043 
00044 
00045     // struct for holding the fattening path coefficients
00046     template<class Real>
00047       struct PathCoefficients
00048       {
00049         Real one; 
00050         Real three;
00051         Real five;
00052         Real seven;
00053         Real naik;
00054         Real lepage;
00055       };
00056 
00057 
00058     inline __device__ float2 operator*(float a, const float2 & b)
00059     {
00060       return make_float2(a*b.x,a*b.y);
00061     }
00062 
00063     inline __device__ double2 operator*(double a, const double2 & b)
00064     {
00065       return make_double2(a*b.x,a*b.y);
00066     }
00067 
00068     inline __device__ const float2 & operator+=(float2 & a, const float2 & b)
00069     {
00070       a.x += b.x;
00071       a.y += b.y;
00072       return a;
00073     }
00074 
00075     inline __device__ const double2 & operator+=(double2 & a, const double2 & b)
00076     {
00077       a.x += b.x;
00078       a.y += b.y;
00079       return a;
00080     }
00081 
00082     inline __device__ const float4 & operator+=(float4 & a, const float4 & b)
00083     {
00084       a.x += b.x;
00085       a.y += b.y;
00086       a.z += b.z;
00087       a.w += b.w;
00088       return a;
00089     }
00090 
00091     // Replication of code 
00092     // This structure is already defined in 
00093     // unitarize_utilities.h
00094 
00095     template<class T>
00096       struct RealTypeId; 
00097 
00098     template<>
00099       struct RealTypeId<float2>
00100       {
00101         typedef float Type;
00102       };
00103 
00104     template<>
00105       struct RealTypeId<double2>
00106       {
00107         typedef double Type;
00108       };
00109 
00110 
00111     template<class T>
00112       inline __device__
00113       void adjointMatrix(T* mat)
00114       {
00115 #define CONJ_INDEX(i,j) j*3 + i
00116 
00117         T tmp;
00118         mat[CONJ_INDEX(0,0)] = conj(mat[0]);
00119         mat[CONJ_INDEX(1,1)] = conj(mat[4]);
00120         mat[CONJ_INDEX(2,2)] = conj(mat[8]);
00121         tmp  = conj(mat[1]);
00122         mat[CONJ_INDEX(1,0)] = conj(mat[3]);
00123         mat[CONJ_INDEX(0,1)] = tmp;     
00124         tmp = conj(mat[2]);
00125         mat[CONJ_INDEX(2,0)] = conj(mat[6]);
00126         mat[CONJ_INDEX(0,2)] = tmp;
00127         tmp = conj(mat[5]);
00128         mat[CONJ_INDEX(2,1)] = conj(mat[7]);
00129         mat[CONJ_INDEX(1,2)] = tmp;
00130 
00131 #undef CONJ_INDEX
00132         return;
00133       }
00134 
00135 
00136     template<int N, class T>
00137       inline __device__
00138       void loadMatrixFromField(const T* const field_even, const T* const field_odd,
00139                                int dir, int idx, T* const mat, int oddness)
00140     {
00141       const T* const field = (oddness)?field_odd:field_even;
00142       for(int i = 0;i < N ;i++){
00143           mat[i] = field[idx + dir*N*Vh + i*Vh];          
00144       }
00145       return;
00146     }
00147 
00148     template<class T>
00149       inline __device__
00150       void loadMatrixFromField(const T* const field_even, const T* const field_odd,
00151                                int dir, int idx, T* const mat, int oddness)
00152       {
00153         loadMatrixFromField<9> (field_even, field_odd, dir, idx, mat, oddness);
00154         return;
00155       }
00156     
00157     
00158 
00159     inline __device__
00160       void loadMatrixFromField(const float4* const field_even, const float4* const field_odd, 
00161                                int dir, int idx, float2* const mat, int oddness)
00162     {
00163       const float4* const field = oddness?field_odd: field_even;
00164       float4 tmp;
00165       tmp = field[idx + dir*Vhx3];
00166       mat[0] = make_float2(tmp.x, tmp.y);
00167       mat[1] = make_float2(tmp.z, tmp.w);
00168       tmp = field[idx + dir*Vhx3 + Vh];
00169       mat[2] = make_float2(tmp.x, tmp.y);
00170       mat[3] = make_float2(tmp.z, tmp.w);
00171       tmp = field[idx + dir*Vhx3 + 2*Vh];
00172       mat[4] = make_float2(tmp.x, tmp.y);
00173       mat[5] = make_float2(tmp.z, tmp.w);
00174       return;
00175     }
00176 
00177     template<class T>
00178       inline __device__
00179       void loadMatrixFromField(const T* const field_even, const T* const field_odd, int idx, T* const mat, int oddness)
00180       {
00181         const T* const field = (oddness)?field_odd:field_even;
00182         mat[0] = field[idx];
00183         mat[1] = field[idx + Vh];
00184         mat[2] = field[idx + Vhx2];
00185         mat[3] = field[idx + Vhx3];
00186         mat[4] = field[idx + Vhx4];
00187         mat[5] = field[idx + Vhx5];
00188         mat[6] = field[idx + Vhx6];
00189         mat[7] = field[idx + Vhx7];
00190         mat[8] = field[idx + Vhx8];
00191 
00192         return;
00193       }
00194     
00195 
00196 #define  addMatrixToNewOprod(mat,  dir, idx, coeff, field_even, field_odd, oddness)     do { \
00197       RealA* const field = (oddness)?field_odd: field_even;             \
00198       RealA value[9];                                                   \
00199       value[0] = LOAD_TEX_ENTRY( ((oddness)?NEWOPROD_ODD_TEX:NEWOPROD_EVEN_TEX), field, idx+dir*Vhx9); \
00200       value[1] = LOAD_TEX_ENTRY( ((oddness)?NEWOPROD_ODD_TEX:NEWOPROD_EVEN_TEX), field, idx+dir*Vhx9 + Vh);     \
00201       value[2] = LOAD_TEX_ENTRY( ((oddness)?NEWOPROD_ODD_TEX:NEWOPROD_EVEN_TEX), field, idx+dir*Vhx9 + 2*Vh); \
00202       value[3] = LOAD_TEX_ENTRY( ((oddness)?NEWOPROD_ODD_TEX:NEWOPROD_EVEN_TEX), field, idx+dir*Vhx9 + 3*Vh); \
00203       value[4] = LOAD_TEX_ENTRY( ((oddness)?NEWOPROD_ODD_TEX:NEWOPROD_EVEN_TEX), field, idx+dir*Vhx9 + 4*Vh); \
00204       value[5] = LOAD_TEX_ENTRY( ((oddness)?NEWOPROD_ODD_TEX:NEWOPROD_EVEN_TEX), field, idx+dir*Vhx9 + 5*Vh); \
00205       value[6] = LOAD_TEX_ENTRY( ((oddness)?NEWOPROD_ODD_TEX:NEWOPROD_EVEN_TEX), field, idx+dir*Vhx9 + 6*Vh); \
00206       value[7] = LOAD_TEX_ENTRY( ((oddness)?NEWOPROD_ODD_TEX:NEWOPROD_EVEN_TEX), field, idx+dir*Vhx9 + 7*Vh); \
00207       value[8] = LOAD_TEX_ENTRY( ((oddness)?NEWOPROD_ODD_TEX:NEWOPROD_EVEN_TEX), field, idx+dir*Vhx9 + 8*Vh); \
00208       field[idx + dir*Vhx9]          = value[0] + coeff*mat[0];         \
00209       field[idx + dir*Vhx9 + Vh]     = value[1] + coeff*mat[1];         \
00210       field[idx + dir*Vhx9 + Vhx2]   = value[2] + coeff*mat[2];         \
00211       field[idx + dir*Vhx9 + Vhx3]   = value[3] + coeff*mat[3];         \
00212       field[idx + dir*Vhx9 + Vhx4]   = value[4] + coeff*mat[4];         \
00213       field[idx + dir*Vhx9 + Vhx5]   = value[5] + coeff*mat[5];         \
00214       field[idx + dir*Vhx9 + Vhx6]   = value[6] + coeff*mat[6];         \
00215       field[idx + dir*Vhx9 + Vhx7]   = value[7] + coeff*mat[7];         \
00216       field[idx + dir*Vhx9 + Vhx8]   = value[8] + coeff*mat[8];         \
00217   }while(0)                                     
00218      
00219 
00220 
00221     // only works if Promote<T,U>::Type = T
00222 
00223     template<class T, class U>   
00224     inline __device__
00225       void addMatrixToField(const T* const mat, int dir, int idx, U coeff, 
00226                              T* const field_even, T* const field_odd, int oddness)
00227       {
00228         T* const field = (oddness)?field_odd: field_even;
00229         field[idx + dir*Vhx9]          += coeff*mat[0];
00230         field[idx + dir*Vhx9 + Vh]     += coeff*mat[1];
00231         field[idx + dir*Vhx9 + Vhx2]   += coeff*mat[2];
00232         field[idx + dir*Vhx9 + Vhx3]   += coeff*mat[3];
00233         field[idx + dir*Vhx9 + Vhx4]   += coeff*mat[4];
00234         field[idx + dir*Vhx9 + Vhx5]   += coeff*mat[5];
00235         field[idx + dir*Vhx9 + Vhx6]   += coeff*mat[6];
00236         field[idx + dir*Vhx9 + Vhx7]   += coeff*mat[7];
00237         field[idx + dir*Vhx9 + Vhx8]   += coeff*mat[8];
00238 
00239         return;
00240       }
00241 
00242 
00243     template<class T, class U>
00244     inline __device__
00245       void addMatrixToField(const T* const mat, int idx, U coeff, T* const field_even,
00246                              T* const field_odd, int oddness)
00247       {
00248         T* const field = (oddness)?field_odd: field_even;
00249         field[idx ]         += coeff*mat[0];
00250         field[idx + Vh]     += coeff*mat[1];
00251         field[idx + Vhx2]   += coeff*mat[2];
00252         field[idx + Vhx3]   += coeff*mat[3];
00253         field[idx + Vhx4]   += coeff*mat[4];
00254         field[idx + Vhx5]   += coeff*mat[5];
00255         field[idx + Vhx6]   += coeff*mat[6];
00256         field[idx + Vhx7]   += coeff*mat[7];
00257         field[idx + Vhx8]   += coeff*mat[8];
00258 
00259         return;
00260       }
00261 
00262 
00263    template<class T>
00264     inline __device__
00265      void storeMatrixToField(const T* const mat, int dir, int idx, T* const field_even, T* const field_odd, int oddness)
00266       {
00267         T* const field = (oddness)?field_odd: field_even;
00268         field[idx + dir*Vhx9]          = mat[0];
00269         field[idx + dir*Vhx9 + Vh]     = mat[1];
00270         field[idx + dir*Vhx9 + Vhx2]   = mat[2];
00271         field[idx + dir*Vhx9 + Vhx3]   = mat[3];
00272         field[idx + dir*Vhx9 + Vhx4]   = mat[4];
00273         field[idx + dir*Vhx9 + Vhx5]   = mat[5];
00274         field[idx + dir*Vhx9 + Vhx6]   = mat[6];
00275         field[idx + dir*Vhx9 + Vhx7]   = mat[7];
00276         field[idx + dir*Vhx9 + Vhx8]   = mat[8];
00277 
00278         return;
00279       }
00280 
00281 
00282     template<class T>
00283     inline __device__
00284       void storeMatrixToField(const T* const mat, int idx, T* const field_even, T* const field_odd, int oddness)
00285       {
00286         T* const field = (oddness)?field_odd: field_even;
00287         field[idx]          = mat[0];
00288         field[idx + Vh]     = mat[1];
00289         field[idx + Vhx2]   = mat[2];
00290         field[idx + Vhx3]   = mat[3];
00291         field[idx + Vhx4]   = mat[4];
00292         field[idx + Vhx5]   = mat[5];
00293         field[idx + Vhx6]   = mat[6];
00294         field[idx + Vhx7]   = mat[7];
00295         field[idx + Vhx8]   = mat[8];
00296 
00297         return;
00298       }
00299 
00300 
00301      template<class T, class U> 
00302      inline __device__
00303        void storeMatrixToMomentumField(const T* const mat, int dir, int idx, U coeff, 
00304                                         T* const mom_even, T* const mom_odd, int oddness)
00305         {
00306           T* const mom_field = (oddness)?mom_odd:mom_even;
00307           T temp2;
00308           temp2.x = (mat[1].x - mat[3].x)*0.5*coeff;
00309           temp2.y = (mat[1].y + mat[3].y)*0.5*coeff;
00310           mom_field[idx + dir*Vhx5] = temp2;    
00311 
00312           temp2.x = (mat[2].x - mat[6].x)*0.5*coeff;
00313           temp2.y = (mat[2].y + mat[6].y)*0.5*coeff;
00314           mom_field[idx + dir*Vhx5 + Vh] = temp2;
00315 
00316           temp2.x = (mat[5].x - mat[7].x)*0.5*coeff;
00317           temp2.y = (mat[5].y + mat[7].y)*0.5*coeff;
00318           mom_field[idx + dir*Vhx5 + Vhx2] = temp2;
00319 
00320           const typename RealTypeId<T>::Type temp = (mat[0].y + mat[4].y + mat[8].y)*0.3333333333333333333333333;
00321           temp2.x =  (mat[0].y-temp)*coeff; 
00322           temp2.y =  (mat[4].y-temp)*coeff;
00323           mom_field[idx + dir*Vhx5 + Vhx3] = temp2;
00324                   
00325           temp2.x = (mat[8].y - temp)*coeff;
00326           temp2.y = 0.0;
00327           mom_field[idx + dir*Vhx5 + Vhx4] = temp2;
00328  
00329           return;
00330         }
00331 
00332     // Struct to determine the coefficient sign at compile time
00333     template<int pos_dir, int odd_lattice>
00334       struct CoeffSign
00335       {
00336         static const int result = -1;
00337       };
00338 
00339     template<>
00340       struct CoeffSign<0,1>
00341       {
00342         static const int result = -1;
00343       }; 
00344 
00345     template<>
00346       struct CoeffSign<0,0>
00347       {
00348         static const int result = 1;
00349       };
00350 
00351     template<>
00352       struct CoeffSign<1,1>
00353       {
00354         static const int result = 1;
00355       };
00356 
00357     template<int odd_lattice>
00358         struct Sign
00359         {
00360           static const int result = 1;
00361         };
00362 
00363     template<>
00364         struct Sign<1>
00365         {
00366           static const int result = -1;
00367         };
00368 
00369     template<class RealX>
00370       struct ArrayLength
00371       {
00372         static const int result=9;
00373       };
00374 
00375     template<>
00376       struct ArrayLength<float4>
00377       {
00378         static const int result=5;
00379       };
00380  
00381 
00382 
00383      
00384 
00385     // reconstructSign doesn't do anything right now, 
00386     // but it will, soon.
00387     template<typename T>
00388       __device__ void reconstructSign(int* const sign, int dir, const T i[4]){
00389 
00390  
00391       *sign=1;
00392       
00393       switch(dir){
00394       case XUP:
00395         if( (i[3]&1)==1) *sign=-1;
00396         break;    
00397 
00398       case YUP:
00399         if( ((i[3]+i[0])&1) == 1) *sign=-1; 
00400         break;
00401         
00402       case ZUP:
00403         if( ((i[3]+i[0]+i[1])&1) == 1) *sign=-1; 
00404         break;
00405         
00406       case TUP:
00407         if(i[3] == X4m1) *sign=-1; 
00408         break;
00409         
00410       default:
00411         printf("Error: invalid dir\n");
00412         break;
00413       }
00414       
00415       return;
00416     }
00417 
00418 
00419 
00420 
00421 
00422 
00423 template<class RealA, int oddBit>
00424   __global__ void 
00425   do_one_link_term_kernel(const RealA* const oprodEven, const RealA* const oprodOdd,
00426                           int sig, typename RealTypeId<RealA>::Type coeff,
00427                           RealA* const outputEven, RealA* const outputOdd)
00428 {
00429   int sid = blockIdx.x * blockDim.x + threadIdx.x;
00430   
00431   RealA COLOR_MAT_W[ArrayLength<RealA>::result];
00432   if(GOES_FORWARDS(sig)){
00433     loadMatrixFromField(oprodEven, oprodOdd, sig, sid, COLOR_MAT_W, oddBit);
00434     addMatrixToField(COLOR_MAT_W, sig, sid, coeff, outputEven, outputOdd, oddBit);
00435   }
00436   return;
00437 }
00438 
00439 
00440 #define DD_CONCAT(n,r) n ## r ## kernel
00441 
00442 #define HISQ_KERNEL_NAME(a,b) DD_CONCAT(a,b)
00443 //precision: 0 is for double, 1 is for single
00444 
00445 #define NEWOPROD_EVEN_TEX newOprod0TexDouble
00446 #define NEWOPROD_ODD_TEX newOprod1TexDouble
00447 #ifdef HISQ_NEW_OPROD_LOAD_TEX
00448 #define LOAD_TEX_ENTRY(tex, field, idx)  READ_DOUBLE2_TEXTURE(tex, field, idx)
00449 #else
00450 #define LOAD_TEX_ENTRY(tex, field, idx) field[idx]
00451 #endif
00452 
00453 //double precision, recon=18
00454 #define PRECISION 0
00455 #define RECON 18
00456 #if (HISQ_SITE_MATRIX_LOAD_TEX == 1)
00457 #define HISQ_LOAD_LINK(linkEven, linkOdd, dir, idx, var, oddness)   HISQ_LOAD_MATRIX_18_DOUBLE_TEX((oddness)?siteLink1TexDouble:siteLink0TexDouble,  (oddness)?linkOdd:linkEven, dir, idx, var, Vh)        
00458 #else
00459 #define HISQ_LOAD_LINK(linkEven, linkOdd, dir, idx, var, oddness)   loadMatrixFromField(linkEven, linkOdd, dir, idx, var, oddness)  
00460 #endif
00461 #define COMPUTE_LINK_SIGN(sign, dir, x) 
00462 #define RECONSTRUCT_SITE_LINK(var, sign)
00463 #include "hisq_paths_force_core.h"
00464 #undef PRECISION
00465 #undef RECON
00466 #undef HISQ_LOAD_LINK
00467 #undef COMPUTE_LINK_SIGN
00468 #undef RECONSTRUCT_SITE_LINK
00469 
00470 //double precision, recon=12
00471 #define PRECISION 0
00472 #define RECON 12
00473 #if (HISQ_SITE_MATRIX_LOAD_TEX == 1)
00474 #define HISQ_LOAD_LINK(linkEven, linkOdd, dir, idx, var, oddness)   HISQ_LOAD_MATRIX_12_DOUBLE_TEX((oddness)?siteLink1TexDouble:siteLink0TexDouble,  (oddness)?linkOdd:linkEven,dir, idx, var, Vh)        
00475 #else
00476 #define HISQ_LOAD_LINK(linkEven, linkOdd, dir, idx, var, oddness)   loadMatrixFromField<6>(linkEven, linkOdd, dir, idx, var, oddness)  
00477 #endif
00478 #define COMPUTE_LINK_SIGN(sign, dir, x) reconstructSign(sign, dir, x)
00479 #define RECONSTRUCT_SITE_LINK(var, sign)  FF_RECONSTRUCT_LINK_12(var, sign)
00480 #include "hisq_paths_force_core.h"
00481 #undef PRECISION
00482 #undef RECON
00483 #undef HISQ_LOAD_LINK
00484 #undef COMPUTE_LINK_SIGN
00485 #undef RECONSTRUCT_SITE_LINK       
00486 #undef NEWOPROD_EVEN_TEX 
00487 #undef NEWOPROD_ODD_TEX 
00488 #undef LOAD_TEX_ENTRY
00489 
00490 
00491 #define NEWOPROD_EVEN_TEX newOprod0TexSingle
00492 #define NEWOPROD_ODD_TEX newOprod1TexSingle
00493 
00494 #ifdef HISQ_NEW_OPROD_LOAD_TEX
00495 #define LOAD_TEX_ENTRY(tex, field, idx)  tex1Dfetch(tex,idx)
00496 #else
00497 #define LOAD_TEX_ENTRY(tex, field, idx) field[idx]
00498 #endif
00499 
00500 //single precision, recon=18  
00501 #define PRECISION 1
00502 #define RECON 18
00503 #if (HISQ_SITE_MATRIX_LOAD_TEX == 1)
00504 #define HISQ_LOAD_LINK(linkEven, linkOdd, dir, idx, var, oddness)   HISQ_LOAD_MATRIX_18_SINGLE_TEX((oddness)?siteLink1TexSingle:siteLink0TexSingle, dir, idx, var, Vh)        
00505 #else
00506 #define HISQ_LOAD_LINK(linkEven, linkOdd, dir, idx, var, oddness)   loadMatrixFromField(linkEven, linkOdd, dir, idx, var, oddness)  
00507 #endif
00508 #define COMPUTE_LINK_SIGN(sign, dir, x) 
00509 #define RECONSTRUCT_SITE_LINK(var, sign)
00510 #include "hisq_paths_force_core.h"
00511 #undef PRECISION
00512 #undef RECON
00513 #undef HISQ_LOAD_LINK
00514 #undef COMPUTE_LINK_SIGN
00515 #undef RECONSTRUCT_SITE_LINK
00516 
00517 //single precision, recon=12
00518 #define PRECISION 1
00519 #define RECON 12
00520 #if (HISQ_SITE_MATRIX_LOAD_TEX == 1)
00521 #define HISQ_LOAD_LINK(linkEven, linkOdd, dir, idx, var, oddness)   HISQ_LOAD_MATRIX_12_SINGLE_TEX((oddness)?siteLink1TexSingle_recon:siteLink0TexSingle_recon, dir, idx, var, Vh)        
00522 #else
00523 #define HISQ_LOAD_LINK(linkEven, linkOdd, dir, idx, var, oddness)   loadMatrixFromField(linkEven, linkOdd, dir, idx, var, oddness)  
00524 #endif
00525 #define COMPUTE_LINK_SIGN(sign, dir, x) reconstructSign(sign, dir, x)
00526 #define RECONSTRUCT_SITE_LINK(var, sign)  FF_RECONSTRUCT_LINK_12(var, sign)
00527 #include "hisq_paths_force_core.h"
00528 #undef PRECISION
00529 #undef RECON
00530 #undef HISQ_LOAD_LINK
00531 #undef COMPUTE_LINK_SIGN
00532 #undef RECONSTRUCT_SITE_LINK
00533 #undef NEWOPROD_EVEN_TEX 
00534 #undef NEWOPROD_ODD_TEX 
00535 #undef LOAD_TEX_ENTRY
00536 
00537     template<class RealA, class RealB>
00538       static void
00539       middle_link_kernel(
00540           const RealA* const oprodEven, const RealA* const oprodOdd, 
00541           const RealA* const QprevEven, const RealA* const QprevOdd,
00542           const RealB* const linkEven,  const RealB* const linkOdd, 
00543           const cudaGaugeField &link, int sig, int mu, 
00544           typename RealTypeId<RealA>::Type coeff,
00545           dim3 gridDim, dim3 BlockDim,
00546           RealA* const PmuEven,  RealA* const PmuOdd, // write only
00547           RealA* const P3Even,   RealA* const P3Odd,  // write only
00548           RealA* const QmuEven,  RealA* const QmuOdd,   // write only
00549           RealA* const newOprodEven,  RealA* const newOprodOdd)
00550       {
00551         QudaReconstructType recon = link.Reconstruct();
00552         dim3 halfGridDim(gridDim.x/2, 1,1);
00553         
00554 #define CALL_ARGUMENTS(typeA, typeB) <<<halfGridDim, BlockDim>>>((typeA*)oprodEven, (typeA*)oprodOdd, \
00555                                                                  (typeA*)QprevEven, (typeA*)QprevOdd, \
00556                                                                  (typeB*)linkEven, (typeB*)linkOdd, \
00557                                                                  sig, mu, (typename RealTypeId<typeA>::Type)coeff, \
00558                                                                  (typeA*)PmuEven, (typeA*)PmuOdd, \
00559                                                                  (typeA*)P3Even, (typeA*)P3Odd, \
00560                                                                  (typeA*)QmuEven, (typeA*)QmuOdd, \
00561                                                                  (typeA*)newOprodEven, (typeA*)newOprodOdd)
00562         
00563 #define CALL_MIDDLE_LINK_KERNEL(sig_sign, mu_sign)                      \
00564         if(sizeof(RealA) == sizeof(float2)){                            \
00565           if(recon  == QUDA_RECONSTRUCT_NO){                            \
00566             do_middle_link_sp_18_kernel<float2, float2, sig_sign, mu_sign, 0> CALL_ARGUMENTS(float2, float2); \
00567             do_middle_link_sp_18_kernel<float2, float2, sig_sign, mu_sign, 1> CALL_ARGUMENTS(float2, float2); \
00568           }else{                                                        \
00569             do_middle_link_sp_12_kernel<float2, float4, sig_sign, mu_sign, 0> CALL_ARGUMENTS(float2, float4); \
00570             do_middle_link_sp_12_kernel<float2, float4, sig_sign, mu_sign, 1> CALL_ARGUMENTS(float2, float4); \
00571           }                                                             \
00572         }else{                                                          \
00573           if(recon  == QUDA_RECONSTRUCT_NO){                            \
00574             do_middle_link_dp_18_kernel<double2, double2, sig_sign, mu_sign, 0> CALL_ARGUMENTS(double2, double2); \
00575             do_middle_link_dp_18_kernel<double2, double2, sig_sign, mu_sign, 1> CALL_ARGUMENTS(double2, double2); \
00576           }else{                                                        \
00577             do_middle_link_dp_12_kernel<double2, double2, sig_sign, mu_sign, 0> CALL_ARGUMENTS(double2, double2); \
00578             do_middle_link_dp_12_kernel<double2, double2, sig_sign, mu_sign, 1> CALL_ARGUMENTS(double2, double2); \
00579           }                                                             \
00580         }
00581         
00582         if (GOES_FORWARDS(sig) && GOES_FORWARDS(mu)){   
00583           CALL_MIDDLE_LINK_KERNEL(1,1);
00584         }else if (GOES_FORWARDS(sig) && GOES_BACKWARDS(mu)){
00585           CALL_MIDDLE_LINK_KERNEL(1,0);
00586         }else if (GOES_BACKWARDS(sig) && GOES_FORWARDS(mu)){
00587           CALL_MIDDLE_LINK_KERNEL(0,1);
00588         }else{
00589           CALL_MIDDLE_LINK_KERNEL(0,0);
00590         }
00591         
00592 #undef CALL_ARGUMENTS   
00593 #undef CALL_MIDDLE_LINK_KERNEL
00594         return;
00595       }
00596 
00597 
00598 
00599 
00600     template<class RealA, class RealB>
00601       static void
00602       side_link_kernel(
00603           const RealA* const P3Even, const RealA* const P3Odd, 
00604           const RealA* const oprodEven, const RealA* const oprodOdd,
00605           const RealB* const linkEven,  const RealB* const linkOdd, 
00606           const cudaGaugeField &link, int sig, int mu, 
00607           typename RealTypeId<RealA>::Type coeff, 
00608           typename RealTypeId<RealA>::Type accumu_coeff,
00609           dim3 gridDim, dim3 blockDim,
00610           RealA* shortPEven,  RealA* shortPOdd,
00611           RealA* newOprodEven, RealA* newOprodOdd)
00612     {
00613       QudaReconstructType recon =link.Reconstruct();
00614       
00615       dim3 halfGridDim(gridDim.x/2,1,1);
00616 
00617 #define CALL_ARGUMENTS(typeA, typeB)    <<<halfGridDim, blockDim>>>((typeA*)P3Even, (typeA*)P3Odd, \
00618                                                                     (typeA*)oprodEven,  (typeA*)oprodOdd, \
00619                                                                     (typeB*)linkEven, (typeB*)linkOdd, \
00620                                                                     sig, mu, \
00621                                                                     (typename RealTypeId<typeA>::Type) coeff, \
00622                                                                     (typename RealTypeId<typeA>::Type) accumu_coeff, \
00623                                                                     (typeA*)shortPEven, (typeA*)shortPOdd, \
00624                                                                     (typeA*)newOprodEven, (typeA*)newOprodOdd)
00625       
00626 #define CALL_SIDE_LINK_KERNEL(sig_sign, mu_sign)                        \
00627       if(sizeof(RealA) == sizeof(float2)){                              \
00628         if(recon  == QUDA_RECONSTRUCT_NO){                              \
00629           do_side_link_sp_18_kernel<float2, float2, sig_sign, mu_sign, 0> CALL_ARGUMENTS(float2, float2); \
00630           do_side_link_sp_18_kernel<float2, float2, sig_sign, mu_sign, 1> CALL_ARGUMENTS(float2, float2); \
00631         }else{                                                          \
00632           do_side_link_sp_12_kernel<float2, float4, sig_sign, mu_sign, 0> CALL_ARGUMENTS(float2, float4); \
00633           do_side_link_sp_12_kernel<float2, float4, sig_sign, mu_sign, 1> CALL_ARGUMENTS(float2, float4); \
00634         }                                                               \
00635       }else{                                                            \
00636         if(recon  == QUDA_RECONSTRUCT_NO){                              \
00637           do_side_link_dp_18_kernel<double2, double2, sig_sign, mu_sign, 0> CALL_ARGUMENTS(double2, double2); \
00638           do_side_link_dp_18_kernel<double2, double2, sig_sign, mu_sign, 1> CALL_ARGUMENTS(double2, double2); \
00639         }else{                                                          \
00640           do_side_link_dp_12_kernel<double2, double2, sig_sign, mu_sign, 0> CALL_ARGUMENTS(double2, double2); \
00641           do_side_link_dp_12_kernel<double2, double2, sig_sign, mu_sign, 1> CALL_ARGUMENTS(double2, double2); \
00642         }                                                               \
00643       }
00644       
00645       if (GOES_FORWARDS(sig) && GOES_FORWARDS(mu)){
00646         CALL_SIDE_LINK_KERNEL(1,1);
00647       }else if (GOES_FORWARDS(sig) && GOES_BACKWARDS(mu)){
00648         CALL_SIDE_LINK_KERNEL(1,0);
00649         
00650       }else if (GOES_BACKWARDS(sig) && GOES_FORWARDS(mu)){
00651         CALL_SIDE_LINK_KERNEL(0,1);
00652       }else{
00653         CALL_SIDE_LINK_KERNEL(0,0);
00654       }
00655       
00656 #undef CALL_SIDE_LINK_KERNEL
00657 #undef CALL_ARGUMENTS      
00658       return;
00659     }
00660 
00661    
00662 
00663     template<class RealA, class RealB>
00664       static void
00665       all_link_kernel(
00666           const RealA* const oprodEven, const RealA* const oprodOdd,
00667           const RealA* const QprevEven, const RealA* const QprevOdd, 
00668           const RealB* const linkEven,  const RealB* const linkOdd, 
00669           const cudaGaugeField &link, int sig, int mu,
00670           typename RealTypeId<RealA>::Type coeff, 
00671           typename RealTypeId<RealA>::Type  accumu_coeff,
00672           dim3 gridDim, dim3 blockDim,
00673           RealA* const shortPEven, RealA* const shortPOdd,
00674           RealA* const newOprodEven, RealA* const newOprodOdd)
00675     {
00676       QudaReconstructType recon = link.Reconstruct();
00677       dim3 halfGridDim(gridDim.x/2, 1,1);
00678       
00679 #define CALL_ARGUMENTS(typeA, typeB) <<<halfGridDim, blockDim>>>((typeA*)oprodEven, (typeA*)oprodOdd, \
00680                                                                  (typeA*)QprevEven, (typeA*)QprevOdd, \
00681                                                                  (typeB*)linkEven, (typeB*)linkOdd, \
00682                                                                  sig,  mu, \
00683                                                                  (typename RealTypeId<typeA>::Type)coeff, \
00684                                                                  (typename RealTypeId<typeA>::Type)accumu_coeff, \
00685                                                                  (typeA*)shortPEven,(typeA*)shortPOdd, \
00686                                                                  (typeA*)newOprodEven, (typeA*)newOprodOdd)
00687 
00688 #define CALL_ALL_LINK_KERNEL(sig_sign, mu_sign)                         \
00689       if(sizeof(RealA) == sizeof(float2)){                              \
00690         if(recon  == QUDA_RECONSTRUCT_NO){                              \
00691           do_all_link_sp_18_kernel<float2, float2, sig_sign, mu_sign, 0> CALL_ARGUMENTS(float2, float2); \
00692           do_all_link_sp_18_kernel<float2, float2, sig_sign, mu_sign, 1> CALL_ARGUMENTS(float2, float2); \
00693         }else{                                                          \
00694           do_all_link_sp_12_kernel<float2, float4, sig_sign, mu_sign, 0> CALL_ARGUMENTS(float2, float4); \
00695           do_all_link_sp_12_kernel<float2, float4, sig_sign, mu_sign, 1> CALL_ARGUMENTS(float2, float4); \
00696         }                                                               \
00697       }else{                                                            \
00698         if(recon  == QUDA_RECONSTRUCT_NO){                              \
00699           do_all_link_dp_18_kernel<double2, double2, sig_sign, mu_sign, 0> CALL_ARGUMENTS(double2, double2); \
00700           do_all_link_dp_18_kernel<double2, double2, sig_sign, mu_sign, 1> CALL_ARGUMENTS(double2, double2); \
00701         }else{                                                          \
00702           do_all_link_dp_12_kernel<double2, double2, sig_sign, mu_sign, 0> CALL_ARGUMENTS(double2, double2); \
00703           do_all_link_dp_12_kernel<double2, double2, sig_sign, mu_sign, 1> CALL_ARGUMENTS(double2, double2); \
00704         }                                                               \
00705       }
00706       
00707       if (GOES_FORWARDS(sig) && GOES_FORWARDS(mu)){
00708         CALL_ALL_LINK_KERNEL(1, 1);
00709       }else if (GOES_FORWARDS(sig) && GOES_BACKWARDS(mu)){
00710         CALL_ALL_LINK_KERNEL(1, 0);
00711       }else if (GOES_BACKWARDS(sig) && GOES_FORWARDS(mu)){
00712         CALL_ALL_LINK_KERNEL(0, 1);
00713       }else{
00714         CALL_ALL_LINK_KERNEL(0, 0);
00715       }
00716       
00717 #undef CALL_ARGUMENTS
00718 #undef CALL_ALL_LINK_KERNEL         
00719       
00720       return;
00721     }
00722     
00723 
00724     template<class RealA>
00725       static void
00726       one_link_term(
00727           const RealA* const oprodEven, 
00728           const RealA* const oprodOdd,
00729           int sig, 
00730           typename RealTypeId<RealA>::Type coeff, 
00731           typename RealTypeId<RealA>::Type naik_coeff,
00732           dim3 gridDim, dim3 blockDim,
00733           RealA* const ForceMatrixEven,
00734           RealA* const ForceMatrixOdd)
00735       {
00736 
00737         dim3 halfGridDim(gridDim.x/2,1,1);
00738 
00739         if(GOES_FORWARDS(sig)){
00740 
00741           do_one_link_term_kernel<RealA,0><<<halfGridDim,blockDim>>>(oprodEven, oprodOdd,
00742                                                                      sig, coeff,
00743                                                                      ForceMatrixEven, ForceMatrixOdd);
00744           do_one_link_term_kernel<RealA,1><<<halfGridDim,blockDim>>>(oprodEven, oprodOdd,
00745                                                                      sig, coeff,
00746                                                                      ForceMatrixEven, ForceMatrixOdd);
00747                   
00748         } // GOES_FORWARDS(sig)
00749 
00750         return;
00751       }
00752 
00753       template<class RealA,class RealB>
00754       void longlink_terms(const RealB* const linkEven, const RealB* const linkOdd,
00755                           const RealA* const naikOprodEven, const RealA* const naikOprodOdd,
00756                           int sig, typename RealTypeId<RealA>::Type naik_coeff,
00757                           dim3 gridDim, dim3 blockDim, const cudaGaugeField& link, 
00758                           RealA* const outputEven, RealA* const outputOdd)
00759       {
00760         
00761         dim3 halfGridDim(gridDim.x/2,1,1);
00762         
00763         QudaReconstructType recon = link.Reconstruct();;
00764         
00765 #define CALL_ARGUMENTS(typeA, typeB)    <<<halfGridDim,blockDim>>>((typeB*)linkEven, (typeB*)linkOdd, \
00766                                                                    (typeA*)naikOprodEven,  (typeA*)naikOprodOdd, \
00767                                                                    sig, naik_coeff, \
00768                                                                    (typeA*)outputEven, (typeA*)outputOdd); \
00769         
00770         
00771         if(GOES_BACKWARDS(sig)){
00772           errorQuda("sig does not go forward\n");
00773         }
00774         if(sizeof(RealA) == sizeof(float2)){
00775           if(recon == QUDA_RECONSTRUCT_NO){
00776             do_longlink_sp_18_kernel<float2,float2, 0> CALL_ARGUMENTS(float2, float2);
00777             do_longlink_sp_18_kernel<float2,float2, 1> CALL_ARGUMENTS(float2, float2);
00778           }else{
00779             do_longlink_sp_12_kernel<float2,float4, 0> CALL_ARGUMENTS(float2, float4);
00780             do_longlink_sp_12_kernel<float2,float4, 1> CALL_ARGUMENTS(float2, float4);
00781           }
00782         }else{
00783           if(recon == QUDA_RECONSTRUCT_NO){
00784             do_longlink_dp_18_kernel<double2,double2, 0> CALL_ARGUMENTS(double2, double2);
00785             do_longlink_dp_18_kernel<double2,double2, 1> CALL_ARGUMENTS(double2, double2);
00786           }else{
00787             do_longlink_dp_12_kernel<double2,double2, 0> CALL_ARGUMENTS(double2, double2);
00788             do_longlink_dp_12_kernel<double2,double2, 1> CALL_ARGUMENTS(double2, double2);          
00789           }
00790         }
00791 #undef CALL_ARGUMENTS   
00792         return;
00793       }
00794 
00795 
00796 
00797           
00798     template<class RealA, class RealB>
00799       static void 
00800       complete_force_kernel(const RealA* const oprodEven, 
00801                             const RealA* const oprodOdd,
00802                             const RealB* const linkEven, 
00803                             const RealB* const linkOdd, 
00804                             const cudaGaugeField &link,
00805                             int sig, dim3 gridDim, dim3 blockDim,
00806                             RealA* const momEven, 
00807                             RealA* const momOdd)
00808     {
00809       dim3 halfGridDim(gridDim.x/2, 1, 1);
00810 #define CALL_ARGUMENTS(typeA, typeB)  <<<halfGridDim, blockDim>>>((typeB*)linkEven, (typeB*)linkOdd, \
00811                                                                   (typeA*)oprodEven, (typeA*)oprodOdd, \
00812                                                                   sig,  \
00813                                                                   (typeA*)momEven, (typeA*)momOdd); 
00814 
00815       QudaReconstructType recon = link.Reconstruct();
00816       
00817         if(sizeof(RealA) == sizeof(float2)){
00818           if(recon == QUDA_RECONSTRUCT_NO){
00819             do_complete_force_sp_18_kernel<float2,float2, 0> CALL_ARGUMENTS(float2, float2);
00820             do_complete_force_sp_18_kernel<float2,float2, 1> CALL_ARGUMENTS(float2, float2);
00821           }else{
00822             do_complete_force_sp_12_kernel<float2,float4, 0> CALL_ARGUMENTS(float2, float4);
00823             do_complete_force_sp_12_kernel<float2,float4, 1> CALL_ARGUMENTS(float2, float4);
00824           }
00825         }else{
00826           if(recon == QUDA_RECONSTRUCT_NO){
00827             do_complete_force_dp_18_kernel<double2,double2, 0> CALL_ARGUMENTS(double2, double2);
00828             do_complete_force_dp_18_kernel<double2,double2, 1> CALL_ARGUMENTS(double2, double2);
00829           }else{
00830             do_complete_force_dp_12_kernel<double2,double2, 0> CALL_ARGUMENTS(double2, double2);
00831             do_complete_force_dp_12_kernel<double2,double2, 1> CALL_ARGUMENTS(double2, double2);            
00832           }
00833         }
00834         
00835       
00836 #undef CALL_ARGUMENTS   
00837       return;
00838     }
00839 
00840 
00841     
00842 static void 
00843   bind_tex_link(const cudaGaugeField& link, const cudaGaugeField& newOprod)
00844 {
00845   if(link.Precision() == QUDA_DOUBLE_PRECISION){
00846     cudaBindTexture(0, siteLink0TexDouble, link.Even_p(), link.Bytes()/2);
00847     cudaBindTexture(0, siteLink1TexDouble, link.Odd_p(), link.Bytes()/2);
00848     
00849     cudaBindTexture(0, newOprod0TexDouble, newOprod.Even_p(), newOprod.Bytes()/2);
00850     cudaBindTexture(0, newOprod1TexDouble, newOprod.Odd_p(), newOprod.Bytes()/2);
00851   }else{
00852     if(link.Reconstruct() == QUDA_RECONSTRUCT_NO){
00853       cudaBindTexture(0, siteLink0TexSingle, link.Even_p(), link.Bytes()/2);      
00854       cudaBindTexture(0, siteLink1TexSingle, link.Odd_p(), link.Bytes()/2);      
00855     }else{
00856       cudaBindTexture(0, siteLink0TexSingle_recon, link.Even_p(), link.Bytes()/2);      
00857       cudaBindTexture(0, siteLink1TexSingle_recon, link.Odd_p(), link.Bytes()/2);            
00858     }
00859     cudaBindTexture(0, newOprod0TexSingle, newOprod.Even_p(), newOprod.Bytes()/2);
00860     cudaBindTexture(0, newOprod1TexSingle, newOprod.Odd_p(), newOprod.Bytes()/2);
00861     
00862   }
00863 }
00864 
00865 static void 
00866 unbind_tex_link(const cudaGaugeField& link, const cudaGaugeField& newOprod)
00867 {
00868   if(link.Precision() == QUDA_DOUBLE_PRECISION){
00869     cudaUnbindTexture(siteLink0TexDouble);
00870     cudaUnbindTexture(siteLink1TexDouble);
00871     cudaUnbindTexture(newOprod0TexDouble);
00872     cudaUnbindTexture(newOprod1TexDouble);
00873   }else{
00874     if(link.Reconstruct() == QUDA_RECONSTRUCT_NO){
00875       cudaUnbindTexture(siteLink0TexSingle);
00876       cudaUnbindTexture(siteLink1TexSingle);      
00877     }else{
00878       cudaUnbindTexture(siteLink0TexSingle_recon);
00879       cudaUnbindTexture(siteLink1TexSingle_recon);      
00880     }
00881     cudaUnbindTexture(newOprod0TexSingle);
00882     cudaUnbindTexture(newOprod1TexSingle);
00883   }
00884 }
00885 
00886 
00887 
00888 #define Pmu       tempmat[0]
00889 #define P3        tempmat[1]
00890 #define P5        tempmat[2]
00891 #define Pnumu     tempmat[3]
00892 
00893 #define Qmu      tempCmat[0]
00894 #define Qnumu    tempCmat[1]
00895 
00896 
00897     template<class Real, class  RealA, class RealB>
00898       static void
00899       do_hisq_staples_force_cuda( PathCoefficients<Real> act_path_coeff,
00900                                  const QudaGaugeParam& param,
00901                                  const cudaGaugeField &oprod, 
00902                                  const cudaGaugeField &link,
00903                                  FullMatrix tempmat[4], 
00904                                  FullMatrix tempCmat[2], 
00905                                  cudaGaugeField &newOprod)
00906       {
00907 
00908         QudaReconstructType recon = link.Reconstruct();
00909         Real coeff;
00910         Real OneLink, Lepage, FiveSt, ThreeSt, SevenSt;
00911         Real mLepage, mFiveSt, mThreeSt;
00912 
00913 
00914 
00915         OneLink = act_path_coeff.one;
00916         ThreeSt = act_path_coeff.three; mThreeSt = -ThreeSt;
00917         FiveSt  = act_path_coeff.five; mFiveSt  = -FiveSt;
00918         SevenSt = act_path_coeff.seven; 
00919         Lepage  = act_path_coeff.lepage; mLepage  = -Lepage;
00920         
00921         
00922         const int volume = param.X[0]*param.X[1]*param.X[2]*param.X[3];
00923         dim3 blockDim(BLOCK_DIM,1,1);
00924         dim3 gridDim(volume/blockDim.x, 1, 1);
00925 
00926         for(int sig=0; sig<8; sig++){
00927           for(int mu=0; mu<8; mu++){
00928             if ( (mu == sig) || (mu == OPP_DIR(sig))){
00929               continue;
00930             }
00931             //3-link
00932             //Kernel A: middle link
00933 
00934 
00935             middle_link_kernel( 
00936                 (RealA*)oprod.Even_p(), (RealA*)oprod.Odd_p(),                            // read only
00937                 (RealA*)NULL,         (RealA*)NULL,                                       // read only
00938                 (RealB*)link.Even_p(), (RealB*)link.Odd_p(),                              // read only 
00939                 link,  // read only
00940                 sig, mu, mThreeSt,
00941                 gridDim, blockDim,
00942                 (RealA*)Pmu.even.data, (RealA*)Pmu.odd.data,                               // write only
00943                 (RealA*)P3.even.data, (RealA*)P3.odd.data,                                 // write only
00944                 (RealA*)Qmu.even.data, (RealA*)Qmu.odd.data,                               // write only     
00945                 (RealA*)newOprod.Even_p(), (RealA*)newOprod.Odd_p());
00946 
00947             checkCudaError();
00948 
00949             for(int nu=0; nu < 8; nu++){
00950               if (nu == sig || nu == OPP_DIR(sig)
00951                   || nu == mu || nu == OPP_DIR(mu)){
00952                 continue;
00953               }
00954 
00955               //5-link: middle link
00956               //Kernel B
00957               middle_link_kernel( 
00958                   (RealA*)Pmu.even.data, (RealA*)Pmu.odd.data,      // read only
00959                   (RealA*)Qmu.even.data, (RealA*)Qmu.odd.data,      // read only
00960                   (RealB*)link.Even_p(), (RealB*)link.Odd_p(), 
00961                   link, 
00962                   sig, nu, FiveSt,
00963                   gridDim, blockDim,
00964                   (RealA*)Pnumu.even.data, (RealA*)Pnumu.odd.data,  // write only
00965                   (RealA*)P5.even.data, (RealA*)P5.odd.data,        // write only
00966                   (RealA*)Qnumu.even.data, (RealA*)Qnumu.odd.data,  // write only
00967                   (RealA*)newOprod.Even_p(), (RealA*)newOprod.Odd_p());
00968 
00969               checkCudaError();
00970 
00971               for(int rho = 0; rho < 8; rho++){
00972                 if (rho == sig || rho == OPP_DIR(sig)
00973                     || rho == mu || rho == OPP_DIR(mu)
00974                     || rho == nu || rho == OPP_DIR(nu)){
00975                   continue;
00976                 }
00977                 //7-link: middle link and side link
00978                 if(FiveSt != 0)coeff = SevenSt/FiveSt; else coeff = 0;
00979                 all_link_kernel(
00980                     (RealA*)Pnumu.even.data, (RealA*)Pnumu.odd.data,
00981                     (RealA*)Qnumu.even.data, (RealA*)Qnumu.odd.data,
00982                     (RealB*)link.Even_p(), (RealB*)link.Odd_p(), 
00983                     link,
00984                     sig, rho, SevenSt, coeff,
00985                     gridDim, blockDim,
00986                     (RealA*)P5.even.data, (RealA*)P5.odd.data, 
00987                     (RealA*)newOprod.Even_p(), (RealA*)newOprod.Odd_p());
00988 
00989                 checkCudaError();
00990 
00991               }//rho            
00992 
00993 
00994               //5-link: side link
00995               if(ThreeSt != 0)coeff = FiveSt/ThreeSt; else coeff = 0;
00996               side_link_kernel(
00997                   (RealA*)P5.even.data, (RealA*)P5.odd.data,    // read only
00998                   (RealA*)Qmu.even.data, (RealA*)Qmu.odd.data,  // read only
00999                   (RealB*)link.Even_p(), (RealB*)link.Odd_p(), 
01000                   link,
01001                   sig, nu, mFiveSt, coeff,
01002                   gridDim, blockDim,
01003                   (RealA*)P3.even.data, (RealA*)P3.odd.data,    // write
01004                   (RealA*)newOprod.Even_p(), (RealA*)newOprod.Odd_p());
01005               checkCudaError();
01006 
01007             } //nu 
01008 
01009             //lepage
01010             if(Lepage != 0.){
01011               middle_link_kernel( 
01012                   (RealA*)Pmu.even.data, (RealA*)Pmu.odd.data,     // read only
01013                   (RealA*)Qmu.even.data, (RealA*)Qmu.odd.data,     // read only
01014                   (RealB*)link.Even_p(), (RealB*)link.Odd_p(), 
01015                   link, 
01016                   sig, mu, Lepage,
01017                   gridDim, blockDim,
01018                   (RealA*)NULL, (RealA*)NULL,                      // write only
01019                   (RealA*)P5.even.data, (RealA*)P5.odd.data,       // write only
01020                   (RealA*)NULL, (RealA*)NULL,                      // write only
01021                   (RealA*)newOprod.Even_p(), (RealA*)newOprod.Odd_p());
01022 
01023 
01024               if(ThreeSt != 0)coeff = Lepage/ThreeSt ; else coeff = 0;
01025 
01026               side_link_kernel(
01027                   (RealA*)P5.even.data, (RealA*)P5.odd.data,           // read only
01028                   (RealA*)Qmu.even.data, (RealA*)Qmu.odd.data,         // read only
01029                   (RealB*)link.Even_p(), (RealB*)link.Odd_p(), 
01030                   link,
01031                   sig, mu, mLepage, coeff,
01032                   gridDim, blockDim,
01033                   (RealA*)P3.even.data, (RealA*)P3.odd.data,           // write only
01034                   (RealA*)newOprod.Even_p(), (RealA*)newOprod.Odd_p());
01035 
01036                   checkCudaError();             
01037             } // Lepage != 0.0
01038 
01039 
01040             //3-link side link
01041             coeff=0.;
01042             side_link_kernel(
01043                 (RealA*)P3.even.data, (RealA*)P3.odd.data, // read only
01044                 (RealA*)NULL, (RealA*)NULL,                // read only
01045                 (RealB*)link.Even_p(), (RealB*)link.Odd_p(), 
01046                 link,
01047                 sig, mu, ThreeSt, coeff,
01048                 gridDim, blockDim, 
01049                 (RealA*)NULL, (RealA*)NULL,                // write
01050                 (RealA*)newOprod.Even_p(), (RealA*)newOprod.Odd_p());
01051 
01052             checkCudaError();                       
01053 
01054           }//mu
01055         }//sig
01056 
01057 
01058         for(int sig=0; sig<8; ++sig){
01059           if(GOES_FORWARDS(sig)){
01060             one_link_term(
01061                 (RealA*)oprod.Even_p(), (RealA*)oprod.Odd_p(),
01062                 sig, OneLink, 0.0,
01063                 gridDim, blockDim,
01064                 (RealA*)newOprod.Even_p(), (RealA*)newOprod.Odd_p());
01065           } // GOES_FORWARDS(sig)
01066           checkCudaError();
01067         }
01068       
01069         return; 
01070    } // do_hisq_staples_force_cuda
01071 
01072 
01073 #undef Pmu
01074 #undef Pnumu
01075 #undef P3
01076 #undef P5
01077 #undef Qmu
01078 #undef Qnumu
01079 
01080 
01081    void hisqCompleteForceCuda(const QudaGaugeParam &param,
01082                    const cudaGaugeField &oprod,
01083                    const cudaGaugeField &link,
01084                    cudaGaugeField* force)
01085    {
01086 
01087            const int volume = param.X[0]*param.X[1]*param.X[2]*param.X[3];
01088            dim3 blockDim(BLOCK_DIM,1,1);
01089            dim3 gridDim(volume/blockDim.x, 1, 1);
01090 
01091            bind_tex_link(link, oprod);
01092            for(int sig=0; sig<4; sig++){
01093                    if(param.cuda_prec == QUDA_DOUBLE_PRECISION){
01094                      complete_force_kernel((double2*)oprod.Even_p(), (double2*)oprod.Odd_p(),
01095                                            (double2*)link.Even_p(), (double2*)link.Odd_p(), 
01096                                            link,
01097                                            sig, gridDim, blockDim,
01098                                            (double2*)force->Even_p(), (double2*)force->Odd_p());
01099                    }else if(param.cuda_prec == QUDA_SINGLE_PRECISION){
01100                      complete_force_kernel((float2*)oprod.Even_p(), (float2*)oprod.Odd_p(),
01101                                            (float2*)link.Even_p(), (float2*)link.Odd_p(), 
01102                                            link,
01103                                            sig, gridDim, blockDim,
01104                                            (float2*)force->Even_p(), (float2*)force->Odd_p());
01105                    }else{
01106                      errorQuda("Unsupported precision");
01107                    }
01108            } // loop over directions
01109 
01110            unbind_tex_link(link, oprod);
01111            return;
01112    }
01113 
01114    
01115 
01116 
01117 
01118    void hisqLongLinkForceCuda(double coeff,
01119                               const QudaGaugeParam &param,
01120                               const cudaGaugeField &oldOprod,
01121                               const cudaGaugeField &link,
01122                               cudaGaugeField  *newOprod)
01123    {
01124      const int volume = param.X[0]*param.X[1]*param.X[2]*param.X[3];
01125      dim3 blockDim(BLOCK_DIM,1,1);
01126      dim3 gridDim(volume/blockDim.x, 1, 1);
01127 
01128      bind_tex_link(link, *newOprod);
01129      
01130      for(int sig=0; sig<4; ++sig){
01131        if(param.cuda_prec == QUDA_DOUBLE_PRECISION){
01132          longlink_terms((double2*)link.Even_p(), (double2*)link.Odd_p(),
01133                         (double2*)oldOprod.Even_p(), (double2*)oldOprod.Odd_p(),
01134                         sig, coeff, 
01135                         gridDim, blockDim, link, 
01136                         (double2*)newOprod->Even_p(), (double2*)newOprod->Odd_p());
01137        }else if(param.cuda_prec == QUDA_SINGLE_PRECISION){
01138          longlink_terms((float2*)link.Even_p(), (float2*)link.Odd_p(),
01139                         (float2*)oldOprod.Even_p(), (float2*)oldOprod.Odd_p(),
01140                         sig, static_cast<float>(coeff), 
01141                         gridDim, blockDim, link,
01142                         (float2*)newOprod->Even_p(), (float2*)newOprod->Odd_p());
01143        }else{
01144          errorQuda("Unsupported precision");
01145        }
01146      } // loop over directions
01147      
01148      unbind_tex_link(link, *newOprod);
01149      return;
01150    }
01151 
01152 
01153 
01154 
01155 
01156     void
01157       hisqStaplesForceCuda(const double path_coeff_array[6],
01158                               const QudaGaugeParam &param,
01159                               const cudaGaugeField &oprod, 
01160                               const cudaGaugeField &link, 
01161                               cudaGaugeField* newOprod)
01162       {
01163 
01164         FullMatrix tempmat[4];
01165         for(int i=0; i<4; i++){
01166           tempmat[i]  = createMatQuda(param.X, param.cuda_prec);
01167         }
01168 
01169         FullMatrix tempCompmat[2];
01170         for(int i=0; i<2; i++){
01171           tempCompmat[i] = createMatQuda(param.X, param.cuda_prec);
01172         }       
01173 
01174         bind_tex_link(link, *newOprod);
01175         
01176 
01177 
01178         cudaEvent_t start, end;
01179         
01180         cudaEventCreate(&start);
01181         cudaEventCreate(&end);
01182         
01183         cudaEventRecord(start);
01184         if (param.cuda_prec == QUDA_DOUBLE_PRECISION){
01185           
01186           PathCoefficients<double> act_path_coeff;
01187           act_path_coeff.one    = path_coeff_array[0];
01188           act_path_coeff.naik   = path_coeff_array[1];
01189           act_path_coeff.three  = path_coeff_array[2];
01190           act_path_coeff.five   = path_coeff_array[3];
01191           act_path_coeff.seven  = path_coeff_array[4];
01192           act_path_coeff.lepage = path_coeff_array[5];
01193           do_hisq_staples_force_cuda<double,double2,double2>( act_path_coeff,
01194                                                            param,
01195                                                            oprod,
01196                                                            link, 
01197                                                            tempmat, 
01198                                                            tempCompmat, 
01199                                                            *newOprod);
01200                                                            
01201 
01202         }else if(param.cuda_prec == QUDA_SINGLE_PRECISION){     
01203           PathCoefficients<float> act_path_coeff;
01204           act_path_coeff.one    = path_coeff_array[0];
01205           act_path_coeff.naik   = path_coeff_array[1];
01206           act_path_coeff.three  = path_coeff_array[2];
01207           act_path_coeff.five   = path_coeff_array[3];
01208           act_path_coeff.seven  = path_coeff_array[4];
01209           act_path_coeff.lepage = path_coeff_array[5];
01210 
01211           do_hisq_staples_force_cuda<float,float2,float2>( act_path_coeff,
01212                                                            param,
01213                                                            oprod,
01214                                                            link, 
01215                                                            tempmat, 
01216                                                            tempCompmat, 
01217                                                            *newOprod);
01218         }else{
01219           errorQuda("Unsupported precision");
01220         }
01221         
01222         
01223         cudaEventRecord(end);
01224         cudaEventSynchronize(end);
01225         float runtime;
01226         cudaEventElapsedTime(&runtime, start, end);
01227         
01228         //printfQuda("hisq staple time=%.2f ms\n", runtime);
01229 
01230         unbind_tex_link(link, *newOprod);
01231 
01232         for(int i=0; i<4; i++){
01233           freeMatQuda(tempmat[i]);
01234         }
01235 
01236         for(int i=0; i<2; i++){
01237           freeMatQuda(tempCompmat[i]);
01238         }
01239         return; 
01240       }
01241 
01242   } // namespace fermion_force
01243 } // namespace hisq
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines