mlpack 3.4.2
gini_impurity.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_HOEFFDING_TREES_GINI_INDEX_HPP
14#define MLPACK_METHODS_HOEFFDING_TREES_GINI_INDEX_HPP
15
16#include <mlpack/prereqs.hpp>
17
18namespace mlpack {
19namespace tree {
20
22{
23 public:
24 static double Evaluate(const arma::Mat<size_t>& counts)
25 {
26 // We need to sum over the difference between the un-split node and the
27 // split nodes. First we'll calculate the number of elements in each split
28 // and total.
29 size_t numElem = 0;
30 arma::vec splitCounts(counts.n_cols);
31 for (size_t i = 0; i < counts.n_cols; ++i)
32 {
33 splitCounts[i] = arma::accu(counts.col(i));
34 numElem += splitCounts[i];
35 }
36
37 // Corner case: if there are no elements, the impurity is zero.
38 if (numElem == 0)
39 return 0.0;
40
41 arma::Col<size_t> classCounts = arma::sum(counts, 1);
42
43 // Calculate the Gini impurity of the un-split node.
44 double impurity = 0.0;
45 for (size_t i = 0; i < classCounts.n_elem; ++i)
46 {
47 const double f = ((double) classCounts[i] / (double) numElem);
48 impurity += f * (1.0 - f);
49 }
50
51 // Now calculate the impurity of the split nodes and subtract them from the
52 // overall impurity.
53 for (size_t i = 0; i < counts.n_cols; ++i)
54 {
55 if (splitCounts[i] > 0)
56 {
57 double splitImpurity = 0.0;
58 for (size_t j = 0; j < counts.n_rows; ++j)
59 {
60 const double f = ((double) counts(j, i) / (double) splitCounts[i]);
61 splitImpurity += f * (1.0 - f);
62 }
63
64 impurity -= ((double) splitCounts[i] / (double) numElem) *
65 splitImpurity;
66 }
67 }
68
69 return impurity;
70 }
71
77 static double Range(const size_t numClasses)
78 {
79 // The best possible case is that only one class exists, which gives a Gini
80 // impurity of 0. The worst possible case is that the classes are evenly
81 // distributed, which gives n * (1/n * (1 - 1/n)) = 1 - 1/n.
82 return 1.0 - (1.0 / double(numClasses));
83 }
84};
85
86} // namespace tree
87} // namespace mlpack
88
89#endif
static double Evaluate(const arma::Mat< size_t > &counts)
static double Range(const size_t numClasses)
Return the range of the Gini impurity for the given number of classes.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.