QUDA  v0.7.0
A library for QCD on GPUs
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
dirac_staggered.cpp
Go to the documentation of this file.
1 #include <dirac_quda.h>
2 #include <blas_quda.h>
3 
4 namespace quda {
5 
6  namespace staggered {
7 #include <dslash_init.cuh>
8  }
9 
11  Dirac(param), face1(param.gauge->X(), 4, 6, 1, param.gauge->Precision()), face2(param.gauge->X(), 4, 6, 1, param.gauge->Precision())
12  //FIXME: this may break mixed precision multishift solver since may not have fatGauge initializeed yet
13  {
14  staggered::initConstants(*param.gauge, profile);
15  }
16 
18  : Dirac(dirac), face1(dirac.face1), face2(dirac.face2)
19  {
20  staggered::initConstants(dirac.gauge, profile);
21  }
22 
24 
26  {
27  if (&dirac != this) {
28  Dirac::operator=(dirac);
29  face1 = dirac.face1;
30  face2 = dirac.face2;
31  }
32  return *this;
33  }
34 
36  {
37  if (in.Precision() != out.Precision()) {
38  errorQuda("Input and output spinor precisions don't match in dslash_quda");
39  }
40 
41  if (in.Stride() != out.Stride()) {
42  errorQuda("Input %d and output %d spinor strides don't match in dslash_quda", in.Stride(), out.Stride());
43  }
44 
46  errorQuda("ColorSpinorFields are not single parity, in = %d, out = %d",
47  in.SiteSubset(), out.SiteSubset());
48  }
49  }
50 
51 
53  const QudaParity parity) const
54  {
55  checkParitySpinor(in, out);
56 
57  staggered::setFace(face1, face2); // FIXME: temporary hack maintain C linkage for dslashCuda
58  staggeredDslashCuda(&out, gauge, &in, parity, dagger, 0, 0, commDim, profile);
59 
60  flops += 654ll*in.Volume();
61  }
62 
65  const double &k) const
66  {
67  checkParitySpinor(in, out);
68 
69  staggered::setFace(face1, face2); // FIXME: temporary hack maintain C linkage for dslashCuda
70  staggeredDslashCuda(&out, gauge, &in, parity, dagger, &x, k, commDim, profile);
71 
72  flops += 666ll*in.Volume();
73  }
74 
75  // Full staggered operator
77  {
78  bool reset = newTmp(&tmp1, in.Even());
79 
80  DslashXpay(out.Even(), in.Odd(), QUDA_EVEN_PARITY, *tmp1, 2*mass);
81  DslashXpay(out.Odd(), in.Even(), QUDA_ODD_PARITY, *tmp1, 2*mass);
82 
83  deleteTmp(&tmp1, reset);
84  }
85 
87  {
88  bool reset = newTmp(&tmp1, in);
89 
90  cudaColorSpinorField* mytmp = dynamic_cast<cudaColorSpinorField*>(&(tmp1->Even()));
91  cudaColorSpinorField* ineven = dynamic_cast<cudaColorSpinorField*>(&(in.Even()));
92  cudaColorSpinorField* inodd = dynamic_cast<cudaColorSpinorField*>(&(in.Odd()));
93  cudaColorSpinorField* outeven = dynamic_cast<cudaColorSpinorField*>(&(out.Even()));
94  cudaColorSpinorField* outodd = dynamic_cast<cudaColorSpinorField*>(&(out.Odd()));
95 
96  //even
97  Dslash(*mytmp, *ineven, QUDA_ODD_PARITY);
98  DslashXpay(*outeven, *mytmp, QUDA_EVEN_PARITY, *ineven, 4*mass*mass);
99 
100  //odd
101  Dslash(*mytmp, *inodd, QUDA_EVEN_PARITY);
102  DslashXpay(*outodd, *mytmp, QUDA_ODD_PARITY, *inodd, 4*mass*mass);
103 
104  deleteTmp(&tmp1, reset);
105  }
106 
109  const QudaSolutionType solType) const
110  {
111  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
112  errorQuda("Preconditioned solution requires a preconditioned solve_type");
113  }
114 
115  src = &b;
116  sol = &x;
117  }
118 
120  const QudaSolutionType solType) const
121  {
122  // do nothing
123  }
124 
125 
127  : DiracStaggered(param)
128  {
129 
130  }
131 
133  : DiracStaggered(dirac)
134  {
135 
136  }
137 
139  {
140 
141  }
142 
144  {
145  if (&dirac != this) {
147  }
148 
149  return *this;
150  }
151 
153  {
154  errorQuda("DiracStaggeredPC::M() is not implemented\n");
155  }
156 
158  {
159  bool reset = newTmp(&tmp1, in);
160 
162  QudaParity other_parity = QUDA_INVALID_PARITY;
164  parity = QUDA_EVEN_PARITY;
165  other_parity = QUDA_ODD_PARITY;
166  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
167  parity = QUDA_ODD_PARITY;
168  other_parity = QUDA_EVEN_PARITY;
169  } else {
170  errorQuda("Invalid matpcType(%d) in function\n", matpcType);
171  }
172  Dslash(*tmp1, in, other_parity);
173  DslashXpay(out, *tmp1, parity, in, 4*mass*mass);
174 
175  deleteTmp(&tmp1, reset);
176  }
177 
180  const QudaSolutionType solType) const
181  {
182  src = &b;
183  sol = &x;
184  }
185 
187  const QudaSolutionType solType) const
188  {
189  // do nothing
190  }
191 
192 } // namespace quda
cudaGaugeField & gauge
Definition: dirac_quda.h:88
virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaParity parity, const cudaColorSpinorField &x, const double &k) const
virtual void checkParitySpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const
unsigned long long flops
Definition: dirac_quda.h:93
#define errorQuda(...)
Definition: util_quda.h:73
virtual void prepare(cudaColorSpinorField *&src, cudaColorSpinorField *&sol, cudaColorSpinorField &x, cudaColorSpinorField &b, const QudaSolutionType) const
DiracStaggeredPC(const DiracParam &param)
bool newTmp(cudaColorSpinorField **, const cudaColorSpinorField &) const
Definition: dirac.cpp:51
virtual void prepare(cudaColorSpinorField *&src, cudaColorSpinorField *&sol, cudaColorSpinorField &x, cudaColorSpinorField &b, const QudaSolutionType) const
void staggeredDslashCuda(cudaColorSpinorField *out, const cudaGaugeField &gauge, const cudaColorSpinorField *in, const int parity, const int dagger, const cudaColorSpinorField *x, const double &k, const int *commDim, TimeProfile &profile, const QudaDslashPolicy &dslashPolicy=QUDA_DSLASH2)
TimeProfile profile
Definition: dirac_quda.h:104
virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, const QudaSolutionType) const
virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
cudaGaugeField * gauge
Definition: dirac_quda.h:30
QudaGaugeParam param
Definition: pack_test.cpp:17
DiracStaggered & operator=(const DiracStaggered &dirac)
cudaColorSpinorField & Odd() const
int commDim[QUDA_MAX_DIM]
Definition: dirac_quda.h:102
DiracStaggeredPC & operator=(const DiracStaggeredPC &dirac)
virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaParity parity) const
cpuColorSpinorField * in
double mass
Definition: dirac_quda.h:90
enum QudaSolutionType_s QudaSolutionType
void deleteTmp(cudaColorSpinorField **, const bool &reset) const
Definition: dirac.cpp:59
Dirac * dirac
Definition: dslash_test.cpp:45
QudaDagType dagger
Definition: dirac_quda.h:92
enum QudaParity_s QudaParity
int x[4]
QudaMatPCType matpcType
Definition: dirac_quda.h:91
virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
Dirac & operator=(const Dirac &dirac)
Definition: dirac.cpp:32
cpuColorSpinorField * out
cudaColorSpinorField * tmp1
Definition: dirac_quda.h:94
QudaPrecision Precision() const
virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, const QudaSolutionType) const
DiracStaggered(const DiracParam &param)
virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
QudaSiteSubset SiteSubset() const
const QudaParity parity
Definition: dslash_test.cpp:29
void * gauge[4]
Definition: su3_test.cpp:15
cudaColorSpinorField & Even() const