QUDA  v0.7.0
A library for QCD on GPUs
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
dirac_twisted_mass.cpp
Go to the documentation of this file.
1 #include <dirac_quda.h>
2 #include <blas_quda.h>
3 #include <iostream>
4 
5 namespace quda {
6 
7  namespace twisted {
8 #include <dslash_init.cuh>
9  }
10 
11  namespace ndegtwisted {
12 #include <dslash_init.cuh>
13  }
14 
15  namespace dslash_aux {
16 #include <dslash_init.cuh>
17  }
18 
20  : DiracWilson(param, nDim), mu(param.mu), epsilon(param.epsilon)
21  {
22  twisted::initConstants(*param.gauge,profile);
23  ndegtwisted::initConstants(*param.gauge,profile);
24  dslash_aux::initConstants(*param.gauge,profile);
25  }
26 
28  : DiracWilson(dirac), mu(dirac.mu), epsilon(dirac.epsilon)
29  {
30  twisted::initConstants(dirac.gauge,profile);
31  ndegtwisted::initConstants(dirac.gauge,profile);
32  dslash_aux::initConstants(dirac.gauge,profile);
33  }
34 
36 
38  {
39  if (&dirac != this) {
41  }
42  return *this;
43  }
44 
45  // Protected method for applying twist
47  const QudaTwistGamma5Type twistType) const
48  {
49  checkParitySpinor(out, in);
50 
52  errorQuda("Twist flavor not set %d\n", in.TwistFlavor());
53 
55  double flavor_mu = in.TwistFlavor() * mu;
56  twistGamma5Cuda(&out, &in, dagger, kappa, flavor_mu, 0.0, twistType);
57  flops += 24ll*in.Volume();
58  } else {
59  errorQuda("DiracTwistedMass::twistedApply method for flavor doublet is not implemented..\n");
60  }
61  }
62 
63 
64  // Public method to apply the twist
66  {
68  }
69 
71  {
72  checkFullSpinor(out, in);
73  if (in.TwistFlavor() != out.TwistFlavor())
74  errorQuda("Twist flavors %d %d don't match", in.TwistFlavor(), out.TwistFlavor());
75 
77  errorQuda("Twist flavor not set %d\n", in.TwistFlavor());
78  }
79 
80  // We can eliminate this temporary at the expense of more kernels (like clover)
81  cudaColorSpinorField *tmp=0; // this hack allows for tmp2 to be full or parity field
82  if (tmp2) {
83  if (tmp2->SiteSubset() == QUDA_FULL_SITE_SUBSET) tmp = &(tmp2->Even());
84  else tmp = tmp2;
85  }
86  bool reset = newTmp(&tmp, in.Even());
87 
88  twisted::setFace(face1,face2); // FIXME: temporary hack maintain C linkage for dslashCuda
89  ndegtwisted::setFace(face1,face2); // FIXME: temporary hack maintain C linkage for dslashCuda
90 
92  double a = 2.0 * kappa * in.TwistFlavor() * mu;//for direct twist (must be daggered separately)
95  flops += (1320ll+72ll)*in.Volume();
96  } else {
97  double a = -2.0 * kappa * mu; //for twist
98  double b = -2.0 * kappa * epsilon;//for twist
101 
102  flops += (1320ll+72ll+24ll)*in.Volume();//??
103  }
104  deleteTmp(&tmp, reset);
105  }
106 
108  {
109  checkFullSpinor(out, in);
110  bool reset = newTmp(&tmp1, in);
111 
112  M(*tmp1, in);
113  Mdag(out, *tmp1);
114 
115  deleteTmp(&tmp1, reset);
116  }
117 
120  const QudaSolutionType solType) const
121  {
122  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
123  errorQuda("Preconditioned solution requires a preconditioned solve_type");
124  }
125 
126  src = &b;
127  sol = &x;
128  }
129 
131  const QudaSolutionType solType) const
132  {
133  // do nothing
134  }
135 
136 
138 
140 
142  {
143 
144  }
145 
147  {
148  if (&dirac != this) {
150  }
151  return *this;
152  }
153 
154  // Public method to apply the inverse twist
156  {
158  }
159 
160  // apply hopping term, then inverse twist: (A_ee^-1 D_eo) or (A_oo^-1 D_oe),
161  // and likewise for dagger: (D^dagger_eo A_ee^-1) or (D^dagger_oe A_oo^-1)
164  {
165  checkParitySpinor(in, out);
166  checkSpinorAlias(in, out);
167 
168  if (in.TwistFlavor() != out.TwistFlavor())
169  errorQuda("Twist flavors %d %d don't match", in.TwistFlavor(), out.TwistFlavor());
171  errorQuda("Twist flavor not set %d\n", in.TwistFlavor());
172 
173  twisted::setFace(face1,face2); // FIXME: temporary hack maintain C linkage for dslashCuda
174  ndegtwisted::setFace(face1,face2); // FIXME: temporary hack maintain C linkage for dslashCuda
175 
177  double a = -2.0 * kappa * in.TwistFlavor() * mu; //for invert twist (not daggered)
178  double b = 1.0 / (1.0 + a*a); //for invert twist
179  if (!dagger || matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) {
180  twistedMassDslashCuda(&out, gauge, &in, parity, dagger, 0, QUDA_DEG_DSLASH_TWIST_INV, a, b, 0.0, 0.0, commDim, profile);
181  flops += 1392ll*in.Volume();
182  } else {
183  twistedMassDslashCuda(&out, gauge, &in, parity, dagger, 0, QUDA_DEG_TWIST_INV_DSLASH, a, b, 0.0, 0.0, commDim, profile);
184  flops += 1392ll*in.Volume();
185  }
186  } else {//TWIST doublet :
187  double a = 2.0 * kappa * mu;
188  double b = 2.0 * kappa * epsilon;
189  double c = 1.0 / (1.0 + a*a - b*b);
190 
191  if (!dagger || matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) {
192  ndegTwistedMassDslashCuda(&out, gauge, &in, parity, dagger, 0, QUDA_NONDEG_DSLASH, a, b, c, 0.0, commDim, profile);
193  flops += (1320ll+120ll)*in.Volume();//per flavor 1320+16*6(rotation per flavor)+24 (scaling per flavor)
194  } else {
195  cudaColorSpinorField *doubletTmp=0;
196  bool reset = newTmp(&doubletTmp, in);
197 
198  twistGamma5Cuda(doubletTmp, &in, dagger, -a, b, c, QUDA_TWIST_GAMMA5_INVERSE);//note a -> -a
199  ndegTwistedMassDslashCuda(&out, gauge, doubletTmp, parity, dagger, 0, QUDA_NONDEG_DSLASH, 0.0, 0.0, 1.0, 0.0, commDim, profile); //merge this!
200 
201  flops += 1440ll*in.Volume();//as for the asymmetric case
202 
203  deleteTmp(&doubletTmp, reset);
204  }
205  }
206  }
207 
208  // xpay version of the above
210  (cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaParity parity, const cudaColorSpinorField &x, const double &k) const
211  {
212  checkParitySpinor(in, out);
213  checkSpinorAlias(in, out);
214  if (in.TwistFlavor() != out.TwistFlavor())
215  errorQuda("Twist flavors %d %d don't match", in.TwistFlavor(), out.TwistFlavor());
217  errorQuda("Twist flavor not set %d\n", in.TwistFlavor());
218 
219  twisted::setFace(face1,face2); // FIXME: temporary hack maintain C linkage for dslashCuda
220  ndegtwisted::setFace(face1,face2); // FIXME: temporary hack maintain C linkage for dslashCuda
221 
223  double a = -2.0 * kappa * in.TwistFlavor() * mu; //for invert twist
224  double b = k / (1.0 + a*a); //for invert twist
225  if (!dagger) {
226  twistedMassDslashCuda(&out, gauge, &in, parity, dagger, &x, QUDA_DEG_DSLASH_TWIST_INV, a, b, 0.0, 0.0, commDim, profile);
227  flops += 1416ll*in.Volume();
228  } else { // tmp1 can alias in, but tmp2 can alias x so must not use this
229  twistedMassDslashCuda(&out, gauge, &in, parity, dagger, &x, QUDA_DEG_TWIST_INV_DSLASH, a, b, 0.0, 0.0, commDim, profile);
230  flops += 1416ll*in.Volume();
231  }
232  } else {//TWIST_DOUBLET:
233  double a = 2.0 * kappa * mu;
234  double b = 2.0 * kappa * epsilon;
235  double c = 1.0 / (1.0 + a*a - b*b);
236 
237  if (!dagger) {
238  c *= k;//(-kappa*kappa)
239  ndegTwistedMassDslashCuda(&out, gauge, &in, parity, dagger, &x, QUDA_NONDEG_DSLASH, a, b, c, 0.0, commDim, profile);
240  flops += 1464ll*in.Volume();
241  } else {
242  cudaColorSpinorField *doubletTmp=0;
243  bool reset = newTmp(&doubletTmp, in);
244  twistGamma5Cuda(doubletTmp, &in, dagger, -a, b, c, QUDA_TWIST_GAMMA5_INVERSE);//note a -> -a
245  ndegTwistedMassDslashCuda(&out, gauge, doubletTmp, parity, dagger, &x, QUDA_NONDEG_DSLASH, 0.0, 0.0, k, 0.0, commDim, profile);
246  flops += 1464ll*in.Volume();
247  deleteTmp(&doubletTmp, reset);
248  }
249  }
250  }
251 
253  {
254  double kappa2 = -kappa*kappa;
255 
256  bool reset = newTmp(&tmp1, in);
257 
260  Dslash(*tmp1, in, QUDA_ODD_PARITY);
261  DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, in, kappa2);
262  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
264  DslashXpay(out, *tmp1, QUDA_ODD_PARITY, in, kappa2);
265  } else {//asymmetric preconditioning
266  double a = 2.0 * kappa * in.TwistFlavor() * mu;
268  Dslash(*tmp1, in, QUDA_ODD_PARITY);
270  flops += (1320ll+96ll)*in.Volume();
274  flops += (1320ll+96ll)*in.Volume();
275  }else { // symmetric preconditioning
276  errorQuda("Invalid matpcType");
277  }
278  }
279  } else { //Twist doublet
281  Dslash(*tmp1, in, QUDA_ODD_PARITY);
282  DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, in, kappa2);
283  } else if (matpcType == QUDA_MATPC_ODD_ODD){
284  Dslash(*tmp1, in, QUDA_EVEN_PARITY); // fused kernel
285  DslashXpay(out, *tmp1, QUDA_ODD_PARITY, in, kappa2);
286  } else {// asymmetric preconditioning
287  //Parameter for invert twist (note the implemented operator: c*(1 - i *a * gamma_5 tau_3 + b * tau_1)):
288  //double a = !dagger ? -2.0 * kappa * mu : 2.0 * kappa * mu;
289  double a = -2.0 * kappa * mu;
290  double b = -2.0 * kappa * epsilon;
291  double c = 1.0;
292 
294  Dslash(*tmp1, in, QUDA_ODD_PARITY);
296  flops += (1464ll)*in.Volume();
298  Dslash(*tmp1, in, QUDA_EVEN_PARITY); // fused kernel
300  flops += (1464ll)*in.Volume();
301  }
302  }
303  }
304  deleteTmp(&tmp1, reset);
305  }
306 
308  {
309  // need extra temporary because of symmetric preconditioning dagger
310  bool reset = newTmp(&tmp2, in);
311  M(*tmp2, in);
312  Mdag(out, *tmp2);
313  deleteTmp(&tmp2, reset);
314  }
315 
318  const QudaSolutionType solType) const
319  {
320  twisted::setFace(face1,face2); // FIXME: temporary hack maintain C linkage for dslashCuda
321  ndegtwisted::setFace(face1,face2); // FIXME: temporary hack maintain C linkage for dslashCuda
322 
323  // we desire solution to preconditioned system
324  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
325  src = &b;
326  sol = &x;
327  return;
328  }
329 
330  bool reset = newTmp(&tmp1, b.Even());
331 
332  // we desire solution to full system
335  // src = A_ee^-1 (b_e + k D_eo A_oo^-1 b_o)
336  src = &(x.Odd());
337  TwistInv(*src, b.Odd());
339  TwistInv(*src, *tmp1);
340  sol = &(x.Even());
341  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
342  // src = A_oo^-1 (b_o + k D_oe A_ee^-1 b_e)
343  src = &(x.Even());
344  TwistInv(*src, b.Even());
346  TwistInv(*src, *tmp1);
347  sol = &(x.Odd());
349  // src = b_e + k D_eo A_oo^-1 b_o
350  src = &(x.Odd());
351  TwistInv(*tmp1, b.Odd()); // safe even when *tmp1 = b.odd
353  sol = &(x.Even());
355  // src = b_o + k D_oe A_ee^-1 b_e
356  src = &(x.Even());
357  TwistInv(*tmp1, b.Even()); // safe even when *tmp1 = b.even
359  sol = &(x.Odd());
360  } else {
361  errorQuda("MatPCType %d not valid for DiracTwistedMassPC", matpcType);
362  }
363  } else {//doublet:
364  // we desire solution to preconditioned system
365 
366  double a = 2.0 * kappa * mu;
367  double bb = 2.0 * kappa * epsilon;
368 
369  double d = (1.0 + a*a - bb*bb);
370  if(d <= 0) errorQuda("Invalid twisted mass parameter\n");
371  double c = 1.0 / d;
372 
373  // we desire solution to full system
375  // src = A_ee^-1(b_e + k D_eo A_oo^-1 b_o)
376  src = &(x.Odd());
377 
378  twistGamma5Cuda(src, &b.Odd(), dagger, a, bb, c, QUDA_TWIST_GAMMA5_DIRECT);//temporal hack!
380  twistGamma5Cuda(src, tmp1, dagger, a, bb, c, QUDA_TWIST_GAMMA5_DIRECT);//temporal hack!
381 
382  sol = &(x.Even());
383 
384  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
385  // src = A_oo^-1 (b_o + k D_oe A_ee^-1 b_e)
386  src = &(x.Even());
387 
388  twistGamma5Cuda(src, &b.Even(), dagger, a, bb, c, QUDA_TWIST_GAMMA5_DIRECT);//temporal hack!
390  twistGamma5Cuda(src, tmp1, dagger, a, bb, c, QUDA_TWIST_GAMMA5_DIRECT);//temporal hack!
391 
392  sol = &(x.Odd());
394  // src = b_e + k D_eo A_oo^-1 b_o
395  src = &(x.Odd());
396 
397  twistGamma5Cuda(tmp1, &b.Odd(), dagger, a, bb, c, QUDA_TWIST_GAMMA5_DIRECT);//temporal hack!
399 
400  sol = &(x.Even());
401 
403  // src = b_o + k D_oe A_ee^-1 b_e
404  src = &(x.Even());
405 
406  twistGamma5Cuda(tmp1, &b.Even(), dagger, a, bb, c, QUDA_TWIST_GAMMA5_DIRECT);//temporal hack!
408 
409  sol = &(x.Odd());
410  } else {
411  errorQuda("MatPCType %d not valid for DiracTwistedMassPC", matpcType);
412  }
413  }//end of doublet
414  // here we use final solution to store parity solution and parity source
415  // b is now up for grabs if we want
416 
417  deleteTmp(&tmp1, reset);
418  }
419 
421  const QudaSolutionType solType) const
422  {
423  twisted::setFace(face1,face2); // FIXME: temporary hack maintain C linkage for dslashCuda
424  ndegtwisted::setFace(face1,face2); // FIXME: temporary hack maintain C linkage for dslashCuda
425 
426  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
427  return;
428  }
429 
430  checkFullSpinor(x, b);
431  bool reset = newTmp(&tmp1, b.Even());
432 
433  // create full solution
436  // x_o = A_oo^-1 (b_o + k D_oe x_e)
438  TwistInv(x.Odd(), *tmp1);
440  // x_e = A_ee^-1 (b_e + k D_eo x_o)
442  TwistInv(x.Even(), *tmp1);
443  } else {
444  errorQuda("MatPCType %d not valid for DiracTwistedMassPC", matpcType);
445  }
446  } else { //twist doublet:
447  double a = 2.0 * kappa * mu;
448  double bb = 2.0 * kappa * epsilon;
449 
450  double d = (1.0 + a*a - bb*bb);
451  if(d <= 0) errorQuda("Invalid twisted mass parameter\n");
452  double c = 1.0 / d;
453 
455  // x_o = A_oo^-1 (b_o + k D_oe x_e)
456 
459 
461  // x_e = A_ee^-1 (b_e + k D_eo x_o)
462 
465  } else {
466  errorQuda("MatPCType %d not valid for DiracTwistedMassPC", matpcType);
467  }
468  }//end of twist doublet...
469  deleteTmp(&tmp1, reset);
470  }
471 } // namespace quda
FaceBuffer face1
Definition: dirac_quda.h:148
int commDim(int)
cudaGaugeField & gauge
Definition: dirac_quda.h:88
virtual void checkParitySpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const
Definition: dirac.cpp:84
void prepare(cudaColorSpinorField *&src, cudaColorSpinorField *&sol, cudaColorSpinorField &x, cudaColorSpinorField &b, const QudaSolutionType) const
unsigned long long flops
Definition: dirac_quda.h:93
virtual void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
DiracTwistedMassPC(const DiracTwistedMassPC &dirac)
#define errorQuda(...)
Definition: util_quda.h:73
void TwistInv(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
__global__ void const RealA *const const RealA *const const RealA *const const RealB *const const RealB *const int int mu
bool newTmp(cudaColorSpinorField **, const cudaColorSpinorField &) const
Definition: dirac.cpp:51
void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, const QudaSolutionType) const
DiracTwistedMassPC & operator=(const DiracTwistedMassPC &dirac)
TimeProfile profile
Definition: dirac_quda.h:104
DiracWilson & operator=(const DiracWilson &dirac)
DiracTwistedMass & operator=(const DiracTwistedMass &dirac)
virtual void Dslash(cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaParity parity) const
QudaDagType dagger
Definition: test_util.cpp:1558
cudaGaugeField * gauge
Definition: dirac_quda.h:30
QudaGaugeParam param
Definition: pack_test.cpp:17
cudaColorSpinorField & Odd() const
int commDim[QUDA_MAX_DIM]
Definition: dirac_quda.h:102
cudaColorSpinorField * tmp
void ndegTwistedMassDslashCuda(cudaColorSpinorField *out, const cudaGaugeField &gauge, const cudaColorSpinorField *in, const int parity, const int dagger, const cudaColorSpinorField *x, const QudaTwistDslashType type, const double &kappa, const double &mu, const double &epsilon, const double &k, const int *commDim, TimeProfile &profile, const QudaDslashPolicy &dslashPolicy=QUDA_DSLASH)
virtual void reconstruct(cudaColorSpinorField &x, const cudaColorSpinorField &b, const QudaSolutionType) const
VOLATILE spinorFloat kappa
cpuColorSpinorField * in
enum QudaSolutionType_s QudaSolutionType
void deleteTmp(cudaColorSpinorField **, const bool &reset) const
Definition: dirac.cpp:59
Dirac * dirac
Definition: dslash_test.cpp:45
QudaDagType dagger
Definition: dirac_quda.h:92
enum QudaParity_s QudaParity
void twistGamma5Cuda(cudaColorSpinorField *out, const cudaColorSpinorField *in, const int dagger, const double &kappa, const double &mu, const double &epsilon, const QudaTwistGamma5Type)
ndeg tm:
Definition: dslash_quda.cu:356
virtual void prepare(cudaColorSpinorField *&src, cudaColorSpinorField *&sol, cudaColorSpinorField &x, cudaColorSpinorField &b, const QudaSolutionType) const
int x[4]
double kappa
Definition: dirac_quda.h:89
QudaMatPCType matpcType
Definition: dirac_quda.h:91
void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
FaceBuffer face2
Definition: dirac_quda.h:148
cudaColorSpinorField * tmp2
Definition: dirac_quda.h:95
virtual void checkFullSpinor(const cudaColorSpinorField &, const cudaColorSpinorField &) const
Definition: dirac.cpp:121
virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaParity parity, const cudaColorSpinorField &x, const double &k) const
virtual void DslashXpay(cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaParity parity, const cudaColorSpinorField &x, const double &k) const
cpuColorSpinorField * out
cudaColorSpinorField * tmp1
Definition: dirac_quda.h:94
QudaTwistFlavorType TwistFlavor() const
void twistedApply(cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaTwistGamma5Type twistType) const
virtual void MdagM(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
enum QudaTwistGamma5Type_s QudaTwistGamma5Type
void M(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
DiracTwistedMass(const DiracTwistedMass &dirac)
void Mdag(cudaColorSpinorField &out, const cudaColorSpinorField &in) const
Definition: dirac.cpp:68
void twistedMassDslashCuda(cudaColorSpinorField *out, const cudaGaugeField &gauge, const cudaColorSpinorField *in, const int parity, const int dagger, const cudaColorSpinorField *x, const QudaTwistDslashType type, const double &kappa, const double &mu, const double &epsilon, const double &k, const int *commDim, TimeProfile &profile, const QudaDslashPolicy &dslashPolicy=QUDA_DSLASH2)
QudaSiteSubset SiteSubset() const
const QudaParity parity
Definition: dslash_test.cpp:29
void * gauge[4]
Definition: su3_test.cpp:15
cudaColorSpinorField & Even() const
void Twist(cudaColorSpinorField &out, const cudaColorSpinorField &in) const