mlpack 3.4.2
reinforce_normal.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_ANN_LAYER_REINFORCE_NORMAL_HPP
14#define MLPACK_METHODS_ANN_LAYER_REINFORCE_NORMAL_HPP
15
16#include <mlpack/prereqs.hpp>
17
18namespace mlpack {
19namespace ann {
20
30template <
31 typename InputDataType = arma::mat,
32 typename OutputDataType = arma::mat
33>
35{
36 public:
42 ReinforceNormal(const double stdev = 1.0);
43
51 template<typename eT>
52 void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
53
63 template<typename DataType>
64 void Backward(const DataType& input, const DataType& /* gy */, DataType& g);
65
67 OutputDataType& OutputParameter() const { return outputParameter; }
69 OutputDataType& OutputParameter() { return outputParameter; }
70
72 OutputDataType& Delta() const { return delta; }
74 OutputDataType& Delta() { return delta; }
75
77 bool Deterministic() const { return deterministic; }
79 bool& Deterministic() { return deterministic; }
80
82 double Reward() const { return reward; }
84 double& Reward() { return reward; }
85
87 double StandardDeviation() const { return stdev; }
88
92 template<typename Archive>
93 void serialize(Archive& ar, const unsigned int /* version */);
94
95 private:
97 double stdev;
98
100 double reward;
101
103 OutputDataType delta;
104
106 OutputDataType outputParameter;
107
109 std::vector<arma::mat> moduleInputParameter;
110
112 bool deterministic;
113}; // class ReinforceNormal
114
115} // namespace ann
116} // namespace mlpack
117
118// Include implementation.
119#include "reinforce_normal_impl.hpp"
120
121#endif
Implementation of the reinforce normal layer.
ReinforceNormal(const double stdev=1.0)
Create the ReinforceNormal object.
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
double StandardDeviation() const
Get the standard deviation used during forward and backward pass.
OutputDataType & OutputParameter() const
Get the output parameter.
OutputDataType & Delta() const
Get the delta.
bool & Deterministic()
Modify the value of the deterministic parameter.
bool Deterministic() const
Get the value of the deterministic parameter.
double & Reward()
Modify the value of the deterministic parameter.
OutputDataType & OutputParameter()
Modify the output parameter.
void Backward(const DataType &input, const DataType &, DataType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
double Reward() const
Get the value of the reward parameter.
OutputDataType & Delta()
Modify the delta.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.