QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
staggered_dslash_reference.cpp
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <math.h>
4 #include <string.h>
5 
6 #include <test_util.h>
7 #include <quda_internal.h>
8 #include <quda.h>
9 #include <util_quda.h>
11 #include "misc.h"
12 #include <blas_quda.h>
13 
14 #include <blas_reference.h>
15 
16 extern void *memset(void *s, int c, size_t n);
17 
18 #include <dslash_util.h>
19 
20 //
21 // dslashReference()
22 //
23 // if oddBit is zero: calculate even parity spinor elements (using odd parity spinor)
24 // if oddBit is one: calculate odd parity spinor elements
25 //
26 // if daggerBit is zero: perform ordinary dslash operator
27 // if daggerBit is one: perform hermitian conjugate of dslash
28 //
29 template<typename Float>
30 void display_link_internal(Float* link)
31 {
32  int i, j;
33 
34  for (i = 0;i < 3; i++){
35  for(j=0;j < 3; j++){
36  printf("(%10f,%10f) \t", link[i*3*2 + j*2], link[i*3*2 + j*2 + 1]);
37  }
38  printf("\n");
39  }
40  printf("\n");
41  return;
42 }
43 
44 template <typename sFloat, typename gFloat>
45 void dslashReference(sFloat *res, gFloat **fatlink, gFloat **longlink, gFloat **ghostFatlink, gFloat **ghostLonglink,
46  sFloat *spinorField, sFloat **fwd_nbr_spinor, sFloat **back_nbr_spinor, int oddBit, int daggerBit, int nSrc,
48 {
49  for (int i=0; i<Vh*mySpinorSiteSize*nSrc; i++) res[i] = 0.0;
50 
51  gFloat *fatlinkEven[4], *fatlinkOdd[4];
52  gFloat *longlinkEven[4], *longlinkOdd[4];
53 
54 #ifdef MULTI_GPU
55  gFloat *ghostFatlinkEven[4], *ghostFatlinkOdd[4];
56  gFloat *ghostLonglinkEven[4], *ghostLonglinkOdd[4];
57 #endif
58 
59  for (int dir = 0; dir < 4; dir++) {
60  fatlinkEven[dir] = fatlink[dir];
61  fatlinkOdd[dir] = fatlink[dir] + Vh*gaugeSiteSize;
62  longlinkEven[dir] =longlink[dir];
63  longlinkOdd[dir] = longlink[dir] + Vh*gaugeSiteSize;
64 
65 #ifdef MULTI_GPU
66  ghostFatlinkEven[dir] = ghostFatlink[dir];
67  ghostFatlinkOdd[dir] = ghostFatlink[dir] + (faceVolume[dir]/2)*gaugeSiteSize;
68  ghostLonglinkEven[dir] = ghostLonglink[dir];
69  ghostLonglinkOdd[dir] = ghostLonglink[dir] + 3*(faceVolume[dir]/2)*gaugeSiteSize;
70 #endif
71  }
72 
73  for (int xs=0; xs<nSrc; xs++) {
74 
75  for (int i = 0; i < Vh; i++) {
76  int sid = i + xs*Vh;
77  int offset = mySpinorSiteSize*sid;
78 
79  for (int dir = 0; dir < 8; dir++) {
80 #ifdef MULTI_GPU
81  const int nFace = dslash_type == QUDA_ASQTAD_DSLASH ? 3 : 1;
82  gFloat* fatlnk = gaugeLink_mg4dir(i, dir, oddBit, fatlinkEven, fatlinkOdd, ghostFatlinkEven, ghostFatlinkOdd, 1, 1);
83  gFloat *longlnk = dslash_type == QUDA_ASQTAD_DSLASH ?
84  gaugeLink_mg4dir(i, dir, oddBit, longlinkEven, longlinkOdd, ghostLonglinkEven, ghostLonglinkOdd, 3, 3) :
85  nullptr;
86  sFloat *first_neighbor_spinor = spinorNeighbor_5d_mgpu<QUDA_4D_PC>(
87  sid, dir, oddBit, spinorField, fwd_nbr_spinor, back_nbr_spinor, 1, nFace, mySpinorSiteSize);
88  sFloat *third_neighbor_spinor = dslash_type == QUDA_ASQTAD_DSLASH ?
89  spinorNeighbor_5d_mgpu<QUDA_4D_PC>(
90  sid, dir, oddBit, spinorField, fwd_nbr_spinor, back_nbr_spinor, 3, nFace, mySpinorSiteSize) :
91  nullptr;
92 #else
93  gFloat *fatlnk = gaugeLink(i, dir, oddBit, fatlinkEven, fatlinkOdd, 1);
94  gFloat *longlnk
95  = dslash_type == QUDA_ASQTAD_DSLASH ? gaugeLink(i, dir, oddBit, longlinkEven, longlinkOdd, 3) : nullptr;
96  sFloat *first_neighbor_spinor = spinorNeighbor_5d<QUDA_4D_PC>(sid, dir, oddBit, spinorField, 1, mySpinorSiteSize);
97  sFloat *third_neighbor_spinor = dslash_type == QUDA_ASQTAD_DSLASH ?
98  spinorNeighbor_5d<QUDA_4D_PC>(sid, dir, oddBit, spinorField, 3, mySpinorSiteSize) :
99  nullptr;
100 #endif
101  sFloat gaugedSpinor[mySpinorSiteSize];
102 
103  if (dir % 2 == 0){
104  su3Mul(gaugedSpinor, fatlnk, first_neighbor_spinor);
105  sum(&res[offset], &res[offset], gaugedSpinor, mySpinorSiteSize);
106 
107  if (dslash_type == QUDA_ASQTAD_DSLASH) {
108  su3Mul(gaugedSpinor, longlnk, third_neighbor_spinor);
109  sum(&res[offset], &res[offset], gaugedSpinor, mySpinorSiteSize);
110  }
111  } else {
112  su3Tmul(gaugedSpinor, fatlnk, first_neighbor_spinor);
113  if (dslash_type == QUDA_LAPLACE_DSLASH) {
114  sum(&res[offset], &res[offset], gaugedSpinor, mySpinorSiteSize);
115  } else {
116  sub(&res[offset], &res[offset], gaugedSpinor, mySpinorSiteSize);
117  }
118 
119  if (dslash_type == QUDA_ASQTAD_DSLASH) {
120  su3Tmul(gaugedSpinor, longlnk, third_neighbor_spinor);
121  sub(&res[offset], &res[offset], gaugedSpinor, mySpinorSiteSize);
122  }
123  }
124  }
125 
126  if (daggerBit) negx(&res[offset], mySpinorSiteSize);
127  } // 4-d volume
128  } // right-hand-side
129 
130 }
131 
133  void **ghost_longlink, cpuColorSpinorField *in, int oddBit, int daggerBit, QudaPrecision sPrecision,
135 {
136  const int nSrc = in->X(4);
137 
138  QudaParity otherparity = QUDA_INVALID_PARITY;
139  if (oddBit == QUDA_EVEN_PARITY) {
140  otherparity = QUDA_ODD_PARITY;
141  } else if (oddBit == QUDA_ODD_PARITY) {
142  otherparity = QUDA_EVEN_PARITY;
143  } else {
144  errorQuda("ERROR: full parity not supported in function %s", __FUNCTION__);
145  }
146  const int nFace = dslash_type == QUDA_ASQTAD_DSLASH ? 3 : 1;
147 
148  in->exchangeGhost(otherparity, nFace, daggerBit);
149 
150  void** fwd_nbr_spinor = in->fwdGhostFaceBuffer;
151  void** back_nbr_spinor = in->backGhostFaceBuffer;
152 
153  if (sPrecision == QUDA_DOUBLE_PRECISION) {
154  if (gPrecision == QUDA_DOUBLE_PRECISION) {
155  dslashReference((double *)out->V(), (double **)fatlink, (double **)longlink, (double **)ghost_fatlink,
156  (double **)ghost_longlink, (double *)in->V(), (double **)fwd_nbr_spinor, (double **)back_nbr_spinor, oddBit,
157  daggerBit, nSrc, dslash_type);
158  } else {
159  dslashReference((double *)out->V(), (float **)fatlink, (float **)longlink, (float **)ghost_fatlink,
160  (float **)ghost_longlink, (double *)in->V(), (double **)fwd_nbr_spinor, (double **)back_nbr_spinor, oddBit,
161  daggerBit, nSrc, dslash_type);
162  }
163  } else {
164  if (gPrecision == QUDA_DOUBLE_PRECISION) {
165  dslashReference((float *)out->V(), (double **)fatlink, (double **)longlink, (double **)ghost_fatlink,
166  (double **)ghost_longlink, (float *)in->V(), (float **)fwd_nbr_spinor, (float **)back_nbr_spinor, oddBit,
167  daggerBit, nSrc, dslash_type);
168  } else {
169  dslashReference((float *)out->V(), (float **)fatlink, (float **)longlink, (float **)ghost_fatlink,
170  (float **)ghost_longlink, (float *)in->V(), (float **)fwd_nbr_spinor, (float **)back_nbr_spinor, oddBit,
171  daggerBit, nSrc, dslash_type);
172  }
173  }
174 }
175 
177  cpuColorSpinorField *in, double mass, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision,
179 {
180  //assert sPrecision and gPrecision must be the same
181  if (sPrecision != gPrecision){
182  errorQuda("Spinor precision and gPrecison is not the same");
183  }
184 
185  QudaParity otherparity = QUDA_INVALID_PARITY;
186  if (parity == QUDA_EVEN_PARITY){
187  otherparity = QUDA_ODD_PARITY;
188  } else if (parity == QUDA_ODD_PARITY) {
189  otherparity = QUDA_EVEN_PARITY;
190  } else {
191  errorQuda("ERROR: full parity not supported in function %s\n", __FUNCTION__);
192  }
193 
194  staggered_dslash(tmp, fatlink, longlink, ghost_fatlink, ghost_longlink, in, otherparity, dagger_bit, sPrecision,
195  gPrecision, dslash_type);
196 
197  staggered_dslash(out, fatlink, longlink, ghost_fatlink, ghost_longlink, tmp, parity, dagger_bit, sPrecision,
198  gPrecision, dslash_type);
199 
200  double msq_x4 = mass*mass*4;
201  if (sPrecision == QUDA_DOUBLE_PRECISION){
202  axmy((double*)in->V(), (double)msq_x4, (double*)out->V(), out->X(4)*Vh*mySpinorSiteSize);
203  }else{
204  axmy((float*)in->V(), (float)msq_x4, (float*)out->V(), out->X(4)*Vh*mySpinorSiteSize);
205  }
206 
207 }
QudaDslashType dslash_type
Definition: test_util.cpp:1621
void display_link_internal(Float *link)
static void sum(Float *dst, Float *a, Float *b, int cnt)
Definition: dslash_util.h:8
enum QudaPrecision_s QudaPrecision
#define errorQuda(...)
Definition: util_quda.h:121
void ** ghost_fatlink
static void sub(Float *dst, Float *a, Float *b, int cnt)
Definition: dslash_util.h:14
cudaColorSpinorField * tmp
Definition: covdev_test.cpp:44
static void axmy(Float *x, Float a, Float *y, int len)
Definition: dslash_util.h:39
void ** ghost_longlink
cpuColorSpinorField * in
void staggered_dslash(cpuColorSpinorField *out, void **fatlink, void **longlink, void **ghost_fatlink, void **ghost_longlink, cpuColorSpinorField *in, int oddBit, int daggerBit, QudaPrecision sPrecision, QudaPrecision gPrecision, QudaDslashType dslash_type)
static void * backGhostFaceBuffer[QUDA_MAX_DIM]
enum QudaParity_s QudaParity
void exchangeGhost(QudaParity parity, int nFace, int dagger, const MemoryLocation *pack_destination=nullptr, const MemoryLocation *halo_location=nullptr, bool gdr_send=false, bool gdr_recv=false, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION) const
This is a unified ghost exchange function for doing a complete halo exchange regardless of the type o...
static void * fwdGhostFaceBuffer[QUDA_MAX_DIM]
#define mySpinorSiteSize
static Float * gaugeLink(int i, int dir, int oddBit, Float **gaugeEven, Float **gaugeOdd, int nbr_distance)
Definition: dslash_util.h:104
cpuColorSpinorField * out
void * memset(void *s, int c, size_t n)
Main header file for the QUDA library.
__shared__ float s[]
void matdagmat(cpuColorSpinorField *out, void **fatlink, void **longlink, void **ghost_fatlink, void **ghost_longlink, cpuColorSpinorField *in, double mass, int dagger_bit, QudaPrecision sPrecision, QudaPrecision gPrecision, cpuColorSpinorField *tmp, QudaParity parity, QudaDslashType dslash_type)
enum QudaDslashType_s QudaDslashType
int faceVolume[4]
Definition: test_util.cpp:31
void * longlink
const int * X() const
static void su3Mul(sFloat *res, gFloat *mat, sFloat *vec)
Definition: dslash_util.h:80
void * fatlink
static void su3Tmul(sFloat *res, gFloat *mat, sFloat *vec)
Definition: dslash_util.h:85
QudaParity parity
Definition: covdev_test.cpp:54
void dslashReference(sFloat *res, gFloat **fatlink, gFloat **longlink, gFloat **ghostFatlink, gFloat **ghostLonglink, sFloat *spinorField, sFloat **fwd_nbr_spinor, sFloat **back_nbr_spinor, int oddBit, int daggerBit, int nSrc, QudaDslashType dslash_type)
static void negx(Float *x, int len)
Definition: dslash_util.h:51
#define gaugeSiteSize
Definition: face_gauge.cpp:34
int Vh
Definition: test_util.cpp:28