mlpack 3.4.2
em_fit.hpp
Go to the documentation of this file.
1
14#ifndef MLPACK_METHODS_GMM_EM_FIT_HPP
15#define MLPACK_METHODS_GMM_EM_FIT_HPP
16
17#include <mlpack/prereqs.hpp>
20
21// Default clustering mechanism.
23// Default covariance matrix constraint.
25
26namespace mlpack {
27namespace gmm {
28
42template<typename InitialClusteringType = kmeans::KMeans<>,
43 typename CovarianceConstraintPolicy = PositiveDefiniteConstraint,
44 typename Distribution = distribution::GaussianDistribution>
45class EMFit
46{
47 public:
64 EMFit(const size_t maxIterations = 300,
65 const double tolerance = 1e-10,
66 InitialClusteringType clusterer = InitialClusteringType(),
67 CovarianceConstraintPolicy constraint = CovarianceConstraintPolicy());
68
83 void Estimate(const arma::mat& observations,
84 std::vector<Distribution>& dists,
85 arma::vec& weights,
86 const bool useInitialModel = false);
87
104 void Estimate(const arma::mat& observations,
105 const arma::vec& probabilities,
106 std::vector<Distribution>& dists,
107 arma::vec& weights,
108 const bool useInitialModel = false);
109
111 const InitialClusteringType& Clusterer() const { return clusterer; }
113 InitialClusteringType& Clusterer() { return clusterer; }
114
116 const CovarianceConstraintPolicy& Constraint() const { return constraint; }
118 CovarianceConstraintPolicy& Constraint() { return constraint; }
119
121 size_t MaxIterations() const { return maxIterations; }
123 size_t& MaxIterations() { return maxIterations; }
124
126 double Tolerance() const { return tolerance; }
128 double& Tolerance() { return tolerance; }
129
131 template<typename Archive>
132 void serialize(Archive& ar, const unsigned int version);
133
134 private:
145 void InitialClustering(
146 const arma::mat& observations,
147 std::vector<Distribution>& dists,
148 arma::vec& weights);
149
160 double LogLikelihood(
161 const arma::mat& data,
162 const std::vector<Distribution>& dists,
163 const arma::vec& weights) const;
164
175 void ArmadilloGMMWrapper(
176 const arma::mat& observations,
177 std::vector<Distribution>& dists,
178 arma::vec& weights,
179 const bool useInitialModel);
180
182 size_t maxIterations;
184 double tolerance;
186 InitialClusteringType clusterer;
188 CovarianceConstraintPolicy constraint;
189};
190
191} // namespace gmm
192} // namespace mlpack
193
194// Include implementation.
195#include "em_fit_impl.hpp"
196
197#endif
This class contains methods which can fit a GMM to observations using the EM algorithm.
Definition: em_fit.hpp:46
EMFit(const size_t maxIterations=300, const double tolerance=1e-10, InitialClusteringType clusterer=InitialClusteringType(), CovarianceConstraintPolicy constraint=CovarianceConstraintPolicy())
Construct the EMFit object, optionally passing an InitialClusteringType object (just in case it needs...
void Estimate(const arma::mat &observations, std::vector< Distribution > &dists, arma::vec &weights, const bool useInitialModel=false)
Fit the observations to a Gaussian mixture model (GMM) using the EM algorithm.
size_t MaxIterations() const
Get the maximum number of iterations of the EM algorithm.
Definition: em_fit.hpp:121
InitialClusteringType & Clusterer()
Modify the clusterer.
Definition: em_fit.hpp:113
void serialize(Archive &ar, const unsigned int version)
Serialize the fitter.
size_t & MaxIterations()
Modify the maximum number of iterations of the EM algorithm.
Definition: em_fit.hpp:123
double & Tolerance()
Modify the tolerance for the convergence of the EM algorithm.
Definition: em_fit.hpp:128
double Tolerance() const
Get the tolerance for the convergence of the EM algorithm.
Definition: em_fit.hpp:126
const CovarianceConstraintPolicy & Constraint() const
Get the covariance constraint policy class.
Definition: em_fit.hpp:116
void Estimate(const arma::mat &observations, const arma::vec &probabilities, std::vector< Distribution > &dists, arma::vec &weights, const bool useInitialModel=false)
Fit the observations to a Gaussian mixture model (GMM) using the EM algorithm, taking into account th...
const InitialClusteringType & Clusterer() const
Get the clusterer.
Definition: em_fit.hpp:111
CovarianceConstraintPolicy & Constraint()
Modify the covariance constraint policy class.
Definition: em_fit.hpp:118
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.