QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
su3_project.cuh
Go to the documentation of this file.
1 #pragma once
2 
11 #include <quda_matrix.h>
12 
13 namespace quda {
14 
23  template <typename Matrix, typename Float>
24  __host__ __device__ inline bool checkUnitary(const Matrix &inv, const Matrix &in, const Float tol)
25  {
26 
27  // first check U - U^{-1} = 0
28 #pragma unroll
29  for (int i=0; i<in.size(); i++) {
30 #pragma unroll
31  for (int j=0; j<in.size(); j++) {
32  if (fabs(in(i,j).real() - inv(j,i).real()) > tol ||
33  fabs(in(i,j).imag() + inv(j,i).imag()) > tol) return false;
34  }
35  }
36 
37  // now check 1 - U^dag * U = 0
38  // this check is more expensive so delay until we have passed first check
39  const Matrix identity = conj(in)*in;
40 #pragma unroll
41  for (int i=0; i<in.size(); i++) {
42  if (fabs(identity(i,i).real() - static_cast<Float>(1.0)) > tol ||
43  fabs(identity(i,i).imag()) > tol)
44  return false;
45 #pragma unroll
46  for (int j=0; j<in.size(); j++) {
47  if (i>j) { // off-diagonal identity check
48  if (fabs(identity(i,j).real()) > tol || fabs(identity(i,j).imag()) > tol ||
49  fabs(identity(j,i).real()) > tol || fabs(identity(j,i).imag()) > tol )
50  return false;
51  }
52  }
53  }
54 
55  return true;
56  }
57 
65  template <typename Matrix>
66  __host__ __device__ void checkUnitaryPrint(const Matrix &inv, const Matrix &in)
67  {
68  for (int i=0; i<in.size(); i++) {
69  for (int j=0; j<in.size(); j++) {
70  printf("TESTR: %+.13le %+.13le %+.13le\n",
71  in(i,j).real(), inv(j,i).real(), fabs(in(i,j).real() - inv(j,i).real()));
72  printf("TESTI: %+.13le %+.13le %+.13le\n",
73  in(i,j).imag(), inv(j,i).imag(), fabs(in(i,j).imag() + inv(j,i).imag()));
74  }
75  }
76  }
77 
86  template <typename Float>
87  __host__ __device__ inline void polarSu3(Matrix<complex<Float>,3> &in, Float tol)
88  {
89  constexpr Float negative_third = -1.0/3.0;
90  constexpr Float negative_sixth = -1.0/6.0;
93 
94  constexpr int max_iter = 100;
95  int i = 0;
96  do { // iterate until matrix is unitary
97  out = 0.5*(out + conj(inv));
98  inv = inverse(out);
99  } while (!checkUnitary(inv, out, tol) && ++i < max_iter);
100 
101  // now project onto special unitary group
102  complex<Float> det = getDeterminant(out);
103  Float mod = pow(norm(det), negative_sixth);
104  Float angle = arg(det);
105 
106  complex<Float> cTemp;
107  sincos(negative_third * angle, &cTemp.y, &cTemp.x);
108 
109  in = (mod*cTemp)*out;
110  }
111 
112 
113 } // namespace quda
__host__ __device__ void polarSu3(Matrix< complex< Float >, 3 > &in, Float tol)
Project the input matrix on the SU(3) group. First unitarize the matrix and then project onto the spe...
Definition: su3_project.cuh:87
__device__ __host__ constexpr int size() const
Definition: quda_matrix.h:74
__host__ __device__ ValueType norm(const complex< ValueType > &z)
Returns the magnitude of z squared.
__host__ __device__ bool checkUnitary(const Matrix &inv, const Matrix &in, const Float tol)
Check the unitarity of the input matrix to a given tolerance.
Definition: su3_project.cuh:24
__host__ __device__ void checkUnitaryPrint(const Matrix &inv, const Matrix &in)
Print out deviation for each component (used for debugging only).
Definition: su3_project.cuh:66
double tol
Definition: test_util.cpp:1656
cpuColorSpinorField * in
__host__ __device__ ValueType pow(ValueType x, ExponentType e)
Definition: complex_quda.h:111
__device__ __host__ Matrix< T, 3 > inverse(const Matrix< T, 3 > &u)
Definition: quda_matrix.h:611
cpuColorSpinorField * out
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
static int mod(int a, int b)
__device__ __host__ T getDeterminant(const Mat< T, 3 > &a)
Definition: quda_matrix.h:422
__host__ __device__ ValueType conj(ValueType x)
Definition: complex_quda.h:130