QUDA  v0.7.0
A library for QCD on GPUs
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
dirac_improved_staggered.cpp
Go to the documentation of this file.
1 #include <dirac_quda.h>
2 #include <blas_quda.h>
3 
4 namespace quda {
5 
6  namespace improvedstaggered {
7 #include <dslash_init.cuh>
8  }
9 
11  Dirac(param), fatGauge(*(param.fatGauge)), longGauge(*(param.longGauge)),
12  face1(param.fatGauge->X(), 4, 6, 3, param.fatGauge->Precision()),
13  face2(param.fatGauge->X(), 4, 6, 3, param.fatGauge->Precision())
14  //FIXME: this may break mixed precision multishift solver since may not have fatGauge initializeed yet
15  {
16  improvedstaggered::initConstants(*param.gauge, profile);
18  }
19 
21  : Dirac(dirac), fatGauge(dirac.fatGauge), longGauge(dirac.longGauge), face1(dirac.face1), face2(dirac.face2)
22  {
23  improvedstaggered::initConstants(dirac.gauge, profile);
25  }
26 
28 
30  {
31  if (&dirac != this) {
32  Dirac::operator=(dirac);
33  fatGauge = dirac.fatGauge;
34  longGauge = dirac.longGauge;
35  face1 = dirac.face1;
36  face2 = dirac.face2;
37  }
38  return *this;
39  }
40 
42  {
43  if (in.Precision() != out.Precision()) {
44  errorQuda("Input and output spinor precisions don't match in dslash_quda");
45  }
46 
47  if (in.Stride() != out.Stride()) {
48  errorQuda("Input %d and output %d spinor strides don't match in dslash_quda", in.Stride(), out.Stride());
49  }
50 
52  errorQuda("ColorSpinorFields are not single parity, in = %d, out = %d",
53  in.SiteSubset(), out.SiteSubset());
54  }
55 
56  if ((out.Volume() != 2*fatGauge.VolumeCB() && out.SiteSubset() == QUDA_FULL_SITE_SUBSET) ||
57  (out.Volume() != fatGauge.VolumeCB() && out.SiteSubset() == QUDA_PARITY_SITE_SUBSET) ) {
58  errorQuda("Spinor volume %d doesn't match gauge volume %d", out.Volume(), fatGauge.VolumeCB());
59  }
60  }
61 
62 
64  const QudaParity parity) const
65  {
66  checkParitySpinor(in, out);
67 
68  improvedstaggered::setFace(face1,face2); // FIXME: temporary hack maintain C linkage for dslashCuda
70 
71  flops += 1146ll*in.Volume();
72  }
73 
76  const double &k) const
77  {
78  checkParitySpinor(in, out);
79 
80  improvedstaggered::setFace(face1,face2); // FIXME: temporary hack maintain C linkage for dslashCuda
81  improvedStaggeredDslashCuda(&out, fatGauge, longGauge, &in, parity, dagger, &x, k, commDim, profile);
82 
83  flops += 1158ll*in.Volume();
84  }
85 
86  // Full staggered operator
88  {
89  bool reset = newTmp(&tmp1, in.Even());
90 
91  DslashXpay(out.Even(), in.Odd(), QUDA_EVEN_PARITY, *tmp1, 2*mass);
92  DslashXpay(out.Odd(), in.Even(), QUDA_ODD_PARITY, *tmp1, 2*mass);
93 
94  deleteTmp(&tmp1, reset);
95  }
96 
98  {
99  bool reset = newTmp(&tmp1, in);
100 
101  cudaColorSpinorField* mytmp = dynamic_cast<cudaColorSpinorField*>(&(tmp1->Even()));
102  cudaColorSpinorField* ineven = dynamic_cast<cudaColorSpinorField*>(&(in.Even()));
103  cudaColorSpinorField* inodd = dynamic_cast<cudaColorSpinorField*>(&(in.Odd()));
104  cudaColorSpinorField* outeven = dynamic_cast<cudaColorSpinorField*>(&(out.Even()));
105  cudaColorSpinorField* outodd = dynamic_cast<cudaColorSpinorField*>(&(out.Odd()));
106 
107  //even
108  Dslash(*mytmp, *ineven, QUDA_ODD_PARITY);
109  DslashXpay(*outeven, *mytmp, QUDA_EVEN_PARITY, *ineven, 4*mass*mass);
110 
111  //odd
112  Dslash(*mytmp, *inodd, QUDA_EVEN_PARITY);
113  DslashXpay(*outodd, *mytmp, QUDA_ODD_PARITY, *inodd, 4*mass*mass);
114 
115  deleteTmp(&tmp1, reset);
116  }
117 
120  const QudaSolutionType solType) const
121  {
122  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
123  errorQuda("Preconditioned solution requires a preconditioned solve_type");
124  }
125 
126  src = &b;
127  sol = &x;
128  }
129 
131  const QudaSolutionType solType) const
132  {
133  // do nothing
134  }
135 
136 
138  : DiracImprovedStaggered(param)
139  {
140 
141  }
142 
144  : DiracImprovedStaggered(dirac)
145  {
146 
147  }
148 
150  {
151 
152  }
153 
155  {
156  if (&dirac != this) {
158  }
159 
160  return *this;
161  }
162 
164  {
165  errorQuda("DiracImprovedStaggeredPC::M() is not implemented\n");
166  }
167 
169  {
170  bool reset = newTmp(&tmp1, in);
171 
173  QudaParity other_parity = QUDA_INVALID_PARITY;
175  parity = QUDA_EVEN_PARITY;
176  other_parity = QUDA_ODD_PARITY;
177  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
178  parity = QUDA_ODD_PARITY;
179  other_parity = QUDA_EVEN_PARITY;
180  } else {
181  errorQuda("Invalid matpcType(%d) in function\n", matpcType);
182  }
183  Dslash(*tmp1, in, other_parity);
184  DslashXpay(out, *tmp1, parity, in, 4*mass*mass);
185 
186  deleteTmp(&tmp1, reset);
187  }
188 
191  const QudaSolutionType solType) const
192  {
193  src = &b;
194  sol = &x;
195  }
196 
198  const QudaSolutionType solType) const
199  {
200  // do nothing
201  }
202 
203 } // namespace quda
DiracImprovedStaggeredPC & operator=(const DiracImprovedStaggeredPC &dirac)
cudaGaugeField & gauge
Definition: dirac_quda.h:88
DiracImprovedStaggeredPC(const DiracParam &param)
void initStaggeredConstants(const cudaGaugeField &fatgauge, const cudaGaugeField &longgauge, TimeProfile &profile)
unsigned long long flops
Definition: dirac_quda.h:93
int VolumeCB() const
virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
#define errorQuda(...)
Definition: util_quda.h:73
bool newTmp(cudaColorSpinorField **, const cudaColorSpinorField &) const
Definition: dirac.cpp:51
TimeProfile profile
Definition: dirac_quda.h:104
void improvedStaggeredDslashCuda(cudaColorSpinorField *out, const cudaGaugeField &fatGauge, const cudaGaugeField &longGauge, const cudaColorSpinorField *in, const int parity, const int dagger, const cudaColorSpinorField *x, const double &k, const int *commDim, TimeProfile &profile, const QudaDslashPolicy &dslashPolicy=QUDA_DSLASH2)
virtual void prepare(cudaColorSpinorField *&src, cudaColorSpinorField *&sol, cudaColorSpinorField &x, cudaColorSpinorField &b, const QudaSolutionType) const
cudaGaugeField * gauge
Definition: dirac_quda.h:30
cudaGaugeField & fatGauge
Definition: dirac_quda.h:523
QudaGaugeParam param
Definition: pack_test.cpp:17
cudaColorSpinorField & Odd() const
int commDim[QUDA_MAX_DIM]
Definition: dirac_quda.h:102
virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaParity parity, const cudaColorSpinorField &x, const double &k) const
virtual void checkParitySpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const
cpuColorSpinorField * in
double mass
Definition: dirac_quda.h:90
enum QudaSolutionType_s QudaSolutionType
virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
void deleteTmp(cudaColorSpinorField **, const bool &reset) const
Definition: dirac.cpp:59
Dirac * dirac
Definition: dslash_test.cpp:45
QudaDagType dagger
Definition: dirac_quda.h:92
enum QudaParity_s QudaParity
int x[4]
QudaMatPCType matpcType
Definition: dirac_quda.h:91
Dirac & operator=(const Dirac &dirac)
Definition: dirac.cpp:32
cpuColorSpinorField * out
cudaColorSpinorField * tmp1
Definition: dirac_quda.h:94
QudaPrecision Precision() const
cudaGaugeField & longGauge
Definition: dirac_quda.h:524
virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, const QudaSolutionType) const
virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaParity parity) const
DiracImprovedStaggered(const DiracParam &param)
QudaSiteSubset SiteSubset() const
const QudaParity parity
Definition: dslash_test.cpp:29
cudaColorSpinorField & Even() const
virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
virtual void prepare(cudaColorSpinorField *&src, cudaColorSpinorField *&sol, cudaColorSpinorField &x, cudaColorSpinorField &b, const QudaSolutionType) const
virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, const QudaSolutionType) const
virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
DiracImprovedStaggered & operator=(const DiracImprovedStaggered &dirac)