QUDA  0.9.0
dirac_improved_staggered.cpp
Go to the documentation of this file.
1 #include <dirac_quda.h>
2 #include <blas_quda.h>
3 
4 namespace quda {
5 
7  : Dirac(param), fatGauge(*(param.fatGauge)), longGauge(*(param.longGauge)) { }
8 
10  : Dirac(dirac), fatGauge(dirac.fatGauge), longGauge(dirac.longGauge) { }
11 
13 
15  {
16  if (&dirac != this) {
18  fatGauge = dirac.fatGauge;
19  longGauge = dirac.longGauge;
20  }
21  return *this;
22  }
23 
25  {
26  if (in.Ndim() != 5 || out.Ndim() != 5) {
27  errorQuda("Staggered dslash requires 5-d fermion fields");
28  }
29 
30  if (in.Precision() != out.Precision()) {
31  errorQuda("Input and output spinor precisions don't match in dslash_quda");
32  }
33 
34  if (in.Stride() != out.Stride()) {
35  errorQuda("Input %d and output %d spinor strides don't match in dslash_quda", in.Stride(), out.Stride());
36  }
37 
38  if (in.SiteSubset() != QUDA_PARITY_SITE_SUBSET || out.SiteSubset() != QUDA_PARITY_SITE_SUBSET) {
39  errorQuda("ColorSpinorFields are not single parity, in = %d, out = %d",
40  in.SiteSubset(), out.SiteSubset());
41  }
42 
43  if ((out.Volume()/out.X(4) != 2*fatGauge.VolumeCB() && out.SiteSubset() == QUDA_FULL_SITE_SUBSET) ||
44  (out.Volume()/out.X(4) != fatGauge.VolumeCB() && out.SiteSubset() == QUDA_PARITY_SITE_SUBSET) ) {
45  errorQuda("Spinor volume %d doesn't match gauge volume %d", out.Volume(), fatGauge.VolumeCB());
46  }
47  }
48 
49 
51  const QudaParity parity) const
52  {
54 
56  improvedStaggeredDslashCuda(&static_cast<cudaColorSpinorField&>(out), fatGauge, longGauge,
57  &static_cast<const cudaColorSpinorField&>(in), parity,
58  dagger, 0, 0, commDim, profile);
59  } else {
60  errorQuda("Not supported");
61  }
62 
63  flops += 1146ll*in.Volume();
64  }
65 
67  const QudaParity parity, const ColorSpinorField &x,
68  const double &k) const
69  {
71 
73  improvedStaggeredDslashCuda(&static_cast<cudaColorSpinorField&>(out), fatGauge, longGauge,
74  &static_cast<const cudaColorSpinorField&>(in), parity, dagger,
75  &static_cast<const cudaColorSpinorField&>(x), k, commDim, profile);
76  } else {
77  errorQuda("Not supported");
78  }
79 
80  flops += 1158ll*in.Volume();
81  }
82 
83  // Full staggered operator
85  {
86  DslashXpay(out.Even(), in.Odd(), QUDA_EVEN_PARITY, in.Even(), 2*mass);
87  DslashXpay(out.Odd(), in.Even(), QUDA_ODD_PARITY, in.Odd(), 2*mass);
88  }
89 
91  {
92  bool reset = newTmp(&tmp1, in);
93 
94  //even
95  Dslash(tmp1->Even(), in.Even(), QUDA_ODD_PARITY);
96  DslashXpay(out.Even(), tmp1->Even(), QUDA_EVEN_PARITY, in.Even(), 4*mass*mass);
97 
98  //odd
99  Dslash(tmp1->Even(), in.Odd(), QUDA_EVEN_PARITY);
100  DslashXpay(out.Odd(), tmp1->Even(), QUDA_ODD_PARITY, in.Odd(), 4*mass*mass);
101 
102  deleteTmp(&tmp1, reset);
103  }
104 
107  const QudaSolutionType solType) const
108  {
109  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
110  errorQuda("Preconditioned solution requires a preconditioned solve_type");
111  }
112 
113  src = &b;
114  sol = &x;
115  }
116 
118  const QudaSolutionType solType) const
119  {
120  // do nothing
121  }
122 
123 
126  {
127 
128  }
129 
132  {
133 
134  }
135 
137  {
138 
139  }
140 
142  {
143  if (&dirac != this) {
145  }
146 
147  return *this;
148  }
149 
151  {
152  errorQuda("DiracImprovedStaggeredPC::M() is not implemented\n");
153  }
154 
156  {
157  bool reset = newTmp(&tmp1, in);
158 
160  QudaParity other_parity = QUDA_INVALID_PARITY;
163  other_parity = QUDA_ODD_PARITY;
164  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
166  other_parity = QUDA_EVEN_PARITY;
167  } else {
168  errorQuda("Invalid matpcType(%d) in function\n", matpcType);
169  }
170  Dslash(*tmp1, in, other_parity);
171  DslashXpay(out, *tmp1, parity, in, 4*mass*mass);
172 
173  deleteTmp(&tmp1, reset);
174  }
175 
178  const QudaSolutionType solType) const
179  {
180  src = &b;
181  sol = &x;
182  }
183 
185  const QudaSolutionType solType) const
186  {
187  // do nothing
188  }
189 
190 } // namespace quda
DiracImprovedStaggeredPC & operator=(const DiracImprovedStaggeredPC &dirac)
DiracImprovedStaggeredPC(const DiracParam &param)
unsigned long long flops
Definition: dirac_quda.h:100
virtual void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, const ColorSpinorField &x, const double &k) const
const void * src
virtual void prepare(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType) const
#define errorQuda(...)
Definition: util_quda.h:90
const ColorSpinorField & Even() const
void deleteTmp(ColorSpinorField **, const bool &reset) const
Definition: dirac.cpp:64
TimeProfile profile
Definition: dirac_quda.h:112
virtual void reconstruct(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType) const
cudaGaugeField & fatGauge
Definition: dirac_quda.h:690
QudaGaugeParam param
Definition: pack_test.cpp:17
bool newTmp(ColorSpinorField **, const ColorSpinorField &) const
Definition: dirac.cpp:53
#define b
int commDim[QUDA_MAX_DIM]
Definition: dirac_quda.h:110
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const
cpuColorSpinorField * in
double mass
Definition: dirac_quda.h:97
virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const
#define checkLocation(...)
enum QudaSolutionType_s QudaSolutionType
QudaDagType dagger
Definition: dirac_quda.h:99
enum QudaParity_s QudaParity
QudaMatPCType matpcType
Definition: dirac_quda.h:98
virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const
void improvedStaggeredDslashCuda(cudaColorSpinorField *out, const cudaGaugeField &fatGauge, const cudaGaugeField &longGauge, const cudaColorSpinorField *in, const int parity, const int dagger, const cudaColorSpinorField *x, const double &k, const int *commDim, TimeProfile &profile)
Dirac & operator=(const Dirac &dirac)
Definition: dirac.cpp:32
GaugeCovDev * dirac
Definition: covdev_test.cpp:75
cpuColorSpinorField * out
virtual void prepare(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType) const
int VolumeCB() const
cudaGaugeField & longGauge
Definition: dirac_quda.h:691
virtual void checkParitySpinor(const ColorSpinorField &, const ColorSpinorField &) const
virtual void reconstruct(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType) const
virtual void Dslash(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const
QudaParity parity
Definition: covdev_test.cpp:53
DiracImprovedStaggered(const DiracParam &param)
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const
DiracImprovedStaggered & operator=(const DiracImprovedStaggered &dirac)
ColorSpinorField * tmp1
Definition: dirac_quda.h:101