mlpack 3.4.2
recurrent.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_ANN_LAYER_RECURRENT_HPP
14#define MLPACK_METHODS_ANN_LAYER_RECURRENT_HPP
15
16#include <mlpack/core.hpp>
17
18#include "../visitor/delete_visitor.hpp"
19#include "../visitor/delta_visitor.hpp"
20#include "../visitor/copy_visitor.hpp"
21#include "../visitor/output_parameter_visitor.hpp"
22
23#include "layer_types.hpp"
24#include "add_merge.hpp"
25#include "sequential.hpp"
26
27namespace mlpack {
28namespace ann {
29
39template <
40 typename InputDataType = arma::mat,
41 typename OutputDataType = arma::mat,
42 typename... CustomLayers
43>
45{
46 public:
52
55
65 template<typename StartModuleType,
66 typename InputModuleType,
67 typename FeedbackModuleType,
68 typename TransferModuleType>
69 Recurrent(const StartModuleType& start,
70 const InputModuleType& input,
71 const FeedbackModuleType& feedback,
72 const TransferModuleType& transfer,
73 const size_t rho);
74
82 template<typename eT>
83 void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
84
94 template<typename eT>
95 void Backward(const arma::Mat<eT>& /* input */,
96 const arma::Mat<eT>& gy,
97 arma::Mat<eT>& g);
98
99 /*
100 * Calculate the gradient using the output delta and the input activation.
101 *
102 * @param input The input parameter used for calculating the gradient.
103 * @param error The calculated error.
104 * @param gradient The calculated gradient.
105 */
106 template<typename eT>
107 void Gradient(const arma::Mat<eT>& input,
108 const arma::Mat<eT>& error,
109 arma::Mat<eT>& /* gradient */);
110
112 std::vector<LayerTypes<CustomLayers...> >& Model() { return network; }
113
115 bool Deterministic() const { return deterministic; }
117 bool& Deterministic() { return deterministic; }
118
120 OutputDataType const& Parameters() const { return parameters; }
122 OutputDataType& Parameters() { return parameters; }
123
125 OutputDataType const& OutputParameter() const { return outputParameter; }
127 OutputDataType& OutputParameter() { return outputParameter; }
128
130 OutputDataType const& Delta() const { return delta; }
132 OutputDataType& Delta() { return delta; }
133
135 OutputDataType const& Gradient() const { return gradient; }
137 OutputDataType& Gradient() { return gradient; }
138
140 size_t const& Rho() const { return rho; }
141
145 template<typename Archive>
146 void serialize(Archive& ar, const unsigned int /* version */);
147
148 private:
150 DeleteVisitor deleteVisitor;
151
153 CopyVisitor<CustomLayers...> copyVisitor;
154
156 LayerTypes<CustomLayers...> startModule;
157
159 LayerTypes<CustomLayers...> inputModule;
160
162 LayerTypes<CustomLayers...> feedbackModule;
163
165 LayerTypes<CustomLayers...> transferModule;
166
168 size_t rho;
169
171 size_t forwardStep;
172
174 size_t backwardStep;
175
177 size_t gradientStep;
178
180 bool deterministic;
181
184 bool ownsLayer;
185
187 OutputDataType parameters;
188
190 LayerTypes<CustomLayers...> initialModule;
191
193 LayerTypes<CustomLayers...> recurrentModule;
194
196 std::vector<LayerTypes<CustomLayers...> > network;
197
199 LayerTypes<CustomLayers...> mergeModule;
200
202 DeltaVisitor deltaVisitor;
203
205 OutputParameterVisitor outputParameterVisitor;
206
208 std::vector<arma::mat> feedbackOutputParameter;
209
211 OutputDataType delta;
212
214 OutputDataType gradient;
215
217 OutputDataType outputParameter;
218
220 arma::mat recurrentError;
221}; // class Recurrent
222
223} // namespace ann
224} // namespace mlpack
225
226// Include implementation.
227#include "recurrent_impl.hpp"
228
229#endif
This visitor is to support copy constructor for neural network module.
DeleteVisitor executes the destructor of the instantiated object.
DeltaVisitor exposes the delta parameter of the given module.
OutputParameterVisitor exposes the output parameter of the given module.
Implementation of the RecurrentLayer class.
Definition: recurrent.hpp:45
OutputDataType const & Delta() const
Get the delta.
Definition: recurrent.hpp:130
OutputDataType const & Parameters() const
Get the parameters.
Definition: recurrent.hpp:120
size_t const & Rho() const
Get the number of steps to backpropagate through time.
Definition: recurrent.hpp:140
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: recurrent.hpp:125
void Gradient(const arma::Mat< eT > &input, const arma::Mat< eT > &error, arma::Mat< eT > &)
Recurrent(const StartModuleType &start, const InputModuleType &input, const FeedbackModuleType &feedback, const TransferModuleType &transfer, const size_t rho)
Create the Recurrent object using the specified modules.
Recurrent()
Default constructor—this will create a Recurrent object that can't be used, so be careful!...
std::vector< LayerTypes< CustomLayers... > > & Model()
Get the model modules.
Definition: recurrent.hpp:112
bool & Deterministic()
Modify the value of the deterministic parameter.
Definition: recurrent.hpp:117
bool Deterministic() const
The value of the deterministic parameter.
Definition: recurrent.hpp:115
OutputDataType const & Gradient() const
Get the gradient.
Definition: recurrent.hpp:135
Recurrent(const Recurrent &)
Copy constructor.
OutputDataType & Gradient()
Modify the gradient.
Definition: recurrent.hpp:137
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: recurrent.hpp:127
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
OutputDataType & Parameters()
Modify the parameters.
Definition: recurrent.hpp:122
OutputDataType & Delta()
Modify the delta.
Definition: recurrent.hpp:132
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
boost::variant< AdaptiveMaxPooling< arma::mat, arma::mat > *, AdaptiveMeanPooling< arma::mat, arma::mat > *, Add< arma::mat, arma::mat > *, AddMerge< arma::mat, arma::mat > *, AlphaDropout< arma::mat, arma::mat > *, AtrousConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, BaseLayer< LogisticFunction, arma::mat, arma::mat > *, BaseLayer< IdentityFunction, arma::mat, arma::mat > *, BaseLayer< TanhFunction, arma::mat, arma::mat > *, BaseLayer< SoftplusFunction, arma::mat, arma::mat > *, BaseLayer< RectifierFunction, arma::mat, arma::mat > *, BatchNorm< arma::mat, arma::mat > *, BilinearInterpolation< arma::mat, arma::mat > *, CELU< arma::mat, arma::mat > *, Concat< arma::mat, arma::mat > *, Concatenate< arma::mat, arma::mat > *, ConcatPerformance< NegativeLogLikelihood< arma::mat, arma::mat >, arma::mat, arma::mat > *, Constant< arma::mat, arma::mat > *, Convolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, CReLU< arma::mat, arma::mat > *, DropConnect< arma::mat, arma::mat > *, Dropout< arma::mat, arma::mat > *, ELU< arma::mat, arma::mat > *, FastLSTM< arma::mat, arma::mat > *, FlexibleReLU< arma::mat, arma::mat > *, GRU< arma::mat, arma::mat > *, HardTanH< arma::mat, arma::mat > *, Join< arma::mat, arma::mat > *, LayerNorm< arma::mat, arma::mat > *, LeakyReLU< arma::mat, arma::mat > *, Linear< arma::mat, arma::mat, NoRegularizer > *, LinearNoBias< arma::mat, arma::mat, NoRegularizer > *, LogSoftMax< arma::mat, arma::mat > *, Lookup< arma::mat, arma::mat > *, LSTM< arma::mat, arma::mat > *, MaxPooling< arma::mat, arma::mat > *, MeanPooling< arma::mat, arma::mat > *, MiniBatchDiscrimination< arma::mat, arma::mat > *, MultiplyConstant< arma::mat, arma::mat > *, MultiplyMerge< arma::mat, arma::mat > *, NegativeLogLikelihood< arma::mat, arma::mat > *, NoisyLinear< arma::mat, arma::mat > *, Padding< arma::mat, arma::mat > *, PReLU< arma::mat, arma::mat > *, Softmax< arma::mat, arma::mat > *, SpatialDropout< arma::mat, arma::mat > *, TransposedConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, WeightNorm< arma::mat, arma::mat > *, MoreTypes, CustomLayers *... > LayerTypes
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1