diff --git a/optimization/optimizers/conjugate_gradient_descent/conjugate_gradient_descent.h b/optimization/optimizers/conjugate_gradient_descent/conjugate_gradient_descent.h index c9d0d38bcf5f63fe33e419ce272ca5d32f48b6e2..85fe5be55a12333c7949a187cf83fc91f9a3973a 100644 --- a/optimization/optimizers/conjugate_gradient_descent/conjugate_gradient_descent.h +++ b/optimization/optimizers/conjugate_gradient_descent/conjugate_gradient_descent.h @@ -32,12 +32,37 @@ public: VectorNs<N> const& gradient() const { return gradient_; } Scalar value() const { return value_; } + // There are four cases for optimize, since you can use a log or not, and can pass an rvalue or + // not. These methods ultimately all call optimize_impl to do the real work. + // Note: I explicitly write out the Vector const& and Vector&& cases rather than passing by + // value and moving since Eigen forbids passing by value: + // https://eigen.tuxfamily.org/dox/group__TopicPassingByValue.html template <typename Objective> - Vector const& optimize(Objective& objective, Vector const& initial_point); + Vector const& optimize(Objective& objective, Vector const& initial_point) { + ConjugateGradientDescentLogNothing log; + return optimize(objective, initial_point, log); + } template <typename Objective, typename Log> - Vector const& optimize(Objective& objective, Vector const& initial_point, Log& log); + Vector const& optimize(Objective& objective, Vector const& initial_point, Log& log) { + point_ = initial_point; + return optimize_impl(objective, log); + } + template <typename Objective> + Vector const& optimize(Objective& objective, Vector&& initial_point) { + ConjugateGradientDescentLogNothing log; + return optimize(objective, std::move(initial_point), log); + } + template <typename Objective, typename Log> + Vector const& optimize(Objective& objective, Vector&& initial_point, Log& log) { + point_ = std::move(initial_point); + return optimize_impl(objective, log); + } private: + // Actually does the optimization, once point_ has been initialized. + template <typename Objective, typename Log> + Vector const& optimize_impl(Objective& objective, Log& log); + // hyperparameters Scalar gradient_threshold_; Scalar progress_threshold_; @@ -62,28 +87,16 @@ private: static constexpr Scalar tiny_ = std::numeric_limits<Scalar>::epsilon(); }; -//.................................................................................................. -template <int32_t N> -template <typename Objective> -VectorNs<N> const& ConjugateGradientDescent<N>::optimize( - Objective& objective, - Vector const& initial_point -) { - ConjugateGradientDescentLogNothing log; - return optimize(objective, initial_point, log); -} - //.................................................................................................. template <int32_t N> template <typename Objective, typename Log> -VectorNs<N> const& ConjugateGradientDescent<N>::optimize( +VectorNs<N> const& ConjugateGradientDescent<N>::optimize_impl( Objective& objective, - Vector const& initial_point, Log& log ) { n_evaluations_ = 0; n_iterations_ = 0; - point_ = initial_point; + // point_ has already been initialized direction_.resize(point_.size()); gradient_.resize(point_.size()); new_point_or_last_gradient_.resize(point_.size()); diff --git a/optimization/optimizers/conjugate_gradient_descent/main.cpp b/optimization/optimizers/conjugate_gradient_descent/main.cpp index 73276d594679f83d185a789988f26f8f7121a711..df00eca0c9bf0336f4dc01d2e0628cef453dfaa4 100644 --- a/optimization/optimizers/conjugate_gradient_descent/main.cpp +++ b/optimization/optimizers/conjugate_gradient_descent/main.cpp @@ -60,9 +60,9 @@ int main(int const argc, char const** argv) { // Only log stuff if we're going to use it. if (log_file_path.empty() && vis_file_path.empty()) { - optimizer.optimize(objective, initial_point); + optimizer.optimize(objective, std::move(initial_point)); } else { - optimizer.optimize(objective, initial_point, log); + optimizer.optimize(objective, std::move(initial_point), log); } std::cout << "n evaluations: " << optimizer.n_evaluations() << '\n';