mlpack 3.4.2
categorical_dqn.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_METHODS_RL_CATEGORICAL_DQN_HPP
13#define MLPACK_METHODS_RL_CATEGORICAL_DQN_HPP
14
15#include <mlpack/prereqs.hpp>
21#include "../training_config.hpp"
22
23namespace mlpack {
24namespace rl {
25
26using namespace mlpack::ann;
27
45template<
46 typename OutputLayerType = EmptyLoss<>,
47 typename InitType = GaussianInitialization,
48 typename NetworkType = FFN<OutputLayerType, InitType>
49>
51{
52 public:
56 CategoricalDQN() : network(), isNoisy(false)
57 { /* Nothing to do here. */ }
58
71 CategoricalDQN(const int inputDim,
72 const int h1,
73 const int h2,
74 const int outputDim,
75 TrainingConfig config,
76 const bool isNoisy = false,
77 InitType init = InitType(),
78 OutputLayerType outputLayer = OutputLayerType()):
79 network(outputLayer, init),
80 atomSize(config.AtomSize()),
81 vMin(config.VMin()),
82 vMax(config.VMax()),
83 isNoisy(isNoisy)
84 {
85 network.Add(new Linear<>(inputDim, h1));
86 network.Add(new ReLULayer<>());
87 if (isNoisy)
88 {
89 noisyLayerIndex.push_back(network.Model().size());
90 network.Add(new NoisyLinear<>(h1, h2));
91 network.Add(new ReLULayer<>());
92 noisyLayerIndex.push_back(network.Model().size());
93 network.Add(new NoisyLinear<>(h2, outputDim * atomSize));
94 }
95 else
96 {
97 network.Add(new Linear<>(h1, h2));
98 network.Add(new ReLULayer<>());
99 network.Add(new Linear<>(h2, outputDim * atomSize));
100 }
101 }
102
110 CategoricalDQN(NetworkType& network,
111 TrainingConfig config,
112 const bool isNoisy = false):
113 network(std::move(network)),
114 atomSize(config.AtomSize()),
115 vMin(config.VMin()),
116 vMax(config.VMax()),
117 isNoisy(isNoisy)
118 { /* Nothing to do here. */ }
119
131 void Predict(const arma::mat state, arma::mat& actionValue)
132 {
133 arma::mat q_atoms;
134 network.Predict(state, q_atoms);
135 activations.copy_size(q_atoms);
136 actionValue.set_size(q_atoms.n_rows / atomSize, q_atoms.n_cols);
137 arma::rowvec support = arma::linspace<arma::rowvec>(vMin, vMax, atomSize);
138 for (size_t i = 0; i < q_atoms.n_rows; i += atomSize)
139 {
140 arma::mat activation = activations.rows(i, i + atomSize - 1);
141 arma::mat input = q_atoms.rows(i, i + atomSize - 1);
142 softMax.Forward(input, activation);
143 activations.rows(i, i + atomSize - 1) = activation;
144 actionValue.row(i/atomSize) = support * activation;
145 }
146 }
147
154 void Forward(const arma::mat state, arma::mat& dist)
155 {
156 arma::mat q_atoms;
157 network.Forward(state, q_atoms);
158 activations.copy_size(q_atoms);
159 for (size_t i = 0; i < q_atoms.n_rows; i += atomSize)
160 {
161 arma::mat activation = activations.rows(i, i + atomSize - 1);
162 arma::mat input = q_atoms.rows(i, i + atomSize - 1);
163 softMax.Forward(input, activation);
164 activations.rows(i, i + atomSize - 1) = activation;
165 }
166 dist = activations;
167 }
168
173 {
174 network.ResetParameters();
175 }
176
181 {
182 for (size_t i = 0; i < noisyLayerIndex.size(); i++)
183 {
184 boost::get<NoisyLinear<>*>
185 (network.Model()[noisyLayerIndex[i]])->ResetNoise();
186 }
187 }
188
190 const arma::mat& Parameters() const { return network.Parameters(); }
192 arma::mat& Parameters() { return network.Parameters(); }
193
201 void Backward(const arma::mat state,
202 arma::mat& lossGradients,
203 arma::mat& gradient)
204 {
205 arma::mat activationGradients(arma::size(activations));
206 for (size_t i = 0; i < activations.n_rows; i += atomSize)
207 {
208 arma::mat activationGrad;
209 arma::mat lossGrad = lossGradients.rows(i, i + atomSize - 1);
210 arma::mat activation = activations.rows(i, i + atomSize - 1);
211 softMax.Backward(activation, lossGrad, activationGrad);
212 activationGradients.rows(i, i + atomSize - 1) = activationGrad;
213 }
214 network.Backward(state, activationGradients, gradient);
215 }
216
217 private:
219 NetworkType network;
220
222 size_t atomSize;
223
225 double vMin;
226
228 double vMax;
229
231 bool isNoisy;
232
234 std::vector<size_t> noisyLayerIndex;
235
237 Softmax<> softMax;
238
240 arma::mat activations;
241};
242
243} // namespace rl
244} // namespace mlpack
245
246#endif
Implementation of the base layer.
Definition: base_layer.hpp:66
The empty loss does nothing, letting the user calculate the loss outside the model.
Definition: empty_loss.hpp:36
Implementation of a standard feed forward network.
Definition: ffn.hpp:53
This class is used to initialize weigth matrix with a gaussian.
Implementation of the Linear layer class.
Definition: linear.hpp:39
Implementation of the NoisyLinear layer class.
Definition: noisylinear.hpp:34
Implementation of the Softmax layer.
Definition: softmax.hpp:39
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Implementation of the Categorical Deep Q-Learning network.
void ResetNoise()
Resets noise of the network, if the network is of type noisy.
arma::mat & Parameters()
Modify the Parameters.
const arma::mat & Parameters() const
Return the Parameters.
CategoricalDQN()
Default constructor.
void Forward(const arma::mat state, arma::mat &dist)
Perform the forward pass of the states in real batch mode.
CategoricalDQN(const int inputDim, const int h1, const int h2, const int outputDim, TrainingConfig config, const bool isNoisy=false, InitType init=InitType(), OutputLayerType outputLayer=OutputLayerType())
Construct an instance of CategoricalDQN class.
CategoricalDQN(NetworkType &network, TrainingConfig config, const bool isNoisy=false)
Construct an instance of CategoricalDQN class from a pre-constructed network.
void ResetParameters()
Resets the parameters of the network.
void Predict(const arma::mat state, arma::mat &actionValue)
Predict the responses to a given set of predictors.
void Backward(const arma::mat state, arma::mat &lossGradients, arma::mat &gradient)
Perform the backward pass of the state in real batch mode.
Artificial Neural Network.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: prereqs.hpp:67
The core includes that mlpack expects; standard C++ includes and Armadillo.