QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
coarse_op_preconditioned.cuh
Go to the documentation of this file.
1 #include <gauge_field_order.h>
2 #include <index_helper.cuh>
3 
4 namespace quda {
5 
6  template <typename Float, typename PreconditionedGauge, typename Gauge, int n> struct CalculateYhatArg {
7  PreconditionedGauge Yhat;
8  const Gauge Y;
9  const Gauge Xinv;
12  int nFace;
13  const int coarseVolumeCB;
15  Float *max_h; // host scalar that stores the maximum element of Yhat. Pointer b/c pinned.
16  Float *max_d; // device scalar that stores the maximum element of Yhat
17 
18  CalculateYhatArg(const PreconditionedGauge &Yhat, const Gauge Y, const Gauge Xinv, const int *dim,
19  const int *comm_dim, int nFace) :
20  Yhat(Yhat),
21  Y(Y),
22  Xinv(Xinv),
23  nFace(nFace),
24  coarseVolumeCB(Y.VolumeCB()),
25  max_h(nullptr),
26  max_d(nullptr)
27  {
28  for (int i=0; i<4; i++) {
29  this->comm_dim[i] = comm_dim[i];
30  this->dim[i] = dim[i];
31  }
32  }
33  };
34 
35  // complex multiply-add with optimal use of fma
36  template<typename Float>
37  inline __device__ __host__ void caxpy(const complex<Float> &a, const complex<Float> &x, complex<Float> &y) {
38  y.x += a.x*x.x;
39  y.x -= a.y*x.y;
40  y.y += a.y*x.x;
41  y.y += a.x*x.y;
42  }
43 
44  template <typename Float, int n, bool compute_max_only, typename Arg>
45  inline __device__ __host__ Float computeYhat(Arg &arg, int d, int x_cb, int parity, int i, int j)
46  {
47 
48  constexpr int nDim = 4;
49  int coord[nDim];
50  getCoords(coord, x_cb, arg.dim, parity);
51 
52  const int ghost_idx = ghostFaceIndex<0, nDim>(coord, arg.dim, d, arg.nFace);
53 
54  Float yHatMax = 0.0;
55 
56  // first do the backwards links Y^{+\mu} * X^{-\dagger}
57  if ( arg.comm_dim[d] && (coord[d] - arg.nFace < 0) ) {
58 
59  complex<Float> yHat = 0.0;
60 #pragma unroll
61  for(int k = 0; k<n; k++) {
62  caxpy(arg.Y.Ghost(d,1-parity,ghost_idx,i,k), conj(arg.Xinv(0,parity,x_cb,j,k)), yHat);
63  }
64  if (compute_max_only) {
65  yHatMax = fmax(fabs(yHat.x), fabs(yHat.y));
66  } else {
67  arg.Yhat.Ghost(d, 1 - parity, ghost_idx, i, j) = yHat;
68  }
69 
70  } else {
71  const int back_idx = linkIndexM1(coord, arg.dim, d);
72 
73  complex<Float> yHat = 0.0;
74 #pragma unroll
75  for (int k = 0; k<n; k++) {
76  caxpy(arg.Y(d,1-parity,back_idx,i,k), conj(arg.Xinv(0,parity,x_cb,j,k)), yHat);
77  }
78  if (compute_max_only) {
79  yHatMax = fmax(fabs(yHat.x), fabs(yHat.y));
80  } else {
81  arg.Yhat(d, 1 - parity, back_idx, i, j) = yHat;
82  }
83  }
84 
85  // now do the forwards links X^{-1} * Y^{-\mu}
86  complex<Float> yHat = 0.0;
87 #pragma unroll
88  for (int k = 0; k<n; k++) {
89  caxpy(arg.Xinv(0,parity,x_cb,i,k), arg.Y(d+4,parity,x_cb,k,j), yHat);
90  }
91  if (compute_max_only) {
92  yHatMax = fmax(yHatMax, fmax(fabs(yHat.x), fabs(yHat.y)));
93  } else {
94  arg.Yhat(d + 4, parity, x_cb, i, j) = yHat;
95  }
96 
97  return yHatMax;
98  }
99 
100  template <typename Float, int n, bool compute_max_only, typename Arg> void CalculateYhatCPU(Arg &arg)
101  {
102  Float max = 0.0;
103  for (int d=0; d<4; d++) {
104  for (int parity=0; parity<2; parity++) {
105 #pragma omp parallel for
106  for (int x_cb = 0; x_cb < arg.Y.VolumeCB(); x_cb++) {
107  for (int i = 0; i < n; i++)
108  for (int j = 0; j < n; j++) {
109  Float max_x = computeYhat<Float, n, compute_max_only>(arg, d, x_cb, parity, i, j);
110  if (compute_max_only) max = max > max_x ? max : max_x;
111  }
112  }
113  } //parity
114  } // dimension
115  if (compute_max_only) *arg.max_h = max;
116  }
117 
118  template <typename Float, int n, bool compute_max_only, typename Arg> __global__ void CalculateYhatGPU(Arg arg)
119  {
120  int x_cb = blockDim.x*blockIdx.x + threadIdx.x;
121  if (x_cb >= arg.coarseVolumeCB) return;
122  int i_parity = blockDim.y*blockIdx.y + threadIdx.y;
123  if (i_parity >= 2*n) return;
124  int j_d = blockDim.z*blockIdx.z + threadIdx.z;
125  if (j_d >= 4*n) return;
126 
127  int i = i_parity % n;
128  int parity = i_parity / n;
129  int j = j_d % n;
130  int d = j_d / n;
131 
132  Float max = computeYhat<Float, n, compute_max_only>(arg, d, x_cb, parity, i, j);
133  if (compute_max_only) atomicMax(arg.max_d, max);
134  }
135 
136 } // namespace quda
void CalculateYhatCPU(Arg &arg)
__device__ __host__ void caxpy(const complex< Float > &a, const complex< Float > &x, complex< Float > &y)
static __device__ __host__ int linkIndexM1(const int x[], const I X[4], const int mu)
CalculateYhatArg(const PreconditionedGauge &Yhat, const Gauge Y, const Gauge Xinv, const int *dim, const int *comm_dim, int nFace)
Main header file for host and device accessors to GaugeFields.
__device__ __host__ Float computeYhat(Arg &arg, int d, int x_cb, int parity, int i, int j)
__global__ void CalculateYhatGPU(Arg arg)
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
#define QUDA_MAX_DIM
Maximum number of dimensions supported by QUDA. In practice, no routines make use of more than 5...
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:130
static __device__ float atomicMax(float *addr, float val)
Implementation of single-precision atomic max using compare and swap. May not support NaNs properly...
Definition: atomic.cuh:142
QudaParity parity
Definition: covdev_test.cpp:54
__host__ __device__ int getCoords(int coord[], const Arg &arg, int &idx, int parity, int &dim)
Compute the space-time coordinates we are at.