13 #define HISQ_UNITARIZE_PI 3.14159265358979323846
14 #define HISQ_UNITARIZE_PI23 HISQ_UNITARIZE_PI*2.0/3.0
27 static double HOST_HISQ_UNITARIZE_EPS;
28 static double HOST_HISQ_FORCE_FILTER;
29 static double HOST_MAX_DET_ERROR;
30 static bool HOST_REUNIT_ALLOW_SVD;
31 static bool HOST_REUNIT_SVD_ONLY;
32 static double HOST_REUNIT_SVD_REL_ERROR;
33 static double HOST_REUNIT_SVD_ABS_ERROR;
38 namespace fermion_force{
42 double max_det_error_h,
bool allow_svd_h,
bool svd_only_h,
43 double svd_rel_error_h,
double svd_abs_error_h)
47 static bool not_set=
true;
58 HOST_HISQ_UNITARIZE_EPS = unitarize_eps_h;
59 HOST_HISQ_FORCE_FILTER = hisq_force_filter_h;
60 HOST_MAX_DET_ERROR = max_det_error_h;
61 HOST_REUNIT_ALLOW_SVD = allow_svd_h;
62 HOST_REUNIT_SVD_ONLY = svd_only_h;
63 HOST_REUNIT_SVD_REL_ERROR = svd_rel_error_h;
64 HOST_REUNIT_SVD_ABS_ERROR = svd_abs_error_h;
78 Real computeC00(
const Real &,
const Real &,
const Real &);
80 Real computeC01(
const Real &,
const Real &,
const Real &);
82 Real computeC02(
const Real &,
const Real &,
const Real &);
84 Real computeC11(
const Real &,
const Real &,
const Real &);
86 Real computeC12(
const Real &,
const Real &,
const Real &);
88 Real computeC22(
const Real &,
const Real &,
const Real &);
91 __device__ __host__
void set(
const Real & u,
const Real & v,
const Real & w);
108 Real DerivativeCoefficients<Real>::computeC00(
const Real & u,
const Real & v,
const Real & w){
109 Real result = -pow(w,3)*pow(u,6)
110 + 3*v*pow(w,3)*pow(u,4)
111 + 3*pow(v,4)*w*pow(u,4)
113 - 4*pow(w,4)*pow(u,3)
114 - 12*pow(v,3)*pow(w,2)*pow(u,3)
115 + 16*pow(v,2)*pow(w,3)*pow(u,2)
116 + 3*pow(v,5)*w*pow(u,2)
118 - 3*pow(v,4)*pow(w,2)*u
127 Real DerivativeCoefficients<Real>::computeC01(
const Real & u,
const Real & v,
const Real & w){
128 Real result = - pow(w,2)*pow(u,7)
129 - pow(v,2)*w*pow(u,6)
131 + 6*v*pow(w,2)*pow(u,5)
132 - 5*pow(w,3)*pow(u,4)
133 - pow(v,3)*w*pow(u,4)
134 - 2*pow(v,5)*pow(u,3)
135 - 6*pow(v,2)*pow(w,2)*pow(u,3)
136 + 10*v*pow(w,3)*pow(u,2)
137 + 6*pow(v,4)*w*pow(u,2)
139 - 6*pow(v,3)*pow(w,2)*u
140 + 2*pow(v,2)*pow(w,3);
146 Real DerivativeCoefficients<Real>::computeC02(
const Real & u,
const Real & v,
const Real & w){
147 Real result = pow(w,2)*pow(u,5)
148 + pow(v,2)*w*pow(u,4)
150 - 4*v*pow(w,2)*pow(u,3)
151 + 4*pow(w,3)*pow(u,2)
152 + 3*pow(v,3)*w*pow(u,2)
153 - 3*pow(v,2)*pow(w,2)*u
160 Real DerivativeCoefficients<Real>::computeC11(
const Real & u,
const Real & v,
const Real & w){
161 Real result = - w*pow(u,8)
164 + 4*pow(v,3)*pow(u,5)
165 - 5*pow(w,2)*pow(u,5)
166 - 16*pow(v,2)*w*pow(u,4)
167 - 4*pow(v,4)*pow(u,3)
168 + 16*v*pow(w,2)*pow(u,3)
169 - 3*pow(w,3)*pow(u,2)
170 + 12*pow(v,3)*w*pow(u,2)
171 - 12*pow(v,2)*pow(w,2)*u
178 Real DerivativeCoefficients<Real>::computeC12(
const Real & u,
const Real & v,
const Real & w){
179 Real result = w*pow(u,6)
182 - 2*pow(v,3)*pow(u,3)
183 + 4*pow(w,2)*pow(u,3)
184 + 6*pow(v,2)*w*pow(u,2)
192 Real DerivativeCoefficients<Real>::computeC22(
const Real & u,
const Real & v,
const Real & w){
193 Real result = - w*pow(u,4)
200 template <
class Real>
203 const Real & denominator = 2.0*pow(w*(u*v-w),3);
204 b[0] = computeC00(u,v,w)/denominator;
205 b[1] = computeC01(u,v,w)/denominator;
206 b[2] = computeC02(u,v,w)/denominator;
207 b[3] = computeC11(u,v,w)/denominator;
208 b[4] = computeC12(u,v,w)/denominator;
209 b[5] = computeC22(u,v,w)/denominator;
214 template<
class Cmplx>
219 for(
int k=0; k<3; ++k){
220 for(
int l=0; l<3; ++l){
223 result->operator()(k,l).
x += temp*right(k,l).x;
224 result->operator()(k,l).y += temp*right(k,l).y;
231 template<
class Cmplx>
235 Cmplx temp =
getTrace(left*outer_prod);
236 for(
int k=0; k<3; ++k){
237 for(
int l=0; l<3; ++l){
238 result->operator()(k,l) = temp*right(k,l);
248 T min = fabs(array[0]);
249 for(
int i=1; i<size; ++i){
250 T abs_val = fabs(array[i]);
251 if((abs_val) < min){ min = abs_val; }
261 if( fabs(a-b) < epsilon)
return true;
270 if( fabs((a-b)/b) < epsilon )
return true;
279 template<
class Cmplx>
290 #define REUNIT_SVD_ONLY DEV_REUNIT_SVD_ONLY
292 #define REUNIT_SVD_ONLY HOST_REUNIT_SVD_ONLY
302 g[0] = g[1] = g[2] = c[0]/3.;
304 s = c[1]/3. - c[0]*c[0]/18;
305 r = c[2]/2. - (c[0]/3.)*(c[1] - c[0]*c[0]/9.);
308 #define HISQ_UNITARIZE_EPS DEV_HISQ_UNITARIZE_EPS
310 #define HISQ_UNITARIZE_EPS HOST_HISQ_UNITARIZE_EPS
319 else{ theta = acos(cosTheta)/3.0; }
322 for(
int i=0; i<3; ++i){
341 #define REUNIT_ALLOW_SVD DEV_REUNIT_ALLOW_SVD
342 #define REUNIT_SVD_REL_ERROR DEV_REUNIT_SVD_REL_ERROR
343 #define REUNIT_SVD_ABS_ERROR DEV_REUNIT_SVD_ABS_ERROR
345 #define REUNIT_ALLOW_SVD HOST_REUNIT_ALLOW_SVD
346 #define REUNIT_SVD_REL_ERROR HOST_REUNIT_SVD_REL_ERROR
347 #define REUNIT_SVD_ABS_ERROR HOST_REUNIT_SVD_ABS_ERROR
351 bool perform_svd =
true;
362 computeSVD<Cmplx>(q,tempq,
tmp2,g);
370 #define MAX_DET_ERROR DEV_MAX_DET_ERROR
372 #define MAX_DET_ERROR HOST_MAX_DET_ERROR
375 #if (!defined(__CUDA_ARCH__) || (__COMPUTE_CAPABILITY__ >= 200))
376 printf(
"Warning: Error in determinant computed by SVD : %g > %g\n", fabs(gprod-determinant),
MAX_DET_ERROR);
381 atomicAdd(unitarization_failed,1);
383 (*unitarization_failed)++;
391 #define HISQ_FORCE_FILTER DEV_HISQ_FORCE_FILTER
393 #define HISQ_FORCE_FILTER HOST_HISQ_FORCE_FILTER
397 for(
int i=0; i<3; ++i){
407 for(
int i=0; i<3; ++i) c[i] = sqrt(g[i]);
410 g[0] = c[0]+c[1]+c[2];
411 g[1] = c[0]*c[1] + c[0]*c[2] + c[1]*c[2];
412 g[2] = c[0]*c[1]*c[2];
415 deriv_coeffs->set(g[0], g[1], g[2]);
418 c[0] = (g[0]*g[1]*g[1] - g[2]*(g[0]*g[0]+g[1]))/denominator;
419 c[1] = (-g[0]*g[0]*g[0] - g[2] + 2.*g[0]*g[1])/denominator;
420 c[2] = g[0]/denominator;
422 tempq = c[1]*q + c[2]*qsq;
424 tempq(0,0).x += c[0];
425 tempq(1,1).x += c[0];
426 tempq(2,2).x += c[0];
439 template<
class Cmplx>
453 reciprocalRoot<Cmplx>(&rsqrt_q, &deriv_coeffs, f, q, unitarization_failed);
456 b[0] = deriv_coeffs.getB00();
457 b[1] = deriv_coeffs.getB01();
458 b[2] = deriv_coeffs.getB02();
459 b[3] = deriv_coeffs.getB11();
460 b[4] = deriv_coeffs.getB12();
461 b[5] = deriv_coeffs.getB22();
466 local_result = rsqrt_q*outer_prod;
475 temp = f[1]*v_dagger + f[2]*qv_dagger;
479 temp = f[1]*v + f[2]*v*q;
480 local_result = local_result + outer_prod*temp*v_dagger + f[2]*q*outer_prod*vv_dagger;
482 local_result = local_result + v_dagger*conj_outer_prod*
conj(temp) + f[2]*qv_dagger*conj_outer_prod*v_dagger;
487 Matrix<Cmplx,3> pv_dagger = b[0]*v_dagger + b[1]*qv_dagger + b[2]*qsqv_dagger;
490 Matrix<Cmplx,3> rv_dagger = b[1]*v_dagger + b[3]*qv_dagger + b[4]*qsqv_dagger;
494 Matrix<Cmplx,3> sv_dagger = b[2]*v_dagger + b[4]*qv_dagger + b[5]*qsqv_dagger;
502 template<
class Cmplx>
504 const Cmplx* old_force_even,
const Cmplx* old_force_odd,
505 Cmplx* force_even, Cmplx* force_odd,
506 int* unitarization_failed)
509 int mem_idx = blockIdx.x*blockDim.x + threadIdx.x;
511 const int HALF_VOLUME = threads/2;
512 if(mem_idx >= threads)
return;
516 const Cmplx* old_force;
520 old_force = old_force_even;
521 if(mem_idx >= HALF_VOLUME){
522 mem_idx = mem_idx - HALF_VOLUME;
525 old_force = old_force_odd;
536 getUnitarizeForceSite<double2>(v, oprod, &result, unitarization_failed);
547 int num_failures = 0;
555 for(
int i=0; i<cpuGauge.
Volume(); ++i){
560 getUnitarizeForceSite<double2>(v, old_force, &new_force, &num_failures);
565 getUnitarizeForceSite<double2>(v, old_force, &new_force, &num_failures);
572 for(
int i=0; i<cpuGauge.
Volume(); ++i){
576 getUnitarizeForceSite<double2>(v, old_force, &new_force, &num_failures);
581 getUnitarizeForceSite<double2>(v, old_force, &new_force, &num_failures);
587 errorQuda(
"Only MILC and QDP gauge orders supported\n");
599 int sharedBytesPerThread()
const {
return 0; }
600 int sharedBytesPerBlock(
const TuneParam &)
const {
return 0; }
606 bool advanceBlockDim(
TuneParam ¶m)
const
608 const unsigned int max_threads =
deviceProp.maxThreadsDim[0];
609 const unsigned int max_blocks =
deviceProp.maxGridSize[0];
610 const unsigned int max_shared = 16384;
614 param.
block.x += step;
615 if (param.
block.x > max_threads || sharedBytesPerThread()*param.
block.x > max_shared) {
616 param.
block = dim3((threads+max_blocks-1)/max_blocks, 1, 1);
617 param.
block.x = ((param.
block.x+step-1) / step) * step;
618 if (param.
block.x > max_threads)
errorQuda(
"Local lattice volume is too large for device");
630 oldForce(oldForce), gauge(gauge), newForce(newForce), fails(fails) { ; }
638 (
const float2*)oldForce.
Even_p(), (
const float2*)oldForce.
Odd_p(),
639 (float2*)newForce.
Even_p(), (float2*)newForce.
Odd_p(),
643 (
const double2*)oldForce.
Even_p(), (
const double2*)oldForce.
Odd_p(),
644 (double2*)newForce.
Even_p(), (double2*)newForce.
Odd_p(),
650 void postTune() { cudaMemset(fails, 0,
sizeof(
int)); }
654 const unsigned int max_threads =
deviceProp.maxThreadsDim[0];
655 const unsigned int max_blocks =
deviceProp.maxGridSize[0];
658 param.
block = dim3((threads+max_blocks-1)/max_blocks, 1, 1);
659 param.
block.x = ((param.
block.x+step-1) / step) * step;
660 if (param.
block.x > max_threads)
errorQuda(
"Local lattice volume is too large for device");
662 param.
shared_bytes = sharedBytesPerThread()*param.
block.x > sharedBytesPerBlock(param) ?
663 sharedBytesPerThread()*param.
block.x : sharedBytesPerBlock(param);
671 long long flops()
const {
return 0; }
674 std::stringstream vol, aux;
675 vol << gauge.
X()[0] <<
"x";
676 vol << gauge.
X()[1] <<
"x";
677 vol << gauge.
X()[2] <<
"x";
678 vol << gauge.
X()[3] <<
"x";
680 aux <<
"stride=" << gauge.
Stride();
681 return TuneKey(vol.str(),
typeid(*this).name(), aux.str());
687 UnitarizeForceCuda unitarizeForce(cudaOldForce, cudaGauge, *cudaNewForce, unitarization_failed);
688 unitarizeForce.
apply(0);