mlpack 3.4.2
highway.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_ANN_LAYER_HIGHWAY_HPP
14#define MLPACK_METHODS_ANN_LAYER_HIGHWAY_HPP
15
16#include <mlpack/prereqs.hpp>
17
18#include <boost/ptr_container/ptr_vector.hpp>
19
20#include "../visitor/delete_visitor.hpp"
21#include "../visitor/delta_visitor.hpp"
22#include "../visitor/output_height_visitor.hpp"
23#include "../visitor/output_parameter_visitor.hpp"
24#include "../visitor/output_width_visitor.hpp"
25
26#include "layer_types.hpp"
27#include "add_merge.hpp"
28
29namespace mlpack {
30namespace ann {
31
56template <
57 typename InputDataType = arma::mat,
58 typename OutputDataType = arma::mat,
59 typename... CustomLayers>
61{
62 public:
65
72 Highway(const size_t inSize, const bool model = true);
73
76
80 void Reset();
81
89 template<typename eT>
90 void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
91
101 template<typename eT>
102 void Backward(const arma::Mat<eT>& /* input */,
103 const arma::Mat<eT>& gy,
104 arma::Mat<eT>& g);
105
113 template<typename eT>
114 void Gradient(const arma::Mat<eT>& input,
115 const arma::Mat<eT>& error,
116 arma::Mat<eT>& gradient);
117
123 template <class LayerType, class... Args>
124 void Add(Args... args)
125 {
126 network.push_back(new LayerType(args...));
127 networkOwnerships.push_back(true);
128 }
129
136 {
137 network.push_back(layer);
138 networkOwnerships.push_back(false);
139 }
140
142 std::vector<LayerTypes<CustomLayers...> >& Model()
143 {
144 if (model)
145 {
146 return network;
147 }
148
149 return empty;
150 }
151
153 OutputDataType const& Parameters() const { return weights; }
155 OutputDataType& Parameters() { return weights; }
156
158 InputDataType const& InputParameter() const { return inputParameter; }
160 InputDataType& InputParameter() { return inputParameter; }
161
163 OutputDataType const& OutputParameter() const { return outputParameter; }
165 OutputDataType& OutputParameter() { return outputParameter; }
166
168 OutputDataType const& Delta() const { return delta; }
170 OutputDataType& Delta() { return delta; }
171
173 OutputDataType const& Gradient() const { return gradient; }
175 OutputDataType& Gradient() { return gradient; }
176
178 size_t InSize() const { return inSize; }
179
183 template<typename Archive>
184 void serialize(Archive& ar, const unsigned int /* version */);
185
186 private:
188 size_t inSize;
189
191 bool model;
192
194 bool reset;
195
197 std::vector<LayerTypes<CustomLayers...> > network;
198
200 std::vector<bool> networkOwnerships;
201
203 std::vector<LayerTypes<CustomLayers...> > empty;
204
206 OutputDataType weights;
207
209 OutputDataType delta;
210
212 OutputDataType gradient;
213
215 OutputDataType transformWeight;
216
218 OutputDataType transformBias;
219
221 OutputDataType transformGate;
222
224 OutputDataType transformGateActivation;
225
227 OutputDataType transformGateError;
228
230 InputDataType inputParameter;
231
233 OutputDataType outputParameter;
234
236 size_t width;
237
239 size_t height;
240
242 OutputDataType networkOutput;
243
245 DeltaVisitor deltaVisitor;
246
248 OutputParameterVisitor outputParameterVisitor;
249
251 DeleteVisitor deleteVisitor;
252
254 OutputWidthVisitor outputWidthVisitor;
255
257 OutputHeightVisitor outputHeightVisitor;
258}; // class Highway
259
260} // namespace ann
261} // namespace mlpack
262
263// Include implementation.
264#include "highway_impl.hpp"
265
266#endif
DeleteVisitor executes the destructor of the instantiated object.
DeltaVisitor exposes the delta parameter of the given module.
Implementation of the Highway layer.
Definition: highway.hpp:61
OutputDataType const & Delta() const
Get the delta.
Definition: highway.hpp:168
OutputDataType const & Parameters() const
Get the parameters.
Definition: highway.hpp:153
void Reset()
Reset the layer parameter.
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed-forward pass of a neural network, evaluating the function f(x) by propagating the activ...
InputDataType & InputParameter()
Modify the input parameter.
Definition: highway.hpp:160
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: highway.hpp:163
void Add(LayerTypes< CustomLayers... > layer)
Add a new module to the model.
Definition: highway.hpp:135
std::vector< LayerTypes< CustomLayers... > > & Model()
Return the modules of the model.
Definition: highway.hpp:142
void Add(Args... args)
Add a new module to the model.
Definition: highway.hpp:124
Highway(const size_t inSize, const bool model=true)
Create the Highway object.
void Gradient(const arma::Mat< eT > &input, const arma::Mat< eT > &error, arma::Mat< eT > &gradient)
Calculate the gradient using the output delta and the input activation.
OutputDataType const & Gradient() const
Get the gradient.
Definition: highway.hpp:173
InputDataType const & InputParameter() const
Get the input parameter.
Definition: highway.hpp:158
OutputDataType & Gradient()
Modify the gradient.
Definition: highway.hpp:175
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed-backward pass of a neural network, calculating the function f(x) by propagating x backw...
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: highway.hpp:165
size_t InSize() const
Get the number of input units.
Definition: highway.hpp:178
Highway()
Create the Highway object.
~Highway()
Destroy the Highway object.
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
OutputDataType & Parameters()
Modify the parameters.
Definition: highway.hpp:155
OutputDataType & Delta()
Modify the delta.
Definition: highway.hpp:170
OutputHeightVisitor exposes the OutputHeight() method of the given module.
OutputParameterVisitor exposes the output parameter of the given module.
OutputWidthVisitor exposes the OutputWidth() method of the given module.
boost::variant< AdaptiveMaxPooling< arma::mat, arma::mat > *, AdaptiveMeanPooling< arma::mat, arma::mat > *, Add< arma::mat, arma::mat > *, AddMerge< arma::mat, arma::mat > *, AlphaDropout< arma::mat, arma::mat > *, AtrousConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, BaseLayer< LogisticFunction, arma::mat, arma::mat > *, BaseLayer< IdentityFunction, arma::mat, arma::mat > *, BaseLayer< TanhFunction, arma::mat, arma::mat > *, BaseLayer< SoftplusFunction, arma::mat, arma::mat > *, BaseLayer< RectifierFunction, arma::mat, arma::mat > *, BatchNorm< arma::mat, arma::mat > *, BilinearInterpolation< arma::mat, arma::mat > *, CELU< arma::mat, arma::mat > *, Concat< arma::mat, arma::mat > *, Concatenate< arma::mat, arma::mat > *, ConcatPerformance< NegativeLogLikelihood< arma::mat, arma::mat >, arma::mat, arma::mat > *, Constant< arma::mat, arma::mat > *, Convolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, CReLU< arma::mat, arma::mat > *, DropConnect< arma::mat, arma::mat > *, Dropout< arma::mat, arma::mat > *, ELU< arma::mat, arma::mat > *, FastLSTM< arma::mat, arma::mat > *, FlexibleReLU< arma::mat, arma::mat > *, GRU< arma::mat, arma::mat > *, HardTanH< arma::mat, arma::mat > *, Join< arma::mat, arma::mat > *, LayerNorm< arma::mat, arma::mat > *, LeakyReLU< arma::mat, arma::mat > *, Linear< arma::mat, arma::mat, NoRegularizer > *, LinearNoBias< arma::mat, arma::mat, NoRegularizer > *, LogSoftMax< arma::mat, arma::mat > *, Lookup< arma::mat, arma::mat > *, LSTM< arma::mat, arma::mat > *, MaxPooling< arma::mat, arma::mat > *, MeanPooling< arma::mat, arma::mat > *, MiniBatchDiscrimination< arma::mat, arma::mat > *, MultiplyConstant< arma::mat, arma::mat > *, MultiplyMerge< arma::mat, arma::mat > *, NegativeLogLikelihood< arma::mat, arma::mat > *, NoisyLinear< arma::mat, arma::mat > *, Padding< arma::mat, arma::mat > *, PReLU< arma::mat, arma::mat > *, Softmax< arma::mat, arma::mat > *, SpatialDropout< arma::mat, arma::mat > *, TransposedConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, WeightNorm< arma::mat, arma::mat > *, MoreTypes, CustomLayers *... > LayerTypes
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.