QUDA  v0.7.0
A library for QCD on GPUs
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
dirac_wilson.cpp
Go to the documentation of this file.
1 #include <dirac_quda.h>
2 #include <blas_quda.h>
3 #include <iostream>
4 
5 namespace quda {
6 
7  namespace wilson {
8 #include <dslash_init.cuh>
9  }
10 
12  Dirac(param), face1(param.gauge->X(), 4, 12, 1, param.gauge->Precision()),
13  face2(param.gauge->X(), 4, 12, 1, param.gauge->Precision())
14  {
15  wilson::initConstants(*param.gauge, profile);
16  }
17 
19  Dirac(dirac), face1(dirac.face1), face2(dirac.face2)
20  {
21  wilson::initConstants(dirac.gauge, profile);
22  }
23 
24  DiracWilson::DiracWilson(const DiracParam &param, const int nDims) :
25  Dirac(param), face1(param.gauge->X(), nDims, 12, 1, param.gauge->Precision(), param.Ls),
26  face2(param.gauge->X(), nDims, 12, 1, param.gauge->Precision(), param.Ls)
27  {
28  wilson::initConstants(*param.gauge, profile);
29 
30  }//temporal hack (for DW and TM operators)
31 
33 
35  {
36  if (&dirac != this) {
37  Dirac::operator=(dirac);
38  face1 = dirac.face1;
39  face2 = dirac.face2;
40  }
41  return *this;
42  }
43 
45  const QudaParity parity) const
46  {
47  wilson::setFace(face1,face2); // FIXME: temporary hack maintain C linkage for dslashCuda
48 
49  checkParitySpinor(in, out);
50  checkSpinorAlias(in, out);
51 
52  wilsonDslashCuda(&out, gauge, &in, parity, dagger, 0, 0.0, commDim, profile);
53 
54  flops += 1320ll*in.Volume();
55  }
56 
59  const double &k) const
60  {
61  wilson::setFace(face1,face2); // FIXME: temporary hack maintain C linkage for dslashCuda
62 
63  checkParitySpinor(in, out);
64  checkSpinorAlias(in, out);
65 
66  wilsonDslashCuda(&out, gauge, &in, parity, dagger, &x, k, commDim, profile);
67 
68  flops += 1368ll*in.Volume();
69  }
70 
72  {
73  checkFullSpinor(out, in);
74  DslashXpay(out.Odd(), in.Even(), QUDA_ODD_PARITY, in.Odd(), -kappa);
75  DslashXpay(out.Even(), in.Odd(), QUDA_EVEN_PARITY, in.Even(), -kappa);
76  }
77 
79  {
80  checkFullSpinor(out, in);
81 
82  bool reset = newTmp(&tmp1, in);
83  checkFullSpinor(*tmp1, in);
84 
85  M(*tmp1, in);
86  Mdag(out, *tmp1);
87 
88  deleteTmp(&tmp1, reset);
89  }
90 
93  const QudaSolutionType solType) const
94  {
95  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
96  errorQuda("Preconditioned solution requires a preconditioned solve_type");
97  }
98 
99  src = &b;
100  sol = &x;
101  }
102 
104  const QudaSolutionType solType) const
105  {
106  // do nothing
107  }
108 
110  : DiracWilson(param)
111  {
112 
113  }
114 
116  : DiracWilson(dirac)
117  {
118 
119  }
120 
122  {
123 
124  }
125 
127  {
128  if (&dirac != this) {
129  DiracWilson::operator=(dirac);
130  }
131  return *this;
132  }
133 
135  {
136  double kappa2 = -kappa*kappa;
137 
138  bool reset = newTmp(&tmp1, in);
139 
141  Dslash(*tmp1, in, QUDA_ODD_PARITY);
142  DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, in, kappa2);
143  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
145  DslashXpay(out, *tmp1, QUDA_ODD_PARITY, in, kappa2);
146  } else {
147  errorQuda("MatPCType %d not valid for DiracWilsonPC", matpcType);
148  }
149 
150  deleteTmp(&tmp1, reset);
151  }
152 
154  {
155 #ifdef MULTI_GPU
156  bool reset = newTmp(&tmp2, in);
157  M(*tmp2, in);
158  Mdag(out, *tmp2);
159  deleteTmp(&tmp2, reset);
160 #else
161  M(out, in);
162  Mdag(out, out);
163 #endif
164  }
165 
168  const QudaSolutionType solType) const
169  {
170  // we desire solution to preconditioned system
171  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
172  src = &b;
173  sol = &x;
174  } else {
175  // we desire solution to full system
177  // src = b_e + k D_eo b_o
178  DslashXpay(x.Odd(), b.Odd(), QUDA_EVEN_PARITY, b.Even(), kappa);
179  src = &(x.Odd());
180  sol = &(x.Even());
181  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
182  // src = b_o + k D_oe b_e
183  DslashXpay(x.Even(), b.Even(), QUDA_ODD_PARITY, b.Odd(), kappa);
184  src = &(x.Even());
185  sol = &(x.Odd());
186  } else {
187  errorQuda("MatPCType %d not valid for DiracWilsonPC", matpcType);
188  }
189  // here we use final solution to store parity solution and parity source
190  // b is now up for grabs if we want
191  }
192 
193  }
194 
196  const QudaSolutionType solType) const
197  {
198  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
199  return;
200  }
201 
202  // create full solution
203 
204  checkFullSpinor(x, b);
206  // x_o = b_o + k D_oe x_e
207  DslashXpay(x.Odd(), x.Even(), QUDA_ODD_PARITY, b.Odd(), kappa);
208  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
209  // x_e = b_e + k D_eo x_o
210  DslashXpay(x.Even(), x.Odd(), QUDA_EVEN_PARITY, b.Even(), kappa);
211  } else {
212  errorQuda("MatPCType %d not valid for DiracWilsonPC", matpcType);
213  }
214  }
215 
216 } // namespace quda
FaceBuffer face1
Definition: dirac_quda.h:148
cudaGaugeField & gauge
Definition: dirac_quda.h:88
void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
virtual void checkParitySpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const
Definition: dirac.cpp:84
unsigned long long flops
Definition: dirac_quda.h:93
virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, const QudaSolutionType) const
void prepare(cudaColorSpinorField *&src, cudaColorSpinorField *&sol, cudaColorSpinorField &x, cudaColorSpinorField &b, const QudaSolutionType) const
DiracWilson(const DiracParam &param)
DiracWilsonPC & operator=(const DiracWilsonPC &dirac)
DiracWilsonPC(const DiracParam &param)
#define errorQuda(...)
Definition: util_quda.h:73
bool newTmp(cudaColorSpinorField **, const cudaColorSpinorField &) const
Definition: dirac.cpp:51
TimeProfile profile
Definition: dirac_quda.h:104
DiracWilson & operator=(const DiracWilson &dirac)
cudaGaugeField * gauge
Definition: dirac_quda.h:30
int Ls
Definition: test_util.cpp:40
QudaGaugeParam param
Definition: pack_test.cpp:17
void checkSpinorAlias(const cudaColorSpinorField &, const cudaColorSpinorField &) const
Definition: dirac.cpp:129
cudaColorSpinorField & Odd() const
int commDim[QUDA_MAX_DIM]
Definition: dirac_quda.h:102
virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaParity parity) const
cpuColorSpinorField * in
void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
enum QudaSolutionType_s QudaSolutionType
void deleteTmp(cudaColorSpinorField **, const bool &reset) const
Definition: dirac.cpp:59
Dirac * dirac
Definition: dslash_test.cpp:45
QudaDagType dagger
Definition: dirac_quda.h:92
enum QudaParity_s QudaParity
virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
int x[4]
void wilsonDslashCuda(cudaColorSpinorField *out, const cudaGaugeField &gauge, const cudaColorSpinorField *in, const int oddBit, const int daggerBit, const cudaColorSpinorField *x, const double &k, const int *commDim, TimeProfile &profile, const QudaDslashPolicy &dslashPolicy=QUDA_DSLASH2)
double kappa
Definition: dirac_quda.h:89
QudaMatPCType matpcType
Definition: dirac_quda.h:91
void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, const QudaSolutionType) const
FaceBuffer face2
Definition: dirac_quda.h:148
cudaColorSpinorField * tmp2
Definition: dirac_quda.h:95
virtual void checkFullSpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const
Definition: dirac.cpp:121
Dirac & operator=(const Dirac &dirac)
Definition: dirac.cpp:32
virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaParity parity, const cudaColorSpinorField &x, const double &k) const
cpuColorSpinorField * out
cudaColorSpinorField * tmp1
Definition: dirac_quda.h:94
virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
void Mdag(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
Definition: dirac.cpp:68
virtual ~DiracWilson()
virtual void prepare(cudaColorSpinorField *&src, cudaColorSpinorField *&sol, cudaColorSpinorField &x, cudaColorSpinorField &b, const QudaSolutionType) const
const QudaParity parity
Definition: dslash_test.cpp:29
void * gauge[4]
Definition: su3_test.cpp:15
cudaColorSpinorField & Even() const