mlpack 3.4.2
information_gain.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_DECISION_TREE_INFORMATION_GAIN_HPP
14#define MLPACK_METHODS_DECISION_TREE_INFORMATION_GAIN_HPP
15
16#include <mlpack/prereqs.hpp>
17
18namespace mlpack {
19namespace tree {
20
26{
27 public:
31 template<bool UseWeights, typename CountType>
32 static double EvaluatePtr(const CountType* counts,
33 const size_t countLength,
34 const CountType totalCount)
35 {
36 double gain = 0.0;
37
38 for (size_t i = 0; i < countLength; ++i)
39 {
40 const double f = ((double) counts[i] / (double) totalCount);
41 if (f > 0.0)
42 gain += f * std::log2(f);
43 }
44
45 return gain;
46 }
47
59 template<bool UseWeights>
60 static double Evaluate(const arma::Row<size_t>& labels,
61 const size_t numClasses,
62 const arma::Row<double>& weights)
63 {
64 // Edge case: if there are no elements, the gain is zero.
65 if (labels.n_elem == 0)
66 return 0.0;
67
68 // Calculate the information gain.
69 double gain = 0.0;
70
71 // Count the number of elements in each class. Use four auxiliary vectors
72 // to exploit SIMD instructions if possible.
73 arma::vec countSpace(4 * numClasses, arma::fill::zeros);
74 arma::vec counts(countSpace.memptr(), numClasses, false, true);
75 arma::vec counts2(countSpace.memptr() + numClasses, numClasses, false,
76 true);
77 arma::vec counts3(countSpace.memptr() + 2 * numClasses, numClasses, false,
78 true);
79 arma::vec counts4(countSpace.memptr() + 3 * numClasses, numClasses, false,
80 true);
81
82 if (UseWeights)
83 {
84 // Sum all the weights up.
85 double accWeights[4] = { 0.0, 0.0, 0.0, 0.0 };
86
87 // SIMD loop: add counts for four elements simultaneously (if the compiler
88 // manages to vectorize the loop).
89 for (size_t i = 3; i < labels.n_elem; i += 4)
90 {
91 const double weight1 = weights[i - 3];
92 const double weight2 = weights[i - 2];
93 const double weight3 = weights[i - 1];
94 const double weight4 = weights[i];
95
96 counts[labels[i - 3]] += weight1;
97 counts2[labels[i - 2]] += weight2;
98 counts3[labels[i - 1]] += weight3;
99 counts4[labels[i]] += weight4;
100
101 accWeights[0] += weight1;
102 accWeights[1] += weight2;
103 accWeights[2] += weight3;
104 accWeights[3] += weight4;
105 }
106
107 // Handle leftovers.
108 if (labels.n_elem % 4 == 1)
109 {
110 const double weight1 = weights[labels.n_elem - 1];
111 counts[labels[labels.n_elem - 1]] += weight1;
112 accWeights[0] += weight1;
113 }
114 else if (labels.n_elem % 4 == 2)
115 {
116 const double weight1 = weights[labels.n_elem - 2];
117 const double weight2 = weights[labels.n_elem - 1];
118
119 counts[labels[labels.n_elem - 2]] += weight1;
120 counts2[labels[labels.n_elem - 1]] += weight2;
121
122 accWeights[0] += weight1;
123 accWeights[1] += weight2;
124 }
125 else if (labels.n_elem % 4 == 3)
126 {
127 const double weight1 = weights[labels.n_elem - 3];
128 const double weight2 = weights[labels.n_elem - 2];
129 const double weight3 = weights[labels.n_elem - 1];
130
131 counts[labels[labels.n_elem - 3]] += weight1;
132 counts2[labels[labels.n_elem - 2]] += weight2;
133 counts3[labels[labels.n_elem - 1]] += weight3;
134
135 accWeights[0] += weight1;
136 accWeights[1] += weight2;
137 accWeights[2] += weight3;
138 }
139
140 accWeights[0] += accWeights[1] + accWeights[2] + accWeights[3];
141 counts += counts2 + counts3 + counts4;
142
143 // Corner case: return 0 if no weight.
144 if (accWeights[0] == 0.0)
145 return 0.0;
146
147 for (size_t i = 0; i < numClasses; ++i)
148 {
149 const double f = ((double) counts[i] / (double) accWeights[0]);
150 if (f > 0.0)
151 gain += f * std::log2(f);
152 }
153 }
154 else
155 {
156 // SIMD loop: add counts for four elements simultaneously (if the compiler
157 // manages to vectorize the loop).
158 for (size_t i = 3; i < labels.n_elem; i += 4)
159 {
160 counts[labels[i - 3]]++;
161 counts2[labels[i - 2]]++;
162 counts3[labels[i - 1]]++;
163 counts4[labels[i]]++;
164 }
165
166 // Handle leftovers.
167 if (labels.n_elem % 4 == 1)
168 {
169 counts[labels[labels.n_elem - 1]]++;
170 }
171 else if (labels.n_elem % 4 == 2)
172 {
173 counts[labels[labels.n_elem - 2]]++;
174 counts2[labels[labels.n_elem - 1]]++;
175 }
176 else if (labels.n_elem % 4 == 3)
177 {
178 counts[labels[labels.n_elem - 3]]++;
179 counts2[labels[labels.n_elem - 2]]++;
180 counts3[labels[labels.n_elem - 1]]++;
181 }
182
183 counts += counts2 + counts3 + counts4;
184
185 for (size_t i = 0; i < numClasses; ++i)
186 {
187 const double f = ((double) counts[i] / (double) labels.n_elem);
188 if (f > 0.0)
189 gain += f * std::log2(f);
190 }
191 }
192
193 return gain;
194 }
195
203 static double Range(const size_t numClasses)
204 {
205 // The best possible case gives an information gain of 0. The worst
206 // possible case is even distribution, which gives n * (1/n * log2(1/n)) =
207 // log2(1/n) = -log2(n). So, the range is log2(n).
208 return std::log2(numClasses);
209 }
210};
211
212} // namespace tree
213} // namespace mlpack
214
215#endif
The standard information gain criterion, used for calculating gain in decision trees.
static double EvaluatePtr(const CountType *counts, const size_t countLength, const CountType totalCount)
Evaluate the Gini impurity given a vector of class weight counts.
static double Evaluate(const arma::Row< size_t > &labels, const size_t numClasses, const arma::Row< double > &weights)
Given a set of labels, calculate the information gain of those labels.
static double Range(const size_t numClasses)
Return the range of the information gain for the given number of classes.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.