mlpack 3.4.2
adaboost.hpp
Go to the documentation of this file.
1
28#ifndef MLPACK_METHODS_ADABOOST_ADABOOST_HPP
29#define MLPACK_METHODS_ADABOOST_ADABOOST_HPP
30
31#include <mlpack/prereqs.hpp>
34
35namespace mlpack {
36namespace adaboost {
37
79template<typename WeakLearnerType = mlpack::perceptron::Perceptron<>,
80 typename MatType = arma::mat>
82{
83 public:
97 AdaBoost(const MatType& data,
98 const arma::Row<size_t>& labels,
99 const size_t numClasses,
100 const WeakLearnerType& other,
101 const size_t iterations = 100,
102 const double tolerance = 1e-6);
103
108 AdaBoost(const double tolerance = 1e-6);
109
111 double Tolerance() const { return tolerance; }
113 double& Tolerance() { return tolerance; }
114
116 size_t NumClasses() const { return numClasses; }
117
119 size_t WeakLearners() const { return alpha.size(); }
120
122 double Alpha(const size_t i) const { return alpha[i]; }
124 double& Alpha(const size_t i) { return alpha[i]; }
125
127 const WeakLearnerType& WeakLearner(const size_t i) const { return wl[i]; }
129 WeakLearnerType& WeakLearner(const size_t i) { return wl[i]; }
130
146 double Train(const MatType& data,
147 const arma::Row<size_t>& labels,
148 const size_t numClasses,
149 const WeakLearnerType& learner,
150 const size_t iterations = 100,
151 const double tolerance = 1e-6);
152
162 void Classify(const MatType& test,
163 arma::Row<size_t>& predictedLabels,
164 arma::mat& probabilities);
165
173 void Classify(const MatType& test,
174 arma::Row<size_t>& predictedLabels);
175
179 template<typename Archive>
180 void serialize(Archive& ar, const unsigned int /* version */);
181
182 private:
184 size_t numClasses;
185 // The tolerance for change in rt and when to stop.
186 double tolerance;
187
189 std::vector<WeakLearnerType> wl;
191 std::vector<double> alpha;
192}; // class AdaBoost
193
194} // namespace adaboost
195} // namespace mlpack
196
198namespace boost {
199namespace serialization {
200
201template<typename WeakLearnerType, typename MatType>
202struct version<mlpack::adaboost::AdaBoost<WeakLearnerType, MatType>>
203{
204 BOOST_STATIC_CONSTANT(int, value = 1);
205};
206
207} // namespace serialization
208} // namespace boost
209
210// Include implementation.
211#include "adaboost_impl.hpp"
212
213#endif
The AdaBoost class.
Definition: adaboost.hpp:82
size_t NumClasses() const
Get the number of classes this model is trained on.
Definition: adaboost.hpp:116
AdaBoost(const double tolerance=1e-6)
Create the AdaBoost object without training.
void Classify(const MatType &test, arma::Row< size_t > &predictedLabels)
Classify the given test points.
size_t WeakLearners() const
Get the number of weak learners in the model.
Definition: adaboost.hpp:119
double & Alpha(const size_t i)
Modify the weight for the given weak learner (be careful!).
Definition: adaboost.hpp:124
AdaBoost(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const WeakLearnerType &other, const size_t iterations=100, const double tolerance=1e-6)
Constructor.
double & Tolerance()
Modify the tolerance for stopping the optimization during training.
Definition: adaboost.hpp:113
double Tolerance() const
Get the tolerance for stopping the optimization during training.
Definition: adaboost.hpp:111
WeakLearnerType & WeakLearner(const size_t i)
Modify the given weak learner (be careful!).
Definition: adaboost.hpp:129
void Classify(const MatType &test, arma::Row< size_t > &predictedLabels, arma::mat &probabilities)
Classify the given test points.
const WeakLearnerType & WeakLearner(const size_t i) const
Get the given weak learner.
Definition: adaboost.hpp:127
double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const WeakLearnerType &learner, const size_t iterations=100, const double tolerance=1e-6)
Train AdaBoost on the given dataset.
void serialize(Archive &ar, const unsigned int)
Serialize the AdaBoost model.
double Alpha(const size_t i) const
Get the weights for the given weak learner.
Definition: adaboost.hpp:122
Set the serialization version of the adaboost class.
Definition: adaboost.hpp:198
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.