mlpack 3.4.2
aggregated_policy.hpp
Go to the documentation of this file.
1
14#ifndef MLPACK_METHODS_RL_POLICY_AGGREGATED_POLICY_HPP
15#define MLPACK_METHODS_RL_POLICY_AGGREGATED_POLICY_HPP
16
17#include <mlpack/prereqs.hpp>
19
20namespace mlpack {
21namespace rl {
22
26template <typename PolicyType>
28{
29 public:
31 using ActionType = typename PolicyType::ActionType;
32
39 AggregatedPolicy(std::vector<PolicyType> policies,
40 const arma::colvec& distribution) :
41 policies(std::move(policies)),
42 sampler({distribution})
43 { /* Nothing to do here. */ };
44
52 ActionType Sample(const arma::colvec& actionValue, bool deterministic = false)
53 {
54 if (deterministic)
55 return policies.front().Sample(actionValue, true);
56 size_t selected = arma::as_scalar(sampler.Random());
57 return policies[selected].Sample(actionValue, false);
58 }
59
63 void Anneal()
64 {
65 for (PolicyType& policy : policies)
66 policy.Anneal();
67 }
68
69 private:
71 std::vector<PolicyType> policies;
72
75};
76
77} // namespace rl
78} // namespace mlpack
79
80#endif
A discrete distribution where the only observations are discrete observations.
arma::vec Random() const
Return a randomly generated observation (one-dimensional vector; one observation) according to the pr...
void Anneal()
Exploration probability will anneal at each step.
ActionType Sample(const arma::colvec &actionValue, bool deterministic=false)
Sample an action based on given action values.
AggregatedPolicy(std::vector< PolicyType > policies, const arma::colvec &distribution)
typename PolicyType::ActionType ActionType
Convenient typedef for action.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: prereqs.hpp:67
The core includes that mlpack expects; standard C++ includes and Armadillo.