mlpack 3.4.2
decision_stump.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_HPP
13#define MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_HPP
14
15#include <mlpack/prereqs.hpp>
16
17namespace mlpack {
18namespace decision_stump {
19
38template<typename MatType = arma::mat>
40{
41 public:
51 mlpack_deprecated DecisionStump(const MatType& data,
52 const arma::Row<size_t>& labels,
53 const size_t numClasses,
54 const size_t bucketSize = 10);
55
69 const MatType& data,
70 const arma::Row<size_t>& labels,
71 const size_t numClasses,
72 const arma::rowvec& weights);
73
80
92 mlpack_deprecated double Train(const MatType& data,
93 const arma::Row<size_t>& labels,
94 const size_t numClasses,
95 const size_t bucketSize);
96
109 mlpack_deprecated double Train(const MatType& data,
110 const arma::Row<size_t>& labels,
111 const arma::rowvec& weights,
112 const size_t numClasses,
113 const size_t bucketSize);
114
123 mlpack_deprecated void Classify(const MatType& test,
124 arma::Row<size_t>& predictedLabels);
125
127 size_t SplitDimension() const { return splitDimension; }
129 size_t& SplitDimension() { return splitDimension; }
130
132 const arma::vec& Split() const { return split; }
134 arma::vec& Split() { return split; }
135
137 const arma::Col<size_t> BinLabels() const { return binLabels; }
139 arma::Col<size_t>& BinLabels() { return binLabels; }
140
142 template<typename Archive>
143 void serialize(Archive& ar, const unsigned int /* version */);
144
145 private:
147 size_t numClasses;
149 size_t bucketSize;
150
152 size_t splitDimension;
154 arma::vec split;
156 arma::Col<size_t> binLabels;
157
166 template<bool UseWeights, typename VecType>
167 double SetupSplitDimension(const VecType& dimension,
168 const arma::Row<size_t>& labels,
169 const arma::rowvec& weightD);
170
178 template<typename VecType>
179 void TrainOnDim(const VecType& dimension,
180 const arma::Row<size_t>& labels);
181
186 void MergeRanges();
187
194 template<typename VecType>
195 double CountMostFreq(const VecType& subCols);
196
202 template<typename VecType>
203 int IsDistinct(const VecType& featureRow);
204
214 template<bool UseWeights, typename VecType, typename WeightVecType>
215 double CalculateEntropy(const VecType& labels,
216 const WeightVecType& weights);
217
228 template<bool UseWeights>
229 double Train(const MatType& data,
230 const arma::Row<size_t>& labels,
231 const arma::rowvec& weights);
232};
233
234} // namespace decision_stump
235} // namespace mlpack
236
237#include "decision_stump_impl.hpp"
238
239#endif
This class implements a decision stump.
mlpack_deprecated DecisionStump(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t bucketSize=10)
Constructor.
mlpack_deprecated double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t bucketSize)
Train the decision stump on the given data.
mlpack_deprecated DecisionStump(const DecisionStump<> &other, const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const arma::rowvec &weights)
Alternate constructor which copies the parameters bucketSize and classes from an already initiated de...
const arma::vec & Split() const
Access the splitting values.
mlpack_deprecated void Classify(const MatType &test, arma::Row< size_t > &predictedLabels)
Classification function.
arma::Col< size_t > & BinLabels()
Modify the labels for each split bin (be careful!).
size_t SplitDimension() const
Access the splitting dimension.
arma::vec & Split()
Modify the splitting values (be careful!).
mlpack_deprecated double Train(const MatType &data, const arma::Row< size_t > &labels, const arma::rowvec &weights, const size_t numClasses, const size_t bucketSize)
Train the decision stump on the given data, with the given weights.
size_t & SplitDimension()
Modify the splitting dimension (be careful!).
const arma::Col< size_t > BinLabels() const
Access the labels for each split bin.
DecisionStump()
Create a decision stump without training.
void serialize(Archive &ar, const unsigned int)
Serialize the decision stump.
#define mlpack_deprecated
Definition: deprecated.hpp:22
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.