mlpack 3.4.2
diagonal_gmm.hpp
Go to the documentation of this file.
1
14#ifndef MLPACK_METHODS_GMM_DIAGONAL_GMM_HPP
15#define MLPACK_METHODS_GMM_DIAGONAL_GMM_HPP
16
17#include <mlpack/prereqs.hpp>
19
20// This is the default fitting method class.
21#include "em_fit.hpp"
22
23// This is the default covariance matrix constraint.
25
26namespace mlpack {
27namespace gmm {
28
75{
76 private:
78 size_t gaussians;
80 size_t dimensionality;
81
83 std::vector<distribution::DiagonalGaussianDistribution> dists;
84
86 arma::vec weights;
87
88 public:
93 gaussians(0),
94 dimensionality(0)
95 {
96 // Warn the user. They probably don't want to do this. If this
97 // constructor is being used (because it is required by some template
98 // classes), the user should know that it is potentially dangerous.
99 Log::Debug << "DiagonalGMM::DiagonalGMM(): no parameters given;"
100 "Estimate() may fail " << "unless parameters are set." << std::endl;
101 }
102
110 DiagonalGMM(const size_t gaussians, const size_t dimensionality);
111
118 DiagonalGMM(const std::vector<distribution::DiagonalGaussianDistribution>&
119 dists, const arma::vec& weights) :
120 gaussians(dists.size()),
121 dimensionality((!dists.empty()) ? dists[0].Mean().n_elem : 0),
122 dists(dists),
123 weights(weights) { /* Nothing to do. */ }
124
127
130
132 size_t Gaussians() const { return gaussians; }
134 size_t Dimensionality() const { return dimensionality; }
135
142 {
143 return dists[i];
144 }
145
152 {
153 return dists[i];
154 }
155
157 const arma::vec& Weights() const { return weights; }
159 arma::vec& Weights() { return weights; }
160
167 double Probability(const arma::vec& observation) const;
168
175 double LogProbability(const arma::vec& observation) const;
176
184 double Probability(const arma::vec& observation,
185 const size_t component) const;
186
194 double LogProbability(const arma::vec& observation,
195 const size_t component) const;
202 arma::vec Random() const;
203
226 template<typename FittingType = EMFit<kmeans::KMeans<>, DiagonalConstraint,
227 distribution::DiagonalGaussianDistribution>>
228 double Train(const arma::mat& observations,
229 const size_t trials = 1,
230 const bool useExistingModel = false,
231 FittingType fitter = FittingType());
232
258 template<typename FittingType = EMFit<kmeans::KMeans<>, DiagonalConstraint,
259 distribution::DiagonalGaussianDistribution>>
260 double Train(const arma::mat& observations,
261 const arma::vec& probabilities,
262 const size_t trials = 1,
263 const bool useExistingModel = false,
264 FittingType fitter = FittingType());
265
283 void Classify(const arma::mat& observations,
284 arma::Row<size_t>& labels) const;
285
289 template<typename Archive>
290 void serialize(Archive& ar, const unsigned int /* version */);
291
292 private:
302 double LogLikelihood(
303 const arma::mat& observations,
304 const std::vector<distribution::DiagonalGaussianDistribution>& dists,
305 const arma::vec& weights) const;
306};
307
308} // namespace gmm
309} // namespace mlpack
310
311// Include implementation.
312#include "diagonal_gmm_impl.hpp"
313
314#endif // MLPACK_METHODS_GMM_DIAGONAL_GMM_HPP
static MLPACK_EXPORT util::NullOutStream Debug
MLPACK_EXPORT is required for global variables, so that they are properly exported by the Windows com...
Definition: log.hpp:79
A single multivariate Gaussian distribution with diagonal covariance.
A Diagonal Gaussian Mixture Model.
size_t Gaussians() const
Return the number of Gaussians in the model.
DiagonalGMM & operator=(const DiagonalGMM &other)
Copy operator for DiagonalGMMs.
arma::vec Random() const
Return a randomly generated observation according to the probability distribution defined by this obj...
distribution::DiagonalGaussianDistribution & Component(size_t i)
Return a reference to a component distribution.
void Classify(const arma::mat &observations, arma::Row< size_t > &labels) const
Classify the given observations as being from an individual component in this DiagonalGMM.
arma::vec & Weights()
Return a reference to the a priori weights of each Gaussian.
double LogProbability(const arma::vec &observation, const size_t component) const
Return the log probability that the given observation came from the given Gaussian component in this ...
const arma::vec & Weights() const
Return a const reference to the a priori weights of each Gaussian.
double Train(const arma::mat &observations, const arma::vec &probabilities, const size_t trials=1, const bool useExistingModel=false, FittingType fitter=FittingType())
Estimate the probability distribution directly from the given observations, taking into account the p...
double LogProbability(const arma::vec &observation) const
Return the log probability that the given observation came from this distribution.
DiagonalGMM()
Create an empty Diagonal Gaussian Mixture Model, with zero gaussians.
size_t Dimensionality() const
Return the dimensionality of the model.
DiagonalGMM(const size_t gaussians, const size_t dimensionality)
Create a GMM with the given number of Gaussians, each of which have the specified dimensionality.
double Probability(const arma::vec &observation) const
Return the probability that the given observation came from this distribution.
double Train(const arma::mat &observations, const size_t trials=1, const bool useExistingModel=false, FittingType fitter=FittingType())
Estimate the probability distribution directly from the given observations, using the given algorithm...
DiagonalGMM(const DiagonalGMM &other)
Copy constructor for DiagonalGMMs.
void serialize(Archive &ar, const unsigned int)
Serialize the DiagonalGMM.
const distribution::DiagonalGaussianDistribution & Component(size_t i) const
Return a const reference to a component distribution.
DiagonalGMM(const std::vector< distribution::DiagonalGaussianDistribution > &dists, const arma::vec &weights)
Create a DiagonalGMM with the given dists and weights.
double Probability(const arma::vec &observation, const size_t component) const
Return the probability that the given observation came from the given Gaussian component in this dist...
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.