diff --git a/optimization/optimizers/conjugate_gradient_descent/conjugate_gradient_descent.h b/optimization/optimizers/conjugate_gradient_descent/conjugate_gradient_descent.h index 9393909896046c26eba5b52c37a1e1bf6f06816d..c9d0d38bcf5f63fe33e419ce272ca5d32f48b6e2 100644 --- a/optimization/optimizers/conjugate_gradient_descent/conjugate_gradient_descent.h +++ b/optimization/optimizers/conjugate_gradient_descent/conjugate_gradient_descent.h @@ -47,17 +47,16 @@ private: // algorithm state uint32_t n_evaluations_; uint32_t n_iterations_; - // This object holds the current point (x0) and search direction (dir). - // TODO allow this to live here, somehow - //LineObjective<Objective, N> line_objective_; - VectorNs<N> point_; - VectorNs<N> direction_; - VectorNs<N> new_point_; + Vector point_; // where we are currently + Vector direction_; // the direction we're searching + Vector gradient_; // the gradient at point_ + // This variable sees double duty. When we're line searching, it holds the candidate points. + // Between evaluating the function at the line minimum and computing the next search direction, + // it stores the gradient at the point of the previous iteration. + Vector new_point_or_last_gradient_; Scalar alpha_; // stores the most recent jump distance, which we use as a guess for the next one Scalar value_; Scalar last_value_; - Vector gradient_; - Vector last_gradient_; // algorithm constants static constexpr Scalar tiny_ = std::numeric_limits<Scalar>::epsilon(); @@ -85,13 +84,14 @@ VectorNs<N> const& ConjugateGradientDescent<N>::optimize( n_evaluations_ = 0; n_iterations_ = 0; point_ = initial_point; - alpha_ = -1; + direction_.resize(point_.size()); gradient_.resize(point_.size()); - last_gradient_.resize(point_.size()); + new_point_or_last_gradient_.resize(point_.size()); + alpha_ = -1; auto const line_objective = [&](Scalar t, Scalar& value) { - new_point_ = point_ + t * direction_; - objective(new_point_, value); + new_point_or_last_gradient_ = point_ + t * direction_; + objective(new_point_or_last_gradient_, value); }; BracketFinder bracket; Brent line_minimizer; @@ -108,19 +108,23 @@ VectorNs<N> const& ConjugateGradientDescent<N>::optimize( } while (true) { - // Find the minimum along this direction. + // Find the minimum along this direction. The next two lines are the only ones that use + // new_point_or_last_gradient_ as new_point_ (happens inside the lambda line_objective). bracket.bracket(line_objective, Scalar(0), alpha_); alpha_ = line_minimizer.optimize(line_objective, bracket).point; n_evaluations_ += bracket.n_evaluations(); n_evaluations_ += line_minimizer.n_evaluations(); - // Note: at some point new_point_ already had this new value, but right now there's no - // guarantee that it's the current value. - point_ += alpha_ * direction_; - // Can we re-use evaluation here? Would need to rewrite line search to have guarantees on - // which point it evaluated last. Probably best to do this after I've written dbrent. last_value_ = value_; - gradient_.swap(last_gradient_); + // Basically this line means new_point_or_last_gradient_ = std::move(gradient_), but I want + // to be clear that we're not releasing any memory here. For the rest of the loop, + // new_point_or_last_gradient_ stores the last gradient. + new_point_or_last_gradient_.swap(gradient_); + // Again, we evaluted the objective here already during the line search, but we probably + // threw that data out. + // Note: at some point new_point_or_last_gradient_ already had this new value, but right now + // there's no guarantee that it's the current value. + point_ += alpha_ * direction_; objective(point_, value_, gradient_); ++n_iterations_; @@ -151,10 +155,11 @@ VectorNs<N> const& ConjugateGradientDescent<N>::optimize( return point_; } - // Choose the next search direction. It is conjugate to all prior directions. - // I view this as the start of the next iteration. - Scalar const gamma = - gradient_.dot(gradient_ - last_gradient_) / last_gradient_.dot(last_gradient_); + // Choose the next search direction. It is conjugate to all prior directions. Recall that + // new_point_or_last_gradient_ currently stores the last gradient. I view this as the start + // of the next iteration. + Scalar const gamma = gradient_.dot(gradient_ - new_point_or_last_gradient_) + / new_point_or_last_gradient_.dot(new_point_or_last_gradient_); direction_ = gradient_ + gamma * direction_; } }