import math
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.animation import FuncAnimation, writers

'''
6.2)
Model the bending of a beam (equation 10.29) under an applied load. Use Hermite polynomial interpolation,
and boundary conditions fixing the displacement and slope at one end, and applying a force at the other end.
V = int(0, L)[0.5*EI * (d^2u/dx^2)^2 - u(x) * f(x)]dx
Adapted from: https://polymerfem.com/full-finite-element-solver-in-100-lines-of-python/
'''

# Supporting functions
def shape(xi):
	x,y = tuple(xi)
	N = [(1.0-x)*(1.0-y), (1.0+x)*(1.0-y), (1.0+x)*(1.0+y), (1.0-x)*(1.0+y)]
	return 0.25*np.array(N)

def gradshape(xi):
	x,y = tuple(xi)
	dN = [[-(1.0-y),  (1.0-y), (1.0+y), -(1.0+y)],
		  [-(1.0-x), -(1.0+x), (1.0+x),  (1.0-x)]]
	return 0.25*np.array(dN)

# Input
mesh_ex = 49
mesh_ey = 9
mesh_lx = 50.0
mesh_ly = 10.0
# Derived
mesh_nx      = mesh_ex + 1
mesh_ny      = mesh_ey + 1
num_nodes    = mesh_nx * mesh_ny
num_elements = mesh_ex * mesh_ey
mesh_hx      = mesh_lx / mesh_ex
mesh_hy      = mesh_ly / mesh_ey
nodes = [] # make our mesh nodes
for y in np.linspace(0.0, mesh_ly, mesh_ny):
	for x in np.linspace(0.0, mesh_lx, mesh_nx):
		nodes.append([x,y]) # nodes is an array of X and Y coordinates
nodes = np.array(nodes)

conn = []
for j in range(mesh_ey):
	for i in range(mesh_ex):
		n0 = i + j*mesh_nx
		conn.append([n0, n0 + 1, n0 + 1 + mesh_nx, n0 + mesh_nx])

# Material model - Plane strain
E = 100.0
v = 0.48
C = E/(1.0+v)/(1.0-2.0*v) * np.array([[1.0-v,     v,     0.0],
								      [    v, 1.0-v,     0.0],
								      [  0.0,   0.0,   0.5-v]])

# Create global stiffness matrix
K = np.zeros((2*num_nodes, 2*num_nodes))
q4 = [[x/math.sqrt(3.0),y/math.sqrt(3.0)] for y in [-1.0,1.0] for x in [-1.0,1.0]]
B = np.zeros((3,8))
for c in conn:
	xIe = nodes[c,:]
	Ke = np.zeros((8,8))
	for q in q4:
		dN = gradshape(q)
		J  = np.dot(dN, xIe).T
		dN = np.dot(np.linalg.inv(J), dN)
		B[0,0::2] = dN[0,:]
		B[1,1::2] = dN[1,:]
		B[2,0::2] = dN[1,:]
		B[2,1::2] = dN[0,:]
		Ke += np.dot(np.dot(B.T,C),B) * np.linalg.det(J)
	for i,I in enumerate(c):
		for j,J in enumerate(c):
			K[2*I, 2*J]     += Ke[2*i, 2*j]
			K[2*I+1, 2*J]   += Ke[2*i+1, 2*j]
			K[2*I+1, 2*J+1] += Ke[2*i+1, 2*j+1]
			K[2*I, 2*J+1]   += Ke[2*i, 2*j+1]

# Assign nodal forces and boundary conditions
f = np.zeros((2*num_nodes))
for i in range(num_nodes):
	if nodes[i,0] == 0.0: # if we are at the far left, fix the displacement?
		K[2*i, :]       = 0.0
		K[2*i+1, :]     = 0.0
		K[2*i, 2*i]     = 1.0
		K[2*i+1, 2*i+1] = 1.0
	if nodes[i,0] == mesh_lx: # if we are at the far right
		y = nodes[i,1]      # get the y coordinate
		f[2*i+1] = 0.5    # apply a force
		if y == 0.0 or y == mesh_ly: # at the edges make the force half
			f[2*i+1] *= 0.5

# Solving linear system
u = np.linalg.solve(K, f)
print('max u=', max(u))

# Plotting
ux = np.reshape(u[0::2], (mesh_ny,mesh_nx))
uy = np.reshape(u[1::2], (mesh_ny,mesh_nx))
xvec = []
yvec = []
res  = []
for i in range(mesh_nx):
    for j in range(mesh_ny):
        xvec.append(i*mesh_hx + ux[j,i])
        yvec.append(j*mesh_hy + uy[j,i])
        res.append(uy[j,i])
# t = plt.tricontourf(xvec, yvec, res, levels=14, cmap=plt.cm.jet)
plt.scatter(xvec, yvec, marker='o', c='black', s=2)
# plt.grid()
# plt.colorbar(t)
plt.axis('equal')
plt.show()

# my_dpi = 96
# f = r"c://Users/david/Desktop/NMM/week6/img/beam.gif" 
# fps2 = 20
# writervideo = animation.FFMpegWriter(fps=fps2, extra_args=['-vcodec', 'libx264']) 
# anim.save(f, writer='imagemagick', fps=fps2, dpi=my_dpi)