mlpack 3.4.2
bernoulli_distribution.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_METHODS_ANN_DISTRIBUTIONS_BERNOULLI_DISTRIBUTION_HPP
13#define MLPACK_METHODS_ANN_DISTRIBUTIONS_BERNOULLI_DISTRIBUTION_HPP
14
15#include <mlpack/prereqs.hpp>
16#include "../activation_functions/logistic_function.hpp"
17
18namespace mlpack {
19namespace ann {
20
33template <typename DataType = arma::mat>
35{
36 public:
42
63 BernoulliDistribution(const DataType& param,
64 const bool applyLogistic = true,
65 const double eps = 1e-10);
66
72 double Probability(const DataType& observation) const
73 {
74 return std::exp(LogProbability(observation));
75 }
76
82 double LogProbability(const DataType& observation) const;
83
91 void LogProbBackward(const DataType& observation, DataType& output) const;
92
99 DataType Sample() const;
100
102 const DataType& Probability() const { return probability; }
103
105 DataType& Probability() { return probability; }
106
108 const DataType& Logits() const { return logits; }
109
111 DataType& Logits() { return logits; }
112
116 template<typename Archive>
117 void serialize(Archive& ar, const unsigned int /* version */)
118 {
119 // We just need to serialize each of the members.
120 ar & BOOST_SERIALIZATION_NVP(probability);
121 ar & BOOST_SERIALIZATION_NVP(logits);
122 ar & BOOST_SERIALIZATION_NVP(applyLogistic);
123 ar & BOOST_SERIALIZATION_NVP(eps);
124 }
125
126 private:
128 DataType probability;
129
132 DataType logits;
133
135 bool applyLogistic;
136
138 double eps;
139}; // class BernoulliDistribution
140
141} // namespace ann
142} // namespace mlpack
143
144// Include implementation.
145#include "bernoulli_distribution_impl.hpp"
146
147#endif
Multiple independent Bernoulli distributions.
double LogProbability(const DataType &observation) const
Return the log probabilities of the given matrix of observations.
DataType Sample() const
Return a matrix of randomly generated samples according to the probability distributions defined by t...
double Probability(const DataType &observation) const
Return the probabilities of the given matrix of observations.
void LogProbBackward(const DataType &observation, DataType &output) const
Stores the gradient of the log probabilities of the observations in the output matrix.
DataType & Probability()
Return a modifiable copy of the probability matrix.
BernoulliDistribution(const DataType &param, const bool applyLogistic=true, const double eps=1e-10)
Create multiple independent Bernoulli distributions whose p values are given by the param parameter.
BernoulliDistribution()
Default constructor, which creates a Bernoulli distribution with zero dimension.
DataType & Logits()
Return a modifiable copy of the pre probability matrix.
const DataType & Probability() const
Return the probability matrix.
void serialize(Archive &ar, const unsigned int)
Serialize the distribution.
const DataType & Logits() const
Return the logits matrix.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.