mlpack 3.4.2
|
Implementation of the Categorical Deep Q-Learning network. More...
#include <categorical_dqn.hpp>
Public Member Functions | |
CategoricalDQN () | |
Default constructor. More... | |
CategoricalDQN (const int inputDim, const int h1, const int h2, const int outputDim, TrainingConfig config, const bool isNoisy=false, InitType init=InitType(), OutputLayerType outputLayer=OutputLayerType()) | |
Construct an instance of CategoricalDQN class. More... | |
CategoricalDQN (NetworkType &network, TrainingConfig config, const bool isNoisy=false) | |
Construct an instance of CategoricalDQN class from a pre-constructed network. More... | |
void | Backward (const arma::mat state, arma::mat &lossGradients, arma::mat &gradient) |
Perform the backward pass of the state in real batch mode. More... | |
void | Forward (const arma::mat state, arma::mat &dist) |
Perform the forward pass of the states in real batch mode. More... | |
arma::mat & | Parameters () |
Modify the Parameters. More... | |
const arma::mat & | Parameters () const |
Return the Parameters. More... | |
void | Predict (const arma::mat state, arma::mat &actionValue) |
Predict the responses to a given set of predictors. More... | |
void | ResetNoise () |
Resets noise of the network, if the network is of type noisy. More... | |
void | ResetParameters () |
Resets the parameters of the network. More... | |
Implementation of the Categorical Deep Q-Learning network.
For more information, see the following.
OutputLayerType | The output layer type of the network. |
InitType | The initialization type used for the network. |
NetworkType | The type of network used for simple dqn. |
Definition at line 50 of file categorical_dqn.hpp.
|
inline |
Default constructor.
Definition at line 56 of file categorical_dqn.hpp.
|
inline |
Construct an instance of CategoricalDQN class.
inputDim | Number of inputs. |
h1 | Number of neurons in hiddenlayer-1. |
h2 | Number of neurons in hiddenlayer-2. |
outputDim | Number of neurons in output layer. |
config | Hyper-parameters for categorical dqn. |
isNoisy | Specifies whether the network needs to be of type noisy. |
init | Specifies the initialization rule for the network. |
outputLayer | Specifies the output layer type for network. |
Definition at line 71 of file categorical_dqn.hpp.
|
inline |
Construct an instance of CategoricalDQN class from a pre-constructed network.
network | The network to be used by CategoricalDQN class. |
config | Hyper-parameters for categorical dqn. |
isNoisy | Specifies whether the network needs to be of type noisy. |
Definition at line 110 of file categorical_dqn.hpp.
|
inline |
Perform the backward pass of the state in real batch mode.
state | The input state. |
lossGradients | The loss gradients. |
gradient | The gradient. |
Definition at line 201 of file categorical_dqn.hpp.
References Softmax< InputDataType, OutputDataType >::Backward().
|
inline |
Perform the forward pass of the states in real batch mode.
state | The input state. |
dist | The predicted distributions. |
Definition at line 154 of file categorical_dqn.hpp.
References Softmax< InputDataType, OutputDataType >::Forward().
|
inline |
Modify the Parameters.
Definition at line 192 of file categorical_dqn.hpp.
|
inline |
Return the Parameters.
Definition at line 190 of file categorical_dqn.hpp.
|
inline |
Predict the responses to a given set of predictors.
The responses will reflect the output of the given output layer as returned by the output layer function.
If you want to pass in a parameter and discard the original parameter object, be sure to use std::move to avoid unnecessary copy.
state | Input state. |
actionValue | Matrix to put output action values of states input. |
Definition at line 131 of file categorical_dqn.hpp.
References Softmax< InputDataType, OutputDataType >::Forward().
|
inline |
Resets noise of the network, if the network is of type noisy.
Definition at line 180 of file categorical_dqn.hpp.
References CategoricalDQN< OutputLayerType, InitType, NetworkType >::ResetNoise().
Referenced by CategoricalDQN< OutputLayerType, InitType, NetworkType >::ResetNoise().
|
inline |
Resets the parameters of the network.
Definition at line 172 of file categorical_dqn.hpp.