diff --git a/main.cpp b/main.cpp index a5b71394801110ba532cd130d0f6a16a97d4301c..84a41cff9caacf18867e8a9192060a149de829c0 100644 --- a/main.cpp +++ b/main.cpp @@ -119,7 +119,7 @@ int main() { Vector subset_differences = subset_sample_values - subset_recovered_sample_values; Scalar loss = subset_differences.squaredNorm(); Vector gradient = -2 * subset_dct_matrix * subset_differences; - constexpr Scalar learning_rate = 0.1; + constexpr Scalar learning_rate = 0.5; while (loss > 1e-6) { recovered_dct -= learning_rate * gradient; recovered_sample_values = dct_matrix.transpose() * recovered_dct; @@ -129,6 +129,7 @@ int main() { gradient = -2 * subset_dct_matrix * subset_differences; //std::cout << loss << '\n'; } + Scalar const final_loss_e = loss; python_print("dct_e", recovered_dct); // Part (f) @@ -151,6 +152,7 @@ int main() { relative_change = (loss - last_loss) / loss; //std::cout << loss << '\n'; } + Scalar const final_loss_f = loss; python_print("dct_f", recovered_dct); // Part (g) @@ -175,7 +177,13 @@ int main() { relative_change = (loss - last_loss) / loss; //std::cout << loss << '\n'; } + Scalar const final_loss_g = loss; python_print("dct_g", recovered_dct); + std::cout << '\n'; + std::cout << "Final unregularized loss: " << final_loss_e << '\n'; + std::cout << "Final L2 regularized loss: " << final_loss_f << '\n'; + std::cout << "Final L1 regularized loss: " << final_loss_g << '\n'; + return 0; } diff --git a/plotter.py b/plotter.py index 10dd4996720cff86db8e9d883eba4ea4970fd445..04ad7aa976e69ef98f2fd19e03b94a13b432e987 100644 --- a/plotter.py +++ b/plotter.py @@ -7,7 +7,7 @@ dct = np.array([0.0493937, 1.53249, 0.0703373, 1.58311, 0.071832, 1.69718, 0.074 recovered_sample_values = np.array([3.71932e-15, 0.475338, 0.917539, 1.29619, 1.5861, 1.76932, 1.83651, 1.78761, 1.6316, 1.38558, 1.0731, 0.721928, 0.361558, 0.0205522, -0.27594, -0.508314, -0.664077, -0.738727, -0.735864, -0.666533, -0.547871, -0.401193, -0.249706, -0.116065, -0.0200166, 0.0236569, 0.00668195, -0.0718098, -0.205127, -0.379618, -0.575975, -0.771114, -0.94046, -1.06042, -1.1108, -1.07697, -0.951487, -0.73516, -0.43726, -0.0750177, 0.327653, 0.742071, 1.13718, 1.48217, 1.7492, 1.91577, 1.96681, 1.89599, 1.70642, 1.41042, 1.02857, 0.588008, 0.120163, -0.341839, -0.765935, -1.12376, -1.39294, -1.5588, -1.61544, -1.56598, -1.42199, -1.20225, -0.930797, -0.634626, -0.341124, -0.0755387, 0.141293, 0.294754, 0.377636, 0.390584, 0.341783, 0.245913, 0.122481, -0.00631234, -0.117955, -0.19203, -0.212412, -0.169044, -0.0591181, 0.112454, 0.333336, 0.584794, 0.843377, 1.08309, 1.27787, 1.40408, 1.44291, 1.38232, 1.21838, 0.95598, 0.608661, 0.197701, -0.249515, -0.701696, -1.12623, -1.49192, -1.77159, -1.94442, -1.99773, -1.92798, -1.74113, -1.45207, -1.08332, -0.663067, -0.222778, 0.205515, 0.591886, 0.910992, 1.14409, 1.28044, 1.31795, 1.26307, 1.12989, 0.938508, 0.713012, 0.478999, 0.261091, 0.0805585, -0.0466774, -0.11149, -0.112311, -0.0551393, 0.0472029, 0.176484, 0.311019, 0.427974, 0.505792, 0.526522, 0.477802, 0.354329, 0.158647, -0.0988122, -0.400459, -0.722977, -1.03934, -1.32121, -1.54155, -1.67714, -1.71079, -1.63312, -1.44356, -1.15069, -0.771708, -0.33113, 0.14113, 0.6122, 1.04895, 1.42074, 1.70194, 1.87404, 1.9271, 1.86053, 1.68299, 1.41149, 1.06984, 0.686454, 0.291854, -0.0840174, -0.414358, -0.677735, -0.859784, -0.954231, -0.963144, -0.896411, -0.770496, -0.606599, -0.428405, -0.259632, -0.121617, -0.0311657, 0.00112032, -0.0281042, -0.11461, -0.247029, -0.408011, -0.575983, -0.727348, -0.838913, -0.890317, -0.866221, -0.75805, -0.565119, -0.295029, 0.0367266, 0.407896, 0.791266, 1.15696, 1.47502, 1.71807, 1.86376, 1.89685, 1.81066, 1.60779, 1.30007, 0.907671, 0.457475, -0.019091, -0.488655, -0.918654, -1.28001, -1.54948, -1.71149, -1.75931, -1.69539, -1.53092, -1.28466, -0.981105, -0.648113, -0.314394, -0.00689592, 0.25156, 0.444213, 0.561694, 0.602629, 0.573475, 0.487607, 0.363754, 0.223933, 0.0910946, -0.0133094, -0.0715908, -0.0717057, -0.00857246, 0.115302, 0.289952, 0.498828, 0.720336, 0.929905, 1.10238, 1.21454, 1.24743, 1.18842, 1.03264, 0.783761, 0.453978, 0.0632207, -0.36236, -0.792506, -1.19536, -1.54017, -1.79994, -1.95375, -1.9886, -1.9006, -1.69531, -1.38735, -0.999124, -0.559002, -0.0988865, 0.348445]) subset_sample_times = np.array([0.000200803, 0.000240964, 0.000281124, 0.000321285, 0.000441767, 0.00064257, 0.000883534, 0.00100402, 0.00104418, 0.00116466, 0.00120482, 0.00124498, 0.00144578, 0.00148594, 0.00164659, 0.00180723, 0.00184739, 0.00192771, 0.00200803, 0.00212851, 0.00240964, 0.00248996, 0.00261044, 0.00289157, 0.00293173, 0.00305221, 0.00313253, 0.00317269, 0.00321285, 0.00325301, 0.00333333, 0.00345382, 0.00369478, 0.00373494, 0.0037751, 0.00389558, 0.00405622, 0.00413655, 0.00425703, 0.00429719, 0.00433735, 0.00437751, 0.00441767, 0.00449799, 0.00453815, 0.00465863, 0.0046988, 0.00473896, 0.00481928, 0.00485944, 0.0048996, 0.00518072, 0.00522088, 0.0053012, 0.00554217, 0.00570281, 0.00574297, 0.00586345, 0.00626506, 0.00630522, 0.00634538, 0.00646586, 0.00658635, 0.00662651, 0.00670683, 0.00674699, 0.00694779, 0.00698795, 0.00702811, 0.00706827, 0.00710843, 0.00714859, 0.00718876, 0.00726908, 0.00730924, 0.00742972, 0.0075502, 0.00763052, 0.00771084, 0.00779116, 0.00791165, 0.00799197, 0.00807229, 0.00811245, 0.00815261, 0.00819277, 0.00823293, 0.00827309, 0.00839357, 0.00843373, 0.00883534, 0.00891566, 0.00903614, 0.00923695, 0.00927711, 0.00935743, 0.00963855, 0.00975904, 0.00983936, 0.01]) subset_sample_values = np.array([1.76932, 1.83651, 1.78761, 1.6316, 0.721928, -0.664077, -0.249706, 0.0236569, 0.00668195, -0.379618, -0.575975, -0.771114, -0.951487, -0.73516, 0.742071, 1.91577, 1.96681, 1.70642, 1.02857, -0.341839, -1.42199, -0.930797, -0.0755387, 0.122481, -0.00631234, -0.212412, -0.0591181, 0.112454, 0.333336, 0.584794, 1.08309, 1.44291, -0.249515, -0.701696, -1.12623, -1.94442, -1.45207, -0.663067, 0.591886, 0.910992, 1.14409, 1.28044, 1.31795, 1.12989, 0.938508, 0.261091, 0.0805585, -0.0466774, -0.112311, -0.0551393, 0.0472029, 0.354329, 0.158647, -0.400459, -1.71079, -0.771708, -0.33113, 1.04895, 0.291854, -0.0840174, -0.414358, -0.954231, -0.770496, -0.606599, -0.259632, -0.121617, -0.247029, -0.408011, -0.575983, -0.727348, -0.838913, -0.890317, -0.866221, -0.565119, -0.295029, 0.791266, 1.71807, 1.89685, 1.60779, 0.907671, -0.488655, -1.28001, -1.71149, -1.75931, -1.69539, -1.53092, -1.28466, -0.981105, -0.00689592, 0.25156, -0.0715908, -0.00857246, 0.498828, 1.24743, 1.18842, 0.783761, -1.79994, -1.9006, -1.38735, 0.348445]) -dct_e = np.array([-0.181725, 1.31755, 0.649338, 0.146722, 0.577319, 0.52855, -0.420861, 0.00972246, 1.34705, 2.02479, -0.45287, 1.50728, 0.11511, 2.50369, -0.860719, -1.9832, -0.188749, -1.89087, 0.895381, -0.31801, -1.49248, 1.61864, -0.740006, 1.70221, 1.18965, -3.83239, -0.595318, -1.56708, -0.0476166, -0.958164, 0.308146, 0.538959, -0.567305, -1.54234, -0.312223, -1.05125, 0.279093, -0.555883, 0.263304, -1.87512, -0.893629, 1.52954, -0.18754, 0.129317, -0.0409772, -0.481374, -0.272716, -0.40792, -0.428772, -1.23366, 0.932427, -0.284411, 0.114686, -0.175491, 0.287177, 0.321669, -0.324072, 0.163224, 0.792369, -1.01108, 1.23269, 0.226345, 0.882266, -0.0194645, 0.858726, 0.459991, -0.664407, -0.433639, 0.909784, 0.119792, 0.347188, 0.0949223, -0.223123, 0.515855, 1.06883, 0.507396, 1.05056, -0.0203115, 1.16572, 0.471133, 0.787901, -0.74391, -0.192763, 0.246201, 0.433876, 0.56428, 0.922964, 0.921105, 0.782957, 1.06646, -0.642362, -0.560816, -0.619862, 0.406475, -0.944654, 0.524438, 0.105425, 0.192224, -0.588941, 0.361894, -0.0422242, -0.039336, -0.739689, -0.431406, -0.503528, -0.476026, 0.321553, -1.01571, -0.572337, 0.309822, 0.917963, 0.861005, -0.763123, -0.260667, 0.35751, 0.803767, -0.39887, -0.580729, -0.517367, -0.134686, -0.22344, -0.263698, -0.449559, -0.882012, -0.477269, 0.705133, 0.471159, -1.19926, -0.238682, -0.554693, 0.795113, 0.0240765, 0.412576, -0.0746031, -0.56287, 0.0267437, 0.132269, -0.577963, 0.144747, -0.0134209, -0.0511658, 1.41067, -0.414959, -0.290205, 0.312317, -0.102443, -0.115503, -0.133131, -2.58108, -0.421964, 0.238911, 0.822077, -1.31384, -0.674516, 0.270687, -0.216738, 0.0403316, 0.394182, -1.24846, 0.518855, -1.11974, -0.114353, -0.105765, -0.0767365, -0.242629, 0.767938, -0.770488, 1.15386, 1.03914, -0.08044, -1.16603, 1.0451, 0.494025, -0.362372, 0.0147945, -0.1952, -0.333689, -0.507866, 1.61599, -0.814214, 0.0306287, -0.659977, 0.803305, 0.150613, 1.02249, 0.340859, 1.74513, 0.55267, -0.036223, 0.240978, 0.132654, 0.227788, 0.0112584, -0.76642, 0.941365, -0.123865, 0.579842, 1.11127, -0.478658, 0.21837, -0.213266, -0.430659, 0.953601, 0.404303, 0.593639, 0.421635, -0.771913, -0.592505, -0.627325, 0.184693, 0.317796, -0.286112, -0.478074, 0.300756, 0.0458375, 0.821987, -0.0354288, 0.459321, 0.595002, 0.190285, 0.373979, 0.365756, -0.575853, -0.492914, -0.671376, -0.0743422, 0.626893, 0.863512, -0.00458193, 1.51513, 0.183877, 0.718056, -0.512423, -0.529361, 0.0781563, 0.336743, 0.653948, 0.683846, -0.797976, 1.4012, 0.294974, -0.716314, -0.451483, 0.0356109, 1.05136, -1.26643, 0.0326029, 0.0339741, -0.294437, 0.511468]) +dct_e = np.array([-0.999984, -0.736924, 0.511211, -0.0826997, 0.0655345, -0.562082, -0.905911, 0.357729, 0.358593, 0.869386, -0.232996, 0.0388327, 0.661931, -0.930856, -0.893077, 0.0594004, 0.342299, -0.984604, -0.233169, -0.866316, -0.165028, 0.373545, 0.177953, 0.860873, 0.692334, 0.0538576, -0.81607, 0.307838, -0.168001, 0.402381, 0.820642, 0.524396, -0.475094, -0.905071, 0.472164, -0.343532, 0.265277, 0.512821, 0.982075, -0.269323, -0.505922, 0.965101, 0.445321, 0.506712, 0.303037, -0.854628, 0.263269, 0.769414, -0.45458, -0.127177, 0.53299, -0.0445365, -0.524451, -0.450186, -0.28147, -0.666986, -0.0269652, 0.795313, 0.818416, -0.878871, 0.809306, 0.00904579, 0.0325839, -0.361934, 0.973284, -0.0120466, -0.467711, -0.818534, 0.895528, -0.852502, 0.00141419, -0.231716, -0.445836, 0.827635, 0.0594948, -0.0711084, 0.88196, -0.899832, 0.523029, 0.540409, 0.655635, -0.749269, -0.968265, 0.376911, 0.736494, 0.259087, 0.472449, 0.450824, 0.998916, 0.777144, -0.53361, -0.387356, -0.29797, 0.0265474, 0.182227, 0.691963, -0.175838, 0.683021, -0.461365, -0.169211, 0.074608, -0.0641653, -0.425575, -0.643345, -0.69256, 0.14331, 0.604811, -0.933892, 0.0688997, -0.00303976, 0.910722, 0.496585, 0.109168, 0.781475, 0.249699, 0.684079, -0.680465, -0.574497, 0.42942, -0.739145, -0.818019, -0.450824, -0.994001, -0.171413, -0.946247, 0.419639, 0.875795, -0.520178, -0.638208, -0.364921, 0.773981, 0.304117, -0.69933, 0.362692, -0.228371, -0.224549, -0.000517935, -0.704934, 0.174373, 0.691151, 0.180217, 0.910818, 0.112292, -0.703697, 0.96661, -0.182467, -0.71636, 0.129797, -0.495747, -0.0229709, -0.0719389, 0.92219, -0.747938, -0.600486, -0.361501, 0.258538, -0.746576, 0.302507, 0.243268, 0.606146, -0.504316, -0.0471364, -0.221372, -0.593499, -0.94325, 0.803347, -0.147005, -0.715958, 0.894974, -0.179374, -0.737623, 0.771297, -0.815653, -0.675603, -0.857873, -0.269322, -0.493885, -0.729781, 0.566306, -0.0893854, -0.300952, -0.0953997, 0.617889, 0.863349, 0.303292, -0.569503, 0.359185, 0.817844, -0.499749, 0.72172, -0.0574753, 0.0119118, 0.200787, 0.635123, 0.511687, -0.07551, 0.902735, 0.265477, -0.121339, 0.649395, 0.377962, 0.404413, 0.974291, 0.90883, 0.702539, -0.421368, 0.0748515, 0.0288693, -0.793132, -0.171943, 0.153433, 0.753131, -0.119923, 0.459495, 0.738527, 0.431285, 0.601441, 0.413071, 0.483432, -0.961815, 0.772062, 0.0499748, -0.0733545, -0.869612, 0.426845, -0.0221137, 0.335358, 0.364098, -0.600891, 0.833268, 0.731765, 0.780037, 0.0878966, -0.72161, -0.0993051, 0.978725, -0.568936, -0.107953, -0.368535, 0.0293189, 0.763008, -0.120549, -0.0649364, 0.6133, -0.269684, -0.576962, 0.998233, -0.692791, 0.260977, 0.232701]) dct_f = np.array([-0.99881, -0.998242, 0.546704, 0.45467, -0.361645, -0.164553, 0.364988, 0.361124, -0.589476, 0.67284, 0.417841, 0.657416, -0.810902, -0.836525, 0.52801, 0.259144, -0.572296, -0.572906, -0.837878, -0.222354, 0.904319, 0.89509, -0.220293, -0.461571, 0.384338, -0.43193, 0.553732, 0.56773, -0.15508, -0.435688, -0.612065, -0.977368, -0.616352, 0.966472, -0.511892, 0.639452, -0.72709, -0.203713, 0.20202, -0.646239, 0.656709, -0.684538, 0.975874, -0.485663, -0.532802, -0.796725, -0.561179, 0.269435, 0.392015, 0.58954, 0.392486, 0.505881, 0.339041, 0.26686, -0.887119, 0.196433, -0.545985, -0.362444, 0.399669, -0.765126, 0.525141, 0.0522466, 0.107823, 0.175978, -0.340667, 0.405979, -0.713909, -0.676625, -0.0293496, 0.720451, 0.626531, 0.113672, 0.477993, -0.367985, -0.729336, 0.0570965, -0.378522, 0.176238, 0.0361672, -0.138304, -0.482314, -0.259547, -0.213965, -0.106268, -0.0482076, -0.224338, -0.441414, -0.843474, -0.260516, -0.492155, 0.343568, 0.352474, 0.0278727, 0.457217, 0.441536, 0.889506, -0.0786053, 0.880326, -0.35688, -0.0791317, 0.0342906, 0.322711, -0.196334, 0.211279, 0.973005, -0.699212, 0.340197, -0.312365, 0.0834203, 0.0456154, 0.658477, -0.970099, -0.451581, 0.285958, 0.0922732, 0.835695, -0.466773, 0.940173, -0.506534, 0.687951, 0.388047, -0.0884968, 0.634203, -0.955833, -0.678502, 0.4139, 0.415653, -0.126723, 0.164842, 0.50342, 0.983066, 0.391358, -0.440977, 0.507725, -0.657642, -0.991492, -0.00861785, -0.840166, -0.673838, 0.806602, 0.565984, 0.493358, -0.127149, -0.987317, 0.169574, 0.0369568, -0.8663, 0.0880454, -0.220537, -0.568348, -0.228461, 0.253723, 0.318107, 0.41864, 0.0773226, -0.439422, 0.631864, -0.256181, 0.372203, -0.379662, -0.981305, -0.792405, 0.0419201, 0.551761, -0.55863, -0.889442, -0.85873, -0.679452, 0.447527, -0.405956, -0.898795, -0.0443914, -0.0856762, 0.039977, -0.107021, -0.699377, -0.428115, 0.667364, 0.391275, 0.155552, 0.362329, -0.33765, -0.885049, 0.973766, 0.0847863, -0.996685, 0.712385, -0.941289, -0.244899, -0.0206966, 0.152346, 0.477917, 0.358664, 0.0725335, -0.928825, -0.754986, 0.946711, -0.623448, -0.297363, 0.228251, 0.215223, -0.748134, 0.105724, 0.908057, -0.28615, 0.676203, 0.948505, -0.481972, -0.496945, -0.151328, 0.626096, 0.795828, -0.510828, 0.521235, 0.389713, -0.0962975, -0.47276, 0.331007, -0.766858, -0.580875, -0.772649, 0.0939497, -0.987841, -0.648051, 0.198719, -0.136183, -0.834987, 0.374829, -0.24938, 0.672413, -0.752725, 0.946692, -0.940729, -0.839294, -0.0115107, 0.538875, 0.868061, -0.499689, -0.280684, 0.538228, -8.93376e-05, 0.498503, 0.343807, 0.363332, 0.51354, -0.927265, -0.538856, -0.556665, 0.125204, 0.305244]) dct_g = np.array([0.233221, -0.257069, -0.550398, -0.532617, 0.306023, -0.668145, 0.488208, -0.680485, -0.904723, 0.323987, -0.75011, 0.898314, -0.0316391, 0.242091, 0.82805, -0.964111, 0.187091, 0.443743, -0.006836, -0.892611, -0.116735, 0.0383497, 0.543888, -0.869287, -0.114419, 0.95436, -0.0645173, -0.341888, -0.10814, 0.486548, -0.589317, -0.657241, -0.25496, 0.879451, 0.929702, -0.490085, -0.85933, -0.752302, 0.0526757, -0.679834, 0.0354144, -0.78963, 0.683152, -0.262876, -0.152219, -0.345545, 0.427093, 0.153552, 0.743768, 0.516492, 0.673196, 0.399582, -0.22188, 0.870483, 0.211205, -0.275833, 0.0790093, -0.0898552, -0.196616, -0.519119, -0.833829, -0.156843, -0.0552401, -0.420467, -0.790062, -0.567228, 0.604037, 0.0548782, 0.337641, 0.727286, -0.512499, 0.423129, -0.472674, -0.234357, -0.837989, -0.0804745, -0.534266, 0.590705, -0.0173138, -0.993542, -0.467762, 0.329378, -0.149618, -0.626351, 0.926951, -0.731922, 0.591711, 0.892089, -0.65769, 0.212003, -0.872393, -0.305017, -0.415064, 0.0193165, 0.651673, 0.668677, 0.452886, -0.341131, 0.604169, 0.263527, -0.909319, -0.917159, -0.69337, 0.53413, -0.876537, 0.0386064, 0.8583, -0.558969, -0.590874, -0.822015, 0.399382, 0.412923, -0.0107382, -0.476443, 0.423143, -0.243268, -0.598711, -0.534785, -0.127808, -0.061003, 0.722375, 0.954477, -0.110962, -0.931514, 0.0475435, -0.936722, 0.506348, 0.183176, 0.645782, -0.340842, -0.535844, 0.0730565, -0.138763, -0.192838, 0.974649, 0.933819, 0.701061, 0.735028, -0.37679, -0.703219, 0.995197, 0.279047, -0.0638505, 0.863829, 0.381545, 0.630814, 0.0830693, 0.145613, -0.681568, 0.890501, 0.642667, -0.689172, -0.91785, -0.312773, -0.779961, -0.806923, 0.0375465, -0.95519, 0.122967, 0.713752, 0.0338162, 0.349049, 0.47435, 0.396921, -0.954401, -0.616032, 0.341982, -0.308257, -0.867419, -0.718579, 0.842235, -0.556687, -0.239771, 0.172032, -0.665374, -0.947321, 0.382253, 0.524688, 0.427391, -0.845351, 0.179619, 0.851006, 0.859396, -0.139108, 0.0119392, 0.662512, 0.838472, 0.203275, 0.447673, 0.0386733, -0.018366, -0.677161, 0.952231, 0.148721, -0.453498, 0.0628546, 0.397116, 0.334099, -0.805135, 0.0974726, 0.221475, 0.335931, -0.00974932, 0.143165, 0.177776, -0.117374, -0.697566, 0.00454946, 0.462728, -0.936687, -0.896767, 0.0441269, -0.358715, -0.929881, -0.509715, -0.785639, -0.234063, 0.103659, 0.194337, 0.222221, 0.870503, 0.544061, 0.026297, -0.0266377, 0.300255, 0.393991, -0.184925, -0.026301, -0.0404343, 0.4209, 0.0683973, -0.446706, 0.211076, -0.439793, 0.406739, 0.0609791, 0.876508, -0.529195, -0.173012, 0.194952, 0.566005, 0.840505, 0.362982, 0.642355, 0.054826, -0.539461, -0.712632, 0.799042, -0.497156, 0.303029]) @@ -43,6 +43,6 @@ if __name__ == "__main__": plt.close() # Part (g) - plt.plot(np.arange(len(dct)), dct_f) + plt.plot(np.arange(len(dct)), dct_g) plt.savefig("fig_g.png") plt.close()