mlpack 3.4.2
gradient_zero_visitor.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_ANN_VISITOR_GRADIENT_ZERO_VISITOR_HPP
14#define MLPACK_METHODS_ANN_VISITOR_GRADIENT_ZERO_VISITOR_HPP
15
18
19#include <boost/variant.hpp>
20
21namespace mlpack {
22namespace ann {
23
24/*
25 * GradientZeroVisitor set the gradient to zero for the given module.
26 */
27class GradientZeroVisitor : public boost::static_visitor<void>
28{
29 public:
32
34 template<typename LayerType>
35 void operator()(LayerType* layer) const;
36
37 void operator()(MoreTypes layer) const;
38
39 private:
41 template<typename T>
42 typename std::enable_if<
43 HasGradientCheck<T, arma::mat&(T::*)()>::value, void>::type
44 LayerGradients(T* layer, arma::mat& input) const;
45
48 template<typename T, typename P>
49 typename std::enable_if<
50 !HasGradientCheck<T, P&(T::*)()>::value, void>::type
51 LayerGradients(T* layer, P& input) const;
52};
53
54} // namespace ann
55} // namespace mlpack
56
57// Include implementation.
58#include "gradient_zero_visitor_impl.hpp"
59
60#endif
void operator()(LayerType *layer) const
Set the gradient to zero.
GradientZeroVisitor()
Set the gradient to zero for the given module.
void operator()(MoreTypes layer) const
boost::variant< Linear3D< arma::mat, arma::mat, NoRegularizer > *, Glimpse< arma::mat, arma::mat > *, Highway< arma::mat, arma::mat > *, MultiheadAttention< arma::mat, arma::mat, NoRegularizer > *, Recurrent< arma::mat, arma::mat > *, RecurrentAttention< arma::mat, arma::mat > *, ReinforceNormal< arma::mat, arma::mat > *, Reparametrization< arma::mat, arma::mat > *, Select< arma::mat, arma::mat > *, Sequential< arma::mat, arma::mat, false > *, Sequential< arma::mat, arma::mat, true > *, Subview< arma::mat, arma::mat > *, VRClassReward< arma::mat, arma::mat > *, VirtualBatchNorm< arma::mat, arma::mat > *, RBF< arma::mat, arma::mat, GaussianFunction > *, BaseLayer< GaussianFunction, arma::mat, arma::mat > *, PositionalEncoding< arma::mat, arma::mat > * > MoreTypes
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1