QUDA  0.9.0
staggered_dslash_reference.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 #include <string.h>
5 
6 #include <test_util.h>
7 #include <quda_internal.h>
8 #include <quda.h>
9 #include <util_quda.h>
11 #include "misc.h"
12 #include <blas_quda.h>
13 
14 #include <blas_reference.h>
15 
16 extern void *memset(void *s, int c, size_t n);
17 
18 #include <dslash_util.h>
19 
20 //
21 // dslashReference()
22 //
23 // if oddBit is zero: calculate even parity spinor elements (using odd parity spinor)
24 // if oddBit is one: calculate odd parity spinor elements
25 //
26 // if daggerBit is zero: perform ordinary dslash operator
27 // if daggerBit is one: perform hermitian conjugate of dslash
28 //
29 template<typename Float>
30 void display_link_internal(Float* link)
31 {
32  int i, j;
33 
34  for (i = 0;i < 3; i++){
35  for(j=0;j < 3; j++){
36  printf("(%10f,%10f) \t", link[i*3*2 + j*2], link[i*3*2 + j*2 + 1]);
37  }
38  printf("\n");
39  }
40  printf("\n");
41  return;
42 }
43 
44 
45 template <typename sFloat, typename gFloat>
46 void dslashReference(sFloat *res, gFloat **fatlink, gFloat** longlink, sFloat *spinorField,
47  int oddBit, int daggerBit)
48 {
49  const int nSrc = Ls; // Ls should already be set
50 
51  for (int i=0; i<Vh*mySpinorSiteSize*nSrc; i++) res[i] = 0.0;
52 
53  gFloat *fatlinkEven[4], *fatlinkOdd[4];
54  gFloat *longlinkEven[4], *longlinkOdd[4];
55 
56  for (int dir = 0; dir < 4; dir++) {
57  fatlinkEven[dir] = fatlink[dir];
58  fatlinkOdd[dir] = fatlink[dir] + Vh*gaugeSiteSize;
59  longlinkEven[dir] =longlink[dir];
60  longlinkOdd[dir] = longlink[dir] + Vh*gaugeSiteSize;
61  }
62 
63  for (int xs=0; xs<nSrc; xs++) {
64 
65  for (int i = 0; i < Vh; i++) {
66  int sid = i + xs*Vh;
68 
69  for (int dir = 0; dir < 8; dir++) {
70  gFloat* fatlnk = gaugeLink(i, dir, oddBit, fatlinkEven, fatlinkOdd, 1);
71  gFloat* longlnk = gaugeLink(i, dir, oddBit, longlinkEven, longlinkOdd, 3);
72 
73  sFloat *first_neighbor_spinor = spinorNeighbor_5d<QUDA_4D_PC>(sid, dir, oddBit, spinorField, 1, mySpinorSiteSize);
74  sFloat *third_neighbor_spinor = spinorNeighbor_5d<QUDA_4D_PC>(sid, dir, oddBit, spinorField, 3, mySpinorSiteSize);
75 
76  sFloat gaugedSpinor[mySpinorSiteSize];
77 
78  if (dir % 2 == 0){
79  su3Mul(gaugedSpinor, fatlnk, first_neighbor_spinor);
80  sum(&res[offset], &res[offset], gaugedSpinor, mySpinorSiteSize);
81  su3Mul(gaugedSpinor, longlnk, third_neighbor_spinor);
82  sum(&res[offset], &res[offset], gaugedSpinor, mySpinorSiteSize);
83  } else {
84  su3Tmul(gaugedSpinor, fatlnk, first_neighbor_spinor);
85  sub(&res[offset], &res[offset], gaugedSpinor, mySpinorSiteSize);
86  su3Tmul(gaugedSpinor, longlnk, third_neighbor_spinor);
87  sub(&res[offset], &res[offset], gaugedSpinor, mySpinorSiteSize);
88  }
89  }
90 
91  if (daggerBit) negx(&res[offset], mySpinorSiteSize);
92  } // 4-d volume
93  } // right-hand-side
94 
95 }
96 
97 
98 
99 
100 void staggered_dslash(void *res, void **fatlink, void** longlink, void *spinorField, int oddBit, int daggerBit,
101  QudaPrecision sPrecision, QudaPrecision gPrecision) {
102 
103  if (sPrecision == QUDA_DOUBLE_PRECISION) {
104  if (gPrecision == QUDA_DOUBLE_PRECISION){
105  dslashReference((double*)res, (double**)fatlink, (double**)longlink, (double*)spinorField, oddBit, daggerBit);
106  }else{
107  dslashReference((double*)res, (float**)fatlink, (float**)longlink, (double*)spinorField, oddBit, daggerBit);
108  }
109  }
110  else{
111  if (gPrecision == QUDA_DOUBLE_PRECISION){
112  dslashReference((float*)res, (double**)fatlink, (double**)longlink, (float*)spinorField, oddBit, daggerBit);
113  }else{
114  dslashReference((float*)res, (float**)fatlink, (float**)longlink, (float*)spinorField, oddBit, daggerBit);
115  }
116  }
117 }
118 
119 
120 
121 
122 template <typename sFloat, typename gFloat>
123 void Mat(sFloat *out, gFloat **fatlink, gFloat** longlink, sFloat *in, sFloat kappa, int daggerBit)
124 {
125  sFloat *inEven = in;
126  sFloat *inOdd = in + Vh*mySpinorSiteSize;
127  sFloat *outEven = out;
128  sFloat *outOdd = out + Vh*mySpinorSiteSize;
129 
130  // full dslash operator
131  dslashReference(outOdd, fatlink, longlink, inEven, 1, daggerBit);
132  dslashReference(outEven, fatlink, longlink, inOdd, 0, daggerBit);
133  }
134 
135 
136 void
137 mat(void *out, void **fatlink, void** longlink, void *in, double kappa, int dagger_bit,
138  QudaPrecision sPrecision, QudaPrecision gPrecision)
139 {
140 
141  if (sPrecision == QUDA_DOUBLE_PRECISION){
142  if (gPrecision == QUDA_DOUBLE_PRECISION) {
143  Mat((double*)out, (double**)fatlink, (double**)longlink, (double*)in, (double)kappa, dagger_bit);
144  }else {
145  Mat((double*)out, (float**)fatlink, (float**)longlink, (double*)in, (double)kappa, dagger_bit);
146  }
147  }else{
148  if (gPrecision == QUDA_DOUBLE_PRECISION){
149  Mat((float*)out, (double**)fatlink, (double**)longlink, (float*)in, (float)kappa, dagger_bit);
150  }else {
151  Mat((float*)out, (float**)fatlink, (float**)longlink, (float*)in, (float)kappa, dagger_bit);
152  }
153  }
154 
155  // lastly apply the kappa term
156  xpay(in, -kappa, out, V*mySpinorSiteSize, sPrecision);
157 }
158 
159 
160 
161 template <typename sFloat, typename gFloat>
162 void
163 Matdagmat(sFloat *out, gFloat **fatlink, gFloat** longlink, sFloat *in, sFloat mass, int daggerBit, sFloat* tmp, QudaParity parity)
164 {
165 
166  sFloat msq_x4 = mass*mass*4;
167 
168  switch(parity){
169  case QUDA_EVEN_PARITY:
170  {
171  sFloat *inEven = in;
172  sFloat *outEven = out;
173  dslashReference(tmp, fatlink, longlink, inEven, 1, daggerBit);
174  dslashReference(outEven, fatlink, longlink, tmp, 0, daggerBit);
175 
176  // lastly apply the mass term
177  axmy(inEven, msq_x4, outEven, Ls*Vh*mySpinorSiteSize);
178  break;
179  }
180  case QUDA_ODD_PARITY:
181  {
182  sFloat *inOdd = in;
183  sFloat *outOdd = out;
184  dslashReference(tmp, fatlink, longlink, inOdd, 0, daggerBit);
185  dslashReference(outOdd, fatlink, longlink, tmp, 1, daggerBit);
186 
187  // lastly apply the mass term
188  axmy(inOdd, msq_x4, outOdd, Ls*Vh*mySpinorSiteSize);
189  break;
190  }
191 
192  default:
193  fprintf(stderr, "ERROR: invalid parity in %s,line %d\n", __FUNCTION__, __LINE__);
194  break;
195  }
196 
197 }
198 
199 
200 
201 void
202 matdagmat(void *out, void **fatlink, void** longlink, void *in, double mass, int dagger_bit,
203  QudaPrecision sPrecision, QudaPrecision gPrecision, void* tmp, QudaParity parity)
204 {
205 
206  if (sPrecision == QUDA_DOUBLE_PRECISION){
207  if (gPrecision == QUDA_DOUBLE_PRECISION) {
208  Matdagmat((double*)out, (double**)fatlink, (double**)longlink, (double*)in, (double)mass, dagger_bit, (double*)tmp, parity);
209  }else {
210  Matdagmat((double*)out, (float**)fatlink, (float**)longlink, (double*)in, (double)mass, dagger_bit, (double*) tmp, parity);
211  }
212  }else{
213  if (gPrecision == QUDA_DOUBLE_PRECISION){
214  Matdagmat((float*)out, (double**)fatlink, (double**)longlink, (float*)in, (float)mass, dagger_bit, (float*)tmp, parity);
215  }else {
216  Matdagmat((float*)out, (float**)fatlink, (float**)longlink, (float*)in, (float)mass, dagger_bit, (float*)tmp, parity);
217  }
218  }
219 }
220 
221 
222 
223 
224 
225 // Apply the even-odd preconditioned Dirac operator
226 template <typename sFloat, typename gFloat>
227 static void MatPC(sFloat *outEven, gFloat **fatlink, gFloat** longlink, sFloat *inEven, int dagger, QudaMatPCType matpc_type) {
228 
229  sFloat *tmp = (sFloat*)malloc(Vh*mySpinorSiteSize*sizeof(sFloat));
230 
231  // full dslash operator
233  dslashReference(tmp, fatlink, longlink, inEven, 1, dagger);
234  dslashReference(outEven, fatlink, longlink, tmp, 0, dagger);
235  } else {
236  dslashReference(tmp, fatlink, longlink, inEven, 0, dagger);
237  dslashReference(outEven, fatlink, longlink, tmp, 1, dagger);
238  }
239 
240  free(tmp);
241 }
242 
243 
244 void
245 staggered_matpc(void *outEven, void **fatlink, void**longlink, void *inEven, double kappa,
246  QudaMatPCType matpc_type, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision)
247 {
248 
249  if (sPrecision == QUDA_DOUBLE_PRECISION)
250  if (gPrecision == QUDA_DOUBLE_PRECISION) {
251  MatPC((double*)outEven, (double**)fatlink, (double**)longlink, (double*)inEven, dagger_bit, matpc_type);
252  }
253  else{
254  MatPC((double*)outEven, (double**)fatlink, (double**)longlink, (double*)inEven, dagger_bit, matpc_type);
255  }
256  else {
257  if (gPrecision == QUDA_DOUBLE_PRECISION){
258  MatPC((float*)outEven, (double**)fatlink, (double**)longlink, (float*)inEven, dagger_bit, matpc_type);
259  }else{
260  MatPC((float*)outEven, (float**)fatlink, (float**)longlink, (float*)inEven, dagger_bit, matpc_type);
261  }
262  }
263 
264  // lastly apply the kappa term
265  double kappa2 = -kappa*kappa;
266  xpay(inEven, kappa2, outEven, Ls*Vh*mySpinorSiteSize, sPrecision);
267 }
268 
269 #ifdef MULTI_GPU
270 
271 template <typename sFloat, typename gFloat>
272 void dslashReference_mg4dir(sFloat *res, gFloat **fatlink, gFloat** longlink,
273  gFloat** ghostFatlink, gFloat** ghostLonglink,
274  sFloat *spinorField, sFloat** fwd_nbr_spinor,
275  sFloat** back_nbr_spinor, int oddBit, int daggerBit, int nSrc)
276 {
277  for (int i=0; i<Vh*mySpinorSiteSize*nSrc; i++) res[i] = 0.0;
278 
279  gFloat *fatlinkEven[4], *fatlinkOdd[4];
280  gFloat *longlinkEven[4], *longlinkOdd[4];
281  gFloat *ghostFatlinkEven[4], *ghostFatlinkOdd[4];
282  gFloat *ghostLonglinkEven[4], *ghostLonglinkOdd[4];
283 
284  for (int dir = 0; dir < 4; dir++) {
285  fatlinkEven[dir] = fatlink[dir];
286  fatlinkOdd[dir] = fatlink[dir] + Vh*gaugeSiteSize;
287  longlinkEven[dir] =longlink[dir];
288  longlinkOdd[dir] = longlink[dir] + Vh*gaugeSiteSize;
289 
290  ghostFatlinkEven[dir] = ghostFatlink[dir];
291  ghostFatlinkOdd[dir] = ghostFatlink[dir] + (faceVolume[dir]/2)*gaugeSiteSize;
292  ghostLonglinkEven[dir] = ghostLonglink[dir];
293  ghostLonglinkOdd[dir] = ghostLonglink[dir] + 3*(faceVolume[dir]/2)*gaugeSiteSize;
294  }
295 
296  for (int xs=0; xs<nSrc; xs++) {
297 
298  for (int i = 0; i < Vh; i++) {
299  int sid = i + xs*Vh;
301 
302  for (int dir = 0; dir < 8; dir++) {
303  gFloat* fatlnk = gaugeLink_mg4dir(i, dir, oddBit, fatlinkEven, fatlinkOdd, ghostFatlinkEven, ghostFatlinkOdd, 1, 1);
304  gFloat* longlnk = gaugeLink_mg4dir(i, dir, oddBit, longlinkEven, longlinkOdd, ghostLonglinkEven, ghostLonglinkOdd, 3, 3);
305 
306  sFloat *first_neighbor_spinor = spinorNeighbor_5d_mgpu<QUDA_4D_PC>(sid, dir, oddBit, spinorField, fwd_nbr_spinor, back_nbr_spinor, 1, 3, mySpinorSiteSize);
307  sFloat *third_neighbor_spinor = spinorNeighbor_5d_mgpu<QUDA_4D_PC>(sid, dir, oddBit, spinorField, fwd_nbr_spinor, back_nbr_spinor, 3, 3, mySpinorSiteSize);
308 
309  sFloat gaugedSpinor[mySpinorSiteSize];
310 
311  if (dir % 2 == 0){
312  su3Mul(gaugedSpinor, fatlnk, first_neighbor_spinor);
313  sum(&res[offset], &res[offset], gaugedSpinor, mySpinorSiteSize);
314  su3Mul(gaugedSpinor, longlnk, third_neighbor_spinor);
315  sum(&res[offset], &res[offset], gaugedSpinor, mySpinorSiteSize);
316  } else {
317  su3Tmul(gaugedSpinor, fatlnk, first_neighbor_spinor);
318  sub(&res[offset], &res[offset], gaugedSpinor, mySpinorSiteSize);
319  su3Tmul(gaugedSpinor, longlnk, third_neighbor_spinor);
320  sub(&res[offset], &res[offset], gaugedSpinor, mySpinorSiteSize);
321  }
322  }
323 
324  if (daggerBit) negx(&res[offset], mySpinorSiteSize);
325  } // 4-d volume
326  } // right-hand-side
327 
328 }
329 
330 
331 
332 void staggered_dslash_mg4dir(cpuColorSpinorField* out, void **fatlink, void** longlink, void** ghost_fatlink,
333  void** ghost_longlink, cpuColorSpinorField* in, int oddBit, int daggerBit,
334  QudaPrecision sPrecision, QudaPrecision gPrecision)
335 {
336  const int nSrc = in->X(4);
337 
338  QudaParity otherparity = QUDA_INVALID_PARITY;
339  if (oddBit == QUDA_EVEN_PARITY) {
340  otherparity = QUDA_ODD_PARITY;
341  } else if (oddBit == QUDA_ODD_PARITY) {
342  otherparity = QUDA_EVEN_PARITY;
343  } else {
344  errorQuda("ERROR: full parity not supported in function %s", __FUNCTION__);
345  }
346  const int nFace = 3;
347 
348  in->exchangeGhost(otherparity, nFace, daggerBit);
349 
350  void** fwd_nbr_spinor = in->fwdGhostFaceBuffer;
351  void** back_nbr_spinor = in->backGhostFaceBuffer;
352 
353  if (sPrecision == QUDA_DOUBLE_PRECISION) {
354  if (gPrecision == QUDA_DOUBLE_PRECISION) {
355  dslashReference_mg4dir((double*)out->V(), (double**)fatlink, (double**)longlink, (double**)ghost_fatlink, (double**)ghost_longlink,
356  (double*)in->V(), (double**)fwd_nbr_spinor, (double**)back_nbr_spinor, oddBit, daggerBit, nSrc);
357  } else {
358  dslashReference_mg4dir((double*)out->V(), (float**)fatlink, (float**)longlink, (float**)ghost_fatlink, (float**)ghost_longlink,
359  (double*)in->V(), (double**)fwd_nbr_spinor, (double**)back_nbr_spinor, oddBit, daggerBit, nSrc);
360  }
361  } else {
362  if (gPrecision == QUDA_DOUBLE_PRECISION) {
363  dslashReference_mg4dir((float*)out->V(), (double**)fatlink, (double**)longlink, (double**)ghost_fatlink, (double**)ghost_longlink,
364  (float*)in->V(), (float**)fwd_nbr_spinor, (float**)back_nbr_spinor, oddBit, daggerBit, nSrc);
365  } else {
366  dslashReference_mg4dir((float*)out->V(), (float**)fatlink, (float**)longlink, (float**)ghost_fatlink, (float**)ghost_longlink,
367  (float*)in->V(), (float**)fwd_nbr_spinor, (float**)back_nbr_spinor, oddBit, daggerBit, nSrc);
368  }
369  }
370 
371 }
372 
373 void
374 matdagmat_mg4dir(cpuColorSpinorField* out, void **fatlink, void** longlink, void** ghost_fatlink, void** ghost_longlink,
375  cpuColorSpinorField* in, double mass, int dagger_bit,
377 {
378  //assert sPrecision and gPrecision must be the same
379  if (sPrecision != gPrecision){
380  errorQuda("Spinor precision and gPrecison is not the same");
381  }
382 
383  QudaParity otherparity = QUDA_INVALID_PARITY;
384  if (parity == QUDA_EVEN_PARITY){
385  otherparity = QUDA_ODD_PARITY;
386  } else if (parity == QUDA_ODD_PARITY) {
387  otherparity = QUDA_EVEN_PARITY;
388  } else {
389  errorQuda("ERROR: full parity not supported in function %s\n", __FUNCTION__);
390  }
391 
392  staggered_dslash_mg4dir(tmp, fatlink, longlink, ghost_fatlink, ghost_longlink,
393  in, otherparity, dagger_bit, sPrecision, gPrecision);
394 
395  staggered_dslash_mg4dir(out, fatlink, longlink, ghost_fatlink, ghost_longlink,
396  tmp, parity, dagger_bit, sPrecision, gPrecision);
397 
398  double msq_x4 = mass*mass*4;
399  if (sPrecision == QUDA_DOUBLE_PRECISION){
400  axmy((double*)in->V(), (double)msq_x4, (double*)out->V(), out->X(4)*Vh*mySpinorSiteSize);
401  }else{
402  axmy((float*)in->V(), (float)msq_x4, (float*)out->V(), out->X(4)*Vh*mySpinorSiteSize);
403  }
404 
405 }
406 
407 #endif
408 
void display_link_internal(Float *link)
void free(void *)
void xpay(ColorSpinorField &x, const double &a, ColorSpinorField &y)
Definition: blas_quda.cu:173
enum QudaPrecision_s QudaPrecision
static void MatPC(sFloat *outEven, gFloat **fatlink, gFloat **longlink, sFloat *inEven, int dagger, QudaMatPCType matpc_type)
void staggered_matpc(void *outEven, void **fatlink, void **longlink, void *inEven, double kappa, QudaMatPCType matpc_type, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision)
#define errorQuda(...)
Definition: util_quda.h:90
static void sub(Float *dst, Float *a, Float *b, int cnt)
Definition: dslash_util.h:14
void staggered_dslash(void *res, void **fatlink, void **longlink, void *spinorField, int oddBit, int daggerBit, QudaPrecision sPrecision, QudaPrecision gPrecision)
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:44
void mat(void *out, void **fatlink, void **longlink, void *in, double kappa, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision)
static void axmy(Float *x, Float a, Float *y, int len)
Definition: dslash_util.h:39
void matdagmat_mg4dir(cpuColorSpinorField *out, void **link, void **ghostLink, cpuColorSpinorField *in, int dagger_bit, int mu, QudaPrecision sPrecision, QudaPrecision gPrecision, cpuColorSpinorField *tmp, QudaParity parity)
void exchangeGhost(QudaParity parity, int nFace, int dagger, const MemoryLocation *pack_destination=nullptr, const MemoryLocation *halo_location=nullptr, bool gdr_send=false, bool gdr_recv=false) const
This is a unified ghost exchange function for doing a complete halo exchange regardless of the type o...
size_t size_t offset
int Ls
Definition: test_util.cpp:39
void * longlink[4]
void * malloc(size_t __size) __attribute__((__warn_unused_result__)) __attribute__((alloc_size(1)))
void matdagmat(void *out, void **fatlink, void **longlink, void *in, double mass, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision, void *tmp, QudaParity parity)
int printf(const char *,...) __attribute__((__format__(__printf__
VOLATILE spinorFloat kappa
__host__ __device__ void sum(double &a, double &b)
cpuColorSpinorField * in
#define mySpinorSiteSize
int V
Definition: test_util.cpp:28
enum QudaMatPCType_s QudaMatPCType
#define gaugeSiteSize
Definition: test_util.h:6
static void * backGhostFaceBuffer[QUDA_MAX_DIM]
enum QudaParity_s QudaParity
QudaMatPCType matpc_type
Definition: test_util.cpp:1652
int fprintf(FILE *, const char *,...) __attribute__((__format__(__printf__
static void * fwdGhostFaceBuffer[QUDA_MAX_DIM]
static Float * gaugeLink(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd, int nbr_distance)
Definition: dslash_util.h:104
cpuColorSpinorField * out
void * memset(void *s, int c, size_t n)
void * fatlink[4]
Main header file for the QUDA library.
int Vh
Definition: test_util.cpp:29
void staggered_dslash_mg4dir(cpuColorSpinorField *out, void **fatlink, void **longlink, void **ghost_fatlink, void **ghost_longlink, cpuColorSpinorField *in, int oddBit, int daggerBit, QudaPrecision sPrecision, QudaPrecision gPrecision)
int faceVolume[4]
Definition: test_util.cpp:32
const void * c
const int * X() const
static void su3Mul(sFloat *res, gFloat *mat, sFloat *vec)
Definition: dslash_util.h:80
static void su3Tmul(sFloat *res, gFloat *mat, sFloat *vec)
Definition: dslash_util.h:85
void dslashReference(sFloat *res, gFloat **fatlink, gFloat **longlink, sFloat *spinorField, int oddBit, int daggerBit)
QudaParity parity
Definition: covdev_test.cpp:53
void Matdagmat(sFloat *out, gFloat **fatlink, gFloat **longlink, sFloat *in, sFloat mass, int daggerBit, sFloat *tmp, QudaParity parity)
void Mat(sFloat *out, gFloat **fatlink, gFloat **longlink, sFloat *in, sFloat kappa, int daggerBit)
static void negx(Float *x, int len)
Definition: dslash_util.h:51