QUDA v0.4.0
A library for QCD on GPUs
quda/lib/blas_core.h
Go to the documentation of this file.
00001 
00004 template <typename FloatN, int M, int writeX, int writeY, int writeZ, int writeW, 
00005   typename InputX, typename InputY, typename InputZ, typename InputW, 
00006   typename OutputX, typename OutputY, typename OutputZ, typename OutputW, typename Functor>
00007 __global__ void blasKernel(InputX X, InputY Y, InputZ Z, InputW W, Functor f, 
00008                            OutputX XX, OutputY YY, OutputZ ZZ, OutputW WW, int length) {
00009   unsigned int i = blockIdx.x*(blockDim.x) + threadIdx.x;
00010   unsigned int gridSize = gridDim.x*blockDim.x;
00011   while (i < length) {
00012     FloatN x[M], y[M], z[M], w[M];
00013     X.load(x, i);
00014     Y.load(y, i);
00015     Z.load(z, i);
00016     W.load(w, i);
00017 
00018 #pragma unroll
00019     for (int j=0; j<M; j++) f(x[j], y[j], z[j], w[j]);
00020 
00021     if (writeX) XX.save(x, i);
00022     if (writeY) YY.save(y, i);
00023     if (writeZ) ZZ.save(z, i);
00024     if (writeW) WW.save(w, i);
00025     i += gridSize;
00026   }
00027 }
00028 
00029 template <typename FloatN, int M, int writeX, int writeY, int writeZ, int writeW, 
00030   typename InputX, typename InputY, typename InputZ, typename InputW, 
00031   typename OutputX, typename OutputY, typename OutputZ, typename OutputW, typename Functor>
00032 class BlasCuda : public Tunable {
00033 
00034 private:
00035   InputX &X;
00036   InputY &Y;
00037   InputZ &Z;
00038   InputW &W;
00039   OutputX &XX;
00040   OutputY &YY;
00041   OutputZ &ZZ;
00042   OutputW &WW;
00043 
00044   // host pointers used for backing up fields when tuning
00045   // these can't be curried into the Spinors because of Tesla argument length restriction
00046   char *X_h, *Y_h, *Z_h, *W_h;
00047   char *Xnorm_h, *Ynorm_h, *Znorm_h, *Wnorm_h;
00048 
00049   Functor &f;
00050   const int length;
00051 
00052   int sharedBytesPerThread() const { return 0; }
00053   int sharedBytesPerBlock() const { return 0; }
00054 
00055   virtual bool advanceSharedBytes(TuneParam &param) const
00056   {
00057     TuneParam next(param);
00058     advanceBlockDim(next); // to get next blockDim
00059     int nthreads = next.block.x * next.block.y * next.block.z;
00060     param.shared_bytes = sharedBytesPerThread()*nthreads > sharedBytesPerBlock() ?
00061       sharedBytesPerThread()*nthreads : sharedBytesPerBlock();
00062     return false;
00063   }
00064 
00065 public:
00066   BlasCuda(InputX &X, InputY &Y, InputZ &Z, InputW &W, Functor &f, 
00067            OutputX &XX, OutputY &YY, OutputZ &ZZ, OutputW &WW, int length) :
00068   X(X), Y(Y), Z(Z), W(W), f(f), XX(XX), YY(YY), ZZ(ZZ), WW(WW), length(length)
00069     { ; }
00070   virtual ~BlasCuda() { }
00071 
00072   TuneKey tuneKey() const {
00073     std::stringstream vol, aux;
00074     vol << blasConstants.x[0] << "x";
00075     vol << blasConstants.x[1] << "x";
00076     vol << blasConstants.x[2] << "x";
00077     vol << blasConstants.x[3];    
00078     aux << "stride=" << blasConstants.stride << ",prec=" << XX.Precision();
00079     return TuneKey(vol.str(), typeid(f).name(), aux.str());
00080   }  
00081 
00082   void apply(const cudaStream_t &stream) {
00083     TuneParam tp = tuneLaunch(*this, blasTuning, verbosity);
00084     blasKernel<FloatN,M,writeX,writeY,writeZ,writeW>
00085       <<<tp.grid, tp.block, tp.shared_bytes, stream>>>
00086       (X, Y, Z, W, f, XX, YY, ZZ, WW, length);
00087   }
00088 
00089   void preTune() {
00090     size_t bytes = XX.Precision()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*M*XX.Stride();
00091     size_t norm_bytes = (XX.Precision() == QUDA_HALF_PRECISION) ? sizeof(float)*length : 0;
00092     if (writeX) XX.save(&X_h, &Xnorm_h, bytes, norm_bytes);
00093     if (writeY) YY.save(&Y_h, &Ynorm_h, bytes, norm_bytes);
00094     if (writeZ) ZZ.save(&Z_h, &Znorm_h, bytes, norm_bytes);
00095     if (writeW) WW.save(&W_h, &Wnorm_h, bytes, norm_bytes);
00096   }
00097 
00098   void postTune() {
00099     size_t bytes = XX.Precision()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*M*XX.Stride();
00100     size_t norm_bytes = (XX.Precision() == QUDA_HALF_PRECISION) ? sizeof(float)*length : 0;
00101     if (writeX) XX.load(&X_h, &Xnorm_h, bytes, norm_bytes);
00102     if (writeY) YY.load(&Y_h, &Ynorm_h, bytes, norm_bytes);
00103     if (writeZ) ZZ.load(&Z_h, &Znorm_h, bytes, norm_bytes);
00104     if (writeW) WW.load(&W_h, &Wnorm_h, bytes, norm_bytes);
00105   }
00106 
00107   long long flops() const { return f.flops()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*length*M; }
00108   long long bytes() const { 
00109     size_t bytes = XX.Precision()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*M;
00110     if (XX.Precision() == QUDA_HALF_PRECISION) bytes += sizeof(float);
00111     return f.streams()*bytes*length; }
00112 };
00113 
00117 template <template <typename Float, typename FloatN> class Functor,
00118           int writeX, int writeY, int writeZ, int writeW>
00119 void blasCuda(const int kernel, const double2 &a, const double2 &b, const double2 &c,
00120               cudaColorSpinorField &x, cudaColorSpinorField &y, 
00121               cudaColorSpinorField &z, cudaColorSpinorField &w) {
00122   checkSpinor(x, y);
00123   checkSpinor(x, z);
00124   checkSpinor(x, w);
00125 
00126   for (int d=0; d<QUDA_MAX_DIM; d++) blasConstants.x[d] = x.X()[d];
00127   blasConstants.stride = x.Stride();
00128 
00129   if (x.SiteSubset() == QUDA_FULL_SITE_SUBSET) {
00130     blasCuda<Functor,writeX,writeY,writeZ,writeW>
00131       (kernel, a, b, c, x.Even(), y.Even(), z.Even(), w.Even());
00132     blasCuda<Functor,writeX,writeY,writeZ,writeW>
00133       (kernel, a, b, c, x.Odd(), y.Odd(), z.Odd(), w.Odd());
00134     return;
00135   }
00136 
00137   Tunable *blas = 0;
00138   if (x.Precision() == QUDA_DOUBLE_PRECISION) {
00139     const int M = 1;
00140     SpinorTexture<double2,double2,double2,M,0> xTex(x);
00141     SpinorTexture<double2,double2,double2,M,1> yTex;
00142     if (x.V() != y.V()) yTex = SpinorTexture<double2,double2,double2,M,1>(y);    
00143     SpinorTexture<double2,double2,double2,M,2> zTex;
00144     if (x.V() != z.V()) zTex = SpinorTexture<double2,double2,double2,M,2>(z);    
00145     SpinorTexture<double2,double2,double2,M,3> wTex;
00146     if (x.V() != w.V()) wTex = SpinorTexture<double2,double2,double2,M,3>(w);    
00147     Spinor<double2,double2,double2,M> X(x);
00148     Spinor<double2,double2,double2,M> Y(y);
00149     Spinor<double2,double2,double2,M> Z(z);
00150     Spinor<double2,double2,double2,M> W(w);
00151     Functor<double2, double2> f(a,b,c);
00152     blas = new BlasCuda<double2,M,writeX,writeY,writeZ,writeW,
00153       SpinorTexture<double2,double2,double2,M,0>, SpinorTexture<double2,double2,double2,M,1>,
00154       SpinorTexture<double2,double2,double2,M,2>, SpinorTexture<double2,double2,double2,M,3>,
00155       Spinor<double2,double2,double2,M>, Spinor<double2,double2,double2,M>, 
00156       Spinor<double2,double2,double2,M>, Spinor<double2,double2,double2,M>, Functor<double2, double2> >
00157       (xTex, yTex, zTex, wTex, f, X, Y, Z, W, x.Length()/(2*M));
00158   } else if (x.Precision() == QUDA_SINGLE_PRECISION) {
00159     const int M = 1;
00160     SpinorTexture<float4,float4,float4,M,0> xTex(x);
00161     SpinorTexture<float4,float4,float4,M,1> yTex;
00162     if (x.V() != y.V()) yTex = SpinorTexture<float4,float4,float4,M,1>(y);
00163     SpinorTexture<float4,float4,float4,M,2> zTex;
00164     if (x.V() != z.V()) zTex = SpinorTexture<float4,float4,float4,M,2>(z);
00165     SpinorTexture<float4,float4,float4,M,3> wTex;
00166     if (x.V() != w.V()) wTex = SpinorTexture<float4,float4,float4,M,3>(w);
00167     Spinor<float4,float4,float4,M> X(x);
00168     Spinor<float4,float4,float4,M> Y(y);
00169     Spinor<float4,float4,float4,M> Z(z);
00170     Spinor<float4,float4,float4,M> W(w);
00171     Functor<float2, float4> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
00172     blas = new BlasCuda<float4,M,writeX,writeY,writeZ,writeW,
00173       SpinorTexture<float4,float4,float4,M,0>, SpinorTexture<float4,float4,float4,M,1>, 
00174       SpinorTexture<float4,float4,float4,M,2>, SpinorTexture<float4,float4,float4,M,3>, 
00175       Spinor<float4,float4,float4,M>, Spinor<float4,float4,float4,M>, 
00176       Spinor<float4,float4,float4,M>, Spinor<float4,float4,float4,M>, Functor<float2, float4> >
00177       (xTex, yTex, zTex, wTex, f, X, Y, Z, W, x.Length()/(4*M));
00178   } else {
00179     if (x.Nspin() == 4){ //wilson
00180       SpinorTexture<float4,float4,short4,6,0> xTex(x);
00181       SpinorTexture<float4,float4,short4,6,1> yTex;
00182       if (x.V() != y.V()) yTex = SpinorTexture<float4,float4,short4,6,1>(y);
00183       SpinorTexture<float4,float4,short4,6,2> zTex;
00184       if (x.V() != z.V()) zTex = SpinorTexture<float4,float4,short4,6,2>(z);
00185       SpinorTexture<float4,float4,short4,6,3> wTex;
00186       if (x.V() != w.V()) wTex = SpinorTexture<float4,float4,short4,6,3>(w);
00187       Spinor<float4,float4,short4,6> xStore(x);
00188       Spinor<float4,float4,short4,6> yStore(y);
00189       Spinor<float4,float4,short4,6> zStore(z);
00190       Spinor<float4,float4,short4,6> wStore(w);
00191       Functor<float2, float4> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
00192       blas = new BlasCuda<float4, 6, writeX, writeY, writeZ, writeW, 
00193         SpinorTexture<float4,float4,short4,6,0>, SpinorTexture<float4,float4,short4,6,1>, 
00194         SpinorTexture<float4,float4,short4,6,2>, SpinorTexture<float4,float4,short4,6,3>, 
00195         Spinor<float4,float4,short4,6>, Spinor<float4,float4,short4,6>, 
00196         Spinor<float4,float4,short4,6>, Spinor<float4,float4,short4,6>, Functor<float2, float4> >
00197         (xTex, yTex, zTex, wTex, f, xStore, yStore, zStore, wStore, y.Volume());
00198     } else if (x.Nspin() == 1) {//staggered
00199       SpinorTexture<float2,float2,short2,3,0> xTex(x);
00200       SpinorTexture<float2,float2,short2,3,1> yTex;
00201       if (x.V() != y.V()) yTex = SpinorTexture<float2,float2,short2,3,1>(y);
00202       SpinorTexture<float2,float2,short2,3,2> zTex;
00203       if (x.V() != z.V()) zTex = SpinorTexture<float2,float2,short2,3,2>(z);
00204       SpinorTexture<float2,float2,short2,3,3> wTex;
00205       if (x.V() != w.V()) wTex = SpinorTexture<float2,float2,short2,3,3>(w);
00206       Spinor<float2,float2,short2,3> xStore(x);
00207       Spinor<float2,float2,short2,3> yStore(y);
00208       Spinor<float2,float2,short2,3> zStore(z);
00209       Spinor<float2,float2,short2,3> wStore(w);
00210       Functor<float2, float2> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
00211       blas = new BlasCuda<float2, 3,writeX,writeY,writeZ,writeW,
00212         SpinorTexture<float2,float2,short2,3,0>, SpinorTexture<float2,float2,short2,3,1>,
00213         SpinorTexture<float2,float2,short2,3,2>, SpinorTexture<float2,float2,short2,3,3>,
00214         Spinor<float2,float2,short2,3>, Spinor<float2,float2,short2,3>,
00215         Spinor<float2,float2,short2,3>, Spinor<float2,float2,short2,3>, Functor<float2, float2> >
00216         (xTex, yTex, zTex, wTex, f, xStore, yStore, zStore, wStore, y.Volume());
00217     } else { errorQuda("ERROR: nSpin=%d is not supported\n", x.Nspin()); }
00218     quda::blas_bytes += Functor<double2,double2>::streams()*x.Volume()*sizeof(float);
00219   }
00220   quda::blas_bytes += Functor<double2,double2>::streams()*x.RealLength()*x.Precision();
00221   quda::blas_flops += Functor<double2,double2>::flops()*x.RealLength();
00222 
00223   blas->apply(*blasStream);
00224   delete blas;
00225 
00226   checkCudaError();
00227 }
00228 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines