diff --git a/optimization/optimizers/conjugate_gradient_descent/conjugate_gradient_descent.h b/optimization/optimizers/conjugate_gradient_descent/conjugate_gradient_descent.h
index 9393909896046c26eba5b52c37a1e1bf6f06816d..c9d0d38bcf5f63fe33e419ce272ca5d32f48b6e2 100644
--- a/optimization/optimizers/conjugate_gradient_descent/conjugate_gradient_descent.h
+++ b/optimization/optimizers/conjugate_gradient_descent/conjugate_gradient_descent.h
@@ -47,17 +47,16 @@ private:
     // algorithm state
     uint32_t n_evaluations_;
     uint32_t n_iterations_;
-    // This object holds the current point (x0) and search direction (dir).
-    // TODO allow this to live here, somehow
-    //LineObjective<Objective, N> line_objective_;
-    VectorNs<N> point_;
-    VectorNs<N> direction_;
-    VectorNs<N> new_point_;
+    Vector point_;         // where we are currently
+    Vector direction_;     // the direction we're searching
+    Vector gradient_;      // the gradient at point_
+    // This variable sees double duty. When we're line searching, it holds the candidate points.
+    // Between evaluating the function at the line minimum and computing the next search direction,
+    // it stores the gradient at the point of the previous iteration.
+    Vector new_point_or_last_gradient_;
     Scalar alpha_; // stores the most recent jump distance, which we use as a guess for the next one
     Scalar value_;
     Scalar last_value_;
-    Vector gradient_;
-    Vector last_gradient_;
 
     // algorithm constants
     static constexpr Scalar tiny_ = std::numeric_limits<Scalar>::epsilon();
@@ -85,13 +84,14 @@ VectorNs<N> const& ConjugateGradientDescent<N>::optimize(
     n_evaluations_ = 0;
     n_iterations_ = 0;
     point_ = initial_point;
-    alpha_ = -1;
+    direction_.resize(point_.size());
     gradient_.resize(point_.size());
-    last_gradient_.resize(point_.size());
+    new_point_or_last_gradient_.resize(point_.size());
+    alpha_ = -1;
 
     auto const line_objective = [&](Scalar t, Scalar& value) {
-        new_point_ = point_ + t * direction_;
-        objective(new_point_, value);
+        new_point_or_last_gradient_ = point_ + t * direction_;
+        objective(new_point_or_last_gradient_, value);
     };
     BracketFinder bracket;
     Brent line_minimizer;
@@ -108,19 +108,23 @@ VectorNs<N> const& ConjugateGradientDescent<N>::optimize(
     }
 
     while (true) {
-        // Find the minimum along this direction.
+        // Find the minimum along this direction. The next two lines are the only ones that use
+        // new_point_or_last_gradient_ as new_point_ (happens inside the lambda line_objective).
         bracket.bracket(line_objective, Scalar(0), alpha_);
         alpha_ = line_minimizer.optimize(line_objective, bracket).point;
         n_evaluations_ += bracket.n_evaluations();
         n_evaluations_ += line_minimizer.n_evaluations();
-        // Note: at some point new_point_ already had this new value, but right now there's no
-        // guarantee that it's the current value.
-        point_ += alpha_ * direction_;
 
-        // Can we re-use evaluation here? Would need to rewrite line search to have guarantees on
-        // which point it evaluated last. Probably best to do this after I've written dbrent.
         last_value_ = value_;
-        gradient_.swap(last_gradient_);
+        // Basically this line means new_point_or_last_gradient_ = std::move(gradient_), but I want
+        // to be clear that we're not releasing any memory here. For the rest of the loop,
+        // new_point_or_last_gradient_ stores the last gradient.
+        new_point_or_last_gradient_.swap(gradient_);
+        // Again, we evaluted the objective here already during the line search, but we probably
+        // threw that data out.
+        // Note: at some point new_point_or_last_gradient_ already had this new value, but right now
+        // there's no guarantee that it's the current value.
+        point_ += alpha_ * direction_;
         objective(point_, value_, gradient_);
         ++n_iterations_;
 
@@ -151,10 +155,11 @@ VectorNs<N> const& ConjugateGradientDescent<N>::optimize(
             return point_;
         }
 
-        // Choose the next search direction. It is conjugate to all prior directions.
-        // I view this as the start of the next iteration.
-        Scalar const gamma =
-            gradient_.dot(gradient_ - last_gradient_) / last_gradient_.dot(last_gradient_);
+        // Choose the next search direction. It is conjugate to all prior directions. Recall that
+        // new_point_or_last_gradient_ currently stores the last gradient. I view this as the start
+        // of the next iteration.
+        Scalar const gamma = gradient_.dot(gradient_ - new_point_or_last_gradient_)
+            / new_point_or_last_gradient_.dot(new_point_or_last_gradient_);
         direction_ = gradient_ + gamma * direction_;
     }
 }