mlpack 3.4.2
n_step_q_learning_worker.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_RL_WORKER_N_STEP_Q_LEARNING_WORKER_HPP
14#define MLPACK_METHODS_RL_WORKER_N_STEP_Q_LEARNING_WORKER_HPP
15
17
18namespace mlpack {
19namespace rl {
20
29template <
30 typename EnvironmentType,
31 typename NetworkType,
32 typename UpdaterType,
33 typename PolicyType
34>
36{
37 public:
38 using StateType = typename EnvironmentType::State;
39 using ActionType = typename EnvironmentType::Action;
40 using TransitionType = std::tuple<StateType, ActionType, double, StateType>;
41
52 const UpdaterType& updater,
53 const EnvironmentType& environment,
54 const TrainingConfig& config,
55 bool deterministic):
56 updater(updater),
57 #if ENS_VERSION_MAJOR >= 2
58 updatePolicy(NULL),
59 #endif
60 environment(environment),
61 config(config),
62 deterministic(deterministic),
63 pending(config.UpdateInterval())
64 { Reset(); }
65
72 updater(other.updater),
73 #if ENS_VERSION_MAJOR >= 2
74 updatePolicy(NULL),
75 #endif
76 environment(other.environment),
77 config(other.config),
78 deterministic(other.deterministic),
79 steps(other.steps),
80 episodeReturn(other.episodeReturn),
81 pending(other.pending),
82 pendingIndex(other.pendingIndex),
83 network(other.network),
84 state(other.state)
85 {
86 #if ENS_VERSION_MAJOR >= 2
87 updatePolicy = new typename UpdaterType::template
88 Policy<arma::mat, arma::mat>(updater,
89 network.Parameters().n_rows,
90 network.Parameters().n_cols);
91 #endif
92
93 Reset();
94 }
95
102 updater(std::move(other.updater)),
103 #if ENS_VERSION_MAJOR >= 2
104 updatePolicy(NULL),
105 #endif
106 environment(std::move(other.environment)),
107 config(std::move(other.config)),
108 deterministic(std::move(other.deterministic)),
109 steps(std::move(other.steps)),
110 episodeReturn(std::move(other.episodeReturn)),
111 pending(std::move(other.pending)),
112 pendingIndex(std::move(other.pendingIndex)),
113 network(std::move(other.network)),
114 state(std::move(other.state))
115 {
116 #if ENS_VERSION_MAJOR >= 2
117 other.updatePolicy = NULL;
118
119 updatePolicy = new typename UpdaterType::template
120 Policy<arma::mat, arma::mat>(updater,
121 network.Parameters().n_rows,
122 network.Parameters().n_cols);
123 #endif
124 }
125
132 {
133 if (&other == this)
134 return *this;
135
136 #if ENS_VERSION_MAJOR >= 2
137 delete updatePolicy;
138 #endif
139
140 updater = other.updater;
141 environment = other.environment;
142 config = other.config;
143 deterministic = other.deterministic;
144 steps = other.steps;
145 episodeReturn = other.episodeReturn;
146 pending = other.pending;
147 pendingIndex = other.pendingIndex;
148 network = other.network;
149 state = other.state;
150
151 #if ENS_VERSION_MAJOR >= 2
152 updatePolicy = new typename UpdaterType::template
153 Policy<arma::mat, arma::mat>(updater,
154 network.Parameters().n_rows,
155 network.Parameters().n_cols);
156 #endif
157
158 Reset();
159
160 return *this;
161 }
162
169 {
170 if (&other == this)
171 return *this;
172
173 #if ENS_VERSION_MAJOR >= 2
174 delete updatePolicy;
175 #endif
176
177 updater = std::move(other.updater);
178 environment = std::move(other.environment);
179 config = std::move(other.config);
180 deterministic = std::move(other.deterministic);
181 steps = std::move(other.steps);
182 episodeReturn = std::move(other.episodeReturn);
183 pending = std::move(other.pending);
184 pendingIndex = std::move(other.pendingIndex);
185 network = std::move(other.network);
186 state = std::move(other.state);
187
188 #if ENS_VERSION_MAJOR >= 2
189 updatePolicy = new typename UpdaterType::template
190 Policy<arma::mat, arma::mat>(updater,
191 network.Parameters().n_rows,
192 network.Parameters().n_cols);
193
194 other.updatePolicy = NULL;
195 #endif
196
197 return *this;
198 }
199
204 {
205 #if ENS_VERSION_MAJOR >= 2
206 delete updatePolicy;
207 #endif
208 }
209
214 void Initialize(NetworkType& learningNetwork)
215 {
216 #if ENS_VERSION_MAJOR == 1
217 updater.Initialize(learningNetwork.Parameters().n_rows,
218 learningNetwork.Parameters().n_cols);
219 #else
220 delete updatePolicy;
221
222 updatePolicy = new typename UpdaterType::template
223 Policy<arma::mat, arma::mat>(updater,
224 learningNetwork.Parameters().n_rows,
225 learningNetwork.Parameters().n_cols);
226 #endif
227
228 // Build local network.
229 network = learningNetwork;
230 }
231
243 bool Step(NetworkType& learningNetwork,
244 NetworkType& targetNetwork,
245 size_t& totalSteps,
246 PolicyType& policy,
247 double& totalReward)
248 {
249 // Interact with the environment.
250 arma::colvec actionValue;
251 network.Predict(state.Encode(), actionValue);
252 ActionType action = policy.Sample(actionValue, deterministic);
253 StateType nextState;
254 double reward = environment.Sample(state, action, nextState);
255 bool terminal = environment.IsTerminal(nextState);
256
257 episodeReturn += reward;
258 steps++;
259
260 terminal = terminal || steps >= config.StepLimit();
261 if (deterministic)
262 {
263 if (terminal)
264 {
265 totalReward = episodeReturn;
266 Reset();
267 // Sync with latest learning network.
268 network = learningNetwork;
269 return true;
270 }
271 state = nextState;
272 return false;
273 }
274
275 #pragma omp atomic
276 totalSteps++;
277
278 pending[pendingIndex] = std::make_tuple(state, action, reward, nextState);
279 pendingIndex++;
280
281 if (terminal || pendingIndex >= config.UpdateInterval())
282 {
283 // Initialize the gradient storage.
284 arma::mat totalGradients(learningNetwork.Parameters().n_rows,
285 learningNetwork.Parameters().n_cols, arma::fill::zeros);
286
287 // Bootstrap from the value of next state.
288 arma::colvec actionValue;
289 double target = 0;
290 if (!terminal)
291 {
292 #pragma omp critical
293 { targetNetwork.Predict(nextState.Encode(), actionValue); };
294 target = actionValue.max();
295 }
296
297 // Update in reverse order.
298 for (int i = pending.size() - 1; i >= 0; --i)
299 {
300 TransitionType &transition = pending[i];
301 target = config.Discount() * target + std::get<2>(transition);
302
303 // Compute the training target for current state.
304 arma::mat input = std::get<0>(transition).Encode();
305 network.Forward(input, actionValue);
306 actionValue[std::get<1>(transition).action] = target;
307
308 // Compute gradient.
309 arma::mat gradients;
310 network.Backward(input, actionValue, gradients);
311
312 // Accumulate gradients.
313 totalGradients += gradients;
314 }
315
316 // Clamp the accumulated gradients.
317 totalGradients.transform(
318 [&](double gradient)
319 { return std::min(std::max(gradient, -config.GradientLimit()),
320 config.GradientLimit()); });
321
322 // Perform async update of the global network.
323 #if ENS_VERSION_MAJOR == 1
324 updater.Update(learningNetwork.Parameters(), config.StepSize(),
325 totalGradients);
326 #else
327 updatePolicy->Update(learningNetwork.Parameters(),
328 config.StepSize(), totalGradients);
329 #endif
330
331 // Sync the local network with the global network.
332 network = learningNetwork;
333
334 pendingIndex = 0;
335 }
336
337 // Update global target network.
338 if (totalSteps % config.TargetNetworkSyncInterval() == 0)
339 {
340 #pragma omp critical
341 { targetNetwork = learningNetwork; }
342 }
343
344 policy.Anneal();
345
346 if (terminal)
347 {
348 totalReward = episodeReturn;
349 Reset();
350 return true;
351 }
352 state = nextState;
353 return false;
354 }
355
356 private:
360 void Reset()
361 {
362 steps = 0;
363 episodeReturn = 0;
364 pendingIndex = 0;
365 state = environment.InitialSample();
366 }
367
369 UpdaterType updater;
370 #if ENS_VERSION_MAJOR >= 2
371 typename UpdaterType::template Policy<arma::mat, arma::mat>* updatePolicy;
372 #endif
373
375 EnvironmentType environment;
376
378 TrainingConfig config;
379
381 bool deterministic;
382
384 size_t steps;
385
387 double episodeReturn;
388
390 std::vector<TransitionType> pending;
391
393 size_t pendingIndex;
394
396 NetworkType network;
397
399 StateType state;
400};
401
402} // namespace rl
403} // namespace mlpack
404
405#endif
if(NOT BUILD_GO_SHLIB) macro(add_go_binding name) endmacro() return() endif() endmacro() macro(post_go_setup) if(BUILD_GO_BINDINGS) file(APPEND "$
Definition: CMakeLists.txt:3
cannot build Julia bindings endif() else() find_package(Julia 0.7.0) if(NOT JULIA_FOUND) unset(BUILD_JULIA_BINDINGS CACHE) endif() endif() if(NOT JULIA_FOUND) not_found_return("Julia not found
Definition: CMakeLists.txt:45
Forward declaration of NStepQLearningWorker.
NStepQLearningWorker(NStepQLearningWorker &&other)
Take ownership of another NStepQLearningWorker.
NStepQLearningWorker & operator=(const NStepQLearningWorker &other)
Copy another NStepQLearningWorker.
std::tuple< StateType, ActionType, double, StateType > TransitionType
NStepQLearningWorker & operator=(NStepQLearningWorker &&other)
Take ownership of another NStepQLearningWorker.
typename EnvironmentType::Action ActionType
void Initialize(NetworkType &learningNetwork)
Initialize the worker.
bool Step(NetworkType &learningNetwork, NetworkType &targetNetwork, size_t &totalSteps, PolicyType &policy, double &totalReward)
The agent will execute one step.
NStepQLearningWorker(const UpdaterType &updater, const EnvironmentType &environment, const TrainingConfig &config, bool deterministic)
Construct N-step Q-Learning worker with the given parameters and environment.
typename EnvironmentType::State StateType
NStepQLearningWorker(const NStepQLearningWorker &other)
Copy another NStepQLearningWorker.
double StepSize() const
Get the step size of the optimizer.
size_t UpdateInterval() const
Get the update interval.
double GradientLimit() const
Get the limit of update gradient.
double Discount() const
Get the discount rate for future reward.
size_t StepLimit() const
Get the maximum steps of each episode.
size_t TargetNetworkSyncInterval() const
Get the interval for syncing target network.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: prereqs.hpp:67