mlpack 3.4.2
simple_dqn.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_METHODS_RL_SIMPLE_DQN_HPP
13#define MLPACK_METHODS_RL_SIMPLE_DQN_HPP
14
15#include <mlpack/prereqs.hpp>
20
21namespace mlpack {
22namespace rl {
23
24using namespace mlpack::ann;
25
31template<
32 typename OutputLayerType = MeanSquaredError<>,
33 typename InitType = GaussianInitialization,
34 typename NetworkType = FFN<OutputLayerType, InitType>
35>
37{
38 public:
42 SimpleDQN() : network(), isNoisy(false)
43 { /* Nothing to do here. */ }
44
56 SimpleDQN(const int inputDim,
57 const int h1,
58 const int h2,
59 const int outputDim,
60 const bool isNoisy = false,
61 InitType init = InitType(),
62 OutputLayerType outputLayer = OutputLayerType()):
63 network(outputLayer, init),
64 isNoisy(isNoisy)
65 {
66 network.Add(new Linear<>(inputDim, h1));
67 network.Add(new ReLULayer<>());
68 if (isNoisy)
69 {
70 noisyLayerIndex.push_back(network.Model().size());
71 network.Add(new NoisyLinear<>(h1, h2));
72 network.Add(new ReLULayer<>());
73 noisyLayerIndex.push_back(network.Model().size());
74 network.Add(new NoisyLinear<>(h2, outputDim));
75 }
76 else
77 {
78 network.Add(new Linear<>(h1, h2));
79 network.Add(new ReLULayer<>());
80 network.Add(new Linear<>(h2, outputDim));
81 }
82 }
83
90 SimpleDQN(NetworkType& network, const bool isNoisy = false):
91 network(network),
92 isNoisy(isNoisy)
93 { /* Nothing to do here. */ }
94
106 void Predict(const arma::mat state, arma::mat& actionValue)
107 {
108 network.Predict(state, actionValue);
109 }
110
117 void Forward(const arma::mat state, arma::mat& target)
118 {
119 network.Forward(state, target);
120 }
121
126 {
127 network.ResetParameters();
128 }
129
134 {
135 for (size_t i = 0; i < noisyLayerIndex.size(); i++)
136 {
137 boost::get<NoisyLinear<>*>
138 (network.Model()[noisyLayerIndex[i]])->ResetNoise();
139 }
140 }
141
143 const arma::mat& Parameters() const { return network.Parameters(); }
145 arma::mat& Parameters() { return network.Parameters(); }
146
154 void Backward(const arma::mat state, arma::mat& target, arma::mat& gradient)
155 {
156 network.Backward(state, target, gradient);
157 }
158
159 private:
161 NetworkType network;
162
164 bool isNoisy;
165
167 std::vector<size_t> noisyLayerIndex;
168};
169
170} // namespace rl
171} // namespace mlpack
172
173#endif
Implementation of the base layer.
Definition: base_layer.hpp:66
Implementation of a standard feed forward network.
Definition: ffn.hpp:53
This class is used to initialize weigth matrix with a gaussian.
Implementation of the Linear layer class.
Definition: linear.hpp:39
The mean squared error performance function measures the network's performance according to the mean ...
Implementation of the NoisyLinear layer class.
Definition: noisylinear.hpp:34
void ResetNoise()
Resets noise of the network, if the network is of type noisy.
Definition: simple_dqn.hpp:133
arma::mat & Parameters()
Modify the Parameters.
Definition: simple_dqn.hpp:145
SimpleDQN(NetworkType &network, const bool isNoisy=false)
Construct an instance of SimpleDQN class from a pre-constructed network.
Definition: simple_dqn.hpp:90
SimpleDQN()
Default constructor.
Definition: simple_dqn.hpp:42
const arma::mat & Parameters() const
Return the Parameters.
Definition: simple_dqn.hpp:143
void ResetParameters()
Resets the parameters of the network.
Definition: simple_dqn.hpp:125
void Forward(const arma::mat state, arma::mat &target)
Perform the forward pass of the states in real batch mode.
Definition: simple_dqn.hpp:117
SimpleDQN(const int inputDim, const int h1, const int h2, const int outputDim, const bool isNoisy=false, InitType init=InitType(), OutputLayerType outputLayer=OutputLayerType())
Construct an instance of SimpleDQN class.
Definition: simple_dqn.hpp:56
void Backward(const arma::mat state, arma::mat &target, arma::mat &gradient)
Perform the backward pass of the state in real batch mode.
Definition: simple_dqn.hpp:154
void Predict(const arma::mat state, arma::mat &actionValue)
Predict the responses to a given set of predictors.
Definition: simple_dqn.hpp:106
Artificial Neural Network.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.