Skip to content
Snippets Groups Projects
Commit 6ce152ca authored by Erik Strand's avatar Erik Strand
Browse files

Add max eval termination condition to cgd

parent 78b91604
Branches
No related tags found
No related merge requests found
......@@ -15,8 +15,16 @@ class ConjugateGradientDescent : public ConjugateGradientDescentLog<N> {
public:
using Vector = VectorNs<N>;
ConjugateGradientDescent(Scalar gt = 1e-8, Scalar pt = 1e-8, uint32_t mi = 1000)
: gradient_threshold_(gt), progress_threshold_(pt), max_iterations_(mi)
ConjugateGradientDescent(
Scalar gt = 1e-8,
Scalar pt = 1e-8,
uint32_t me = 10000,
uint32_t mi = 1000
) :
gradient_threshold_(gt),
progress_threshold_(pt),
max_evaluations_(me),
max_iterations_(mi)
{}
uint32_t n_evaluations() const { return n_evaluations_; }
......@@ -32,6 +40,7 @@ private:
// termination parameters
Scalar gradient_threshold_;
Scalar progress_threshold_;
uint32_t max_evaluations_;
uint32_t max_iterations_;
static constexpr Scalar tiny_ = std::numeric_limits<Scalar>::epsilon();
......@@ -106,6 +115,10 @@ VectorNs<N> ConjugateGradientDescent<N>::optimize(
std::cout << "reached progress threshold\n";
return line_objective.x0();
}
if (n_evaluations_ > max_evaluations_) {
std::cout << "reached max evaluations\n";
return line_objective.x0();
}
if (n_iterations_ > max_iterations_) {
std::cout << "reached max iterations\n";
return line_objective.x0();
......
......@@ -21,6 +21,7 @@ int main(int const argc, char const** argv) {
uint32_t dim = 2;
Scalar gradient_threshold = 1e-8;
Scalar progress_threshold = 1e-8;
uint32_t max_evaluations = 10000;
uint32_t max_iterations = 1000;
Scalar x0 = -1;
Scalar y0 = 2;
......@@ -31,7 +32,8 @@ int main(int const argc, char const** argv) {
clara::Opt(dim, "dim")["-d"]["--dim"]("Dimension of the search space") |
clara::Opt(gradient_threshold, "gradient_threshold")["-g"]["--gradient"]("Return if the gradient norm is this small") |
clara::Opt(progress_threshold, "progress_threshold")["-p"]["--progress"]("Return if the difference between two consecutive values is this small") |
clara::Opt(max_iterations, "max_iterations")["-n"]["--max-iterations"]("Maximum number of line minimization cycles") |
clara::Opt(max_evaluations, "max_evaluations")["-n"]["--max-evaluations"]("Maximum number of function evaluations") |
clara::Opt(max_iterations, "max_iterations")["-i"]["--max-iterations"]("Maximum number of line minimization cycles") |
clara::Opt(x0, "x0")["-x"]["--x0"]("X coordinate of initial point") |
clara::Opt(y0, "y0")["-y"]["--y0"]("Y coordinate of initial point");
auto const result = cli.parse(clara::Args(argc, argv));
......@@ -54,7 +56,12 @@ int main(int const argc, char const** argv) {
Objective objective;
objective.dim() = dim;
ConjugateGradientDescent<-1> optimizer(gradient_threshold, progress_threshold, max_iterations);
ConjugateGradientDescent<-1> optimizer(
gradient_threshold,
progress_threshold,
max_evaluations,
max_iterations
);
VectorXs minimum = optimizer.optimize(objective, initial_point);
std::cout << "n iterations: " << optimizer.n_iterations() << '\n';
std::cout << "n evaluations: " << optimizer.n_evaluations() << '\n';
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment