mlpack 3.4.2
dueling_dqn.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_METHODS_RL_DUELING_DQN_HPP
13#define MLPACK_METHODS_RL_DUELING_DQN_HPP
14
15#include <mlpack/prereqs.hpp>
21
22namespace mlpack {
23namespace rl {
24
25using namespace mlpack::ann;
26
48template <
49 typename OutputLayerType = EmptyLoss<>,
50 typename InitType = GaussianInitialization,
51 typename CompleteNetworkType = FFN<OutputLayerType, InitType>,
52 typename FeatureNetworkType = Sequential<>,
53 typename AdvantageNetworkType = Sequential<>,
54 typename ValueNetworkType = Sequential<>
55>
57{
58 public:
60 DuelingDQN() : isNoisy(false)
61 {
62 featureNetwork = new Sequential<>();
63 valueNetwork = new Sequential<>();
64 advantageNetwork = new Sequential<>();
65 concat = new Concat<>(true);
66
67 concat->Add(valueNetwork);
68 concat->Add(advantageNetwork);
69 completeNetwork.Add(new IdentityLayer<>());
70 completeNetwork.Add(featureNetwork);
71 completeNetwork.Add(concat);
72 }
73
85 DuelingDQN(const int inputDim,
86 const int h1,
87 const int h2,
88 const int outputDim,
89 const bool isNoisy = false,
90 InitType init = InitType(),
91 OutputLayerType outputLayer = OutputLayerType()):
92 completeNetwork(outputLayer, init),
93 isNoisy(isNoisy)
94 {
95 featureNetwork = new Sequential<>();
96 featureNetwork->Add(new Linear<>(inputDim, h1));
97 featureNetwork->Add(new ReLULayer<>());
98
99 valueNetwork = new Sequential<>();
100 advantageNetwork = new Sequential<>();
101
102 if (isNoisy)
103 {
104 noisyLayerIndex.push_back(valueNetwork->Model().size());
105 valueNetwork->Add(new NoisyLinear<>(h1, h2));
106 advantageNetwork->Add(new NoisyLinear<>(h1, h2));
107
108 valueNetwork->Add(new ReLULayer<>());
109 advantageNetwork->Add(new ReLULayer<>());
110
111 noisyLayerIndex.push_back(valueNetwork->Model().size());
112 valueNetwork->Add(new NoisyLinear<>(h2, 1));
113 advantageNetwork->Add(new NoisyLinear<>(h2, outputDim));
114 }
115 else
116 {
117 valueNetwork->Add(new Linear<>(h1, h2));
118 valueNetwork->Add(new ReLULayer<>());
119 valueNetwork->Add(new Linear<>(h2, 1));
120
121 advantageNetwork->Add(new Linear<>(h1, h2));
122 advantageNetwork->Add(new ReLULayer<>());
123 advantageNetwork->Add(new Linear<>(h2, outputDim));
124 }
125
126 concat = new Concat<>(true);
127 concat->Add(valueNetwork);
128 concat->Add(advantageNetwork);
129
130 completeNetwork.Add(new IdentityLayer<>());
131 completeNetwork.Add(featureNetwork);
132 completeNetwork.Add(concat);
133 this->ResetParameters();
134 }
135
144 DuelingDQN(FeatureNetworkType& featureNetwork,
145 AdvantageNetworkType& advantageNetwork,
146 ValueNetworkType& valueNetwork,
147 const bool isNoisy = false):
148 featureNetwork(featureNetwork),
149 advantageNetwork(advantageNetwork),
150 valueNetwork(valueNetwork),
151 isNoisy(isNoisy)
152 {
153 concat = new Concat<>(true);
154 concat->Add(valueNetwork);
155 concat->Add(advantageNetwork);
156 completeNetwork.Add(new IdentityLayer<>());
157 completeNetwork.Add(featureNetwork);
158 completeNetwork.Add(concat);
159 this->ResetParameters();
160 }
161
163 DuelingDQN(const DuelingDQN& /* model */) : isNoisy(false)
164 { /* Nothing to do here. */ }
165
167 void operator = (const DuelingDQN& model)
168 {
169 *valueNetwork = *model.valueNetwork;
170 *advantageNetwork = *model.advantageNetwork;
171 *featureNetwork = *model.featureNetwork;
172 isNoisy = model.isNoisy;
173 noisyLayerIndex = model.noisyLayerIndex;
174 }
175
187 void Predict(const arma::mat state, arma::mat& actionValue)
188 {
189 arma::mat advantage, value, networkOutput;
190 completeNetwork.Predict(state, networkOutput);
191 value = networkOutput.row(0);
192 advantage = networkOutput.rows(1, networkOutput.n_rows - 1);
193 actionValue = advantage.each_row() +
194 (value - arma::mean(advantage));
195 }
196
203 void Forward(const arma::mat state, arma::mat& actionValue)
204 {
205 arma::mat advantage, value, networkOutput;
206 completeNetwork.Forward(state, networkOutput);
207 value = networkOutput.row(0);
208 advantage = networkOutput.rows(1, networkOutput.n_rows - 1);
209 actionValue = advantage.each_row() +
210 (value - arma::mean(advantage));
211 this->actionValues = actionValue;
212 }
213
221 void Backward(const arma::mat state, arma::mat& target, arma::mat& gradient)
222 {
223 arma::mat gradLoss;
224 lossFunction.Backward(this->actionValues, target, gradLoss);
225
226 arma::mat gradValue = arma::sum(gradLoss);
227 arma::mat gradAdvantage = gradLoss.each_row() - arma::mean(gradLoss);
228
229 arma::mat grad = arma::join_cols(gradValue, gradAdvantage);
230 completeNetwork.Backward(state, grad, gradient);
231 }
232
237 {
238 completeNetwork.ResetParameters();
239 }
240
245 {
246 for (size_t i = 0; i < noisyLayerIndex.size(); i++)
247 {
248 boost::get<NoisyLinear<>*>
249 (valueNetwork->Model()[noisyLayerIndex[i]])->ResetNoise();
250 boost::get<NoisyLinear<>*>
251 (advantageNetwork->Model()[noisyLayerIndex[i]])->ResetNoise();
252 }
253 }
254
256 const arma::mat& Parameters() const { return completeNetwork.Parameters(); }
258 arma::mat& Parameters() { return completeNetwork.Parameters(); }
259
260 private:
262 CompleteNetworkType completeNetwork;
263
265 Concat<>* concat;
266
268 FeatureNetworkType* featureNetwork;
269
271 AdvantageNetworkType* advantageNetwork;
272
274 ValueNetworkType* valueNetwork;
275
277 bool isNoisy;
278
280 std::vector<size_t> noisyLayerIndex;
281
283 arma::mat actionValues;
284
286 MeanSquaredError<> lossFunction;
287};
288
289} // namespace rl
290} // namespace mlpack
291
292#endif
Implementation of the base layer.
Definition: base_layer.hpp:66
Implementation of the Concat class.
Definition: concat.hpp:46
void Add(Args... args)
Definition: concat.hpp:147
The empty loss does nothing, letting the user calculate the loss outside the model.
Definition: empty_loss.hpp:36
Implementation of a standard feed forward network.
Definition: ffn.hpp:53
This class is used to initialize weigth matrix with a gaussian.
Implementation of the Linear layer class.
Definition: linear.hpp:39
The mean squared error performance function measures the network's performance according to the mean ...
void Backward(const InputType &input, const TargetType &target, OutputType &output)
Ordinary feed backward pass of a neural network.
Implementation of the NoisyLinear layer class.
Definition: noisylinear.hpp:34
Implementation of the Dueling Deep Q-Learning network.
Definition: dueling_dqn.hpp:57
void ResetNoise()
Resets noise of the network, if the network is of type noisy.
arma::mat & Parameters()
Modify the Parameters.
DuelingDQN()
Default constructor.
Definition: dueling_dqn.hpp:60
const arma::mat & Parameters() const
Return the Parameters.
void Forward(const arma::mat state, arma::mat &actionValue)
Perform the forward pass of the states in real batch mode.
void ResetParameters()
Resets the parameters of the network.
DuelingDQN(const DuelingDQN &)
Copy constructor.
DuelingDQN(const int inputDim, const int h1, const int h2, const int outputDim, const bool isNoisy=false, InitType init=InitType(), OutputLayerType outputLayer=OutputLayerType())
Construct an instance of DuelingDQN class.
Definition: dueling_dqn.hpp:85
void Backward(const arma::mat state, arma::mat &target, arma::mat &gradient)
Perform the backward pass of the state in real batch mode.
DuelingDQN(FeatureNetworkType &featureNetwork, AdvantageNetworkType &advantageNetwork, ValueNetworkType &valueNetwork, const bool isNoisy=false)
Construct an instance of DuelingDQN class from a pre-constructed network.
void Predict(const arma::mat state, arma::mat &actionValue)
Predict the responses to a given set of predictors.
void operator=(const DuelingDQN &model)
Copy assignment operator.
Artificial Neural Network.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.