mlpack 3.4.2
reparametrization.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_HPP
14#define MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_HPP
15
16#include <mlpack/prereqs.hpp>
17
18#include "layer_types.hpp"
19#include "../activation_functions/softplus_function.hpp"
20
21namespace mlpack {
22namespace ann {
23
52template <
53 typename InputDataType = arma::mat,
54 typename OutputDataType = arma::mat
55>
57{
58 public:
61
70 Reparametrization(const size_t latentSize,
71 const bool stochastic = true,
72 const bool includeKl = true,
73 const double beta = 1);
74
82 template<typename eT>
83 void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
84
94 template<typename eT>
95 void Backward(const arma::Mat<eT>& input,
96 const arma::Mat<eT>& gy,
97 arma::Mat<eT>& g);
98
100 OutputDataType const& OutputParameter() const { return outputParameter; }
102 OutputDataType& OutputParameter() { return outputParameter; }
103
105 OutputDataType const& Delta() const { return delta; }
107 OutputDataType& Delta() { return delta; }
108
110 size_t const& OutputSize() const { return latentSize; }
112 size_t& OutputSize() { return latentSize; }
113
115 double Loss()
116 {
117 if (!includeKl)
118 return 0;
119
120 return -0.5 * beta * arma::accu(2 * arma::log(stdDev) - arma::pow(stdDev, 2)
121 - arma::pow(mean, 2) + 1) / mean.n_cols;
122 }
123
125 bool Stochastic() const { return stochastic; }
126
128 bool IncludeKL() const { return includeKl; }
129
131 double Beta() const { return beta; }
132
136 template<typename Archive>
137 void serialize(Archive& ar, const unsigned int /* version */);
138
139 private:
141 size_t latentSize;
142
144 bool stochastic;
145
147 bool includeKl;
148
150 double beta;
151
153 OutputDataType delta;
154
156 OutputDataType gaussianSample;
157
159 OutputDataType mean;
160
163 OutputDataType preStdDev;
164
166 OutputDataType stdDev;
167
169 OutputDataType outputParameter;
170}; // class Reparametrization
171
172} // namespace ann
173} // namespace mlpack
174
175// Include implementation.
176#include "reparametrization_impl.hpp"
177
178#endif
Implementation of the Reparametrization layer class.
OutputDataType const & Delta() const
Get the delta.
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Reparametrization(const size_t latentSize, const bool stochastic=true, const bool includeKl=true, const double beta=1)
Create the Reparametrization layer object using the specified sample vector size.
size_t & OutputSize()
Modify the output size.
OutputDataType const & OutputParameter() const
Get the output parameter.
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Reparametrization()
Create the Reparametrization object.
size_t const & OutputSize() const
Get the output size.
double Loss()
Get the KL divergence with standard normal.
bool IncludeKL() const
Get the value of the includeKl parameter.
double Beta() const
Get the value of the beta hyperparameter.
OutputDataType & OutputParameter()
Modify the output parameter.
bool Stochastic() const
Get the value of the stochastic parameter.
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
OutputDataType & Delta()
Modify the delta.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.