mlpack 3.4.2
gradient_update_visitor.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_ANN_VISITOR_GRADIENT_UPDATE_VISITOR_HPP
14#define MLPACK_METHODS_ANN_VISITOR_GRADIENT_UPDATE_VISITOR_HPP
15
17
18#include <boost/variant.hpp>
19
20namespace mlpack {
21namespace ann {
22
26class GradientUpdateVisitor : public boost::static_visitor<size_t>
27{
28 public:
30 GradientUpdateVisitor(arma::mat& gradient, size_t offset = 0);
31
33 template<typename LayerType>
34 size_t operator()(LayerType* layer) const;
35
36 size_t operator()(MoreTypes layer) const;
37
38 private:
40 arma::mat& gradient;
41
43 size_t offset;
44
46 template<typename T>
47 typename std::enable_if<
48 HasGradientCheck<T, arma::mat&(T::*)()>::value &&
49 !HasModelCheck<T>::value, size_t>::type
50 LayerGradients(T* layer, arma::mat& input) const;
51
53 template<typename T>
54 typename std::enable_if<
55 !HasGradientCheck<T, arma::mat&(T::*)()>::value &&
56 HasModelCheck<T>::value, size_t>::type
57 LayerGradients(T* layer, arma::mat& input) const;
58
61 template<typename T>
62 typename std::enable_if<
63 HasGradientCheck<T, arma::mat&(T::*)()>::value &&
64 HasModelCheck<T>::value, size_t>::type
65 LayerGradients(T* layer, arma::mat& input) const;
66
69 template<typename T, typename P>
70 typename std::enable_if<
71 !HasGradientCheck<T, P&(T::*)()>::value &&
72 !HasModelCheck<T>::value, size_t>::type
73 LayerGradients(T* layer, P& input) const;
74};
75
76} // namespace ann
77} // namespace mlpack
78
79// Include implementation.
80#include "gradient_update_visitor_impl.hpp"
81
82#endif
GradientUpdateVisitor update the gradient parameter given the gradient set.
size_t operator()(MoreTypes layer) const
GradientUpdateVisitor(arma::mat &gradient, size_t offset=0)
Update the gradient parameter given the gradient set.
size_t operator()(LayerType *layer) const
Update the gradient parameter.
boost::variant< Linear3D< arma::mat, arma::mat, NoRegularizer > *, Glimpse< arma::mat, arma::mat > *, Highway< arma::mat, arma::mat > *, MultiheadAttention< arma::mat, arma::mat, NoRegularizer > *, Recurrent< arma::mat, arma::mat > *, RecurrentAttention< arma::mat, arma::mat > *, ReinforceNormal< arma::mat, arma::mat > *, Reparametrization< arma::mat, arma::mat > *, Select< arma::mat, arma::mat > *, Sequential< arma::mat, arma::mat, false > *, Sequential< arma::mat, arma::mat, true > *, Subview< arma::mat, arma::mat > *, VRClassReward< arma::mat, arma::mat > *, VirtualBatchNorm< arma::mat, arma::mat > *, RBF< arma::mat, arma::mat, GaussianFunction > *, BaseLayer< GaussianFunction, arma::mat, arma::mat > *, PositionalEncoding< arma::mat, arma::mat > * > MoreTypes
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1