12#ifndef MLPACK_METHODS_ANN_LAYER_RECURRENT_ATTENTION_HPP
13#define MLPACK_METHODS_ANN_LAYER_RECURRENT_ATTENTION_HPP
16#include <boost/ptr_container/ptr_vector.hpp>
18#include "../visitor/delta_visitor.hpp"
19#include "../visitor/output_parameter_visitor.hpp"
20#include "../visitor/reset_visitor.hpp"
21#include "../visitor/weight_size_visitor.hpp"
53 typename InputDataType = arma::mat,
54 typename OutputDataType = arma::mat
73 template<
typename RNNModuleType,
typename ActionModuleType>
75 const RNNModuleType& rnn,
76 const ActionModuleType& action,
87 void Forward(
const arma::Mat<eT>& input, arma::Mat<eT>& output);
100 const arma::Mat<eT>& gy,
110 template<
typename eT>
112 const arma::Mat<eT>& ,
116 std::vector<LayerTypes<>>&
Model() {
return network; }
124 OutputDataType
const&
Parameters()
const {
return parameters; }
134 OutputDataType
const&
Delta()
const {
return delta; }
136 OutputDataType&
Delta() {
return delta; }
139 OutputDataType
const&
Gradient()
const {
return gradient; }
147 size_t const&
Rho()
const {
return rho; }
152 template<
typename Archive>
157 void IntermediateGradient()
159 intermediateGradient.zeros();
162 if (backwardStep == (rho - 1))
170 outputParameterVisitor, actionModule), actionError),
175 boost::apply_visitor(GradientVisitor(boost::apply_visitor(
176 outputParameterVisitor, rnnModule), recurrentError),
179 attentionGradient += intermediateGradient;
186 LayerTypes<> rnnModule;
189 LayerTypes<> actionModule;
204 OutputDataType parameters;
207 std::vector<LayerTypes<>> network;
210 WeightSizeVisitor weightSizeVisitor;
213 DeltaVisitor deltaVisitor;
216 OutputParameterVisitor outputParameterVisitor;
219 std::vector<arma::mat> feedbackOutputParameter;
222 std::vector<arma::mat> moduleOutputParameter;
225 OutputDataType delta;
228 OutputDataType gradient;
231 OutputDataType outputParameter;
234 arma::mat recurrentError;
237 arma::mat actionError;
240 arma::mat actionDelta;
246 arma::mat initialInput;
249 ResetVisitor resetVisitor;
252 arma::mat attentionGradient;
255 arma::mat intermediateGradient;
262#include "recurrent_attention_impl.hpp"
SearchModeVisitor executes the Gradient() method of the given module using the input and delta parame...
This class implements the Recurrent Model for Visual Attention, using a variety of possible layer imp...
OutputDataType const & Delta() const
Get the delta.
OutputDataType const & Parameters() const
Get the parameters.
RecurrentAttention()
Default constructor: this will not give a usable RecurrentAttention object, so be sure to set all the...
size_t const & Rho() const
Get the number of steps to backpropagate through time.
size_t OutSize() const
Get the module output size.
std::vector< LayerTypes<> > & Model()
Get the model modules.
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
OutputDataType const & OutputParameter() const
Get the output parameter.
RecurrentAttention(const size_t outSize, const RNNModuleType &rnn, const ActionModuleType &action, const size_t rho)
Create the RecurrentAttention object using the specified modules.
bool & Deterministic()
Modify the value of the deterministic parameter.
bool Deterministic() const
The value of the deterministic parameter.
OutputDataType const & Gradient() const
Get the gradient.
OutputDataType & Gradient()
Modify the gradient.
void Gradient(const arma::Mat< eT > &, const arma::Mat< eT > &, arma::Mat< eT > &)
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
OutputDataType & OutputParameter()
Modify the output parameter.
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
OutputDataType & Parameters()
Modify the parameters.
OutputDataType & Delta()
Modify the delta.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.