SG++-Doxygen-Documentation
Loading...
Searching...
No Matches
sgpp::datadriven::ModelFittingClassification Class Reference

Fitter object that encapsulates density based classification using instances of ModelFittingDensityEstimation for each class. More...

#include <ModelFittingClassification.hpp>

Inheritance diagram for sgpp::datadriven::ModelFittingClassification:
sgpp::datadriven::ModelFittingBase

Public Member Functions

bool adapt () override
 Improve the accuracy of the classification by refining or coarsening the grids of each class.
 
double computeResidual (DataMatrix &validationData) const override
 Should compute some kind of Residual to evaluate the fit of the model.
 
double evaluate (const DataVector &sample) override
 Predict the class of a data sample based on the density of the sample for each model.
 
void evaluate (DataMatrix &samples, DataVector &results) override
 Predicts the class for a set of data points based on the learned densities for each class.
 
void fit (Dataset &dataset) override
 Fits the models for all classes based on the data given in the dataset parameter.
 
void fit (Dataset &datasetP, Dataset &datasetQ) override
 
std::map< double, size_t > getClassIdx ()
 obtain the index mapping for each label class.
 
std::vector< std::unique_ptr< ModelFittingDensityEstimation > > * getModels ()
 obtain the density estimation models per each class.
 
 ModelFittingClassification (const FitterConfigurationClassification &config)
 Constructor.
 
 ModelFittingClassification (const FitterConfigurationClassification &config, std::shared_ptr< DBMatObjectStore > objectStore)
 Constructor with specified object store.
 
void reset () override
 Resets the state of the entire model.
 
void resetTraining () override
 Resets any trained representations of the model, but does not reset the entire state.
 
void storeClassificator ()
 store Fitter into text file in folder /datadriven/classificator/
 
void update (Dataset &dataset) override
 Updates the models for each class based on new data (streaming or batch learning)
 
void update (Dataset &datasetP, Dataset &datasetQ) override
 
void updateRegularization (double lambda) override
 Updates the regularization parameter lambda of the underlying model.
 
- Public Member Functions inherited from sgpp::datadriven::ModelFittingBase
DatasetgetDataset ()
 
FitterConfigurationgetFitterConfiguration ()
 Get or set the configuration of the fitter object.
 
const FitterConfigurationgetFitterConfiguration () const
 Get the configuration of the fitter object.
 
virtual GridgetGrid ()
 Return the learned grid.
 
virtual std::shared_ptr< BlacsProcessGridgetProcessGrid () const
 
virtual DataVectorgetSurpluses ()
 Return the learned hierarchical surpluses.
 
 ModelFittingBase ()
 Default constructor.
 
 ModelFittingBase (const ModelFittingBase &rhs)=delete
 Copy constructor - we cannot deep copy all member variables yet.
 
 ModelFittingBase (ModelFittingBase &&rhs)=default
 Move constructor.
 
ModelFittingBaseoperator= (const ModelFittingBase &rhs)=delete
 Copy assign operator - we cannot deep copy all member variables yet.
 
ModelFittingBaseoperator= (ModelFittingBase &&rhs)=default
 Move assign operator.
 
virtual ~ModelFittingBase ()=default
 virtual destructor.
 

Additional Inherited Members

- Public Attributes inherited from sgpp::datadriven::ModelFittingBase
bool verboseSolver
 Whether the Solver produces output or not.
 
- Protected Member Functions inherited from sgpp::datadriven::ModelFittingBase
GridbuildGrid (const sgpp::base::GeneralGridConfiguration &gridConfig) const
 Factory member function that generates a grid from configuration.
 
GridbuildGrid (const sgpp::base::GeneralGridConfiguration &gridConfig, const GeometryConfiguration &geometryConfig) const
 Factory member function that generates a grid from configuration.
 
SLESolverbuildSolver (const SLESolverConfiguration &config) const
 Factory member function to build the solver for the least squares regression problem according to the config.
 
std::set< std::set< size_t > > getInteractions (const GeometryConfiguration &geometryConfig)
 
void reconfigureSolver (SLESolver &solver, const SLESolverConfiguration &config) const
 Configure solver based on the desired configuration.
 
- Protected Attributes inherited from sgpp::datadriven::ModelFittingBase
std::unique_ptr< FitterConfigurationconfig
 Configuration object for the fitter.
 
Datasetdataset
 Pointer to sgpp::datadriven::Dataset.
 
DatasetextraDataset
 
std::unique_ptr< std::set< std::set< size_t > > > interactions
 
std::unique_ptr< SLESolversolver
 Solver for the learning problem.
 

Detailed Description

Fitter object that encapsulates density based classification using instances of ModelFittingDensityEstimation for each class.

Constructor & Destructor Documentation

◆ ModelFittingClassification() [1/2]

sgpp::datadriven::ModelFittingClassification::ModelFittingClassification ( const FitterConfigurationClassification config)
explicit

Constructor.

Parameters
configconfiguration object that specifies grid, refinement, and regularization

References sgpp::datadriven::ModelFittingBase::config, and sgpp::datadriven::FitterConfiguration::getParallelConfig().

◆ ModelFittingClassification() [2/2]

sgpp::datadriven::ModelFittingClassification::ModelFittingClassification ( const FitterConfigurationClassification config,
std::shared_ptr< DBMatObjectStore objectStore 
)
explicit

Constructor with specified object store.

Parameters
configConfiguration object that specifies grid, refinement, and regularization
objectStoreOffline object store for already decomposed offline objects.

Member Function Documentation

◆ adapt()

◆ computeResidual()

double sgpp::datadriven::ModelFittingClassification::computeResidual ( DataMatrix validationData) const
inlineoverridevirtual

Should compute some kind of Residual to evaluate the fit of the model.

In the case of density estimation, this is || R * alpha_lambda - b_val ||_2

This is useful for unsupervised learning models, where normal evaluation cannot be used as there are no targets.

For classification, this is not implemented, as accuracy should be used in this case.

Parameters
validationDataMatrix for validation data
Returns
the residual score

Implements sgpp::datadriven::ModelFittingBase.

◆ evaluate() [1/2]

double sgpp::datadriven::ModelFittingClassification::evaluate ( const DataVector sample)
overridevirtual

Predict the class of a data sample based on the density of the sample for each model.

Parameters
samplethe sample point to classify
Returns
the predicted class label

Implements sgpp::datadriven::ModelFittingBase.

Referenced by evaluate(), sgpp::datadriven::VisualizerClassification::storeHeatmapJsonClassification(), and sgpp::datadriven::VisualizerClassification::storeHeatmapJsonClassification().

◆ evaluate() [2/2]

void sgpp::datadriven::ModelFittingClassification::evaluate ( DataMatrix samples,
DataVector results 
)
overridevirtual

◆ fit() [1/2]

void sgpp::datadriven::ModelFittingClassification::fit ( Dataset dataset)
overridevirtual

Fits the models for all classes based on the data given in the dataset parameter.

Parameters
datasetthe training dataset that is used to fit the models

Implements sgpp::datadriven::ModelFittingBase.

References reset(), and update().

◆ fit() [2/2]

void sgpp::datadriven::ModelFittingClassification::fit ( Dataset datasetP,
Dataset datasetQ 
)
inlineoverridevirtual

◆ getClassIdx()

std::map< double, size_t > sgpp::datadriven::ModelFittingClassification::getClassIdx ( )

obtain the index mapping for each label class.

To be used in VisualizerClassification

Referenced by sgpp::datadriven::VisualizerClassification::runVisualization().

◆ getModels()

std::vector< std::unique_ptr< ModelFittingDensityEstimation > > * sgpp::datadriven::ModelFittingClassification::getModels ( )

◆ reset()

void sgpp::datadriven::ModelFittingClassification::reset ( )
overridevirtual

Resets the state of the entire model.

Implements sgpp::datadriven::ModelFittingBase.

Referenced by fit().

◆ resetTraining()

void sgpp::datadriven::ModelFittingClassification::resetTraining ( )
overridevirtual

Resets any trained representations of the model, but does not reset the entire state.

Decompositions are not discarded, but can be reused.

Implements sgpp::datadriven::ModelFittingBase.

◆ storeClassificator()

void sgpp::datadriven::ModelFittingClassification::storeClassificator ( )

store Fitter into text file in folder /datadriven/classificator/

◆ update() [1/2]

◆ update() [2/2]

void sgpp::datadriven::ModelFittingClassification::update ( Dataset datasetP,
Dataset datasetQ 
)
inlineoverridevirtual

◆ updateRegularization()

void sgpp::datadriven::ModelFittingClassification::updateRegularization ( double  lambda)
overridevirtual

Updates the regularization parameter lambda of the underlying model.

Parameters
lambdathe new lambda parameter

Implements sgpp::datadriven::ModelFittingBase.

References lambda.


The documentation for this class was generated from the following files: