QUDA v0.4.0
A library for QCD on GPUs
quda/lib/hw_quda.cpp
Go to the documentation of this file.
00001 #include <stdlib.h>
00002 #include <stdio.h>
00003 
00004 #include "quda.h"
00005 #include "hw_quda.h"
00006 #include "util_quda.h"
00007 
00008 #define hwSiteSize 12
00009 
00010 static ParityHw 
00011 allocateParityHw(int *X, QudaPrecision precision) 
00012 {
00013     ParityHw ret;
00014     
00015     ret.precision = precision;
00016     ret.X[0] = X[0]/2;
00017     ret.volume = X[0]/2;
00018     for (int d=1; d<4; d++) {
00019         ret.X[d] = X[d];
00020         ret.volume *= X[d];
00021     }
00022     ret.Nc = 3;
00023     ret.Ns = 2;
00024     ret.length = ret.volume*ret.Nc*ret.Ns*2;
00025     
00026     if (precision == QUDA_DOUBLE_PRECISION) ret.bytes = ret.length*sizeof(double);
00027     else if (precision == QUDA_SINGLE_PRECISION) ret.bytes = ret.length*sizeof(float);
00028     else ret.bytes = ret.length*sizeof(short);
00029     
00030     if (cudaMalloc((void**)&ret.data, ret.bytes) == cudaErrorMemoryAllocation) {
00031         printf("Error allocating half wilson\n");
00032         exit(0);
00033     }
00034     
00035     cudaMemset(ret.data, 0, ret.bytes);
00036     
00037     if (precision == QUDA_HALF_PRECISION) { //FIXME not supported yet
00038       printf("ERROR: half precision not supporte yet in function %s\n", __FUNCTION__);
00039       //if (cudaMalloc((void**)&ret.dataNorm, 2*ret.bytes/spinorSiteSize) == cudaErrorMemoryAllocation) {
00040       //printf("Error allocating half wilson Norm\n");
00041       //exit(0);
00042       //}
00043     }
00044     
00045     return ret;
00046 }
00047 
00048 
00049 FullHw 
00050 createHwQuda(int *X, QudaPrecision precision) 
00051 {
00052     FullHw ret;
00053     ret.even = allocateParityHw(X, precision);
00054     ret.odd = allocateParityHw(X, precision);
00055     return ret;
00056 }
00057 
00058 
00059 static void
00060 freeParityHwQuda(ParityHw parity_hw) 
00061 {
00062     
00063     cudaFree(parity_hw.data);
00064     if (parity_hw.precision == QUDA_HALF_PRECISION){
00065         cudaFree(parity_hw.dataNorm);
00066     }
00067     
00068     parity_hw.data = NULL;
00069     parity_hw.dataNorm = NULL;
00070 }
00071 
00072 void 
00073 freeHwQuda(FullHw hw) 
00074 {
00075     freeParityHwQuda(hw.even);
00076     freeParityHwQuda(hw.odd);
00077 }
00078 
00079 
00080 template <typename Float>
00081 static inline void packHwVector(float4* a, Float *b, int Vh) 
00082 {    
00083     a[0*Vh].x = b[0];
00084     a[0*Vh].y = b[1];
00085     a[0*Vh].z = b[2];
00086     a[0*Vh].w = b[3];
00087     
00088     a[1*Vh].x = b[4];
00089     a[1*Vh].y = b[5];
00090     a[1*Vh].z = b[6];
00091     a[1*Vh].w = b[7];
00092     
00093     a[2*Vh].x = b[8];
00094     a[2*Vh].y = b[9];
00095     a[2*Vh].z = b[10];
00096     a[2*Vh].w = b[11];
00097     
00098 }
00099 
00100 template <typename Float>
00101 static inline void packHwVector(float2* a, Float *b, int Vh) 
00102 {    
00103     a[0*Vh].x = b[0];
00104     a[0*Vh].y = b[1];
00105     
00106     a[1*Vh].x = b[2];
00107     a[1*Vh].y = b[3];
00108     
00109     a[2*Vh].x = b[4];
00110     a[2*Vh].y = b[5];
00111     
00112     a[3*Vh].x = b[6];
00113     a[3*Vh].y = b[7];
00114     
00115     a[4*Vh].x = b[8];
00116     a[4*Vh].y = b[9];
00117     
00118     a[5*Vh].x = b[10];
00119     a[5*Vh].y = b[11];  
00120 }
00121 
00122 template <typename Float>
00123 static inline void packHwVector(double2* a, Float *b, int Vh) 
00124 {    
00125     a[0*Vh].x = b[0];
00126     a[0*Vh].y = b[1];
00127     
00128     a[1*Vh].x = b[2];
00129     a[1*Vh].y = b[3];
00130     
00131     a[2*Vh].x = b[4];
00132     a[2*Vh].y = b[5];
00133     
00134     a[3*Vh].x = b[6];
00135     a[3*Vh].y = b[7];
00136     
00137     a[4*Vh].x = b[8];
00138     a[4*Vh].y = b[9];
00139     
00140     a[5*Vh].x = b[10];
00141     a[5*Vh].y = b[11];  
00142 }
00143 
00144 
00145 template <typename Float>
00146 static inline void unpackHwVector(Float *a, float4 *b, int Vh) 
00147 {
00148     a[0] = a[0*Vh].x;
00149     a[1] = a[0*Vh].y;
00150     a[2] = a[0*Vh].z;
00151     a[3] = a[0*Vh].t;
00152     
00153     a[4] = a[1*Vh].x;
00154     a[5] = a[1*Vh].y;
00155     a[6] = a[1*Vh].z;
00156     a[7] = a[1*Vh].t;
00157     
00158     a[8] = a[2*Vh].x;
00159     a[9] = a[2*Vh].y;
00160     a[10] = a[2*Vh].z;
00161     a[11] = a[2*Vh].t;      
00162 }
00163 
00164 
00165 template <typename Float>
00166 static inline void unpackHwVector(Float *a, float2 *b, int Vh) 
00167 {    
00168     a[0] = b[0*Vh].x;
00169     a[1] = b[0*Vh].y;
00170     
00171     a[2] = b[1*Vh].x;
00172     a[3] = b[1*Vh].y;
00173     
00174     a[4] = b[2*Vh].x;
00175     a[5] = b[2*Vh].y;
00176     
00177     a[6] = b[3*Vh].x;
00178     a[7] = b[3*Vh].y;
00179     
00180     a[8] = b[4*Vh].x;
00181     a[9] = b[4*Vh].y;
00182     
00183     a[10] = b[5*Vh].x;
00184     a[11] = b[5*Vh].y;   
00185 
00186 }
00187 
00188 template <typename Float>
00189 static inline void unpackHwVector(Float *a, double2 *b, int Vh) 
00190 {    
00191     a[0] = b[0*Vh].x;
00192     a[1] = b[0*Vh].y;
00193     
00194     a[2] = b[1*Vh].x;
00195     a[3] = b[1*Vh].y;
00196     
00197     a[4] = b[2*Vh].x;
00198     a[5] = b[2*Vh].y;
00199     
00200     a[6] = b[3*Vh].x;
00201     a[7] = b[3*Vh].y;
00202     
00203     a[8] = b[4*Vh].x;
00204     a[9] = b[4*Vh].y;
00205     
00206     a[10] = b[5*Vh].x;
00207     a[11] = b[5*Vh].y;   
00208 
00209 }
00210 
00211 template <typename Float, typename FloatN>
00212 void packParityHw(FloatN *res, Float *hw, int Vh) 
00213 {
00214     for (int i = 0; i < Vh; i++) {
00215         packHwVector(res+i, hw+hwSiteSize*i, Vh);
00216     }
00217 }
00218 
00219 template <typename Float, typename FloatN>
00220 static void unpackParityHw(Float *res, FloatN *hwPacked, int Vh) {
00221 
00222   for (int i = 0; i < Vh; i++) {
00223       unpackHwVector(res+i*hwSiteSize, hwPacked+i, Vh);
00224   }
00225 }
00226 
00227 
00228 
00229 void
00230 static loadParityHw(ParityHw ret, void *hw, QudaPrecision cpu_prec)
00231 {
00232     void *packedHw1 = 0;
00233     
00234     if (ret.precision == QUDA_DOUBLE_PRECISION && cpu_prec != QUDA_DOUBLE_PRECISION) {
00235         printf("Error, cannot have CUDA double precision without double CPU precision\n");
00236         exit(-1);
00237     }
00238     
00239     if (ret.precision != QUDA_HALF_PRECISION) { 
00240         cudaMallocHost(&packedHw1, ret.bytes);
00241         
00242         if (ret.precision == QUDA_DOUBLE_PRECISION) {
00243             packParityHw((double2*)packedHw1, (double*)hw, ret.volume);
00244         } else {
00245             if (cpu_prec == QUDA_DOUBLE_PRECISION) {
00246                 packParityHw((float2*)packedHw1, (double*)hw, ret.volume);
00247             }
00248             else {
00249                 packParityHw((float2*)packedHw1, (float*)hw, ret.volume);
00250             }
00251         }
00252         cudaMemcpy(ret.data, packedHw1, ret.bytes, cudaMemcpyHostToDevice);
00253         cudaFreeHost(packedHw1);
00254     } else {
00255         
00256         //half precision
00257         /*
00258           ParityHw tmp = allocateParityHw(ret.X, QUDA_SINGLE_PRECISION);
00259           loadParityHw(tmp, hw, cpu_prec, dirac_order);
00260           copyCuda(ret, tmp);
00261           freeParityHw(tmp);
00262         */
00263     }
00264     
00265 }
00266 
00267 
00268 void
00269 loadHwToGPU(FullHw ret, void *hw, QudaPrecision cpu_prec)
00270 {
00271     void *hw_odd;
00272     if (cpu_prec == QUDA_SINGLE_PRECISION){
00273         hw_odd = ((float*)hw) + ret.even.length;
00274     }else{
00275         hw_odd = ((double*)hw) + ret.even.length;
00276     }
00277     
00278     loadParityHw(ret.even, hw, cpu_prec);
00279     loadParityHw(ret.odd, hw_odd, cpu_prec);
00280     
00281 }
00282 
00283 static void 
00284 retrieveParityHw(void *res, ParityHw hw, QudaPrecision cpu_prec)
00285 {
00286     void *packedHw1 = 0;
00287     if (hw.precision != QUDA_HALF_PRECISION) {
00288         cudaMallocHost((void**)&packedHw1, hw.bytes);
00289         cudaMemcpy(packedHw1, hw.data, hw.bytes, cudaMemcpyDeviceToHost);
00290         
00291         if (hw.precision == QUDA_DOUBLE_PRECISION) {
00292             unpackParityHw((double*)res, (double2*)packedHw1, hw.volume);
00293         } else {
00294             if (cpu_prec == QUDA_DOUBLE_PRECISION){
00295                 unpackParityHw((double*)res, (float2*)packedHw1, hw.volume);
00296             }
00297             else {
00298                 unpackParityHw((float*)res, (float2*)packedHw1, hw.volume);
00299             }
00300         }
00301         cudaFreeHost(packedHw1);
00302         
00303     } else {
00304         //half precision
00305         /*
00306           ParityHw tmp = allocateParityHw(hw.X, QUDA_SINGLE_PRECISION);
00307           copyCuda(tmp, hw);
00308           retrieveParityHw(res, tmp, cpu_prec, dirac_order);
00309           freeParityHw(tmp);
00310         */
00311     }
00312 }
00313 
00314 
00315 void 
00316 retrieveHwField(void *res, FullHw hw, QudaPrecision cpu_prec)
00317 {
00318     void *res_odd;
00319     if (cpu_prec == QUDA_SINGLE_PRECISION) res_odd = (float*)res + hw.even.length;
00320     else res_odd = (double*)res + hw.even.length;
00321         
00322     retrieveParityHw(res, hw.even, cpu_prec);
00323     retrieveParityHw(res_odd, hw.odd, cpu_prec);
00324     
00325 }
00326 
00327 
00328 /*
00329 void hwHalfPack(float *c, short *s0, float *f0, int V) {
00330 
00331   float *f = f0;
00332   short *s = s0;
00333   for (int i=0; i<24*V; i+=24) {
00334     c[i] = sqrt(f[0]*f[0] + f[1]*f[1]);
00335     for (int j=0; j<24; j+=2) {
00336       float k = sqrt(f[j]*f[j] + f[j+1]*f[j+1]);
00337       if (k > c[i]) c[i] = k;
00338     }
00339 
00340     for (int j=0; j<24; j++) s[j] = (short)(MAX_SHORT*f[j]/c[i]);
00341     f+=24;
00342     s+=24;
00343   }
00344 
00345 }
00346 
00347 void hwHalfUnpack(float *f0, float *c, short *s0, int V) {
00348   float *f = f0;
00349   short *s = s0;
00350   for (int i=0; i<24*V; i+=24) {
00351     for (int j=0; j<24; j++) f[j] = s[j] * (c[i] / MAX_SHORT);
00352     f+=24;
00353     s+=24;
00354   }
00355 
00356 }
00357 */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines