Skip to content
Snippets Groups Projects
Commit 8ec831f7 authored by Erik Strand's avatar Erik Strand
Browse files

Use reasonable basis functions

parent bb889547
Branches
No related tags found
No related merge requests found
......@@ -40,53 +40,56 @@ class LocalBasisFunction:
def right_curvature(self, xval):
return self.dd_right.subs(x, xval)
def derive_basis_function():
def cubic_basis_functions():
# helps determine cubic poly coefficients
poly_matrix = Matrix([
[1, 0, 0, 0],
[0, 1, 0, 0],
[1, h, h**2, h**3],
[0, 0, 2, 0 ],
[0, 0, 2, 6*h ]
[0, 1, 2*h, 3*h**2]
])
poly_matrix_inverse = poly_matrix**-1
print(latex(poly_matrix))
print(latex(poly_matrix_inverse))
print("")
c_1 = simplify(poly_matrix_inverse * Matrix([0, -1, 0, 1]))
c_2 = simplify(poly_matrix_inverse * Matrix([-1, 0, 1, 0]))
p_1 = c_1.dot(Matrix([1, x, x**2, x**3]))
p_2 = c_2.dot(Matrix([1, x, x**2, x**3]))
print("coeffs")
print(latex(c_1))
print(latex(c_2))
print("polys")
p_1 = (poly_matrix_inverse * Matrix([1, 0, 0, 0])).dot(Matrix([1, x, x**2, x**3]))
p_2 = (poly_matrix_inverse * Matrix([0, 1, 0, 0])).dot(Matrix([1, x, x**2, x**3]))
p_3 = (poly_matrix_inverse * Matrix([0, 0, 1, 0])).dot(Matrix([1, x, x**2, x**3]))
p_4 = (poly_matrix_inverse * Matrix([0, 0, 0, 1])).dot(Matrix([1, x, x**2, x**3]))
print(latex(p_1))
print(latex(p_2))
print("poly derivs")
print(latex(diff(p_1, x)))
print(latex(diff(p_2, x)))
print("poly 2nd derivs")
print(latex(diff(p_1, x, 2)))
print(latex(diff(p_2, x, 2)))
print(latex(p_3))
print(latex(p_4))
print("")
print("sum")
print(latex(simplify(p_1 + p_2)))
return LocalBasisFunction(p_1, p_2)
return [
LocalBasisFunction(p_3, p_1),
LocalBasisFunction(p_4, p_2)
]
# basis_functions is a list of LocalBasisFunction objects
# coefficients is a 2d numpy array; each row is for a basis function, each col is for a node
def plot_result(basis_functions, coefficients):
assert(len(basis_functions) == coefficients.shape[0])
def plot_result(basis_function, coefficients):
n_subdivs = 10
xvals = []
yvals = []
dyvals = []
ddyvals = []
for i in range(0, len(coefficients) - 1):
for i in range(0, coefficients.shape[1] - 1):
for xval in np.linspace(0, 1, n_subdivs, endpoint=False):
y = coefficients[i] * basis_function.right_value(xval).subs(h, 1)
y += coefficients[i + 1] * basis_function.left_value(xval).subs(h, 1)
dy = coefficients[i] * basis_function.right_slope(xval).subs(h, 1)
dy += coefficients[i + 1] * basis_function.left_slope(xval).subs(h, 1)
ddy = coefficients[i] * basis_function.right_curvature(xval).subs(h, 1)
ddy += coefficients[i + 1] * basis_function.left_curvature(xval).subs(h, 1)
y = 0
dy = 0
ddy = 0
for basis_function, row in zip(basis_functions, coefficients):
y += row[i] * basis_function.right_value(xval).subs(h, 1)
y += row[i + 1] * basis_function.left_value(xval).subs(h, 1)
dy += row[i] * basis_function.right_slope(xval).subs(h, 1)
dy += row[i + 1] * basis_function.left_slope(xval).subs(h, 1)
ddy += row[i] * basis_function.right_curvature(xval).subs(h, 1)
ddy += row[i + 1] * basis_function.left_curvature(xval).subs(h, 1)
xvals.append(i + xval)
yvals.append(y)
dyvals.append(dy)
......@@ -95,15 +98,15 @@ def plot_result(basis_function, coefficients):
fig1 = plt.figure()
left, bottom, width, height = 0.1, 0.1, 0.8, 0.8
ax1 = fig1.add_axes([left, bottom, width, height])
ax1.plot(xvals, yvals, label="value")
ax1.plot(xvals, yvals, label="displacement")
ax1.plot(xvals, dyvals, label="slope")
ax1.plot(xvals, ddyvals, label="curvature")
ax1.set_xlabel('h')
ax1.set_ylabel("displacement")
ax1.legend()
plt.show(fig1)
fig1.savefig("../../../assets/img/06_bad_basis_functions.png", transparent=True)
#fig1.savefig("../../../assets/img/06_basis_functions.png", transparent=True)
if __name__ == "__main__":
basis_function = derive_basis_function()
plot_result(basis_function, [0.1, 0.1, 0.1, 0.1, 0.1])
basis_function = cubic_basis_functions()
plot_result(basis_function, np.array([[0, 0.1, 0.2, 0.4, 0.4], [0, 0.2, 0.1, 0.1, -0.1]]))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment