mlpack 3.4.2
async_learning.hpp
Go to the documentation of this file.
1
14#ifndef MLPACK_METHODS_RL_ASYNC_LEARNING_HPP
15#define MLPACK_METHODS_RL_ASYNC_LEARNING_HPP
16
17#include <mlpack/prereqs.hpp>
21#include "training_config.hpp"
22
23namespace mlpack {
24namespace rl {
25
50template <
51 typename WorkerType,
52 typename EnvironmentType,
53 typename NetworkType,
54 typename UpdaterType,
55 typename PolicyType
56>
58{
59 public:
70 NetworkType network,
71 PolicyType policy,
72 UpdaterType updater = UpdaterType(),
73 EnvironmentType environment = EnvironmentType());
74
88 template <typename Measure>
89 void Train(Measure& measure);
90
92 TrainingConfig& Config() { return config; }
94 const TrainingConfig& Config() const { return config; }
95
97 NetworkType& Network() { return learningNetwork; }
99 const NetworkType& Network() const { return learningNetwork; }
100
102 PolicyType& Policy() { return policy; }
104 const PolicyType& Policy() const { return policy; }
105
107 UpdaterType& Updater() { return updater; }
109 const UpdaterType& Updater() const { return updater; }
110
112 EnvironmentType& Environment() { return environment; }
114 const EnvironmentType& Environment() const { return environment; }
115
116 private:
118 TrainingConfig config;
119
121 NetworkType learningNetwork;
122
124 PolicyType policy;
125
127 UpdaterType updater;
128
130 EnvironmentType environment;
131};
132
141template <
142 typename EnvironmentType,
143 typename NetworkType,
144 typename UpdaterType,
145 typename PolicyType
146>
147class OneStepQLearningWorker;
148
157template <
158 typename EnvironmentType,
159 typename NetworkType,
160 typename UpdaterType,
161 typename PolicyType
162>
163class OneStepSarsaWorker;
164
173template <
174 typename EnvironmentType,
175 typename NetworkType,
176 typename UpdaterType,
177 typename PolicyType
178>
179class NStepQLearningWorker;
180
189template <
190 typename EnvironmentType,
191 typename NetworkType,
192 typename UpdaterType,
193 typename PolicyType
194>
196 NetworkType, UpdaterType, PolicyType>, EnvironmentType, NetworkType,
197 UpdaterType, PolicyType>;
198
207template <
208 typename EnvironmentType,
209 typename NetworkType,
210 typename UpdaterType,
211 typename PolicyType
212>
214 NetworkType, UpdaterType, PolicyType>, EnvironmentType, NetworkType,
215 UpdaterType, PolicyType>;
216
225template <
226 typename EnvironmentType,
227 typename NetworkType,
228 typename UpdaterType,
229 typename PolicyType
230>
232 NetworkType, UpdaterType, PolicyType>, EnvironmentType, NetworkType,
233 UpdaterType, PolicyType>;
234
235} // namespace rl
236} // namespace mlpack
237
238// Include implementation
239#include "async_learning_impl.hpp"
240
241#endif
Wrapper of various asynchronous learning algorithms, e.g.
const TrainingConfig & Config() const
Modify training config.
AsyncLearning(TrainingConfig config, NetworkType network, PolicyType policy, UpdaterType updater=UpdaterType(), EnvironmentType environment=EnvironmentType())
Construct an instance of the given async learning algorithm.
NetworkType & Network()
Get learning network.
TrainingConfig & Config()
Get training config.
const PolicyType & Policy() const
Modify behavior policy.
PolicyType & Policy()
Get behavior policy.
const UpdaterType & Updater() const
Modify optimizer.
void Train(Measure &measure)
Starting async training.
UpdaterType & Updater()
Get optimizer.
EnvironmentType & Environment()
Get the environment.
const EnvironmentType & Environment() const
Modify the environment.
const NetworkType & Network() const
Modify learning network.
Forward declaration of NStepQLearningWorker.
Forward declaration of OneStepQLearningWorker.
Forward declaration of OneStepSarsaWorker.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.