mlpack 3.4.2
layer_norm.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_METHODS_ANN_LAYER_LAYERNORM_HPP
13#define MLPACK_METHODS_ANN_LAYER_LAYERNORM_HPP
14
15#include <mlpack/prereqs.hpp>
16
17namespace mlpack {
18namespace ann {
19
61template <
62 typename InputDataType = arma::mat,
63 typename OutputDataType = arma::mat
64>
66{
67 public:
70
77 LayerNorm(const size_t size, const double eps = 1e-8);
78
82 void Reset();
83
92 template<typename eT>
93 void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
94
102 template<typename eT>
103 void Backward(const arma::Mat<eT>& input,
104 const arma::Mat<eT>& gy,
105 arma::Mat<eT>& g);
106
114 template<typename eT>
115 void Gradient(const arma::Mat<eT>& input,
116 const arma::Mat<eT>& error,
117 arma::Mat<eT>& gradient);
118
120 OutputDataType const& Parameters() const { return weights; }
122 OutputDataType& Parameters() { return weights; }
123
125 OutputDataType const& OutputParameter() const { return outputParameter; }
127 OutputDataType& OutputParameter() { return outputParameter; }
128
130 OutputDataType const& Delta() const { return delta; }
132 OutputDataType& Delta() { return delta; }
133
135 OutputDataType const& Gradient() const { return gradient; }
137 OutputDataType& Gradient() { return gradient; }
138
140 OutputDataType Mean() { return mean; }
141
143 OutputDataType Variance() { return variance; }
144
146 size_t InSize() const { return size; }
147
149 double Epsilon() const { return eps; }
150
154 template<typename Archive>
155 void serialize(Archive& ar, const unsigned int /* version */);
156
157 private:
159 size_t size;
160
162 double eps;
163
165 bool loading;
166
168 OutputDataType gamma;
169
171 OutputDataType beta;
172
174 OutputDataType weights;
175
177 OutputDataType mean;
178
180 OutputDataType variance;
181
183 OutputDataType gradient;
184
186 OutputDataType delta;
187
189 OutputDataType outputParameter;
190
192 OutputDataType normalized;
193
195 OutputDataType inputMean;
196}; // class LayerNorm
197
198} // namespace ann
199} // namespace mlpack
200
201// Include the implementation.
202#include "layer_norm_impl.hpp"
203
204#endif
Declaration of the Layer Normalization class.
Definition: layer_norm.hpp:66
LayerNorm(const size_t size, const double eps=1e-8)
Create the LayerNorm object for a specified number of input units.
OutputDataType const & Delta() const
Get the delta.
Definition: layer_norm.hpp:130
OutputDataType const & Parameters() const
Get the parameters.
Definition: layer_norm.hpp:120
void Reset()
Reset the layer parameters.
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Forward pass of Layer Normalization.
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: layer_norm.hpp:125
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Backward pass through the layer.
OutputDataType Variance()
Get the variance across single training data.
Definition: layer_norm.hpp:143
LayerNorm()
Create the LayerNorm object.
void Gradient(const arma::Mat< eT > &input, const arma::Mat< eT > &error, arma::Mat< eT > &gradient)
Calculate the gradient using the output delta and the input activations.
OutputDataType const & Gradient() const
Get the gradient.
Definition: layer_norm.hpp:135
OutputDataType Mean()
Get the mean across single training data.
Definition: layer_norm.hpp:140
OutputDataType & Gradient()
Modify the gradient.
Definition: layer_norm.hpp:137
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: layer_norm.hpp:127
size_t InSize() const
Get the number of input units.
Definition: layer_norm.hpp:146
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
OutputDataType & Parameters()
Modify the parameters.
Definition: layer_norm.hpp:122
double Epsilon() const
Get the value of epsilon.
Definition: layer_norm.hpp:149
OutputDataType & Delta()
Modify the delta.
Definition: layer_norm.hpp:132
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.