mlpack 3.4.2
gradient_visitor.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_ANN_VISITOR_GRADIENT_VISITOR_HPP
14#define MLPACK_METHODS_ANN_VISITOR_GRADIENT_VISITOR_HPP
15
18
19#include <boost/variant.hpp>
20
21namespace mlpack {
22namespace ann {
23
28class GradientVisitor : public boost::static_visitor<void>
29{
30 public:
33 GradientVisitor(const arma::mat& input, const arma::mat& delta);
34
36 GradientVisitor(const arma::mat& input,
37 const arma::mat& delta,
38 const size_t index);
39
41 template<typename LayerType>
42 void operator()(LayerType* layer) const;
43
44 void operator()(MoreTypes layer) const;
45
46 private:
48 const arma::mat& input;
49
51 const arma::mat& delta;
52
54 size_t index;
55
57 bool hasIndex;
58
61 template<typename T>
62 typename std::enable_if<
63 HasGradientCheck<T, arma::mat&(T::*)()>::value &&
64 !HasRunCheck<T, bool&(T::*)(void)>::value, void>::type
65 LayerGradients(T* layer, arma::mat& input) const;
66
69 template<typename T>
70 typename std::enable_if<
71 HasGradientCheck<T, arma::mat&(T::*)()>::value &&
72 HasRunCheck<T, bool&(T::*)(void)>::value, void>::type
73 LayerGradients(T* layer, arma::mat& input) const;
74
77 template<typename T, typename P>
78 typename std::enable_if<
79 !HasGradientCheck<T, P&(T::*)()>::value, void>::type
80 LayerGradients(T* layer, P& input) const;
81};
82
83} // namespace ann
84} // namespace mlpack
85
86// Include implementation.
87#include "gradient_visitor_impl.hpp"
88
89#endif
SearchModeVisitor executes the Gradient() method of the given module using the input and delta parame...
void operator()(LayerType *layer) const
Executes the Gradient() method.
GradientVisitor(const arma::mat &input, const arma::mat &delta, const size_t index)
Executes the Gradient() method for the layer with the specified index.
void operator()(MoreTypes layer) const
GradientVisitor(const arma::mat &input, const arma::mat &delta)
Executes the Gradient() method of the given module using the input and delta parameter.
boost::variant< Linear3D< arma::mat, arma::mat, NoRegularizer > *, Glimpse< arma::mat, arma::mat > *, Highway< arma::mat, arma::mat > *, MultiheadAttention< arma::mat, arma::mat, NoRegularizer > *, Recurrent< arma::mat, arma::mat > *, RecurrentAttention< arma::mat, arma::mat > *, ReinforceNormal< arma::mat, arma::mat > *, Reparametrization< arma::mat, arma::mat > *, Select< arma::mat, arma::mat > *, Sequential< arma::mat, arma::mat, false > *, Sequential< arma::mat, arma::mat, true > *, Subview< arma::mat, arma::mat > *, VRClassReward< arma::mat, arma::mat > *, VirtualBatchNorm< arma::mat, arma::mat > *, RBF< arma::mat, arma::mat, GaussianFunction > *, BaseLayer< GaussianFunction, arma::mat, arma::mat > *, PositionalEncoding< arma::mat, arma::mat > * > MoreTypes
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1