SG++-Doxygen-Documentation
Loading...
Searching...
No Matches
sgpp::datadriven::ModelFittingBase Class Referenceabstract

Base class for arbitrary machine learning models based on adaptive sparse grids. More...

#include <ModelFittingBase.hpp>

Inheritance diagram for sgpp::datadriven::ModelFittingBase:
sgpp::datadriven::ModelFittingBaseSingleGrid sgpp::datadriven::ModelFittingClassification sgpp::datadriven::ModelFittingDensityEstimation sgpp::datadriven::ModelFittingDensityRatioEstimation sgpp::datadriven::ModelFittingLeastSquares sgpp::datadriven::ModelFittingDensityDifferenceEstimationCG sgpp::datadriven::ModelFittingDensityDifferenceEstimationOnOff sgpp::datadriven::ModelFittingDensityEstimationCG sgpp::datadriven::ModelFittingDensityEstimationCombi sgpp::datadriven::ModelFittingDensityEstimationOnOff sgpp::datadriven::ModelFittingDensityEstimationOnOffParallel

Public Member Functions

virtual bool adapt ()=0
 Improve accuracy of the model on the given training data by adaptive refinement of the grid.
 
virtual double computeResidual (DataMatrix &validationData) const =0
 Should compute some kind of Residual to evaluate the fit of the model.
 
virtual double evaluate (const DataVector &sample)=0
 Evaluate the fitted model at a single data point.
 
virtual void evaluate (DataMatrix &samples, DataVector &results)=0
 Evaluate the fitted model on a set of data points.
 
virtual void fit (Dataset &dataset)=0
 Fit the grid to the dataset by determining the weights of an initial grid.
 
virtual void fit (Dataset &datasetP, Dataset &datasetQ)=0
 
DatasetgetDataset ()
 
FitterConfigurationgetFitterConfiguration ()
 Get or set the configuration of the fitter object.
 
const FitterConfigurationgetFitterConfiguration () const
 Get the configuration of the fitter object.
 
virtual GridgetGrid ()
 Return the learned grid.
 
virtual std::shared_ptr< BlacsProcessGridgetProcessGrid () const
 
virtual DataVectorgetSurpluses ()
 Return the learned hierarchical surpluses.
 
 ModelFittingBase ()
 Default constructor.
 
 ModelFittingBase (const ModelFittingBase &rhs)=delete
 Copy constructor - we cannot deep copy all member variables yet.
 
 ModelFittingBase (ModelFittingBase &&rhs)=default
 Move constructor.
 
ModelFittingBaseoperator= (const ModelFittingBase &rhs)=delete
 Copy assign operator - we cannot deep copy all member variables yet.
 
ModelFittingBaseoperator= (ModelFittingBase &&rhs)=default
 Move assign operator.
 
virtual void reset ()=0
 Resets the state of the entire model.
 
virtual void resetTraining ()=0
 Resets any trained representations of the model, but does not reset the entire state.
 
virtual void update (Dataset &dataset)=0
 Train the grid of an existing model with new samples.
 
virtual void update (Dataset &datasetP, Dataset &datasetQ)=0
 
virtual void updateRegularization (double lambda)=0
 Updates the regularization parameter lambda of the underlying model.
 
virtual ~ModelFittingBase ()=default
 virtual destructor.
 

Public Attributes

bool verboseSolver
 Whether the Solver produces output or not.
 

Protected Member Functions

GridbuildGrid (const sgpp::base::GeneralGridConfiguration &gridConfig) const
 Factory member function that generates a grid from configuration.
 
GridbuildGrid (const sgpp::base::GeneralGridConfiguration &gridConfig, const GeometryConfiguration &geometryConfig) const
 Factory member function that generates a grid from configuration.
 
SLESolverbuildSolver (const SLESolverConfiguration &config) const
 Factory member function to build the solver for the least squares regression problem according to the config.
 
std::set< std::set< size_t > > getInteractions (const GeometryConfiguration &geometryConfig)
 
void reconfigureSolver (SLESolver &solver, const SLESolverConfiguration &config) const
 Configure solver based on the desired configuration.
 

Protected Attributes

std::unique_ptr< FitterConfigurationconfig
 Configuration object for the fitter.
 
Datasetdataset
 Pointer to sgpp::datadriven::Dataset.
 
DatasetextraDataset
 
std::unique_ptr< std::set< std::set< size_t > > > interactions
 
std::unique_ptr< SLESolversolver
 Solver for the learning problem.
 

Detailed Description

Base class for arbitrary machine learning models based on adaptive sparse grids.

A model tries to generalize high dimensional training data by using sparse grids. An underlying model can be trained using training data, its accuracy can be improved by using the adaptivity of sparse grids and the underlying grid(s) of a model can be retrained on other data. Once a model is trained it can be evaluated on unseen data.

Constructor & Destructor Documentation

◆ ModelFittingBase() [1/3]

sgpp::datadriven::ModelFittingBase::ModelFittingBase ( )

Default constructor.

◆ ModelFittingBase() [2/3]

sgpp::datadriven::ModelFittingBase::ModelFittingBase ( const ModelFittingBase rhs)
delete

Copy constructor - we cannot deep copy all member variables yet.

Parameters
rhsconst reference to the scorer object to copy from.

◆ ModelFittingBase() [3/3]

sgpp::datadriven::ModelFittingBase::ModelFittingBase ( ModelFittingBase &&  rhs)
default

Move constructor.

Parameters
rhsR-value reference to a scorer object to moved from.

◆ ~ModelFittingBase()

virtual sgpp::datadriven::ModelFittingBase::~ModelFittingBase ( )
virtualdefault

virtual destructor.

Member Function Documentation

◆ adapt()

virtual bool sgpp::datadriven::ModelFittingBase::adapt ( )
pure virtual

Improve accuracy of the model on the given training data by adaptive refinement of the grid.

Returns
true if refinement was performed, else false.

Implemented in sgpp::datadriven::ModelFittingClassification, sgpp::datadriven::ModelFittingDensityEstimation, sgpp::datadriven::ModelFittingDensityEstimationCombi, sgpp::datadriven::ModelFittingDensityRatioEstimation, and sgpp::datadriven::ModelFittingLeastSquares.

◆ buildGrid() [1/2]

◆ buildGrid() [2/2]

Grid * sgpp::datadriven::ModelFittingBase::buildGrid ( const sgpp::base::GeneralGridConfiguration gridConfig,
const GeometryConfiguration geometryConfig 
) const
protected

Factory member function that generates a grid from configuration.

Parameters
gridConfigconfiguration for the grid object
geometryConfigconfiguration for the geometry parameters
Returns
new grid object that is owned by the caller.

References sgpp::datadriven::GridFactory::createGrid(), sgpp::datadriven::GridFactory::getInteractions(), interactions, and sgpp::datadriven::GeometryConfiguration::stencils_.

◆ buildSolver()

SLESolver * sgpp::datadriven::ModelFittingBase::buildSolver ( const SLESolverConfiguration config) const
protected

Factory member function to build the solver for the least squares regression problem according to the config.

Parameters
configconfiguratin for the solver object

References sgpp::solver::SLESolverConfiguration::eps_, sgpp::solver::SLESolverConfiguration::maxIterations_, and sgpp::solver::SLESolverConfiguration::type_.

Referenced by sgpp::datadriven::ModelFittingDensityRatioEstimation::ModelFittingDensityRatioEstimation(), and sgpp::datadriven::ModelFittingLeastSquares::ModelFittingLeastSquares().

◆ computeResidual()

virtual double sgpp::datadriven::ModelFittingBase::computeResidual ( DataMatrix validationData) const
pure virtual

Should compute some kind of Residual to evaluate the fit of the model.

In the case of density estimation, this is || R * alpha_lambda - b_val ||_2

This is useful for unsupervised learning models, where normal evaluation cannot be used as there are no targets.

Parameters
validationDataMatrix for validation data
Returns
the residual score

Implemented in sgpp::datadriven::ModelFittingClassification, sgpp::datadriven::ModelFittingDensityDifferenceEstimationCG, sgpp::datadriven::ModelFittingDensityDifferenceEstimationOnOff, sgpp::datadriven::ModelFittingDensityEstimationCG, sgpp::datadriven::ModelFittingDensityEstimationCombi, sgpp::datadriven::ModelFittingDensityEstimationOnOff, sgpp::datadriven::ModelFittingDensityEstimationOnOffParallel, sgpp::datadriven::ModelFittingDensityRatioEstimation, sgpp::datadriven::ModelFittingLeastSquares, and sgpp::datadriven::ModelFittingDensityEstimation.

Referenced by sgpp::datadriven::ResidualScore::measure().

◆ evaluate() [1/2]

◆ evaluate() [2/2]

virtual void sgpp::datadriven::ModelFittingBase::evaluate ( DataMatrix samples,
DataVector results 
)
pure virtual

◆ fit() [1/2]

◆ fit() [2/2]

◆ getDataset()

◆ getFitterConfiguration() [1/2]

FitterConfiguration & sgpp::datadriven::ModelFittingBase::getFitterConfiguration ( )

Get or set the configuration of the fitter object.

Returns
configuration of the fitter object

References config.

◆ getFitterConfiguration() [2/2]

◆ getGrid()

◆ getInteractions()

std::set< std::set< size_t > > sgpp::datadriven::ModelFittingBase::getInteractions ( const GeometryConfiguration geometryConfig)
protected

◆ getProcessGrid()

virtual std::shared_ptr< BlacsProcessGrid > sgpp::datadriven::ModelFittingBase::getProcessGrid ( ) const
inlinevirtual
Returns
the BLACS process grid, useful if the fitter uses ScaLAPACK

Reimplemented in sgpp::datadriven::ModelFittingDensityEstimationOnOffParallel.

◆ getSurpluses()

virtual DataVector & sgpp::datadriven::ModelFittingBase::getSurpluses ( )
inlinevirtual

Return the learned hierarchical surpluses.

Reimplemented in sgpp::datadriven::ModelFittingBaseSingleGrid.

Referenced by sgpp::datadriven::ModelFittingClassification::adapt().

◆ operator=() [1/2]

ModelFittingBase & sgpp::datadriven::ModelFittingBase::operator= ( const ModelFittingBase rhs)
delete

Copy assign operator - we cannot deep copy all member variables yet.

Parameters
rhsconst reference to the scorer object to copy from.
Returns
rerefernce to this with updated values.

◆ operator=() [2/2]

ModelFittingBase & sgpp::datadriven::ModelFittingBase::operator= ( ModelFittingBase &&  rhs)
default

Move assign operator.

Parameters
rhsR-value reference to an a scorer object to move from.
Returns
rerefernce to this with updated values.

◆ reconfigureSolver()

void sgpp::datadriven::ModelFittingBase::reconfigureSolver ( SLESolver solver,
const SLESolverConfiguration config 
) const
protected

Configure solver based on the desired configuration.

Parameters
solverthe solver object to be modified.
configconfiguration updating the for the solver.

References sgpp::solver::SLESolverConfiguration::eps_, sgpp::solver::SLESolverConfiguration::maxIterations_, and solver.

◆ reset()

◆ resetTraining()

◆ update() [1/2]

◆ update() [2/2]

◆ updateRegularization()

Member Data Documentation

◆ config

std::unique_ptr<FitterConfiguration> sgpp::datadriven::ModelFittingBase::config
protected

Configuration object for the fitter.

Referenced by sgpp::datadriven::ModelFittingClassification::adapt(), sgpp::datadriven::ModelFittingDensityEstimation::adapt(), sgpp::datadriven::ModelFittingDensityRatioEstimation::adapt(), sgpp::datadriven::ModelFittingLeastSquares::adapt(), sgpp::datadriven::ModelFittingDensityDifferenceEstimationOnOff::adapt(), sgpp::datadriven::ModelFittingDensityEstimationOnOff::adapt(), sgpp::datadriven::ModelFittingDensityEstimationOnOffParallel::adapt(), sgpp::datadriven::ModelFittingDensityEstimationCombi::addNewModel(), sgpp::datadriven::ModelFittingDensityEstimationOnOff::computeResidual(), sgpp::datadriven::ModelFittingDensityEstimationOnOffParallel::computeResidual(), sgpp::datadriven::ModelFittingClassification::evaluate(), sgpp::datadriven::ModelFittingDensityEstimationOnOffParallel::evaluate(), sgpp::datadriven::ModelFittingDensityRatioEstimation::evaluate(), sgpp::datadriven::ModelFittingLeastSquares::evaluate(), sgpp::datadriven::ModelFittingDensityEstimationCG::fit(), sgpp::datadriven::ModelFittingDensityEstimationOnOff::fit(), sgpp::datadriven::ModelFittingDensityEstimationOnOffParallel::fit(), sgpp::datadriven::ModelFittingDensityDifferenceEstimationCG::fit(), sgpp::datadriven::ModelFittingDensityDifferenceEstimationOnOff::fit(), sgpp::datadriven::ModelFittingDensityEstimationCombi::fit(), sgpp::datadriven::ModelFittingLeastSquares::fit(), sgpp::datadriven::ModelFittingDensityRatioEstimation::fit(), sgpp::datadriven::ModelFittingDensityEstimation::getCoarseningFunctor(), getFitterConfiguration(), getFitterConfiguration(), sgpp::datadriven::ModelFittingDensityEstimation::getRefinementFunctor(), python.uq.dists.SGDEdist.SGDEdist::marginalize(), python.uq.dists.SGDEdist.SGDEdist::marginalizeToDimX(), sgpp::datadriven::ModelFittingClassification::ModelFittingClassification(), sgpp::datadriven::ModelFittingDensityDifferenceEstimationCG::ModelFittingDensityDifferenceEstimationCG(), sgpp::datadriven::ModelFittingDensityDifferenceEstimationOnOff::ModelFittingDensityDifferenceEstimationOnOff(), sgpp::datadriven::ModelFittingDensityEstimationCG::ModelFittingDensityEstimationCG(), sgpp::datadriven::ModelFittingDensityEstimationCombi::ModelFittingDensityEstimationCombi(), sgpp::datadriven::ModelFittingDensityEstimationOnOff::ModelFittingDensityEstimationOnOff(), sgpp::datadriven::ModelFittingDensityEstimationOnOffParallel::ModelFittingDensityEstimationOnOffParallel(), sgpp::datadriven::ModelFittingDensityEstimationOnOffParallel::ModelFittingDensityEstimationOnOffParallel(), sgpp::datadriven::ModelFittingDensityRatioEstimation::ModelFittingDensityRatioEstimation(), sgpp::datadriven::ModelFittingLeastSquares::ModelFittingLeastSquares(), python.uq.dists.SGDEdist.SGDEdist::toJson(), sgpp::datadriven::ModelFittingDensityEstimationCG::update(), sgpp::datadriven::ModelFittingDensityEstimationOnOff::update(), sgpp::datadriven::ModelFittingDensityEstimationOnOffParallel::update(), sgpp::datadriven::ModelFittingDensityDifferenceEstimationCG::update(), sgpp::datadriven::ModelFittingDensityDifferenceEstimationOnOff::update(), sgpp::datadriven::ModelFittingLeastSquares::update(), sgpp::datadriven::ModelFittingDensityRatioEstimation::update(), sgpp::datadriven::ModelFittingDensityEstimationOnOff::updateRegularization(), and sgpp::datadriven::ModelFittingDensityEstimationOnOffParallel::updateRegularization().

◆ dataset

◆ extraDataset

◆ interactions

std::unique_ptr<std::set<std::set<size_t> > > sgpp::datadriven::ModelFittingBase::interactions
protected

◆ solver

◆ verboseSolver

bool sgpp::datadriven::ModelFittingBase::verboseSolver

Whether the Solver produces output or not.


The documentation for this class was generated from the following files: