SG++-Doxygen-Documentation
Loading...
Searching...
No Matches
FISTA Solver

This example demonstrates the FISTA solver for a toy dataset using using the elastic net regularization method with various regularization penalties.

#include <cmath>
#include <vector>
#include <random>
int main(int argc, char** argv) {
A class to store two-dimensional data.
Definition DataMatrix.hpp:28
A class to store one-dimensional data.
Definition DataVector.hpp:25
int main()
Definition densityMultiplication.cpp:22

Create a two-dimensional grid with level 6.

grid->getGenerator().regular(6);
static Grid * createModLinearGrid(size_t dim)
creates a modified linear grid
Definition Grid.cpp:99
std::unique_ptr< sgpp::base::Grid > grid(nullptr)

Then create a two dimensional dataset.

const auto num_examples = 3000;
auto dataset = DataMatrix(num_examples, 2);
auto y = DataVector(num_examples);
auto gen = std::mt19937_64(42);
auto dist = std::normal_distribution<double>(0, 0.1);
for (auto i = 0; i < num_examples; ++i) {
const double x1 = std::abs(std::sin(i));
const double x2 = std::abs(std::sin(i * x1));
const double comb = std::sinh(x1) * (x1 + x2) + dist(gen);
const auto row = DataVector(std::vector<double>({x1, x2}));
dataset.setRow(i, row);
y.set(i, comb);
}
std::cout << "Created grid with size: " << grid->getSize() << std::endl;
sgpp::datadriven::Dataset dataset
Definition multHPX.cpp:42
base::OperationMultipleEval * createOperationMultipleEval(base::Grid &grid, base::DataMatrix &dataset)
Factory method, returning an OperationMultipleEval for the grid at hand.
Definition BaseOpFactory.cpp:595

Set up the weights for the grid.

DataVector weights(grid->getSize());
double lambda = 1.0;
const double gamma = 0.98;
base::DataVector & lambda
Definition AugmentedLagrangian.cpp:79

Iterate over a few regularization penalties.

for (int i = 0; i < 5; ++i) {
weights.setAll(0.0);
lambda /= 10;

Create an elastic net function and the corresponding fista solver.

std::cout << "Start solving." << std::endl;
base::VectorFunction & g
Definition AugmentedLagrangian.cpp:76
The ElasticNetFunction class.
Definition ElasticNetFunction.hpp:23
Fast Iterative Shrinkage Tresholding Algorithm is a solver for least-squares problems.
Definition Fista.hpp:31

Here we solve the penalised linear system.

solver.solve(*op, weights, y, 500, 1e-4);
std::cout << "Finished solving." << std::endl;
auto prediction = DataVector(num_examples);

Finally, we calculate the root-mean-squared error of the prediction and print it.

op->mult(weights, prediction);
prediction.sub(y);
prediction.sqr();
const double rmse = std::sqrt((1.0 / num_examples) * prediction.sum());
std::cout << "Lambda = " << lambda << " |weights|_2 = " << weights.l2Norm()
<< std::endl;
std::cout << "Residual = " << rmse << std::endl;
}
}