QUDA  v0.5.0
A library for QCD on GPUs
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
dirac_twisted_mass.cpp
Go to the documentation of this file.
1 #include <dirac_quda.h>
2 #include <blas_quda.h>
3 #include <iostream>
4 
5 namespace quda {
6 
7  int DiracTwistedMass::initTMFlag = 0;//set to 1 for parity spinors, and 2 for full spinors
8 
9  DiracTwistedMass::DiracTwistedMass(const DiracTwistedMass &dirac) : DiracWilson(dirac), mu(dirac.mu), epsilon(dirac.epsilon) { }
10 
11  DiracTwistedMass::DiracTwistedMass(const DiracParam &param, const int nDim) : DiracWilson(param, nDim), mu(param.mu), epsilon(param.epsilon) { }
12 
14 
16  {
17  if (&dirac != this) {
19  }
20  return *this;
21  }
22 
23  // Protected method for applying twist
24 
26  const QudaTwistGamma5Type twistType) const
27  {
28  checkParitySpinor(out, in);
29 
30  if(initTMFlag != 1 && in.SiteSubset() == QUDA_PARITY_SITE_SUBSET){
31  int flv_stride = 0;
32  flv_stride = (in.TwistFlavor() != QUDA_TWIST_PLUS || in.TwistFlavor() != QUDA_TWIST_MINUS) ? in.Volume()/2 : in.Volume();
33 
35  initTwistedMassConstants(flv_stride);
36  initTMFlag = 1;
37  }
38  else if(initTMFlag != 2 && in.SiteSubset() == QUDA_FULL_SITE_SUBSET){
39  int flv_stride = 0;
40  flv_stride = (in.TwistFlavor() != QUDA_TWIST_PLUS || in.TwistFlavor() != QUDA_TWIST_MINUS) ? in.Volume()/4 : in.Volume()/2;//extract half-volume
41 
43  initTwistedMassConstants(flv_stride);
44  initTMFlag = 2;
45  }
46 
48  errorQuda("Twist flavor not set %d\n", in.TwistFlavor());
49 
51  {
52  double flavor_mu = in.TwistFlavor() * mu;
53  twistGamma5Cuda(&out, &in, dagger, kappa, flavor_mu, 0.0, twistType);
54  flops += 24ll*in.Volume();
55  }
56  else
57  {
58  errorQuda("DiracTwistedMass::twistedApply method for flavor doublet is not implemented..\n");
59  }
60  }
61 
62 
63  // Public method to apply the twist
65  {
67  }
68 
70  {
71  checkFullSpinor(out, in);
72  if (in.TwistFlavor() != out.TwistFlavor())
73  errorQuda("Twist flavors %d %d don't match", in.TwistFlavor(), out.TwistFlavor());
74 
76  errorQuda("Twist flavor not set %d\n", in.TwistFlavor());
77  }
78 
79  // We can eliminate this temporary at the expense of more kernels (like clover)
80  cudaColorSpinorField *tmp=0; // this hack allows for tmp2 to be full or parity field
81  if (tmp2) {
82  if (tmp2->SiteSubset() == QUDA_FULL_SITE_SUBSET) tmp = &(tmp2->Even());
83  else tmp = tmp2;
84  }
85  bool reset = newTmp(&tmp, in.Even());
86 
88  Twist(*tmp, in.Odd());
89  DslashXpay(out.Odd(), in.Even(), QUDA_ODD_PARITY, *tmp, -kappa);
90  Twist(*tmp, in.Even());
91  DslashXpay(out.Even(), in.Odd(), QUDA_EVEN_PARITY, *tmp, -kappa);
92  }
93  else{
94  //errorQuda("Method is not implemented %d\n", in.TwistFlavor());
95  double a = -2.0 * kappa * mu;
96  double b = -2.0 * kappa * epsilon;
97  //if(d <= 0) errorQuda("Invalid twisted mass parameter\n");
98  double c = 1.0;
99 
100  setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda
101 
102  if(initTMFlag != 2){
104  int flv_stride = in.Volume()/4;//if (in.SiteSubset() == QUDA_FULL_SITE_SUBSET)
105  initTwistedMassConstants(flv_stride);
106  initTMFlag = 2;
107  }
108 
109  twistGamma5Cuda(tmp, &in.Odd(), dagger, a, b, c, QUDA_TWIST_GAMMA5_DIRECT);
110  twistedMassDslashCuda(&out.Odd(), gauge, &in.Even(), QUDA_ODD_PARITY, dagger, tmp, 0.0, 0.0, -kappa, commDim);
111 
112  twistGamma5Cuda(tmp, &in.Even(), dagger, a, b, c, QUDA_TWIST_GAMMA5_DIRECT);
113  twistedMassDslashCuda(&out.Even(), gauge, &in.Odd(), QUDA_EVEN_PARITY, dagger, tmp, 0.0, 0.0, -kappa, commDim);
114 
115  flops += (1320+72+24)*in.Volume();
116  }
117  deleteTmp(&tmp, reset);
118  }
119 
121  {
122  checkFullSpinor(out, in);
123  bool reset = newTmp(&tmp1, in);
124 
125  M(*tmp1, in);
126  Mdag(out, *tmp1);
127 
128  deleteTmp(&tmp1, reset);
129  }
130 
133  const QudaSolutionType solType) const
134  {
135  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
136  errorQuda("Preconditioned solution requires a preconditioned solve_type");
137  }
138 
139  src = &b;
140  sol = &x;
141  }
142 
144  const QudaSolutionType solType) const
145  {
146  // do nothing
147  }
148 
149 
151 
153 
155  {
156 
157  }
158 
160  {
161  if (&dirac != this) {
163  }
164  return *this;
165  }
166 
167  // Public method to apply the inverse twist
169  {
171  }
172 
173  // apply hopping term, then inverse twist: (A_ee^-1 D_eo) or (A_oo^-1 D_oe),
174  // and likewise for dagger: (D^dagger_eo D_ee^-1) or (D^dagger_oe A_oo^-1)
177  {
178  checkParitySpinor(in, out);
179  checkSpinorAlias(in, out);
180 
181  if (in.TwistFlavor() != out.TwistFlavor())
182  errorQuda("Twist flavors %d %d don't match", in.TwistFlavor(), out.TwistFlavor());
184  errorQuda("Twist flavor not set %d\n", in.TwistFlavor());
185 
187  if (!dagger || matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC)
188  {
189  double flavor_mu = in.TwistFlavor() * mu;
190  setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda
191  if(initTMFlag != 1){
192  initSpinorConstants(in);
193  initTMFlag = 1;
194  }
195  twistedMassDslashCuda(&out, gauge, &in, parity, dagger, 0, kappa, flavor_mu, 0.0, commDim);
196  flops += 1392ll*in.Volume();
197  } else { // safe to use tmp2 here which may alias in
198  bool reset = newTmp(&tmp2, in);
199 
200  if(initTMFlag != 1){
201  initSpinorConstants(in);
202  initTMFlag = 1;
203  }
204 
205  TwistInv(*tmp2, in);
206  DiracWilson::Dslash(out, *tmp2, parity);
207 
208  flops += 72ll*in.Volume();
209 
210  // if the pointers alias, undo the twist
211  if (tmp2->V() == in.V()) Twist(*tmp2, *tmp2);
212 
213  deleteTmp(&tmp2, reset);
214  }
215  }
216  else{//TWIST doublet :
217  double a = 2.0 * kappa * mu;
218  double b = 2.0 * kappa * epsilon;
219 
220  double d = (1.0 + a*a - b*b);
221  if(d <= 0) errorQuda("Invalid twisted mass parameter\n");
222  double c = 1.0 / d;
223 
224  setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda
225 
226  if(initTMFlag != 1){
228  int flv_stride = in.Volume()/2;//if (in.SiteSubset() == QUDA_PARITY_SITE_SUBSET)
229  initTwistedMassConstants(flv_stride);
230  initTMFlag = 1;
231  }
232 
233  if (!dagger || matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) {
234  twistedMassDslashCuda(&out, gauge, &in, parity, dagger, 0, /*'kappa' = */a, /*mu = */b, /*epsilon = */c, commDim);//need to set 2km 2ke and c
235  flops += (1320+72+24)*in.Volume();
236  }
237  else{
238  cudaColorSpinorField *doubletTmp=0;
239  bool reset = newTmp(&doubletTmp, in);
240 
241  a *= -1.0;
242  twistGamma5Cuda(doubletTmp, &in, dagger, a, b, c, QUDA_TWIST_GAMMA5_INVERSE);//??
243  twistedMassDslashCuda(&out, gauge, doubletTmp, parity, dagger, 0, /*kappa = */0.0, /*mu = */0.0, /*epsilon = */1.0, commDim);
244 
245  flops += 1416ll*in.Volume();
246 
247  deleteTmp(&doubletTmp, reset);
248  }
249  }
250  }
251 
252  // xpay version of the above
254  (cudaColorSpinorField &out, const cudaColorSpinorField &in, const QudaParity parity, const cudaColorSpinorField &x, const double &k) const
255  {
256  checkParitySpinor(in, out);
257  checkSpinorAlias(in, out);
258  if (in.TwistFlavor() != out.TwistFlavor())
259  errorQuda("Twist flavors %d %d don't match", in.TwistFlavor(), out.TwistFlavor());
261  errorQuda("Twist flavor not set %d\n", in.TwistFlavor());
262 
264  if (!dagger) {
265  double flavor_mu = in.TwistFlavor() * mu;
266  setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda
267  if(initTMFlag != 1){
268  initSpinorConstants(in);
269  initTMFlag = 1;
270  }
271  twistedMassDslashCuda(&out, gauge, &in, parity, dagger, &x, kappa, flavor_mu, k, commDim);
272  flops += 1416ll*in.Volume();
273  } else { // tmp1 can alias in, but tmp2 can alias x so must not use this
274  bool reset = newTmp(&tmp1, in);
275  if(initTMFlag != 1){
276  initSpinorConstants(in);
277  initTMFlag = 1;
278  }
279  TwistInv(*tmp1, in);
280  DiracWilson::Dslash(out, *tmp1, parity);
281  xpayCuda((cudaColorSpinorField&)x, k, out);
282  flops += 96ll*in.Volume();
283 
284  // if the pointers alias, undo the twist
285  if (tmp1->V() == in.V()) Twist(*tmp1, *tmp1);
286 
287  deleteTmp(&tmp1, reset);
288  }
289  }
290  else{//TWIST_DOUBLET:
291  double a = 2.0 * kappa * mu;
292  double b = 2.0 * kappa * epsilon;
293 
294  double d = (1.0 + a*a - b*b);
295  if(d <= 0) errorQuda("Invalid twisted mass parameter\n");
296  double c = 1.0 / d;
297 
298  setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda
299 
300  if(initTMFlag != 1){
301  initSpinorConstants(in);
302  int flv_stride = in.Volume()/2;//if (in.SiteSubset() == QUDA_PARITY_SITE_SUBSET)
303  initTwistedMassConstants(flv_stride);
304  initTMFlag = 1;
305  }
306 
307  if (!dagger) {
308  c *= k;//(-kappa*kappa)
309  twistedMassDslashCuda(&out, gauge, &in, parity, dagger, &x, a, b, c, commDim);
310  flops += 1440ll*in.Volume();
311  }
312  else{
313  cudaColorSpinorField *doubletTmp=0;
314  bool reset = newTmp(&doubletTmp, in);
315 
316  a *= -1.0;
317  twistGamma5Cuda(doubletTmp, &in, dagger, a, b, c, QUDA_TWIST_GAMMA5_INVERSE);
318  c = k;
319  twistedMassDslashCuda(&out, gauge, doubletTmp, parity, dagger, &x, 0.0, 0.0, c, commDim);
320  flops += 1440ll*in.Volume();
321  deleteTmp(&doubletTmp, reset);
322  }
323  }
324  }
325 
327  {
328  double kappa2 = -kappa*kappa;
329 
330  bool reset = newTmp(&tmp1, in);
331 
334  Dslash(*tmp1, in, QUDA_ODD_PARITY);
335  Twist(out, in);
336  DiracWilson::DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, out, kappa2);
339  Twist(out, in);
340  DiracWilson::DslashXpay(out, *tmp1, QUDA_ODD_PARITY, out, kappa2);
341  } else { // symmetric preconditioning
343  Dslash(*tmp1, in, QUDA_ODD_PARITY);
344  DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, in, kappa2);
345  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
347  DslashXpay(out, *tmp1, QUDA_ODD_PARITY, in, kappa2);
348  } else {
349  errorQuda("Invalid matpcType");
350  }
351  }
352  }//Twist doublet
353  else{
355  Dslash(*tmp1, in, QUDA_ODD_PARITY);
356  DslashXpay(out, *tmp1, QUDA_EVEN_PARITY, in, kappa2);
357  } else if (matpcType == QUDA_MATPC_ODD_ODD){
358  Dslash(*tmp1, in, QUDA_EVEN_PARITY); // fused kernel
359  DslashXpay(out, *tmp1, QUDA_ODD_PARITY, in, kappa2);
360  }
361  else {// asymmetric preconditioning
362  //Parameter for invert twist (note the implemented operator: c*(1 - i *a * gamma_5 tau_3 + b * tau_1)):
363 
364  double a = !dagger ? -2.0 * kappa * mu : 2.0 * kappa * mu;
365  double b = -2.0 * kappa * epsilon;
366  double c = 1.0;
367 
368  cudaColorSpinorField *asymTmp=0;
369  bool reset_asym = newTmp(&asymTmp, in);
370 
372  Dslash(*tmp1, in, QUDA_ODD_PARITY);
373  twistGamma5Cuda(asymTmp, &in, dagger, a, b, c, QUDA_TWIST_GAMMA5_DIRECT);//direct due to c and b
375  twistedMassDslashCuda(&out, gauge, tmp1, QUDA_EVEN_PARITY, dagger, asymTmp, 0.0, 0.0, kappa2, commDim);
377  Dslash(*tmp1, in, QUDA_EVEN_PARITY); // fused kernel
378  twistGamma5Cuda(asymTmp, &in, dagger, a, b, c, QUDA_TWIST_GAMMA5_DIRECT);
380  twistedMassDslashCuda(&out, gauge, tmp1, QUDA_ODD_PARITY, dagger, asymTmp, 0.0, 0.0, kappa2, commDim);
381  }
382  deleteTmp(&asymTmp, reset_asym);
383  }
384  }
385  deleteTmp(&tmp1, reset);
386  }
387 
389  {
390  // need extra temporary because of symmetric preconditioning dagger
391  bool reset = newTmp(&tmp2, in);
392  M(*tmp2, in);
393  Mdag(out, *tmp2);
394  deleteTmp(&tmp2, reset);
395  }
396 
399  const QudaSolutionType solType) const
400  {
401  // we desire solution to preconditioned system
402  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
403  src = &b;
404  sol = &x;
405  return;
406  }
407 
408  bool reset = newTmp(&tmp1, b.Even());
409 
410  // we desire solution to full system
413  // src = A_ee^-1 (b_e + k D_eo A_oo^-1 b_o)
414  src = &(x.Odd());
415  TwistInv(*src, b.Odd());
417  TwistInv(*src, *tmp1);
418  sol = &(x.Even());
419  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
420  // src = A_oo^-1 (b_o + k D_oe A_ee^-1 b_e)
421  src = &(x.Even());
422  TwistInv(*src, b.Even());
424  TwistInv(*src, *tmp1);
425  sol = &(x.Odd());
427  // src = b_e + k D_eo A_oo^-1 b_o
428  src = &(x.Odd());
429  TwistInv(*tmp1, b.Odd()); // safe even when *tmp1 = b.odd
431  sol = &(x.Even());
433  // src = b_o + k D_oe A_ee^-1 b_e
434  src = &(x.Even());
435  TwistInv(*tmp1, b.Even()); // safe even when *tmp1 = b.even
437  sol = &(x.Odd());
438  } else {
439  errorQuda("MatPCType %d not valid for DiracTwistedMassPC", matpcType);
440  }
441  }
442  else{//doublet:
443  // we desire solution to preconditioned system
444 
445  double a = 2.0 * kappa * mu;
446  double bb = 2.0 * kappa * epsilon;
447 
448  double d = (1.0 + a*a - bb*bb);
449  if(d <= 0) errorQuda("Invalid twisted mass parameter\n");
450  double c = 1.0 / d;
451 
452  // we desire solution to full system
454  // src = A_ee^-1(b_e + k D_eo A_oo^-1 b_o)
455  src = &(x.Odd());
456  setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda
457  //
458  if(initTMFlag != 1){
459  initSpinorConstants(*src);
460  int flv_stride = src->Volume()/2;//if (in.SiteSubset() == QUDA_PARITY_SITE_SUBSET)
461  initTwistedMassConstants(flv_stride);
462  initTMFlag = 1;
463  }
464 
465  twistGamma5Cuda(src, &b.Odd(), dagger, a, bb, c, QUDA_TWIST_GAMMA5_DIRECT);//temporal hack!
467  twistGamma5Cuda(src, tmp1, dagger, a, bb, c, QUDA_TWIST_GAMMA5_DIRECT);//temporal hack!
468 
469  sol = &(x.Even());
470 
471  } else if (matpcType == QUDA_MATPC_ODD_ODD) {
472  // src = A_oo^-1 (b_o + k D_oe A_ee^-1 b_e)
473  src = &(x.Even());
474  setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda
475  //
476  if(initTMFlag != 1){
477  initSpinorConstants(*src);
478  int flv_stride = src->Volume()/2;//if (in.SiteSubset() == QUDA_PARITY_SITE_SUBSET)
479  initTwistedMassConstants(flv_stride);
480  initTMFlag = 1;
481  }
482 
483  twistGamma5Cuda(src, &b.Even(), dagger, a, bb, c, QUDA_TWIST_GAMMA5_DIRECT);//temporal hack!
485  twistGamma5Cuda(src, tmp1, dagger, a, bb, c, QUDA_TWIST_GAMMA5_DIRECT);//temporal hack!
486 
487  sol = &(x.Odd());
489  // src = b_e + k D_eo A_oo^-1 b_o
490  src = &(x.Odd());
491  setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda
492  //
493  if(initTMFlag != 1){
494  initSpinorConstants(*src);
495  int flv_stride = src->Volume()/2;//if (in.SiteSubset() == QUDA_PARITY_SITE_SUBSET)
496  initTwistedMassConstants(flv_stride);
497  initTMFlag = 1;
498  }
499 
500  twistGamma5Cuda(tmp1, &b.Odd(), dagger, a, bb, c, QUDA_TWIST_GAMMA5_DIRECT);//temporal hack!
502 
503  sol = &(x.Even());
504 
506  // src = b_o + k D_oe A_ee^-1 b_e
507  src = &(x.Even());
508  setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda
509  //
510  if(initTMFlag != 1){
511  initSpinorConstants(*src);
512  int flv_stride = src->Volume()/2;//if (in.SiteSubset() == QUDA_PARITY_SITE_SUBSET)
513  initTwistedMassConstants(flv_stride);
514  initTMFlag = 1;
515  }
516 
517  twistGamma5Cuda(tmp1, &b.Even(), dagger, a, bb, c, QUDA_TWIST_GAMMA5_DIRECT);//temporal hack!
519 commDim);
520 
521  sol = &(x.Odd());
522  } else {
523  errorQuda("MatPCType %d not valid for DiracTwistedMassPC", matpcType);
524  }
525  }//end of doublet
526  // here we use final solution to store parity solution and parity source
527  // b is now up for grabs if we want
528 
529  deleteTmp(&tmp1, reset);
530  }
531 
533  const QudaSolutionType solType) const
534  {
535  if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) {
536  return;
537  }
538 
539  checkFullSpinor(x, b);
540  bool reset = newTmp(&tmp1, b.Even());
541 
542  // create full solution
545  // x_o = A_oo^-1 (b_o + k D_oe x_e)
547  TwistInv(x.Odd(), *tmp1);
549  // x_e = A_ee^-1 (b_e + k D_eo x_o)
551  TwistInv(x.Even(), *tmp1);
552  } else {
553  errorQuda("MatPCType %d not valid for DiracTwistedMassPC", matpcType);
554  }
555  }//twist doublet:
556  else{
557  double a = 2.0 * kappa * mu;
558  double bb = 2.0 * kappa * epsilon;
559 
560  double d = (1.0 + a*a - bb*bb);
561  if(d <= 0) errorQuda("Invalid twisted mass parameter\n");
562  double c = 1.0 / d;
563 
565  // x_o = A_oo^-1 (b_o + k D_oe x_e)
566  setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda
567  //
568  if(initTMFlag != 1){
570  int flv_stride = x.Even().Volume()/2;//if (in.SiteSubset() == QUDA_PARITY_SITE_SUBSET)
571  initTwistedMassConstants(flv_stride);
572  initTMFlag = 1;
573  }
574 
577 
579  // x_e = A_ee^-1 (b_e + k D_eo x_o)
580  setFace(face); // FIXME: temporary hack maintain C linkage for dslashCuda
581  //
582  if(initTMFlag != 1){
583  initSpinorConstants(x.Odd());
584  int flv_stride = x.Odd().Volume()/2;//if (in.SiteSubset() == QUDA_PARITY_SITE_SUBSET)
585  initTwistedMassConstants(flv_stride);
586  initTMFlag = 1;
587  }
588 
591  } else {
592  errorQuda("MatPCType %d not valid for DiracTwistedMassPC", matpcType);
593  }
594  }//end of twist doublet...
595  deleteTmp(&tmp1, reset);
596  }
597 } // namespace quda