diff --git a/main.cpp b/main.cpp index acae7de332034a5f9b357d8b6a83a1a2df4ce0b1..45300a0f6b82b4dfcfd0b6144f4deae96dfc8fdd 100644 --- a/main.cpp +++ b/main.cpp @@ -45,11 +45,18 @@ Matrix compute_dct_matrix(uint32_t n_samples) { } //-------------------------------------------------------------------------------------------------- -int main() { - //Vector x = Vector::Random(5); - //std::cout << x << '\n'; - +// Selects certain columns of a matrix. +Matrix matrix_subset(Matrix const& matrix, std::vector<uint32_t> const& subset_indices) { + Matrix subset(matrix.rows(), subset_indices.size()); + for (uint32_t i = 0; i < subset_indices.size(); ++i) { + auto const index = subset_indices[i]; + subset.col(i) = matrix.col(index); + } + return subset; +} +//-------------------------------------------------------------------------------------------------- +int main() { constexpr Scalar f1 = 697; constexpr Scalar f2 = 1209; @@ -67,7 +74,7 @@ int main() { python_print("dct", dct); // Part (c) - Vector const recovered_sample_values = dct_matrix.transpose() * dct; + Vector recovered_sample_values = dct_matrix.transpose() * dct; python_print("recovered_sample_values", recovered_sample_values); // Part (d) @@ -93,5 +100,21 @@ int main() { python_print("subset_sample_times", subset_sample_times); python_print("subset_sample_values", subset_sample_values); + // Part (e) + Matrix const subset_dct_matrix = matrix_subset(dct_matrix, subset_indices); + Vector recovered_dct = Vector::Random(n_samples); + recovered_sample_values = dct_matrix.transpose() * recovered_dct; + Vector subset_recovered_sample_values(n_subsamples); + for (uint32_t i = 0; i < n_subsamples; ++i) { + auto const index = subset_indices[i]; + subset_recovered_sample_values[i] = recovered_sample_values[index]; + } + Vector subset_differences = subset_sample_values - subset_recovered_sample_values; + Scalar loss = subset_differences.squaredNorm(); + Vector gradient = -2 * subset_dct_matrix * subset_differences; + + std::cout << loss << '\n'; + std::cout << gradient << '\n'; + return 0; }