![]() |
SG++-Doxygen-Documentation
|
LearnerSGD learns the data using stochastic gradient descent. More...
#include <LearnerSGD.hpp>
Public Member Functions | |
double | getAccuracy (sgpp::base::DataMatrix &testData, sgpp::base::DataVector &testLabels, double threshold) |
Computes the classification accuracy on the given dataset. | |
void | initialize () |
Initializes the SGD learner (creates grid etc.). | |
LearnerSGD (base::RegularGridConfiguration &gridConfig, base::AdaptivityConfiguration &adaptivityConfig, base::DataMatrix &pTrainData, base::DataVector &pTrainLabels, base::DataMatrix &pTestData, base::DataVector &pTestLabels, base::DataMatrix *pValData, base::DataVector *pValLabels, double lambda, double gamma, size_t batchSize, bool useValidData) | |
Constructor. | |
void | storeResults (base::DataMatrix &testDataset) |
Stores classified data, grids and function evaluations to csv files. | |
void | train (size_t maxDataPasses, std::string refType, std::string refMonitor, size_t refPeriod, double errorDeclineThreshold, size_t errorDeclineBufferSize, size_t minRefInterval) |
Implements online learning using stochastic gradient descent. | |
~LearnerSGD () | |
Destructor. | |
Public Attributes | |
sgpp::base::DataVector | avgErrors |
double | error |
Protected Member Functions | |
std::unique_ptr< base::Grid > | createRegularGrid () |
Generates a regular grid. | |
double | getAccuracy (sgpp::base::DataVector &testLabels, double threshold, sgpp::base::DataVector &predictedLabels) |
Computes the classification accuracy. | |
void | getBatchError (sgpp::base::DataMatrix &data, const sgpp::base::DataVector &labels) |
Computes error contribution for each data point of the given data set (required for predictive refinement indicator). | |
double | getError (sgpp::base::DataMatrix &data, sgpp::base::DataVector &labels, std::string errorType) |
Computes specified error type (e.g. | |
void | predict (base::DataMatrix &testData, base::DataVector &predictedLabels) |
Predicts class labels based on the trained model. | |
void | pushToBatch (sgpp::base::DataVector &x, double y) |
Stores the last 'batchSize' processed data points if no validation data is provided. | |
Protected Attributes | |
base::AdaptivityConfiguration | adaptivityConfig |
base::DataVector | alpha |
base::DataVector | alphaAvg |
base::DataMatrix * | batchData |
base::DataVector | batchError |
base::DataVector * | batchLabels |
size_t | batchSize |
double | currentGamma |
double | gamma |
std::unique_ptr< base::Grid > | grid |
base::RegularGridConfiguration | gridConfig |
double | lambda |
base::DataMatrix & | testData |
base::DataVector & | testLabels |
base::DataMatrix & | trainData |
base::DataVector & | trainLabels |
bool | useValidData |
LearnerSGD learns the data using stochastic gradient descent.
sgpp::datadriven::LearnerSGD::LearnerSGD | ( | base::RegularGridConfiguration & | gridConfig, |
base::AdaptivityConfiguration & | adaptivityConfig, | ||
base::DataMatrix & | pTrainData, | ||
base::DataVector & | pTrainLabels, | ||
base::DataMatrix & | pTestData, | ||
base::DataVector & | pTestLabels, | ||
base::DataMatrix * | pValData, | ||
base::DataVector * | pValLabels, | ||
double | lambda, | ||
double | gamma, | ||
size_t | batchSize, | ||
bool | useValidData | ||
) |
Constructor.
gridConfig | The grid configuration |
adaptivityConfig | The refinement configuration |
pTrainData | The training dataset |
pTrainLabels | The corresponding training labels |
pTestData | The test dataset |
pTestLabels | The corresponding test labels |
pValData | The validation dataset |
pValLabels | The corresponding validation labels |
lambda | The regularization parameter |
gamma | The learning parameter (i.e. step width) |
batchSize | The number of data points which are considered to compute the error contributions for predictive refinement |
useValidData | Specifies if validation data should be used for all error computations |
References batchData, batchLabels, batchSize, sgpp::base::DataMatrix::getNcols(), sgpp::base::DataMatrix::reserveAdditionalRows(), sgpp::base::DataVector::setAll(), trainData, and useValidData.
sgpp::datadriven::LearnerSGD::~LearnerSGD | ( | ) |
Destructor.
|
protected |
Generates a regular grid.
References sgpp::base::Grid::createLinearGrid(), sgpp::base::Grid::createModLinearGrid(), sgpp::base::GeneralGridConfiguration::dim_, gridConfig, sgpp::base::GeneralGridConfiguration::level_, sgpp::base::Linear, sgpp::base::ModLinear, and sgpp::base::GeneralGridConfiguration::type_.
Referenced by initialize().
double sgpp::datadriven::LearnerSGD::getAccuracy | ( | sgpp::base::DataMatrix & | testData, |
sgpp::base::DataVector & | testLabels, | ||
double | threshold | ||
) |
Computes the classification accuracy on the given dataset.
testData | The data for which class labels should be predicted |
testLabels | The corresponding actual class labels |
threshold | The decision threshold (e.g. for class labels -1, 1 -> threshold = 0) |
References getAccuracy(), sgpp::base::DataMatrix::getNrows(), predict(), testData, and testLabels.
Referenced by getAccuracy(), and train().
|
protected |
Computes the classification accuracy.
testLabels | The actual class labels |
threshold | The decision threshold (e.g. for class labels -1, 1 -> threshold = 0) |
predictedLabels | The predicted class labels |
References sgpp::base::DataVector::get(), sgpp::base::DataVector::getSize(), and testLabels.
|
protected |
Computes error contribution for each data point of the given data set (required for predictive refinement indicator).
data | The data points |
labels | The corresponding class labels |
References alphaAvg, batchError, sgpp::op_factory::createOperationMultipleEval(), sgpp::base::DataVector::get(), grid, and sgpp::base::DataVector::set().
Referenced by train().
|
protected |
Computes specified error type (e.g.
MSE).
data | The data points |
labels | The corresponding class labels |
errorType | The type of the error measurement (MSE or Hinge loss) |
References alphaAvg, sgpp::op_factory::createOperationMultipleEval(), error, sgpp::base::DataVector::get(), and grid.
Referenced by train().
void sgpp::datadriven::LearnerSGD::initialize | ( | ) |
Initializes the SGD learner (creates grid etc.).
References alpha, alphaAvg, batchError, batchSize, createRegularGrid(), sgpp::base::GeneralGridConfiguration::dim_, sgpp::base::DataMatrix::getNcols(), grid, gridConfig, and trainData.
|
protected |
Predicts class labels based on the trained model.
testData | The data for which class labels should be predicted |
predictedLabels | The predicted class labels |
References alphaAvg, sgpp::op_factory::createOperationMultipleEval(), sgpp::base::DataVector::get(), sgpp::base::DataMatrix::getNrows(), grid, sgpp::base::DataVector::set(), and testData.
Referenced by getAccuracy(), storeResults(), and train().
|
protected |
Stores the last 'batchSize' processed data points if no validation data is provided.
x | The current data point |
y | The corresponding class label |
References sgpp::base::DataMatrix::appendRow(), batchData, batchSize, sgpp::base::DataMatrix::getNrows(), and sgpp::base::DataMatrix::setRow().
Referenced by train().
void sgpp::datadriven::LearnerSGD::storeResults | ( | base::DataMatrix & | testDataset | ) |
Stores classified data, grids and function evaluations to csv files.
testDataset | Data points for which the model is evaluated |
References alphaAvg, sgpp::base::DataMatrix::appendRow(), sgpp::base::HashGridStorage::begin(), sgpp::op_factory::createOperationEval(), sgpp::base::HashGridStorage::end(), sgpp::base::DataVector::get(), sgpp::base::HashGridStorage::getCoordinates(), sgpp::base::DataMatrix::getNcols(), sgpp::base::DataMatrix::getNrows(), sgpp::base::DataMatrix::getRow(), sgpp::base::DataVector::getSize(), grid, predict(), and sgpp::base::DataVector::set().
void sgpp::datadriven::LearnerSGD::train | ( | size_t | maxDataPasses, |
std::string | refType, | ||
std::string | refMonitor, | ||
size_t | refPeriod, | ||
double | errorDeclineThreshold, | ||
size_t | errorDeclineBufferSize, | ||
size_t | minRefInterval | ||
) |
Implements online learning using stochastic gradient descent.
maxDataPasses | The number of passes over the whole training data |
refType | The refinement indicator (surplus, zero-crossings or data-based) |
refMonitor | The refinement strategy (periodic or convergence-based) |
refPeriod | The refinement interval (if periodic refinement is chosen) |
errorDeclineThreshold | The convergence threshold (if convergence-based refinement is chosen) |
errorDeclineBufferSize | The number of error measurements which are used to check convergence (if convergence-based refinement is chosen) |
minRefInterval | The minimum number of data points which have to be processed before next refinement can be scheduled (if convergence-based refinement is chosen) |
References adaptivityConfig, alpha, alphaAvg, sgpp::base::DataVector::append(), avgErrors, sgpp::base::DataVector::axpy(), batchData, batchError, batchLabels, sgpp::op_factory::createOperationMultipleEval(), currentGamma, sgpp::base::DataVector::dotProduct(), error, sgpp::base::ImpurityRefinement::free_refine(), sgpp::base::PredictiveRefinement::free_refine(), gamma, sgpp::base::DataVector::get(), getAccuracy(), getBatchError(), getError(), sgpp::base::DataMatrix::getNcols(), sgpp::base::DataMatrix::getNrows(), sgpp::base::DataVector::getPointer(), sgpp::base::DataMatrix::getRow(), sgpp::base::DataVector::getSize(), grid, lambda, mu, sgpp::base::DataVector::mult(), sgpp::base::AdaptivityConfiguration::numRefinementPoints_, sgpp::base::AdaptivityConfiguration::numRefinements_, predict(), pushToBatch(), sgpp::datadriven::RefinementMonitor::pushToBuffer(), sgpp::datadriven::RefinementMonitor::refinementsNecessary(), sgpp::base::AdaptivityConfiguration::refinementThreshold_, sgpp::datadriven::residual, sgpp::base::DataVector::resizeZero(), testData, testLabels, trainData, trainLabels, and useValidData.
|
protected |
Referenced by train().
|
protected |
Referenced by python.learner.Learner.Learner::doLearningIteration(), initialize(), python.learner.Learner.Learner::learnData(), python.learner.Learner.Learner::learnDataWithTest(), python.uq.learner.Interpolant.Interpolant::learnDataWithTest(), python.uq.dists.SGDEdist.SGDEdist::pdf(), python.uq.dists.SGDEdist.SGDEdist::toJson(), and train().
|
protected |
Referenced by getBatchError(), getError(), initialize(), predict(), storeResults(), and train().
sgpp::base::DataVector sgpp::datadriven::LearnerSGD::avgErrors |
Referenced by train().
|
protected |
Referenced by LearnerSGD(), pushToBatch(), and train().
|
protected |
Referenced by getBatchError(), initialize(), and train().
|
protected |
Referenced by LearnerSGD(), and train().
|
protected |
Referenced by initialize(), LearnerSGD(), and pushToBatch().
|
protected |
Referenced by train().
double sgpp::datadriven::LearnerSGD::error |
Referenced by python.uq.dists.J.J::discretize(), python.learner.Regressor.Regressor::evalError(), python.uq.learner.Regressor.Regressor::evalError(), getError(), python.learner.Regressor.Regressor::getL2NormError(), python.uq.learner.Regressor.Regressor::getL2NormError(), python.learner.Regressor.Regressor::getMaxError(), python.uq.learner.Regressor.Regressor::getMaxError(), python.learner.Regressor.Regressor::getMinError(), python.uq.learner.Regressor.Regressor::getMinError(), and train().
|
protected |
Referenced by train().
|
protected |
Referenced by python.uq.dists.SGDEdist.SGDEdist::__str__(), python.learner.Learner.Learner::applyData(), python.tools.Matrix::ApplyMatrix(), python.uq.dists.SGDEdist.SGDEdist::cdf(), python.uq.learner.Interpolant.Interpolant::doLearningIteration(), python.learner.Learner.Learner::doLearningIteration(), python.learner.Classifier.Classifier::evalError(), python.uq.learner.Interpolant.Interpolant::evalError(), python.tools.Matrix::generateb(), python.controller.CheckpointController.CheckpointController::generateFoldValidationJob(), getBatchError(), python.uq.learner.SimulationLearner.SimulationLearner::getCollocationNodes(), getError(), python.uq.learner.SimulationLearner.SimulationLearner::getGrid(), python.uq.learner.SimulationLearner.SimulationLearner::getLearner(), initialize(), python.learner.Learner.Learner::learnData(), python.learner.Learner.Learner::learnDataWithFolding(), python.uq.learner.Regressor.Regressor::learnDataWithFolding(), python.learner.Learner.Learner::learnDataWithTest(), python.uq.learner.Regressor.Regressor::learnDataWithTest(), python.controller.CheckpointController.CheckpointController::loadAll(), python.uq.operations.forcePositivity.operationMakePositive.OperationMakePositive::makePositive(), python.uq.operations.forcePositivity.operationMakePositiveFast.OperationMakePositiveFast::makePositive(), python.uq.dists.SGDEdist.SGDEdist::mean(), python.uq.dists.SGDEdist.SGDEdist::pdf(), python.uq.dists.SGDEdist.SGDEdist::ppf(), predict(), python.learner.Classifier.Classifier::refineGrid(), python.learner.Regressor.Regressor::refineGrid(), python.uq.learner.Regressor.Regressor::refineGrid(), python.uq.learner.SimulationLearner.SimulationLearner::refineGrid(), python.controller.CheckpointController.CheckpointController::saveGrid(), python.controller.CheckpointController.CheckpointController::setGrid(), python.learner.Learner.Learner::setGrid(), python.uq.learner.Learner.Learner::setGrid(), python.uq.operations.forcePositivity.localFullGridSearch.LocalFullGrid::split(), storeResults(), python.uq.dists.SGDEdist.SGDEdist::toJson(), train(), and python.uq.dists.SGDEdist.SGDEdist::var().
|
protected |
Referenced by createRegularGrid(), and initialize().
|
protected |
Referenced by train().
|
protected |
|
protected |
Referenced by getAccuracy(), getAccuracy(), and train().
|
protected |
Referenced by python.uq.dists.KDEDist.KDEDist::__init__(), python.uq.dists.LibAGFDist.LibAGFDist::cdf(), python.uq.dists.EstimatedDist.EstimatedDist::getSamples(), initialize(), LearnerSGD(), python.uq.dists.KDEDist.KDEDist::marginalize(), python.uq.dists.SGDEdist.SGDEdist::marginalize(), python.uq.dists.KDEDist.KDEDist::marginalizeToDimX(), python.uq.dists.SGDEdist.SGDEdist::marginalizeToDimX(), python.uq.dists.LibAGFDist.LibAGFDist::pdf(), python.uq.dists.LibAGFDist.LibAGFDist::ppf(), python.uq.dists.SGDEdist.SGDEdist::toJson(), and train().
|
protected |
Referenced by train().
|
protected |
Referenced by LearnerSGD(), and train().