QUDA  v1.1.0
A library for QCD on GPUs
dirac.cpp
Go to the documentation of this file.
1 #include <dirac_quda.h>
2 #include <dslash_quda.h>
3 #include <blas_quda.h>
4 
5 #include <iostream>
6 
7 namespace quda {
8 
9  // FIXME: At the moment, it's unsafe for more than one Dirac operator to be active unless
10  // they all have the same volume, etc. (used to initialize the various CUDA constants).
11 
13  gauge(param.gauge),
14  kappa(param.kappa),
15  mass(param.mass),
17  matpcType(param.matpcType),
19  flops(0),
20  tmp1(param.tmp1),
21  tmp2(param.tmp2),
22  type(param.type),
23  halo_precision(param.halo_precision),
24  profile("Dirac", false)
25  {
26  for (int i=0; i<4; i++) commDim[i] = param.commDim[i];
27  }
28 
30  gauge(dirac.gauge),
31  kappa(dirac.kappa),
33  matpcType(dirac.matpcType),
35  flops(0),
36  tmp1(dirac.tmp1),
37  tmp2(dirac.tmp2),
38  type(dirac.type),
39  halo_precision(dirac.halo_precision),
40  profile("Dirac", false)
41  {
42  for (int i=0; i<4; i++) commDim[i] = dirac.commDim[i];
43  }
44 
45  // Destroy
48  }
49 
50  // Assignment
52  {
53  if (&dirac != this) {
54  gauge = dirac.gauge;
55  kappa = dirac.kappa;
56  laplace3D = dirac.laplace3D;
57  matpcType = dirac.matpcType;
58  dagger = dirac.dagger;
59  flops = 0;
60  tmp1 = dirac.tmp1;
61  tmp2 = dirac.tmp2;
62 
63  for (int i=0; i<4; i++) commDim[i] = dirac.commDim[i];
64 
65  profile = dirac.profile;
66 
67  if (type != dirac.type) errorQuda("Trying to copy between incompatible types %d %d", type, dirac.type);
68  }
69  return *this;
70  }
71 
73  if (*tmp) return false;
75  param.create = QUDA_ZERO_FIELD_CREATE; // need to zero elements else padded region will be junk
76 
77  if (typeid(a) == typeid(cudaColorSpinorField)) *tmp = new cudaColorSpinorField(a, param);
78  else *tmp = new cpuColorSpinorField(param);
79 
80  return true;
81  }
82 
83  void Dirac::deleteTmp(ColorSpinorField **a, const bool &reset) const {
84  if (reset) {
85  delete *a;
86  *a = nullptr;
87  }
88  }
89 
90 #define flip(x) (x) = ((x) == QUDA_DAG_YES ? QUDA_DAG_NO : QUDA_DAG_YES)
91 
92  void Dirac::Mdag(ColorSpinorField &out, const ColorSpinorField &in) const
93  {
94  flip(dagger);
95  M(out, in);
96  flip(dagger);
97  }
98 
99  void Dirac::MMdag(ColorSpinorField &out, const ColorSpinorField &in) const
100  {
101  flip(dagger);
102  MdagM(out, in);
103  flip(dagger);
104  }
105 
106 #undef flip
107 
109  {
111  in.Nspin() == 4) {
112  errorQuda("CUDA Dirac operator requires UKQCD basis, out = %d, in = %d",
113  out.GammaBasis(), in.GammaBasis());
114  }
115 
117  errorQuda("ColorSpinorFields are not single parity: in = %d, out = %d",
118  in.SiteSubset(), out.SiteSubset());
119  }
120 
121  if (!static_cast<const cudaColorSpinorField&>(in).isNative()) errorQuda("Input field is not in native order");
122  if (!static_cast<const cudaColorSpinorField&>(out).isNative()) errorQuda("Output field is not in native order");
123 
124  if (out.Ndim() != 5) {
125  if ((out.Volume() != gauge->Volume() && out.SiteSubset() == QUDA_FULL_SITE_SUBSET) ||
126  (out.Volume() != gauge->VolumeCB() && out.SiteSubset() == QUDA_PARITY_SITE_SUBSET) ) {
127  errorQuda("Spinor volume %lu doesn't match gauge volume %lu", out.Volume(), gauge->VolumeCB());
128  }
129  } else {
130  // Domain wall fermions, compare 4d volumes not 5d
131  if ((out.Volume()/out.X(4) != gauge->Volume() && out.SiteSubset() == QUDA_FULL_SITE_SUBSET) ||
132  (out.Volume()/out.X(4) != gauge->VolumeCB() && out.SiteSubset() == QUDA_PARITY_SITE_SUBSET) ) {
133  errorQuda("Spinor volume %lu doesn't match gauge volume %lu", out.Volume(), gauge->VolumeCB());
134  }
135  }
136  }
137 
138  void Dirac::checkFullSpinor(const ColorSpinorField &out, const ColorSpinorField &in) const
139  {
141  errorQuda("ColorSpinorFields are not full fields: in = %d, out = %d",
142  in.SiteSubset(), out.SiteSubset());
143  }
144  }
145 
147  if (a.V() == b.V()) errorQuda("Aliasing pointers");
148  }
149 
150  // Dirac operator factory
152  {
153  if (param.type == QUDA_WILSON_DIRAC) {
154  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracWilson operator\n");
155  return new DiracWilson(param);
156  } else if (param.type == QUDA_WILSONPC_DIRAC) {
157  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracWilsonPC operator\n");
158  return new DiracWilsonPC(param);
159  } else if (param.type == QUDA_CLOVER_DIRAC) {
160  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracClover operator\n");
161  return new DiracClover(param);
163  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracCloverHasenbuschTwist operator\n");
164  return new DiracCloverHasenbuschTwist(param);
166  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracCloverHasenbuschTwistPC operator\n");
168  } else if (param.type == QUDA_CLOVERPC_DIRAC) {
169  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracCloverPC operator\n");
170  return new DiracCloverPC(param);
171  } else if (param.type == QUDA_DOMAIN_WALL_DIRAC) {
172  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracDomainWall operator\n");
173  return new DiracDomainWall(param);
174  } else if (param.type == QUDA_DOMAIN_WALLPC_DIRAC) {
175  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracDomainWallPC operator\n");
176  return new DiracDomainWallPC(param);
177  } else if (param.type == QUDA_DOMAIN_WALL_4D_DIRAC) {
178  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracDomainWall4D operator\n");
179  return new DiracDomainWall4D(param);
180  } else if (param.type == QUDA_DOMAIN_WALL_4DPC_DIRAC) {
181  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracDomainWall4DPC operator\n");
182  return new DiracDomainWall4DPC(param);
183  } else if (param.type == QUDA_MOBIUS_DOMAIN_WALL_DIRAC) {
184  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracMobius operator\n");
185  return new DiracMobius(param);
187  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracMobiusPC operator\n");
188  return new DiracMobiusPC(param);
190  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracMobiusEofa operator\n");
191  return new DiracMobiusEofa(param);
193  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracMobiusEofaPC operator\n");
194  return new DiracMobiusEofaPC(param);
195  } else if (param.type == QUDA_STAGGERED_DIRAC) {
196  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracStaggered operator\n");
197  return new DiracStaggered(param);
198  } else if (param.type == QUDA_STAGGEREDPC_DIRAC) {
199  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracStaggeredPC operator\n");
200  return new DiracStaggeredPC(param);
201  } else if (param.type == QUDA_STAGGEREDKD_DIRAC) {
202  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracStaggeredKD operator\n");
203  return new DiracStaggeredKD(param);
204  } else if (param.type == QUDA_ASQTAD_DIRAC) {
205  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracImprovedStaggered operator\n");
206  return new DiracImprovedStaggered(param);
207  } else if (param.type == QUDA_ASQTADPC_DIRAC) {
208  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracImprovedStaggeredPC operator\n");
209  return new DiracImprovedStaggeredPC(param);
210  } else if (param.type == QUDA_ASQTADKD_DIRAC) {
211  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracImprovedStaggeredKD operator\n");
212  return new DiracImprovedStaggeredKD(param);
213  } else if (param.type == QUDA_TWISTED_CLOVER_DIRAC) {
214  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracTwistedClover operator (%d flavor(s))\n", param.Ls);
215  if (param.Ls == 1) {
216  return new DiracTwistedClover(param, 4);
217  } else {
218  errorQuda("Cannot create DiracTwistedClover operator for %d flavors\n", param.Ls);
219  }
220  } else if (param.type == QUDA_TWISTED_CLOVERPC_DIRAC) {
221  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracTwistedCloverPC operator (%d flavor(s))\n", param.Ls);
222  if (param.Ls == 1) {
223  return new DiracTwistedCloverPC(param, 4);
224  } else {
225  errorQuda("Cannot create DiracTwistedCloverPC operator for %d flavors\n", param.Ls);
226  }
227  } else if (param.type == QUDA_TWISTED_MASS_DIRAC) {
228  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracTwistedMass operator (%d flavor(s))\n", param.Ls);
229  if (param.Ls == 1) return new DiracTwistedMass(param, 4);
230  else return new DiracTwistedMass(param, 5);
231  } else if (param.type == QUDA_TWISTED_MASSPC_DIRAC) {
233  printfQuda("Creating a DiracTwistedMassPC operator (%d flavor(s))\n", param.Ls);
234  if (param.Ls == 1)
235  return new DiracTwistedMassPC(param, 4);
236  else
237  return new DiracTwistedMassPC(param, 5);
238  } else if (param.type == QUDA_COARSE_DIRAC) {
239  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracCoarse operator\n");
240  return new DiracCoarse(param);
241  } else if (param.type == QUDA_COARSEPC_DIRAC) {
242  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracCoarsePC operator\n");
243  return new DiracCoarsePC(param);
244  } else if (param.type == QUDA_GAUGE_COVDEV_DIRAC) {
245  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a GaugeCovDev operator\n");
246  return new GaugeCovDev(param);
247  } else if (param.type == QUDA_GAUGE_LAPLACE_DIRAC) {
248  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a GaugeLaplace operator\n");
249  return new GaugeLaplace(param);
250  } else if (param.type == QUDA_GAUGE_LAPLACEPC_DIRAC) {
251  if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a GaugeLaplacePC operator\n");
252  return new GaugeLaplacePC(param);
253  } else {
254  errorQuda("Unsupported Dirac type %d", param.type);
255  }
256 
257  return nullptr;
258  }
259 
260  // Count the number of stencil applications per dslash application.
262  {
263  int steps = 0;
264  switch (type)
265  {
266  case QUDA_COARSE_DIRAC: // single fused operator
269  steps = 1;
270  break;
271  case QUDA_WILSON_DIRAC:
272  case QUDA_CLOVER_DIRAC:
277  case QUDA_ASQTAD_DIRAC:
280  steps = 2; // For D_{eo} and D_{oe} piece.
281  break;
282  case QUDA_WILSONPC_DIRAC:
283  case QUDA_CLOVERPC_DIRAC:
289  case QUDA_ASQTADPC_DIRAC:
292  case QUDA_COARSEPC_DIRAC:
294  steps = 2;
295  break;
296  default:
297  errorQuda("Unsupported Dslash type %d.\n", type);
298  steps = 0;
299  break;
300  }
301 
302  return steps;
303  }
304 
306  {
307  if (gauge) gauge->prefetch(mem_space, stream);
308  if (tmp1) tmp1->prefetch(mem_space, stream);
309  if (tmp2) tmp2->prefetch(mem_space, stream);
310  }
311 
312 } // namespace quda
QudaSiteSubset SiteSubset() const
QudaGammaBasis GammaBasis() const
const int * X() const
virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const =0
Apply M for the dirac op. E.g. the Schur Complement operator.
unsigned long long flops
Definition: dirac_quda.h:150
QudaDiracType type
Definition: dirac_quda.h:153
bool newTmp(ColorSpinorField **, const ColorSpinorField &) const
Definition: dirac.cpp:72
void MMdag(ColorSpinorField &out, const ColorSpinorField &in) const
Apply Normal Operator.
Definition: dirac.cpp:99
double kappa
Definition: dirac_quda.h:145
virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const =0
Apply MdagM operator which may be optimized.
virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream=0) const
If managed memory and prefetch is enabled, prefetch the gauge field and temporary spinors to the CPU ...
Definition: dirac.cpp:305
QudaMatPCType matpcType
Definition: dirac_quda.h:148
cudaGaugeField * gauge
Definition: dirac_quda.h:144
void deleteTmp(ColorSpinorField **, const bool &reset) const
Definition: dirac.cpp:83
Dirac(const DiracParam &param)
Definition: dirac.cpp:12
virtual ~Dirac()
Definition: dirac.cpp:46
virtual void checkParitySpinor(const ColorSpinorField &, const ColorSpinorField &) const
Check parity spinors are usable (check geometry ?)
Definition: dirac.cpp:108
ColorSpinorField * tmp1
Definition: dirac_quda.h:151
static Dirac * create(const DiracParam &param)
Creates a subclass from parameters.
Definition: dirac.cpp:151
TimeProfile profile
Definition: dirac_quda.h:161
QudaDagType dagger
Definition: dirac_quda.h:149
void checkSpinorAlias(const ColorSpinorField &, const ColorSpinorField &) const
check spinors do not alias
Definition: dirac.cpp:146
Dirac & operator=(const Dirac &dirac)
Definition: dirac.cpp:51
virtual void checkFullSpinor(const ColorSpinorField &, const ColorSpinorField &) const
check full spinors are compatible (check geometry ?)
Definition: dirac.cpp:138
ColorSpinorField * tmp2
Definition: dirac_quda.h:152
int getStencilSteps() const
I have no idea what this does.
Definition: dirac.cpp:261
int commDim[QUDA_MAX_DIM]
Definition: dirac_quda.h:159
void Mdag(ColorSpinorField &out, const ColorSpinorField &in) const
Apply Mdag (daggered operator of M.
Definition: dirac.cpp:92
Full Covariant Derivative operator. Although not a Dirac operator per se, it's a linear operator so i...
Definition: dirac_quda.h:1858
Full Gauge Laplace operator. Although not a Dirac operator per se, it's a linear operator so it's con...
Definition: dirac_quda.h:1805
Even-odd preconditioned Gauge Laplace operator.
Definition: dirac_quda.h:1833
size_t Volume() const
size_t VolumeCB() const
virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream=0) const
If managed memory and prefetch is enabled, prefetch all relevant memory fields to the current device ...
void Print()
Definition: timer.cpp:7
void prefetch(QudaFieldLocation mem_space, qudaStream_t stream=0) const
If managed memory and prefetch is enabled, prefetch the gauge field and buffers to the CPU or the GPU...
double kappa
int laplace3D
double mass
bool dagger
GaugeCovDev * dirac
Definition: covdev_test.cpp:42
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:34
#define flip(x)
Definition: dirac.cpp:90
@ QUDA_TWISTED_MASSPC_DIRAC
Definition: enum_quda.h:312
@ QUDA_GAUGE_LAPLACE_DIRAC
Definition: enum_quda.h:317
@ QUDA_GAUGE_COVDEV_DIRAC
Definition: enum_quda.h:319
@ QUDA_TWISTED_CLOVERPC_DIRAC
Definition: enum_quda.h:314
@ QUDA_MOBIUS_DOMAIN_WALLPC_EOFA_DIRAC
Definition: enum_quda.h:304
@ QUDA_CLOVER_HASENBUSCH_TWIST_DIRAC
Definition: enum_quda.h:295
@ QUDA_TWISTED_MASS_DIRAC
Definition: enum_quda.h:311
@ QUDA_COARSEPC_DIRAC
Definition: enum_quda.h:316
@ QUDA_STAGGERED_DIRAC
Definition: enum_quda.h:305
@ QUDA_CLOVER_HASENBUSCH_TWISTPC_DIRAC
Definition: enum_quda.h:296
@ QUDA_DOMAIN_WALL_4D_DIRAC
Definition: enum_quda.h:299
@ QUDA_ASQTAD_DIRAC
Definition: enum_quda.h:308
@ QUDA_STAGGEREDPC_DIRAC
Definition: enum_quda.h:306
@ QUDA_MOBIUS_DOMAIN_WALL_EOFA_DIRAC
Definition: enum_quda.h:303
@ QUDA_ASQTADPC_DIRAC
Definition: enum_quda.h:309
@ QUDA_COARSE_DIRAC
Definition: enum_quda.h:315
@ QUDA_GAUGE_LAPLACEPC_DIRAC
Definition: enum_quda.h:318
@ QUDA_MOBIUS_DOMAIN_WALLPC_DIRAC
Definition: enum_quda.h:302
@ QUDA_ASQTADKD_DIRAC
Definition: enum_quda.h:310
@ QUDA_TWISTED_CLOVER_DIRAC
Definition: enum_quda.h:313
@ QUDA_DOMAIN_WALL_4DPC_DIRAC
Definition: enum_quda.h:300
@ QUDA_STAGGEREDKD_DIRAC
Definition: enum_quda.h:307
@ QUDA_CLOVER_DIRAC
Definition: enum_quda.h:293
@ QUDA_MOBIUS_DOMAIN_WALL_DIRAC
Definition: enum_quda.h:301
@ QUDA_DOMAIN_WALLPC_DIRAC
Definition: enum_quda.h:298
@ QUDA_DOMAIN_WALL_DIRAC
Definition: enum_quda.h:297
@ QUDA_WILSONPC_DIRAC
Definition: enum_quda.h:292
@ QUDA_CLOVERPC_DIRAC
Definition: enum_quda.h:294
@ QUDA_WILSON_DIRAC
Definition: enum_quda.h:291
@ QUDA_DEBUG_VERBOSE
Definition: enum_quda.h:268
@ QUDA_VERBOSE
Definition: enum_quda.h:267
@ QUDA_FULL_SITE_SUBSET
Definition: enum_quda.h:333
@ QUDA_PARITY_SITE_SUBSET
Definition: enum_quda.h:332
@ QUDA_UKQCD_GAMMA_BASIS
Definition: enum_quda.h:369
enum QudaFieldLocation_s QudaFieldLocation
@ QUDA_ZERO_FIELD_CREATE
Definition: enum_quda.h:361
unsigned long long flops
bool isNative(QudaCloverFieldOrder order, QudaPrecision precision)
Definition: clover_field.h:14
qudaStream_t * stream
QudaGaugeParam param
Definition: pack_test.cpp:18
cudaStream_t qudaStream_t
Definition: quda_api.h:9
QudaLinkType type
Definition: quda.h:41
#define printfQuda(...)
Definition: util_quda.h:114
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define errorQuda(...)
Definition: util_quda.h:120