mlpack 3.4.2
information_gain.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_HOEFFDING_TREES_INFORMATION_GAIN_HPP
14#define MLPACK_METHODS_HOEFFDING_TREES_INFORMATION_GAIN_HPP
15
16namespace mlpack {
17namespace tree {
18
20{
21 public:
31 static double Evaluate(const arma::Mat<size_t>& counts)
32 {
33 // Calculate the number of elements in the unsplit node and also in each
34 // proposed child.
35 size_t numElem = 0;
36 arma::vec splitCounts(counts.n_elem);
37 for (size_t i = 0; i < counts.n_cols; ++i)
38 {
39 splitCounts[i] = arma::accu(counts.col(i));
40 numElem += splitCounts[i];
41 }
42
43 // Corner case: if there are no elements, the gain is zero.
44 if (numElem == 0)
45 return 0.0;
46
47 arma::Col<size_t> classCounts = arma::sum(counts, 1);
48
49 // Calculate the gain of the unsplit node.
50 double gain = 0.0;
51 for (size_t i = 0; i < classCounts.n_elem; ++i)
52 {
53 const double f = ((double) classCounts[i] / (double) numElem);
54 if (f > 0.0)
55 gain -= f * std::log2(f);
56 }
57
58 // Now calculate the impurity of the split nodes and subtract them from the
59 // overall gain.
60 for (size_t i = 0; i < counts.n_cols; ++i)
61 {
62 if (splitCounts[i] > 0)
63 {
64 double splitGain = 0.0;
65 for (size_t j = 0; j < counts.n_rows; ++j)
66 {
67 const double f = ((double) counts(j, i) / (double) splitCounts[i]);
68 if (f > 0.0)
69 splitGain += f * std::log2(f);
70 }
71
72 gain += ((double) splitCounts[i] / (double) numElem) * splitGain;
73 }
74 }
75
76 return gain;
77 }
78
84 static double Range(const size_t numClasses)
85 {
86 // The best possible case gives an information gain of 0. The worst
87 // possible case is even distribution, which gives n * (1/n * log2(1/n)) =
88 // log2(1/n) = -log2(n). So, the range is log2(n).
89 return std::log2(numClasses);
90 }
91};
92
93} // namespace tree
94} // namespace mlpack
95
96#endif
static double Evaluate(const arma::Mat< size_t > &counts)
Given the sufficient statistics of a proposed split, calculate the information gain if that split was...
static double Range(const size_t numClasses)
Return the range of the information gain for the given number of classes.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1