mlpack 3.4.2
greedy_policy.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_RL_POLICY_GREEDY_POLICY_HPP
14#define MLPACK_METHODS_RL_POLICY_GREEDY_POLICY_HPP
15
16#include <mlpack/prereqs.hpp>
17
18namespace mlpack {
19namespace rl {
20
30template <typename EnvironmentType>
32{
33 public:
35 using ActionType = typename EnvironmentType::Action;
36
48 GreedyPolicy(const double initialEpsilon,
49 const size_t annealInterval,
50 const double minEpsilon,
51 const double decayRate = 1.0) :
52 epsilon(initialEpsilon),
53 minEpsilon(minEpsilon),
54 delta(((initialEpsilon - minEpsilon) * decayRate) / annealInterval)
55 { /* Nothing to do here. */ }
56
65 ActionType Sample(const arma::colvec& actionValue,
66 bool deterministic = false,
67 const bool isNoisy = false)
68 {
69 double exploration = math::Random();
70 ActionType action;
71
72 // Select the action randomly.
73 if (!deterministic && exploration < epsilon && isNoisy == false)
74 {
75 action.action = static_cast<decltype(action.action)>
76 (math::RandInt(ActionType::size));
77 }
78 // Select the action greedily.
79 else
80 {
81 action.action = static_cast<decltype(action.action)>(
82 arma::as_scalar(arma::find(actionValue == actionValue.max(), 1)));
83 }
84 return action;
85 }
86
90 void Anneal()
91 {
92 epsilon -= delta;
93 epsilon = std::max(minEpsilon, epsilon);
94 }
95
99 const double& Epsilon() const { return epsilon; }
100
101 private:
103 double epsilon;
104
106 double minEpsilon;
107
109 double delta;
110};
111
112} // namespace rl
113} // namespace mlpack
114
115#endif
Implementation for epsilon greedy policy.
void Anneal()
Exploration probability will anneal at each step.
GreedyPolicy(const double initialEpsilon, const size_t annealInterval, const double minEpsilon, const double decayRate=1.0)
Constructor for epsilon greedy policy class.
const double & Epsilon() const
typename EnvironmentType::Action ActionType
Convenient typedef for action.
ActionType Sample(const arma::colvec &actionValue, bool deterministic=false, const bool isNoisy=false)
Sample an action based on given action values.
double Random()
Generates a uniform random number between 0 and 1.
Definition: random.hpp:83
int RandInt(const int hiExclusive)
Generates a uniform random integer.
Definition: random.hpp:110
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.