13#ifndef MLPACK_METHODS_RL_WORKER_ONE_STEP_SARSA_WORKER_HPP
14#define MLPACK_METHODS_RL_WORKER_ONE_STEP_SARSA_WORKER_HPP
30 typename EnvironmentType,
53 const UpdaterType& updater,
54 const EnvironmentType& environment,
58 #
if ENS_VERSION_MAJOR >= 2
61 environment(environment),
63 deterministic(deterministic),
64 pending(config.UpdateInterval())
73 updater(other.updater),
74 #
if ENS_VERSION_MAJOR >= 2
77 environment(other.environment),
79 deterministic(other.deterministic),
81 episodeReturn(other.episodeReturn),
82 pending(other.pending),
83 pendingIndex(other.pendingIndex),
84 network(other.network),
90 #if ENS_VERSION_MAJOR >= 2
91 updatePolicy =
new typename UpdaterType::template
92 Policy<arma::mat, arma::mat>(updater,
93 network.Parameters().n_rows,
94 network.Parameters().n_cols);
104 updater(
std::move(other.updater)),
105 #
if ENS_VERSION_MAJOR >= 2
108 environment(
std::move(other.environment)),
109 config(
std::move(other.config)),
110 deterministic(
std::move(other.deterministic)),
111 steps(
std::move(other.steps)),
112 episodeReturn(
std::move(other.episodeReturn)),
113 pending(
std::move(other.pending)),
114 pendingIndex(
std::move(other.pendingIndex)),
115 network(
std::move(other.network)),
116 state(
std::move(other.state)),
117 action(
std::move(other.action))
119 #if ENS_VERSION_MAJOR >= 2
120 other.updatePolicy = NULL;
122 updatePolicy =
new typename UpdaterType::template
123 Policy<arma::mat, arma::mat>(updater,
124 network.Parameters().n_rows,
125 network.Parameters().n_cols);
139 #if ENS_VERSION_MAJOR >= 2
143 updater = other.updater;
144 environment = other.environment;
145 config = other.config;
146 deterministic = other.deterministic;
148 episodeReturn = other.episodeReturn;
149 pending = other.pending;
150 pendingIndex = other.pendingIndex;
151 network = other.network;
153 action = other.action;
155 #if ENS_VERSION_MAJOR >= 2
156 updatePolicy =
new typename UpdaterType::template
157 Policy<arma::mat, arma::mat>(updater,
158 network.Parameters().n_rows,
159 network.Parameters().n_cols);
177 #if ENS_VERSION_MAJOR >= 2
181 updater = std::move(other.updater);
182 environment = std::move(other.environment);
183 config = std::move(other.config);
184 deterministic = std::move(other.deterministic);
185 steps = std::move(other.steps);
186 episodeReturn = std::move(other.episodeReturn);
187 pending = std::move(other.pending);
188 pendingIndex = std::move(other.pendingIndex);
189 network = std::move(other.network);
190 state = std::move(other.state);
191 action = std::move(other.action);
193 #if ENS_VERSION_MAJOR >= 2
194 other.updatePolicy = NULL;
196 updatePolicy =
new typename UpdaterType::template
197 Policy<arma::mat, arma::mat>(updater,
198 network.Parameters().n_rows,
199 network.Parameters().n_cols);
210 #if ENS_VERSION_MAJOR >= 2
221 #if ENS_VERSION_MAJOR == 1
222 updater.Initialize(learningNetwork.Parameters().n_rows,
223 learningNetwork.Parameters().n_cols);
227 updatePolicy =
new typename UpdaterType::template
228 Policy<arma::mat, arma::mat>(updater,
229 learningNetwork.Parameters().n_rows,
230 learningNetwork.Parameters().n_cols);
234 network = learningNetwork;
248 bool Step(NetworkType& learningNetwork,
249 NetworkType& targetNetwork,
255 if (action.action == ActionType::size)
258 arma::colvec actionValue;
259 network.Predict(state.Encode(), actionValue);
260 action = policy.Sample(actionValue, deterministic);
263 double reward = environment.Sample(state, action, nextState);
264 bool terminal = environment.IsTerminal(nextState);
265 arma::colvec actionValue;
266 network.Predict(nextState.Encode(), actionValue);
267 ActionType nextAction = policy.Sample(actionValue, deterministic);
269 episodeReturn += reward;
272 terminal = terminal || steps >= config.
StepLimit();
277 totalReward = episodeReturn;
280 network = learningNetwork;
291 pending[pendingIndex++] =
292 std::make_tuple(state, action, reward, nextState, nextAction);
297 arma::mat totalGradients(learningNetwork.Parameters().n_rows,
298 learningNetwork.Parameters().n_cols, arma::fill::zeros);
299 for (
size_t i = 0; i < pending.size(); ++i)
304 arma::colvec actionValue;
307 targetNetwork.Predict(
308 std::get<3>(transition).Encode(), actionValue);
310 double targetActionValue = 0;
311 if (!(terminal && i == pending.size() - 1))
312 targetActionValue = actionValue[std::get<4>(transition).action];
313 targetActionValue = std::get<2>(transition) +
314 config.
Discount() * targetActionValue;
317 arma::mat input = std::get<0>(transition).Encode();
318 network.Forward(input, actionValue);
319 actionValue[std::get<1>(transition).action] = targetActionValue;
323 network.Backward(input, actionValue, gradients);
326 totalGradients += gradients;
330 totalGradients.transform(
332 {
return std::min(std::max(gradient, -config.
GradientLimit()),
336 #if ENS_VERSION_MAJOR == 1
337 updater.Update(learningNetwork.Parameters(), config.
StepSize(),
340 updatePolicy->Update(learningNetwork.Parameters(),
345 network = learningNetwork;
354 { targetNetwork = learningNetwork; }
361 totalReward = episodeReturn;
379 state = environment.InitialSample();
380 using actions =
typename EnvironmentType::Action::actions;
381 action.action =
static_cast<actions
>(ActionType::size);
386 #if ENS_VERSION_MAJOR >= 2
387 typename UpdaterType::template Policy<arma::mat, arma::mat>* updatePolicy;
391 EnvironmentType environment;
394 TrainingConfig config;
403 double episodeReturn;
406 std::vector<TransitionType> pending;
if(NOT BUILD_GO_SHLIB) macro(add_go_binding name) endmacro() return() endif() endmacro() macro(post_go_setup) if(BUILD_GO_BINDINGS) file(APPEND "$
cannot build Julia bindings endif() else() find_package(Julia 0.7.0) if(NOT JULIA_FOUND) unset(BUILD_JULIA_BINDINGS CACHE) endif() endif() if(NOT JULIA_FOUND) not_found_return("Julia not found
Forward declaration of OneStepSarsaWorker.
OneStepSarsaWorker(const UpdaterType &updater, const EnvironmentType &environment, const TrainingConfig &config, bool deterministic)
Construct one step sarsa worker with the given parameters and environment.
OneStepSarsaWorker(const OneStepSarsaWorker &other)
Copy another OneStepSarsaWorker.
OneStepSarsaWorker & operator=(const OneStepSarsaWorker &other)
Copy another OneStepSarsaWorker.
std::tuple< StateType, ActionType, double, StateType, ActionType > TransitionType
OneStepSarsaWorker & operator=(OneStepSarsaWorker &&other)
Take ownership of another OneStepSarsaWorker.
OneStepSarsaWorker(OneStepSarsaWorker &&other)
Take ownership of another OneStepSarsaWorker.
~OneStepSarsaWorker()
Clean memory.
typename EnvironmentType::Action ActionType
void Initialize(NetworkType &learningNetwork)
Initialize the worker.
bool Step(NetworkType &learningNetwork, NetworkType &targetNetwork, size_t &totalSteps, PolicyType &policy, double &totalReward)
The agent will execute one step.
typename EnvironmentType::State StateType
double StepSize() const
Get the step size of the optimizer.
size_t UpdateInterval() const
Get the update interval.
double GradientLimit() const
Get the limit of update gradient.
double Discount() const
Get the discount rate for future reward.
size_t StepLimit() const
Get the maximum steps of each episode.
size_t TargetNetworkSyncInterval() const
Get the interval for syncing target network.
Linear algebra utility functions, generally performed on matrices or vectors.