12#ifndef MLPACK_METHODS_RL_CATEGORICAL_DQN_HPP
13#define MLPACK_METHODS_RL_CATEGORICAL_DQN_HPP
21#include "../training_config.hpp"
76 const bool isNoisy =
false,
77 InitType init = InitType(),
78 OutputLayerType outputLayer = OutputLayerType()):
79 network(outputLayer, init),
80 atomSize(config.AtomSize()),
85 network.Add(
new Linear<>(inputDim, h1));
89 noisyLayerIndex.push_back(network.Model().size());
92 noisyLayerIndex.push_back(network.Model().size());
99 network.Add(
new Linear<>(h2, outputDim * atomSize));
112 const bool isNoisy =
false):
113 network(
std::move(network)),
114 atomSize(config.AtomSize()),
131 void Predict(
const arma::mat state, arma::mat& actionValue)
134 network.Predict(state, q_atoms);
135 activations.copy_size(q_atoms);
136 actionValue.set_size(q_atoms.n_rows / atomSize, q_atoms.n_cols);
137 arma::rowvec support = arma::linspace<arma::rowvec>(vMin, vMax, atomSize);
138 for (
size_t i = 0; i < q_atoms.n_rows; i += atomSize)
140 arma::mat activation = activations.rows(i, i + atomSize - 1);
141 arma::mat input = q_atoms.rows(i, i + atomSize - 1);
142 softMax.
Forward(input, activation);
143 activations.rows(i, i + atomSize - 1) = activation;
144 actionValue.row(i/atomSize) = support * activation;
154 void Forward(
const arma::mat state, arma::mat& dist)
157 network.Forward(state, q_atoms);
158 activations.copy_size(q_atoms);
159 for (
size_t i = 0; i < q_atoms.n_rows; i += atomSize)
161 arma::mat activation = activations.rows(i, i + atomSize - 1);
162 arma::mat input = q_atoms.rows(i, i + atomSize - 1);
163 softMax.
Forward(input, activation);
164 activations.rows(i, i + atomSize - 1) = activation;
174 network.ResetParameters();
182 for (
size_t i = 0; i < noisyLayerIndex.size(); i++)
184 boost::get<NoisyLinear<>*>
185 (network.Model()[noisyLayerIndex[i]])->
ResetNoise();
190 const arma::mat&
Parameters()
const {
return network.Parameters(); }
202 arma::mat& lossGradients,
205 arma::mat activationGradients(arma::size(activations));
206 for (
size_t i = 0; i < activations.n_rows; i += atomSize)
208 arma::mat activationGrad;
209 arma::mat lossGrad = lossGradients.rows(i, i + atomSize - 1);
210 arma::mat activation = activations.rows(i, i + atomSize - 1);
211 softMax.
Backward(activation, lossGrad, activationGrad);
212 activationGradients.rows(i, i + atomSize - 1) = activationGrad;
214 network.Backward(state, activationGradients, gradient);
234 std::vector<size_t> noisyLayerIndex;
240 arma::mat activations;
Implementation of the base layer.
The empty loss does nothing, letting the user calculate the loss outside the model.
Implementation of a standard feed forward network.
This class is used to initialize weigth matrix with a gaussian.
Implementation of the Linear layer class.
Implementation of the NoisyLinear layer class.
Implementation of the Softmax layer.
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Implementation of the Categorical Deep Q-Learning network.
void ResetNoise()
Resets noise of the network, if the network is of type noisy.
arma::mat & Parameters()
Modify the Parameters.
const arma::mat & Parameters() const
Return the Parameters.
CategoricalDQN()
Default constructor.
void Forward(const arma::mat state, arma::mat &dist)
Perform the forward pass of the states in real batch mode.
CategoricalDQN(const int inputDim, const int h1, const int h2, const int outputDim, TrainingConfig config, const bool isNoisy=false, InitType init=InitType(), OutputLayerType outputLayer=OutputLayerType())
Construct an instance of CategoricalDQN class.
CategoricalDQN(NetworkType &network, TrainingConfig config, const bool isNoisy=false)
Construct an instance of CategoricalDQN class from a pre-constructed network.
void ResetParameters()
Resets the parameters of the network.
void Predict(const arma::mat state, arma::mat &actionValue)
Predict the responses to a given set of predictors.
void Backward(const arma::mat state, arma::mat &lossGradients, arma::mat &gradient)
Perform the backward pass of the state in real batch mode.
Artificial Neural Network.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.