QUDA  v0.7.0
A library for QCD on GPUs
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
staggered_dslash_reference.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 #include <string.h>
5 
6 #include <test_util.h>
7 #include <quda_internal.h>
8 #include <quda.h>
9 #include <util_quda.h>
11 #include "misc.h"
12 #include <blas_quda.h>
13 
14 #include <face_quda.h>
15 
16 extern void *memset(void *s, int c, size_t n);
17 
18 #include <dslash_util.h>
19 
20 //
21 // dslashReference()
22 //
23 // if oddBit is zero: calculate even parity spinor elements (using odd parity spinor)
24 // if oddBit is one: calculate odd parity spinor elements
25 //
26 // if daggerBit is zero: perform ordinary dslash operator
27 // if daggerBit is one: perform hermitian conjugate of dslash
28 //
29 template<typename Float>
31 {
32  int i, j;
33 
34  for (i = 0;i < 3; i++){
35  for(j=0;j < 3; j++){
36  printf("(%10f,%10f) \t", link[i*3*2 + j*2], link[i*3*2 + j*2 + 1]);
37  }
38  printf("\n");
39  }
40  printf("\n");
41  return;
42 }
43 
44 
45 template <typename sFloat, typename gFloat>
46 void dslashReference(sFloat *res, gFloat **fatlink, gFloat** longlink, sFloat *spinorField,
47  int oddBit, int daggerBit)
48 {
49  for (int i=0; i<Vh*1*3*2; i++) res[i] = 0.0;
50 
51  gFloat *fatlinkEven[4], *fatlinkOdd[4];
52  gFloat *longlinkEven[4], *longlinkOdd[4];
53 
54  for (int dir = 0; dir < 4; dir++) {
55  fatlinkEven[dir] = fatlink[dir];
56  fatlinkOdd[dir] = fatlink[dir] + Vh*gaugeSiteSize;
57  longlinkEven[dir] =longlink[dir];
58  longlinkOdd[dir] = longlink[dir] + Vh*gaugeSiteSize;
59  }
60 
61  for (int i = 0; i < Vh; i++) {
62  memset(res + i*mySpinorSiteSize, 0, mySpinorSiteSize*sizeof(sFloat));
63  for (int dir = 0; dir < 8; dir++) {
64  gFloat* fatlnk = gaugeLink(i, dir, oddBit, fatlinkEven, fatlinkOdd, 1);
65  gFloat* longlnk = gaugeLink(i, dir, oddBit, longlinkEven, longlinkOdd, 3);
66 
67  sFloat *first_neighbor_spinor = spinorNeighbor(i, dir, oddBit, spinorField, 1);
68  sFloat *third_neighbor_spinor = spinorNeighbor(i, dir, oddBit, spinorField, 3);
69 
70 
71  sFloat gaugedSpinor[mySpinorSiteSize];
72 
73  if (dir % 2 == 0){
74  su3Mul(gaugedSpinor, fatlnk, first_neighbor_spinor);
75  sum(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize);
76  su3Mul(gaugedSpinor, longlnk, third_neighbor_spinor);
77  sum(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize);
78  } else {
79  su3Tmul(gaugedSpinor, fatlnk, first_neighbor_spinor);
80  sub(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize);
81 
82  su3Tmul(gaugedSpinor, longlnk, third_neighbor_spinor);
83  sub(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize);
84  }
85  }
86  if (daggerBit){
87  negx(&res[i*mySpinorSiteSize], mySpinorSiteSize);
88  }
89  }
90 
91 }
92 
93 
94 
95 
96 void staggered_dslash(void *res, void **fatlink, void** longlink, void *spinorField, int oddBit, int daggerBit,
97  QudaPrecision sPrecision, QudaPrecision gPrecision) {
98 
99  if (sPrecision == QUDA_DOUBLE_PRECISION) {
100  if (gPrecision == QUDA_DOUBLE_PRECISION){
101  dslashReference((double*)res, (double**)fatlink, (double**)longlink, (double*)spinorField, oddBit, daggerBit);
102  }else{
103  dslashReference((double*)res, (float**)fatlink, (float**)longlink, (double*)spinorField, oddBit, daggerBit);
104  }
105  }
106  else{
107  if (gPrecision == QUDA_DOUBLE_PRECISION){
108  dslashReference((float*)res, (double**)fatlink, (double**)longlink, (float*)spinorField, oddBit, daggerBit);
109  }else{
110  dslashReference((float*)res, (float**)fatlink, (float**)longlink, (float*)spinorField, oddBit, daggerBit);
111  }
112  }
113 }
114 
115 
116 
117 
118 template <typename sFloat, typename gFloat>
119 void Mat(sFloat *out, gFloat **fatlink, gFloat** longlink, sFloat *in, sFloat kappa, int daggerBit)
120 {
121  sFloat *inEven = in;
122  sFloat *inOdd = in + Vh*mySpinorSiteSize;
123  sFloat *outEven = out;
124  sFloat *outOdd = out + Vh*mySpinorSiteSize;
125 
126  // full dslash operator
127  dslashReference(outOdd, fatlink, longlink, inEven, 1, daggerBit);
128  dslashReference(outEven, fatlink, longlink, inOdd, 0, daggerBit);
129 
130  // lastly apply the kappa term
131  xpay(in, -kappa, out, V*mySpinorSiteSize);
132 }
133 
134 
135 void
136 mat(void *out, void **fatlink, void** longlink, void *in, double kappa, int dagger_bit,
137  QudaPrecision sPrecision, QudaPrecision gPrecision)
138 {
139 
140  if (sPrecision == QUDA_DOUBLE_PRECISION){
141  if (gPrecision == QUDA_DOUBLE_PRECISION) {
142  Mat((double*)out, (double**)fatlink, (double**)longlink, (double*)in, (double)kappa, dagger_bit);
143  }else {
144  Mat((double*)out, (float**)fatlink, (float**)longlink, (double*)in, (double)kappa, dagger_bit);
145  }
146  }else{
147  if (gPrecision == QUDA_DOUBLE_PRECISION){
148  Mat((float*)out, (double**)fatlink, (double**)longlink, (float*)in, (float)kappa, dagger_bit);
149  }else {
150  Mat((float*)out, (float**)fatlink, (float**)longlink, (float*)in, (float)kappa, dagger_bit);
151  }
152  }
153 }
154 
155 
156 
157 template <typename sFloat, typename gFloat>
158 void
159 Matdagmat(sFloat *out, gFloat **fatlink, gFloat** longlink, sFloat *in, sFloat mass, int daggerBit, sFloat* tmp, QudaParity parity)
160 {
161 
162  sFloat msq_x4 = mass*mass*4;
163 
164  switch(parity){
165  case QUDA_EVEN_PARITY:
166  {
167  sFloat *inEven = in;
168  sFloat *outEven = out;
169  dslashReference(tmp, fatlink, longlink, inEven, 1, daggerBit);
170  dslashReference(outEven, fatlink, longlink, tmp, 0, daggerBit);
171 
172  // lastly apply the mass term
173  axmy(inEven, msq_x4, outEven, Vh*mySpinorSiteSize);
174  break;
175  }
176  case QUDA_ODD_PARITY:
177  {
178  sFloat *inOdd = in;
179  sFloat *outOdd = out;
180  dslashReference(tmp, fatlink, longlink, inOdd, 0, daggerBit);
181  dslashReference(outOdd, fatlink, longlink, tmp, 1, daggerBit);
182 
183  // lastly apply the mass term
184  axmy(inOdd, msq_x4, outOdd, Vh*mySpinorSiteSize);
185  break;
186  }
187 
188  default:
189  fprintf(stderr, "ERROR: invalid parity in %s,line %d\n", __FUNCTION__, __LINE__);
190  break;
191  }
192 
193 }
194 
195 
196 
197 void
198 matdagmat(void *out, void **fatlink, void** longlink, void *in, double mass, int dagger_bit,
199  QudaPrecision sPrecision, QudaPrecision gPrecision, void* tmp, QudaParity parity)
200 {
201 
202  if (sPrecision == QUDA_DOUBLE_PRECISION){
203  if (gPrecision == QUDA_DOUBLE_PRECISION) {
204  Matdagmat((double*)out, (double**)fatlink, (double**)longlink, (double*)in, (double)mass, dagger_bit, (double*)tmp, parity);
205  }else {
206  Matdagmat((double*)out, (float**)fatlink, (float**)longlink, (double*)in, (double)mass, dagger_bit, (double*) tmp, parity);
207  }
208  }else{
209  if (gPrecision == QUDA_DOUBLE_PRECISION){
210  Matdagmat((float*)out, (double**)fatlink, (double**)longlink, (float*)in, (float)mass, dagger_bit, (float*)tmp, parity);
211  }else {
212  Matdagmat((float*)out, (float**)fatlink, (float**)longlink, (float*)in, (float)mass, dagger_bit, (float*)tmp, parity);
213  }
214  }
215 }
216 
217 
218 
219 
220 
221 // Apply the even-odd preconditioned Dirac operator
222 template <typename sFloat, typename gFloat>
223 static void MatPC(sFloat *outEven, gFloat **fatlink, gFloat** longlink, sFloat *inEven, sFloat kappa,
224  int daggerBit, QudaMatPCType matpc_type) {
225 
226  sFloat *tmp = (sFloat*)malloc(Vh*mySpinorSiteSize*sizeof(sFloat));
227 
228  // full dslash operator
229  if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
230  dslashReference(tmp, fatlink, longlink, inEven, 1, daggerBit);
231  dslashReference(outEven, fatlink, longlink, tmp, 0, daggerBit);
232 
233  //dslashReference(outEven, fatlink, longlink, inEven, 1, daggerBit);
234  } else {
235  dslashReference(tmp, fatlink, longlink, inEven, 0, daggerBit);
236  dslashReference(outEven, fatlink, longlink, tmp, 1, daggerBit);
237  }
238 
239  // lastly apply the kappa term
240 
241  sFloat kappa2 = -kappa*kappa;
242  xpay(inEven, kappa2, outEven, Vh*mySpinorSiteSize);
243 
244  free(tmp);
245 }
246 
247 
248 void
249 staggered_matpc(void *outEven, void **fatlink, void**longlink, void *inEven, double kappa,
250  QudaMatPCType matpc_type, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision)
251 {
252 
253  if (sPrecision == QUDA_DOUBLE_PRECISION)
254  if (gPrecision == QUDA_DOUBLE_PRECISION) {
255  MatPC((double*)outEven, (double**)fatlink, (double**)longlink, (double*)inEven, (double)kappa, dagger_bit, matpc_type);
256  }
257  else{
258  MatPC((double*)outEven, (double**)fatlink, (double**)longlink, (double*)inEven, (double)kappa, dagger_bit, matpc_type);
259  }
260  else {
261  if (gPrecision == QUDA_DOUBLE_PRECISION){
262  MatPC((float*)outEven, (double**)fatlink, (double**)longlink, (float*)inEven, (float)kappa, dagger_bit, matpc_type);
263  }else{
264  MatPC((float*)outEven, (float**)fatlink, (float**)longlink, (float*)inEven, (float)kappa, dagger_bit, matpc_type);
265  }
266  }
267 }
268 
269 #ifdef MULTI_GPU
270 
271 template <typename sFloat, typename gFloat>
272 void dslashReference_mg4dir(sFloat *res, gFloat **fatlink, gFloat** longlink,
273  gFloat** ghostFatlink, gFloat** ghostLonglink,
274  sFloat *spinorField, sFloat** fwd_nbr_spinor,
275  sFloat** back_nbr_spinor, int oddBit, int daggerBit)
276 {
277  for (int i=0; i<Vh*1*3*2; i++) res[i] = 0.0;
278 
279  int Vsh[4] = {Vsh_x, Vsh_y, Vsh_z, Vsh_t};
280  gFloat *fatlinkEven[4], *fatlinkOdd[4];
281  gFloat *longlinkEven[4], *longlinkOdd[4];
282  gFloat *ghostFatlinkEven[4], *ghostFatlinkOdd[4];
283  gFloat *ghostLonglinkEven[4], *ghostLonglinkOdd[4];
284 
285  for (int dir = 0; dir < 4; dir++) {
286  fatlinkEven[dir] = fatlink[dir];
287  fatlinkOdd[dir] = fatlink[dir] + Vh*gaugeSiteSize;
288  longlinkEven[dir] =longlink[dir];
289  longlinkOdd[dir] = longlink[dir] + Vh*gaugeSiteSize;
290 
291  ghostFatlinkEven[dir] = ghostFatlink[dir];
292  ghostFatlinkOdd[dir] = ghostFatlink[dir] + Vsh[dir]*gaugeSiteSize;
293  ghostLonglinkEven[dir] = ghostLonglink[dir];
294  ghostLonglinkOdd[dir] = ghostLonglink[dir] + 3*Vsh[dir]*gaugeSiteSize;
295  }
296 
297  for (int i = 0; i < Vh; i++) {
298  memset(res + i*mySpinorSiteSize, 0, mySpinorSiteSize*sizeof(sFloat));
299  for (int dir = 0; dir < 8; dir++) {
300  gFloat* fatlnk = gaugeLink_mg4dir(i, dir, oddBit, fatlinkEven, fatlinkOdd, ghostFatlinkEven, ghostFatlinkOdd, 1, 1);
301  gFloat* longlnk = gaugeLink_mg4dir(i, dir, oddBit, longlinkEven, longlinkOdd, ghostLonglinkEven, ghostLonglinkOdd, 3, 3);
302 
303  sFloat *first_neighbor_spinor = spinorNeighbor_mg4dir(i, dir, oddBit, spinorField, fwd_nbr_spinor, back_nbr_spinor, 1, 3);
304  sFloat *third_neighbor_spinor = spinorNeighbor_mg4dir(i, dir, oddBit, spinorField, fwd_nbr_spinor, back_nbr_spinor, 3, 3);
305 
306  sFloat gaugedSpinor[mySpinorSiteSize];
307 
308 
309  if (dir % 2 == 0){
310  su3Mul(gaugedSpinor, fatlnk, first_neighbor_spinor);
311  sum(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize);
312  su3Mul(gaugedSpinor, longlnk, third_neighbor_spinor);
313  sum(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize);
314  }
315  else{
316  su3Tmul(gaugedSpinor, fatlnk, first_neighbor_spinor);
317  sub(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize);
318 
319  su3Tmul(gaugedSpinor, longlnk, third_neighbor_spinor);
320  sub(&res[i*mySpinorSiteSize], &res[i*mySpinorSiteSize], gaugedSpinor, mySpinorSiteSize);
321 
322  }
323 
324  }
325  if (daggerBit){
326  negx(&res[i*mySpinorSiteSize], mySpinorSiteSize);
327  }
328  }
329 
330 }
331 
332 
333 
334 void staggered_dslash_mg4dir(cpuColorSpinorField* out, void **fatlink, void** longlink, void** ghost_fatlink,
335  void** ghost_longlink, cpuColorSpinorField* in, int oddBit, int daggerBit,
336  QudaPrecision sPrecision, QudaPrecision gPrecision)
337 {
338 
339  QudaParity otherparity = QUDA_INVALID_PARITY;
340  if (oddBit == QUDA_EVEN_PARITY){
341  otherparity = QUDA_ODD_PARITY;
342  }else if (oddBit == QUDA_ODD_PARITY){
343  otherparity = QUDA_EVEN_PARITY;
344  }else{
345  errorQuda("ERROR: full parity not supported in function %s", __FUNCTION__);
346  }
347 
348  int Nc = 3;
349  int nFace = 3;
350  FaceBuffer faceBuf(Z, 4, 2*Nc, nFace, sPrecision);
351  faceBuf.exchangeCpuSpinor(*in, otherparity, daggerBit);
352 
353  void** fwd_nbr_spinor = in->fwdGhostFaceBuffer;
354  void** back_nbr_spinor = in->backGhostFaceBuffer;
355 
356  if (sPrecision == QUDA_DOUBLE_PRECISION) {
357  if (gPrecision == QUDA_DOUBLE_PRECISION){
358  dslashReference_mg4dir((double*)out->V(), (double**)fatlink, (double**)longlink,
359  (double**)ghost_fatlink, (double**)ghost_longlink, (double*)in->V(),
360  (double**)fwd_nbr_spinor, (double**)back_nbr_spinor, oddBit, daggerBit);
361  } else {
362  dslashReference_mg4dir((double*)out->V(), (float**)fatlink, (float**)longlink, (float**)ghost_fatlink, (float**)ghost_longlink,
363  (double*)in->V(), (double**)fwd_nbr_spinor, (double**)back_nbr_spinor, oddBit, daggerBit);
364  }
365  }
366  else{
367  if (gPrecision == QUDA_DOUBLE_PRECISION){
368  dslashReference_mg4dir((float*)out->V(), (double**)fatlink, (double**)longlink, (double**)ghost_fatlink, (double**)ghost_longlink,
369  (float*)in->V(), (float**)fwd_nbr_spinor, (float**)back_nbr_spinor, oddBit, daggerBit);
370  }else{
371  dslashReference_mg4dir((float*)out->V(), (float**)fatlink, (float**)longlink, (float**)ghost_fatlink, (float**)ghost_longlink,
372  (float*)in->V(), (float**)fwd_nbr_spinor, (float**)back_nbr_spinor, oddBit, daggerBit);
373  }
374  }
375 
376 
377 }
378 
379 void
380 matdagmat_mg4dir(cpuColorSpinorField* out, void **fatlink, void** longlink, void** ghost_fatlink, void** ghost_longlink,
381  cpuColorSpinorField* in, double mass, int dagger_bit,
382  QudaPrecision sPrecision, QudaPrecision gPrecision, cpuColorSpinorField* tmp, QudaParity parity)
383 {
384  //assert sPrecision and gPrecision must be the same
385  if (sPrecision != gPrecision){
386  errorQuda("Spinor precision and gPrecison is not the same");
387  }
388 
389  QudaParity otherparity = QUDA_INVALID_PARITY;
390  if (parity == QUDA_EVEN_PARITY){
391  otherparity = QUDA_ODD_PARITY;
392  }else if (parity == QUDA_ODD_PARITY){
393  otherparity = QUDA_EVEN_PARITY;
394  }else{
395  errorQuda("ERROR: full parity not supported in function %s\n", __FUNCTION__);
396  }
397 
398  staggered_dslash_mg4dir(tmp, fatlink, longlink, ghost_fatlink, ghost_longlink,
399  in, otherparity, dagger_bit, sPrecision, gPrecision);
400 
401  staggered_dslash_mg4dir(out, fatlink, longlink, ghost_fatlink, ghost_longlink,
402  tmp, parity, dagger_bit, sPrecision, gPrecision);
403 
404  double msq_x4 = mass*mass*4;
405  if (sPrecision == QUDA_DOUBLE_PRECISION){
406  axmy((double*)in->V(), (double)msq_x4, (double*)out->V(), Vh*mySpinorSiteSize);
407  }else{
408  axmy((float*)in->V(), (float)msq_x4, (float*)out->V(), Vh*mySpinorSiteSize);
409  }
410 
411 }
412 
413 #endif
414 
void display_link_internal(Float *link)
__constant__ int Vh
enum QudaPrecision_s QudaPrecision
int V
Definition: test_util.cpp:29
void matdagmat_mg4dir(cpuColorSpinorField *out, void **fatlink, void **longlink, void **ghost_fatlink, void **ghost_longlink, cpuColorSpinorField *in, double mass, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision, cpuColorSpinorField *tmp, QudaParity parity)
void staggered_matpc(void *outEven, void **fatlink, void **longlink, void *inEven, double kappa, QudaMatPCType matpc_type, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision)
__constant__ int Vsh
#define errorQuda(...)
Definition: util_quda.h:73
void staggered_dslash(void *res, void **fatlink, void **longlink, void *spinorField, int oddBit, int daggerBit, QudaPrecision sPrecision, QudaPrecision gPrecision)
void mat(void *out, void **fatlink, void **longlink, void *in, double kappa, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision)
#define gaugeSiteSize
#define Vsh_t
Definition: llfat_core.h:4
void * longlink[4]
#define Vsh_z
Definition: llfat_core.h:3
cudaColorSpinorField * tmp
void matdagmat(void *out, void **fatlink, void **longlink, void *in, double mass, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision, void *tmp, QudaParity parity)
VOLATILE spinorFloat kappa
FloatingPoint< float > Float
Definition: gtest.h:7350
cpuColorSpinorField * in
#define mySpinorSiteSize
enum QudaMatPCType_s QudaMatPCType
static void * backGhostFaceBuffer[QUDA_MAX_DIM]
enum QudaParity_s QudaParity
static void * fwdGhostFaceBuffer[QUDA_MAX_DIM]
cpuColorSpinorField * out
void * memset(void *s, int c, size_t n)
int Z[4]
Definition: test_util.cpp:28
Main header file for the QUDA library.
#define Vsh_y
Definition: llfat_core.h:2
QudaMatPCType matpc_type
Definition: test_util.cpp:1573
void staggered_dslash_mg4dir(cpuColorSpinorField *out, void **fatlink, void **longlink, void **ghost_fatlink, void **ghost_longlink, cpuColorSpinorField *in, int oddBit, int daggerBit, QudaPrecision sPrecision, QudaPrecision gPrecision)
#define Vsh_x
Definition: llfat_core.h:1
double mass
Definition: test_util.cpp:1569
VOLATILE spinorFloat * s
void dslashReference(sFloat *res, gFloat **fatlink, gFloat **longlink, sFloat *spinorField, int oddBit, int daggerBit)
void Matdagmat(sFloat *out, gFloat **fatlink, gFloat **longlink, sFloat *in, sFloat mass, int daggerBit, sFloat *tmp, QudaParity parity)
void Mat(sFloat *out, gFloat **fatlink, gFloat **longlink, sFloat *in, sFloat kappa, int daggerBit)
const QudaParity parity
Definition: dslash_test.cpp:29
int oddBit
void * fatlink[4]