QUDA
v0.5.0
A library for QCD on GPUs
Main Page
Namespaces
Classes
Files
File List
File Members
All
Classes
Namespaces
Files
Functions
Variables
Typedefs
Enumerations
Enumerator
Friends
Macros
Pages
quda
lib
dirac_staggered.cpp
Go to the documentation of this file.
1
#include <
dirac_quda.h
>
2
#include <
blas_quda.h
>
3
4
namespace
quda {
5
6
DiracStaggered::DiracStaggered
(
const
DiracParam
&
param
) :
7
Dirac
(param), fatGauge(*(param.fatGauge)), longGauge(*(param.longGauge)),
8
face(param.fatGauge->
X
(), 4, 6, 3, param.fatGauge->Precision())
9
//FIXME: this may break mixed precision multishift solver since may not have fatGauge initializeed yet
10
{
11
initStaggeredConstants
(
fatGauge
,
longGauge
);
12
}
13
14
DiracStaggered::DiracStaggered
(
const
DiracStaggered
&
dirac
) :
Dirac
(dirac),
15
fatGauge(dirac.fatGauge), longGauge(dirac.longGauge), face(dirac.face)
16
{
17
initStaggeredConstants
(
fatGauge
,
longGauge
);
18
}
19
20
DiracStaggered::~DiracStaggered
() { }
21
22
DiracStaggered
&
DiracStaggered::operator=
(
const
DiracStaggered
&
dirac
)
23
{
24
if
(&dirac !=
this
) {
25
Dirac::operator=
(dirac);
26
fatGauge
= dirac.
fatGauge
;
27
longGauge
= dirac.
longGauge
;
28
face
= dirac.
face
;
29
}
30
return
*
this
;
31
}
32
33
void
DiracStaggered::checkParitySpinor
(
const
cudaColorSpinorField
&
in
,
const
cudaColorSpinorField
&
out
)
const
34
{
35
if
(in.
Precision
() != out.
Precision
()) {
36
errorQuda
(
"Input and output spinor precisions don't match in dslash_quda"
);
37
}
38
39
if
(in.
Stride
() != out.
Stride
()) {
40
errorQuda
(
"Input %d and output %d spinor strides don't match in dslash_quda"
, in.
Stride
(), out.
Stride
());
41
}
42
43
if
(in.
SiteSubset
() !=
QUDA_PARITY_SITE_SUBSET
|| out.
SiteSubset
() !=
QUDA_PARITY_SITE_SUBSET
) {
44
errorQuda
(
"ColorSpinorFields are not single parity, in = %d, out = %d"
,
45
in.
SiteSubset
(), out.
SiteSubset
());
46
}
47
48
if
((out.
Volume
() != 2*
fatGauge
.
VolumeCB
() && out.
SiteSubset
() ==
QUDA_FULL_SITE_SUBSET
) ||
49
(out.
Volume
() !=
fatGauge
.
VolumeCB
() && out.
SiteSubset
() ==
QUDA_PARITY_SITE_SUBSET
) ) {
50
errorQuda
(
"Spinor volume %d doesn't match gauge volume %d"
, out.
Volume
(),
fatGauge
.
VolumeCB
());
51
}
52
}
53
54
55
void
DiracStaggered::Dslash
(
cudaColorSpinorField
&
out
,
const
cudaColorSpinorField
&
in
,
56
const
QudaParity
parity
)
const
57
{
58
checkParitySpinor
(in, out);
59
60
initSpinorConstants
(in);
61
setFace
(
face
);
// FIXME: temporary hack maintain C linkage for dslashCuda
62
staggeredDslashCuda
(&out,
fatGauge
,
longGauge
, &in, parity,
dagger
, 0, 0,
commDim
);
63
64
flops
+= 1146ll*in.
Volume
();
65
}
66
67
void
DiracStaggered::DslashXpay
(
cudaColorSpinorField
&
out
,
const
cudaColorSpinorField
&
in
,
68
const
QudaParity
parity
,
const
cudaColorSpinorField
&
x
,
69
const
double
&k)
const
70
{
71
checkParitySpinor
(in, out);
72
73
initSpinorConstants
(in);
74
setFace
(
face
);
// FIXME: temporary hack maintain C linkage for dslashCuda
75
staggeredDslashCuda
(&out,
fatGauge
,
longGauge
, &in, parity,
dagger
, &x, k,
commDim
);
76
77
flops
+= 1158ll*in.
Volume
();
78
}
79
80
// Full staggered operator
81
void
DiracStaggered::M
(
cudaColorSpinorField
&
out
,
const
cudaColorSpinorField
&
in
)
const
82
{
83
bool
reset =
newTmp
(&
tmp1
, in.
Even
());
84
85
DslashXpay
(out.
Even
(), in.
Odd
(),
QUDA_EVEN_PARITY
, *
tmp1
, 2*
mass
);
86
DslashXpay
(out.
Odd
(), in.
Even
(),
QUDA_ODD_PARITY
, *
tmp1
, 2*
mass
);
87
88
deleteTmp
(&tmp1, reset);
89
}
90
91
void
DiracStaggered::MdagM
(
cudaColorSpinorField
&
out
,
const
cudaColorSpinorField
&
in
)
const
92
{
93
bool
reset =
newTmp
(&
tmp1
, in);
94
95
cudaColorSpinorField
* mytmp =
dynamic_cast<
cudaColorSpinorField
*
>
(&(
tmp1
->
Even
()));
96
cudaColorSpinorField
* ineven =
dynamic_cast<
cudaColorSpinorField
*
>
(&(in.
Even
()));
97
cudaColorSpinorField
* inodd =
dynamic_cast<
cudaColorSpinorField
*
>
(&(in.
Odd
()));
98
cudaColorSpinorField
* outeven =
dynamic_cast<
cudaColorSpinorField
*
>
(&(out.
Even
()));
99
cudaColorSpinorField
* outodd =
dynamic_cast<
cudaColorSpinorField
*
>
(&(out.
Odd
()));
100
101
//even
102
Dslash
(*mytmp, *ineven,
QUDA_ODD_PARITY
);
103
DslashXpay
(*outeven, *mytmp,
QUDA_EVEN_PARITY
, *ineven, 4*
mass
*
mass
);
104
105
//odd
106
Dslash
(*mytmp, *inodd,
QUDA_EVEN_PARITY
);
107
DslashXpay
(*outodd, *mytmp,
QUDA_ODD_PARITY
, *inodd, 4*mass*mass);
108
109
deleteTmp
(&
tmp1
, reset);
110
}
111
112
void
DiracStaggered::prepare
(
cudaColorSpinorField
* &src,
cudaColorSpinorField
* &sol,
113
cudaColorSpinorField
&
x
,
cudaColorSpinorField
&b,
114
const
QudaSolutionType
solType)
const
115
{
116
if
(solType ==
QUDA_MATPC_SOLUTION
|| solType ==
QUDA_MATPCDAG_MATPC_SOLUTION
) {
117
errorQuda
(
"Preconditioned solution requires a preconditioned solve_type"
);
118
}
119
120
src = &b;
121
sol = &
x
;
122
}
123
124
void
DiracStaggered::reconstruct
(
cudaColorSpinorField
&
x
,
const
cudaColorSpinorField
&b,
125
const
QudaSolutionType
solType)
const
126
{
127
// do nothing
128
}
129
130
131
DiracStaggeredPC::DiracStaggeredPC
(
const
DiracParam
&
param
)
132
:
DiracStaggered
(param)
133
{
134
135
}
136
137
DiracStaggeredPC::DiracStaggeredPC
(
const
DiracStaggeredPC
&
dirac
)
138
:
DiracStaggered
(dirac)
139
{
140
141
}
142
143
DiracStaggeredPC::~DiracStaggeredPC
()
144
{
145
146
}
147
148
DiracStaggeredPC
&
DiracStaggeredPC::operator=
(
const
DiracStaggeredPC
&
dirac
)
149
{
150
if
(&dirac !=
this
) {
151
DiracStaggered::operator=
(dirac);
152
}
153
154
return
*
this
;
155
}
156
157
void
DiracStaggeredPC::M
(
cudaColorSpinorField
&
out
,
const
cudaColorSpinorField
&
in
)
const
158
{
159
errorQuda
(
"DiracStaggeredPC::M() is not implemented\n"
);
160
}
161
162
void
DiracStaggeredPC::MdagM
(
cudaColorSpinorField
&
out
,
const
cudaColorSpinorField
&
in
)
const
163
{
164
bool
reset =
newTmp
(&
tmp1
, in);
165
166
QudaParity
parity
=
QUDA_INVALID_PARITY
;
167
QudaParity
other_parity =
QUDA_INVALID_PARITY
;
168
if
(
matpcType
==
QUDA_MATPC_EVEN_EVEN
) {
169
parity =
QUDA_EVEN_PARITY
;
170
other_parity =
QUDA_ODD_PARITY
;
171
}
else
if
(
matpcType
==
QUDA_MATPC_ODD_ODD
) {
172
parity =
QUDA_ODD_PARITY
;
173
other_parity =
QUDA_EVEN_PARITY
;
174
}
else
{
175
errorQuda
(
"Invalid matpcType(%d) in function\n"
,
matpcType
);
176
}
177
Dslash
(*
tmp1
, in, other_parity);
178
DslashXpay
(out, *
tmp1
, parity, in, 4*
mass
*
mass
);
179
180
deleteTmp
(&
tmp1
, reset);
181
}
182
183
void
DiracStaggeredPC::prepare
(
cudaColorSpinorField
* &src,
cudaColorSpinorField
* &sol,
184
cudaColorSpinorField
&
x
,
cudaColorSpinorField
&b,
185
const
QudaSolutionType
solType)
const
186
{
187
src = &b;
188
sol = &
x
;
189
}
190
191
void
DiracStaggeredPC::reconstruct
(
cudaColorSpinorField
&
x
,
const
cudaColorSpinorField
&b,
192
const
QudaSolutionType
solType)
const
193
{
194
// do nothing
195
}
196
197
}
// namespace quda
Generated on Wed Mar 20 2013 12:52:14 for QUDA by
1.8.2