mlpack 3.4.2
random_forest.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP
13#define MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP
14
17#include "bootstrap.hpp"
18
19namespace mlpack {
20namespace tree {
21
22template<typename FitnessFunction = GiniGain,
23 typename DimensionSelectionType = MultipleRandomDimensionSelect,
24 template<typename> class NumericSplitType = BestBinaryNumericSplit,
25 template<typename> class CategoricalSplitType = AllCategoricalSplit,
26 typename ElemType = double>
28{
29 public:
31 typedef DecisionTree<FitnessFunction, NumericSplitType, CategoricalSplitType,
32 DimensionSelectionType, ElemType> DecisionTreeType;
33
39
56 template<typename MatType>
57 RandomForest(const MatType& dataset,
58 const arma::Row<size_t>& labels,
59 const size_t numClasses,
60 const size_t numTrees = 20,
61 const size_t minimumLeafSize = 1,
62 const double minimumGainSplit = 1e-7,
63 const size_t maximumDepth = 0,
64 DimensionSelectionType dimensionSelector =
65 DimensionSelectionType());
66
85 template<typename MatType>
86 RandomForest(const MatType& dataset,
87 const data::DatasetInfo& datasetInfo,
88 const arma::Row<size_t>& labels,
89 const size_t numClasses,
90 const size_t numTrees = 20,
91 const size_t minimumLeafSize = 1,
92 const double minimumGainSplit = 1e-7,
93 const size_t maximumDepth = 0,
94 DimensionSelectionType dimensionSelector =
95 DimensionSelectionType());
96
112 template<typename MatType>
113 RandomForest(const MatType& dataset,
114 const arma::Row<size_t>& labels,
115 const size_t numClasses,
116 const arma::rowvec& weights,
117 const size_t numTrees = 20,
118 const size_t minimumLeafSize = 1,
119 const double minimumGainSplit = 1e-7,
120 const size_t maximumDepth = 0,
121 DimensionSelectionType dimensionSelector =
122 DimensionSelectionType());
123
143 template<typename MatType>
144 RandomForest(const MatType& dataset,
145 const data::DatasetInfo& datasetInfo,
146 const arma::Row<size_t>& labels,
147 const size_t numClasses,
148 const arma::rowvec& weights,
149 const size_t numTrees = 20,
150 const size_t minimumLeafSize = 1,
151 const double minimumGainSplit = 1e-7,
152 const size_t maximumDepth = 0,
153 DimensionSelectionType dimensionSelector =
154 DimensionSelectionType());
155
173 template<typename MatType>
174 double Train(const MatType& data,
175 const arma::Row<size_t>& labels,
176 const size_t numClasses,
177 const size_t numTrees = 20,
178 const size_t minimumLeafSize = 1,
179 const double minimumGainSplit = 1e-7,
180 const size_t maximumDepth = 0,
181 DimensionSelectionType dimensionSelector =
182 DimensionSelectionType());
183
204 template<typename MatType>
205 double Train(const MatType& data,
206 const data::DatasetInfo& datasetInfo,
207 const arma::Row<size_t>& labels,
208 const size_t numClasses,
209 const size_t numTrees = 20,
210 const size_t minimumLeafSize = 1,
211 const double minimumGainSplit = 1e-7,
212 const size_t maximumDepth = 0,
213 DimensionSelectionType dimensionSelector =
214 DimensionSelectionType());
215
234 template<typename MatType>
235 double Train(const MatType& data,
236 const arma::Row<size_t>& labels,
237 const size_t numClasses,
238 const arma::rowvec& weights,
239 const size_t numTrees = 20,
240 const size_t minimumLeafSize = 1,
241 const double minimumGainSplit = 1e-7,
242 const size_t maximumDepth = 0,
243 DimensionSelectionType dimensionSelector =
244 DimensionSelectionType());
245
266 template<typename MatType>
267 double Train(const MatType& data,
268 const data::DatasetInfo& datasetInfo,
269 const arma::Row<size_t>& labels,
270 const size_t numClasses,
271 const arma::rowvec& weights,
272 const size_t numTrees = 20,
273 const size_t minimumLeafSize = 1,
274 const double minimumGainSplit = 1e-7,
275 const size_t maximumDepth = 0,
276 DimensionSelectionType dimensionSelector =
277 DimensionSelectionType());
278
285 template<typename VecType>
286 size_t Classify(const VecType& point) const;
287
297 template<typename VecType>
298 void Classify(const VecType& point,
299 size_t& prediction,
300 arma::vec& probabilities) const;
301
309 template<typename MatType>
310 void Classify(const MatType& data,
311 arma::Row<size_t>& predictions) const;
312
322 template<typename MatType>
323 void Classify(const MatType& data,
324 arma::Row<size_t>& predictions,
325 arma::mat& probabilities) const;
326
328 const DecisionTreeType& Tree(const size_t i) const { return trees[i]; }
330 DecisionTreeType& Tree(const size_t i) { return trees[i]; }
331
333 size_t NumTrees() const { return trees.size(); }
334
338 template<typename Archive>
339 void serialize(Archive& ar, const unsigned int /* version */);
340
341 private:
362 template<bool UseWeights, bool UseDatasetInfo, typename MatType>
363 double Train(const MatType& data,
364 const data::DatasetInfo& datasetInfo,
365 const arma::Row<size_t>& labels,
366 const size_t numClasses,
367 const arma::rowvec& weights,
368 const size_t numTrees,
369 const size_t minimumLeafSize,
370 const double minimumGainSplit,
371 const size_t maximumDepth,
372 DimensionSelectionType& dimensionSelector);
373
375 std::vector<DecisionTreeType> trees;
376};
377
378} // namespace tree
379} // namespace mlpack
380
381// Include implementation.
382#include "random_forest_impl.hpp"
383
384#endif
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
This class implements a generic decision tree learner.
void Classify(const MatType &data, arma::Row< size_t > &predictions, arma::mat &probabilities) const
Predict the classes of each point in the given dataset, also returning the predicted class probabilit...
const DecisionTreeType & Tree(const size_t i) const
Access a tree in the forest.
double Train(const MatType &data, const data::DatasetInfo &datasetInfo, const arma::Row< size_t > &labels, const size_t numClasses, const size_t numTrees=20, const size_t minimumLeafSize=1, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the random forest on the given labeled training data with the given dataset info and the given ...
RandomForest()
Construct the random forest without any training or specifying the number of trees.
double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const arma::rowvec &weights, const size_t numTrees=20, const size_t minimumLeafSize=1, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the random forest on the given weighted labeled training data with the given number of trees.
double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t numTrees=20, const size_t minimumLeafSize=1, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the random forest on the given labeled training data with the given number of trees.
DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, DimensionSelectionType, ElemType > DecisionTreeType
Allow access to the underlying decision tree type.
DecisionTreeType & Tree(const size_t i)
Modify a tree in the forest (be careful!).
size_t Classify(const VecType &point) const
Predict the class of the given point.
RandomForest(const MatType &dataset, const data::DatasetInfo &datasetInfo, const arma::Row< size_t > &labels, const size_t numClasses, const arma::rowvec &weights, const size_t numTrees=20, const size_t minimumLeafSize=1, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Create a random forest, training on the given weighted labeled training data with the given dataset i...
void Classify(const MatType &data, arma::Row< size_t > &predictions) const
Predict the classes of each point in the given dataset.
void Classify(const VecType &point, size_t &prediction, arma::vec &probabilities) const
Predict the class of the given point and return the predicted class probabilities for each class.
size_t NumTrees() const
Get the number of trees in the forest.
RandomForest(const MatType &dataset, const arma::Row< size_t > &labels, const size_t numClasses, const arma::rowvec &weights, const size_t numTrees=20, const size_t minimumLeafSize=1, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Create a random forest, training on the given weighted labeled training data with the given number of...
double Train(const MatType &data, const data::DatasetInfo &datasetInfo, const arma::Row< size_t > &labels, const size_t numClasses, const arma::rowvec &weights, const size_t numTrees=20, const size_t minimumLeafSize=1, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the random forest on the given weighted labeled training data with the given dataset info and t...
void serialize(Archive &ar, const unsigned int)
Serialize the random forest.
RandomForest(const MatType &dataset, const data::DatasetInfo &datasetInfo, const arma::Row< size_t > &labels, const size_t numClasses, const size_t numTrees=20, const size_t minimumLeafSize=1, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Create a random forest, training on the given labeled training data with the given dataset info and t...
RandomForest(const MatType &dataset, const arma::Row< size_t > &labels, const size_t numClasses, const size_t numTrees=20, const size_t minimumLeafSize=1, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Create a random forest, training on the given labeled training data with the given number of trees.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1