diff --git a/main.cpp b/main.cpp index fea2db59ee7e2e9cd8d4e0ebda34c2a3774a8629..70c0681caec53aa031df01ad2f653ecb3ba4a40e 100644 --- a/main.cpp +++ b/main.cpp @@ -131,5 +131,27 @@ int main() { } python_print("dct_e", recovered_dct); + // Part (f) + recovered_dct = Vector::Random(n_samples); + recovered_sample_values = dct_matrix.transpose() * recovered_dct; + subset_recovered_sample_values = vector_subset(recovered_sample_values, subset_indices); + subset_differences = subset_sample_values - subset_recovered_sample_values; + loss = subset_differences.squaredNorm() + recovered_dct.squaredNorm(); + gradient = -2 * subset_dct_matrix * subset_differences + 2 * recovered_dct; + Scalar last_loss = std::numeric_limits<Scalar>::infinity(); + Scalar relative_change = (loss - last_loss) / loss; + while (relative_change > 1e-6) { + recovered_dct -= learning_rate * gradient; + recovered_sample_values = dct_matrix.transpose() * recovered_dct; + subset_recovered_sample_values = vector_subset(recovered_sample_values, subset_indices); + subset_differences = subset_sample_values - subset_recovered_sample_values; + loss = subset_differences.squaredNorm() + recovered_dct.squaredNorm(); + gradient = -2 * subset_dct_matrix * subset_differences + 2 * recovered_dct; + last_loss = std::numeric_limits<Scalar>::infinity(); + relative_change = (loss - last_loss) / loss; + //std::cout << loss << '\n'; + } + python_print("dct_f", recovered_dct); + return 0; } diff --git a/plotter.py b/plotter.py index ac0436501f299529884771cf58bf3514636e6db0..7edc1aedeede1256b9125336943b3edee69c5a09 100644 --- a/plotter.py +++ b/plotter.py @@ -8,6 +8,7 @@ recovered_sample_values = np.array([3.71932e-15, 0.475338, 0.917539, 1.29619, 1. subset_sample_times = np.array([0.000200803, 0.000240964, 0.000281124, 0.000321285, 0.000441767, 0.00064257, 0.000883534, 0.00100402, 0.00104418, 0.00116466, 0.00120482, 0.00124498, 0.00144578, 0.00148594, 0.00164659, 0.00180723, 0.00184739, 0.00192771, 0.00200803, 0.00212851, 0.00240964, 0.00248996, 0.00261044, 0.00289157, 0.00293173, 0.00305221, 0.00313253, 0.00317269, 0.00321285, 0.00325301, 0.00333333, 0.00345382, 0.00369478, 0.00373494, 0.0037751, 0.00389558, 0.00405622, 0.00413655, 0.00425703, 0.00429719, 0.00433735, 0.00437751, 0.00441767, 0.00449799, 0.00453815, 0.00465863, 0.0046988, 0.00473896, 0.00481928, 0.00485944, 0.0048996, 0.00518072, 0.00522088, 0.0053012, 0.00554217, 0.00570281, 0.00574297, 0.00586345, 0.00626506, 0.00630522, 0.00634538, 0.00646586, 0.00658635, 0.00662651, 0.00670683, 0.00674699, 0.00694779, 0.00698795, 0.00702811, 0.00706827, 0.00710843, 0.00714859, 0.00718876, 0.00726908, 0.00730924, 0.00742972, 0.0075502, 0.00763052, 0.00771084, 0.00779116, 0.00791165, 0.00799197, 0.00807229, 0.00811245, 0.00815261, 0.00819277, 0.00823293, 0.00827309, 0.00839357, 0.00843373, 0.00883534, 0.00891566, 0.00903614, 0.00923695, 0.00927711, 0.00935743, 0.00963855, 0.00975904, 0.00983936, 0.01]) subset_sample_values = np.array([1.76932, 1.83651, 1.78761, 1.6316, 0.721928, -0.664077, -0.249706, 0.0236569, 0.00668195, -0.379618, -0.575975, -0.771114, -0.951487, -0.73516, 0.742071, 1.91577, 1.96681, 1.70642, 1.02857, -0.341839, -1.42199, -0.930797, -0.0755387, 0.122481, -0.00631234, -0.212412, -0.0591181, 0.112454, 0.333336, 0.584794, 1.08309, 1.44291, -0.249515, -0.701696, -1.12623, -1.94442, -1.45207, -0.663067, 0.591886, 0.910992, 1.14409, 1.28044, 1.31795, 1.12989, 0.938508, 0.261091, 0.0805585, -0.0466774, -0.112311, -0.0551393, 0.0472029, 0.354329, 0.158647, -0.400459, -1.71079, -0.771708, -0.33113, 1.04895, 0.291854, -0.0840174, -0.414358, -0.954231, -0.770496, -0.606599, -0.259632, -0.121617, -0.247029, -0.408011, -0.575983, -0.727348, -0.838913, -0.890317, -0.866221, -0.565119, -0.295029, 0.791266, 1.71807, 1.89685, 1.60779, 0.907671, -0.488655, -1.28001, -1.71149, -1.75931, -1.69539, -1.53092, -1.28466, -0.981105, -0.00689592, 0.25156, -0.0715908, -0.00857246, 0.498828, 1.24743, 1.18842, 0.783761, -1.79994, -1.9006, -1.38735, 0.348445]) dct_e = np.array([-0.181725, 1.31755, 0.649338, 0.146722, 0.577319, 0.52855, -0.420861, 0.00972246, 1.34705, 2.02479, -0.45287, 1.50728, 0.11511, 2.50369, -0.860719, -1.9832, -0.188749, -1.89087, 0.895381, -0.31801, -1.49248, 1.61864, -0.740006, 1.70221, 1.18965, -3.83239, -0.595318, -1.56708, -0.0476166, -0.958164, 0.308146, 0.538959, -0.567305, -1.54234, -0.312223, -1.05125, 0.279093, -0.555883, 0.263304, -1.87512, -0.893629, 1.52954, -0.18754, 0.129317, -0.0409772, -0.481374, -0.272716, -0.40792, -0.428772, -1.23366, 0.932427, -0.284411, 0.114686, -0.175491, 0.287177, 0.321669, -0.324072, 0.163224, 0.792369, -1.01108, 1.23269, 0.226345, 0.882266, -0.0194645, 0.858726, 0.459991, -0.664407, -0.433639, 0.909784, 0.119792, 0.347188, 0.0949223, -0.223123, 0.515855, 1.06883, 0.507396, 1.05056, -0.0203115, 1.16572, 0.471133, 0.787901, -0.74391, -0.192763, 0.246201, 0.433876, 0.56428, 0.922964, 0.921105, 0.782957, 1.06646, -0.642362, -0.560816, -0.619862, 0.406475, -0.944654, 0.524438, 0.105425, 0.192224, -0.588941, 0.361894, -0.0422242, -0.039336, -0.739689, -0.431406, -0.503528, -0.476026, 0.321553, -1.01571, -0.572337, 0.309822, 0.917963, 0.861005, -0.763123, -0.260667, 0.35751, 0.803767, -0.39887, -0.580729, -0.517367, -0.134686, -0.22344, -0.263698, -0.449559, -0.882012, -0.477269, 0.705133, 0.471159, -1.19926, -0.238682, -0.554693, 0.795113, 0.0240765, 0.412576, -0.0746031, -0.56287, 0.0267437, 0.132269, -0.577963, 0.144747, -0.0134209, -0.0511658, 1.41067, -0.414959, -0.290205, 0.312317, -0.102443, -0.115503, -0.133131, -2.58108, -0.421964, 0.238911, 0.822077, -1.31384, -0.674516, 0.270687, -0.216738, 0.0403316, 0.394182, -1.24846, 0.518855, -1.11974, -0.114353, -0.105765, -0.0767365, -0.242629, 0.767938, -0.770488, 1.15386, 1.03914, -0.08044, -1.16603, 1.0451, 0.494025, -0.362372, 0.0147945, -0.1952, -0.333689, -0.507866, 1.61599, -0.814214, 0.0306287, -0.659977, 0.803305, 0.150613, 1.02249, 0.340859, 1.74513, 0.55267, -0.036223, 0.240978, 0.132654, 0.227788, 0.0112584, -0.76642, 0.941365, -0.123865, 0.579842, 1.11127, -0.478658, 0.21837, -0.213266, -0.430659, 0.953601, 0.404303, 0.593639, 0.421635, -0.771913, -0.592505, -0.627325, 0.184693, 0.317796, -0.286112, -0.478074, 0.300756, 0.0458375, 0.821987, -0.0354288, 0.459321, 0.595002, 0.190285, 0.373979, 0.365756, -0.575853, -0.492914, -0.671376, -0.0743422, 0.626893, 0.863512, -0.00458193, 1.51513, 0.183877, 0.718056, -0.512423, -0.529361, 0.0781563, 0.336743, 0.653948, 0.683846, -0.797976, 1.4012, 0.294974, -0.716314, -0.451483, 0.0356109, 1.05136, -1.26643, 0.0326029, 0.0339741, -0.294437, 0.511468]) +dct_f = np.array([-0.0587958, 0.83392, 0.132366, -0.0302263, 0.455145, 0.250454, -0.0265924, 0.0759504, 0.451562, 0.719869, 0.0187715, 0.832717, -0.300425, 1.51969, -0.22468, -1.07992, -0.231819, -0.777481, 0.640932, 0.354693, -0.552675, 0.461616, -0.188974, 0.668277, 0.474778, -1.74631, 0.102872, -0.80092, -0.0307746, -0.435484, -0.293083, -0.129672, -0.0500061, -0.384728, -0.263845, -0.325387, 0.00869475, -0.544153, -0.0846679, -0.79423, -0.271229, 0.251245, -0.162325, -0.0986975, -0.0643201, 0.0771463, -0.0324127, -0.306365, -0.183463, -0.515214, 0.15253, -0.111942, 0.158193, -0.0575633, 0.0477503, 0.304364, -0.0150429, -0.273452, 0.0964331, -0.185356, 0.366054, 0.215494, 0.23831, 0.148894, -0.0347637, 0.21712, -0.193439, 0.0335857, -0.0563406, 0.135322, 0.195014, 0.222709, 0.153517, 0.169688, 0.419629, 0.42625, 0.291496, 0.221815, 0.348463, -0.0450442, 0.0882775, -0.141382, 0.0721275, -0.271627, 0.0412067, 0.271174, 0.25799, 0.159642, 0.181944, 0.378495, -0.21575, -0.193457, -0.0284752, 0.0601079, -0.474095, 0.0361101, -0.0881306, 0.00997256, -0.169241, 0.146956, 0.126388, -0.156709, -0.153306, 0.0247547, -0.193703, -0.0300769, 0.00274041, -0.134635, -0.189638, 0.0208125, 0.309971, 0.150421, -0.521753, -0.250083, -0.163288, -0.0104956, -0.0267689, -0.106925, -0.25884, 0.170088, 0.362609, 0.0639183, 0.0550957, -0.385564, -0.0467193, -0.0440537, 0.0231831, -0.134736, 0.173095, -0.186557, 0.248478, 0.0703752, 0.44137, -0.213527, -0.163014, 0.0776367, -0.00207879, -0.0706977, -0.0178802, -0.159089, 0.0192185, 0.258309, 0.0694495, -0.0188079, -0.0714662, 0.0960203, 0.09085, 0.0343693, -0.816958, -0.253304, 0.102688, 0.0605715, -0.307286, -0.0398879, 0.0733608, 0.0103113, 0.235709, 0.162043, -0.600436, -0.191684, -0.0795467, 0.0390502, 0.156636, 0.0648662, 0.0192842, -0.00614493, -0.00518692, 0.418032, 0.155772, 0.113786, -0.254948, 0.143302, 0.410395, 0.0680757, 0.102929, 0.0361711, -0.00267253, -0.0352654, 0.404044, -0.245883, 0.264461, -0.145808, 0.203461, -0.360337, 0.377393, -0.0408745, 0.659419, -0.0438794, 0.0504482, -0.320621, -0.0144115, -0.172558, 0.0525555, -0.279137, 0.335267, 0.0617953, 0.0855495, 0.478775, -0.16531, -0.0513437, -0.204274, -0.422788, 0.0842655, -0.148743, 0.145042, 0.192927, -0.5449, 0.0852716, 0.152811, 0.0766111, -0.0687772, -0.297389, -0.201822, 0.0150052, -0.22068, 0.296586, -0.274366, -0.0371884, 0.171577, 0.409762, -0.227992, 0.106775, -0.458331, -0.14934, -0.365863, 0.0957223, 0.205644, 0.195402, 0.146401, 0.58686, -0.304342, -0.0341664, -0.379284, -0.0907502, 0.111846, -0.193023, 0.157946, 0.108623, -0.22766, 0.469102, -0.15848, -0.390685, -0.197786, -0.287082, 0.586648, -0.238093, -0.209569, 0.203207, -0.28243, 0.377885]) if __name__ == "__main__": # Part (a) @@ -34,3 +35,8 @@ if __name__ == "__main__": plt.plot(np.arange(len(dct)), dct_e) plt.savefig("fig_e.png") plt.close() + + # Part (e) + plt.plot(np.arange(len(dct)), dct_f) + plt.savefig("fig_f.png") + plt.close()