QUDA v0.4.0
A library for QCD on GPUs
|
00001 00004 template <typename FloatN, int M, int writeX, int writeY, int writeZ, int writeW, 00005 typename InputX, typename InputY, typename InputZ, typename InputW, 00006 typename OutputX, typename OutputY, typename OutputZ, typename OutputW, typename Functor> 00007 __global__ void blasKernel(InputX X, InputY Y, InputZ Z, InputW W, Functor f, 00008 OutputX XX, OutputY YY, OutputZ ZZ, OutputW WW, int length) { 00009 unsigned int i = blockIdx.x*(blockDim.x) + threadIdx.x; 00010 unsigned int gridSize = gridDim.x*blockDim.x; 00011 while (i < length) { 00012 FloatN x[M], y[M], z[M], w[M]; 00013 X.load(x, i); 00014 Y.load(y, i); 00015 Z.load(z, i); 00016 W.load(w, i); 00017 00018 #pragma unroll 00019 for (int j=0; j<M; j++) f(x[j], y[j], z[j], w[j]); 00020 00021 if (writeX) XX.save(x, i); 00022 if (writeY) YY.save(y, i); 00023 if (writeZ) ZZ.save(z, i); 00024 if (writeW) WW.save(w, i); 00025 i += gridSize; 00026 } 00027 } 00028 00029 template <typename FloatN, int M, int writeX, int writeY, int writeZ, int writeW, 00030 typename InputX, typename InputY, typename InputZ, typename InputW, 00031 typename OutputX, typename OutputY, typename OutputZ, typename OutputW, typename Functor> 00032 class BlasCuda : public Tunable { 00033 00034 private: 00035 InputX &X; 00036 InputY &Y; 00037 InputZ &Z; 00038 InputW &W; 00039 OutputX &XX; 00040 OutputY &YY; 00041 OutputZ &ZZ; 00042 OutputW &WW; 00043 00044 // host pointers used for backing up fields when tuning 00045 // these can't be curried into the Spinors because of Tesla argument length restriction 00046 char *X_h, *Y_h, *Z_h, *W_h; 00047 char *Xnorm_h, *Ynorm_h, *Znorm_h, *Wnorm_h; 00048 00049 Functor &f; 00050 const int length; 00051 00052 int sharedBytesPerThread() const { return 0; } 00053 int sharedBytesPerBlock() const { return 0; } 00054 00055 virtual bool advanceSharedBytes(TuneParam ¶m) const 00056 { 00057 TuneParam next(param); 00058 advanceBlockDim(next); // to get next blockDim 00059 int nthreads = next.block.x * next.block.y * next.block.z; 00060 param.shared_bytes = sharedBytesPerThread()*nthreads > sharedBytesPerBlock() ? 00061 sharedBytesPerThread()*nthreads : sharedBytesPerBlock(); 00062 return false; 00063 } 00064 00065 public: 00066 BlasCuda(InputX &X, InputY &Y, InputZ &Z, InputW &W, Functor &f, 00067 OutputX &XX, OutputY &YY, OutputZ &ZZ, OutputW &WW, int length) : 00068 X(X), Y(Y), Z(Z), W(W), f(f), XX(XX), YY(YY), ZZ(ZZ), WW(WW), length(length) 00069 { ; } 00070 virtual ~BlasCuda() { } 00071 00072 TuneKey tuneKey() const { 00073 std::stringstream vol, aux; 00074 vol << blasConstants.x[0] << "x"; 00075 vol << blasConstants.x[1] << "x"; 00076 vol << blasConstants.x[2] << "x"; 00077 vol << blasConstants.x[3]; 00078 aux << "stride=" << blasConstants.stride << ",prec=" << XX.Precision(); 00079 return TuneKey(vol.str(), typeid(f).name(), aux.str()); 00080 } 00081 00082 void apply(const cudaStream_t &stream) { 00083 TuneParam tp = tuneLaunch(*this, blasTuning, verbosity); 00084 blasKernel<FloatN,M,writeX,writeY,writeZ,writeW> 00085 <<<tp.grid, tp.block, tp.shared_bytes, stream>>> 00086 (X, Y, Z, W, f, XX, YY, ZZ, WW, length); 00087 } 00088 00089 void preTune() { 00090 size_t bytes = XX.Precision()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*M*XX.Stride(); 00091 size_t norm_bytes = (XX.Precision() == QUDA_HALF_PRECISION) ? sizeof(float)*length : 0; 00092 if (writeX) XX.save(&X_h, &Xnorm_h, bytes, norm_bytes); 00093 if (writeY) YY.save(&Y_h, &Ynorm_h, bytes, norm_bytes); 00094 if (writeZ) ZZ.save(&Z_h, &Znorm_h, bytes, norm_bytes); 00095 if (writeW) WW.save(&W_h, &Wnorm_h, bytes, norm_bytes); 00096 } 00097 00098 void postTune() { 00099 size_t bytes = XX.Precision()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*M*XX.Stride(); 00100 size_t norm_bytes = (XX.Precision() == QUDA_HALF_PRECISION) ? sizeof(float)*length : 0; 00101 if (writeX) XX.load(&X_h, &Xnorm_h, bytes, norm_bytes); 00102 if (writeY) YY.load(&Y_h, &Ynorm_h, bytes, norm_bytes); 00103 if (writeZ) ZZ.load(&Z_h, &Znorm_h, bytes, norm_bytes); 00104 if (writeW) WW.load(&W_h, &Wnorm_h, bytes, norm_bytes); 00105 } 00106 00107 long long flops() const { return f.flops()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*length*M; } 00108 long long bytes() const { 00109 size_t bytes = XX.Precision()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*M; 00110 if (XX.Precision() == QUDA_HALF_PRECISION) bytes += sizeof(float); 00111 return f.streams()*bytes*length; } 00112 }; 00113 00117 template <template <typename Float, typename FloatN> class Functor, 00118 int writeX, int writeY, int writeZ, int writeW> 00119 void blasCuda(const int kernel, const double2 &a, const double2 &b, const double2 &c, 00120 cudaColorSpinorField &x, cudaColorSpinorField &y, 00121 cudaColorSpinorField &z, cudaColorSpinorField &w) { 00122 checkSpinor(x, y); 00123 checkSpinor(x, z); 00124 checkSpinor(x, w); 00125 00126 for (int d=0; d<QUDA_MAX_DIM; d++) blasConstants.x[d] = x.X()[d]; 00127 blasConstants.stride = x.Stride(); 00128 00129 if (x.SiteSubset() == QUDA_FULL_SITE_SUBSET) { 00130 blasCuda<Functor,writeX,writeY,writeZ,writeW> 00131 (kernel, a, b, c, x.Even(), y.Even(), z.Even(), w.Even()); 00132 blasCuda<Functor,writeX,writeY,writeZ,writeW> 00133 (kernel, a, b, c, x.Odd(), y.Odd(), z.Odd(), w.Odd()); 00134 return; 00135 } 00136 00137 Tunable *blas = 0; 00138 if (x.Precision() == QUDA_DOUBLE_PRECISION) { 00139 const int M = 1; 00140 SpinorTexture<double2,double2,double2,M,0> xTex(x); 00141 SpinorTexture<double2,double2,double2,M,1> yTex; 00142 if (x.V() != y.V()) yTex = SpinorTexture<double2,double2,double2,M,1>(y); 00143 SpinorTexture<double2,double2,double2,M,2> zTex; 00144 if (x.V() != z.V()) zTex = SpinorTexture<double2,double2,double2,M,2>(z); 00145 SpinorTexture<double2,double2,double2,M,3> wTex; 00146 if (x.V() != w.V()) wTex = SpinorTexture<double2,double2,double2,M,3>(w); 00147 Spinor<double2,double2,double2,M> X(x); 00148 Spinor<double2,double2,double2,M> Y(y); 00149 Spinor<double2,double2,double2,M> Z(z); 00150 Spinor<double2,double2,double2,M> W(w); 00151 Functor<double2, double2> f(a,b,c); 00152 blas = new BlasCuda<double2,M,writeX,writeY,writeZ,writeW, 00153 SpinorTexture<double2,double2,double2,M,0>, SpinorTexture<double2,double2,double2,M,1>, 00154 SpinorTexture<double2,double2,double2,M,2>, SpinorTexture<double2,double2,double2,M,3>, 00155 Spinor<double2,double2,double2,M>, Spinor<double2,double2,double2,M>, 00156 Spinor<double2,double2,double2,M>, Spinor<double2,double2,double2,M>, Functor<double2, double2> > 00157 (xTex, yTex, zTex, wTex, f, X, Y, Z, W, x.Length()/(2*M)); 00158 } else if (x.Precision() == QUDA_SINGLE_PRECISION) { 00159 const int M = 1; 00160 SpinorTexture<float4,float4,float4,M,0> xTex(x); 00161 SpinorTexture<float4,float4,float4,M,1> yTex; 00162 if (x.V() != y.V()) yTex = SpinorTexture<float4,float4,float4,M,1>(y); 00163 SpinorTexture<float4,float4,float4,M,2> zTex; 00164 if (x.V() != z.V()) zTex = SpinorTexture<float4,float4,float4,M,2>(z); 00165 SpinorTexture<float4,float4,float4,M,3> wTex; 00166 if (x.V() != w.V()) wTex = SpinorTexture<float4,float4,float4,M,3>(w); 00167 Spinor<float4,float4,float4,M> X(x); 00168 Spinor<float4,float4,float4,M> Y(y); 00169 Spinor<float4,float4,float4,M> Z(z); 00170 Spinor<float4,float4,float4,M> W(w); 00171 Functor<float2, float4> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y)); 00172 blas = new BlasCuda<float4,M,writeX,writeY,writeZ,writeW, 00173 SpinorTexture<float4,float4,float4,M,0>, SpinorTexture<float4,float4,float4,M,1>, 00174 SpinorTexture<float4,float4,float4,M,2>, SpinorTexture<float4,float4,float4,M,3>, 00175 Spinor<float4,float4,float4,M>, Spinor<float4,float4,float4,M>, 00176 Spinor<float4,float4,float4,M>, Spinor<float4,float4,float4,M>, Functor<float2, float4> > 00177 (xTex, yTex, zTex, wTex, f, X, Y, Z, W, x.Length()/(4*M)); 00178 } else { 00179 if (x.Nspin() == 4){ //wilson 00180 SpinorTexture<float4,float4,short4,6,0> xTex(x); 00181 SpinorTexture<float4,float4,short4,6,1> yTex; 00182 if (x.V() != y.V()) yTex = SpinorTexture<float4,float4,short4,6,1>(y); 00183 SpinorTexture<float4,float4,short4,6,2> zTex; 00184 if (x.V() != z.V()) zTex = SpinorTexture<float4,float4,short4,6,2>(z); 00185 SpinorTexture<float4,float4,short4,6,3> wTex; 00186 if (x.V() != w.V()) wTex = SpinorTexture<float4,float4,short4,6,3>(w); 00187 Spinor<float4,float4,short4,6> xStore(x); 00188 Spinor<float4,float4,short4,6> yStore(y); 00189 Spinor<float4,float4,short4,6> zStore(z); 00190 Spinor<float4,float4,short4,6> wStore(w); 00191 Functor<float2, float4> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y)); 00192 blas = new BlasCuda<float4, 6, writeX, writeY, writeZ, writeW, 00193 SpinorTexture<float4,float4,short4,6,0>, SpinorTexture<float4,float4,short4,6,1>, 00194 SpinorTexture<float4,float4,short4,6,2>, SpinorTexture<float4,float4,short4,6,3>, 00195 Spinor<float4,float4,short4,6>, Spinor<float4,float4,short4,6>, 00196 Spinor<float4,float4,short4,6>, Spinor<float4,float4,short4,6>, Functor<float2, float4> > 00197 (xTex, yTex, zTex, wTex, f, xStore, yStore, zStore, wStore, y.Volume()); 00198 } else if (x.Nspin() == 1) {//staggered 00199 SpinorTexture<float2,float2,short2,3,0> xTex(x); 00200 SpinorTexture<float2,float2,short2,3,1> yTex; 00201 if (x.V() != y.V()) yTex = SpinorTexture<float2,float2,short2,3,1>(y); 00202 SpinorTexture<float2,float2,short2,3,2> zTex; 00203 if (x.V() != z.V()) zTex = SpinorTexture<float2,float2,short2,3,2>(z); 00204 SpinorTexture<float2,float2,short2,3,3> wTex; 00205 if (x.V() != w.V()) wTex = SpinorTexture<float2,float2,short2,3,3>(w); 00206 Spinor<float2,float2,short2,3> xStore(x); 00207 Spinor<float2,float2,short2,3> yStore(y); 00208 Spinor<float2,float2,short2,3> zStore(z); 00209 Spinor<float2,float2,short2,3> wStore(w); 00210 Functor<float2, float2> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y)); 00211 blas = new BlasCuda<float2, 3,writeX,writeY,writeZ,writeW, 00212 SpinorTexture<float2,float2,short2,3,0>, SpinorTexture<float2,float2,short2,3,1>, 00213 SpinorTexture<float2,float2,short2,3,2>, SpinorTexture<float2,float2,short2,3,3>, 00214 Spinor<float2,float2,short2,3>, Spinor<float2,float2,short2,3>, 00215 Spinor<float2,float2,short2,3>, Spinor<float2,float2,short2,3>, Functor<float2, float2> > 00216 (xTex, yTex, zTex, wTex, f, xStore, yStore, zStore, wStore, y.Volume()); 00217 } else { errorQuda("ERROR: nSpin=%d is not supported\n", x.Nspin()); } 00218 quda::blas_bytes += Functor<double2,double2>::streams()*x.Volume()*sizeof(float); 00219 } 00220 quda::blas_bytes += Functor<double2,double2>::streams()*x.RealLength()*x.Precision(); 00221 quda::blas_flops += Functor<double2,double2>::flops()*x.RealLength(); 00222 00223 blas->apply(*blasStream); 00224 delete blas; 00225 00226 checkCudaError(); 00227 } 00228