mlpack 3.4.2
naive_bayes_classifier.hpp
Go to the documentation of this file.
1
15#ifndef MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_HPP
16#define MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_HPP
17
18#include <mlpack/prereqs.hpp>
19
20namespace mlpack {
21namespace naive_bayes {
22
57template<typename ModelMatType = arma::mat>
59{
60 public:
61 // Convenience typedef.
62 typedef typename ModelMatType::elem_type ElemType;
63
83 template<typename MatType>
84 NaiveBayesClassifier(const MatType& data,
85 const arma::Row<size_t>& labels,
86 const size_t numClasses,
87 const bool incrementalVariance = false,
88 const double epsilon = 1e-10);
89
96 NaiveBayesClassifier(const size_t dimensionality = 0,
97 const size_t numClasses = 0,
98 const double epsilon = 1e-10);
99
117 template<typename MatType>
118 void Train(const MatType& data,
119 const arma::Row<size_t>& labels,
120 const size_t numClasses,
121 const bool incremental = true);
122
131 template<typename VecType>
132 void Train(const VecType& point, const size_t label);
133
140 template<typename VecType>
141 size_t Classify(const VecType& point) const;
142
153 template<typename VecType, typename ProbabilitiesVecType>
154 void Classify(const VecType& point,
155 size_t& prediction,
156 ProbabilitiesVecType& probabilities) const;
157
172 template<typename MatType>
173 void Classify(const MatType& data,
174 arma::Row<size_t>& predictions) const;
175
197 template<typename MatType, typename ProbabilitiesMatType>
198 void Classify(const MatType& data,
199 arma::Row<size_t>& predictions,
200 ProbabilitiesMatType& probabilities) const;
201
203 const ModelMatType& Means() const { return means; }
205 ModelMatType& Means() { return means; }
206
208 const ModelMatType& Variances() const { return variances; }
210 ModelMatType& Variances() { return variances; }
211
213 const ModelMatType& Probabilities() const { return probabilities; }
215 ModelMatType& Probabilities() { return probabilities; }
216
218 template<typename Archive>
219 void serialize(Archive& ar, const unsigned int /* version */);
220
221 private:
223 ModelMatType means;
225 ModelMatType variances;
227 ModelMatType probabilities;
229 size_t trainingPoints;
231 double epsilon;
232
241 template<typename MatType>
242 void LogLikelihood(const MatType& data,
243 ModelMatType& logLikelihoods) const;
244};
245
246} // namespace naive_bayes
247} // namespace mlpack
248
249// Include implementation.
250#include "naive_bayes_classifier_impl.hpp"
251
252#endif
void Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const bool incremental=true)
Train the Naive Bayes classifier on the given dataset.
void Classify(const VecType &point, size_t &prediction, ProbabilitiesVecType &probabilities) const
Classify the given point using the trained NaiveBayesClassifier model and also return estimates of th...
void Train(const VecType &point, const size_t label)
Train the Naive Bayes classifier on the given point.
ModelMatType & Variances()
Modify the sample variances for each class.
ModelMatType & Means()
Modify the sample means for each class.
const ModelMatType & Means() const
Get the sample means for each class.
void Classify(const MatType &data, arma::Row< size_t > &predictions, ProbabilitiesMatType &probabilities) const
Classify the given points using the trained NaiveBayesClassifier model and also return estimates of t...
size_t Classify(const VecType &point) const
Classify the given point, using the trained NaiveBayesClassifier model.
const ModelMatType & Probabilities() const
Get the prior probabilities for each class.
NaiveBayesClassifier(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const bool incrementalVariance=false, const double epsilon=1e-10)
Initializes the classifier as per the input and then trains it by calculating the sample mean and var...
ModelMatType & Probabilities()
Modify the prior probabilities for each class.
void Classify(const MatType &data, arma::Row< size_t > &predictions) const
Classify the given points using the trained NaiveBayesClassifier model.
const ModelMatType & Variances() const
Get the sample variances for each class.
NaiveBayesClassifier(const size_t dimensionality=0, const size_t numClasses=0, const double epsilon=1e-10)
Initialize the Naive Bayes classifier without performing training.
void serialize(Archive &ar, const unsigned int)
Serialize the classifier.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.