mlpack 3.4.2
decision_tree.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
14#define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
15
16#include <mlpack/prereqs.hpp>
17#include "gini_gain.hpp"
18#include "information_gain.hpp"
22#include <type_traits>
23
24namespace mlpack {
25namespace tree {
26
34template<typename FitnessFunction = GiniGain,
35 template<typename> class NumericSplitType = BestBinaryNumericSplit,
36 template<typename> class CategoricalSplitType = AllCategoricalSplit,
37 typename DimensionSelectionType = AllDimensionSelect,
38 typename ElemType = double,
39 bool NoRecursion = false>
41 public NumericSplitType<FitnessFunction>::template
42 AuxiliarySplitInfo<ElemType>,
43 public CategoricalSplitType<FitnessFunction>::template
44 AuxiliarySplitInfo<ElemType>
45{
46 public:
48 typedef NumericSplitType<FitnessFunction> NumericSplit;
50 typedef CategoricalSplitType<FitnessFunction> CategoricalSplit;
52 typedef DimensionSelectionType DimensionSelection;
53
71 template<typename MatType, typename LabelsType>
72 DecisionTree(MatType data,
73 const data::DatasetInfo& datasetInfo,
74 LabelsType labels,
75 const size_t numClasses,
76 const size_t minimumLeafSize = 10,
77 const double minimumGainSplit = 1e-7,
78 const size_t maximumDepth = 0,
79 DimensionSelectionType dimensionSelector =
80 DimensionSelectionType());
81
98 template<typename MatType, typename LabelsType>
99 DecisionTree(MatType data,
100 LabelsType labels,
101 const size_t numClasses,
102 const size_t minimumLeafSize = 10,
103 const double minimumGainSplit = 1e-7,
104 const size_t maximumDepth = 0,
105 DimensionSelectionType dimensionSelector =
106 DimensionSelectionType());
107
127 template<typename MatType, typename LabelsType, typename WeightsType>
129 MatType data,
130 const data::DatasetInfo& datasetInfo,
131 LabelsType labels,
132 const size_t numClasses,
133 WeightsType weights,
134 const size_t minimumLeafSize = 10,
135 const double minimumGainSplit = 1e-7,
136 const size_t maximumDepth = 0,
137 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
138 const std::enable_if_t<arma::is_arma_type<
139 typename std::remove_reference<WeightsType>::type>::value>* = 0);
140
159 template<typename MatType, typename LabelsType, typename WeightsType>
161 const DecisionTree& other,
162 MatType data,
163 const data::DatasetInfo& datasetInfo,
164 LabelsType labels,
165 const size_t numClasses,
166 WeightsType weights,
167 const size_t minimumLeafSize = 10,
168 const double minimumGainSplit = 1e-7,
169 const std::enable_if_t<arma::is_arma_type<
170 typename std::remove_reference<WeightsType>::type>::value>* = 0);
189 template<typename MatType, typename LabelsType, typename WeightsType>
191 MatType data,
192 LabelsType labels,
193 const size_t numClasses,
194 WeightsType weights,
195 const size_t minimumLeafSize = 10,
196 const double minimumGainSplit = 1e-7,
197 const size_t maximumDepth = 0,
198 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
199 const std::enable_if_t<arma::is_arma_type<
200 typename std::remove_reference<WeightsType>::type>::value>* = 0);
201
220 template<typename MatType, typename LabelsType, typename WeightsType>
222 const DecisionTree& other,
223 MatType data,
224 LabelsType labels,
225 const size_t numClasses,
226 WeightsType weights,
227 const size_t minimumLeafSize = 10,
228 const double minimumGainSplit = 1e-7,
229 const size_t maximumDepth = 0,
230 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
231 const std::enable_if_t<arma::is_arma_type<
232 typename std::remove_reference<WeightsType>::type>::value>* = 0);
233
240 DecisionTree(const size_t numClasses = 1);
241
249
256
264
271
276
296 template<typename MatType, typename LabelsType>
297 double Train(MatType data,
298 const data::DatasetInfo& datasetInfo,
299 LabelsType labels,
300 const size_t numClasses,
301 const size_t minimumLeafSize = 10,
302 const double minimumGainSplit = 1e-7,
303 const size_t maximumDepth = 0,
304 DimensionSelectionType dimensionSelector =
305 DimensionSelectionType());
306
324 template<typename MatType, typename LabelsType>
325 double Train(MatType data,
326 LabelsType labels,
327 const size_t numClasses,
328 const size_t minimumLeafSize = 10,
329 const double minimumGainSplit = 1e-7,
330 const size_t maximumDepth = 0,
331 DimensionSelectionType dimensionSelector =
332 DimensionSelectionType());
333
355 template<typename MatType, typename LabelsType, typename WeightsType>
356 double Train(MatType data,
357 const data::DatasetInfo& datasetInfo,
358 LabelsType labels,
359 const size_t numClasses,
360 WeightsType weights,
361 const size_t minimumLeafSize = 10,
362 const double minimumGainSplit = 1e-7,
363 const size_t maximumDepth = 0,
364 DimensionSelectionType dimensionSelector =
365 DimensionSelectionType(),
366 const std::enable_if_t<arma::is_arma_type<typename
367 std::remove_reference<WeightsType>::type>::value>* = 0);
368
388 template<typename MatType, typename LabelsType, typename WeightsType>
389 double Train(MatType data,
390 LabelsType labels,
391 const size_t numClasses,
392 WeightsType weights,
393 const size_t minimumLeafSize = 10,
394 const double minimumGainSplit = 1e-7,
395 const size_t maximumDepth = 0,
396 DimensionSelectionType dimensionSelector =
397 DimensionSelectionType(),
398 const std::enable_if_t<arma::is_arma_type<typename
399 std::remove_reference<WeightsType>::type>::value>* = 0);
400
407 template<typename VecType>
408 size_t Classify(const VecType& point) const;
409
419 template<typename VecType>
420 void Classify(const VecType& point,
421 size_t& prediction,
422 arma::vec& probabilities) const;
423
431 template<typename MatType>
432 void Classify(const MatType& data,
433 arma::Row<size_t>& predictions) const;
434
445 template<typename MatType>
446 void Classify(const MatType& data,
447 arma::Row<size_t>& predictions,
448 arma::mat& probabilities) const;
449
453 template<typename Archive>
454 void serialize(Archive& ar, const unsigned int /* version */);
455
457 size_t NumChildren() const { return children.size(); }
458
460 const DecisionTree& Child(const size_t i) const { return *children[i]; }
462 DecisionTree& Child(const size_t i) { return *children[i]; }
463
466 size_t SplitDimension() const { return splitDimension; }
467
475 template<typename VecType>
476 size_t CalculateDirection(const VecType& point) const;
477
481 size_t NumClasses() const;
482
483 private:
485 std::vector<DecisionTree*> children;
487 size_t splitDimension;
490 size_t dimensionTypeOrMajorityClass;
498 arma::vec classProbabilities;
499
503 typedef typename NumericSplit::template AuxiliarySplitInfo<ElemType>
504 NumericAuxiliarySplitInfo;
505 typedef typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
506 CategoricalAuxiliarySplitInfo;
507
511 template<bool UseWeights, typename RowType, typename WeightsRowType>
512 void CalculateClassProbabilities(const RowType& labels,
513 const size_t numClasses,
514 const WeightsRowType& weights);
515
533 template<bool UseWeights, typename MatType>
534 double Train(MatType& data,
535 const size_t begin,
536 const size_t count,
537 const data::DatasetInfo& datasetInfo,
538 arma::Row<size_t>& labels,
539 const size_t numClasses,
540 arma::rowvec& weights,
541 const size_t minimumLeafSize,
542 const double minimumGainSplit,
543 const size_t maximumDepth,
544 DimensionSelectionType& dimensionSelector);
545
562 template<bool UseWeights, typename MatType>
563 double Train(MatType& data,
564 const size_t begin,
565 const size_t count,
566 arma::Row<size_t>& labels,
567 const size_t numClasses,
568 arma::rowvec& weights,
569 const size_t minimumLeafSize,
570 const double minimumGainSplit,
571 const size_t maximumDepth,
572 DimensionSelectionType& dimensionSelector);
573};
574
578template<typename FitnessFunction = GiniGain,
579 template<typename> class NumericSplitType = BestBinaryNumericSplit,
580 template<typename> class CategoricalSplitType = AllCategoricalSplit,
581 typename DimensionSelectType = AllDimensionSelect,
582 typename ElemType = double>
583using DecisionStump = DecisionTree<FitnessFunction,
584 NumericSplitType,
585 CategoricalSplitType,
586 DimensionSelectType,
587 ElemType,
588 false>;
589
598 double,
600} // namespace tree
601} // namespace mlpack
602
603// Include implementation.
604#include "decision_tree_impl.hpp"
605
606#endif
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
This dimension selection policy allows any dimension to be selected for splitting.
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
This class implements a generic decision tree learner.
DecisionTree(const DecisionTree &other, MatType data, LabelsType labels, const size_t numClasses, WeightsType weights, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType(), const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *=0)
Take ownership of another decision tree and train on the given data and labels with weights,...
size_t NumClasses() const
Get the number of classes in the tree.
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
size_t NumChildren() const
Get the number of children.
DecisionTree(DecisionTree &&other)
Take ownership of another tree.
void Classify(const MatType &data, arma::Row< size_t > &predictions, arma::mat &probabilities) const
Classify the given points and also return estimates of the probabilities for each class in the given ...
DecisionTree(const DecisionTree &other, MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, WeightsType weights, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *=0)
Take ownership of another decision tree and train on the given data and labels with weights,...
DecisionTree(MatType data, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Construct the decision tree on the given data and labels, assuming that the data is all of the numeri...
~DecisionTree()
Clean up memory.
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
DecisionTree(const DecisionTree &other)
Copy another tree.
DecisionTree(MatType data, LabelsType labels, const size_t numClasses, WeightsType weights, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType(), const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *=0)
Construct the decision tree on the given data and labels with weights, assuming that the data is all ...
double Train(MatType data, LabelsType labels, const size_t numClasses, WeightsType weights, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType(), const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *=0)
Train the decision tree on the given weighted data, assuming that all dimensions are numeric.
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
size_t SplitDimension() const
Get the split dimension (only meaningful if this is a non-leaf in a trained tree).
DecisionTree(const size_t numClasses=1)
Construct a decision tree without training it.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, WeightsType weights, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType(), const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *=0)
Construct the decision tree on the given data and labels with weights, where the data can be both num...
void Classify(const MatType &data, arma::Row< size_t > &predictions) const
Classify the given points, using the entire tree.
void Classify(const VecType &point, size_t &prediction, arma::vec &probabilities) const
Classify the given point and also return estimates of the probability for each class in the given vec...
double Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data.
double Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, WeightsType weights, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType(), const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *=0)
Train the decision tree on the given weighted data.
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
DecisionTree & operator=(DecisionTree &&other)
Take ownership of another tree.
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
void serialize(Archive &ar, const unsigned int)
Serialize the tree.
double Train(MatType data, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data, assuming that all dimensions are numeric.
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
Definition: gini_gain.hpp:28
The standard information gain criterion, used for calculating gain in decision trees.
DecisionTree< InformationGain, BestBinaryNumericSplit, AllCategoricalSplit, AllDimensionSelect, double, true > ID3DecisionStump
Convenience typedef for ID3 decision stumps (single level decision trees made with the ID3 algorithm)...
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:70
The core includes that mlpack expects; standard C++ includes and Armadillo.