mlpack 3.4.2
one_step_sarsa_worker.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_RL_WORKER_ONE_STEP_SARSA_WORKER_HPP
14#define MLPACK_METHODS_RL_WORKER_ONE_STEP_SARSA_WORKER_HPP
15
17
18namespace mlpack {
19namespace rl {
20
29template <
30 typename EnvironmentType,
31 typename NetworkType,
32 typename UpdaterType,
33 typename PolicyType
34>
36{
37 public:
38 using StateType = typename EnvironmentType::State;
39 using ActionType = typename EnvironmentType::Action;
40 using TransitionType = std::tuple<StateType, ActionType, double, StateType,
42
53 const UpdaterType& updater,
54 const EnvironmentType& environment,
55 const TrainingConfig& config,
56 bool deterministic):
57 updater(updater),
58 #if ENS_VERSION_MAJOR >= 2
59 updatePolicy(NULL),
60 #endif
61 environment(environment),
62 config(config),
63 deterministic(deterministic),
64 pending(config.UpdateInterval())
65 { Reset(); }
66
73 updater(other.updater),
74 #if ENS_VERSION_MAJOR >= 2
75 updatePolicy(NULL),
76 #endif
77 environment(other.environment),
78 config(other.config),
79 deterministic(other.deterministic),
80 steps(other.steps),
81 episodeReturn(other.episodeReturn),
82 pending(other.pending),
83 pendingIndex(other.pendingIndex),
84 network(other.network),
85 state(other.state),
86 action(other.action)
87 {
88 Reset();
89
90 #if ENS_VERSION_MAJOR >= 2
91 updatePolicy = new typename UpdaterType::template
92 Policy<arma::mat, arma::mat>(updater,
93 network.Parameters().n_rows,
94 network.Parameters().n_cols);
95 #endif
96 }
97
104 updater(std::move(other.updater)),
105 #if ENS_VERSION_MAJOR >= 2
106 updatePolicy(NULL),
107 #endif
108 environment(std::move(other.environment)),
109 config(std::move(other.config)),
110 deterministic(std::move(other.deterministic)),
111 steps(std::move(other.steps)),
112 episodeReturn(std::move(other.episodeReturn)),
113 pending(std::move(other.pending)),
114 pendingIndex(std::move(other.pendingIndex)),
115 network(std::move(other.network)),
116 state(std::move(other.state)),
117 action(std::move(other.action))
118 {
119 #if ENS_VERSION_MAJOR >= 2
120 other.updatePolicy = NULL;
121
122 updatePolicy = new typename UpdaterType::template
123 Policy<arma::mat, arma::mat>(updater,
124 network.Parameters().n_rows,
125 network.Parameters().n_cols);
126 #endif
127 }
128
135 {
136 if (&other == this)
137 return *this;
138
139 #if ENS_VERSION_MAJOR >= 2
140 delete updatePolicy;
141 #endif
142
143 updater = other.updater;
144 environment = other.environment;
145 config = other.config;
146 deterministic = other.deterministic;
147 steps = other.steps;
148 episodeReturn = other.episodeReturn;
149 pending = other.pending;
150 pendingIndex = other.pendingIndex;
151 network = other.network;
152 state = other.state;
153 action = other.action;
154
155 #if ENS_VERSION_MAJOR >= 2
156 updatePolicy = new typename UpdaterType::template
157 Policy<arma::mat, arma::mat>(updater,
158 network.Parameters().n_rows,
159 network.Parameters().n_cols);
160 #endif
161
162 Reset();
163
164 return *this;
165 }
166
173 {
174 if (&other == this)
175 return *this;
176
177 #if ENS_VERSION_MAJOR >= 2
178 delete updatePolicy;
179 #endif
180
181 updater = std::move(other.updater);
182 environment = std::move(other.environment);
183 config = std::move(other.config);
184 deterministic = std::move(other.deterministic);
185 steps = std::move(other.steps);
186 episodeReturn = std::move(other.episodeReturn);
187 pending = std::move(other.pending);
188 pendingIndex = std::move(other.pendingIndex);
189 network = std::move(other.network);
190 state = std::move(other.state);
191 action = std::move(other.action);
192
193 #if ENS_VERSION_MAJOR >= 2
194 other.updatePolicy = NULL;
195
196 updatePolicy = new typename UpdaterType::template
197 Policy<arma::mat, arma::mat>(updater,
198 network.Parameters().n_rows,
199 network.Parameters().n_cols);
200 #endif
201
202 return *this;
203 }
204
209 {
210 #if ENS_VERSION_MAJOR >= 2
211 delete updatePolicy;
212 #endif
213 }
214
219 void Initialize(NetworkType& learningNetwork)
220 {
221 #if ENS_VERSION_MAJOR == 1
222 updater.Initialize(learningNetwork.Parameters().n_rows,
223 learningNetwork.Parameters().n_cols);
224 #else
225 delete updatePolicy;
226
227 updatePolicy = new typename UpdaterType::template
228 Policy<arma::mat, arma::mat>(updater,
229 learningNetwork.Parameters().n_rows,
230 learningNetwork.Parameters().n_cols);
231 #endif
232
233 // Build local network.
234 network = learningNetwork;
235 }
236
248 bool Step(NetworkType& learningNetwork,
249 NetworkType& targetNetwork,
250 size_t& totalSteps,
251 PolicyType& policy,
252 double& totalReward)
253 {
254 // Interact with the environment.
255 if (action.action == ActionType::size)
256 {
257 // Invalid action means we are at the beginning of an episode.
258 arma::colvec actionValue;
259 network.Predict(state.Encode(), actionValue);
260 action = policy.Sample(actionValue, deterministic);
261 }
262 StateType nextState;
263 double reward = environment.Sample(state, action, nextState);
264 bool terminal = environment.IsTerminal(nextState);
265 arma::colvec actionValue;
266 network.Predict(nextState.Encode(), actionValue);
267 ActionType nextAction = policy.Sample(actionValue, deterministic);
268
269 episodeReturn += reward;
270 steps++;
271
272 terminal = terminal || steps >= config.StepLimit();
273 if (deterministic)
274 {
275 if (terminal)
276 {
277 totalReward = episodeReturn;
278 Reset();
279 // Sync with latest learning network.
280 network = learningNetwork;
281 return true;
282 }
283 state = nextState;
284 action = nextAction;
285 return false;
286 }
287
288 #pragma omp atomic
289 totalSteps++;
290
291 pending[pendingIndex++] =
292 std::make_tuple(state, action, reward, nextState, nextAction);
293
294 if (terminal || pendingIndex >= config.UpdateInterval())
295 {
296 // Initialize the gradient storage.
297 arma::mat totalGradients(learningNetwork.Parameters().n_rows,
298 learningNetwork.Parameters().n_cols, arma::fill::zeros);
299 for (size_t i = 0; i < pending.size(); ++i)
300 {
301 TransitionType &transition = pending[i];
302
303 // Compute the target state-action value.
304 arma::colvec actionValue;
305 #pragma omp critical
306 {
307 targetNetwork.Predict(
308 std::get<3>(transition).Encode(), actionValue);
309 };
310 double targetActionValue = 0;
311 if (!(terminal && i == pending.size() - 1))
312 targetActionValue = actionValue[std::get<4>(transition).action];
313 targetActionValue = std::get<2>(transition) +
314 config.Discount() * targetActionValue;
315
316 // Compute the training target for current state.
317 arma::mat input = std::get<0>(transition).Encode();
318 network.Forward(input, actionValue);
319 actionValue[std::get<1>(transition).action] = targetActionValue;
320
321 // Compute gradient.
322 arma::mat gradients;
323 network.Backward(input, actionValue, gradients);
324
325 // Accumulate gradients.
326 totalGradients += gradients;
327 }
328
329 // Clamp the accumulated gradients.
330 totalGradients.transform(
331 [&](double gradient)
332 { return std::min(std::max(gradient, -config.GradientLimit()),
333 config.GradientLimit()); });
334
335 // Perform async update of the global network.
336 #if ENS_VERSION_MAJOR == 1
337 updater.Update(learningNetwork.Parameters(), config.StepSize(),
338 totalGradients);
339 #else
340 updatePolicy->Update(learningNetwork.Parameters(),
341 config.StepSize(), totalGradients);
342 #endif
343
344 // Sync the local network with the global network.
345 network = learningNetwork;
346
347 pendingIndex = 0;
348 }
349
350 // Update global target network.
351 if (totalSteps % config.TargetNetworkSyncInterval() == 0)
352 {
353 #pragma omp critical
354 { targetNetwork = learningNetwork; }
355 }
356
357 policy.Anneal();
358
359 if (terminal)
360 {
361 totalReward = episodeReturn;
362 Reset();
363 return true;
364 }
365 state = nextState;
366 action = nextAction;
367 return false;
368 }
369
370 private:
374 void Reset()
375 {
376 steps = 0;
377 episodeReturn = 0;
378 pendingIndex = 0;
379 state = environment.InitialSample();
380 using actions = typename EnvironmentType::Action::actions;
381 action.action = static_cast<actions>(ActionType::size);
382 }
383
385 UpdaterType updater;
386 #if ENS_VERSION_MAJOR >= 2
387 typename UpdaterType::template Policy<arma::mat, arma::mat>* updatePolicy;
388 #endif
389
391 EnvironmentType environment;
392
394 TrainingConfig config;
395
397 bool deterministic;
398
400 size_t steps;
401
403 double episodeReturn;
404
406 std::vector<TransitionType> pending;
407
409 size_t pendingIndex;
410
412 NetworkType network;
413
415 StateType state;
416
418 ActionType action;
419};
420
421} // namespace rl
422} // namespace mlpack
423
424#endif
if(NOT BUILD_GO_SHLIB) macro(add_go_binding name) endmacro() return() endif() endmacro() macro(post_go_setup) if(BUILD_GO_BINDINGS) file(APPEND "$
Definition: CMakeLists.txt:3
cannot build Julia bindings endif() else() find_package(Julia 0.7.0) if(NOT JULIA_FOUND) unset(BUILD_JULIA_BINDINGS CACHE) endif() endif() if(NOT JULIA_FOUND) not_found_return("Julia not found
Definition: CMakeLists.txt:45
Forward declaration of OneStepSarsaWorker.
OneStepSarsaWorker(const UpdaterType &updater, const EnvironmentType &environment, const TrainingConfig &config, bool deterministic)
Construct one step sarsa worker with the given parameters and environment.
OneStepSarsaWorker(const OneStepSarsaWorker &other)
Copy another OneStepSarsaWorker.
OneStepSarsaWorker & operator=(const OneStepSarsaWorker &other)
Copy another OneStepSarsaWorker.
std::tuple< StateType, ActionType, double, StateType, ActionType > TransitionType
OneStepSarsaWorker & operator=(OneStepSarsaWorker &&other)
Take ownership of another OneStepSarsaWorker.
OneStepSarsaWorker(OneStepSarsaWorker &&other)
Take ownership of another OneStepSarsaWorker.
typename EnvironmentType::Action ActionType
void Initialize(NetworkType &learningNetwork)
Initialize the worker.
bool Step(NetworkType &learningNetwork, NetworkType &targetNetwork, size_t &totalSteps, PolicyType &policy, double &totalReward)
The agent will execute one step.
typename EnvironmentType::State StateType
double StepSize() const
Get the step size of the optimizer.
size_t UpdateInterval() const
Get the update interval.
double GradientLimit() const
Get the limit of update gradient.
double Discount() const
Get the discount rate for future reward.
size_t StepLimit() const
Get the maximum steps of each episode.
size_t TargetNetworkSyncInterval() const
Get the interval for syncing target network.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: prereqs.hpp:67