12#ifndef MLPACK_METHODS_SOFTMAX_REGRESSION_SOFTMAX_REGRESSION_HPP
13#define MLPACK_METHODS_SOFTMAX_REGRESSION_SOFTMAX_REGRESSION_HPP
16#include <ensmallen.hpp>
72 const size_t numClasses = 0,
73 const bool fitIntercept =
false);
88 template<
typename OptimizerType = ens::L_BFGS>
90 const arma::Row<size_t>& labels,
91 const size_t numClasses,
92 const double lambda = 0.0001,
93 const bool fitIntercept =
false,
94 OptimizerType optimizer = OptimizerType());
112 template<
typename OptimizerType,
typename... CallbackTypes>
114 const arma::Row<size_t>& labels,
115 const size_t numClasses,
117 const bool fitIntercept,
118 OptimizerType optimizer,
119 CallbackTypes&&... callbacks);
128 void Classify(
const arma::mat& dataset, arma::Row<size_t>& labels)
const;
136 template<
typename VecType>
151 arma::Row<size_t>& labels,
152 arma::mat& probabilities)
const;
161 arma::mat& probabilities)
const;
172 const arma::Row<size_t>& labels)
const;
183 template<
typename OptimizerType = ens::L_BFGS>
185 const arma::Row<size_t>& labels,
186 const size_t numClasses,
187 OptimizerType optimizer = OptimizerType());
201 template<
typename OptimizerType = ens::L_BFGS,
typename... CallbackTypes>
203 const arma::Row<size_t>& labels,
204 const size_t numClasses,
205 OptimizerType optimizer,
206 CallbackTypes&&... callbacks);
228 {
return fitIntercept ? parameters.n_cols - 1:
234 template<
typename Archive>
237 ar & BOOST_SERIALIZATION_NVP(parameters);
238 ar & BOOST_SERIALIZATION_NVP(numClasses);
239 ar & BOOST_SERIALIZATION_NVP(lambda);
240 ar & BOOST_SERIALIZATION_NVP(fitIntercept);
245 arma::mat parameters;
258#include "softmax_regression_impl.hpp"
Softmax Regression is a classifier which can be used for classification when the data available can t...
size_t NumClasses() const
Gets the number of classes.
arma::mat & Parameters()
Get the model parameters.
const arma::mat & Parameters() const
Get the model parameters.
void Classify(const arma::mat &dataset, arma::Row< size_t > &labels) const
Classify the given points, returning the predicted labels for each point.
bool FitIntercept() const
Gets the intercept term flag. We can't change this after training.
double Train(const arma::mat &data, const arma::Row< size_t > &labels, const size_t numClasses, OptimizerType optimizer=OptimizerType())
Train the softmax regression with the given training data.
double & Lambda()
Sets the regularization parameter.
double Lambda() const
Gets the regularization parameter.
void Classify(const arma::mat &dataset, arma::Row< size_t > &labels, arma::mat &probabilities) const
Classify the given points, returning class probabilities and predicted class label for each point.
void Classify(const arma::mat &dataset, arma::mat &probabilities) const
Classify the given points, returning class probabilities for each point.
size_t FeatureSize() const
Gets the features size of the training data.
size_t & NumClasses()
Sets the number of classes.
SoftmaxRegression(const arma::mat &data, const arma::Row< size_t > &labels, const size_t numClasses, const double lambda, const bool fitIntercept, OptimizerType optimizer, CallbackTypes &&... callbacks)
Construct the SoftmaxRegression class with the provided data and labels.
SoftmaxRegression(const arma::mat &data, const arma::Row< size_t > &labels, const size_t numClasses, const double lambda=0.0001, const bool fitIntercept=false, OptimizerType optimizer=OptimizerType())
Construct the SoftmaxRegression class with the provided data and labels.
size_t Classify(const VecType &point) const
Classify the given point.
SoftmaxRegression(const size_t inputSize=0, const size_t numClasses=0, const bool fitIntercept=false)
Initialize the SoftmaxRegression without performing training.
double Train(const arma::mat &data, const arma::Row< size_t > &labels, const size_t numClasses, OptimizerType optimizer, CallbackTypes &&... callbacks)
Train the softmax regression with the given training data.
void serialize(Archive &ar, const unsigned int)
Serialize the SoftmaxRegression model.
double ComputeAccuracy(const arma::mat &testData, const arma::Row< size_t > &labels) const
Computes accuracy of the learned model given the feature data and the labels associated with each dat...
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.