QUDA v0.4.0
A library for QCD on GPUs
quda/tests/staggered_dslash_reference.cpp
Go to the documentation of this file.
00001 #include <stdio.h>
00002 #include <stdlib.h>
00003 #include <math.h>
00004 #include <string.h>
00005 
00006 #include <test_util.h>
00007 #include <quda_internal.h>
00008 #include <quda.h>
00009 #include <util_quda.h>
00010 #include <staggered_dslash_reference.h>
00011 #include "misc.h"
00012 #include <blas_quda.h>
00013 
00014 #include <face_quda.h>
00015 
00016 extern void *memset(void *s, int c, size_t n);
00017 
00018 static int mySpinorSiteSize = 6;
00019 
00020 #include <dslash_util.h>
00021 
00022 //
00023 // dslashReference()
00024 //
00025 // if oddBit is zero: calculate even parity spinor elements (using odd parity spinor) 
00026 // if oddBit is one:  calculate odd parity spinor elements 
00027 //
00028 // if daggerBit is zero: perform ordinary dslash operator
00029 // if daggerBit is one:  perform hermitian conjugate of dslash
00030 //
00031 template<typename Float>
00032 void display_link_internal(Float* link)
00033 {
00034   int i, j;
00035     
00036   for (i = 0;i < 3; i++){
00037     for(j=0;j < 3; j++){
00038       printf("(%10f,%10f) \t", link[i*3*2 + j*2], link[i*3*2 + j*2 + 1]);
00039     }
00040     printf("\n");
00041   }
00042   printf("\n");
00043   return;
00044 }
00045 
00046 
00047 template <typename sFloat, typename gFloat>
00048 void dslashReference(sFloat *res, gFloat **fatlink, gFloat** longlink, sFloat *spinorField, 
00049                      int oddBit, int daggerBit) 
00050 {
00051   for (int i=0; i<Vh*1*3*2; i++) res[i] = 0.0;
00052   
00053   gFloat *fatlinkEven[4], *fatlinkOdd[4];
00054   gFloat *longlinkEven[4], *longlinkOdd[4];
00055   
00056   for (int dir = 0; dir < 4; dir++) {  
00057     fatlinkEven[dir] = fatlink[dir];
00058     fatlinkOdd[dir] = fatlink[dir] + Vh*gaugeSiteSize;
00059     longlinkEven[dir] =longlink[dir];
00060     longlinkOdd[dir] = longlink[dir] + Vh*gaugeSiteSize;    
00061   }
00062   
00063   for (int i = 0; i < Vh; i++) {
00064     memset(res + i*mySpinorSiteSize, 0, mySpinorSiteSize*sizeof(sFloat));
00065     for (int dir = 0; dir < 8; dir++) {
00066       gFloat* fatlnk = gaugeLink(i, dir, oddBit, fatlinkEven, fatlinkOdd, 1);
00067       gFloat* longlnk = gaugeLink(i, dir, oddBit, longlinkEven, longlinkOdd, 3);
00068       
00069       sFloat *first_neighbor_spinor = spinorNeighbor(i, dir, oddBit, spinorField, 1);
00070       sFloat *third_neighbor_spinor = spinorNeighbor(i, dir, oddBit, spinorField, 3);
00071       
00072       
00073       sFloat gaugedSpinor[mySpinorSiteSize];
00074       
00075       if (dir % 2 == 0){
00076         su3Mul(gaugedSpinor, fatlnk, first_neighbor_spinor);
00077         sum(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize);            
00078         su3Mul(gaugedSpinor, longlnk, third_neighbor_spinor);
00079         sum(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize);                
00080       } else {
00081         su3Tmul(gaugedSpinor, fatlnk, first_neighbor_spinor);
00082         sub(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize);        
00083         
00084         su3Tmul(gaugedSpinor, longlnk, third_neighbor_spinor);
00085         sub(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize);
00086       }             
00087     }
00088     if (daggerBit){
00089       negx(&res[i*mySpinorSiteSize], mySpinorSiteSize);
00090     }
00091   }
00092   
00093 }
00094 
00095 
00096 
00097 
00098 void staggered_dslash(void *res, void **fatlink, void** longlink, void *spinorField, int oddBit, int daggerBit,
00099                       QudaPrecision sPrecision, QudaPrecision gPrecision) {
00100     
00101   if (sPrecision == QUDA_DOUBLE_PRECISION) {
00102     if (gPrecision == QUDA_DOUBLE_PRECISION){
00103       dslashReference((double*)res, (double**)fatlink, (double**)longlink, (double*)spinorField, oddBit, daggerBit);
00104     }else{
00105       dslashReference((double*)res, (float**)fatlink, (float**)longlink, (double*)spinorField, oddBit, daggerBit);
00106     }
00107   }
00108   else{
00109     if (gPrecision == QUDA_DOUBLE_PRECISION){
00110       dslashReference((float*)res, (double**)fatlink, (double**)longlink, (float*)spinorField, oddBit, daggerBit);
00111     }else{
00112       dslashReference((float*)res, (float**)fatlink, (float**)longlink, (float*)spinorField, oddBit, daggerBit);
00113     }
00114   }
00115 }
00116 
00117 
00118 
00119 
00120 template <typename sFloat, typename gFloat>
00121 void Mat(sFloat *out, gFloat **fatlink, gFloat** longlink, sFloat *in, sFloat kappa, int daggerBit) 
00122 {
00123   sFloat *inEven = in;
00124   sFloat *inOdd  = in + Vh*mySpinorSiteSize;
00125   sFloat *outEven = out;
00126   sFloat *outOdd = out + Vh*mySpinorSiteSize;
00127     
00128   // full dslash operator
00129   dslashReference(outOdd, fatlink, longlink, inEven, 1, daggerBit);
00130   dslashReference(outEven, fatlink, longlink, inOdd, 0, daggerBit);
00131     
00132   // lastly apply the kappa term
00133   xpay(in, -kappa, out, V*mySpinorSiteSize);
00134 }
00135 
00136 
00137 void 
00138 mat(void *out, void **fatlink, void** longlink, void *in, double kappa, int dagger_bit,
00139     QudaPrecision sPrecision, QudaPrecision gPrecision) 
00140 {
00141     
00142   if (sPrecision == QUDA_DOUBLE_PRECISION){
00143     if (gPrecision == QUDA_DOUBLE_PRECISION) {
00144       Mat((double*)out, (double**)fatlink, (double**)longlink, (double*)in, (double)kappa, dagger_bit);
00145     }else {
00146       Mat((double*)out, (float**)fatlink, (float**)longlink, (double*)in, (double)kappa, dagger_bit);
00147     }
00148   }else{
00149     if (gPrecision == QUDA_DOUBLE_PRECISION){ 
00150       Mat((float*)out, (double**)fatlink, (double**)longlink, (float*)in, (float)kappa, dagger_bit);
00151     }else {
00152       Mat((float*)out, (float**)fatlink, (float**)longlink, (float*)in, (float)kappa, dagger_bit);
00153     }
00154   }
00155 }
00156 
00157 
00158 
00159 template <typename sFloat, typename gFloat>
00160 void
00161 Matdagmat(sFloat *out, gFloat **fatlink, gFloat** longlink, sFloat *in, sFloat mass, int daggerBit, sFloat* tmp, QudaParity parity) 
00162 {
00163     
00164   sFloat msq_x4 = mass*mass*4;
00165 
00166   switch(parity){
00167   case QUDA_EVEN_PARITY:
00168     {
00169       sFloat *inEven = in;
00170       sFloat *outEven = out;
00171       dslashReference(tmp, fatlink, longlink, inEven, 1, daggerBit);
00172       dslashReference(outEven, fatlink, longlink, tmp, 0, daggerBit);
00173             
00174       // lastly apply the mass term
00175       axmy(inEven, msq_x4, outEven, Vh*mySpinorSiteSize);
00176       break;
00177     }
00178   case QUDA_ODD_PARITY:
00179     {
00180       sFloat *inOdd = in;
00181       sFloat *outOdd = out;
00182       dslashReference(tmp, fatlink, longlink, inOdd, 0, daggerBit);
00183       dslashReference(outOdd, fatlink, longlink, tmp, 1, daggerBit);
00184             
00185       // lastly apply the mass term
00186       axmy(inOdd, msq_x4, outOdd, Vh*mySpinorSiteSize);
00187       break;    
00188     }
00189         
00190   default:
00191     fprintf(stderr, "ERROR: invalid parity in %s,line %d\n", __FUNCTION__, __LINE__);
00192     break;
00193   }
00194     
00195 }
00196 
00197 
00198 
00199 void 
00200 matdagmat(void *out, void **fatlink, void** longlink, void *in, double mass, int dagger_bit,
00201           QudaPrecision sPrecision, QudaPrecision gPrecision, void* tmp, QudaParity parity) 
00202 {
00203   
00204   if (sPrecision == QUDA_DOUBLE_PRECISION){
00205     if (gPrecision == QUDA_DOUBLE_PRECISION) {
00206       Matdagmat((double*)out, (double**)fatlink, (double**)longlink, (double*)in, (double)mass, dagger_bit, (double*)tmp, parity);
00207     }else {
00208       Matdagmat((double*)out, (float**)fatlink, (float**)longlink, (double*)in, (double)mass, dagger_bit, (double*) tmp, parity);
00209     }
00210   }else{
00211     if (gPrecision == QUDA_DOUBLE_PRECISION){ 
00212       Matdagmat((float*)out, (double**)fatlink, (double**)longlink, (float*)in, (float)mass, dagger_bit, (float*)tmp, parity);
00213     }else {
00214       Matdagmat((float*)out, (float**)fatlink, (float**)longlink, (float*)in, (float)mass, dagger_bit, (float*)tmp, parity);
00215     }
00216   }
00217 }
00218 
00219 
00220 
00221 
00222 
00223 // Apply the even-odd preconditioned Dirac operator
00224 template <typename sFloat, typename gFloat>
00225 static void MatPC(sFloat *outEven, gFloat **fatlink, gFloat** longlink, sFloat *inEven, sFloat kappa, 
00226                   int daggerBit, MatPCType matpc_type) {
00227     
00228   sFloat *tmp = (sFloat*)malloc(Vh*mySpinorSiteSize*sizeof(sFloat));
00229     
00230   // full dslash operator
00231   if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
00232     dslashReference(tmp, fatlink, longlink, inEven, 1, daggerBit);
00233     dslashReference(outEven, fatlink, longlink, tmp, 0, daggerBit);
00234 
00235     //dslashReference(outEven, fatlink, longlink, inEven, 1, daggerBit);
00236   } else {
00237     dslashReference(tmp, fatlink, longlink, inEven, 0, daggerBit);
00238     dslashReference(outEven, fatlink, longlink, tmp, 1, daggerBit);
00239   }    
00240   
00241   // lastly apply the kappa term
00242     
00243   sFloat kappa2 = -kappa*kappa;
00244   xpay(inEven, kappa2, outEven, Vh*mySpinorSiteSize);
00245     
00246   free(tmp);
00247 }
00248 
00249 
00250 void
00251 staggered_matpc(void *outEven, void **fatlink, void**longlink, void *inEven, double kappa, 
00252                 MatPCType matpc_type, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision) 
00253 {
00254     
00255   if (sPrecision == QUDA_DOUBLE_PRECISION)
00256     if (gPrecision == QUDA_DOUBLE_PRECISION) {
00257       MatPC((double*)outEven, (double**)fatlink, (double**)longlink, (double*)inEven, (double)kappa, dagger_bit, matpc_type);
00258     }
00259     else{
00260       MatPC((double*)outEven, (double**)fatlink, (double**)longlink, (double*)inEven, (double)kappa, dagger_bit, matpc_type);
00261     }
00262   else {
00263     if (gPrecision == QUDA_DOUBLE_PRECISION){ 
00264       MatPC((float*)outEven, (double**)fatlink, (double**)longlink, (float*)inEven, (float)kappa, dagger_bit, matpc_type);
00265     }else{
00266       MatPC((float*)outEven, (float**)fatlink, (float**)longlink, (float*)inEven, (float)kappa, dagger_bit, matpc_type);
00267     }
00268   }
00269 }
00270 
00271 #ifdef MULTI_GPU
00272 
00273 template <typename sFloat, typename gFloat>
00274 void dslashReference_mg4dir(sFloat *res, gFloat **fatlink, gFloat** longlink, 
00275                             gFloat** ghostFatlink, gFloat** ghostLonglink,
00276                             sFloat *spinorField, sFloat** fwd_nbr_spinor, 
00277                             sFloat** back_nbr_spinor, int oddBit, int daggerBit)
00278 {
00279   for (int i=0; i<Vh*1*3*2; i++) res[i] = 0.0;
00280 
00281   int Vsh[4] = {Vsh_x, Vsh_y, Vsh_z, Vsh_t};
00282   gFloat *fatlinkEven[4], *fatlinkOdd[4];
00283   gFloat *longlinkEven[4], *longlinkOdd[4];
00284   gFloat *ghostFatlinkEven[4], *ghostFatlinkOdd[4];
00285   gFloat *ghostLonglinkEven[4], *ghostLonglinkOdd[4];
00286 
00287   for (int dir = 0; dir < 4; dir++) {
00288     fatlinkEven[dir] = fatlink[dir];
00289     fatlinkOdd[dir] = fatlink[dir] + Vh*gaugeSiteSize;
00290     longlinkEven[dir] =longlink[dir];
00291     longlinkOdd[dir] = longlink[dir] + Vh*gaugeSiteSize;
00292     
00293     ghostFatlinkEven[dir] = ghostFatlink[dir];
00294     ghostFatlinkOdd[dir] = ghostFatlink[dir] + Vsh[dir]*gaugeSiteSize;
00295     ghostLonglinkEven[dir] = ghostLonglink[dir];
00296     ghostLonglinkOdd[dir] = ghostLonglink[dir] + 3*Vsh[dir]*gaugeSiteSize;
00297   }
00298 
00299   for (int i = 0; i < Vh; i++) {
00300     memset(res + i*mySpinorSiteSize, 0, mySpinorSiteSize*sizeof(sFloat));
00301     for (int dir = 0; dir < 8; dir++) {
00302       gFloat* fatlnk = gaugeLink_mg4dir(i, dir, oddBit, fatlinkEven, fatlinkOdd, ghostFatlinkEven, ghostFatlinkOdd, 1, 1);
00303       gFloat* longlnk = gaugeLink_mg4dir(i, dir, oddBit, longlinkEven, longlinkOdd, ghostLonglinkEven, ghostLonglinkOdd, 3, 3);
00304 
00305       sFloat *first_neighbor_spinor = spinorNeighbor_mg4dir(i, dir, oddBit, spinorField, fwd_nbr_spinor, back_nbr_spinor, 1, 3);
00306       sFloat *third_neighbor_spinor = spinorNeighbor_mg4dir(i, dir, oddBit, spinorField, fwd_nbr_spinor, back_nbr_spinor, 3, 3);
00307 
00308       sFloat gaugedSpinor[mySpinorSiteSize];
00309 
00310 
00311       if (dir % 2 == 0){
00312         su3Mul(gaugedSpinor, fatlnk, first_neighbor_spinor);
00313         sum(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize);
00314         su3Mul(gaugedSpinor, longlnk, third_neighbor_spinor);
00315         sum(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize);                                                        
00316       }
00317       else{
00318         su3Tmul(gaugedSpinor, fatlnk, first_neighbor_spinor);
00319         sub(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize);
00320 
00321         su3Tmul(gaugedSpinor, longlnk, third_neighbor_spinor);
00322         sub(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize);
00323         
00324       }
00325 
00326     }
00327     if (daggerBit){
00328       negx(&res[i*mySpinorSiteSize], mySpinorSiteSize);
00329     }
00330   }
00331 
00332 }
00333 
00334 
00335 
00336 void staggered_dslash_mg4dir(cpuColorSpinorField* out, void **fatlink, void** longlink, void** ghost_fatlink, 
00337                              void** ghost_longlink, cpuColorSpinorField* in, int oddBit, int daggerBit, 
00338                              QudaPrecision sPrecision, QudaPrecision gPrecision)
00339 {
00340 
00341   QudaParity otherparity = QUDA_INVALID_PARITY;
00342   if (oddBit == QUDA_EVEN_PARITY){
00343     otherparity = QUDA_ODD_PARITY;
00344   }else if (oddBit == QUDA_ODD_PARITY){
00345     otherparity = QUDA_EVEN_PARITY;
00346   }else{
00347     errorQuda("ERROR: full parity not supported in function %s", __FUNCTION__);
00348   }
00349 
00350   int Nc = 3;
00351   int nFace = 3;
00352   FaceBuffer faceBuf(Z, 4, 2*Nc, nFace, sPrecision);
00353   faceBuf.exchangeCpuSpinor(*in, otherparity, daggerBit); 
00354   
00355   void** fwd_nbr_spinor = in->fwdGhostFaceBuffer;
00356   void** back_nbr_spinor = in->backGhostFaceBuffer;
00357 
00358   if (sPrecision == QUDA_DOUBLE_PRECISION) {
00359     if (gPrecision == QUDA_DOUBLE_PRECISION){
00360       dslashReference_mg4dir((double*)out->V(), (double**)fatlink, (double**)longlink,  
00361                              (double**)ghost_fatlink, (double**)ghost_longlink, (double*)in->V(), 
00362                              (double**)fwd_nbr_spinor, (double**)back_nbr_spinor, oddBit, daggerBit);
00363     } else {
00364       dslashReference_mg4dir((double*)out->V(), (float**)fatlink, (float**)longlink, (float**)ghost_fatlink, (float**)ghost_longlink,
00365                              (double*)in->V(), (double**)fwd_nbr_spinor, (double**)back_nbr_spinor, oddBit, daggerBit);
00366     }
00367   }
00368   else{
00369     if (gPrecision == QUDA_DOUBLE_PRECISION){
00370       dslashReference_mg4dir((float*)out->V(), (double**)fatlink, (double**)longlink, (double**)ghost_fatlink, (double**)ghost_longlink,
00371                              (float*)in->V(), (float**)fwd_nbr_spinor, (float**)back_nbr_spinor, oddBit, daggerBit);
00372     }else{
00373       dslashReference_mg4dir((float*)out->V(), (float**)fatlink, (float**)longlink, (float**)ghost_fatlink, (float**)ghost_longlink,
00374                              (float*)in->V(), (float**)fwd_nbr_spinor, (float**)back_nbr_spinor, oddBit, daggerBit);
00375     }
00376   }
00377   
00378   
00379 }
00380 
00381 void 
00382 matdagmat_mg4dir(cpuColorSpinorField* out, void **fatlink, void** longlink, void** ghost_fatlink, void** ghost_longlink, 
00383                  cpuColorSpinorField* in, double mass, int dagger_bit,
00384                  QudaPrecision sPrecision, QudaPrecision gPrecision, cpuColorSpinorField* tmp, QudaParity parity) 
00385 {
00386   //assert sPrecision and gPrecision must be the same
00387   if (sPrecision != gPrecision){
00388     errorQuda("Spinor precision and gPrecison is not the same");
00389   }
00390   
00391   QudaParity otherparity = QUDA_INVALID_PARITY;
00392   if (parity == QUDA_EVEN_PARITY){
00393     otherparity = QUDA_ODD_PARITY;
00394   }else if (parity == QUDA_ODD_PARITY){
00395     otherparity = QUDA_EVEN_PARITY;
00396   }else{
00397     errorQuda("ERROR: full parity not supported in function %s\n", __FUNCTION__);
00398   }
00399   
00400   staggered_dslash_mg4dir(tmp, fatlink, longlink, ghost_fatlink, ghost_longlink,
00401                           in, otherparity, dagger_bit, sPrecision, gPrecision);
00402 
00403   staggered_dslash_mg4dir(out, fatlink, longlink, ghost_fatlink, ghost_longlink,
00404                           tmp, parity, dagger_bit, sPrecision, gPrecision);
00405   
00406   double msq_x4 = mass*mass*4;
00407   if (sPrecision == QUDA_DOUBLE_PRECISION){
00408     axmy((double*)in->V(), (double)msq_x4, (double*)out->V(), Vh*mySpinorSiteSize);
00409   }else{
00410     axmy((float*)in->V(), (float)msq_x4, (float*)out->V(), Vh*mySpinorSiteSize);    
00411   }
00412 
00413 }
00414 
00415 #endif
00416 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines