QUDA  v1.1.0
A library for QCD on GPUs
staggered_dslash_reference.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 #include <string.h>
5 
6 #include <host_utils.h>
7 #include <quda_internal.h>
8 #include <quda.h>
9 #include <util_quda.h>
11 #include <command_line_params.h>
12 #include "misc.h"
13 #include <blas_quda.h>
14 
15 extern void *memset(void *s, int c, size_t n);
16 
17 #include <dslash_reference.h>
18 
19 template <typename Float> void display_link_internal(Float *link)
20 {
21  int i, j;
22 
23  for (i = 0; i < 3; i++) {
24  for (j = 0; j < 3; j++) { printf("(%10f,%10f) \t", link[i * 3 * 2 + j * 2], link[i * 3 * 2 + j * 2 + 1]); }
25  printf("\n");
26  }
27  printf("\n");
28  return;
29 }
30 
31 // staggeredDslashReferenece()
32 //
33 // if oddBit is zero: calculate even parity spinor elements (using odd parity spinor)
34 // if oddBit is one: calculate odd parity spinor elements
35 // if daggerBit is zero: perform ordinary dslash operator
36 // if daggerBit is one: perform hermitian conjugate of dslash
37 template <typename sFloat, typename gFloat>
38 void staggeredDslashReference(sFloat *res, gFloat **fatlink, gFloat **longlink, gFloat **ghostFatlink,
39  gFloat **ghostLonglink, sFloat *spinorField, sFloat **fwd_nbr_spinor,
40  sFloat **back_nbr_spinor, int oddBit, int daggerBit, int nSrc, QudaDslashType dslash_type)
41 {
42  for (int i = 0; i < Vh * stag_spinor_site_size * nSrc; i++) res[i] = 0.0;
43 
44  gFloat *fatlinkEven[4], *fatlinkOdd[4];
45  gFloat *longlinkEven[4], *longlinkOdd[4];
46 
47 #ifdef MULTI_GPU
48  gFloat *ghostFatlinkEven[4], *ghostFatlinkOdd[4];
49  gFloat *ghostLonglinkEven[4], *ghostLonglinkOdd[4];
50 #endif
51 
52  for (int dir = 0; dir < 4; dir++) {
53  fatlinkEven[dir] = fatlink[dir];
54  fatlinkOdd[dir] = fatlink[dir] + Vh * gauge_site_size;
55  longlinkEven[dir] = longlink[dir];
56  longlinkOdd[dir] = longlink[dir] + Vh * gauge_site_size;
57 
58 #ifdef MULTI_GPU
59  ghostFatlinkEven[dir] = ghostFatlink[dir];
60  ghostFatlinkOdd[dir] = ghostFatlink[dir] + (faceVolume[dir] / 2) * gauge_site_size;
61  ghostLonglinkEven[dir] = ghostLonglink[dir];
62  ghostLonglinkOdd[dir] = ghostLonglink[dir] + 3 * (faceVolume[dir] / 2) * gauge_site_size;
63 #endif
64  }
65 
66  for (int xs = 0; xs < nSrc; xs++) {
67 
68  for (int i = 0; i < Vh; i++) {
69  int sid = i + xs * Vh;
70  int offset = stag_spinor_site_size * sid;
71 
72  for (int dir = 0; dir < 8; dir++) {
73 #ifdef MULTI_GPU
74  const int nFace = dslash_type == QUDA_ASQTAD_DSLASH ? 3 : 1;
75  gFloat *fatlnk
76  = gaugeLink_mg4dir(i, dir, oddBit, fatlinkEven, fatlinkOdd, ghostFatlinkEven, ghostFatlinkOdd, 1, 1);
77  gFloat *longlnk = dslash_type == QUDA_ASQTAD_DSLASH ?
78  gaugeLink_mg4dir(i, dir, oddBit, longlinkEven, longlinkOdd, ghostLonglinkEven, ghostLonglinkOdd, 3, 3) :
79  nullptr;
80  sFloat *first_neighbor_spinor = spinorNeighbor_5d_mgpu<QUDA_4D_PC>(
81  sid, dir, oddBit, spinorField, fwd_nbr_spinor, back_nbr_spinor, 1, nFace, stag_spinor_site_size);
82  sFloat *third_neighbor_spinor = dslash_type == QUDA_ASQTAD_DSLASH ?
83  spinorNeighbor_5d_mgpu<QUDA_4D_PC>(sid, dir, oddBit, spinorField, fwd_nbr_spinor, back_nbr_spinor, 3, nFace,
85  nullptr;
86 #else
87  gFloat *fatlnk = gaugeLink(i, dir, oddBit, fatlinkEven, fatlinkOdd, 1);
88  gFloat *longlnk
89  = dslash_type == QUDA_ASQTAD_DSLASH ? gaugeLink(i, dir, oddBit, longlinkEven, longlinkOdd, 3) : nullptr;
90  sFloat *first_neighbor_spinor
91  = spinorNeighbor_5d<QUDA_4D_PC>(sid, dir, oddBit, spinorField, 1, stag_spinor_site_size);
92  sFloat *third_neighbor_spinor = dslash_type == QUDA_ASQTAD_DSLASH ?
93  spinorNeighbor_5d<QUDA_4D_PC>(sid, dir, oddBit, spinorField, 3, stag_spinor_site_size) :
94  nullptr;
95 #endif
96  sFloat gaugedSpinor[stag_spinor_site_size];
97 
98  if (dir % 2 == 0) {
99  su3Mul(gaugedSpinor, fatlnk, first_neighbor_spinor);
100  sum(&res[offset], &res[offset], gaugedSpinor, stag_spinor_site_size);
101 
103  su3Mul(gaugedSpinor, longlnk, third_neighbor_spinor);
104  sum(&res[offset], &res[offset], gaugedSpinor, stag_spinor_site_size);
105  }
106  } else {
107  su3Tmul(gaugedSpinor, fatlnk, first_neighbor_spinor);
109  sum(&res[offset], &res[offset], gaugedSpinor, stag_spinor_site_size);
110  } else {
111  sub(&res[offset], &res[offset], gaugedSpinor, stag_spinor_site_size);
112  }
113 
115  su3Tmul(gaugedSpinor, longlnk, third_neighbor_spinor);
116  sub(&res[offset], &res[offset], gaugedSpinor, stag_spinor_site_size);
117  }
118  }
119  }
120 
121  if (daggerBit) negx(&res[offset], stag_spinor_site_size);
122  } // 4-d volume
123  } // right-hand-side
124 }
125 
126 void staggeredDslash(ColorSpinorField *out, void **fatlink, void **longlink, void **ghost_fatlink,
127  void **ghost_longlink, ColorSpinorField *in, int oddBit, int daggerBit, QudaPrecision sPrecision,
129 {
130  const int nSrc = in->X(4);
131 
132  QudaParity otherparity = QUDA_INVALID_PARITY;
133  if (oddBit == QUDA_EVEN_PARITY) {
134  otherparity = QUDA_ODD_PARITY;
135  } else if (oddBit == QUDA_ODD_PARITY) {
136  otherparity = QUDA_EVEN_PARITY;
137  } else {
138  errorQuda("ERROR: full parity not supported in function %s", __FUNCTION__);
139  }
140  const int nFace = dslash_type == QUDA_ASQTAD_DSLASH ? 3 : 1;
141 
142  in->exchangeGhost(otherparity, nFace, daggerBit);
143 
144  void **fwd_nbr_spinor = ((cpuColorSpinorField *)in)->fwdGhostFaceBuffer;
145  void **back_nbr_spinor = ((cpuColorSpinorField *)in)->backGhostFaceBuffer;
146 
147  if (sPrecision == QUDA_DOUBLE_PRECISION) {
148  if (gPrecision == QUDA_DOUBLE_PRECISION) {
149  staggeredDslashReference((double *)out->V(), (double **)fatlink, (double **)longlink, (double **)ghost_fatlink,
150  (double **)ghost_longlink, (double *)in->V(), (double **)fwd_nbr_spinor,
151  (double **)back_nbr_spinor, oddBit, daggerBit, nSrc, dslash_type);
152  } else {
153  staggeredDslashReference((double *)out->V(), (float **)fatlink, (float **)longlink, (float **)ghost_fatlink,
154  (float **)ghost_longlink, (double *)in->V(), (double **)fwd_nbr_spinor,
155  (double **)back_nbr_spinor, oddBit, daggerBit, nSrc, dslash_type);
156  }
157  } else {
158  if (gPrecision == QUDA_DOUBLE_PRECISION) {
159  staggeredDslashReference((float *)out->V(), (double **)fatlink, (double **)longlink, (double **)ghost_fatlink,
160  (double **)ghost_longlink, (float *)in->V(), (float **)fwd_nbr_spinor,
161  (float **)back_nbr_spinor, oddBit, daggerBit, nSrc, dslash_type);
162  } else {
163  staggeredDslashReference((float *)out->V(), (float **)fatlink, (float **)longlink, (float **)ghost_fatlink,
164  (float **)ghost_longlink, (float *)in->V(), (float **)fwd_nbr_spinor,
165  (float **)back_nbr_spinor, oddBit, daggerBit, nSrc, dslash_type);
166  }
167  }
168 }
169 
170 void staggeredMatDagMat(ColorSpinorField *out, void **fatlink, void **longlink, void **ghost_fatlink,
171  void **ghost_longlink, ColorSpinorField *in, double mass, int dagger_bit,
174 {
175  // assert sPrecision and gPrecision must be the same
176  if (sPrecision != gPrecision) { errorQuda("Spinor precision and gPrecison is not the same"); }
177 
178  QudaParity otherparity = QUDA_INVALID_PARITY;
179  if (parity == QUDA_EVEN_PARITY) {
180  otherparity = QUDA_ODD_PARITY;
181  } else if (parity == QUDA_ODD_PARITY) {
182  otherparity = QUDA_EVEN_PARITY;
183  } else {
184  errorQuda("ERROR: full parity not supported in function %s\n", __FUNCTION__);
185  }
186 
187  staggeredDslash(tmp, fatlink, longlink, ghost_fatlink, ghost_longlink, in, otherparity, dagger_bit, sPrecision,
188  gPrecision, dslash_type);
189 
190  staggeredDslash(out, fatlink, longlink, ghost_fatlink, ghost_longlink, tmp, parity, dagger_bit, sPrecision,
191  gPrecision, dslash_type);
192 
193  double msq_x4 = mass * mass * 4;
194  if (sPrecision == QUDA_DOUBLE_PRECISION) {
195  axmy((double *)in->V(), (double)msq_x4, (double *)out->V(), out->X(4) * Vh * stag_spinor_site_size);
196  } else {
197  axmy((float *)in->V(), (float)msq_x4, (float *)out->V(), out->X(4) * Vh * stag_spinor_site_size);
198  }
199 }
virtual void exchangeGhost(QudaParity parity, int nFace, int dagger, const MemoryLocation *pack_destination=nullptr, const MemoryLocation *halo_location=nullptr, bool gdr_send=false, bool gdr_recv=false, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION) const =0
const int * X() const
double mass
QudaDslashType dslash_type
int Vh
Definition: host_utils.cpp:38
QudaParity parity
Definition: covdev_test.cpp:40
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:34
enum QudaPrecision_s QudaPrecision
@ QUDA_ASQTAD_DSLASH
Definition: enum_quda.h:98
@ QUDA_LAPLACE_DSLASH
Definition: enum_quda.h:101
@ QUDA_EVEN_PARITY
Definition: enum_quda.h:284
@ QUDA_ODD_PARITY
Definition: enum_quda.h:284
@ QUDA_INVALID_PARITY
Definition: enum_quda.h:284
enum QudaDslashType_s QudaDslashType
@ QUDA_DOUBLE_PRECISION
Definition: enum_quda.h:65
enum QudaParity_s QudaParity
#define gauge_site_size
Definition: face_gauge.cpp:34
int faceVolume[4]
Definition: host_utils.cpp:41
#define stag_spinor_site_size
Definition: host_utils.h:10
FloatingPoint< float > Float
__host__ __device__ T sum(const array< T, s > &a)
Definition: utility.h:76
Main header file for the QUDA library.
void staggeredDslashReference(sFloat *res, gFloat **fatlink, gFloat **longlink, gFloat **ghostFatlink, gFloat **ghostLonglink, sFloat *spinorField, sFloat **fwd_nbr_spinor, sFloat **back_nbr_spinor, int oddBit, int daggerBit, int nSrc, QudaDslashType dslash_type)
void display_link_internal(Float *link)
void staggeredMatDagMat(ColorSpinorField *out, void **fatlink, void **longlink, void **ghost_fatlink, void **ghost_longlink, ColorSpinorField *in, double mass, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision, ColorSpinorField *tmp, QudaParity parity, QudaDslashType dslash_type)
void staggeredDslash(ColorSpinorField *out, void **fatlink, void **longlink, void **ghost_fatlink, void **ghost_longlink, ColorSpinorField *in, int oddBit, int daggerBit, QudaPrecision sPrecision, QudaPrecision gPrecision, QudaDslashType dslash_type)
void * memset(void *s, int c, size_t n)
#define errorQuda(...)
Definition: util_quda.h:120