|
QUDA v0.3.2
A library for QCD on GPUs
|
00001 #include <quda_internal.h> 00002 #include <quda.h> 00003 00004 #include <iostream> 00005 #include <complex> 00006 typedef std::complex<double> Complex; 00007 00008 #ifndef _COLOR_SPINOR_FIELD_H 00009 #define _COLOR_SPINOR_FIELD_H 00010 00011 // Probably want some checking for this limit 00012 #define QUDA_MAX_DIM 6 00013 00014 class ColorSpinorParam { 00015 public: 00016 QudaFieldLocation fieldLocation; // cpu, cuda etc. 00017 int nColor; // Number of colors of the field 00018 int nSpin; // =1 for staggered, =2 for coarse Dslash, =4 for 4d spinor 00019 int nDim; // number of spacetime dimensions 00020 int x[QUDA_MAX_DIM]; // size of each dimension 00021 QudaPrecision precision; // Precision of the field 00022 int pad; // volumetric padding 00023 00024 QudaTwistFlavorType twistFlavor; // used by twisted mass 00025 00026 QudaSiteSubset siteSubset; // Full, even or odd 00027 QudaSiteOrder siteOrder; // defined for full fields 00028 00029 QudaFieldOrder fieldOrder; // Float, Float2, Float4 etc. 00030 QudaGammaBasis gammaBasis; 00031 QudaFieldCreate create; // 00032 00033 void *v; // pointer to field 00034 void *norm; 00035 00036 ColorSpinorParam() 00037 : fieldLocation(QUDA_INVALID_FIELD_LOCATION), nColor(0), nSpin(0), nDim(0), 00038 precision(QUDA_INVALID_PRECISION), pad(0), twistFlavor(QUDA_TWIST_INVALID), 00039 siteSubset(QUDA_INVALID_SITE_SUBSET), siteOrder(QUDA_INVALID_SITE_ORDER), 00040 fieldOrder(QUDA_INVALID_FIELD_ORDER), gammaBasis(QUDA_INVALID_GAMMA_BASIS), 00041 create(QUDA_INVALID_FIELD_CREATE) 00042 { for(int d=0; d<QUDA_MAX_DIM; d++) x[d] = 0;} 00043 00044 // used to create cpu params 00045 ColorSpinorParam(void *V, QudaInvertParam &inv_param, int *X) 00046 : fieldLocation(QUDA_CPU_FIELD_LOCATION), nColor(3), nSpin(4), nDim(4), 00047 precision(inv_param.cpu_prec), pad(0), twistFlavor(inv_param.twist_flavor), 00048 siteSubset(QUDA_INVALID_SITE_SUBSET), siteOrder(QUDA_INVALID_SITE_ORDER), 00049 fieldOrder(QUDA_INVALID_FIELD_ORDER), gammaBasis(QUDA_DEGRAND_ROSSI_GAMMA_BASIS), 00050 create(QUDA_REFERENCE_FIELD_CREATE), v(V) 00051 { 00052 00053 if (nDim > QUDA_MAX_DIM) errorQuda("Number of dimensions too great"); 00054 for (int d=0; d<nDim; d++) x[d] = X[d]; 00055 00056 if (inv_param.dslash_type == QUDA_DOMAIN_WALL_DSLASH) { 00057 nDim++; 00058 x[4] = inv_param.Ls; 00059 } 00060 00061 if (inv_param.dirac_order == QUDA_CPS_WILSON_DIRAC_ORDER) { 00062 fieldOrder = QUDA_SPACE_SPIN_COLOR_FIELD_ORDER; 00063 siteOrder = QUDA_ODD_EVEN_SITE_ORDER; 00064 } else if (inv_param.dirac_order == QUDA_QDP_DIRAC_ORDER) { 00065 fieldOrder = QUDA_SPACE_COLOR_SPIN_FIELD_ORDER; 00066 siteOrder = QUDA_EVEN_ODD_SITE_ORDER; 00067 } else if (inv_param.dirac_order == QUDA_DIRAC_ORDER) { 00068 fieldOrder = QUDA_SPACE_SPIN_COLOR_FIELD_ORDER; 00069 siteOrder = QUDA_EVEN_ODD_SITE_ORDER; 00070 } else { 00071 errorQuda("Dirac order %d not supported", inv_param.dirac_order); 00072 } 00073 } 00074 00075 // used to create cuda param from a cpu param 00076 ColorSpinorParam(ColorSpinorParam &cpuParam, QudaInvertParam &inv_param) 00077 : fieldLocation(QUDA_CUDA_FIELD_LOCATION), nColor(cpuParam.nColor), nSpin(cpuParam.nSpin), 00078 nDim(cpuParam.nDim), precision(inv_param.cuda_prec), pad(inv_param.sp_pad), 00079 twistFlavor(cpuParam.twistFlavor), siteSubset(cpuParam.siteSubset), 00080 siteOrder(QUDA_EVEN_ODD_SITE_ORDER), fieldOrder(QUDA_INVALID_FIELD_ORDER), 00081 gammaBasis(QUDA_UKQCD_GAMMA_BASIS), create(QUDA_COPY_FIELD_CREATE), v(0) 00082 { 00083 if (nDim > QUDA_MAX_DIM) errorQuda("Number of dimensions too great"); 00084 for (int d=0; d<nDim; d++) x[d] = cpuParam.x[d]; 00085 00086 if (precision == QUDA_DOUBLE_PRECISION) { 00087 fieldOrder = QUDA_FLOAT2_FIELD_ORDER; 00088 } else { 00089 fieldOrder = QUDA_FLOAT4_FIELD_ORDER; 00090 } 00091 00092 } 00093 00094 void print() { 00095 printfQuda("fieldLocation = %d\n", fieldLocation); 00096 printfQuda("nColor = %d\n", nColor); 00097 printfQuda("nSpin = %d\n", nSpin); 00098 printfQuda("twistFlavor = %d\n", twistFlavor); 00099 printfQuda("nDim = %d\n", nDim); 00100 for (int d=0; d<nDim; d++) printfQuda("x[%d] = %d\n", d, x[d]); 00101 printfQuda("precision = %d\n", precision); 00102 printfQuda("pad = %d\n", pad); 00103 printfQuda("siteSubset = %d\n", siteSubset); 00104 printfQuda("siteOrder = %d\n", siteOrder); 00105 printfQuda("fieldOrder = %d\n", fieldOrder); 00106 printfQuda("gammaBasis = %d\n", gammaBasis); 00107 printfQuda("create = %d\n", create); 00108 printfQuda("v = %lx\n", (unsigned long)v); 00109 printfQuda("norm = %lx\n", (unsigned long)norm); 00110 } 00111 00112 virtual ~ColorSpinorParam() { 00113 } 00114 }; 00115 00116 class ColorSpinorField { 00117 00118 private: 00119 void create(int nDim, const int *x, int Nc, int Ns, QudaTwistFlavorType Twistflavor, 00120 QudaPrecision precision, int pad, QudaFieldLocation location, QudaSiteSubset subset, 00121 QudaSiteOrder siteOrder, QudaFieldOrder fieldOrder, QudaGammaBasis gammaBasis); 00122 void destroy(); 00123 00124 protected: 00125 bool init; 00126 QudaPrecision precision; 00127 00128 int nColor; 00129 int nSpin; 00130 00131 int nDim; 00132 int x[QUDA_MAX_DIM]; 00133 00134 int volume; 00135 int pad; 00136 int stride; 00137 00138 QudaTwistFlavorType twistFlavor; 00139 00140 int real_length; 00141 int length; 00142 size_t bytes; 00143 00144 QudaFieldLocation fieldLocation; 00145 QudaSiteSubset siteSubset; 00146 QudaSiteOrder siteOrder; 00147 QudaFieldOrder fieldOrder; 00148 QudaGammaBasis gammaBasis; 00149 00150 // in the case of full fields, these are references to the even / odd sublattices 00151 ColorSpinorField *even; 00152 ColorSpinorField *odd; 00153 00154 // resets the above attributes based on contents of param 00155 void reset(const ColorSpinorParam &); 00156 void fill(ColorSpinorParam &); 00157 static void checkField(const ColorSpinorField &, const ColorSpinorField &); 00158 00159 public: 00160 //ColorSpinorField(); 00161 ColorSpinorField(const ColorSpinorField &); 00162 ColorSpinorField(const ColorSpinorParam &); 00163 00164 virtual ~ColorSpinorField(); 00165 00166 ColorSpinorField& operator=(const ColorSpinorField &); 00167 00168 QudaPrecision Precision() const { return precision; } 00169 int Ncolor() const { return nColor; } 00170 int Nspin() const { return nSpin; } 00171 int TwistFlavor() const { return twistFlavor; } 00172 int Ndim() const { return nDim; } 00173 int X(int d) const { return x[d]; } 00174 int Length() const { return length; } 00175 int Stride() const { return stride; } 00176 int Volume() const { return volume; } 00177 void PrintDims() const { printf("dimensions=%d %d %d %d\n", 00178 x[0], x[1], x[2], x[3]);} 00179 00180 QudaFieldLocation FieldLocation() const { return fieldLocation; } 00181 QudaSiteSubset SiteSubset() const { return siteSubset; } 00182 QudaSiteOrder SiteOrder() const { return siteOrder; } 00183 QudaFieldOrder FieldOrder() const { return fieldOrder; } 00184 QudaGammaBasis GammaBasis() const { return gammaBasis; } 00185 00186 friend std::ostream& operator<<(std::ostream &out, const ColorSpinorField &); 00187 }; 00188 00189 class cpuColorSpinorField; 00190 00191 // CUDA implementation 00192 class cudaColorSpinorField : public ColorSpinorField { 00193 00194 friend class cpuColorSpinorField; 00195 00196 friend double normEven(const cudaColorSpinorField &b); 00197 00198 friend class Dirac; 00199 friend class DiracWilson; 00200 friend class DiracClover; 00201 friend class DiracCloverPC; 00202 friend class DiracDomainWall; 00203 friend class DiracDomainWallPC; 00204 friend class DiracStaggered; 00205 friend class DiracStaggeredPC; 00206 friend class DiracTwistedMass; 00207 friend class DiracTwistedMassPC; 00208 friend void zeroCuda(cudaColorSpinorField &a); 00209 friend void copyCuda(cudaColorSpinorField &, const cudaColorSpinorField &); 00210 friend double axpyNormCuda(const double &a, cudaColorSpinorField &x, cudaColorSpinorField &y); 00211 friend double sumCuda(cudaColorSpinorField &b); 00212 friend double normCuda(const cudaColorSpinorField &b); 00213 friend double reDotProductCuda(cudaColorSpinorField &a, cudaColorSpinorField &b); 00214 friend double xmyNormCuda(cudaColorSpinorField &a, cudaColorSpinorField &b); 00215 friend void axpbyCuda(const double &a, cudaColorSpinorField &x, const double &b, cudaColorSpinorField &y); 00216 friend void axpyCuda(const double &a, cudaColorSpinorField &x, cudaColorSpinorField &y); 00217 friend void axCuda(const double &a, cudaColorSpinorField &x); 00218 friend void xpyCuda(cudaColorSpinorField &x, cudaColorSpinorField &y); 00219 friend void xpayCuda(const cudaColorSpinorField &x, const double &a, cudaColorSpinorField &y); 00220 friend void mxpyCuda(cudaColorSpinorField &x, cudaColorSpinorField &y); 00221 friend void axpyZpbxCuda(const double &a, cudaColorSpinorField &x, cudaColorSpinorField &y, 00222 cudaColorSpinorField &z, const double &b); 00223 friend void axpyBzpcxCuda(const double &a, cudaColorSpinorField& x, cudaColorSpinorField& y, 00224 const double &b, cudaColorSpinorField& z, const double &c); 00225 00226 friend void caxpbyCuda(const Complex &a, cudaColorSpinorField &x, const Complex &b, cudaColorSpinorField &y); 00227 friend void caxpyCuda(const Complex &a, cudaColorSpinorField &x, cudaColorSpinorField &y); 00228 friend void cxpaypbzCuda(cudaColorSpinorField &, const Complex &b, cudaColorSpinorField &y, 00229 const Complex &c, cudaColorSpinorField &z); 00230 friend void caxpbypzYmbwCuda(const Complex &, cudaColorSpinorField &, const Complex &, cudaColorSpinorField &, 00231 cudaColorSpinorField &, cudaColorSpinorField &); 00232 friend Complex cDotProductCuda(cudaColorSpinorField &, cudaColorSpinorField &); 00233 friend Complex xpaycDotzyCuda(cudaColorSpinorField &x, const double &a, cudaColorSpinorField &y, 00234 cudaColorSpinorField &z); 00235 friend double3 cDotProductNormACuda(cudaColorSpinorField &a, cudaColorSpinorField &b); 00236 friend double3 cDotProductNormBCuda(cudaColorSpinorField &a, cudaColorSpinorField &b); 00237 friend double3 caxpbypzYmbwcDotProductWYNormYCuda(const Complex &a, cudaColorSpinorField &x, const Complex &b, 00238 cudaColorSpinorField &y, cudaColorSpinorField &z, 00239 cudaColorSpinorField &w, cudaColorSpinorField &u); 00240 00241 private: 00242 void *v; // the field elements 00243 void *norm; // the normalization field 00244 bool alloc; // whether we allocated memory 00245 bool init; 00246 00247 static void *buffer;// pinned memory 00248 static bool bufferInit; 00249 static size_t bufferBytes; 00250 00251 void create(const QudaFieldCreate); 00252 void destroy(); 00253 void zero(); 00254 void copy(const cudaColorSpinorField &); 00255 00256 public: 00257 //cudaColorSpinorField(); 00258 cudaColorSpinorField(const cudaColorSpinorField&); 00259 cudaColorSpinorField(const ColorSpinorField&, const ColorSpinorParam&); 00260 cudaColorSpinorField(const ColorSpinorField&); 00261 cudaColorSpinorField(const ColorSpinorParam&); 00262 virtual ~cudaColorSpinorField(); 00263 00264 cudaColorSpinorField& operator=(const cudaColorSpinorField&); 00265 cudaColorSpinorField& operator=(const cpuColorSpinorField&); 00266 00267 void loadCPUSpinorField(const cpuColorSpinorField &src); 00268 void saveCPUSpinorField (cpuColorSpinorField &src) const; 00269 00270 cudaColorSpinorField& Even() const; 00271 cudaColorSpinorField& Odd() const; 00272 00273 static void freeBuffer(); 00274 00275 }; 00276 00277 00278 // CPU implementation 00279 class cpuColorSpinorField : public ColorSpinorField { 00280 00281 friend class cudaColorSpinorField; 00282 00283 friend double normCpu(const cpuColorSpinorField &); 00284 friend double dslashCUDA(); 00285 friend void dslashRef(); 00286 friend void staggeredDslashRef(); 00287 00288 private: 00289 void *v; // the field elements 00290 void *norm; // the normalization field 00291 bool init; 00292 00293 void create(const QudaFieldCreate); 00294 void destroy(); 00295 void copy(const cpuColorSpinorField&); 00296 void zero(); 00297 00298 public: 00299 //cpuColorSpinorField(); 00300 cpuColorSpinorField(const cpuColorSpinorField&); 00301 cpuColorSpinorField(const ColorSpinorField&); 00302 cpuColorSpinorField(const ColorSpinorParam&); 00303 virtual ~cpuColorSpinorField(); 00304 00305 cpuColorSpinorField& operator=(const cpuColorSpinorField&); 00306 cpuColorSpinorField& operator=(const cudaColorSpinorField&); 00307 00308 //cpuColorSpinorField& Even() const; 00309 //cpuColorSpinorField& Odd() const; 00310 00311 void Source(const QudaSourceType sourceType, const int st=0, const int s=0, const int c=0); 00312 static void Compare(const cpuColorSpinorField &a, const cpuColorSpinorField &b, const int resolution=1); 00313 void PrintVector(int vol); 00314 }; 00315 00316 #endif // _COLOR_SPINOR_FIELD_H 00317 00318 /* 00319 00320 // experimenting with functors for arbitrary ordering 00321 class spinorFunctor { 00322 00323 protected: 00324 const void *v; 00325 const int nColor; 00326 const int nSpin; 00327 const int volume; 00328 const Precision precision; 00329 00330 public: 00331 spinorFunctor(void *V, int Volume, int Nc, int Ns, Precision prec) 00332 : v(V), nColor(Nc), nSpin(Ns), volume(Volume), precision(prec) { ; } 00333 virtual ~spinorFunctor(); 00334 // return element at parity p, linear index x, spin s, color c and complexity z 00335 00336 virtual void* operator()(int p, int x, int s, int c, int z) const = 0; 00337 }; 00338 00339 // accessor for SPACE_SPIN_COLOR_ORDER 00340 class SSCfunctor : public spinorFunctor { 00341 00342 public: 00343 SSCfunctor(void *V, int volume, int Nc, int Ns, Precision prec) 00344 : spinorFunctor(V, volume, Nc, Ns, prec) { ; } 00345 virtual ~SSCfunctor(); 00346 00347 void* operator()(int p, int x, int s, int c, int z) const { 00348 switch (precision) { 00349 case QUDA_DOUBLE_PRECISION: 00350 return ((double*)v)+(((p*volume+x)*nSpin+s)*nColor+c)*2+z; 00351 case QUDA_SINGLE_PRECISION: 00352 return ((float*)v)+(((p*volume+x)*nSpin+s)*nColor+c)*2+z; 00353 default: 00354 errorQuda("Precision not defined"); 00355 } 00356 } 00357 00358 }; 00359 00360 // accessor for SPACE_COLOR_SPIN_ORDER 00361 class SCSfunctor : public spinorFunctor { 00362 00363 public: 00364 SCSfunctor(void *V, int volume, int Nc, int Ns, Precision prec) 00365 : spinorFunctor(V, volume, Nc, Ns, prec) { ; } 00366 virtual ~SCSfunctor(); 00367 00368 void* operator()(int p, int x, int s, int c, int z) const { 00369 switch (precision) { 00370 case QUDA_DOUBLE_PRECISION: 00371 return ((double*)v)+(((p*volume+x)*nColor+c)*nSpin+s)*2+z; 00372 case QUDA_SINGLE_PRECISION: 00373 return ((float*)v)+(((p*volume+x)*nColor+c)*nSpin+s)*2+z; 00374 default: 00375 errorQuda("Precision not defined"); 00376 } 00377 } 00378 00379 }; 00380 00381 */ 00382 00383
1.7.3