mlpack 3.4.2
weight_norm.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_METHODS_ANN_LAYER_WEIGHTNORM_HPP
13#define MLPACK_METHODS_ANN_LAYER_WEIGHTNORM_HPP
14
15#include <mlpack/prereqs.hpp>
16#include "layer_types.hpp"
17
18#include "../visitor/delete_visitor.hpp"
19#include "../visitor/delta_visitor.hpp"
20#include "../visitor/output_parameter_visitor.hpp"
21#include "../visitor/reset_visitor.hpp"
22#include "../visitor/weight_size_visitor.hpp"
23#include "../visitor/weight_set_visitor.hpp"
24
25namespace mlpack {
26namespace ann {
27
56template <
57 typename InputDataType = arma::mat,
58 typename OutputDataType = arma::mat,
59 typename... CustomLayers
60>
62{
63 public:
70
73
77 void Reset();
78
88 template<typename eT>
89 void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
90
99 template<typename eT>
100 void Backward(const arma::Mat<eT>& input,
101 const arma::Mat<eT>& gy,
102 arma::Mat<eT>& g);
103
112 template<typename eT>
113 void Gradient(const arma::Mat<eT>& input,
114 const arma::Mat<eT>& error,
115 arma::Mat<eT>& gradient);
116
118 OutputDataType const& Delta() const { return delta; }
120 OutputDataType& Delta() { return delta; }
121
123 OutputDataType const& Gradient() const { return gradient; }
125 OutputDataType& Gradient() { return gradient; }
126
128 OutputDataType const& OutputParameter() const { return outputParameter; }
130 OutputDataType& OutputParameter() { return outputParameter; }
131
133 OutputDataType const& Parameters() const { return weights; }
135 OutputDataType& Parameters() { return weights; }
136
138 LayerTypes<CustomLayers...> const& Layer() { return wrappedLayer; }
139
143 template<typename Archive>
144 void serialize(Archive& ar, const unsigned int /* version */);
145
146 private:
148 size_t biasWeightSize;
149
151 DeleteVisitor deleteVisitor;
152
154 OutputDataType delta;
155
157 DeltaVisitor deltaVisitor;
158
160 OutputDataType gradient;
161
163 LayerTypes<CustomLayers...> wrappedLayer;
164
166 size_t layerWeightSize;
167
169 OutputDataType outputParameter;
170
172 OutputParameterVisitor outputParameterVisitor;
173
175 void ResetGradients(arma::mat& gradient);
176
178 ResetVisitor resetVisitor;
179
181 OutputDataType scalarParameter;
182
184 OutputDataType vectorParameter;
185
187 OutputDataType weights;
188
190 WeightSizeVisitor weightSizeVisitor;
191
193 OutputDataType layerGradients;
194
196 OutputDataType layerWeights;
197}; // class WeightNorm
198
199} // namespace ann
200} // namespace mlpack
201
202// Include the implementation.
203#include "weight_norm_impl.hpp"
204
205#endif
DeleteVisitor executes the destructor of the instantiated object.
DeltaVisitor exposes the delta parameter of the given module.
OutputParameterVisitor exposes the output parameter of the given module.
ResetVisitor executes the Reset() function.
Declaration of the WeightNorm layer class.
Definition: weight_norm.hpp:62
WeightNorm(LayerTypes< CustomLayers... > layer=LayerTypes< CustomLayers... >())
Create the WeightNorm layer object.
OutputDataType const & Delta() const
Get the delta.
OutputDataType const & Parameters() const
Get the parameters.
void Reset()
Reset the layer parameters.
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Forward pass of the WeightNorm layer.
OutputDataType const & OutputParameter() const
Get the output parameter.
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Backward pass through the layer.
~WeightNorm()
Destructor to release allocated memory.
void Gradient(const arma::Mat< eT > &input, const arma::Mat< eT > &error, arma::Mat< eT > &gradient)
Calculate the gradient using the output delta, input activations and the weights of the wrapped layer...
OutputDataType const & Gradient() const
Get the gradient.
OutputDataType & Gradient()
Modify the gradient.
LayerTypes< CustomLayers... > const & Layer()
Get the wrapped layer.
OutputDataType & OutputParameter()
Modify the output parameter.
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
OutputDataType & Parameters()
Modify the parameters.
OutputDataType & Delta()
Modify the delta.
WeightSizeVisitor returns the number of weights of the given module.
boost::variant< AdaptiveMaxPooling< arma::mat, arma::mat > *, AdaptiveMeanPooling< arma::mat, arma::mat > *, Add< arma::mat, arma::mat > *, AddMerge< arma::mat, arma::mat > *, AlphaDropout< arma::mat, arma::mat > *, AtrousConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, BaseLayer< LogisticFunction, arma::mat, arma::mat > *, BaseLayer< IdentityFunction, arma::mat, arma::mat > *, BaseLayer< TanhFunction, arma::mat, arma::mat > *, BaseLayer< SoftplusFunction, arma::mat, arma::mat > *, BaseLayer< RectifierFunction, arma::mat, arma::mat > *, BatchNorm< arma::mat, arma::mat > *, BilinearInterpolation< arma::mat, arma::mat > *, CELU< arma::mat, arma::mat > *, Concat< arma::mat, arma::mat > *, Concatenate< arma::mat, arma::mat > *, ConcatPerformance< NegativeLogLikelihood< arma::mat, arma::mat >, arma::mat, arma::mat > *, Constant< arma::mat, arma::mat > *, Convolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, CReLU< arma::mat, arma::mat > *, DropConnect< arma::mat, arma::mat > *, Dropout< arma::mat, arma::mat > *, ELU< arma::mat, arma::mat > *, FastLSTM< arma::mat, arma::mat > *, FlexibleReLU< arma::mat, arma::mat > *, GRU< arma::mat, arma::mat > *, HardTanH< arma::mat, arma::mat > *, Join< arma::mat, arma::mat > *, LayerNorm< arma::mat, arma::mat > *, LeakyReLU< arma::mat, arma::mat > *, Linear< arma::mat, arma::mat, NoRegularizer > *, LinearNoBias< arma::mat, arma::mat, NoRegularizer > *, LogSoftMax< arma::mat, arma::mat > *, Lookup< arma::mat, arma::mat > *, LSTM< arma::mat, arma::mat > *, MaxPooling< arma::mat, arma::mat > *, MeanPooling< arma::mat, arma::mat > *, MiniBatchDiscrimination< arma::mat, arma::mat > *, MultiplyConstant< arma::mat, arma::mat > *, MultiplyMerge< arma::mat, arma::mat > *, NegativeLogLikelihood< arma::mat, arma::mat > *, NoisyLinear< arma::mat, arma::mat > *, Padding< arma::mat, arma::mat > *, PReLU< arma::mat, arma::mat > *, Softmax< arma::mat, arma::mat > *, SpatialDropout< arma::mat, arma::mat > *, TransposedConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, WeightNorm< arma::mat, arma::mat > *, MoreTypes, CustomLayers *... > LayerTypes
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.