13#ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
14#define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
34template<
typename FitnessFunction = GiniGain,
35 template<
typename>
class NumericSplitType = BestBinaryNumericSplit,
36 template<
typename>
class CategoricalSplitType = AllCategoricalSplit,
37 typename DimensionSelectionType = AllDimensionSelect,
38 typename ElemType = double,
39 bool NoRecursion =
false>
41 public NumericSplitType<FitnessFunction>::template
42 AuxiliarySplitInfo<ElemType>,
43 public CategoricalSplitType<FitnessFunction>::template
44 AuxiliarySplitInfo<ElemType>
71 template<
typename MatType,
typename LabelsType>
75 const size_t numClasses,
76 const size_t minimumLeafSize = 10,
77 const double minimumGainSplit = 1e-7,
78 const size_t maximumDepth = 0,
79 DimensionSelectionType dimensionSelector =
80 DimensionSelectionType());
98 template<
typename MatType,
typename LabelsType>
101 const size_t numClasses,
102 const size_t minimumLeafSize = 10,
103 const double minimumGainSplit = 1e-7,
104 const size_t maximumDepth = 0,
105 DimensionSelectionType dimensionSelector =
106 DimensionSelectionType());
127 template<
typename MatType,
typename LabelsType,
typename WeightsType>
132 const size_t numClasses,
134 const size_t minimumLeafSize = 10,
135 const double minimumGainSplit = 1e-7,
136 const size_t maximumDepth = 0,
137 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
139 typename std::remove_reference<WeightsType>::type>::value>* = 0);
159 template<
typename MatType,
typename LabelsType,
typename WeightsType>
165 const size_t numClasses,
167 const size_t minimumLeafSize = 10,
168 const double minimumGainSplit = 1e-7,
170 typename std::remove_reference<WeightsType>::type>::value>* = 0);
189 template<
typename MatType,
typename LabelsType,
typename WeightsType>
193 const size_t numClasses,
195 const size_t minimumLeafSize = 10,
196 const double minimumGainSplit = 1e-7,
197 const size_t maximumDepth = 0,
198 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
200 typename std::remove_reference<WeightsType>::type>::value>* = 0);
220 template<
typename MatType,
typename LabelsType,
typename WeightsType>
225 const size_t numClasses,
227 const size_t minimumLeafSize = 10,
228 const double minimumGainSplit = 1e-7,
229 const size_t maximumDepth = 0,
230 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
232 typename std::remove_reference<WeightsType>::type>::value>* = 0);
296 template<
typename MatType,
typename LabelsType>
300 const size_t numClasses,
301 const size_t minimumLeafSize = 10,
302 const double minimumGainSplit = 1e-7,
303 const size_t maximumDepth = 0,
304 DimensionSelectionType dimensionSelector =
305 DimensionSelectionType());
324 template<
typename MatType,
typename LabelsType>
327 const size_t numClasses,
328 const size_t minimumLeafSize = 10,
329 const double minimumGainSplit = 1e-7,
330 const size_t maximumDepth = 0,
331 DimensionSelectionType dimensionSelector =
332 DimensionSelectionType());
355 template<
typename MatType,
typename LabelsType,
typename WeightsType>
359 const size_t numClasses,
361 const size_t minimumLeafSize = 10,
362 const double minimumGainSplit = 1e-7,
363 const size_t maximumDepth = 0,
364 DimensionSelectionType dimensionSelector =
365 DimensionSelectionType(),
367 std::remove_reference<WeightsType>::type>::value>* = 0);
388 template<
typename MatType,
typename LabelsType,
typename WeightsType>
391 const size_t numClasses,
393 const size_t minimumLeafSize = 10,
394 const double minimumGainSplit = 1e-7,
395 const size_t maximumDepth = 0,
396 DimensionSelectionType dimensionSelector =
397 DimensionSelectionType(),
399 std::remove_reference<WeightsType>::type>::value>* = 0);
407 template<
typename VecType>
419 template<
typename VecType>
422 arma::vec& probabilities)
const;
431 template<
typename MatType>
433 arma::Row<size_t>& predictions)
const;
445 template<
typename MatType>
447 arma::Row<size_t>& predictions,
448 arma::mat& probabilities)
const;
453 template<
typename Archive>
475 template<
typename VecType>
485 std::vector<DecisionTree*> children;
487 size_t splitDimension;
490 size_t dimensionTypeOrMajorityClass;
498 arma::vec classProbabilities;
503 typedef typename NumericSplit::template AuxiliarySplitInfo<ElemType>
504 NumericAuxiliarySplitInfo;
505 typedef typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
506 CategoricalAuxiliarySplitInfo;
511 template<
bool UseWeights,
typename RowType,
typename WeightsRowType>
512 void CalculateClassProbabilities(
const RowType& labels,
513 const size_t numClasses,
514 const WeightsRowType& weights);
533 template<
bool UseWeights,
typename MatType>
534 double Train(MatType& data,
538 arma::Row<size_t>& labels,
539 const size_t numClasses,
540 arma::rowvec& weights,
541 const size_t minimumLeafSize,
542 const double minimumGainSplit,
543 const size_t maximumDepth,
544 DimensionSelectionType& dimensionSelector);
562 template<
bool UseWeights,
typename MatType>
563 double Train(MatType& data,
566 arma::Row<size_t>& labels,
567 const size_t numClasses,
568 arma::rowvec& weights,
569 const size_t minimumLeafSize,
570 const double minimumGainSplit,
571 const size_t maximumDepth,
572 DimensionSelectionType& dimensionSelector);
578template<
typename FitnessFunction =
GiniGain,
582 typename ElemType =
double>
585 CategoricalSplitType,
604#include "decision_tree_impl.hpp"
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
This dimension selection policy allows any dimension to be selected for splitting.
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
This class implements a generic decision tree learner.
DecisionTree(const DecisionTree &other, MatType data, LabelsType labels, const size_t numClasses, WeightsType weights, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType(), const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *=0)
Take ownership of another decision tree and train on the given data and labels with weights,...
size_t NumClasses() const
Get the number of classes in the tree.
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
size_t NumChildren() const
Get the number of children.
DecisionTree(DecisionTree &&other)
Take ownership of another tree.
void Classify(const MatType &data, arma::Row< size_t > &predictions, arma::mat &probabilities) const
Classify the given points and also return estimates of the probabilities for each class in the given ...
DecisionTree(const DecisionTree &other, MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, WeightsType weights, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *=0)
Take ownership of another decision tree and train on the given data and labels with weights,...
DecisionTree(MatType data, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Construct the decision tree on the given data and labels, assuming that the data is all of the numeri...
~DecisionTree()
Clean up memory.
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
DecisionTree(const DecisionTree &other)
Copy another tree.
DecisionTree(MatType data, LabelsType labels, const size_t numClasses, WeightsType weights, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType(), const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *=0)
Construct the decision tree on the given data and labels with weights, assuming that the data is all ...
double Train(MatType data, LabelsType labels, const size_t numClasses, WeightsType weights, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType(), const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *=0)
Train the decision tree on the given weighted data, assuming that all dimensions are numeric.
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
size_t SplitDimension() const
Get the split dimension (only meaningful if this is a non-leaf in a trained tree).
DecisionTree(const size_t numClasses=1)
Construct a decision tree without training it.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, WeightsType weights, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType(), const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *=0)
Construct the decision tree on the given data and labels with weights, where the data can be both num...
void Classify(const MatType &data, arma::Row< size_t > &predictions) const
Classify the given points, using the entire tree.
void Classify(const VecType &point, size_t &prediction, arma::vec &probabilities) const
Classify the given point and also return estimates of the probability for each class in the given vec...
double Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data.
double Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, WeightsType weights, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType(), const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *=0)
Train the decision tree on the given weighted data.
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
DecisionTree & operator=(DecisionTree &&other)
Take ownership of another tree.
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
void serialize(Archive &ar, const unsigned int)
Serialize the tree.
double Train(MatType data, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data, assuming that all dimensions are numeric.
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
DecisionTree< InformationGain, BestBinaryNumericSplit, AllCategoricalSplit, AllDimensionSelect, double, true > ID3DecisionStump
Convenience typedef for ID3 decision stumps (single level decision trees made with the ID3 algorithm)...
Linear algebra utility functions, generally performed on matrices or vectors.
typename enable_if< B, T >::type enable_if_t
The core includes that mlpack expects; standard C++ includes and Armadillo.