![]() |
SG++-Doxygen-Documentation
|
Fitter object that encapsulates density based classification using instances of ModelFittingDensityEstimation for each class. More...
#include <ModelFittingClassification.hpp>
Public Member Functions | |
bool | adapt () override |
Improve the accuracy of the classification by refining or coarsening the grids of each class. | |
double | computeResidual (DataMatrix &validationData) const override |
Should compute some kind of Residual to evaluate the fit of the model. | |
double | evaluate (const DataVector &sample) override |
Predict the class of a data sample based on the density of the sample for each model. | |
void | evaluate (DataMatrix &samples, DataVector &results) override |
Predicts the class for a set of data points based on the learned densities for each class. | |
void | fit (Dataset &dataset) override |
Fits the models for all classes based on the data given in the dataset parameter. | |
void | fit (Dataset &datasetP, Dataset &datasetQ) override |
std::map< double, size_t > | getClassIdx () |
obtain the index mapping for each label class. | |
std::vector< std::unique_ptr< ModelFittingDensityEstimation > > * | getModels () |
obtain the density estimation models per each class. | |
ModelFittingClassification (const FitterConfigurationClassification &config) | |
Constructor. | |
ModelFittingClassification (const FitterConfigurationClassification &config, std::shared_ptr< DBMatObjectStore > objectStore) | |
Constructor with specified object store. | |
void | reset () override |
Resets the state of the entire model. | |
void | resetTraining () override |
Resets any trained representations of the model, but does not reset the entire state. | |
void | storeClassificator () |
store Fitter into text file in folder /datadriven/classificator/ | |
void | update (Dataset &dataset) override |
Updates the models for each class based on new data (streaming or batch learning) | |
void | update (Dataset &datasetP, Dataset &datasetQ) override |
void | updateRegularization (double lambda) override |
Updates the regularization parameter lambda of the underlying model. | |
![]() | |
Dataset * | getDataset () |
FitterConfiguration & | getFitterConfiguration () |
Get or set the configuration of the fitter object. | |
const FitterConfiguration & | getFitterConfiguration () const |
Get the configuration of the fitter object. | |
virtual Grid & | getGrid () |
Return the learned grid. | |
virtual std::shared_ptr< BlacsProcessGrid > | getProcessGrid () const |
virtual DataVector & | getSurpluses () |
Return the learned hierarchical surpluses. | |
ModelFittingBase () | |
Default constructor. | |
ModelFittingBase (const ModelFittingBase &rhs)=delete | |
Copy constructor - we cannot deep copy all member variables yet. | |
ModelFittingBase (ModelFittingBase &&rhs)=default | |
Move constructor. | |
ModelFittingBase & | operator= (const ModelFittingBase &rhs)=delete |
Copy assign operator - we cannot deep copy all member variables yet. | |
ModelFittingBase & | operator= (ModelFittingBase &&rhs)=default |
Move assign operator. | |
virtual | ~ModelFittingBase ()=default |
virtual destructor. | |
Additional Inherited Members | |
![]() | |
bool | verboseSolver |
Whether the Solver produces output or not. | |
![]() | |
Grid * | buildGrid (const sgpp::base::GeneralGridConfiguration &gridConfig) const |
Factory member function that generates a grid from configuration. | |
Grid * | buildGrid (const sgpp::base::GeneralGridConfiguration &gridConfig, const GeometryConfiguration &geometryConfig) const |
Factory member function that generates a grid from configuration. | |
SLESolver * | buildSolver (const SLESolverConfiguration &config) const |
Factory member function to build the solver for the least squares regression problem according to the config. | |
std::set< std::set< size_t > > | getInteractions (const GeometryConfiguration &geometryConfig) |
void | reconfigureSolver (SLESolver &solver, const SLESolverConfiguration &config) const |
Configure solver based on the desired configuration. | |
![]() | |
std::unique_ptr< FitterConfiguration > | config |
Configuration object for the fitter. | |
Dataset * | dataset |
Pointer to sgpp::datadriven::Dataset. | |
Dataset * | extraDataset |
std::unique_ptr< std::set< std::set< size_t > > > | interactions |
std::unique_ptr< SLESolver > | solver |
Solver for the learning problem. | |
Fitter object that encapsulates density based classification using instances of ModelFittingDensityEstimation for each class.
|
explicit |
Constructor.
config | configuration object that specifies grid, refinement, and regularization |
References sgpp::datadriven::ModelFittingBase::config, and sgpp::datadriven::FitterConfiguration::getParallelConfig().
|
explicit |
Constructor with specified object store.
config | Configuration object that specifies grid, refinement, and regularization |
objectStore | Offline object store for already decomposed offline objects. |
|
overridevirtual |
Improve the accuracy of the classification by refining or coarsening the grids of each class.
Coarsening is currently only implemented for RefinementFunctorType::Classification
Implements sgpp::datadriven::ModelFittingBase.
References sgpp::datadriven::ClassificationRefinementFunctor::adaptAllGrids(), sgpp::base::ComponentGrid, sgpp::datadriven::ModelFittingBase::config, sgpp::datadriven::ModelFittingBase::getGrid(), sgpp::datadriven::GridFactory::getInteractions(), sgpp::datadriven::ModelFittingBase::getSurpluses(), sgpp::base::AdaptivityConfiguration::numRefinements_, sgpp::datadriven::MultiGridRefinementFunctor::preComputeEvaluations(), sgpp::base::AdaptivityConfiguration::precomputeEvaluations_, sgpp::datadriven::MultipleClassRefinementFunctor::refine(), sgpp::base::AdaptivityConfiguration::refinementFunctorType_, sgpp::datadriven::MultiGridRefinementFunctor::setGridIndex(), and sgpp::datadriven::GeometryConfiguration::stencils_.
|
inlineoverridevirtual |
Should compute some kind of Residual to evaluate the fit of the model.
In the case of density estimation, this is || R * alpha_lambda - b_val ||_2
This is useful for unsupervised learning models, where normal evaluation cannot be used as there are no targets.
For classification, this is not implemented, as accuracy should be used in this case.
validationData | Matrix for validation data |
Implements sgpp::datadriven::ModelFittingBase.
|
overridevirtual |
Predict the class of a data sample based on the density of the sample for each model.
sample | the sample point to classify |
Implements sgpp::datadriven::ModelFittingBase.
Referenced by evaluate(), sgpp::datadriven::VisualizerClassification::storeHeatmapJsonClassification(), and sgpp::datadriven::VisualizerClassification::storeHeatmapJsonClassification().
|
overridevirtual |
Predicts the class for a set of data points based on the learned densities for each class.
samples | matrix where each row represents a data sample |
results | vector to output the predicted classes |
Implements sgpp::datadriven::ModelFittingBase.
References sgpp::datadriven::ModelFittingBase::config, evaluate(), sgpp::datadriven::DataVectorDistributed::getLocalPointer(), sgpp::datadriven::DataVectorDistributed::getLocalRows(), sgpp::datadriven::DataVectorDistributed::localToGlobalRowIndex(), sgpp::base::DataVector::mult(), sgpp::base::DataVector::set(), and sgpp::datadriven::DataVectorDistributed::toLocalDataVector().
|
overridevirtual |
Fits the models for all classes based on the data given in the dataset parameter.
dataset | the training dataset that is used to fit the models |
Implements sgpp::datadriven::ModelFittingBase.
|
inlineoverridevirtual |
Implements sgpp::datadriven::ModelFittingBase.
std::map< double, size_t > sgpp::datadriven::ModelFittingClassification::getClassIdx | ( | ) |
obtain the index mapping for each label class.
To be used in VisualizerClassification
Referenced by sgpp::datadriven::VisualizerClassification::runVisualization().
std::vector< std::unique_ptr< ModelFittingDensityEstimation > > * sgpp::datadriven::ModelFittingClassification::getModels | ( | ) |
obtain the density estimation models per each class.
To be used in VisualizerClassification
Referenced by sgpp::datadriven::VisualizerClassification::runVisualization(), sgpp::datadriven::VisualizerClassification::storeHeatmapJsonClassification(), and sgpp::datadriven::VisualizerClassification::storeHeatmapJsonClassification().
|
overridevirtual |
Resets the state of the entire model.
Implements sgpp::datadriven::ModelFittingBase.
Referenced by fit().
|
overridevirtual |
Resets any trained representations of the model, but does not reset the entire state.
Decompositions are not discarded, but can be reused.
Implements sgpp::datadriven::ModelFittingBase.
void sgpp::datadriven::ModelFittingClassification::storeClassificator | ( | ) |
store Fitter into text file in folder /datadriven/classificator/
|
overridevirtual |
Updates the models for each class based on new data (streaming or batch learning)
dataset | the new data |
Implements sgpp::datadriven::ModelFittingBase.
References sgpp::datadriven::ModelFittingBase::dataset, sgpp::base::DataVector::get(), sgpp::datadriven::Dataset::getData(), sgpp::datadriven::Dataset::getDimension(), sgpp::datadriven::Dataset::getNumberInstances(), sgpp::base::DataMatrix::getRow(), and sgpp::datadriven::Dataset::getTargets().
Referenced by fit(), python.uq.refinement.RefinementStrategy.Ranking::rank(), and python.learner.LearnedKnowledge.LearnedKnowledge::setMemento().
|
overridevirtual |
Updates the regularization parameter lambda of the underlying model.
lambda | the new lambda parameter |
Implements sgpp::datadriven::ModelFittingBase.
References lambda.