Skip to content
Snippets Groups Projects
Commit 78507241 authored by Erik Strand's avatar Erik Strand
Browse files

Generalize dimension in main CGD app

parent a23a6ef4
No related branches found
No related tags found
No related merge requests found
...@@ -17,7 +17,6 @@ public: ...@@ -17,7 +17,6 @@ public:
uint32_t& dim() { return dim_; } uint32_t& dim() { return dim_; }
void eval(VectorNs<N> const& x, Scalar& value) { void eval(VectorNs<N> const& x, Scalar& value) {
dim_ = x.size();
value = Scalar(0); value = Scalar(0);
for (uint32_t i = 1; i < dim_; ++i) { for (uint32_t i = 1; i < dim_; ++i) {
Scalar const x_squared = x[i - 1] * x[i - 1]; Scalar const x_squared = x[i - 1] * x[i - 1];
...@@ -28,7 +27,6 @@ public: ...@@ -28,7 +27,6 @@ public:
} }
void eval(VectorNs<N> const& x, Scalar& value, VectorNs<N>& gradient) { void eval(VectorNs<N> const& x, Scalar& value, VectorNs<N>& gradient) {
dim_ = x.size();
value = Scalar(0); value = Scalar(0);
gradient.resize(dim_); gradient.resize(dim_);
gradient.setZero(); gradient.setZero();
......
...@@ -57,8 +57,8 @@ VectorNs<N> ConjugateGradientDescent<N>::optimize( ...@@ -57,8 +57,8 @@ VectorNs<N> ConjugateGradientDescent<N>::optimize(
Scalar alpha = -1; Scalar alpha = -1;
Scalar value, last_value; Scalar value, last_value;
Vector gradient, last_gradient; Vector gradient, last_gradient;
gradient.resize(objective.dim()); gradient.resize(initial_point.size());
last_gradient.resize(objective.dim()); last_gradient.resize(initial_point.size());
objective.eval(line_objective.x0(), value, gradient); objective.eval(line_objective.x0(), value, gradient);
++n_evaluations_; ++n_evaluations_;
......
...@@ -67,8 +67,8 @@ void ConjugateGradientDescentLog<N>::initialize( ...@@ -67,8 +67,8 @@ void ConjugateGradientDescentLog<N>::initialize(
) { ) {
objective_name = Objective::name; objective_name = Objective::name;
states.emplace_back(); states.emplace_back();
states.back().direction.resize(N); states.back().direction.resize(point.size());
states.back().direction.setZero(N); states.back().direction.setZero();
states.back().point = point; states.back().point = point;
states.back().gradient = gradient; states.back().gradient = gradient;
states.back().value = value; states.back().value = value;
......
...@@ -14,6 +14,7 @@ using json = nlohmann::json; ...@@ -14,6 +14,7 @@ using json = nlohmann::json;
int main(int const argc, char const** argv) { int main(int const argc, char const** argv) {
std::string log_file_path; std::string log_file_path;
std::string vis_file_path; std::string vis_file_path;
uint32_t dim = 2;
Scalar gradient_threshold = 1e-8; Scalar gradient_threshold = 1e-8;
Scalar progress_threshold = 1e-8; Scalar progress_threshold = 1e-8;
uint32_t max_iterations = 1000; uint32_t max_iterations = 1000;
...@@ -23,6 +24,7 @@ int main(int const argc, char const** argv) { ...@@ -23,6 +24,7 @@ int main(int const argc, char const** argv) {
auto const cli = auto const cli =
clara::Arg(log_file_path, "log_file_path")("Location of the optimization log") | clara::Arg(log_file_path, "log_file_path")("Location of the optimization log") |
clara::Arg(vis_file_path, "vis_file_path")("Location of the visualization file") | clara::Arg(vis_file_path, "vis_file_path")("Location of the visualization file") |
clara::Opt(dim, "dim")["-d"]["--dim"]("Dimension of the search space") |
clara::Opt(gradient_threshold, "gradient_threshold")["-g"]["--gradient"]("Return if the gradient norm is this small") | clara::Opt(gradient_threshold, "gradient_threshold")["-g"]["--gradient"]("Return if the gradient norm is this small") |
clara::Opt(progress_threshold, "progress_threshold")["-p"]["--progress"]("Return if the difference between two consecutive values is this small") | clara::Opt(progress_threshold, "progress_threshold")["-p"]["--progress"]("Return if the difference between two consecutive values is this small") |
clara::Opt(max_iterations, "max_iterations")["-n"]["--max-iterations"]("Maximum number of line minimization cycles") | clara::Opt(max_iterations, "max_iterations")["-n"]["--max-iterations"]("Maximum number of line minimization cycles") |
...@@ -34,14 +36,21 @@ int main(int const argc, char const** argv) { ...@@ -34,14 +36,21 @@ int main(int const argc, char const** argv) {
exit(1); exit(1);
} }
Vector2<Scalar> initial_point = Vector2<Scalar>(x0, y0); VectorXs initial_point;
initial_point.resize(dim);
initial_point[0] = x0;
initial_point[1] = y0;
for (uint32_t i = 2; i < dim; ++i) {
initial_point[i] = -1;
}
//using Objective = Paraboloid<Vector2<Scalar>>; //using Objective = Paraboloid<Vector2<Scalar>>;
//Objective objective(dim); //Objective objective(dim);
using Objective = Rosenbrock<2>; using Objective = Rosenbrock<-1>;
Objective objective; Objective objective;
ConjugateGradientDescent<2> optimizer(gradient_threshold, progress_threshold, max_iterations); objective.dim() = dim;
VectorNs<2> minimum = optimizer.optimize(objective, initial_point); ConjugateGradientDescent<-1> optimizer(gradient_threshold, progress_threshold, max_iterations);
VectorXs minimum = optimizer.optimize(objective, initial_point);
std::cout << "n iterations: " << optimizer.n_iterations() << '\n'; std::cout << "n iterations: " << optimizer.n_iterations() << '\n';
std::cout << "n evaluations: " << optimizer.n_evaluations() << '\n'; std::cout << "n evaluations: " << optimizer.n_evaluations() << '\n';
std::cout << "final point: " << minimum << '\n'; std::cout << "final point: " << minimum << '\n';
...@@ -53,7 +62,7 @@ int main(int const argc, char const** argv) { ...@@ -53,7 +62,7 @@ int main(int const argc, char const** argv) {
} }
if (!vis_file_path.empty()) { if (!vis_file_path.empty()) {
json data = ConjugateGradientDescentVis<2>{optimizer}; json data = ConjugateGradientDescentVis<-1>{optimizer};
std::ofstream vis_file(vis_file_path); std::ofstream vis_file(vis_file_path);
vis_file << data.dump(4) << '\n'; vis_file << data.dump(4) << '\n';
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment