QUDA v0.3.2
A library for QCD on GPUs

quda/tests/staggered_invert_test.cpp

Go to the documentation of this file.
00001 #include <stdlib.h>
00002 #include <stdio.h>
00003 #include <time.h>
00004 #include <math.h>
00005 
00006 #include <test_util.h>
00007 #include <blas_reference.h>
00008 #include <staggered_dslash_reference.h>
00009 #include <quda.h>
00010 #include <string.h>
00011 #include "misc.h"
00012 
00013 #define mySpinorSiteSize 6
00014 
00015 int device = 0;
00016 QudaReconstructType link_recon = QUDA_RECONSTRUCT_12;
00017 QudaPrecision prec = QUDA_SINGLE_PRECISION;
00018 QudaPrecision cpu_prec = QUDA_DOUBLE_PRECISION;
00019 
00020 QudaReconstructType link_recon_sloppy = QUDA_RECONSTRUCT_INVALID;
00021 QudaPrecision  prec_sloppy = QUDA_INVALID_PRECISION;
00022 
00023 static double tol = 1e-8;
00024 
00025 static int testtype = 0;
00026 static int sdim = 24;
00027 static int tdim = 24;
00028 
00029 extern int V;
00030 
00031 template<typename Float>
00032 void constructSpinorField(Float *res) {
00033   for(int i = 0; i < V; i++) {
00034     for (int s = 0; s < 1; s++) {
00035       for (int m = 0; m < 3; m++) {
00036         res[i*(1*3*2) + s*(3*2) + m*(2) + 0] = rand() / (Float)RAND_MAX;
00037         res[i*(1*3*2) + s*(3*2) + m*(2) + 1] = rand() / (Float)RAND_MAX;
00038       }
00039     }
00040   }
00041 }
00042 
00043 
00044 static int
00045 invert_test(void)
00046 {
00047   void *fatlink[4];
00048   void *longlink[4];
00049     
00050   QudaGaugeParam gauge_param;
00051   QudaInvertParam inv_param;
00052 
00053   gauge_param.X[0] = sdim;
00054   gauge_param.X[1] = sdim;
00055   gauge_param.X[2] = sdim;
00056   gauge_param.X[3] = tdim;
00057   setDims(gauge_param.X);
00058     
00059   gauge_param.cpu_prec = cpu_prec;
00060     
00061   gauge_param.cuda_prec = prec;
00062   gauge_param.reconstruct = link_recon;
00063 
00064   gauge_param.cuda_prec_sloppy = prec_sloppy;
00065   gauge_param.reconstruct_sloppy = link_recon_sloppy;
00066   
00067   gauge_param.gauge_fix = QUDA_GAUGE_FIXED_NO;
00068 
00069   gauge_param.tadpole_coeff = 0.8;
00070 
00071   inv_param.verbosity = QUDA_VERBOSE;
00072   inv_param.inv_type = QUDA_CG_INVERTER;
00073 
00074   gauge_param.t_boundary = QUDA_ANTI_PERIODIC_T;
00075   gauge_param.gauge_order = QUDA_QDP_GAUGE_ORDER;
00076     
00077   double mass = 0.95;
00078   inv_param.mass = mass;
00079   inv_param.tol = tol;
00080   inv_param.maxiter = 100;
00081   inv_param.reliable_delta = 1e-3;
00082 
00083   inv_param.solution_type = QUDA_MATDAG_MAT_SOLUTION;
00084   inv_param.solve_type = QUDA_NORMEQ_PC_SOLVE;
00085   inv_param.matpc_type = QUDA_MATPC_EVEN_EVEN;
00086   inv_param.dagger = QUDA_DAG_NO;
00087   inv_param.mass_normalization = QUDA_MASS_NORMALIZATION;
00088 
00089   inv_param.cpu_prec = cpu_prec;
00090   inv_param.cuda_prec = prec; 
00091   inv_param.cuda_prec_sloppy = prec_sloppy;
00092   inv_param.preserve_source = QUDA_PRESERVE_SOURCE_YES;
00093   inv_param.dirac_order = QUDA_DIRAC_ORDER;
00094   inv_param.dslash_type = QUDA_ASQTAD_DSLASH;
00095   gauge_param.ga_pad = sdim*sdim*sdim;
00096   inv_param.sp_pad = sdim*sdim*sdim;
00097   
00098   size_t gSize = (gauge_param.cpu_prec == QUDA_DOUBLE_PRECISION) ? sizeof(double) : sizeof(float);
00099   size_t sSize = (inv_param.cpu_prec == QUDA_DOUBLE_PRECISION) ? sizeof(double) : sizeof(float);
00100   
00101   for (int dir = 0; dir < 4; dir++) {
00102     fatlink[dir] = malloc(V*gaugeSiteSize*gSize);
00103     longlink[dir] = malloc(V*gaugeSiteSize*gSize);
00104   }
00105   construct_fat_long_gauge_field(fatlink, longlink, 1, gauge_param.cpu_prec, &gauge_param);
00106     
00107   for (int dir = 0; dir < 4; dir++) {
00108     for(int i = 0;i < V*gaugeSiteSize;i++){
00109       if (gauge_param.cpu_prec == QUDA_DOUBLE_PRECISION){
00110         ((double*)fatlink[dir])[i] = 0.5 *rand()/RAND_MAX;
00111       }else{
00112         ((float*)fatlink[dir])[i] = 0.5* rand()/RAND_MAX;
00113       }
00114     }
00115   }
00116     
00117   void *spinorIn = malloc(V*mySpinorSiteSize*sSize);
00118   void *spinorOut = malloc(V*mySpinorSiteSize*sSize);
00119   void *spinorCheck = malloc(V*mySpinorSiteSize*sSize);
00120   void *tmp = malloc(V*mySpinorSiteSize*sSize);
00121     
00122   memset(spinorIn, 0, V*mySpinorSiteSize*sSize);
00123   memset(spinorOut, 0, V*mySpinorSiteSize*sSize);
00124   memset(spinorCheck, 0, V*mySpinorSiteSize*sSize);
00125   memset(tmp, 0, V*mySpinorSiteSize*sSize);
00126 
00127   if (inv_param.cpu_prec == QUDA_SINGLE_PRECISION){
00128     constructSpinorField((float*)spinorIn);    
00129   }else{
00130     constructSpinorField((double*)spinorIn);
00131   }
00132   
00133   void* spinorInOdd = ((char*)spinorIn) + Vh*mySpinorSiteSize*sSize;
00134   void* spinorOutOdd = ((char*)spinorOut) + Vh*mySpinorSiteSize*sSize;
00135   void* spinorCheckOdd = ((char*)spinorCheck) + Vh*mySpinorSiteSize*sSize;
00136   
00137   initQuda(device);
00138 
00139   gauge_param.type = QUDA_ASQTAD_FAT_LINKS;
00140   gauge_param.reconstruct = gauge_param.reconstruct_sloppy = QUDA_RECONSTRUCT_NO;
00141   loadGaugeQuda(fatlink, &gauge_param);
00142 
00143   gauge_param.type = QUDA_ASQTAD_LONG_LINKS;
00144   gauge_param.reconstruct = link_recon;
00145   gauge_param.reconstruct_sloppy = link_recon_sloppy;
00146   loadGaugeQuda(longlink, &gauge_param);
00147 
00148   double time0 = -((double)clock()); // Start the timer
00149   
00150   unsigned long volume = Vh;
00151   unsigned long nflops=2*1187; //from MILC's CG routine
00152   double nrm2=0;
00153   double src2=0;
00154   switch(testtype){
00155 
00156   case 0: //even
00157     volume = Vh;
00158     inv_param.solution_type = QUDA_MATPCDAG_MATPC_SOLUTION;
00159     inv_param.matpc_type = QUDA_MATPC_EVEN_EVEN;
00160     
00161     invertQuda(spinorOut, spinorIn, &inv_param);
00162     
00163     time0 += clock(); 
00164     time0 /= CLOCKS_PER_SEC;
00165     
00166     matdagmat_milc(spinorCheck, fatlink, longlink, spinorOut, mass, 0, inv_param.cpu_prec, gauge_param.cpu_prec, tmp, QUDA_EVEN);
00167     
00168     mxpy(spinorIn, spinorCheck, Vh*mySpinorSiteSize, inv_param.cpu_prec);
00169     nrm2 = norm_2(spinorCheck, Vh*mySpinorSiteSize, inv_param.cpu_prec);
00170     src2 = norm_2(spinorIn, Vh*mySpinorSiteSize, inv_param.cpu_prec);
00171     break;
00172 
00173   case 1: //odd
00174         
00175     volume = Vh;    
00176     inv_param.solution_type = QUDA_MATPCDAG_MATPC_SOLUTION;
00177     inv_param.matpc_type = QUDA_MATPC_ODD_ODD;
00178     invertQuda(spinorOutOdd, spinorInOdd, &inv_param);  
00179     time0 += clock(); // stop the timer
00180     time0 /= CLOCKS_PER_SEC;
00181     
00182     
00183     matdagmat_milc(spinorCheckOdd, fatlink, longlink, spinorOutOdd, mass, 0, inv_param.cpu_prec, gauge_param.cpu_prec, tmp, QUDA_ODD);  
00184     mxpy(spinorInOdd, spinorCheckOdd, Vh*mySpinorSiteSize, inv_param.cpu_prec);
00185     nrm2 = norm_2(spinorCheckOdd, Vh*mySpinorSiteSize, inv_param.cpu_prec);
00186     src2 = norm_2(spinorInOdd, Vh*mySpinorSiteSize, inv_param.cpu_prec);
00187         
00188     break;
00189     
00190   case 2: //full spinor
00191 
00192     volume = Vh; //FIXME: the time reported is only parity time
00193     inv_param.solve_type = QUDA_NORMEQ_SOLVE;
00194     inv_param.solution_type = QUDA_MATDAG_MAT_SOLUTION;
00195     invertQuda(spinorOut, spinorIn, &inv_param);
00196     
00197     time0 += clock(); // stop the timer
00198     time0 /= CLOCKS_PER_SEC;
00199     
00200     matdagmat_milc(spinorCheck, fatlink, longlink, spinorOut, mass, 0, inv_param.cpu_prec, gauge_param.cpu_prec, tmp, QUDA_EVENODD);
00201     
00202     mxpy(spinorIn, spinorCheck, V*mySpinorSiteSize, inv_param.cpu_prec);
00203     nrm2 = norm_2(spinorCheck, V*mySpinorSiteSize, inv_param.cpu_prec);
00204     src2 = norm_2(spinorIn, V*mySpinorSiteSize, inv_param.cpu_prec);
00205 
00206     break;
00207 
00208   case 3: //multi mass CG, even
00209   case 4:
00210   case 5:
00211 
00212 #define NUM_OFFSETS 4
00213         
00214     nflops = 2*(1205 + 15* NUM_OFFSETS); //from MILC's multimass CG routine
00215     double masses[NUM_OFFSETS] ={5.05, 1.23, 2.64, 2.33};
00216     double offsets[NUM_OFFSETS];        
00217     int num_offsets =NUM_OFFSETS;
00218     void* spinorOutArray[NUM_OFFSETS];
00219     void* in;
00220     int len;
00221     
00222     for (int i=0; i< num_offsets;i++){
00223       offsets[i] = 4*masses[i]*masses[i];
00224     }
00225     
00226     if (testtype == 3){
00227       in=spinorIn;
00228       len=Vh;
00229       volume = Vh;
00230       
00231       inv_param.solution_type = QUDA_MATPCDAG_MATPC_SOLUTION;
00232       inv_param.matpc_type = QUDA_MATPC_EVEN_EVEN;      
00233       
00234       spinorOutArray[0] = spinorOut;
00235       for (int i=1; i< num_offsets;i++){
00236         spinorOutArray[i] = malloc(Vh*mySpinorSiteSize*sSize);
00237       }         
00238     }
00239     
00240     else if (testtype ==4){
00241       in=spinorInOdd;
00242       len = Vh;
00243       volume = Vh;
00244 
00245       inv_param.solution_type = QUDA_MATPCDAG_MATPC_SOLUTION;
00246       inv_param.matpc_type = QUDA_MATPC_ODD_ODD;
00247       
00248       spinorOutArray[0] = spinorOutOdd;
00249       for (int i=1; i< num_offsets;i++){
00250         spinorOutArray[i] = malloc(Vh*mySpinorSiteSize*sSize);
00251       }
00252     }else { //testtype ==5
00253       in=spinorIn;
00254       len= V;
00255       inv_param.solution_type = QUDA_MATDAG_MAT_SOLUTION;
00256       inv_param.solve_type = QUDA_NORMEQ_SOLVE;
00257       volume = Vh; //FIXME: the time reported is only parity time
00258       spinorOutArray[0] = spinorOut;
00259       for (int i=1; i< num_offsets;i++){
00260         spinorOutArray[i] = malloc(V*mySpinorSiteSize*sSize);
00261       }         
00262     }
00263     
00264     double residue_sq;
00265     invertMultiShiftQuda(spinorOutArray, in, &inv_param, offsets, num_offsets, &residue_sq);    
00266     cudaThreadSynchronize();
00267     printf("Final residue squred =%g\n", residue_sq);
00268     time0 += clock(); // stop the timer
00269     time0 /= CLOCKS_PER_SEC;
00270     
00271     printf("done: total time = %g secs, %i iter / %g secs = %g gflops, \n", 
00272            time0, inv_param.iter, inv_param.secs,
00273            inv_param.gflops/inv_param.secs);
00274 
00275     
00276     printf("checking the solution\n");
00277     MyQudaParity parity;
00278     if (inv_param.solve_type == QUDA_NORMEQ_SOLVE){
00279       parity = QUDA_EVENODD;
00280     }else if (inv_param.matpc_type == QUDA_MATPC_EVEN_EVEN){
00281       parity = QUDA_EVEN;
00282     }else if (inv_param.matpc_type == QUDA_MATPC_ODD_ODD){
00283       parity = QUDA_ODD;
00284     }else{
00285       printf("ERROR: invalid spinor parity \n");
00286       exit(1);
00287     }
00288     
00289     for(int i=0;i < num_offsets;i++){
00290       printf("%dth solution: mass=%f", i, masses[i]);
00291       matdagmat_milc(spinorCheck, fatlink, longlink, spinorOutArray[i], masses[i], 0, inv_param.cpu_prec, gauge_param.cpu_prec, tmp, parity);
00292       mxpy(in, spinorCheck, len*mySpinorSiteSize, inv_param.cpu_prec);
00293       double nrm2 = norm_2(spinorCheck, len*mySpinorSiteSize, inv_param.cpu_prec);
00294       double src2 = norm_2(in, len*mySpinorSiteSize, inv_param.cpu_prec);
00295       printf("relative residual, requested = %g, actual = %g\n", inv_param.tol, sqrt(nrm2/src2));
00296     }
00297     
00298     for(int i=1; i < num_offsets;i++){
00299       free(spinorOutArray[i]);
00300     }
00301 
00302     
00303   }//switch
00304     
00305   if (testtype <=2){
00306     printf("Relative residual, requested = %g, actual = %g\n", inv_param.tol, sqrt(nrm2/src2));
00307         
00308     printf("done: total time = %g secs, %i iter / %g secs = %g gflops, \n", 
00309            time0, inv_param.iter, inv_param.secs,
00310            inv_param.gflops/inv_param.secs);
00311   }
00312   endQuda();
00313 
00314   if (tmp){
00315     free(tmp);
00316   }
00317   return 0;
00318 }
00319 
00320 
00321 
00322 
00323 void
00324 display_test_info()
00325 {
00326   printf("running the following test:\n");
00327     
00328   printf("prec    sloppy_prec    link_recon  sloppy_link_recon test_type  S_dimension T_dimension\n");
00329   printf("%s   %s             %s            %s            %s         %d          %d \n",
00330          get_prec_str(prec),get_prec_str(prec_sloppy),
00331          get_recon_str(link_recon), 
00332          get_recon_str(link_recon_sloppy), get_test_type(testtype), sdim, tdim);     
00333   return ;
00334   
00335 }
00336 
00337 void
00338 usage(char** argv )
00339 {
00340   printf("Usage: %s <args>\n", argv[0]);
00341   printf("--prec         <double/single/half>     Spinor/gauge precision\n"); 
00342   printf("--prec_sloppy  <double/single/half>     Spinor/gauge sloppy precision\n"); 
00343   printf("--recon        <8/12>                   Long link reconstruction type\n"); 
00344   printf("--test         <0/1/2/3/4/5>            Testing type(0=even, 1=odd, 2=full, 3=multimass even,\n" 
00345          "                                                     4=multimass odd, 5=multimass full)\n"); 
00346   printf("--tdim                                  T dimension\n");
00347   printf("--sdim                                  S dimension\n");
00348   printf("--help                                  Print out this message\n"); 
00349   exit(1);
00350   return ;
00351 }
00352 
00353 
00354 int main(int argc, char** argv)
00355 {
00356 
00357   int i;
00358   for (i =1;i < argc; i++){
00359         
00360     if( strcmp(argv[i], "--help")== 0){
00361       usage(argv);
00362     }
00363         
00364     if( strcmp(argv[i], "--prec") == 0){
00365       if (i+1 >= argc){
00366         usage(argv);
00367       }     
00368       prec = get_prec(argv[i+1]);
00369       i++;
00370       continue;     
00371     }
00372     
00373     if( strcmp(argv[i], "--prec_sloppy") == 0){
00374       if (i+1 >= argc){
00375         usage(argv);
00376       }     
00377       prec_sloppy =  get_prec(argv[i+1]);
00378       i++;
00379       continue;     
00380     }
00381     
00382     
00383     if( strcmp(argv[i], "--recon") == 0){
00384       if (i+1 >= argc){
00385         usage(argv);
00386       }     
00387       link_recon =  get_recon(argv[i+1]);
00388       i++;
00389       continue;     
00390     }
00391     if( strcmp(argv[i], "--tol") == 0){
00392       float tmpf;
00393       if (i+1 >= argc){
00394         usage(argv);
00395       }
00396       sscanf(argv[i+1], "%f", &tmpf);
00397       if (tol <= 0){
00398         PRINTF("ERROR: invalid tol(%f)\n", tmpf);
00399         usage(argv);
00400       }
00401       tol = tmpf;
00402       i++;
00403       continue;
00404     }
00405 
00406 
00407         
00408     if( strcmp(argv[i], "--recon_sloppy") == 0){
00409       if (i+1 >= argc){
00410         usage(argv);
00411       }     
00412       link_recon_sloppy =  get_recon(argv[i+1]);
00413       i++;
00414       continue;     
00415     }
00416         
00417     if( strcmp(argv[i], "--test") == 0){
00418       if (i+1 >= argc){
00419         usage(argv);
00420       }     
00421       testtype = atoi(argv[i+1]);
00422       i++;
00423       continue;     
00424     }
00425 
00426     if( strcmp(argv[i], "--cprec") == 0){
00427       if (i+1 >= argc){
00428         usage(argv);
00429       }
00430       cpu_prec= get_prec(argv[i+1]);
00431       i++;
00432       continue;
00433     }
00434 
00435     if( strcmp(argv[i], "--tdim") == 0){
00436       if (i+1 >= argc){
00437         usage(argv);
00438       }
00439       tdim= atoi(argv[i+1]);
00440       if (tdim < 0 || tdim > 128){
00441         printf("ERROR: invalid T dimention (%d)\n", tdim);
00442         usage(argv);
00443       }
00444       i++;
00445       continue;
00446     }           
00447     if( strcmp(argv[i], "--sdim") == 0){
00448       if (i+1 >= argc){
00449         usage(argv);
00450       }
00451       sdim= atoi(argv[i+1]);
00452       if (sdim < 0 || sdim > 128){
00453         printf("ERROR: invalid S dimention (%d)\n", sdim);
00454         usage(argv);
00455       }
00456       i++;
00457       continue;
00458     }
00459     if( strcmp(argv[i], "--device") == 0){
00460           if (i+1 >= argc){
00461               usage(argv);
00462           }
00463           device =  atoi(argv[i+1]);
00464           if (device < 0){
00465               fprintf(stderr, "Error: invalid device number(%d)\n", device);
00466               exit(1);
00467           }
00468           i++;
00469           continue;
00470     }
00471 
00472 
00473     fprintf(stderr, "ERROR: Invalid option:%s\n", argv[i]);
00474     usage(argv);
00475   }
00476 
00477 
00478   if (prec_sloppy == QUDA_INVALID_PRECISION){
00479     prec_sloppy = prec;
00480   }
00481   if (link_recon_sloppy == QUDA_RECONSTRUCT_INVALID){
00482     link_recon_sloppy = link_recon;
00483   }
00484   
00485   display_test_info();
00486   invert_test();
00487     
00488 
00489   return 0;
00490 }
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Defines