12#ifndef MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_HPP
13#define MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_HPP
18namespace decision_stump {
38template<
typename MatType = arma::mat>
52 const arma::Row<size_t>& labels,
53 const size_t numClasses,
54 const size_t bucketSize = 10);
70 const arma::Row<size_t>& labels,
71 const size_t numClasses,
72 const arma::rowvec& weights);
93 const arma::Row<size_t>& labels,
94 const size_t numClasses,
95 const size_t bucketSize);
110 const arma::Row<size_t>& labels,
111 const arma::rowvec& weights,
112 const size_t numClasses,
113 const size_t bucketSize);
124 arma::Row<size_t>& predictedLabels);
132 const arma::vec&
Split()
const {
return split; }
134 arma::vec&
Split() {
return split; }
137 const arma::Col<size_t>
BinLabels()
const {
return binLabels; }
142 template<
typename Archive>
152 size_t splitDimension;
156 arma::Col<size_t> binLabels;
166 template<
bool UseWeights,
typename VecType>
167 double SetupSplitDimension(
const VecType& dimension,
168 const arma::Row<size_t>& labels,
169 const arma::rowvec& weightD);
178 template<
typename VecType>
179 void TrainOnDim(
const VecType& dimension,
180 const arma::Row<size_t>& labels);
194 template<
typename VecType>
195 double CountMostFreq(
const VecType& subCols);
202 template<
typename VecType>
203 int IsDistinct(
const VecType& featureRow);
214 template<
bool UseWeights,
typename VecType,
typename WeightVecType>
215 double CalculateEntropy(
const VecType& labels,
216 const WeightVecType& weights);
228 template<
bool UseWeights>
229 double Train(
const MatType& data,
230 const arma::Row<size_t>& labels,
231 const arma::rowvec& weights);
237#include "decision_stump_impl.hpp"
This class implements a decision stump.
mlpack_deprecated DecisionStump(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t bucketSize=10)
Constructor.
mlpack_deprecated double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t bucketSize)
Train the decision stump on the given data.
mlpack_deprecated DecisionStump(const DecisionStump<> &other, const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const arma::rowvec &weights)
Alternate constructor which copies the parameters bucketSize and classes from an already initiated de...
const arma::vec & Split() const
Access the splitting values.
mlpack_deprecated void Classify(const MatType &test, arma::Row< size_t > &predictedLabels)
Classification function.
arma::Col< size_t > & BinLabels()
Modify the labels for each split bin (be careful!).
size_t SplitDimension() const
Access the splitting dimension.
arma::vec & Split()
Modify the splitting values (be careful!).
mlpack_deprecated double Train(const MatType &data, const arma::Row< size_t > &labels, const arma::rowvec &weights, const size_t numClasses, const size_t bucketSize)
Train the decision stump on the given data, with the given weights.
size_t & SplitDimension()
Modify the splitting dimension (be careful!).
const arma::Col< size_t > BinLabels() const
Access the labels for each split bin.
DecisionStump()
Create a decision stump without training.
void serialize(Archive &ar, const unsigned int)
Serialize the decision stump.
#define mlpack_deprecated
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.