From c2df5e1612eab5d1d5d071b411ff27e797a98a21 Mon Sep 17 00:00:00 2001
From: Erik Strand <erik.strand@cba.mit.edu>
Date: Fri, 17 Apr 2020 10:29:48 -0400
Subject: [PATCH] Make objectives compatible with lambda functions

Now everything uses operator() instead of eval.
---
 optimization/objectives/paraboloid.h             |  4 ++--
 optimization/objectives/rosenbrock.h             |  4 ++--
 optimization/optimizers/cma_es/cma_es.h          |  2 +-
 .../conjugate_gradient_descent.h                 |  4 ++--
 .../gradient_descent/gradient_descent.h          |  2 +-
 optimization/optimizers/line_search/bracket.h    | 16 ++++++++--------
 optimization/optimizers/line_search/brent.h      |  2 +-
 .../optimizers/line_search/golden_section.h      |  8 ++++----
 .../optimizers/line_search/line_objective.h      |  8 ++++----
 .../optimizers/nelder_mead/nelder_mead.h         |  6 +++---
 test/optimizers/line_search/bracket.cpp          |  2 +-
 test/optimizers/line_search/brent.cpp            |  2 +-
 test/optimizers/line_search/golden_section.cpp   |  2 +-
 13 files changed, 31 insertions(+), 31 deletions(-)

diff --git a/optimization/objectives/paraboloid.h b/optimization/objectives/paraboloid.h
index 7352f60..c10e1db 100644
--- a/optimization/objectives/paraboloid.h
+++ b/optimization/objectives/paraboloid.h
@@ -20,14 +20,14 @@ public:
     uint32_t dim() const { return dim_; }
     uint32_t& dim() { return dim_; }
 
-    void eval(Input const& x, Scalar& value) const {
+    void operator()(Input const& x, Scalar& value) const {
         value = 0;
         for (uint32_t d = 0; d < dim_; ++d) {
             value += x[d] * x[d];
         }
     }
 
-    void eval(Input const& x, Scalar& value, Gradient& gradient) const {
+    void operator()(Input const& x, Scalar& value, Gradient& gradient) const {
         value = 0;
         for (uint32_t d = 0; d < dim_; ++d) {
             value += x[d] * x[d];
diff --git a/optimization/objectives/rosenbrock.h b/optimization/objectives/rosenbrock.h
index 7c5aaae..607436d 100644
--- a/optimization/objectives/rosenbrock.h
+++ b/optimization/objectives/rosenbrock.h
@@ -16,7 +16,7 @@ public:
     uint32_t dim() const { return dim_; }
     uint32_t& dim() { return dim_; }
 
-    void eval(VectorNs<N> const& x, Scalar& value) {
+    void operator()(VectorNs<N> const& x, Scalar& value) {
         value = Scalar(0);
         for (uint32_t i = 1; i < dim_; ++i) {
             Scalar const x_squared = x[i - 1] * x[i - 1];
@@ -26,7 +26,7 @@ public:
         }
     }
 
-    void eval(VectorNs<N> const& x, Scalar& value, VectorNs<N>& gradient) {
+    void operator()(VectorNs<N> const& x, Scalar& value, VectorNs<N>& gradient) {
         value = Scalar(0);
         gradient.resize(dim_);
         gradient.setZero();
diff --git a/optimization/optimizers/cma_es/cma_es.h b/optimization/optimizers/cma_es/cma_es.h
index dbf8f12..7cdfa00 100644
--- a/optimization/optimizers/cma_es/cma_es.h
+++ b/optimization/optimizers/cma_es/cma_es.h
@@ -171,7 +171,7 @@ Eigen::Map<const Eigen::MatrixXd> CmaEs::optimize(
             auto point_map = Eigen::Map<const Eigen::MatrixXd>(points_[i], dim_, 1);
             // TODO: make objectives handle Maps so we don't have to copy the data.
             point_vec = point_map;
-            objective.eval(point_vec, values_[i]);
+            objective(point_vec, values_[i]);
         }
 
         log.push_back(points_, values_, dim_, pop_size_);
diff --git a/optimization/optimizers/conjugate_gradient_descent/conjugate_gradient_descent.h b/optimization/optimizers/conjugate_gradient_descent/conjugate_gradient_descent.h
index 13a5d5c..1dcb6b9 100644
--- a/optimization/optimizers/conjugate_gradient_descent/conjugate_gradient_descent.h
+++ b/optimization/optimizers/conjugate_gradient_descent/conjugate_gradient_descent.h
@@ -94,7 +94,7 @@ VectorNs<N> const& ConjugateGradientDescent<N>::optimize(
     BracketFinder bracket;
     Brent line_minimizer;
 
-    objective.eval(point_, value_, gradient_);
+    objective(point_, value_, gradient_);
     ++n_evaluations_;
     direction_ = gradient_;
 
@@ -119,7 +119,7 @@ VectorNs<N> const& ConjugateGradientDescent<N>::optimize(
         // which point it evaluated last. Probably best to do this after I've written dbrent.
         last_value_ = value_;
         gradient_.swap(last_gradient_);
-        objective.eval(point_, value_, gradient_);
+        objective(point_, value_, gradient_);
         ++n_iterations_;
 
         log.push_back(
diff --git a/optimization/optimizers/gradient_descent/gradient_descent.h b/optimization/optimizers/gradient_descent/gradient_descent.h
index 1a92096..e8a7bf6 100644
--- a/optimization/optimizers/gradient_descent/gradient_descent.h
+++ b/optimization/optimizers/gradient_descent/gradient_descent.h
@@ -69,7 +69,7 @@ VectorNs<N> const& GradientDescent<N>::optimize(
     log.initialize(objective);
 
     while (true) {
-        objective.eval(point_, value_, gradient_);
+        objective(point_, value_, gradient_);
         ++n_evaluations_;
         log.push_back(point_, value_, gradient_);
 
diff --git a/optimization/optimizers/line_search/bracket.h b/optimization/optimizers/line_search/bracket.h
index 8d94f05..89de998 100644
--- a/optimization/optimizers/line_search/bracket.h
+++ b/optimization/optimizers/line_search/bracket.h
@@ -61,8 +61,8 @@ void BracketFinder::bracket(Objective& objective, Scalar x_1, Scalar x_2) {
     // Copy in starting values and perform first evaluations.
     x_1_ = x_1;
     x_2_ = x_2;
-    objective.eval(x_1_, y_1_);
-    objective.eval(x_2_, y_2_);
+    objective(x_1_, y_1_);
+    objective(x_2_, y_2_);
     n_evaluations_ = 2;
 
     // Ensure that y_2_ < y_1_.
@@ -75,7 +75,7 @@ void BracketFinder::bracket(Objective& objective, Scalar x_1, Scalar x_2) {
     // First try a golden ratio step.
     // From here on out, either x_1_ < x_2_ < x_3_ or x_1_ > x_2_ > x_3_.
     x_3_ = x_2_ + golden_ratio_ * (x_2_ - x_1_);
-    objective.eval(x_3_, y_3_);
+    objective(x_3_, y_3_);
     ++n_evaluations_;
 
     // Search until we have a bracket.
@@ -100,7 +100,7 @@ void BracketFinder::bracket(Objective& objective, Scalar x_1, Scalar x_2) {
 
         if ((x_new - x_2_) * (x_3_ - x_new) > Scalar(0)) {
             // If the new point between x_2_ and x_3_, try it.
-            objective.eval(x_new, y_new);
+            objective(x_new, y_new);
             ++n_evaluations_;
 
             if (y_new < y_3_) {
@@ -118,7 +118,7 @@ void BracketFinder::bracket(Objective& objective, Scalar x_1, Scalar x_2) {
             }
         } else if ((x_new - x_3_) * (x_lim - x_new) > Scalar(0)) {
             // If the new point seems to be downhill and is not super far away, try it.
-            objective.eval(x_new, y_new);
+            objective(x_new, y_new);
             ++n_evaluations_;
 
             if (y_new < y_3_) {
@@ -129,18 +129,18 @@ void BracketFinder::bracket(Objective& objective, Scalar x_1, Scalar x_2) {
                 x_new += step;
                 y_2_ = y_3_;
                 y_3_ = y_new;
-                objective.eval(x_new, y_new);
+                objective(x_new, y_new);
                 ++n_evaluations_;
             }
         } else if ((x_lim - x_3_) * (x_new - x_lim) >= Scalar(0)) {
             // If the new point is past x_lim, try x_lim next.
             x_new = x_lim;
-            objective.eval(x_new, y_new);
+            objective(x_new, y_new);
             ++n_evaluations_;
         } else {
             // If the new point seems to be too far uphill, just do a golden ratio step.
             x_new = x_3_ + golden_ratio_ * (x_3_ - x_2_);
-            objective.eval(x_new, y_new);
+            objective(x_new, y_new);
             ++n_evaluations_;
         }
 
diff --git a/optimization/optimizers/line_search/brent.h b/optimization/optimizers/line_search/brent.h
index 26f65d6..4c5b9ac 100644
--- a/optimization/optimizers/line_search/brent.h
+++ b/optimization/optimizers/line_search/brent.h
@@ -156,7 +156,7 @@ Sample<Scalar> Brent::optimize(Objective& objective, Bracket const& bracket) {
         // Take the step and evaluate the result.
         Scalar const x_new = x_1 + step;
         Scalar y_new;
-        objective.eval(x_new, y_new);
+        objective(x_new, y_new);
         ++n_evaluations_;
 
         if (y_new <= y_1) {
diff --git a/optimization/optimizers/line_search/golden_section.h b/optimization/optimizers/line_search/golden_section.h
index 17eb700..0d19417 100644
--- a/optimization/optimizers/line_search/golden_section.h
+++ b/optimization/optimizers/line_search/golden_section.h
@@ -61,14 +61,14 @@ Sample<Scalar> GoldenSection::optimize(Objective& objective, Bracket const& brac
         x_1 = bracket.x_2();
         y_1 = bracket.y_2();
         x_2 = x_1 + golden_ratio_small_ * (x_3 - x_1);
-        objective.eval(x_2, y_2);
+        objective(x_2, y_2);
     } else {
         // If the middle of the bracket is closer to the right edge, interpolate into the left half
         // of the bracket.
         x_2 = bracket.x_2();
         y_2 = bracket.y_2();
         x_1 = x_2 - golden_ratio_small_ * (x_2 - x_0);
-        objective.eval(x_1, y_1);
+        objective(x_1, y_1);
     }
     n_evaluations_ = 1;
 
@@ -83,7 +83,7 @@ Sample<Scalar> GoldenSection::optimize(Objective& objective, Bracket const& brac
             x_1 = x_2;
             x_2 = golden_ratio_big_ * x_2 + golden_ratio_small_ * x_3;
             y_1 = y_2;
-            objective.eval(x_2, y_2);
+            objective(x_2, y_2);
             ++n_evaluations_;
         } else {
             // x_2 is our new left edge; interpolate between x_0 and x_1
@@ -91,7 +91,7 @@ Sample<Scalar> GoldenSection::optimize(Objective& objective, Bracket const& brac
             x_2 = x_1;
             x_1 = golden_ratio_small_ * x_0 + golden_ratio_big_ * x_1;
             y_2 = y_1;
-            objective.eval(x_1, y_1);
+            objective(x_1, y_1);
             ++n_evaluations_;
         }
     }
diff --git a/optimization/optimizers/line_search/line_objective.h b/optimization/optimizers/line_search/line_objective.h
index e52b8f9..7a7bb55 100644
--- a/optimization/optimizers/line_search/line_objective.h
+++ b/optimization/optimizers/line_search/line_objective.h
@@ -13,14 +13,14 @@ public:
         objective_(o), x0_(x0), dir_(dir), x_(x)
     {}
 
-    void eval(Scalar t, Scalar& value) {
+    void operator()(Scalar t, Scalar& value) {
         x_ = x0_ + t * dir_;
-        objective_.eval(x_, value);
+        objective_(x_, value);
     }
 
-    void eval(Scalar t, Scalar& value, VectorNs<N>& gradient) {
+    void operator()(Scalar t, Scalar& value, VectorNs<N>& gradient) {
         x_ = x0_ + t * dir_;
-        objective_.eval(x_, value, gradient);
+        objective_(x_, value, gradient);
     }
 
 private:
diff --git a/optimization/optimizers/nelder_mead/nelder_mead.h b/optimization/optimizers/nelder_mead/nelder_mead.h
index 477065c..2f0fbcb 100644
--- a/optimization/optimizers/nelder_mead/nelder_mead.h
+++ b/optimization/optimizers/nelder_mead/nelder_mead.h
@@ -156,7 +156,7 @@ decltype(auto) NelderMead<D>::optimize(Objective& objective, MatrixDN const& sim
 
     // Evaluate the objective at all simplex vertices.
     for (uint32_t i = 0u; i < n_vertices_; ++i) {
-        objective.eval(simplex_vertices_.col(i), simplex_values_[i]);
+        objective(simplex_vertices_.col(i), simplex_values_[i]);
     }
     n_evaluations_ = n_vertices_;
 
@@ -224,7 +224,7 @@ decltype(auto) NelderMead<D>::optimize(Objective& objective, MatrixDN const& sim
             if (i != i_lowest_) {
                 simplex_vertices_.col(i) =
                     shrinking_coefficient_ * (simplex_vertices_.col(i_lowest_) + simplex_vertices_.col(i));
-                objective.eval(simplex_vertices_.col(i), simplex_values_[i]);
+                objective(simplex_vertices_.col(i), simplex_values_[i]);
             }
         }
         n_evaluations_ += dim_;
@@ -244,7 +244,7 @@ Scalar NelderMead<D>::try_new_point(Objective& objective, Scalar factor) {
 
     // Evaluate the new point.
     Scalar y_new;
-    objective.eval(x_new, y_new);
+    objective(x_new, y_new);
     ++n_evaluations_;
 
     // If the new point is an improvement, keep it.
diff --git a/test/optimizers/line_search/bracket.cpp b/test/optimizers/line_search/bracket.cpp
index b2dd1ae..6440509 100644
--- a/test/optimizers/line_search/bracket.cpp
+++ b/test/optimizers/line_search/bracket.cpp
@@ -9,7 +9,7 @@ TEST_CASE("Bracket", "[Bracket]") {
 
     SECTION("parabola") {
         struct Parabola {
-            void eval(Scalar x, Scalar& y) const { y = x * x; }
+            void operator()(Scalar x, Scalar& y) const { y = x * x; }
         };
         Parabola parabola;
 
diff --git a/test/optimizers/line_search/brent.cpp b/test/optimizers/line_search/brent.cpp
index 2cca032..0c3e5f6 100644
--- a/test/optimizers/line_search/brent.cpp
+++ b/test/optimizers/line_search/brent.cpp
@@ -13,7 +13,7 @@ TEST_CASE("Brent", "[Brent]") {
 
     SECTION("parabola") {
         struct Parabola {
-            void eval(Scalar x, Scalar& y) const { y = x * x; }
+            void operator()(Scalar x, Scalar& y) const { y = x * x; }
         };
         Parabola parabola;
 
diff --git a/test/optimizers/line_search/golden_section.cpp b/test/optimizers/line_search/golden_section.cpp
index 875c659..7f81325 100644
--- a/test/optimizers/line_search/golden_section.cpp
+++ b/test/optimizers/line_search/golden_section.cpp
@@ -13,7 +13,7 @@ TEST_CASE("GoldenSection", "[GoldenSection]") {
 
     SECTION("parabola") {
         struct Parabola {
-            void eval(Scalar x, Scalar& y) const { y = x * x; }
+            void operator()(Scalar x, Scalar& y) const { y = x * x; }
         };
         Parabola parabola;
 
-- 
GitLab