mlpack 3.4.2
celu.hpp
Go to the documentation of this file.
1
23#ifndef MLPACK_METHODS_ANN_LAYER_CELU_HPP
24#define MLPACK_METHODS_ANN_LAYER_CELU_HPP
25
26#include <mlpack/prereqs.hpp>
27
28namespace mlpack {
29namespace ann {
30
56template <
57 typename InputDataType = arma::mat,
58 typename OutputDataType = arma::mat
59>
60class CELU
61{
62 public:
70 CELU(const double alpha = 1.0);
71
79 template<typename InputType, typename OutputType>
80 void Forward(const InputType& input, OutputType& output);
81
91 template<typename DataType>
92 void Backward(const DataType& input, const DataType& gy, DataType& g);
93
95 OutputDataType const& OutputParameter() const { return outputParameter; }
97 OutputDataType& OutputParameter() { return outputParameter; }
98
100 OutputDataType const& Delta() const { return delta; }
102 OutputDataType& Delta() { return delta; }
103
105 double const& Alpha() const { return alpha; }
107 double& Alpha() { return alpha; }
108
110 bool Deterministic() const { return deterministic; }
112 bool& Deterministic() { return deterministic; }
113
117 template<typename Archive>
118 void serialize(Archive& ar, const unsigned int /* version */);
119
120 private:
122 OutputDataType delta;
123
125 OutputDataType outputParameter;
126
128 arma::mat derivative;
129
131 double alpha;
132
134 bool deterministic;
135}; // class CELU
136
137} // namespace ann
138} // namespace mlpack
139
140// Include implementation.
141#include "celu_impl.hpp"
142
143#endif
The CELU activation function, defined by.
Definition: celu.hpp:61
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
OutputDataType const & Delta() const
Get the delta.
Definition: celu.hpp:100
double & Alpha()
Modify the non zero gradient.
Definition: celu.hpp:107
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: celu.hpp:95
bool & Deterministic()
Modify the value of deterministic parameter.
Definition: celu.hpp:112
bool Deterministic() const
Get the value of deterministic parameter.
Definition: celu.hpp:110
CELU(const double alpha=1.0)
Create the CELU object using the specified parameter.
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: celu.hpp:97
double const & Alpha() const
Get the non zero gradient.
Definition: celu.hpp:105
void Backward(const DataType &input, const DataType &gy, DataType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
OutputDataType & Delta()
Modify the delta.
Definition: celu.hpp:102
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.