diff --git a/compressed_sensing.py b/compressed_sensing.py
index 360d74a671c9a79772ef42a1755212828f455f4b..b9620292de4c9500feff43f024e128994e0bcb9f 100644
--- a/compressed_sensing.py
+++ b/compressed_sensing.py
@@ -1,4 +1,5 @@
 import numpy as np
+import scipy
 import matplotlib.pyplot as plt
 
 
@@ -6,22 +7,35 @@ def sample_two_sins(f1, f2, sample_times):
     sample_rads = 2 * np.pi * sample_times
     return np.sin(f1 * sample_rads) + np.sin(f2 * sample_rads)
 
+
+def compute_dct_matrix(n_samples):
+    dct_matrix = np.zeros((n_samples, n_samples))
+    root_one_over_n = np.sqrt(1.0 / n_samples)
+    root_two_over_n = np.sqrt(2.0 / n_samples)
+    for j in range(0, n_samples):
+        dct_matrix[0, j] = root_one_over_n
+    for i in range(1, n_samples):
+        for j in range(0, n_samples):
+            dct_matrix[i, j] = root_two_over_n * np.cos(np.pi * (2 * j + 1) * i / (2 * n_samples))
+    return dct_matrix
+
+
 if __name__ == "__main__":
     f1 = 697 # Hz
     f2 = 1209 # Hz
 
+    # Part (a)
     sample_period = 0.01
-    n_samples = 1000
+    n_samples = 250
     sample_times = (sample_period / n_samples) * np.arange(n_samples)
     sample_values = sample_two_sins(f1, f2, sample_times)
     plt.plot(sample_times, sample_values)
     plt.savefig("fig_a.png")
     plt.close()
 
-    sample_period = 0.02
-    n_samples = 1000
-    sample_times = (sample_period / n_samples) * np.arange(n_samples)
-    sample_values = sample_two_sins(f1, f2, sample_times)
-    plt.plot(sample_times, sample_values)
+    # Part (b)
+    dct_matrix = compute_dct_matrix(n_samples)
+    dct = np.matmul(dct_matrix, sample_values)
+    plt.plot(np.arange(n_samples), dct)
     plt.savefig("fig_b.png")
     plt.close()