13#ifndef MLPACK_METHODS_DECISION_TREE_GINI_GAIN_HPP
14#define MLPACK_METHODS_DECISION_TREE_GINI_GAIN_HPP
33 template<
bool UseWeights,
typename CountType>
35 const size_t countLength,
36 const CountType totalCount)
41 CountType impurity = 0.0;
42 for (
size_t i = 0; i < countLength; ++i)
43 impurity += counts[i] * (totalCount - counts[i]);
45 return -((double) impurity / ((
double) std::pow(totalCount, 2)));
61 template<
bool UseWeights,
typename RowType,
typename WeightVecType>
63 const size_t numClasses,
64 const WeightVecType& weights)
67 if (labels.n_elem == 0)
72 arma::vec countSpace(4 * numClasses, arma::fill::zeros);
73 arma::vec counts(countSpace.memptr(), numClasses,
false,
true);
74 arma::vec counts2(countSpace.memptr() + numClasses, numClasses,
false,
76 arma::vec counts3(countSpace.memptr() + 2 * numClasses, numClasses,
false,
78 arma::vec counts4(countSpace.memptr() + 3 * numClasses, numClasses,
false,
82 double impurity = 0.0;
87 double accWeights[4] = { 0.0, 0.0, 0.0, 0.0 };
91 for (
size_t i = 3; i < labels.n_elem; i += 4)
93 const double weight1 = weights[i - 3];
94 const double weight2 = weights[i - 2];
95 const double weight3 = weights[i - 1];
96 const double weight4 = weights[i];
98 counts[labels[i - 3]] += weight1;
99 counts2[labels[i - 2]] += weight2;
100 counts3[labels[i - 1]] += weight3;
101 counts4[labels[i]] += weight4;
103 accWeights[0] += weight1;
104 accWeights[1] += weight2;
105 accWeights[2] += weight3;
106 accWeights[3] += weight4;
110 if (labels.n_elem % 4 == 1)
112 const double weight1 = weights[labels.n_elem - 1];
113 counts[labels[labels.n_elem - 1]] += weight1;
114 accWeights[0] += weight1;
116 else if (labels.n_elem % 4 == 2)
118 const double weight1 = weights[labels.n_elem - 2];
119 const double weight2 = weights[labels.n_elem - 1];
121 counts[labels[labels.n_elem - 2]] += weight1;
122 counts2[labels[labels.n_elem - 1]] += weight2;
124 accWeights[0] += weight1;
125 accWeights[1] += weight2;
127 else if (labels.n_elem % 4 == 3)
129 const double weight1 = weights[labels.n_elem - 3];
130 const double weight2 = weights[labels.n_elem - 2];
131 const double weight3 = weights[labels.n_elem - 1];
133 counts[labels[labels.n_elem - 3]] += weight1;
134 counts2[labels[labels.n_elem - 2]] += weight2;
135 counts3[labels[labels.n_elem - 1]] += weight3;
137 accWeights[0] += weight1;
138 accWeights[1] += weight2;
139 accWeights[2] += weight3;
142 accWeights[0] += accWeights[1] + accWeights[2] + accWeights[3];
143 counts += counts2 + counts3 + counts4;
146 if (accWeights[0] == 0.0)
149 for (
size_t i = 0; i < numClasses; ++i)
151 const double f = ((double) counts[i] / (
double) accWeights[0]);
152 impurity += f * (1.0 - f);
159 for (
size_t i = 3; i < labels.n_elem; i += 4)
161 counts[labels[i - 3]]++;
162 counts2[labels[i - 2]]++;
163 counts3[labels[i - 1]]++;
164 counts4[labels[i]]++;
168 if (labels.n_elem % 4 == 1)
170 counts[labels[labels.n_elem - 1]]++;
172 else if (labels.n_elem % 4 == 2)
174 counts[labels[labels.n_elem - 2]]++;
175 counts2[labels[labels.n_elem - 1]]++;
177 else if (labels.n_elem % 4 == 3)
179 counts[labels[labels.n_elem - 3]]++;
180 counts2[labels[labels.n_elem - 2]]++;
181 counts3[labels[labels.n_elem - 1]]++;
184 counts += counts2 + counts3 + counts4;
186 for (
size_t i = 0; i < numClasses; ++i)
188 const double f = ((double) counts[i] / (
double) labels.n_elem);
189 impurity += f * (1.0 - f);
203 static double Range(
const size_t numClasses)
208 return 1.0 - (1.0 / double(numClasses));
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
static double Evaluate(const RowType &labels, const size_t numClasses, const WeightVecType &weights)
Evaluate the Gini impurity on the given set of labels.
static double EvaluatePtr(const CountType *counts, const size_t countLength, const CountType totalCount)
Evaluate the Gini impurity given a vector of class weight counts.
static double Range(const size_t numClasses)
Return the range of the Gini impurity for the given number of classes.
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
Linear algebra utility functions, generally performed on matrices or vectors.