QUDA v0.4.0
A library for QCD on GPUs
|
00001 #include <dirac_quda.h> 00002 #include <blas_quda.h> 00003 00004 DiracStaggered::DiracStaggered(const DiracParam ¶m) : 00005 Dirac(param), fatGauge(param.fatGauge), longGauge(param.longGauge), 00006 face(param.fatGauge->X(), 4, 6, 3, param.fatGauge->Precision()) 00007 //FIXME: this may break mixed precision multishift solver since may not have fatGauge initializeed yet 00008 { 00009 00010 } 00011 00012 DiracStaggered::DiracStaggered(const DiracStaggered &dirac) : Dirac(dirac), 00013 fatGauge(dirac.fatGauge), longGauge(dirac.longGauge), face(dirac.face) { } 00014 00015 DiracStaggered::~DiracStaggered() 00016 { 00017 00018 } 00019 00020 DiracStaggered& DiracStaggered::operator=(const DiracStaggered &dirac) 00021 { 00022 if (&dirac != this) { 00023 Dirac::operator=(dirac); 00024 fatGauge = dirac.fatGauge; 00025 longGauge = dirac.longGauge; 00026 face = dirac.face; 00027 } 00028 00029 return *this; 00030 } 00031 00032 void DiracStaggered::checkParitySpinor(const cudaColorSpinorField &in, const cudaColorSpinorField &out) const 00033 { 00034 if (in.Precision() != out.Precision()) { 00035 errorQuda("Input and output spinor precisions don't match in dslash_quda"); 00036 } 00037 00038 if (in.Stride() != out.Stride()) { 00039 errorQuda("Input %d and output %d spinor strides don't match in dslash_quda", in.Stride(), out.Stride()); 00040 } 00041 00042 if (in.SiteSubset() != QUDA_PARITY_SITE_SUBSET || out.SiteSubset() != QUDA_PARITY_SITE_SUBSET) { 00043 errorQuda("ColorSpinorFields are not single parity, in = %d, out = %d", 00044 in.SiteSubset(), out.SiteSubset()); 00045 } 00046 00047 if ((out.Volume() != 2*fatGauge->VolumeCB() && out.SiteSubset() == QUDA_FULL_SITE_SUBSET) || 00048 (out.Volume() != fatGauge->VolumeCB() && out.SiteSubset() == QUDA_PARITY_SITE_SUBSET) ) { 00049 errorQuda("Spinor volume %d doesn't match gauge volume %d", out.Volume(), fatGauge->VolumeCB()); 00050 } 00051 } 00052 00053 00054 void DiracStaggered::Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00055 const QudaParity parity) const 00056 { 00057 if (!initDslash) { 00058 initDslashConstants(*fatGauge, in.Stride()); 00059 initStaggeredConstants(*fatGauge, *longGauge); 00060 } 00061 checkParitySpinor(in, out); 00062 00063 setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda 00064 staggeredDslashCuda(&out, *fatGauge, *longGauge, &in, parity, dagger, 0, 0, commDim); 00065 00066 flops += 1146*in.Volume(); 00067 } 00068 00069 void DiracStaggered::DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, 00070 const QudaParity parity, const cudaColorSpinorField &x, 00071 const double &k) const 00072 { 00073 if (!initDslash){ 00074 initDslashConstants(*fatGauge, in.Stride()); 00075 initStaggeredConstants(*fatGauge, *longGauge); 00076 } 00077 checkParitySpinor(in, out); 00078 00079 setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda 00080 staggeredDslashCuda(&out, *fatGauge, *longGauge, &in, parity, dagger, &x, k, commDim); 00081 00082 flops += (1146+12)*in.Volume(); 00083 } 00084 00085 // Full staggered operator 00086 void DiracStaggered::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00087 { 00088 if (!initDslash){ 00089 initDslashConstants(*fatGauge, in.Stride()); 00090 initStaggeredConstants(*fatGauge, *longGauge); 00091 } 00092 00093 bool reset = newTmp(&tmp1, in.Even()); 00094 00095 DslashXpay(out.Even(), in.Odd(), QUDA_EVEN_PARITY, *tmp1, 2*mass); 00096 DslashXpay(out.Odd(), in.Even(), QUDA_ODD_PARITY, *tmp1, 2*mass); 00097 00098 deleteTmp(&tmp1, reset); 00099 } 00100 00101 void DiracStaggered::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00102 { 00103 00104 if (!initDslash){ 00105 initDslashConstants(*fatGauge, in.Stride()); 00106 initStaggeredConstants(*fatGauge, *longGauge); 00107 } 00108 00109 bool reset = newTmp(&tmp1, in); 00110 00111 cudaColorSpinorField* mytmp = dynamic_cast<cudaColorSpinorField*>(&(tmp1->Even())); 00112 cudaColorSpinorField* ineven = dynamic_cast<cudaColorSpinorField*>(&(in.Even())); 00113 cudaColorSpinorField* inodd = dynamic_cast<cudaColorSpinorField*>(&(in.Odd())); 00114 cudaColorSpinorField* outeven = dynamic_cast<cudaColorSpinorField*>(&(out.Even())); 00115 cudaColorSpinorField* outodd = dynamic_cast<cudaColorSpinorField*>(&(out.Odd())); 00116 00117 //even 00118 Dslash(*mytmp, *ineven, QUDA_ODD_PARITY); 00119 DslashXpay(*outeven, *mytmp, QUDA_EVEN_PARITY, *ineven, 4*mass*mass); 00120 00121 //odd 00122 Dslash(*mytmp, *inodd, QUDA_EVEN_PARITY); 00123 DslashXpay(*outodd, *mytmp, QUDA_ODD_PARITY, *inodd, 4*mass*mass); 00124 00125 deleteTmp(&tmp1, reset); 00126 } 00127 00128 void DiracStaggered::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00129 cudaColorSpinorField &x, cudaColorSpinorField &b, 00130 const QudaSolutionType solType) const 00131 { 00132 if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { 00133 errorQuda("Preconditioned solution requires a preconditioned solve_type"); 00134 } 00135 00136 src = &b; 00137 sol = &x; 00138 } 00139 00140 void DiracStaggered::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00141 const QudaSolutionType solType) const 00142 { 00143 // do nothing 00144 } 00145 00146 00147 DiracStaggeredPC::DiracStaggeredPC(const DiracParam ¶m) 00148 : DiracStaggered(param) 00149 { 00150 00151 } 00152 00153 DiracStaggeredPC::DiracStaggeredPC(const DiracStaggeredPC &dirac) 00154 : DiracStaggered(dirac) 00155 { 00156 00157 } 00158 00159 DiracStaggeredPC::~DiracStaggeredPC() 00160 { 00161 00162 } 00163 00164 DiracStaggeredPC& DiracStaggeredPC::operator=(const DiracStaggeredPC &dirac) 00165 { 00166 if (&dirac != this) { 00167 DiracStaggered::operator=(dirac); 00168 } 00169 00170 return *this; 00171 } 00172 00173 void DiracStaggeredPC::M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00174 { 00175 errorQuda("DiracStaggeredPC::M() is not implemented\n"); 00176 } 00177 00178 void DiracStaggeredPC::MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const 00179 { 00180 if (!initDslash){ 00181 initDslashConstants(*fatGauge, in.Stride()); 00182 initStaggeredConstants(*fatGauge, *longGauge); 00183 } 00184 00185 bool reset = newTmp(&tmp1, in); 00186 00187 QudaParity parity = QUDA_INVALID_PARITY; 00188 QudaParity other_parity = QUDA_INVALID_PARITY; 00189 if (matpcType == QUDA_MATPC_EVEN_EVEN) { 00190 parity = QUDA_EVEN_PARITY; 00191 other_parity = QUDA_ODD_PARITY; 00192 } else if (matpcType == QUDA_MATPC_ODD_ODD) { 00193 parity = QUDA_ODD_PARITY; 00194 other_parity = QUDA_EVEN_PARITY; 00195 } else { 00196 errorQuda("Invalid matpcType(%d) in function\n", matpcType); 00197 } 00198 Dslash(*tmp1, in, other_parity); 00199 DslashXpay(out, *tmp1, parity, in, 4*mass*mass); 00200 00201 deleteTmp(&tmp1, reset); 00202 } 00203 00204 void DiracStaggeredPC::prepare(cudaColorSpinorField* &src, cudaColorSpinorField* &sol, 00205 cudaColorSpinorField &x, cudaColorSpinorField &b, 00206 const QudaSolutionType solType) const 00207 { 00208 src = &b; 00209 sol = &x; 00210 } 00211 00212 void DiracStaggeredPC::reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, 00213 const QudaSolutionType solType) const 00214 { 00215 // do nothing 00216 } 00217 00218 00219 00220