10 #include <Eigen/Dense>
31 virtual const char*
what()
const throw()
override
33 return "Could not compute matrix eigenvectors.";
60 template <
int n_states,
int n_params>
63 static_assert(n_states <= n_params,
"Sorry, n_states must be smaller or equal to n_params");
67 static const int k = n_states;
68 static const int l = n_params;
69 static const int lr =
l - 1;
71 using x_t = Eigen::Matrix<double, k, 1>;
72 using xs_t = Eigen::Matrix<double,
k, -1>;
73 using u_t = Eigen::Matrix<double, l, 1>;
74 using P_t = Eigen::Matrix<double, k, k>;
75 using Ps_t = std::vector<P_t>;
77 using z_t = Eigen::Matrix<double, lr, 1>;
78 using zs_t = Eigen::Matrix<double,
lr, -1>;
81 using dzdx_t = Eigen::Matrix<double, lr, k>;
85 using theta_t = Eigen::Matrix<double, l, 1>;
89 using A_t = Eigen::Matrix<double, lr, lr>;
91 using Bs_t = std::vector<B_t>;
94 using betas_t = Eigen::Matrix<double, 1, -1>;
95 using ms_t = std::chrono::milliseconds;
109 m_initialized(false),
116 m_ALS_theta_set(false),
117 m_last_theta_set(false)
135 RHEIV(
const f_z_t& f_z,
const f_dzdx_t& f_dzdx,
const double min_dtheta = 1e-15,
const unsigned max_its = 100)
140 m_min_dtheta(min_dtheta),
142 m_timeout(ms_t::zero()),
144 m_ALS_theta_set(false),
145 m_last_theta_set(false)
165 template <
typename time_t>
166 RHEIV(
const f_z_t& f_z,
const f_dzdx_t& f_dzdx,
const double min_dtheta = 1e-15,
const unsigned max_its = 100,
const time_t& timeout = std::chrono::duration_cast<time_t>(ms_t::zero()),
const int debug_nth_it = -1)
171 m_min_dtheta(min_dtheta),
173 m_timeout(std::chrono::duration_cast<ms_t>(timeout)),
174 m_debug_nth_it(debug_nth_it),
175 m_ALS_theta_set(false),
176 m_last_theta_set(false)
189 RHEIV(
const f_z_t& f_z,
const dzdx_t& dzdx,
const double min_dtheta = 1e-15,
const unsigned max_its = 100)
201 m_min_dtheta(min_dtheta),
203 m_timeout(ms_t::zero()),
205 m_ALS_theta_set(
false),
206 m_last_theta_set(
false)
222 template <
typename time_t>
223 RHEIV(
const f_z_t& f_z,
const dzdx_t& dzdx,
const double min_dtheta = 1e-15,
const unsigned max_its = 100,
const time_t& timeout = std::chrono::duration_cast<time_t>(ms_t::zero()),
const int debug_nth_it = -1)
235 m_min_dtheta(min_dtheta),
237 m_timeout(std::chrono::duration_cast<ms_t>(timeout)),
238 m_debug_nth_it(debug_nth_it),
239 m_ALS_theta_set(
false),
240 m_last_theta_set(
false)
260 assert(m_initialized);
261 assert((
size_t)xs.cols() == Ps.size());
262 const std::chrono::system_clock::time_point fit_start = std::chrono::system_clock::now();
264 const zs_t zs = m_f_z(xs);
265 const dzdxs_t dzdxs = precalculate_dxdzs(xs, m_f_dzdx);
268 m_ALS_theta = fit_ALS_impl(zs);
269 m_last_theta = m_ALS_theta;
270 m_ALS_theta_set = m_last_theta_set =
true;
271 eta_t eta = m_ALS_theta.template block<lr, 1>(0, 0);
273 for (
unsigned it = 0; it < m_max_its; it++)
275 const std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
276 const auto fit_dur = now - fit_start;
277 if (m_timeout > ms_t::zero() && fit_dur > m_timeout)
279 if (m_debug_nth_it > 0)
280 std::cerr <<
"[RHEIV]: Ending at iteration " << it <<
" (max " << m_max_its <<
") - timed out." << std::endl;
283 const theta_t prev_theta = m_last_theta;
284 const auto [M, N, zc] = calc_MN(eta, zs, Ps, dzdxs);
285 eta = calc_min_eigvec(M, N);
286 m_last_theta = calc_theta(eta, zc);
287 const double dtheta = calc_dtheta(prev_theta, m_last_theta);
288 if (m_debug_nth_it > 0 && it % m_debug_nth_it == 0)
289 std::cout <<
"[RHEIV]: iteration " << it <<
" (max " << m_max_its <<
"), dtheta: " << dtheta <<
" (min " << m_min_dtheta <<
") " << std::endl;
290 if (dtheta < m_min_dtheta)
312 template <
class T_it1,
314 theta_t fit(
const T_it1& xs_begin,
const T_it1& xs_end,
const T_it2& Ps_begin,
const T_it2& Ps_end)
316 const xs_t xs = cont_to_eigen(xs_begin, xs_end);
317 const Ps_t Ps = cont_to_vector(Ps_begin, Ps_end);
337 assert(m_initialized);
339 const zs_t zs = m_f_z(xs);
340 m_ALS_theta = fit_ALS_impl(zs);
341 m_ALS_theta_set =
true;
360 assert(m_last_theta_set);
379 assert(m_ALS_theta_set);
396 bool m_ALS_theta_set;
398 bool m_last_theta_set;
399 theta_t m_last_theta;
403 std::tuple<M_t, N_t, z_t> calc_MN(
const eta_t& eta,
const zs_t& zs,
const Ps_t& Ps,
const dzdxs_t& dzdxs)
const
405 const int n = zs.cols();
407 const Bs_t Bs = calc_Bs(Ps, dzdxs);
408 const betas_t betas = calc_betas(eta, Bs);
409 const auto [zrs, zc] = reduce_zs(zs, betas);
413 for (
int it = 0; it < n; it++)
415 const double beta = betas(it);
416 const z_t& zr = zrs.col(it);
417 const B_t& B = Bs.at(it);
418 const A_t A = calc_A(zr);
420 N += (eta.transpose()*A*eta)*B*(beta*beta);
427 Bs_t calc_Bs(
const Ps_t& Ps,
const dzdxs_t& dzdxs)
const
429 const int n = Ps.size();
432 for (
int it = 0; it < n; it++)
434 const P_t& P = Ps.at(it);
435 const dzdx_t& dzdx = dzdxs.at(it);
436 const B_t B = dzdx*P*dzdx.transpose();
444 betas_t calc_betas(
const eta_t& eta,
const Bs_t& Bs)
const
446 const int n = Bs.size();
448 for (
int it = 0; it < n; it++)
450 const B_t& B = Bs.at(it);
451 const double beta = 1.0/(eta.transpose()*B*eta);
459 A_t calc_A(
const z_t& z)
const
461 return z*z.transpose();
468 return std::min( (th1 - th2).norm(), (th1 + th2).norm() );
473 theta_t calc_theta(
const eta_t& eta,
const z_t& zc)
const
475 const double alpha = -zc.transpose()*eta;
476 const theta_t theta( (
theta_t() << eta, alpha).finished().normalized() );
482 eta_t calc_min_eigvec(
const M_t& M,
const N_t& N)
const
484 const Eigen::GeneralizedSelfAdjointEigenSolver<M_t> es(M, N);
485 if (es.info() != Eigen::Success)
486 throw eigenvector_exception();
488 const eta_t evec = es.eigenvectors().col(0).normalized();
494 eta_t calc_min_eigvec(
const M_t& M)
const
496 const Eigen::SelfAdjointEigenSolver<M_t> es(M);
497 if (es.info() != Eigen::Success)
498 throw eigenvector_exception();
500 const eta_t evec = es.eigenvectors().col(0).normalized();
508 const int n = xs.cols();
511 for (
int it = 0; it < n; it++)
513 const x_t& x = xs.col(it);
514 ret.push_back(f_dzdx(x));
521 z_t calc_centroid(
const zs_t& zs,
const betas_t& betas)
const
523 const z_t zc = (zs.array().rowwise() * betas.array()).rowwise().sum()/betas.sum();
529 std::pair<zs_t, z_t> reduce_zs(
const zs_t& zs,
const betas_t& betas)
const
531 const z_t zc = calc_centroid(zs, betas);
532 const zs_t zrs = zs.colwise() - zc;
538 template <
typename T_it>
539 xs_t cont_to_eigen(
const T_it& begin,
const T_it& end)
541 const auto n = end-begin;
542 xs_t ret(n_states, n);
544 for (T_it it = begin; it != end; it++)
546 ret.template block<n_states, 1>(0, i) = *it;
554 template <
typename T_it>
555 Ps_t cont_to_vector(
const T_it& begin,
const T_it& end)
557 const auto n = end-begin;
560 for (T_it it = begin; it != end; it++)
574 z_t calc_centroid(
const zs_t& zs)
const
576 return zs.rowwise().mean();
579 std::pair<zs_t, z_t> reduce_zs(
const zs_t& zs)
const
581 const z_t zc = calc_centroid(zs);
582 const zs_t zrs = zs.colwise() - zc;
588 const auto [zrs, zc] = reduce_zs(zs);
590 for (
int it = 0; it < zrs.cols(); it++)
592 const z_t& z = zrs.col(it);
593 const A_t A = calc_A(z);
596 const eta_t eta = calc_min_eigvec(M);
597 const theta_t theta = calc_theta(eta, zc);