#include <iostream>
#include "n_body.h"

#include "optimizers/nelder_mead/nelder_mead.h"
#include "optimizers/conjugate_gradient_descent/conjugate_gradient_descent.h"
#include "objectives/rosenbrock.h"

//--------------------------------------------------------------------------------------------------
GlutGrapher grapher;
void display() { grapher.display(); }
void idle() { grapher.idle(); }
void mouse(int, int, int, int) { exit(0); }


//--------------------------------------------------------------------------------------------------
struct LaunchObjective {
    LaunchObjective() {
        masses_.resize(n_bodies_);
        masses_[0] = 1e1;  // planet
        masses_[1] = 1e-1; // satellite

        initial_conditions_.resize(2 * dim_ * n_bodies_);
        initial_conditions_.segment<dim_>(0 * dim_) = Vector2<Scalar>(0, 0);  // position of sun
        initial_conditions_.segment<dim_>(1 * dim_) = Vector2<Scalar>(-1, 0); // position of earth
        initial_conditions_.segment<dim_>(2 * dim_) = Vector2<Scalar>(0, 0);  // velocity of sun
        initial_conditions_.segment<dim_>(3 * dim_) = Vector2<Scalar>(0, 0);  // velocity of earth
    }

    NBody& nbody() { return nbody_; }

    void prep_initial_conditions(VectorX<Scalar> const& velocity) {
        initial_conditions_.segment<dim_>(3 * dim_) = velocity;
        nbody_.initialize(dim_, masses_, initial_conditions_, delta_t_);
    }

    void operator()(VectorX<Scalar> const& velocity, Scalar& loss) {
        // prepare initial conditions
        initial_conditions_.segment<dim_>(3 * dim_) = velocity;

        // run the simulation
        nbody_.initialize(dim_, masses_, initial_conditions_, delta_t_);
        for (uint32_t i = 0; i < n_steps_; ++i) {
            nbody_.step();
        }

        // compute loss
        Vector2<Scalar> const goal = Vector2<Scalar>(0, 1);
        Vector2<Scalar> const difference = nbody_.state().segment<dim_>(dim_) - goal;
        loss = difference.squaredNorm();

        // print stuff
        std::cout << "loss = " << loss << "\n";
    }

    void operator()(VectorX<Scalar> const& velocity, Scalar& loss, VectorX<Scalar>& gradient) {
        // prepare initial conditions
        initial_conditions_.segment<dim_>(3 * dim_) = velocity;

        // run the simulation
        nbody_.initialize(dim_, masses_, initial_conditions_, delta_t_);
        for (uint32_t i = 0; i < n_steps_; ++i) {
            nbody_.step();
        }

        // compute loss
        Vector2<Scalar> const goal = Vector2<Scalar>(0, 1);
        Vector2<Scalar> const difference = nbody_.state().segment<dim_>(dim_) - goal;
        loss = difference.squaredNorm();

        // compute gradient of loss
        // loss = difference^T * difference
        // d_loss = 2 * difference^T * d_difference
        gradient = 2 * difference.transpose() * nbody_.state_jacobian().block<dim_, dim_>(dim_, 3 * dim_);

        // print stuff
        std::cout << "loss = " << loss << "\n";
        std::cout << "loss gradient\n";
        std::cout << gradient << '\n';
        //std::cout << "state jacobian\n";
        //std::cout << nbody_.state_jacobian() << '\n';
    }

    constexpr static uint32_t dim_ = 2;
    constexpr static uint32_t n_bodies_ = 2;
    constexpr static Scalar sim_time_ = 1.4;
    constexpr static Scalar delta_t_ = 1e-3;
    constexpr static uint32_t n_steps_ = static_cast<uint32_t>(sim_time_ / delta_t_);

    VectorX<Scalar> masses_;
    VectorX<Scalar> initial_conditions_;

    NBody nbody_;
};

//--------------------------------------------------------------------------------------------------
int main() {
    LaunchObjective objective;

    VectorX<Scalar> satellite_velocity;
    satellite_velocity.resize(2);
    satellite_velocity[0] = 0;
    satellite_velocity[1] = 0.4;

    //Scalar loss;
    //VectorX<Scalar> gradient;
    //objective(satellite_velocity, loss, gradient);
    //std::cout << '\n';

    bool do_nm = false;
    bool do_cgd = true;
    bool do_vis = true;

    if (do_nm) {
        optimization::NelderMead<-1> nelder_mead(
            10000, // max evaluations
            1e-6   // relative y tolerance
        );
        nelder_mead.optimize(objective, satellite_velocity, 0.5);
        satellite_velocity = nelder_mead.point();
        std::cout << "n evaluations: " << nelder_mead.n_evaluations() << '\n';
        std::cout << "final point: " << nelder_mead.point() << '\n';
        std::cout << "final loss: " << nelder_mead.value() << '\n';
        std::cout << '\n';
    }

    if (do_cgd) {
        optimization::ConjugateGradientDescent<-1> cgd(
            1e-3,   // gradient threshold
            -1,     // abs threshold
            100000, // max evaluations
            10000   // max iterations
        );
        cgd.optimize(objective, satellite_velocity);
        satellite_velocity = cgd.point();
        std::cout << "n evaluations: " << cgd.n_evaluations() << '\n';
        std::cout << "n iterations: " << cgd.n_iterations() << '\n';
        std::cout << "final point: " << cgd.point() << '\n';
        std::cout << '\n';
    }

    if (do_vis) {
        objective.prep_initial_conditions(satellite_velocity);

        grapher.initialize(
            &objective.nbody(),
            BoundingBox<Vector2<Scalar>>(
                {-2, -2},
                {2, 2}
            )
        );
        grapher.init_glut();
        glutDisplayFunc(display);
        glutMouseFunc(mouse);
        glutIdleFunc(idle);
        grapher.run();
    }

    return 0;
}