This example demonstrates the FISTA solver for a toy dataset using using the elastic net regularization method with various regularization penalties.
#include <cmath>
#include <vector>
#include <random>
int main(
int argc,
char** argv) {
A class to store two-dimensional data.
Definition DataMatrix.hpp:28
A class to store one-dimensional data.
Definition DataVector.hpp:25
int main()
Definition densityMultiplication.cpp:22
Create a two-dimensional grid with level 6.
grid->getGenerator().regular(6);
static Grid * createModLinearGrid(size_t dim)
creates a modified linear grid
Definition Grid.cpp:99
std::unique_ptr< sgpp::base::Grid > grid(nullptr)
Then create a two dimensional dataset.
const auto num_examples = 3000;
auto dataset = DataMatrix(num_examples, 2);
auto y = DataVector(num_examples);
auto gen = std::mt19937_64(42);
auto dist = std::normal_distribution<double>(0, 0.1);
for (auto i = 0; i < num_examples; ++i) {
const double x1 = std::abs(std::sin(i));
const double x2 = std::abs(std::sin(i * x1));
const double comb = std::sinh(x1) * (x1 + x2) + dist(gen);
const auto row = DataVector(std::vector<double>({x1, x2}));
y.set(i, comb);
}
std::cout <<
"Created grid with size: " <<
grid->getSize() << std::endl;
sgpp::datadriven::Dataset dataset
Definition multHPX.cpp:42
base::OperationMultipleEval * createOperationMultipleEval(base::Grid &grid, base::DataMatrix &dataset)
Factory method, returning an OperationMultipleEval for the grid at hand.
Definition BaseOpFactory.cpp:595
Set up the weights for the grid.
DataVector weights(
grid->getSize());
const double gamma = 0.98;
base::DataVector & lambda
Definition AugmentedLagrangian.cpp:79
Iterate over a few regularization penalties.
for (int i = 0; i < 5; ++i) {
weights.setAll(0.0);
Create an elastic net function and the corresponding fista solver.
std::cout << "Start solving." << std::endl;
base::VectorFunction & g
Definition AugmentedLagrangian.cpp:76
The ElasticNetFunction class.
Definition ElasticNetFunction.hpp:23
Fast Iterative Shrinkage Tresholding Algorithm is a solver for least-squares problems.
Definition Fista.hpp:31
Here we solve the penalised linear system.
solver.solve(*op, weights, y, 500, 1e-4);
std::cout << "Finished solving." << std::endl;
auto prediction = DataVector(num_examples);
Finally, we calculate the root-mean-squared error of the prediction and print it.
op->mult(weights, prediction);
prediction.sub(y);
prediction.sqr();
const double rmse = std::sqrt((1.0 / num_examples) * prediction.sum());
std::cout <<
"Lambda = " <<
lambda <<
" |weights|_2 = " << weights.l2Norm()
<< std::endl;
std::cout << "Residual = " << rmse << std::endl;
}
}