13#ifndef MLPACK_METHODS_RL_WORKER_ONE_STEP_Q_LEARNING_WORKER_HPP
14#define MLPACK_METHODS_RL_WORKER_ONE_STEP_Q_LEARNING_WORKER_HPP
30 typename EnvironmentType,
52 const UpdaterType& updater,
53 const EnvironmentType& environment,
57 #
if ENS_VERSION_MAJOR >= 2
60 environment(environment),
62 deterministic(deterministic),
63 pending(config.UpdateInterval())
72 updater(other.updater),
73 #
if ENS_VERSION_MAJOR >= 2
76 environment(other.environment),
78 deterministic(other.deterministic),
80 episodeReturn(other.episodeReturn),
81 pending(other.pending),
82 pendingIndex(other.pendingIndex),
83 network(other.network),
86 #if ENS_VERSION_MAJOR >= 2
87 updatePolicy =
new typename UpdaterType::template
88 Policy<arma::mat, arma::mat>(updater,
89 network.Parameters().n_rows,
90 network.Parameters().n_cols);
102 updater(
std::move(other.updater)),
103 #
if ENS_VERSION_MAJOR >= 2
106 environment(
std::move(other.environment)),
107 config(
std::move(other.config)),
108 deterministic(
std::move(other.deterministic)),
109 steps(
std::move(other.steps)),
110 episodeReturn(
std::move(other.episodeReturn)),
111 pending(
std::move(other.pending)),
112 pendingIndex(
std::move(other.pendingIndex)),
113 network(
std::move(other.network)),
114 state(
std::move(other.state))
116 #if ENS_VERSION_MAJOR >= 2
117 other.updatePolicy = NULL;
119 updatePolicy =
new typename UpdaterType::template
120 Policy<arma::mat, arma::mat>(updater,
121 network.Parameters().n_rows,
122 network.Parameters().n_cols);
136 #if ENS_VERSION_MAJOR >= 2
140 updater = other.updater;
141 environment = other.environment;
142 config = other.config;
143 deterministic = other.deterministic;
145 episodeReturn = other.episodeReturn;
146 pending = other.pending;
147 pendingIndex = other.pendingIndex;
148 network = other.network;
151 #if ENS_VERSION_MAJOR >= 2
152 updatePolicy =
new typename UpdaterType::template
153 Policy<arma::mat, arma::mat>(updater,
154 network.Parameters().n_rows,
155 network.Parameters().n_cols);
173 #if ENS_VERSION_MAJOR >= 2
177 updater = std::move(other.updater);
178 environment = std::move(other.environment);
179 config = std::move(other.config);
180 deterministic = std::move(other.deterministic);
181 steps = std::move(other.steps);
182 episodeReturn = std::move(other.episodeReturn);
183 pending = std::move(other.pending);
184 pendingIndex = std::move(other.pendingIndex);
185 network = std::move(other.network);
186 state = std::move(other.state);
188 #if ENS_VERSION_MAJOR >= 2
189 other.updatePolicy = NULL;
191 updatePolicy =
new typename UpdaterType::template
192 Policy<arma::mat, arma::mat>(updater,
193 network.Parameters().n_rows,
194 network.Parameters().n_cols);
205 #if ENS_VERSION_MAJOR >= 2
216 #if ENS_VERSION_MAJOR == 1
217 updater.Initialize(learningNetwork.Parameters().n_rows,
218 learningNetwork.Parameters().n_cols);
222 updatePolicy =
new typename UpdaterType::template
223 Policy<arma::mat, arma::mat>(updater,
224 learningNetwork.Parameters().n_rows,
225 learningNetwork.Parameters().n_cols);
229 network = learningNetwork;
243 bool Step(NetworkType& learningNetwork,
244 NetworkType& targetNetwork,
250 arma::colvec actionValue;
251 network.Predict(state.Encode(), actionValue);
252 ActionType action = policy.Sample(actionValue, deterministic);
254 double reward = environment.Sample(state, action, nextState);
255 bool terminal = environment.IsTerminal(nextState);
257 episodeReturn += reward;
260 terminal = terminal || steps >= config.
StepLimit();
265 totalReward = episodeReturn;
268 network = learningNetwork;
278 pending[pendingIndex] = std::make_tuple(state, action, reward, nextState);
284 arma::mat totalGradients(learningNetwork.Parameters().n_rows,
285 learningNetwork.Parameters().n_cols, arma::fill::zeros);
286 for (
size_t i = 0; i < pending.size(); ++i)
291 arma::colvec actionValue;
294 targetNetwork.Predict(
295 std::get<3>(transition).Encode(), actionValue);
297 double targetActionValue = actionValue.max();
298 if (terminal && i == pending.size() - 1)
299 targetActionValue = 0;
300 targetActionValue = std::get<2>(transition) +
301 config.
Discount() * targetActionValue;
304 arma::mat input = std::get<0>(transition).Encode();
305 network.Forward(input, actionValue);
306 actionValue[std::get<1>(transition).action] = targetActionValue;
310 network.Backward(input, actionValue, gradients);
313 totalGradients += gradients;
317 totalGradients.transform(
319 {
return std::min(std::max(gradient, -config.
GradientLimit()),
323 #if ENS_VERSION_MAJOR == 1
324 updater.Update(learningNetwork.Parameters(), config.
StepSize(),
327 updatePolicy->Update(learningNetwork.Parameters(),
332 network = learningNetwork;
341 { targetNetwork = learningNetwork; }
348 totalReward = episodeReturn;
365 state = environment.InitialSample();
370 #if ENS_VERSION_MAJOR >= 2
371 typename UpdaterType::template Policy<arma::mat, arma::mat>* updatePolicy;
375 EnvironmentType environment;
378 TrainingConfig config;
387 double episodeReturn;
390 std::vector<TransitionType> pending;
if(NOT BUILD_GO_SHLIB) macro(add_go_binding name) endmacro() return() endif() endmacro() macro(post_go_setup) if(BUILD_GO_BINDINGS) file(APPEND "$
cannot build Julia bindings endif() else() find_package(Julia 0.7.0) if(NOT JULIA_FOUND) unset(BUILD_JULIA_BINDINGS CACHE) endif() endif() if(NOT JULIA_FOUND) not_found_return("Julia not found
Forward declaration of OneStepQLearningWorker.
OneStepQLearningWorker & operator=(const OneStepQLearningWorker &other)
Copy another OneStepQLearningWorker.
OneStepQLearningWorker(const UpdaterType &updater, const EnvironmentType &environment, const TrainingConfig &config, bool deterministic)
Construct one step Q-Learning worker with the given parameters and environment.
std::tuple< StateType, ActionType, double, StateType > TransitionType
OneStepQLearningWorker & operator=(OneStepQLearningWorker &&other)
Take ownership of another OneStepQLearningWorker.
~OneStepQLearningWorker()
Clean memory.
typename EnvironmentType::Action ActionType
OneStepQLearningWorker(const OneStepQLearningWorker &other)
Copy another OneStepQLearningWorker.
void Initialize(NetworkType &learningNetwork)
Initialize the worker.
OneStepQLearningWorker(OneStepQLearningWorker &&other)
Take ownership of another OneStepQLearningWorker.
bool Step(NetworkType &learningNetwork, NetworkType &targetNetwork, size_t &totalSteps, PolicyType &policy, double &totalReward)
The agent will execute one step.
typename EnvironmentType::State StateType
double StepSize() const
Get the step size of the optimizer.
size_t UpdateInterval() const
Get the update interval.
double GradientLimit() const
Get the limit of update gradient.
double Discount() const
Get the discount rate for future reward.
size_t StepLimit() const
Get the maximum steps of each episode.
size_t TargetNetworkSyncInterval() const
Get the interval for syncing target network.
Linear algebra utility functions, generally performed on matrices or vectors.