From 4ab765b329440b60d9ae88ee8a9df132e3d9987d Mon Sep 17 00:00:00 2001 From: Erik Strand <erik.strand@cba.mit.edu> Date: Fri, 17 Apr 2020 10:52:32 -0400 Subject: [PATCH] Replace LineObjective with a lambda function --- .../conjugate_gradient_descent.h | 6 ++-- .../optimizers/line_search/line_objective.h | 35 ------------------- 2 files changed, 4 insertions(+), 37 deletions(-) delete mode 100644 optimization/optimizers/line_search/line_objective.h diff --git a/optimization/optimizers/conjugate_gradient_descent/conjugate_gradient_descent.h b/optimization/optimizers/conjugate_gradient_descent/conjugate_gradient_descent.h index 1dcb6b9..9393909 100644 --- a/optimization/optimizers/conjugate_gradient_descent/conjugate_gradient_descent.h +++ b/optimization/optimizers/conjugate_gradient_descent/conjugate_gradient_descent.h @@ -3,7 +3,6 @@ #include "logs/nothing.h" #include "optimizers/line_search/brent.h" -#include "optimizers/line_search/line_objective.h" #include <iostream> @@ -90,7 +89,10 @@ VectorNs<N> const& ConjugateGradientDescent<N>::optimize( gradient_.resize(point_.size()); last_gradient_.resize(point_.size()); - LineObjective<Objective, N> line_objective{objective, point_, direction_, new_point_}; + auto const line_objective = [&](Scalar t, Scalar& value) { + new_point_ = point_ + t * direction_; + objective(new_point_, value); + }; BracketFinder bracket; Brent line_minimizer; diff --git a/optimization/optimizers/line_search/line_objective.h b/optimization/optimizers/line_search/line_objective.h deleted file mode 100644 index 7a7bb55..0000000 --- a/optimization/optimizers/line_search/line_objective.h +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef OPTIMIZATION_LINE_SEARCH_LINE_OBJECTIVE_H -#define OPTIMIZATION_LINE_SEARCH_LINE_OBJECTIVE_H - -#include "utils/vector.h" - -namespace optimization { - -//-------------------------------------------------------------------------------------------------- -template <typename Objective, int32_t N> -struct LineObjective { -public: - LineObjective(Objective& o, VectorNs<N>& x0, VectorNs<N>& dir, VectorNs<N>& x) : - objective_(o), x0_(x0), dir_(dir), x_(x) - {} - - void operator()(Scalar t, Scalar& value) { - x_ = x0_ + t * dir_; - objective_(x_, value); - } - - void operator()(Scalar t, Scalar& value, VectorNs<N>& gradient) { - x_ = x0_ + t * dir_; - objective_(x_, value, gradient); - } - -private: - Objective& objective_; - VectorNs<N>& x0_; - VectorNs<N>& dir_; - VectorNs<N>& x_; -}; - -} - -#endif -- GitLab