mlpack 3.4.2
sac.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_RL_SAC_HPP
14#define MLPACK_METHODS_RL_SAC_HPP
15
16#include <mlpack/prereqs.hpp>
17
22#include "training_config.hpp"
23
24namespace mlpack {
25namespace rl {
26
56template <
57 typename EnvironmentType,
58 typename QNetworkType,
59 typename PolicyNetworkType,
60 typename UpdaterType,
61 typename ReplayType = RandomReplay<EnvironmentType>
62>
63class SAC
64{
65 public:
67 using StateType = typename EnvironmentType::State;
68
70 using ActionType = typename EnvironmentType::Action;
71
89 QNetworkType& learningQ1Network,
90 PolicyNetworkType& policyNetwork,
91 ReplayType& replayMethod,
92 UpdaterType qNetworkUpdater = UpdaterType(),
93 UpdaterType policyNetworkUpdater = UpdaterType(),
94 EnvironmentType environment = EnvironmentType());
95
100
107 void SoftUpdate(double rho);
108
112 void Update();
113
118
123 double Episode();
124
126 size_t& TotalSteps() { return totalSteps; }
128 const size_t& TotalSteps() const { return totalSteps; }
129
131 StateType& State() { return state; }
133 const StateType& State() const { return state; }
134
136 const ActionType& Action() const { return action; }
137
139 bool& Deterministic() { return deterministic; }
141 const bool& Deterministic() const { return deterministic; }
142
143
144 private:
146 TrainingConfig& config;
147
149 QNetworkType& learningQ1Network;
150 QNetworkType learningQ2Network;
151
153 QNetworkType targetQ1Network;
154 QNetworkType targetQ2Network;
155
157 PolicyNetworkType& policyNetwork;
158
160 ReplayType& replayMethod;
161
163 UpdaterType qNetworkUpdater;
164 #if ENS_VERSION_MAJOR >= 2
165 typename UpdaterType::template Policy<arma::mat, arma::mat>*
166 qNetworkUpdatePolicy;
167 #endif
168
170 UpdaterType policyNetworkUpdater;
171 #if ENS_VERSION_MAJOR >= 2
172 typename UpdaterType::template Policy<arma::mat, arma::mat>*
173 policyNetworkUpdatePolicy;
174 #endif
175
177 EnvironmentType environment;
178
180 size_t totalSteps;
181
183 StateType state;
184
186 ActionType action;
187
189 bool deterministic;
190
193};
194
195} // namespace rl
196} // namespace mlpack
197
198// Include implementation
199#include "sac_impl.hpp"
200#endif
The mean squared error performance function measures the network's performance according to the mean ...
Implementation of Soft Actor-Critic, a model-free off-policy actor-critic based deep reinforcement le...
Definition: sac.hpp:64
StateType & State()
Modify the state of the agent.
Definition: sac.hpp:131
void SoftUpdate(double rho)
Softly update the learning Q network parameters to the target Q network parameters.
double Episode()
Execute an episode.
SAC(TrainingConfig &config, QNetworkType &learningQ1Network, PolicyNetworkType &policyNetwork, ReplayType &replayMethod, UpdaterType qNetworkUpdater=UpdaterType(), UpdaterType policyNetworkUpdater=UpdaterType(), EnvironmentType environment=EnvironmentType())
Create the SAC object with given settings.
size_t & TotalSteps()
Modify total steps from beginning.
Definition: sac.hpp:126
const bool & Deterministic() const
Get the indicator of training mode / test mode.
Definition: sac.hpp:141
const size_t & TotalSteps() const
Get total steps from beginning.
Definition: sac.hpp:128
bool & Deterministic()
Modify the training mode / test mode indicator.
Definition: sac.hpp:139
typename EnvironmentType::Action ActionType
Convenient typedef for action.
Definition: sac.hpp:70
const StateType & State() const
Get the state of the agent.
Definition: sac.hpp:133
const ActionType & Action() const
Get the action of the agent.
Definition: sac.hpp:136
void SelectAction()
Select an action, given an agent.
~SAC()
Clean memory.
typename EnvironmentType::State StateType
Convenient typedef for state.
Definition: sac.hpp:67
void Update()
Update the Q and policy networks.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.