|
QUDA v0.3.2
A library for QCD on GPUs
|
00001 #include <dirac_quda.h> 00002 #include <blas_quda.h> 00003 00004 DiracStaggeredPC::DiracStaggeredPC(const DiracParam ¶m) 00005 : Dirac(param), fatGauge(param.fatGauge), longGauge(param.longGauge) 00006 { 00007 00008 } 00009 00010 DiracStaggeredPC::DiracStaggeredPC(const DiracStaggeredPC &dirac) 00011 : Dirac(dirac) 00012 { 00013 00014 } 00015 00016 DiracStaggeredPC::~DiracStaggeredPC() 00017 { 00018 00019 } 00020 00021 DiracStaggeredPC& DiracStaggeredPC::operator=(const DiracStaggeredPC &dirac) 00022 { 00023 if (&dirac != this) { 00024 Dirac::operator=(dirac); 00025 } 00026 00027 return *this; 00028 } 00029 00030 00031 00032 void DiracStaggeredPC::checkParitySpinor(const cudaColorSpinorField &in, const cudaColorSpinorField &out) const 00033 { 00034 00035 if (in.Precision() != out.Precision()) { 00036 errorQuda("Input and output spinor precisions don't match in dslash_quda"); 00037 } 00038 00039 if (in.Stride() != out.Stride()) { 00040 errorQuda("Input %d and output %d spinor strides don't match in dslash_quda", in.Stride(), out.Stride()); 00041 } 00042 00043 if (in.SiteSubset() != QUDA_PARITY_SITE_SUBSET || out.SiteSubset() != QUDA_PARITY_SITE_SUBSET) { 00044 errorQuda("ColorSpinorFields are not single parity, in = %d, out = %d", 00045 in.SiteSubset(), out.SiteSubset()); 00046 } 00047 00048 if ((out.Volume() != 2*fatGauge->volume && out.SiteSubset() == QUDA_FULL_SITE_SUBSET) || 00049 (out.Volume() != fatGauge->volume && out.SiteSubset() == QUDA_PARITY_SITE_SUBSET) ) { 00050 errorQuda("Spinor volume %d doesn't match gauge volume %d", out.Volume(), fatGauge->volume); 00051 } 00052 00053 } 00054 00055 00056 void DiracStaggeredPC::Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00057 const QudaParity parity) const 00058 { 00059 if (!initDslash) { 00060 initDslashConstants(*fatGauge, in.Stride(), 0); 00061 } 00062 checkParitySpinor(in, out); 00063 00064 staggeredDslashCuda(out.v, out.norm, *fatGauge, *longGauge, in.v, in.norm, parity, dagger, 00065 0, 0, 0, out.volume, out.length, in.Precision()); 00066 00067 flops += 1146*in.volume; 00068 } 00069 00070 void DiracStaggeredPC::DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00071 const QudaParity parity, const cudaColorSpinorField &x, 00072 const double &k) const 00073 { 00074 if (!initDslash){ 00075 initDslashConstants(*fatGauge, in.Stride(), 0); 00076 } 00077 checkParitySpinor(in, out); 00078 00079 staggeredDslashCuda(out.v, out.norm, *fatGauge, *longGauge, in.v, in.norm, parity, dagger, x.v, x.norm, k, 00080 out.volume, out.length, in.Precision()); 00081 00082 flops += (1146+12)*in.volume; 00083 } 00084 00085 void DiracStaggeredPC::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00086 { 00087 errorQuda("DiracStaggeredPC::M() is not implemented\n"); 00088 } 00089 00090 void DiracStaggeredPC::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00091 { 00092 if (!initDslash){ 00093 initDslashConstants(*fatGauge, in.Stride(), 0); 00094 } 00095 00096 bool reset = newTmp(&tmp1, in); 00097 00098 QudaParity parity; 00099 QudaParity other_parity; 00100 if (matpcType == QUDA_MATPC_EVEN_EVEN) { 00101 parity = QUDA_EVEN_PARITY; 00102 other_parity = QUDA_ODD_PARITY; 00103 } else if (matpcType == QUDA_MATPC_ODD_ODD) { 00104 parity = QUDA_ODD_PARITY; 00105 other_parity = QUDA_EVEN_PARITY; 00106 } else { 00107 errorQuda("Invalid matpcType(%d) in function\n", matpcType); 00108 } 00109 00110 Dslash(*tmp1, in, other_parity); 00111 DslashXpay(out, *tmp1, parity, in, 4*mass*mass); 00112 00113 deleteTmp(&tmp1, reset); 00114 } 00115 00116 void DiracStaggeredPC::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00117 cudaColorSpinorField &x, cudaColorSpinorField &b, 00118 const QudaSolutionType solType) const 00119 { 00120 src = &b; 00121 sol = &x; 00122 } 00123 00124 void DiracStaggeredPC::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00125 const QudaSolutionType solType) const 00126 { 00127 // do nothing 00128 } 00129 00130 00131 00132 00133 DiracStaggered::DiracStaggered(const DiracParam ¶m) 00134 : Dirac(param), fatGauge(param.fatGauge), longGauge(param.longGauge) 00135 { 00136 00137 } 00138 00139 DiracStaggered::DiracStaggered(const DiracStaggered &dirac) 00140 : Dirac(dirac) 00141 { 00142 00143 } 00144 00145 DiracStaggered::~DiracStaggered() 00146 { 00147 00148 } 00149 00150 DiracStaggered& DiracStaggered::operator=(const DiracStaggered &dirac) 00151 { 00152 if (&dirac != this) { 00153 Dirac::operator=(dirac); 00154 } 00155 00156 return *this; 00157 } 00158 00159 00160 00161 void DiracStaggered::checkParitySpinor(const cudaColorSpinorField &in, const cudaColorSpinorField &out) const 00162 { 00163 if (in.Precision() != out.Precision()) { 00164 errorQuda("Input and output spinor precisions don't match in dslash_quda"); 00165 } 00166 00167 if (in.Stride() != out.Stride()) { 00168 errorQuda("Input %d and output %d spinor strides don't match in dslash_quda", in.Stride(), out.Stride()); 00169 } 00170 00171 if (in.SiteSubset() != QUDA_PARITY_SITE_SUBSET || out.SiteSubset() != QUDA_PARITY_SITE_SUBSET) { 00172 errorQuda("ColorSpinorFields are not single parity, in = %d, out = %d", 00173 in.SiteSubset(), out.SiteSubset()); 00174 } 00175 00176 if ((out.Volume() != 2*fatGauge->volume && out.SiteSubset() == QUDA_FULL_SITE_SUBSET) || 00177 (out.Volume() != fatGauge->volume && out.SiteSubset() == QUDA_PARITY_SITE_SUBSET) ) { 00178 errorQuda("Spinor volume %d doesn't match gauge volume %d", out.Volume(), fatGauge->volume); 00179 } 00180 } 00181 00182 00183 void DiracStaggered::Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00184 const QudaParity parity) const 00185 { 00186 if (!initDslash) { 00187 initDslashConstants(*fatGauge, in.Stride(), 0); 00188 } 00189 checkParitySpinor(in, out); 00190 00191 staggeredDslashCuda(out.v, out.norm, *fatGauge, *longGauge, in.v, in.norm, parity, dagger, 00192 0, 0, 0, out.volume, out.length, in.Precision()); 00193 00194 flops += 1187*in.volume; 00195 } 00196 00197 void DiracStaggered::DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00198 const QudaParity parity, const cudaColorSpinorField &x, 00199 const double &k) const 00200 { 00201 if (!initDslash){ 00202 initDslashConstants(*fatGauge, in.Stride(), 0); 00203 } 00204 checkParitySpinor(in, out); 00205 00206 staggeredDslashCuda(out.v, out.norm, *fatGauge, *longGauge, in.v, in.norm, parity, dagger, x.v, x.norm, k, 00207 out.volume, out.length, in.Precision()); 00208 00209 flops += (1187+12)*in.volume; 00210 } 00211 00212 void DiracStaggered::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00213 { 00214 errorQuda("DiracStaggered::M() is not implemented"); 00215 } 00216 00217 void DiracStaggered::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00218 { 00219 00220 if (!initDslash){ 00221 initDslashConstants(*fatGauge, in.Stride(), 0); 00222 } 00223 00224 bool reset = newTmp(&tmp1, in); 00225 00226 cudaColorSpinorField* mytmp = dynamic_cast<cudaColorSpinorField*>(tmp1->even); 00227 cudaColorSpinorField* ineven = dynamic_cast<cudaColorSpinorField*>(in.even); 00228 cudaColorSpinorField* inodd = dynamic_cast<cudaColorSpinorField*>(in.odd); 00229 cudaColorSpinorField* outeven = dynamic_cast<cudaColorSpinorField*>(out.even); 00230 cudaColorSpinorField* outodd = dynamic_cast<cudaColorSpinorField*>(out.odd); 00231 00232 //even 00233 Dslash(*mytmp, *ineven, QUDA_ODD_PARITY); 00234 DslashXpay(*outeven, *mytmp, QUDA_EVEN_PARITY, *ineven, 4*mass*mass); 00235 00236 //odd 00237 Dslash(*mytmp, *inodd, QUDA_EVEN_PARITY); 00238 DslashXpay(*outodd, *mytmp, QUDA_ODD_PARITY, *inodd, 4*mass*mass); 00239 00240 deleteTmp(&tmp1, reset); 00241 } 00242 00243 void DiracStaggered::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00244 cudaColorSpinorField &x, cudaColorSpinorField &b, 00245 const QudaSolutionType solType) const 00246 { 00247 if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { 00248 errorQuda("Preconditioned solution requires a preconditioned solve_type"); 00249 } 00250 00251 src = &b; 00252 sol = &x; 00253 } 00254 00255 void DiracStaggered::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00256 const QudaSolutionType solType) const 00257 { 00258 // do nothing 00259 }
1.7.3