mlpack 3.4.2
backward_visitor.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_ANN_VISITOR_BACKWARD_VISITOR_HPP
14#define MLPACK_METHODS_ANN_VISITOR_BACKWARD_VISITOR_HPP
15
18
19#include <boost/variant.hpp>
20
21namespace mlpack {
22namespace ann {
23
28class BackwardVisitor : public boost::static_visitor<void>
29{
30 public:
33 BackwardVisitor(const arma::mat& input,
34 const arma::mat& error,
35 arma::mat& delta);
36
38 BackwardVisitor(const arma::mat& input,
39 const arma::mat& error,
40 arma::mat& delta,
41 const size_t index);
42
44 template<typename LayerType>
45 void operator()(LayerType* layer) const;
46
47 void operator()(MoreTypes layer) const;
48
49 private:
51 const arma::mat& input;
52
54 const arma::mat& error;
55
57 arma::mat& delta;
58
60 size_t index;
61
63 bool hasIndex;
64
67 template<typename T>
68 typename std::enable_if<
69 !HasRunCheck<T, bool&(T::*)(void)>::value, void>::type
70 LayerBackward(T* layer, arma::mat& input) const;
71
73 template<typename T>
74 typename std::enable_if<
75 HasRunCheck<T, bool&(T::*)(void)>::value, void>::type
76 LayerBackward(T* layer, arma::mat& input) const;
77};
78
79} // namespace ann
80} // namespace mlpack
81
82// Include implementation.
83#include "backward_visitor_impl.hpp"
84
85#endif
BackwardVisitor executes the Backward() function given the input, error and delta parameter.
void operator()(LayerType *layer) const
Execute the Backward() function.
BackwardVisitor(const arma::mat &input, const arma::mat &error, arma::mat &delta)
Execute the Backward() function given the input, error and delta parameter.
BackwardVisitor(const arma::mat &input, const arma::mat &error, arma::mat &delta, const size_t index)
Execute the Backward() function for the layer with the specified index.
void operator()(MoreTypes layer) const
boost::variant< Linear3D< arma::mat, arma::mat, NoRegularizer > *, Glimpse< arma::mat, arma::mat > *, Highway< arma::mat, arma::mat > *, MultiheadAttention< arma::mat, arma::mat, NoRegularizer > *, Recurrent< arma::mat, arma::mat > *, RecurrentAttention< arma::mat, arma::mat > *, ReinforceNormal< arma::mat, arma::mat > *, Reparametrization< arma::mat, arma::mat > *, Select< arma::mat, arma::mat > *, Sequential< arma::mat, arma::mat, false > *, Sequential< arma::mat, arma::mat, true > *, Subview< arma::mat, arma::mat > *, VRClassReward< arma::mat, arma::mat > *, VirtualBatchNorm< arma::mat, arma::mat > *, RBF< arma::mat, arma::mat, GaussianFunction > *, BaseLayer< GaussianFunction, arma::mat, arma::mat > *, PositionalEncoding< arma::mat, arma::mat > * > MoreTypes
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1