mlpack 3.4.2
batch_norm.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_ANN_LAYER_BATCHNORM_HPP
14#define MLPACK_METHODS_ANN_LAYER_BATCHNORM_HPP
15
16#include <mlpack/prereqs.hpp>
17
18namespace mlpack {
19namespace ann {
20
52template <
53 typename InputDataType = arma::mat,
54 typename OutputDataType = arma::mat
55>
57{
58 public:
61
71 BatchNorm(const size_t size,
72 const double eps = 1e-8,
73 const bool average = true,
74 const double momentum = 0.1);
75
79 void Reset();
80
89 template<typename eT>
90 void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
91
99 template<typename eT>
100 void Backward(const arma::Mat<eT>& input,
101 const arma::Mat<eT>& gy,
102 arma::Mat<eT>& g);
103
111 template<typename eT>
112 void Gradient(const arma::Mat<eT>& input,
113 const arma::Mat<eT>& error,
114 arma::Mat<eT>& gradient);
115
117 OutputDataType const& Parameters() const { return weights; }
119 OutputDataType& Parameters() { return weights; }
120
122 OutputDataType const& OutputParameter() const { return outputParameter; }
124 OutputDataType& OutputParameter() { return outputParameter; }
125
127 OutputDataType const& Delta() const { return delta; }
129 OutputDataType& Delta() { return delta; }
130
132 OutputDataType const& Gradient() const { return gradient; }
134 OutputDataType& Gradient() { return gradient; }
135
137 bool Deterministic() const { return deterministic; }
139 bool& Deterministic() { return deterministic; }
140
142 OutputDataType const& TrainingMean() const { return runningMean; }
144 OutputDataType& TrainingMean() { return runningMean; }
145
147 OutputDataType const& TrainingVariance() const { return runningVariance; }
149 OutputDataType& TrainingVariance() { return runningVariance; }
150
152 size_t InputSize() const { return size; }
153
155 double Epsilon() const { return eps; }
156
158 double Momentum() const { return momentum; }
159
161 bool Average() const { return average; }
162
164 size_t WeightSize() const { return 2 * size; }
165
169 template<typename Archive>
170 void serialize(Archive& ar, const unsigned int /* version */);
171
172 private:
174 size_t size;
175
177 double eps;
178
181 bool average;
182
184 double momentum;
185
187 bool loading;
188
190 OutputDataType gamma;
191
193 OutputDataType beta;
194
196 OutputDataType mean;
197
199 OutputDataType variance;
200
202 OutputDataType weights;
203
208 bool deterministic;
209
211 size_t count;
212
215 double averageFactor;
216
218 OutputDataType runningMean;
219
221 OutputDataType runningVariance;
222
224 OutputDataType gradient;
225
227 OutputDataType delta;
228
230 OutputDataType outputParameter;
231
233 arma::cube normalized;
234
236 arma::cube inputMean;
237}; // class BatchNorm
238
239} // namespace ann
240} // namespace mlpack
241
242// Include the implementation.
243#include "batch_norm_impl.hpp"
244
245#endif
Declaration of the Batch Normalization layer class.
Definition: batch_norm.hpp:57
OutputDataType const & Delta() const
Get the delta.
Definition: batch_norm.hpp:127
OutputDataType const & TrainingVariance() const
Get the variance over the training data.
Definition: batch_norm.hpp:147
OutputDataType const & Parameters() const
Get the parameters.
Definition: batch_norm.hpp:117
void Reset()
Reset the layer parameters.
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Forward pass of the Batch Normalization layer.
double Momentum() const
Get the momentum value.
Definition: batch_norm.hpp:158
BatchNorm(const size_t size, const double eps=1e-8, const bool average=true, const double momentum=0.1)
Create the BatchNorm layer object for a specified number of input units.
OutputDataType & TrainingVariance()
Modify the variance over the training data.
Definition: batch_norm.hpp:149
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: batch_norm.hpp:122
size_t InputSize() const
Get the number of input units / channels.
Definition: batch_norm.hpp:152
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Backward pass through the layer.
size_t WeightSize() const
Get size of weights.
Definition: batch_norm.hpp:164
BatchNorm()
Create the BatchNorm object.
bool & Deterministic()
Modify the value of deterministic parameter.
Definition: batch_norm.hpp:139
bool Deterministic() const
Get the value of deterministic parameter.
Definition: batch_norm.hpp:137
OutputDataType const & TrainingMean() const
Get the mean over the training data.
Definition: batch_norm.hpp:142
bool Average() const
Get the average parameter.
Definition: batch_norm.hpp:161
void Gradient(const arma::Mat< eT > &input, const arma::Mat< eT > &error, arma::Mat< eT > &gradient)
Calculate the gradient using the output delta and the input activations.
OutputDataType const & Gradient() const
Get the gradient.
Definition: batch_norm.hpp:132
OutputDataType & Gradient()
Modify the gradient.
Definition: batch_norm.hpp:134
OutputDataType & TrainingMean()
Modify the mean over the training data.
Definition: batch_norm.hpp:144
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: batch_norm.hpp:124
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
OutputDataType & Parameters()
Modify the parameters.
Definition: batch_norm.hpp:119
double Epsilon() const
Get the epsilon value.
Definition: batch_norm.hpp:155
OutputDataType & Delta()
Modify the delta.
Definition: batch_norm.hpp:129
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.