mlpack 3.4.2
reward_set_visitor.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_ANN_VISITOR_REWARD_SET_VISITOR_HPP
14#define MLPACK_METHODS_ANN_VISITOR_REWARD_SET_VISITOR_HPP
15
17
18#include <boost/variant.hpp>
19
20namespace mlpack {
21namespace ann {
22
26class RewardSetVisitor : public boost::static_visitor<void>
27{
28 public:
30 RewardSetVisitor(const double reward);
31
33 template<typename LayerType>
34 void operator()(LayerType* layer) const;
35
36 void operator()(MoreTypes layer) const;
37
38 private:
40 const double reward;
41
44 template<typename T>
45 typename std::enable_if<
46 HasRewardCheck<T, double&(T::*)()>::value &&
47 HasModelCheck<T>::value, void>::type
48 LayerReward(T* layer) const;
49
52 template<typename T>
53 typename std::enable_if<
54 !HasRewardCheck<T, double&(T::*)()>::value &&
55 HasModelCheck<T>::value, void>::type
56 LayerReward(T* layer) const;
57
60 template<typename T>
61 typename std::enable_if<
62 HasRewardCheck<T, double&(T::*)()>::value &&
63 !HasModelCheck<T>::value, void>::type
64 LayerReward(T* layer) const;
65
68 template<typename T>
69 typename std::enable_if<
70 !HasRewardCheck<T, double&(T::*)()>::value &&
71 !HasModelCheck<T>::value, void>::type
72 LayerReward(T* layer) const;
73};
74
75} // namespace ann
76} // namespace mlpack
77
78// Include implementation.
79#include "reward_set_visitor_impl.hpp"
80
81#endif
RewardSetVisitor set the reward parameter given the reward value.
void operator()(LayerType *layer) const
Set the reward parameter.
void operator()(MoreTypes layer) const
RewardSetVisitor(const double reward)
Set the reward parameter given the reward value.
boost::variant< Linear3D< arma::mat, arma::mat, NoRegularizer > *, Glimpse< arma::mat, arma::mat > *, Highway< arma::mat, arma::mat > *, MultiheadAttention< arma::mat, arma::mat, NoRegularizer > *, Recurrent< arma::mat, arma::mat > *, RecurrentAttention< arma::mat, arma::mat > *, ReinforceNormal< arma::mat, arma::mat > *, Reparametrization< arma::mat, arma::mat > *, Select< arma::mat, arma::mat > *, Sequential< arma::mat, arma::mat, false > *, Sequential< arma::mat, arma::mat, true > *, Subview< arma::mat, arma::mat > *, VRClassReward< arma::mat, arma::mat > *, VirtualBatchNorm< arma::mat, arma::mat > *, RBF< arma::mat, arma::mat, GaussianFunction > *, BaseLayer< GaussianFunction, arma::mat, arma::mat > *, PositionalEncoding< arma::mat, arma::mat > * > MoreTypes
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1