Skip to content
Snippets Groups Projects
Commit ca128b84 authored by Erik Strand's avatar Erik Strand
Browse files

Reduce memory cost of conjugate gradient descent

Before I had five vectors, now just four. The realization is that the
last gradient and the temporary new point (used for the line search) are
never needed at the same time, so they can share memory.
parent 4ab765b3
No related branches found
No related tags found
No related merge requests found
......@@ -47,17 +47,16 @@ private:
// algorithm state
uint32_t n_evaluations_;
uint32_t n_iterations_;
// This object holds the current point (x0) and search direction (dir).
// TODO allow this to live here, somehow
//LineObjective<Objective, N> line_objective_;
VectorNs<N> point_;
VectorNs<N> direction_;
VectorNs<N> new_point_;
Vector point_; // where we are currently
Vector direction_; // the direction we're searching
Vector gradient_; // the gradient at point_
// This variable sees double duty. When we're line searching, it holds the candidate points.
// Between evaluating the function at the line minimum and computing the next search direction,
// it stores the gradient at the point of the previous iteration.
Vector new_point_or_last_gradient_;
Scalar alpha_; // stores the most recent jump distance, which we use as a guess for the next one
Scalar value_;
Scalar last_value_;
Vector gradient_;
Vector last_gradient_;
// algorithm constants
static constexpr Scalar tiny_ = std::numeric_limits<Scalar>::epsilon();
......@@ -85,13 +84,14 @@ VectorNs<N> const& ConjugateGradientDescent<N>::optimize(
n_evaluations_ = 0;
n_iterations_ = 0;
point_ = initial_point;
alpha_ = -1;
direction_.resize(point_.size());
gradient_.resize(point_.size());
last_gradient_.resize(point_.size());
new_point_or_last_gradient_.resize(point_.size());
alpha_ = -1;
auto const line_objective = [&](Scalar t, Scalar& value) {
new_point_ = point_ + t * direction_;
objective(new_point_, value);
new_point_or_last_gradient_ = point_ + t * direction_;
objective(new_point_or_last_gradient_, value);
};
BracketFinder bracket;
Brent line_minimizer;
......@@ -108,19 +108,23 @@ VectorNs<N> const& ConjugateGradientDescent<N>::optimize(
}
while (true) {
// Find the minimum along this direction.
// Find the minimum along this direction. The next two lines are the only ones that use
// new_point_or_last_gradient_ as new_point_ (happens inside the lambda line_objective).
bracket.bracket(line_objective, Scalar(0), alpha_);
alpha_ = line_minimizer.optimize(line_objective, bracket).point;
n_evaluations_ += bracket.n_evaluations();
n_evaluations_ += line_minimizer.n_evaluations();
// Note: at some point new_point_ already had this new value, but right now there's no
// guarantee that it's the current value.
point_ += alpha_ * direction_;
// Can we re-use evaluation here? Would need to rewrite line search to have guarantees on
// which point it evaluated last. Probably best to do this after I've written dbrent.
last_value_ = value_;
gradient_.swap(last_gradient_);
// Basically this line means new_point_or_last_gradient_ = std::move(gradient_), but I want
// to be clear that we're not releasing any memory here. For the rest of the loop,
// new_point_or_last_gradient_ stores the last gradient.
new_point_or_last_gradient_.swap(gradient_);
// Again, we evaluted the objective here already during the line search, but we probably
// threw that data out.
// Note: at some point new_point_or_last_gradient_ already had this new value, but right now
// there's no guarantee that it's the current value.
point_ += alpha_ * direction_;
objective(point_, value_, gradient_);
++n_iterations_;
......@@ -151,10 +155,11 @@ VectorNs<N> const& ConjugateGradientDescent<N>::optimize(
return point_;
}
// Choose the next search direction. It is conjugate to all prior directions.
// I view this as the start of the next iteration.
Scalar const gamma =
gradient_.dot(gradient_ - last_gradient_) / last_gradient_.dot(last_gradient_);
// Choose the next search direction. It is conjugate to all prior directions. Recall that
// new_point_or_last_gradient_ currently stores the last gradient. I view this as the start
// of the next iteration.
Scalar const gamma = gradient_.dot(gradient_ - new_point_or_last_gradient_)
/ new_point_or_last_gradient_.dot(new_point_or_last_gradient_);
direction_ = gradient_ + gamma * direction_;
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment