QUDA  v0.5.0
A library for QCD on GPUs
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
dirac_staggered.cpp
Go to the documentation of this file.
1 #include <dirac_quda.h>
2 #include <blas_quda.h>
3 
4 namespace quda {
5 
7  Dirac(param), fatGauge(*(param.fatGauge)), longGauge(*(param.longGauge)),
8  face(param.fatGauge->X(), 4, 6, 3, param.fatGauge->Precision())
9  //FIXME: this may break mixed precision multishift solver since may not have fatGauge initializeed yet
10  {
12  }
13 
15  fatGauge(dirac.fatGauge), longGauge(dirac.longGauge), face(dirac.face)
16  {
18  }
19 
21 
23  {
24  if (&dirac != this) {
25  Dirac::operator=(dirac);
26  fatGauge = dirac.fatGauge;
27  longGauge = dirac.longGauge;
28  face = dirac.face;
29  }
30  return *this;
31  }
32 
34  {
35  if (in.Precision() != out.Precision()) {
36  errorQuda("Input and output spinor precisions don't match in dslash_quda");
37  }
38 
39  if (in.Stride() != out.Stride()) {
40  errorQuda("Input %d and output %d spinor strides don't match in dslash_quda", in.Stride(), out.Stride());
41  }
42 
44  errorQuda("ColorSpinorFields are not single parity, in = %d, out = %d",
45  in.SiteSubset(), out.SiteSubset());
46  }
47 
48  if ((out.Volume() != 2*fatGauge.VolumeCB() && out.SiteSubset() == QUDA_FULL_SITE_SUBSET) ||
49  (out.Volume() != fatGauge.VolumeCB() && out.SiteSubset() == QUDA_PARITY_SITE_SUBSET) ) {
50  errorQuda("Spinor volume %d doesn't match gauge volume %d", out.Volume(), fatGauge.VolumeCB());
51  }
52  }
53 
54 
56  const QudaParity parity) const
57  {
58  checkParitySpinor(in, out);
59 
61  setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda
62  staggeredDslashCuda(&out, fatGauge, longGauge, &in, parity, dagger, 0, 0, commDim);
63 
64  flops += 1146ll*in.Volume();
65  }
66 
69  const double &k) const
70  {
71  checkParitySpinor(in, out);
72 
74  setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda
75  staggeredDslashCuda(&out, fatGauge, longGauge, &in, parity, dagger, &x, k, commDim);
76 
77  flops += 1158ll*in.Volume();
78  }
79 
80  // Full staggered operator
82  {
83  bool reset = newTmp(&tmp1, in.Even());
84 
85  DslashXpay(out.Even(), in.Odd(), QUDA_EVEN_PARITY, *tmp1, 2*mass);
86  DslashXpay(out.Odd(), in.Even(), QUDA_ODD_PARITY, *tmp1, 2*mass);
87 
88  deleteTmp(&tmp1, reset);
89  }
90 
92  {
93  bool reset = newTmp(&tmp1, in);
94 
95  cudaColorSpinorField* mytmp = dynamic_cast<cudaColorSpinorField*>(&(tmp1->Even()));
96  cudaColorSpinorField* ineven = dynamic_cast<cudaColorSpinorField*>(&(in.Even()));
97  cudaColorSpinorField* inodd = dynamic_cast<cudaColorSpinorField*>(&(in.Odd()));
98  cudaColorSpinorField* outeven = dynamic_cast<cudaColorSpinorField*>(&(out.Even()));
99  cudaColorSpinorField* outodd = dynamic_cast<cudaColorSpinorField*>(&(out.Odd()));
100 
101  //even
102  Dslash(*mytmp, *ineven, QUDA_ODD_PARITY);
103  DslashXpay(*outeven, *mytmp, QUDA_EVEN_PARITY, *ineven, 4*mass*mass);
104 
105  //odd
106  Dslash(*mytmp, *inodd, QUDA_EVEN_PARITY);
107  DslashXpay(*outodd, *mytmp, QUDA_ODD_PARITY, *inodd, 4*mass*mass);
108 
109  deleteTmp(&tmp1, reset);
110  }
111 
114  const QudaSolutionType solType) const
115  {
116  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
117  errorQuda("Preconditioned solution requires a preconditioned solve_type");
118  }
119 
120  src = &b;
121  sol = &x;
122  }
123 
125  const QudaSolutionType solType) const
126  {
127  // do nothing
128  }
129 
130 
132  : DiracStaggered(param)
133  {
134 
135  }
136 
138  : DiracStaggered(dirac)
139  {
140 
141  }
142 
144  {
145 
146  }
147 
149  {
150  if (&dirac != this) {
152  }
153 
154  return *this;
155  }
156 
158  {
159  errorQuda("DiracStaggeredPC::M() is not implemented\n");
160  }
161 
163  {
164  bool reset = newTmp(&tmp1, in);
165 
167  QudaParity other_parity = QUDA_INVALID_PARITY;
169  parity = QUDA_EVEN_PARITY;
170  other_parity = QUDA_ODD_PARITY;
171  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
172  parity = QUDA_ODD_PARITY;
173  other_parity = QUDA_EVEN_PARITY;
174  } else {
175  errorQuda("Invalid matpcType(%d) in function\n", matpcType);
176  }
177  Dslash(*tmp1, in, other_parity);
178  DslashXpay(out, *tmp1, parity, in, 4*mass*mass);
179 
180  deleteTmp(&tmp1, reset);
181  }
182 
185  const QudaSolutionType solType) const
186  {
187  src = &b;
188  sol = &x;
189  }
190 
192  const QudaSolutionType solType) const
193  {
194  // do nothing
195  }
196 
197 } // namespace quda