SG++-Doxygen-Documentation
Loading...
Searching...
No Matches
Regression Learner

This example demonstrates sparse grid regression learning.

getLearner

Parameters
dimensionis the number of dimensions
regularizationConfig
Returns
a sparse grid regression learner
size_t dimension, sgpp::datadriven::RegularizationConfiguration regularizationConfig) {
gridConfig.dim_ = dimension;
gridConfig.level_ = 3;
gridConfig.type_ = sgpp::base::GridType::ModLinear;
// gridConfig.type_ = sgpp::base::GridType::ModNakBspline;
gridConfig.maxDegree_ = 3;
auto solverConfig = sgpp::solver::SLESolverConfiguration();
solverConfig.type_ = sgpp::solver::SLESolverType::CG;
solverConfig.maxIterations_ = 1000;
solverConfig.eps_ = 1e-8;
solverConfig.threshold_ = 1e-5;
return sgpp::datadriven::RegressionLearner(gridConfig, adaptivityConfig, solverConfig,
solverConfig, regularizationConfig);
}
The RegressionLearner class Solves a regression problem with continuous target vector.
Definition RegressionLearner.hpp:31
sgpp::base::AdaptivityConfiguration adaptivityConfig
Definition multHPX.cpp:37
structure that can be used by application to define adaptivity strategies
Definition Grid.hpp:143
size_t numRefinementPoints_
max. number of points to be refined
Definition Grid.hpp:157
size_t numRefinements_
number of refinements
Definition Grid.hpp:145
structure that can be used by applications to cluster regular grid information
Definition Grid.hpp:111
Definition RegularizationConfiguration.hpp:17
Definition TypesSolver.hpp:19

showRegularizationConfiguration

Parameters
regularizationConfig
Returns
type of the regularization method as string
std::string showRegularizationConfiguration(
const sgpp::datadriven::RegularizationConfiguration& regularizationConfig) {
std::ostringstream ss;
const auto regType = regularizationConfig.type_;
ss << "type: DiagonalMatrix\t";
ss << "type: IdentityMatrix\t";
ss << "type: Laplace\t";
} else {
ss << "type: unknown\t";
}
ss << "lambda: " << regularizationConfig.lambda_
<< "\tmultiplicationFactor: " << regularizationConfig.exponentBase_;
return ss.str();
}
RegularizationType type_
Definition RegularizationConfiguration.hpp:18
double exponentBase_
Definition RegularizationConfiguration.hpp:21
double lambda_
Definition RegularizationConfiguration.hpp:19

gridSearch performs a hyper-parameter grid search over configs using a holdout validation set.

Parameters
configsare the regularization configs that will be tried
dimensionis the number of dimensions
xTrainare the training predictors
yTrainis the training target
xValidationare the validation predictors
yValidationis the validation target
Returns
best found regularization configuration
std::vector<sgpp::datadriven::RegularizationConfiguration> configs, size_t dimension,
sgpp::base::DataMatrix& xValidation, sgpp::base::DataVector& yValidation) {
double bestMSE = std::numeric_limits<double>::max();
for (const auto& config : configs) {
// Step 1: Create a learner
auto learner = getLearner(dimension, config);
// Step 2: Train it with the hyperparameter
learner.train(xTrain, yTrain);
// Step 3: Evaluate accuracy
const double curMSE = learner.getMSE(xValidation, yValidation);
std::cout << "Tested parameters are\n" << showRegularizationConfiguration(config) << ".\n";
if (curMSE < bestMSE) {
std::cout << "Better! RMSE is now " << std::sqrt(curMSE) << std::endl;
bestConfig = config;
bestMSE = curMSE;
} else {
std::cout << "Worse! RMSE is now " << std::sqrt(curMSE) << std::endl;
}
}
std::cout << "gridSearch finished with parameters " << showRegularizationConfiguration(bestConfig)
<< std::endl;
return bestConfig;
}
A class to store two-dimensional data.
Definition DataMatrix.hpp:28
A class to store one-dimensional data.
Definition DataVector.hpp:25

getConfigs

Returns
some regularization configurations for seven lambdas between 1 and 0.000001 and for exponent bases 1.0, 0.5, 0.25, 0.125
std::vector<sgpp::datadriven::RegularizationConfiguration> getConfigs() {
decltype(getConfigs()) result;
std::vector<double> lambdas = {1.0, 0.1, 0.01, 0.001, 0.0001, 0.00001, 0.000001};
std::vector<double> exponentBases = {1.0, 0.5, 0.25, 0.125};
for (const auto lambda : lambdas) {
// Identity
const auto regularizationType = sgpp::datadriven::RegularizationType::Identity;
auto regularizationConfig = sgpp::datadriven::RegularizationConfiguration();
regularizationConfig.type_ = regularizationType;
regularizationConfig.lambda_ = lambda;
regularizationConfig.exponentBase_ = 0.25;
result.push_back(regularizationConfig);
{
// Laplace
const auto regularizationType = sgpp::datadriven::RegularizationType::Laplace;
auto regularizationConfig = sgpp::datadriven::RegularizationConfiguration();
regularizationConfig.type_ = regularizationType;
regularizationConfig.lambda_ = lambda;
regularizationConfig.exponentBase_ = 0.25;
result.push_back(regularizationConfig);
}
// Diagonal
for (const auto exponentBase : exponentBases) {
const auto regularizationType = sgpp::datadriven::RegularizationType::Diagonal;
auto regularizationConfig = sgpp::datadriven::RegularizationConfiguration();
regularizationConfig.type_ = regularizationType;
regularizationConfig.lambda_ = lambda;
regularizationConfig.exponentBase_ = exponentBase;
result.push_back(regularizationConfig);
}
}
return result;
}
base::DataVector & lambda
Definition AugmentedLagrangian.cpp:79

main is an example for the RegressionLearner. It performs a grid search for the best hyper-parameter for the Friedman3 dataset using the diagonal Tikhonov regularization method.

int main(int argc, char** argv) {
const auto filenameTrain = std::string("../datasets/friedman/friedman3_10k_train.arff");
const auto filenameValidation = std::string("../datasets/friedman/friedman3_10k_validation.arff");
const auto filenameTest = std::string("../datasets/friedman/friedman3_10k_test.arff");
auto dataTrain = sgpp::datadriven::ARFFTools::readARFFFromFile(filenameTrain);
std::cout << "Read file " << filenameTrain << "." << std::endl;
auto xTrain = dataTrain.getData();
auto yTrain = dataTrain.getTargets();
const auto dimensions = dataTrain.getDimension();
auto dataValidation = sgpp::datadriven::ARFFTools::readARFFFromFile(filenameValidation);
std::cout << "Read file " << filenameValidation << "." << std::endl;
auto xValidation = dataValidation.getData();
auto yValidation = dataValidation.getTargets();
const auto configs = getConfigs();
const auto bestConfig = gridSearch(configs, dimensions, xTrain, yTrain, xValidation, yValidation);
auto dataTest = sgpp::datadriven::ARFFTools::readARFFFromFile(filenameTest);
std::cout << "Read file " << filenameTest << "." << std::endl;
auto xTest = dataTest.getData();
auto yTest = dataTest.getTargets();
auto learner = getLearner(dimensions, bestConfig);
learner.train(xTrain, yTrain);
const auto MSETest = learner.getMSE(xTest, yTest);
std::cout << "Best config got a testing MSE of " << MSETest << "!" << std::endl;
}
static Dataset readARFFFromFile(const std::string &filename, bool hasTargets=true, size_t instanceCutoff=-1, std::vector< size_t > selectedCols=std::vector< size_t >(), std::vector< double > selectedTargets=std::vector< double >())
Wrapper from input type: File.
Definition ARFFTools.cpp:48
int main()
Definition densityMultiplication.cpp:22