Skip to content
Snippets Groups Projects
Commit 434cc3c3 authored by Erik Strand's avatar Erik Strand
Browse files

Implement Brent's method (without derivatives)

parent d39ebb29
Branches
Tags
No related merge requests found
#ifndef OPTIMIZATION_OPTIMIZERS_LINE_SEARCH_BRENT_H
#define OPTIMIZATION_OPTIMIZERS_LINE_SEARCH_BRENT_H
#include "bracket.h"
#include "objectives/samples.h"
#include <cmath>
#include <limits>
#include <iostream>
namespace optimization {
//--------------------------------------------------------------------------------------------------
class Brent {
public:
// Note:
// - the default absolute tolerance is equivalent to no absolute termination condition
// - the default relative tolerance is suitable for double precision floats
Brent(Scalar abs_tol = -1, Scalar rel_tol = 3e-8, uint32_t me = 100)
: abs_tol_(abs_tol), rel_tol_(rel_tol), max_evaluations_(me)
{}
// Brent's method evaluates the function exactly once per iteration.
uint32_t n_iterations() const { return n_evaluations_; }
uint32_t n_evaluations() const { return n_evaluations_; }
template <typename Objective>
Sample<Scalar> optimize(Objective& objective, Bracket const& bracket);
private:
uint32_t n_evaluations_;
// Goal for absolute width of the bracket. Put a negative number if you want no abs tol.
Scalar abs_tol_;
// Goal for width of bracket relative to central value. Should always have this (and we use its
// absolute value, so setting it negative won't help).
Scalar rel_tol_;
uint32_t max_evaluations_;
static constexpr Scalar golden_ratio_small_ = 0.3819660;
static constexpr Scalar tiny_ = std::numeric_limits<Scalar>::epsilon() * 1e-3;
};
//..................................................................................................
template <typename Objective>
Sample<Scalar> Brent::optimize(Objective& objective, Bracket const& bracket) {
n_evaluations_ = 0;
// These two points define our bracket. Invariant: a < b.
Scalar a, b;
if (bracket.x_1() < bracket.x_3()) {
a = bracket.x_1();
b = bracket.x_3();
} else {
a = bracket.x_3();
b = bracket.x_1();
}
// These are the points and values of the three best points found so far. Invariants:
// a <= x_1 <= b
// a <= x_2 <= b
// a <= x_3 <= b
// y_1 <= y_2 <= y_3.
Scalar x_1, x_2, x_3;
Scalar y_1, y_2, y_3;
x_1 = bracket.x_2();
y_1 = bracket.y_2();
if (bracket.y_1() <= bracket.y_3()) {
x_2 = bracket.x_1();
y_2 = bracket.y_1();
x_3 = bracket.x_3();
y_3 = bracket.y_3();
} else {
x_2 = bracket.x_3();
y_2 = bracket.y_3();
x_3 = bracket.x_1();
y_3 = bracket.y_1();
}
// This variable is used to store the next step (as a displacement from x_1).
Scalar step = 0;
// When we take parabolic steps, this variable records the last step size. When we take golden
// section steps, this variable records the size of the section of the bracket that we stepped
// into (i.e. what was the larger half of the bracket).
Scalar prev_step = std::abs(x_3 - x_2);
while (true) {
// Check the absolute termination condition.
if (b - a <= abs_tol_) {
return Sample<Scalar>(x_1, y_1);
}
// The midpoint of the current bracket.
Scalar const midpoint = 0.5 * (a + b);
// Note that tol_1 and tol_2 are non-negative.
Scalar const tol_1 = std::abs(rel_tol_ * x_1) + tiny_;
Scalar const tol_2 = 2.0 * tol_1;
// Check the relative termination condition.
if (std::abs(x_1 - midpoint) <= (tol_2 - 0.5 * (b - a))) {
return Sample<Scalar>(x_1, y_1);
}
// Try a parabolic fit using x_1, x_2, and x_3.
// The minimum of the parabola is at x_1 + numerator / denominator.
// Note that the denominator is always positive (we put the sign in the numerator).
Scalar const tmp_1 = (x_1 - x_2) * (y_1 - y_3);
Scalar const tmp_2 = (x_1 - x_3) * (y_1 - y_2);
Scalar const sgn = (tmp_2 >= tmp_1) ? 1 : -1;
Scalar const numerator = -sgn * ((x_1 - x_3) * tmp_2 - (x_1 - x_2) * tmp_1);
Scalar const denominator = sgn * 2.0 * (tmp_2 - tmp_1);
// For us to use this point, we require that the step size is sufficiently small, and
// that the point is within our bracket. The reason we've kept the numerator and
// denominator separate, and made the latter positive, is so we can write these tests in
// forms that work even when the denominator is zero. We end up with three conditions:
//
// 1. The proposed step size is less than half the second to last step size (prev_step).
// | numerator / denominator | < 0.5 * prev_step
// ==> | numerator | < 0.5 * | denominator * prev_step |
//
// 2. The parabolic minimum must be to the right of a
// x_1 + numerator / denominator > a
// ==> x_1 numerator > denominator * (a - x_1)
//
// 3. The parabolic minimum must be to the left of b
// x_1 + numerator / denominator < b
// ==> x_1 numerator < denominator * (b - x_1)
if (
std::abs(numerator) < 0.5 * std::abs(denominator * prev_step)
&& numerator > denominator * (a - x_1)
&& numerator < denominator * (b - x_1)
) {
// Take the parabolic step.
prev_step = step;
step = numerator / denominator;
Scalar const x_new = x_1 + step;
// Make sure we're not stepping too close to a or b.
if (x_new - a < tol_2 || b - x_new < tol_2) {
step = (midpoint - x_1 >= 0) ? tol_1 : -tol_1;
}
} else {
// Take a golden section step.
// This "resets" prev_step, to be the size of the half of the bracket we step into.
prev_step = (x_1 >= midpoint) ? (a - x_1) : (b - x_1);
step = golden_ratio_small_ * prev_step;
}
// Ensure the step isn't too small.
if (std::abs(step) < tol_1) {
step = (step >= 0) ? tol_1 : -tol_1;
}
// Take the step and evaluate the result.
Scalar const x_new = x_1 + step;
Scalar y_new;
objective.eval(x_new, y_new);
++n_evaluations_;
if (y_new <= y_1) {
// y_new is the best point we've seen so far, so x_1 becomes a bracket bound.
if (x_new >= x_1) {
a = x_1;
} else {
b = x_1;
}
// Update our three points.
x_3 = x_2;
x_2 = x_1;
x_1 = x_new;
y_3 = y_2;
y_2 = y_1;
y_1 = y_new;
} else {
// x_1 is still the best point, so x_new becomes a bracket bound.
if (x_new < x_1) {
a = x_new;
} else {
b = x_new;
}
// Update our three points.
if (y_new <= y_2) {
x_3 = x_2;
x_2 = x_new;
y_3 = y_2;
y_2 = y_new;
} else if (y_new <= y_3) {
x_3 = x_new;
y_3 = y_new;
}
}
// Check failsafe termination condition.
if (n_evaluations_ > max_evaluations_) {
return Sample<Scalar>(x_1, y_1);
}
}
}
}
#endif
add_executable(test
main.cpp
optimizers/line_search/bracket.cpp
optimizers/line_search/brent.cpp
optimizers/line_search/golden_section.cpp
)
target_link_libraries(test optimization_lib catch2)
#include "catch.hpp"
#include "optimizers/line_search/brent.h"
using namespace optimization;
//--------------------------------------------------------------------------------------------------
TEST_CASE("Brent", "[Brent]") {
// The 1e-8 sets the absolute tolerance on the width of the bracket.
// There's still a default relative tolerance in effect as well, but these tests don't hit it.
Brent brent(1e-8);
Bracket bracket;
Sample<Scalar> result;
SECTION("parabola") {
struct Parabola {
void eval(Scalar x, Scalar& y) const { y = x * x; }
};
Parabola parabola;
bracket = Bracket(-2, -1, 2, 4, 1, 4);
result = brent.optimize(parabola, bracket);
// The parabola is flat here so y accuracy had better exceed x accuracy.
REQUIRE(std::abs(result.point) < 1e-8);
REQUIRE(std::abs(result.value) < 1e-16);
// Just a sanity check.
REQUIRE(brent.n_evaluations() < 100);
bracket = Bracket(50, 10, -100, 2500, 100, 10000);
result = brent.optimize(parabola, bracket);
// The parabola is flat here so y accuracy had better exceed x accuracy.
REQUIRE(std::abs(result.point) < 1e-8);
REQUIRE(std::abs(result.value) < 1e-16);
// Just a sanity check.
REQUIRE(brent.n_evaluations() < 100);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment