mlpack 3.4.2
discrete_distribution.hpp
Go to the documentation of this file.
1
14#ifndef MLPACK_CORE_DISTRIBUTIONS_DISCRETE_DISTRIBUTION_HPP
15#define MLPACK_CORE_DISTRIBUTIONS_DISCRETE_DISTRIBUTION_HPP
16
17#include <mlpack/prereqs.hpp>
20
21namespace mlpack {
22namespace distribution {
23
46{
47 public:
53 probabilities(std::vector<arma::vec>(1)){ /* Nothing to do. */ }
54
63 DiscreteDistribution(const size_t numObservations) :
64 probabilities(std::vector<arma::vec>(1,
65 arma::ones<arma::vec>(numObservations) / numObservations))
66 { /* Nothing to do. */ }
67
76 DiscreteDistribution(const arma::Col<size_t>& numObservations)
77 {
78 for (size_t i = 0; i < numObservations.n_elem; ++i)
79 {
80 const size_t numObs = size_t(numObservations[i]);
81 if (numObs <= 0)
82 {
83 std::ostringstream oss;
84 oss << "number of observations for dimension " << i << " is 0, but "
85 << "must be greater than 0";
86 throw std::invalid_argument(oss.str());
87 }
88 probabilities.push_back(arma::ones<arma::vec>(numObs) / numObs);
89 }
90 }
91
98 DiscreteDistribution(const std::vector<arma::vec>& probabilities)
99 {
100 for (size_t i = 0; i < probabilities.size(); ++i)
101 {
102 arma::vec temp = probabilities[i];
103 double sum = accu(temp);
104 if (sum > 0)
105 this->probabilities.push_back(temp / sum);
106 else
107 {
108 this->probabilities.push_back(arma::ones<arma::vec>(temp.n_elem)
109 / temp.n_elem);
110 }
111 }
112 }
113
117 size_t Dimensionality() const { return probabilities.size(); }
118
127 double Probability(const arma::vec& observation) const
128 {
129 double probability = 1.0;
130 // Ensure the observation has the same dimension with the probabilities.
131 if (observation.n_elem != probabilities.size())
132 {
133 Log::Fatal << "DiscreteDistribution::Probability(): observation has "
134 << "incorrect dimension " << observation.n_elem << " but should have"
135 << " dimension " << probabilities.size() << "!" << std::endl;
136 }
137
138 for (size_t dimension = 0; dimension < observation.n_elem; dimension++)
139 {
140 // Adding 0.5 helps ensure that we cast the floating point to a size_t
141 // correctly.
142 const size_t obs = size_t(observation(dimension) + 0.5);
143
144 // Ensure that the observation is within the bounds.
145 if (obs >= probabilities[dimension].n_elem)
146 {
147 Log::Fatal << "DiscreteDistribution::Probability(): received "
148 << "observation " << obs << "; observation must be in [0, "
149 << probabilities[dimension].n_elem << "] for this distribution."
150 << std::endl;
151 }
152 probability *= probabilities[dimension][obs];
153 }
154
155 return probability;
156 }
157
166 double LogProbability(const arma::vec& observation) const
167 {
168 // TODO: consider storing log probabilities instead?
169 return log(Probability(observation));
170 }
171
179 void Probability(const arma::mat& x, arma::vec& probabilities) const
180 {
181 probabilities.set_size(x.n_cols);
182 for (size_t i = 0; i < x.n_cols; ++i)
183 probabilities(i) = Probability(x.unsafe_col(i));
184 }
185
194 void LogProbability(const arma::mat& x, arma::vec& logProbabilities) const
195 {
196 logProbabilities.set_size(x.n_cols);
197 for (size_t i = 0; i < x.n_cols; ++i)
198 logProbabilities(i) = log(Probability(x.unsafe_col(i)));
199 }
200
208 arma::vec Random() const;
209
217 void Train(const arma::mat& observations);
218
228 void Train(const arma::mat& observations,
229 const arma::vec& probabilities);
230
232 arma::vec& Probabilities(const size_t dim = 0) { return probabilities[dim]; }
234 const arma::vec& Probabilities(const size_t dim = 0) const
235 { return probabilities[dim]; }
236
240 template<typename Archive>
241 void serialize(Archive& ar, const unsigned int /* version */)
242 {
243 ar & BOOST_SERIALIZATION_NVP(probabilities);
244 }
245
246 private:
249 std::vector<arma::vec> probabilities;
250};
251
252} // namespace distribution
253} // namespace mlpack
254
255#endif
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
A discrete distribution where the only observations are discrete observations.
arma::vec & Probabilities(const size_t dim=0)
Return the vector of probabilities for the given dimension.
const arma::vec & Probabilities(const size_t dim=0) const
Modify the vector of probabilities for the given dimension.
void Probability(const arma::mat &x, arma::vec &probabilities) const
Calculates the Discrete probability density function for each data point (column) in the given matrix...
void Train(const arma::mat &observations)
Estimate the probability distribution directly from the given observations.
arma::vec Random() const
Return a randomly generated observation (one-dimensional vector; one observation) according to the pr...
DiscreteDistribution()
Default constructor, which creates a distribution that has no observations.
double LogProbability(const arma::vec &observation) const
Return the log probability of the given observation.
size_t Dimensionality() const
Get the dimensionality of the distribution.
void LogProbability(const arma::mat &x, arma::vec &logProbabilities) const
Returns the Log probability of the given matrix.
double Probability(const arma::vec &observation) const
Return the probability of the given observation.
DiscreteDistribution(const std::vector< arma::vec > &probabilities)
Define the multidimensional discrete distribution as having the given probabilities for each observat...
DiscreteDistribution(const arma::Col< size_t > &numObservations)
Define the multidimensional discrete distribution as having numObservations possible observations.
void Train(const arma::mat &observations, const arma::vec &probabilities)
Estimate the probability distribution from the given observations, taking into account the probabilit...
DiscreteDistribution(const size_t numObservations)
Define the discrete distribution as having numObservations possible observations.
void serialize(Archive &ar, const unsigned int)
Serialize the distribution.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: prereqs.hpp:67
The core includes that mlpack expects; standard C++ includes and Armadillo.
Miscellaneous math random-related routines.