mlpack 3.4.2
q_learning.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_RL_Q_LEARNING_HPP
14#define MLPACK_METHODS_RL_Q_LEARNING_HPP
15
16#include <mlpack/prereqs.hpp>
17
20#include "training_config.hpp"
21
22namespace mlpack {
23namespace rl {
24
51template <
52 typename EnvironmentType,
53 typename NetworkType,
54 typename UpdaterType,
55 typename PolicyType,
56 typename ReplayType = RandomReplay<EnvironmentType>
57>
59{
60 public:
62 using StateType = typename EnvironmentType::State;
63
65 using ActionType = typename EnvironmentType::Action;
66
81 NetworkType& network,
82 PolicyType& policy,
83 ReplayType& replayMethod,
84 UpdaterType updater = UpdaterType(),
85 EnvironmentType environment = EnvironmentType());
86
91
95 void TrainAgent();
96
101
106
111 double Episode();
112
114 size_t& TotalSteps() { return totalSteps; }
116 const size_t& TotalSteps() const { return totalSteps; }
117
119 StateType& State() { return state; }
121 const StateType& State() const { return state; }
122
124 const ActionType& Action() const { return action; }
125
127 EnvironmentType& Environment() { return environment; }
129 const EnvironmentType& Environment() const { return environment; }
130
132 bool& Deterministic() { return deterministic; }
134 const bool& Deterministic() const { return deterministic; }
135
137 const NetworkType& Network() const { return learningNetwork; }
139 NetworkType& Network() { return learningNetwork; }
140
141 private:
147 arma::Col<size_t> BestAction(const arma::mat& actionValues);
148
150 TrainingConfig& config;
151
153 NetworkType& learningNetwork;
154
156 NetworkType targetNetwork;
157
159 PolicyType& policy;
160
162 ReplayType& replayMethod;
163
165 UpdaterType updater;
166 #if ENS_VERSION_MAJOR >= 2
167 typename UpdaterType::template Policy<arma::mat, arma::mat>* updatePolicy;
168 #endif
169
171 EnvironmentType environment;
172
174 size_t totalSteps;
175
177 StateType state;
178
180 ActionType action;
181
183 bool deterministic;
184};
185
186} // namespace rl
187} // namespace mlpack
188
189// Include implementation
190#include "q_learning_impl.hpp"
191#endif
Implementation of various Q-Learning algorithms, such as DQN, double DQN.
Definition: q_learning.hpp:59
StateType & State()
Modify the state of the agent.
Definition: q_learning.hpp:119
double Episode()
Execute an episode.
QLearning(TrainingConfig &config, NetworkType &network, PolicyType &policy, ReplayType &replayMethod, UpdaterType updater=UpdaterType(), EnvironmentType environment=EnvironmentType())
Create the QLearning object with given settings.
~QLearning()
Clean memory.
void TrainCategoricalAgent()
Trains the DQN agent of categorical type.
NetworkType & Network()
Modify the learning network.
Definition: q_learning.hpp:139
size_t & TotalSteps()
Modify total steps from beginning.
Definition: q_learning.hpp:114
const bool & Deterministic() const
Get the indicator of training mode / test mode.
Definition: q_learning.hpp:134
const size_t & TotalSteps() const
Get total steps from beginning.
Definition: q_learning.hpp:116
bool & Deterministic()
Modify the training mode / test mode indicator.
Definition: q_learning.hpp:132
typename EnvironmentType::Action ActionType
Convenient typedef for action.
Definition: q_learning.hpp:65
const StateType & State() const
Get the state of the agent.
Definition: q_learning.hpp:121
const ActionType & Action() const
Get the action of the agent.
Definition: q_learning.hpp:124
void SelectAction()
Select an action, given an agent.
typename EnvironmentType::State StateType
Convenient typedef for state.
Definition: q_learning.hpp:62
void TrainAgent()
Trains the DQN agent(non-categorical).
EnvironmentType & Environment()
Modify the environment in which the agent is.
Definition: q_learning.hpp:127
const EnvironmentType & Environment() const
Get the environment in which the agent is.
Definition: q_learning.hpp:129
const NetworkType & Network() const
Return the learning network.
Definition: q_learning.hpp:137
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.