import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sympy
from sympy import *

'''
Search week
'''

# A) Plot the rosenbrock function in 3D
def rosenbrock_2d(coords):
    x = coords[0]
    y = coords[1]
    return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2

bound = 1
z_bound = 100
x = np.linspace(-bound, bound, z_bound)
y = np.linspace(-bound, bound, z_bound)

X, Y = np.meshgrid(x, y)
Z = rosenbrock_2d([X, Y])

fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot_surface(X, Y, Z, cmap='viridis', alpha=0.9)

# D) Rosenbrock search with gradient descent

x, y = sympy.symbols('x y')
rosenbrock = (1 - x)**2 + 100*(y - x**2)**2

# Find the gradient of the rosenbrock function with sympy
def find_diff():
    rosenbrock_diff = sympy.derive_by_array(rosenbrock, (x, y))
    return rosenbrock_diff

def evaluate_gradient(derivative, variables, substitute):
    subs_dict = dict(zip(variables, substitute))
    Gradient = derivative.subs(subs_dict)
    return Gradient

def gradient_descent(starting_point, learning_rate, iterations, tol=1e-2):
    x_num = starting_point[0]
    y_num = starting_point[1]
    xs = [x_num]
    ys = [y_num]
    inertia = np.array([0, 0])
    zs = [rosenbrock_2d([x_num, y_num])]
    loop_count = 0
    for i in range(iterations):
        gradient = evaluate_gradient(rosenbrock_diff, [x, y], [x_num, y_num]) + inertia
        x_num = x_num - learning_rate * gradient[0]
        y_num = y_num - learning_rate * gradient[1]
        z_num = rosenbrock_2d([x_num, y_num])
        xs.append(x_num)
        ys.append(y_num)
        zs.append(z_num)
        inertia = gradient * 0.65
        loop_count += 1
        # if the difference between the last two z values is smaller than the tolerance, stop
        if abs(zs[-1] - zs[-2]) < tol:
            break
    return [xs, ys, zs], loop_count

import time
start_time = time.time()
rosenbrock_diff = find_diff()
starting_point = [-1, -1]
learning_rate = 0.001
iterations = 100

# Scatter plot of the gradient descent
descent_points, passes = gradient_descent(starting_point, learning_rate, iterations)
end_time = time.time()
print("Time taken: " + str(end_time - start_time))

ax.plot(descent_points[0], descent_points[1], descent_points[2], c='r', marker='o', linewidth = 1, markersize=3)
# title with the final point coords (with 3 decimal places)
final_point = str([round(descent_points[0][-1], 3), round(descent_points[1][-1], 3), round(descent_points[2][-1], 3)])
ax.text2D(0.05, 1, "Final Point: " + final_point, transform=ax.transAxes)
# show the number of passes as a text block on the plot
ax.text2D(0.05, 0.95, "Passes: " + str(passes), transform=ax.transAxes)
ax.view_init(30, 45)
plt.show()