mlpack 3.4.2
gini_gain.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_DECISION_TREE_GINI_GAIN_HPP
14#define MLPACK_METHODS_DECISION_TREE_GINI_GAIN_HPP
15
16#include <mlpack/core.hpp>
17
18namespace mlpack {
19namespace tree {
20
28{
29 public:
33 template<bool UseWeights, typename CountType>
34 static double EvaluatePtr(const CountType* counts,
35 const size_t countLength,
36 const CountType totalCount)
37 {
38 if (totalCount == 0)
39 return 0.0;
40
41 CountType impurity = 0.0;
42 for (size_t i = 0; i < countLength; ++i)
43 impurity += counts[i] * (totalCount - counts[i]);
44
45 return -((double) impurity / ((double) std::pow(totalCount, 2)));
46 }
47
61 template<bool UseWeights, typename RowType, typename WeightVecType>
62 static double Evaluate(const RowType& labels,
63 const size_t numClasses,
64 const WeightVecType& weights)
65 {
66 // Corner case: if there are no elements, the impurity is zero.
67 if (labels.n_elem == 0)
68 return 0.0;
69
70 // Count the number of elements in each class. Use four auxiliary vectors
71 // to exploit SIMD instructions if possible.
72 arma::vec countSpace(4 * numClasses, arma::fill::zeros);
73 arma::vec counts(countSpace.memptr(), numClasses, false, true);
74 arma::vec counts2(countSpace.memptr() + numClasses, numClasses, false,
75 true);
76 arma::vec counts3(countSpace.memptr() + 2 * numClasses, numClasses, false,
77 true);
78 arma::vec counts4(countSpace.memptr() + 3 * numClasses, numClasses, false,
79 true);
80
81 // Calculate the Gini impurity of the un-split node.
82 double impurity = 0.0;
83
84 if (UseWeights)
85 {
86 // Sum all the weights up.
87 double accWeights[4] = { 0.0, 0.0, 0.0, 0.0 };
88
89 // SIMD loop: add counts for four elements simultaneously (if the compiler
90 // manages to vectorize the loop).
91 for (size_t i = 3; i < labels.n_elem; i += 4)
92 {
93 const double weight1 = weights[i - 3];
94 const double weight2 = weights[i - 2];
95 const double weight3 = weights[i - 1];
96 const double weight4 = weights[i];
97
98 counts[labels[i - 3]] += weight1;
99 counts2[labels[i - 2]] += weight2;
100 counts3[labels[i - 1]] += weight3;
101 counts4[labels[i]] += weight4;
102
103 accWeights[0] += weight1;
104 accWeights[1] += weight2;
105 accWeights[2] += weight3;
106 accWeights[3] += weight4;
107 }
108
109 // Handle leftovers.
110 if (labels.n_elem % 4 == 1)
111 {
112 const double weight1 = weights[labels.n_elem - 1];
113 counts[labels[labels.n_elem - 1]] += weight1;
114 accWeights[0] += weight1;
115 }
116 else if (labels.n_elem % 4 == 2)
117 {
118 const double weight1 = weights[labels.n_elem - 2];
119 const double weight2 = weights[labels.n_elem - 1];
120
121 counts[labels[labels.n_elem - 2]] += weight1;
122 counts2[labels[labels.n_elem - 1]] += weight2;
123
124 accWeights[0] += weight1;
125 accWeights[1] += weight2;
126 }
127 else if (labels.n_elem % 4 == 3)
128 {
129 const double weight1 = weights[labels.n_elem - 3];
130 const double weight2 = weights[labels.n_elem - 2];
131 const double weight3 = weights[labels.n_elem - 1];
132
133 counts[labels[labels.n_elem - 3]] += weight1;
134 counts2[labels[labels.n_elem - 2]] += weight2;
135 counts3[labels[labels.n_elem - 1]] += weight3;
136
137 accWeights[0] += weight1;
138 accWeights[1] += weight2;
139 accWeights[2] += weight3;
140 }
141
142 accWeights[0] += accWeights[1] + accWeights[2] + accWeights[3];
143 counts += counts2 + counts3 + counts4;
144
145 // Catch edge case: if there are no weights, the impurity is zero.
146 if (accWeights[0] == 0.0)
147 return 0.0;
148
149 for (size_t i = 0; i < numClasses; ++i)
150 {
151 const double f = ((double) counts[i] / (double) accWeights[0]);
152 impurity += f * (1.0 - f);
153 }
154 }
155 else
156 {
157 // SIMD loop: add counts for four elements simultaneously (if the compiler
158 // manages to vectorize the loop).
159 for (size_t i = 3; i < labels.n_elem; i += 4)
160 {
161 counts[labels[i - 3]]++;
162 counts2[labels[i - 2]]++;
163 counts3[labels[i - 1]]++;
164 counts4[labels[i]]++;
165 }
166
167 // Handle leftovers.
168 if (labels.n_elem % 4 == 1)
169 {
170 counts[labels[labels.n_elem - 1]]++;
171 }
172 else if (labels.n_elem % 4 == 2)
173 {
174 counts[labels[labels.n_elem - 2]]++;
175 counts2[labels[labels.n_elem - 1]]++;
176 }
177 else if (labels.n_elem % 4 == 3)
178 {
179 counts[labels[labels.n_elem - 3]]++;
180 counts2[labels[labels.n_elem - 2]]++;
181 counts3[labels[labels.n_elem - 1]]++;
182 }
183
184 counts += counts2 + counts3 + counts4;
185
186 for (size_t i = 0; i < numClasses; ++i)
187 {
188 const double f = ((double) counts[i] / (double) labels.n_elem);
189 impurity += f * (1.0 - f);
190 }
191 }
192
193 return -impurity;
194 }
195
203 static double Range(const size_t numClasses)
204 {
205 // The best possible case is that only one class exists, which gives a Gini
206 // impurity of 0. The worst possible case is that the classes are evenly
207 // distributed, which gives n * (1/n * (1 - 1/n)) = 1 - 1/n.
208 return 1.0 - (1.0 / double(numClasses));
209 }
210};
211
212} // namespace tree
213} // namespace mlpack
214
215#endif
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
Definition: gini_gain.hpp:28
static double Evaluate(const RowType &labels, const size_t numClasses, const WeightVecType &weights)
Evaluate the Gini impurity on the given set of labels.
Definition: gini_gain.hpp:62
static double EvaluatePtr(const CountType *counts, const size_t countLength, const CountType totalCount)
Evaluate the Gini impurity given a vector of class weight counts.
Definition: gini_gain.hpp:34
static double Range(const size_t numClasses)
Return the range of the Gini impurity for the given number of classes.
Definition: gini_gain.hpp:203
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1