12#ifndef MLPACK_METHODS_RL_DUELING_DQN_HPP
13#define MLPACK_METHODS_RL_DUELING_DQN_HPP
67 concat->
Add(valueNetwork);
68 concat->
Add(advantageNetwork);
70 completeNetwork.Add(featureNetwork);
71 completeNetwork.Add(concat);
89 const bool isNoisy =
false,
90 InitType init = InitType(),
91 OutputLayerType outputLayer = OutputLayerType()):
92 completeNetwork(outputLayer, init),
96 featureNetwork->Add(
new Linear<>(inputDim, h1));
104 noisyLayerIndex.push_back(valueNetwork->Model().size());
111 noisyLayerIndex.push_back(valueNetwork->Model().size());
117 valueNetwork->Add(
new Linear<>(h1, h2));
119 valueNetwork->Add(
new Linear<>(h2, 1));
121 advantageNetwork->Add(
new Linear<>(h1, h2));
123 advantageNetwork->Add(
new Linear<>(h2, outputDim));
127 concat->
Add(valueNetwork);
128 concat->
Add(advantageNetwork);
131 completeNetwork.Add(featureNetwork);
132 completeNetwork.Add(concat);
145 AdvantageNetworkType& advantageNetwork,
146 ValueNetworkType& valueNetwork,
147 const bool isNoisy =
false):
148 featureNetwork(featureNetwork),
149 advantageNetwork(advantageNetwork),
150 valueNetwork(valueNetwork),
154 concat->
Add(valueNetwork);
155 concat->
Add(advantageNetwork);
157 completeNetwork.Add(featureNetwork);
158 completeNetwork.Add(concat);
169 *valueNetwork = *model.valueNetwork;
170 *advantageNetwork = *model.advantageNetwork;
171 *featureNetwork = *model.featureNetwork;
172 isNoisy = model.isNoisy;
173 noisyLayerIndex = model.noisyLayerIndex;
187 void Predict(
const arma::mat state, arma::mat& actionValue)
189 arma::mat advantage, value, networkOutput;
190 completeNetwork.Predict(state, networkOutput);
191 value = networkOutput.row(0);
192 advantage = networkOutput.rows(1, networkOutput.n_rows - 1);
193 actionValue = advantage.each_row() +
194 (value - arma::mean(advantage));
203 void Forward(
const arma::mat state, arma::mat& actionValue)
205 arma::mat advantage, value, networkOutput;
206 completeNetwork.Forward(state, networkOutput);
207 value = networkOutput.row(0);
208 advantage = networkOutput.rows(1, networkOutput.n_rows - 1);
209 actionValue = advantage.each_row() +
210 (value - arma::mean(advantage));
211 this->actionValues = actionValue;
221 void Backward(
const arma::mat state, arma::mat& target, arma::mat& gradient)
224 lossFunction.
Backward(this->actionValues, target, gradLoss);
226 arma::mat gradValue = arma::sum(gradLoss);
227 arma::mat gradAdvantage = gradLoss.each_row() - arma::mean(gradLoss);
229 arma::mat grad = arma::join_cols(gradValue, gradAdvantage);
230 completeNetwork.Backward(state, grad, gradient);
238 completeNetwork.ResetParameters();
246 for (
size_t i = 0; i < noisyLayerIndex.size(); i++)
248 boost::get<NoisyLinear<>*>
249 (valueNetwork->Model()[noisyLayerIndex[i]])->
ResetNoise();
250 boost::get<NoisyLinear<>*>
251 (advantageNetwork->Model()[noisyLayerIndex[i]])->
ResetNoise();
256 const arma::mat&
Parameters()
const {
return completeNetwork.Parameters(); }
258 arma::mat&
Parameters() {
return completeNetwork.Parameters(); }
262 CompleteNetworkType completeNetwork;
268 FeatureNetworkType* featureNetwork;
271 AdvantageNetworkType* advantageNetwork;
274 ValueNetworkType* valueNetwork;
280 std::vector<size_t> noisyLayerIndex;
283 arma::mat actionValues;
Implementation of the base layer.
Implementation of the Concat class.
The empty loss does nothing, letting the user calculate the loss outside the model.
Implementation of a standard feed forward network.
This class is used to initialize weigth matrix with a gaussian.
Implementation of the Linear layer class.
The mean squared error performance function measures the network's performance according to the mean ...
void Backward(const InputType &input, const TargetType &target, OutputType &output)
Ordinary feed backward pass of a neural network.
Implementation of the NoisyLinear layer class.
Implementation of the Dueling Deep Q-Learning network.
void ResetNoise()
Resets noise of the network, if the network is of type noisy.
arma::mat & Parameters()
Modify the Parameters.
DuelingDQN()
Default constructor.
const arma::mat & Parameters() const
Return the Parameters.
void Forward(const arma::mat state, arma::mat &actionValue)
Perform the forward pass of the states in real batch mode.
void ResetParameters()
Resets the parameters of the network.
DuelingDQN(const DuelingDQN &)
Copy constructor.
DuelingDQN(const int inputDim, const int h1, const int h2, const int outputDim, const bool isNoisy=false, InitType init=InitType(), OutputLayerType outputLayer=OutputLayerType())
Construct an instance of DuelingDQN class.
void Backward(const arma::mat state, arma::mat &target, arma::mat &gradient)
Perform the backward pass of the state in real batch mode.
DuelingDQN(FeatureNetworkType &featureNetwork, AdvantageNetworkType &advantageNetwork, ValueNetworkType &valueNetwork, const bool isNoisy=false)
Construct an instance of DuelingDQN class from a pre-constructed network.
void Predict(const arma::mat state, arma::mat &actionValue)
Predict the responses to a given set of predictors.
void operator=(const DuelingDQN &model)
Copy assignment operator.
Artificial Neural Network.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.