mlpack 3.4.2
linear_svm.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_HPP
13#define MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_HPP
14
15#include <mlpack/prereqs.hpp>
16#include <ensmallen.hpp>
17
19
20namespace mlpack {
21namespace svm {
22
79template <typename MatType = arma::mat>
81{
82 public:
101 template <typename OptimizerType, typename... CallbackTypes>
102 LinearSVM(const MatType& data,
103 const arma::Row<size_t>& labels,
104 const size_t numClasses,
105 const double lambda,
106 const double delta,
107 const bool fitIntercept,
108 OptimizerType optimizer,
109 CallbackTypes&&... callbacks);
110
126 template <typename OptimizerType = ens::L_BFGS>
127 LinearSVM(const MatType& data,
128 const arma::Row<size_t>& labels,
129 const size_t numClasses = 2,
130 const double lambda = 0.0001,
131 const double delta = 1.0,
132 const bool fitIntercept = false,
133 OptimizerType optimizer = OptimizerType());
134
146 LinearSVM(const size_t inputSize,
147 const size_t numClasses = 0,
148 const double lambda = 0.0001,
149 const double delta = 1.0,
150 const bool fitIntercept = false);
161 LinearSVM(const size_t numClasses = 0,
162 const double lambda = 0.0001,
163 const double delta = 1.0,
164 const bool fitIntercept = false);
165
175 void Classify(const MatType& data,
176 arma::Row<size_t>& labels) const;
177
189 void Classify(const MatType& data,
190 arma::Row<size_t>& labels,
191 arma::mat& scores) const;
192
199 void Classify(const MatType& data,
200 arma::mat& scores) const;
201
210 template<typename VecType>
211 size_t Classify(const VecType& point) const;
212
222 double ComputeAccuracy(const MatType& testData,
223 const arma::Row<size_t>& testLabels) const;
224
238 template <typename OptimizerType, typename... CallbackTypes>
239 double Train(const MatType& data,
240 const arma::Row<size_t>& labels,
241 const size_t numClasses,
242 OptimizerType optimizer,
243 CallbackTypes&&... callbacks);
244
255 template <typename OptimizerType = ens::L_BFGS>
256 double Train(const MatType& data,
257 const arma::Row<size_t>& labels,
258 const size_t numClasses = 2,
259 OptimizerType optimizer = OptimizerType());
260
261
263 size_t& NumClasses() { return numClasses; }
265 size_t NumClasses() const { return numClasses; }
266
268 double& Lambda() { return lambda; }
270 double Lambda() const { return lambda; }
271
273 double& Delta() { return delta; }
275 double Delta() const { return delta; }
276
278 bool& FitIntercept() { return fitIntercept; }
279
281 arma::mat& Parameters() { return parameters; }
283 const arma::mat& Parameters() const { return parameters; }
284
286 size_t FeatureSize() const
287 { return fitIntercept ? parameters.n_rows - 1 :
288 parameters.n_rows; }
289
293 template<typename Archive>
294 void serialize(Archive& ar, const unsigned int /* version */)
295 {
296 ar & BOOST_SERIALIZATION_NVP(parameters);
297 ar & BOOST_SERIALIZATION_NVP(numClasses);
298 ar & BOOST_SERIALIZATION_NVP(lambda);
299 ar & BOOST_SERIALIZATION_NVP(fitIntercept);
300 }
301
302 private:
304 arma::mat parameters;
306 size_t numClasses;
308 double lambda;
310 double delta;
312 bool fitIntercept;
313};
314
315} // namespace svm
316} // namespace mlpack
317
318// Include implementation.
319#include "linear_svm_impl.hpp"
320
321#endif // MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_HPP
The LinearSVM class implements an L2-regularized support vector machine model, and supports training ...
Definition: linear_svm.hpp:81
size_t NumClasses() const
Gets the number of classes.
Definition: linear_svm.hpp:265
arma::mat & Parameters()
Set the model parameters.
Definition: linear_svm.hpp:281
LinearSVM(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses=2, const double lambda=0.0001, const double delta=1.0, const bool fitIntercept=false, OptimizerType optimizer=OptimizerType())
Construct the LinearSVM class with the provided data and labels.
const arma::mat & Parameters() const
Get the model parameters.
Definition: linear_svm.hpp:283
LinearSVM(const size_t numClasses=0, const double lambda=0.0001, const double delta=1.0, const bool fitIntercept=false)
Initialize the Linear SVM without performing training.
double & Lambda()
Sets the regularization parameter.
Definition: linear_svm.hpp:268
double Lambda() const
Gets the regularization parameter.
Definition: linear_svm.hpp:270
void Classify(const MatType &data, arma::mat &scores) const
Classify the given points, returning class scores for each point.
double ComputeAccuracy(const MatType &testData, const arma::Row< size_t > &testLabels) const
Computes accuracy of the learned model given the feature data and the labels associated with each dat...
double & Delta()
Sets the margin between the correct class and all other classes.
Definition: linear_svm.hpp:273
size_t FeatureSize() const
Gets the features size of the training data.
Definition: linear_svm.hpp:286
double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, OptimizerType optimizer, CallbackTypes &&... callbacks)
Train the Linear SVM with the given training data.
size_t & NumClasses()
Sets the number of classes.
Definition: linear_svm.hpp:263
void Classify(const MatType &data, arma::Row< size_t > &labels, arma::mat &scores) const
Classify the given points, returning class scores and predicted class label for each point.
double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses=2, OptimizerType optimizer=OptimizerType())
Train the Linear SVM with the given training data.
void Classify(const MatType &data, arma::Row< size_t > &labels) const
Classify the given points, returning the predicted labels for each point.
size_t Classify(const VecType &point) const
Classify the given point.
LinearSVM(const size_t inputSize, const size_t numClasses=0, const double lambda=0.0001, const double delta=1.0, const bool fitIntercept=false)
Initialize the Linear SVM without performing training.
double Delta() const
Gets the margin between the correct class and all other classes.
Definition: linear_svm.hpp:275
bool & FitIntercept()
Sets the intercept term flag.
Definition: linear_svm.hpp:278
LinearSVM(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const double lambda, const double delta, const bool fitIntercept, OptimizerType optimizer, CallbackTypes &&... callbacks)
Construct the LinearSVM class with the provided data and labels.
void serialize(Archive &ar, const unsigned int)
Serialize the LinearSVM model.
Definition: linear_svm.hpp:294
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.