QUDA v0.3.2
A library for QCD on GPUs

quda/tests/staggered_dslash_reference.cpp

Go to the documentation of this file.
00001 #include <stdio.h>
00002 #include <stdlib.h>
00003 #include <math.h>
00004 #include <string.h>
00005 
00006 #include <test_util.h>
00007 #include <quda_internal.h>
00008 #include <quda.h>
00009 #include <util_quda.h>
00010 #include <staggered_dslash_reference.h>
00011 #include "misc.h"
00012 extern void *memset(void *s, int c, size_t n);
00013 
00014 static int mySpinorSiteSize = 6;
00015 
00016 int Z[4];
00017 int V;
00018 int Vh;
00019 
00020 void setDims(int *X) {
00021   V = 1;
00022   for (int d=0; d< 4; d++) {
00023     V *= X[d];
00024     Z[d] = X[d];
00025   }
00026   Vh = V/2;
00027 }
00028 
00029 template <typename Float>
00030 void sum(Float *dst, Float *a, Float *b, int cnt) {
00031   for (int i = 0; i < cnt; i++)
00032     dst[i] = a[i] + b[i];
00033 }
00034 template <typename Float>
00035 void sub(Float *dst, Float *a, Float *b, int cnt) {
00036   for (int i = 0; i < cnt; i++)
00037     dst[i] = a[i] - b[i];
00038 }
00039 // performs the operation y[i] = x[i] + a*y[i]
00040 template <typename Float>
00041 void xpay(Float *x, Float a, Float *y, int len) {
00042     for (int i=0; i<len; i++) y[i] = x[i] + a*y[i];
00043 }
00044 // performs the operation y[i] = a*x[i] - y[i]
00045 template <typename Float>
00046 void axmy(Float *x, Float a, Float *y, int len) {
00047     for (int i=0; i<len; i++) y[i] = a*x[i] - y[i];
00048 }
00049 
00050 template <typename Float>
00051 void negx(Float *x, int len) {
00052     for (int i=0; i<len; i++) x[i] = -x[i];
00053 }
00054 
00055 // i represents a "half index" into an even or odd "half lattice".
00056 // when oddBit={0,1} the half lattice is {even,odd}.
00057 // 
00058 // the displacements, such as dx, refer to the full lattice coordinates. 
00059 //
00060 // neighborIndex() takes a "half index", displaces it, and returns the
00061 // new "half index", which can be an index into either the even or odd lattices.
00062 // displacements of magnitude one always interchange odd and even lattices.
00063 //
00064 
00065 
00066 template <typename Float>
00067 Float *gaugeLink(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd, int nbr_distance) {
00068   Float **gaugeField;
00069   int j;
00070   int d = nbr_distance;
00071   if (dir % 2 == 0) {
00072     j = i;
00073     gaugeField = (oddBit ? gaugeOdd : gaugeEven);
00074   }
00075   else {
00076     switch (dir) {
00077     case 1: j = neighborIndex(i, oddBit, 0, 0, 0, -d); break;
00078     case 3: j = neighborIndex(i, oddBit, 0, 0, -d, 0); break;
00079     case 5: j = neighborIndex(i, oddBit, 0, -d, 0, 0); break;
00080     case 7: j = neighborIndex(i, oddBit, -d, 0, 0, 0); break;
00081     default: j = -1; break;
00082     }
00083     gaugeField = (oddBit ? gaugeEven : gaugeOdd);
00084   }
00085   
00086   return &gaugeField[dir/2][j*(3*3*2)];
00087 }
00088 
00089 
00090 
00091 template <typename Float>
00092 Float *spinorNeighbor(int i, int dir, int oddBit, Float *spinorField, int neighbor_distance) 
00093 {
00094     int j;
00095     int nb = neighbor_distance;
00096     switch (dir) {
00097     case 0: j = neighborIndex(i, oddBit, 0, 0, 0, +nb); break;
00098     case 1: j = neighborIndex(i, oddBit, 0, 0, 0, -nb); break;
00099     case 2: j = neighborIndex(i, oddBit, 0, 0, +nb, 0); break;
00100     case 3: j = neighborIndex(i, oddBit, 0, 0, -nb, 0); break;
00101     case 4: j = neighborIndex(i, oddBit, 0, +nb, 0, 0); break;
00102     case 5: j = neighborIndex(i, oddBit, 0, -nb, 0, 0); break;
00103     case 6: j = neighborIndex(i, oddBit, +nb, 0, 0, 0); break;
00104     case 7: j = neighborIndex(i, oddBit, -nb, 0, 0, 0); break;
00105     default: j = -1; break;
00106     }
00107     
00108     return &spinorField[j*(mySpinorSiteSize)];
00109 }
00110 
00111 template <typename sFloat, typename gFloat>
00112 void dot(sFloat* res, gFloat* a, sFloat* b) {
00113   res[0] = res[1] = 0;
00114   for (int m = 0; m < 3; m++) {
00115     sFloat a_re = a[2*m+0];
00116     sFloat a_im = a[2*m+1];
00117     sFloat b_re = b[2*m+0];
00118     sFloat b_im = b[2*m+1];
00119     res[0] += a_re * b_re - a_im * b_im;
00120     res[1] += a_re * b_im + a_im * b_re;
00121   }
00122 }
00123 
00124 template <typename Float>
00125 void su3Transpose(Float *res, Float *mat) {
00126   for (int m = 0; m < 3; m++) {
00127     for (int n = 0; n < 3; n++) {
00128       res[m*(3*2) + n*(2) + 0] = + mat[n*(3*2) + m*(2) + 0];
00129       res[m*(3*2) + n*(2) + 1] = - mat[n*(3*2) + m*(2) + 1];
00130     }
00131   }
00132 }
00133 
00134 
00135 template <typename sFloat, typename gFloat>
00136 void su3Mul(sFloat *res, gFloat *mat, sFloat *vec) {
00137   for (int n = 0; n < 3; n++) dot(&res[n*(2)], &mat[n*(3*2)], vec);
00138 }
00139 
00140 template <typename sFloat, typename gFloat>
00141 void su3Tmul(sFloat *res, gFloat *mat, sFloat *vec) {
00142   gFloat matT[3*3*2];
00143   su3Transpose(matT, mat);
00144   su3Mul(res, matT, vec);
00145 }
00146 
00147 
00148 //
00149 // dslashReference()
00150 //
00151 // if oddBit is zero: calculate even parity spinor elements (using odd parity spinor) 
00152 // if oddBit is one:  calculate odd parity spinor elements 
00153 //
00154 // if daggerBit is zero: perform ordinary dslash operator
00155 // if daggerBit is one:  perform hermitian conjugate of dslash
00156 //
00157 template<typename Float>
00158 void display_link_internal(Float* link)
00159 {
00160     int i, j;
00161     
00162     for (i = 0;i < 3; i++){
00163         for(j=0;j < 3; j++){
00164             printf("(%10f,%10f) \t", link[i*3*2 + j*2], link[i*3*2 + j*2 + 1]);
00165         }
00166         printf("\n");
00167     }
00168     printf("\n");
00169     return;
00170 }
00171 
00172 
00173 template <typename sFloat, typename gFloat>
00174 void dslashReference(sFloat *res, gFloat **fatlink, gFloat** longlink, sFloat *spinorField, int oddBit, int daggerBit) 
00175 {
00176     for (int i=0; i<Vh*1*3*2; i++) res[i] = 0.0;
00177     
00178     gFloat *fatlinkEven[4], *fatlinkOdd[4];
00179     gFloat *longlinkEven[4], *longlinkOdd[4];
00180     
00181     for (int dir = 0; dir < 4; dir++) {  
00182         fatlinkEven[dir] = fatlink[dir];
00183         fatlinkOdd[dir] = fatlink[dir] + Vh*gaugeSiteSize;
00184         longlinkEven[dir] =longlink[dir];
00185         longlinkOdd[dir] = longlink[dir] + Vh*gaugeSiteSize;    
00186     }
00187 
00188     for (int i = 0; i < Vh; i++) {
00189         memset(res + i*mySpinorSiteSize, 0, mySpinorSiteSize*sizeof(sFloat));
00190         for (int dir = 0; dir < 8; dir++) {
00191             gFloat* fatlnk = gaugeLink(i, dir, oddBit, fatlinkEven, fatlinkOdd, 1);
00192             gFloat* longlnk = gaugeLink(i, dir, oddBit, longlinkEven, longlinkOdd, 3);
00193 
00194             sFloat *first_neighbor_spinor = spinorNeighbor(i, dir, oddBit, spinorField, 1);
00195             sFloat *third_neighbor_spinor = spinorNeighbor(i, dir, oddBit, spinorField, 3);
00196 
00197 
00198             sFloat gaugedSpinor[mySpinorSiteSize];
00199 
00200             if (dir % 2 == 0){
00201                 su3Mul(gaugedSpinor, fatlnk, first_neighbor_spinor);
00202                 sum(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize);            
00203                 su3Mul(gaugedSpinor, longlnk, third_neighbor_spinor);
00204                 sum(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize);                                                                
00205             }
00206             else{
00207                 su3Tmul(gaugedSpinor, fatlnk, first_neighbor_spinor);
00208                 sub(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize);        
00209                 
00210                 su3Tmul(gaugedSpinor, longlnk, third_neighbor_spinor);
00211                 sub(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize);       
00212                 
00213             }               
00214         }
00215         if (daggerBit){
00216             negx(&res[i*mySpinorSiteSize], mySpinorSiteSize);
00217         }
00218     }
00219 
00220 }
00221 
00222 
00223 void staggered_dslash(void *res, void **fatlink, void** longlink, void *spinorField, int oddBit, int daggerBit,
00224                       QudaPrecision sPrecision, QudaPrecision gPrecision) {
00225     
00226     if (sPrecision == QUDA_DOUBLE_PRECISION) {
00227         if (gPrecision == QUDA_DOUBLE_PRECISION){
00228             dslashReference((double*)res, (double**)fatlink, (double**)longlink, (double*)spinorField, oddBit, daggerBit);
00229         }else{
00230             dslashReference((double*)res, (float**)fatlink, (float**)longlink, (double*)spinorField, oddBit, daggerBit);
00231         }
00232     }
00233     else{
00234         if (gPrecision == QUDA_DOUBLE_PRECISION){
00235             dslashReference((float*)res, (double**)fatlink, (double**)longlink, (float*)spinorField, oddBit, daggerBit);
00236         }else{
00237             dslashReference((float*)res, (float**)fatlink, (float**)longlink, (float*)spinorField, oddBit, daggerBit);
00238         }
00239     }
00240 }
00241 
00242 
00243 
00244 template <typename sFloat, typename gFloat>
00245 void Mat(sFloat *out, gFloat **fatlink, gFloat** longlink, sFloat *in, sFloat kappa, int daggerBit) 
00246 {
00247     sFloat *inEven = in;
00248     sFloat *inOdd  = in + Vh*mySpinorSiteSize;
00249     sFloat *outEven = out;
00250     sFloat *outOdd = out + Vh*mySpinorSiteSize;
00251     
00252     // full dslash operator
00253     dslashReference(outOdd, fatlink, longlink, inEven, 1, daggerBit);
00254     dslashReference(outEven, fatlink, longlink, inOdd, 0, daggerBit);
00255     
00256     // lastly apply the kappa term
00257     xpay(in, -kappa, out, V*mySpinorSiteSize);
00258 }
00259 
00260 
00261 void 
00262 mat(void *out, void **fatlink, void** longlink, void *in, double kappa, int dagger_bit,
00263        QudaPrecision sPrecision, QudaPrecision gPrecision) 
00264 {
00265     
00266     if (sPrecision == QUDA_DOUBLE_PRECISION){
00267         if (gPrecision == QUDA_DOUBLE_PRECISION) {
00268             Mat((double*)out, (double**)fatlink, (double**)longlink, (double*)in, (double)kappa, dagger_bit);
00269         }else {
00270             Mat((double*)out, (float**)fatlink, (float**)longlink, (double*)in, (double)kappa, dagger_bit);
00271         }
00272     }else{
00273         if (gPrecision == QUDA_DOUBLE_PRECISION){ 
00274             Mat((float*)out, (double**)fatlink, (double**)longlink, (float*)in, (float)kappa, dagger_bit);
00275         }else {
00276             Mat((float*)out, (float**)fatlink, (float**)longlink, (float*)in, (float)kappa, dagger_bit);
00277         }
00278     }
00279 }
00280 
00281 
00282 
00283 template <typename sFloat, typename gFloat>
00284 void
00285 Matdagmat_milc(sFloat *out, gFloat **fatlink, gFloat** longlink, sFloat *in, sFloat mass, int daggerBit, sFloat* tmp, MyQudaParity parity) 
00286 {
00287     
00288     sFloat msq_x4 = mass*mass*4;
00289 
00290     switch(parity){
00291     case QUDA_EVEN:
00292         {
00293             sFloat *inEven = in;
00294             sFloat *outEven = out;
00295             dslashReference(tmp, fatlink, longlink, inEven, 1, daggerBit);
00296             dslashReference(outEven, fatlink, longlink, tmp, 0, daggerBit);
00297             
00298             // lastly apply the mass term
00299             axmy(inEven, msq_x4, outEven, Vh*mySpinorSiteSize);
00300             break;
00301         }
00302     case QUDA_ODD:
00303         {
00304             sFloat *inOdd = in;
00305             sFloat *outOdd = out;
00306             dslashReference(tmp, fatlink, longlink, inOdd, 0, daggerBit);
00307             dslashReference(outOdd, fatlink, longlink, tmp, 1, daggerBit);
00308             
00309             // lastly apply the mass term
00310             axmy(inOdd, msq_x4, outOdd, Vh*mySpinorSiteSize);
00311             break;      
00312         }
00313         
00314     case QUDA_EVENODD:
00315         {
00316             sFloat *inEven = in;
00317             sFloat *inOdd = in + Vh*mySpinorSiteSize;
00318             sFloat *outEven = out;
00319             sFloat *outOdd = out + Vh*mySpinorSiteSize;
00320             sFloat *tmpEven = tmp;
00321             sFloat *tmpOdd = tmp + Vh*mySpinorSiteSize;
00322             
00323             dslashReference(tmpOdd, fatlink, longlink, inEven, 1, daggerBit);
00324             dslashReference(tmpEven, fatlink, longlink, inOdd, 0, daggerBit);
00325             
00326             dslashReference(outOdd, fatlink, longlink, tmpEven, 1, daggerBit);
00327             dslashReference(outEven, fatlink, longlink, tmpOdd, 0, daggerBit);
00328             
00329             // lastly apply the mass term
00330             axmy(in, msq_x4, out, V*mySpinorSiteSize);              
00331             break;
00332         }
00333     default:
00334         fprintf(stderr, "ERROR: invalid parity in %s,line %d\n", __FUNCTION__, __LINE__);
00335         break;
00336     }
00337     
00338 }
00339 
00340 
00341 void 
00342 matdagmat_milc(void *out, void **fatlink, void** longlink, void *in, double mass, int dagger_bit,
00343                QudaPrecision sPrecision, QudaPrecision gPrecision, void* tmp, MyQudaParity parity) 
00344 {
00345     
00346     if (sPrecision == QUDA_DOUBLE_PRECISION){
00347         if (gPrecision == QUDA_DOUBLE_PRECISION) {
00348             Matdagmat_milc((double*)out, (double**)fatlink, (double**)longlink, (double*)in, (double)mass, dagger_bit, (double*)tmp, parity);
00349         }else {
00350             Matdagmat_milc((double*)out, (float**)fatlink, (float**)longlink, (double*)in, (double)mass, dagger_bit, (double*) tmp, parity);
00351         }
00352     }else{
00353         if (gPrecision == QUDA_DOUBLE_PRECISION){ 
00354             Matdagmat_milc((float*)out, (double**)fatlink, (double**)longlink, (float*)in, (float)mass, dagger_bit, (float*)tmp, parity);
00355         }else {
00356             Matdagmat_milc((float*)out, (float**)fatlink, (float**)longlink, (float*)in, (float)mass, dagger_bit, (float*)tmp, parity);
00357         }
00358     }
00359 }
00360 
00361 
00362 
00363 // Apply the even-odd preconditioned Dirac operator
00364 template <typename sFloat, typename gFloat>
00365 static void MatPC(sFloat *outEven, gFloat **fatlink, gFloat** longlink, sFloat *inEven, sFloat kappa, 
00366               int daggerBit, MatPCType matpc_type) {
00367     
00368     sFloat *tmp = (sFloat*)malloc(Vh*mySpinorSiteSize*sizeof(sFloat));
00369     
00370     // full dslash operator
00371     if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
00372         dslashReference(tmp, fatlink, longlink, inEven, 1, daggerBit);
00373         dslashReference(outEven, fatlink, longlink, tmp, 0, daggerBit);
00374 
00375         //dslashReference(outEven, fatlink, longlink, inEven, 1, daggerBit);
00376     } else {
00377         dslashReference(tmp, fatlink, longlink, inEven, 0, daggerBit);
00378         dslashReference(outEven, fatlink, longlink, tmp, 1, daggerBit);
00379     }    
00380   
00381     // lastly apply the kappa term
00382     
00383     sFloat kappa2 = -kappa*kappa;
00384     xpay(inEven, kappa2, outEven, Vh*mySpinorSiteSize);
00385     
00386     free(tmp);
00387 }
00388 
00389 
00390 void
00391 staggered_matpc(void *outEven, void **fatlink, void**longlink, void *inEven, double kappa, 
00392                 MatPCType matpc_type, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision) 
00393 {
00394     
00395     if (sPrecision == QUDA_DOUBLE_PRECISION)
00396         if (gPrecision == QUDA_DOUBLE_PRECISION) {
00397             MatPC((double*)outEven, (double**)fatlink, (double**)longlink, (double*)inEven, (double)kappa, dagger_bit, matpc_type);
00398         }
00399         else{
00400             MatPC((double*)outEven, (double**)fatlink, (double**)longlink, (double*)inEven, (double)kappa, dagger_bit, matpc_type);
00401         }
00402     else {
00403         if (gPrecision == QUDA_DOUBLE_PRECISION){ 
00404             MatPC((float*)outEven, (double**)fatlink, (double**)longlink, (float*)inEven, (float)kappa, dagger_bit, matpc_type);
00405         }else{
00406             MatPC((float*)outEven, (float**)fatlink, (float**)longlink, (float*)inEven, (float)kappa, dagger_bit, matpc_type);
00407         }
00408     }
00409 }
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Defines