QUDA  v1.1.0
A library for QCD on GPUs
dirac_improved_staggered_kd.cpp
Go to the documentation of this file.
1 #include <dirac_quda.h>
2 #include <blas_quda.h>
3 #include <multigrid.h>
5 
6 namespace quda
7 {
8 
11  Xinv(param.xInvKD)
12  {
13  }
14 
17  Xinv(dirac.Xinv)
18  {
19  }
20 
22 
24  {
25  if (&dirac != this) {
27  Xinv = dirac.Xinv;
28  }
29  return *this;
30  }
31 
33  {
34  if (in.Ndim() != 5 || out.Ndim() != 5) { errorQuda("Staggered dslash requires 5-d fermion fields"); }
35 
36  if (in.Precision() != out.Precision()) {
37  errorQuda("Input and output spinor precisions don't match in dslash_quda");
38  }
39 
41  errorQuda("ColorSpinorFields are not full parity, in = %d, out = %d", in.SiteSubset(), out.SiteSubset());
42  }
43 
44  if (out.Volume() / out.X(4) != 2 * gauge->VolumeCB() && out.SiteSubset() == QUDA_FULL_SITE_SUBSET) {
45  errorQuda("Spinor volume %lu doesn't match gauge volume %lu", out.Volume(), gauge->VolumeCB());
46  }
47  }
48 
50  {
51  errorQuda("The improved staggered Kahler-Dirac operator does not have a single parity form");
52  }
53 
55  const ColorSpinorField &x, const double &k) const
56  {
57  errorQuda("The improved staggered Kahler-Dirac operator does not have a single parity form");
58  }
59 
60  // Full staggered operator
62  {
63  // Due to the staggered convention, the staggered part is applying
64  // ( 2m -D_eo ) (x_e) = (b_e)
65  // ( -D_oe 2m ) (x_o) = (b_o)
66  // ... but under the hood we need to catch the zero mass case.
67 
68  // TODO: add left vs right precond
69 
70  checkFullSpinor(out, in);
71 
72  bool reset = newTmp(&tmp2, in);
73 
74  bool right_block_precond = false;
75 
76  if (right_block_precond) {
77  if (dagger == QUDA_DAG_NO) {
78  // K-D op is right-block preconditioned
80  flops += (8ll * 48 - 2ll) * 48 * in.Volume() / 16; // for 2^4 block
81  if (mass == 0.) {
83  commDim, profile);
84  flops += 1146ll * in.Volume();
85  } else {
87  commDim, profile);
88  flops += 1158ll * in.Volume();
89  }
90  } else { // QUDA_DAG_YES
91 
92  if (mass == 0.) {
94  profile);
95  flops += 1146ll * in.Volume();
96  } else {
98  profile);
99  flops += 1158ll * in.Volume();
100  }
102  flops += (8ll * 48 - 2ll) * 48 * in.Volume() / 16; // for 2^4 block
103  }
104  } else { // left preconditioned
105  if (dagger == QUDA_DAG_NO) {
106 
107  if (mass == 0.) {
109  profile);
110  flops += 1146ll * in.Volume();
111  } else {
113  profile);
114  flops += 1158ll * in.Volume();
115  }
116 
118  flops += (8ll * 48 - 2ll) * 48 * in.Volume() / 16; // for 2^4 block
119 
120  } else { // QUDA_DAG_YES
121 
123  flops += (8ll * 48 - 2ll) * 48 * in.Volume() / 16; // for 2^4 block
124 
125  if (mass == 0.) {
127  commDim, profile);
128  flops += 1146ll * in.Volume();
129  } else {
131  commDim, profile);
132  flops += 1158ll * in.Volume();
133  }
134  }
135  }
136 
137  deleteTmp(&tmp2, reset);
138  }
139 
141  {
142  bool reset = newTmp(&tmp1, in);
143 
144  M(*tmp1, in);
145  Mdag(out, *tmp1);
146 
147  deleteTmp(&tmp1, reset);
148  }
149 
151  {
153  }
154 
156  ColorSpinorField &b, const QudaSolutionType solType) const
157  {
158  // TODO: technically KD is a different type of preconditioning.
159  // Should we support "preparing" and "reconstructing"?
160  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
161  errorQuda("Preconditioned solution requires a preconditioned solve_type");
162  }
163 
164  src = &b;
165  sol = &x;
166  }
167 
169  ColorSpinorField &b, const QudaSolutionType solType) const
170  {
171  // TODO: technically KD is a different type of preconditioning.
172  // Should we support "preparing" and "reconstructing"?
173  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
174  errorQuda("Preconditioned solution requires a preconditioned solve_type");
175  }
176 
177  checkFullSpinor(x, b);
178 
179  bool right_block_precond = false;
180 
181  if (right_block_precond) {
182  // need to modify the solution
183  src = &b;
184  sol = &x;
185  } else {
186  // need to modify rhs
187  bool reset = newTmp(&tmp1, b);
188 
189  KahlerDiracInv(*tmp1, b);
190  b = *tmp1;
191 
192  deleteTmp(&tmp1, reset);
193  sol = &x;
194  src = &b;
195  }
196  }
197 
199  const QudaSolutionType solType) const
200  {
201  // do nothing
202 
203  // TODO: technically KD is a different type of preconditioning.
204  // Should we support "preparing" and "reconstructing"?
205  }
206 
208  const QudaSolutionType solType) const
209  {
210  // do nothing
211 
212  // TODO: technically KD is a different type of preconditioning.
213  // Should we support "preparing" and "reconstructing"?
214 
215  checkFullSpinor(x, b);
216 
217  bool right_block_precond = false;
218 
219  if (right_block_precond) {
220  bool reset = newTmp(&tmp1, b.Even());
221 
222  KahlerDiracInv(*tmp1, x);
223  x = *tmp1;
224 
225  deleteTmp(&tmp1, reset);
226  }
227  // nothing required for left block preconditioning
228  }
229 
231  cudaGaugeField *long_gauge_in, cudaCloverField *clover_in)
232  {
233  Dirac::updateFields(fat_gauge_in, nullptr, nullptr, nullptr);
234  fatGauge = fat_gauge_in;
235  longGauge = long_gauge_in;
236 
237  // Recompute Xinv (I guess we should do that here?)
239  }
240 
242  double mass, double mu, double mu_factor) const
243  {
244  errorQuda("Staggered KD operators do not support MG coarsening yet");
245 
246  // if (T.getTransferType() != QUDA_TRANSFER_AGGREGATE)
247  // errorQuda("Staggered KD operators only support aggregation coarsening");
248  // StaggeredCoarseOp(Y, X, T, *fatGauge, Xinv, mass, QUDA_ASQTADKD_DIRAC, QUDA_MATPC_INVALID);
249  }
250 
252  {
254  if (Xinv != nullptr) Xinv->prefetch(mem_space, stream);
255  }
256 
257 } // namespace quda
QudaSiteSubset SiteSubset() const
const ColorSpinorField & Even() const
const int * X() const
unsigned long long flops
Definition: dirac_quda.h:150
bool newTmp(ColorSpinorField **, const ColorSpinorField &) const
Definition: dirac.cpp:72
virtual void updateFields(cudaGaugeField *gauge_in, cudaGaugeField *fat_gauge_in, cudaGaugeField *long_gauge_in, cudaCloverField *clover_in)
Update the internal gauge, fat gauge, long gauge, clover field pointer as appropriate....
Definition: dirac_quda.h:360
cudaGaugeField * gauge
Definition: dirac_quda.h:144
double mass
Definition: dirac_quda.h:146
void deleteTmp(ColorSpinorField **, const bool &reset) const
Definition: dirac.cpp:83
ColorSpinorField * tmp1
Definition: dirac_quda.h:151
TimeProfile profile
Definition: dirac_quda.h:161
QudaDagType dagger
Definition: dirac_quda.h:149
virtual void checkFullSpinor(const ColorSpinorField &, const ColorSpinorField &) const
check full spinors are compatible (check geometry ?)
Definition: dirac.cpp:138
ColorSpinorField * tmp2
Definition: dirac_quda.h:152
int commDim[QUDA_MAX_DIM]
Definition: dirac_quda.h:159
void Mdag(ColorSpinorField &out, const ColorSpinorField &in) const
Apply Mdag (daggered operator of M.
Definition: dirac.cpp:92
DiracImprovedStaggered & operator=(const DiracImprovedStaggered &dirac)
cudaGaugeField * fatGauge
Definition: dirac_quda.h:1356
virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream=0) const
If managed memory and prefetch is enabled, prefetch all relevant memory fields (fat+long links,...
cudaGaugeField * longGauge
Definition: dirac_quda.h:1357
DiracImprovedStaggeredKD(const DiracParam &param)
void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu=0., double mu_factor=0.) const
Create the coarse improved staggered KD operator.
virtual void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, const ColorSpinorField &x, const double &k) const
Xpay version of Dslash.
virtual void prepareSpecialMG(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType solType) const
virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const
Apply MdagM operator which may be optimized.
virtual void checkParitySpinor(const ColorSpinorField &, const ColorSpinorField &) const
Check parity spinors are usable (check geometry ?)
virtual void reconstruct(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType) const
virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream=0) const
If managed memory and prefetch is enabled, prefetch all relevant memory fields (gauge,...
void KahlerDiracInv(ColorSpinorField &out, const ColorSpinorField &in) const
virtual void prepare(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType) const
DiracImprovedStaggeredKD & operator=(const DiracImprovedStaggeredKD &dirac)
virtual void reconstructSpecialMG(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType solType) const
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const
Apply M for the dirac op. E.g. the Schur Complement operator.
virtual void updateFields(cudaGaugeField *gauge_in, cudaGaugeField *fat_gauge_in, cudaGaugeField *long_gauge_in, cudaCloverField *clover_in)
Update the internal gauge, fat gauge, long gauge, clover field pointer as appropriate....
virtual void Dslash(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const
apply 'dslash' operator for the DiracOp. This may be e.g. AD
QudaPrecision Precision() const
size_t VolumeCB() const
void prefetch(QudaFieldLocation mem_space, qudaStream_t stream=0) const
If managed memory and prefetch is enabled, prefetch the gauge field and buffers to the CPU or the GPU...
double kappa
double mass
double mu
quda::mgarray< double > mu_factor
GaugeCovDev * dirac
Definition: covdev_test.cpp:42
QudaParity parity
Definition: covdev_test.cpp:40
@ QUDA_DAG_NO
Definition: enum_quda.h:223
@ QUDA_DAG_YES
Definition: enum_quda.h:223
@ QUDA_FULL_SITE_SUBSET
Definition: enum_quda.h:333
@ QUDA_INVALID_PARITY
Definition: enum_quda.h:284
enum QudaSolutionType_s QudaSolutionType
enum QudaFieldLocation_s QudaFieldLocation
@ QUDA_MATPC_SOLUTION
Definition: enum_quda.h:159
@ QUDA_MATPCDAG_MATPC_SOLUTION
Definition: enum_quda.h:161
enum QudaParity_s QudaParity
void BuildStaggeredKahlerDiracInverse(GaugeField &Xinv, const cudaGaugeField &gauge, const double mass)
Build the Kahler-Dirac inverse block for KD operators.
qudaStream_t * stream
void ApplyStaggeredKahlerDiracInverse(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &Xinv, bool dagger)
Apply the (improved) staggered Kahler-Dirac inverse block to a color-spinor field.
void ApplyImprovedStaggered(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const GaugeField &L, double a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile)
Apply the improved staggered dslash operator to a color-spinor field.
QudaGaugeParam param
Definition: pack_test.cpp:18
cudaStream_t qudaStream_t
Definition: quda_api.h:9
#define errorQuda(...)
Definition: util_quda.h:120