mlpack 3.4.2
base_layer.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_ANN_LAYER_BASE_LAYER_HPP
14#define MLPACK_METHODS_ANN_LAYER_BASE_LAYER_HPP
15
16#include <mlpack/prereqs.hpp>
30
31namespace mlpack {
32namespace ann {
33
60template <
61 class ActivationFunction = LogisticFunction,
62 typename InputDataType = arma::mat,
63 typename OutputDataType = arma::mat
64>
66{
67 public:
72 {
73 // Nothing to do here.
74 }
75
83 template<typename InputType, typename OutputType>
84 void Forward(const InputType& input, OutputType& output)
85 {
86 ActivationFunction::Fn(input, output);
87 }
88
98 template<typename eT>
99 void Backward(const arma::Mat<eT>& input,
100 const arma::Mat<eT>& gy,
101 arma::Mat<eT>& g)
102 {
103 arma::Mat<eT> derivative;
104 ActivationFunction::Deriv(input, derivative);
105 g = gy % derivative;
106 }
107
109 OutputDataType const& OutputParameter() const { return outputParameter; }
111 OutputDataType& OutputParameter() { return outputParameter; }
112
114 OutputDataType const& Delta() const { return delta; }
116 OutputDataType& Delta() { return delta; }
117
121 template<typename Archive>
122 void serialize(Archive& /* ar */, const unsigned int /* version */)
123 {
124 /* Nothing to do here */
125 }
126
127 private:
129 OutputDataType delta;
130
132 OutputDataType outputParameter;
133}; // class BaseLayer
134
135// Convenience typedefs.
136
140template <
141 class ActivationFunction = LogisticFunction,
142 typename InputDataType = arma::mat,
143 typename OutputDataType = arma::mat
144>
146 ActivationFunction, InputDataType, OutputDataType>;
147
151template <
152 class ActivationFunction = IdentityFunction,
153 typename InputDataType = arma::mat,
154 typename OutputDataType = arma::mat
155>
157 ActivationFunction, InputDataType, OutputDataType>;
158
162template <
163 class ActivationFunction = RectifierFunction,
164 typename InputDataType = arma::mat,
165 typename OutputDataType = arma::mat
166>
168 ActivationFunction, InputDataType, OutputDataType>;
169
173template <
174 class ActivationFunction = TanhFunction,
175 typename InputDataType = arma::mat,
176 typename OutputDataType = arma::mat
177>
179 ActivationFunction, InputDataType, OutputDataType>;
180
184template <
185 class ActivationFunction = SoftplusFunction,
186 typename InputDataType = arma::mat,
187 typename OutputDataType = arma::mat
188>
190 ActivationFunction, InputDataType, OutputDataType>;
191
195template <
196 class ActivationFunction = HardSigmoidFunction,
197 typename InputDataType = arma::mat,
198 typename OutputDataType = arma::mat
199>
201 ActivationFunction, InputDataType, OutputDataType>;
202
206template <
207 class ActivationFunction = SwishFunction,
208 typename InputDataType = arma::mat,
209 typename OutputDataType = arma::mat
210>
212 ActivationFunction, InputDataType, OutputDataType>;
213
217template <
218 class ActivationFunction = MishFunction,
219 typename InputDataType = arma::mat,
220 typename OutputDataType = arma::mat
221>
223 ActivationFunction, InputDataType, OutputDataType>;
224
228template <
229 class ActivationFunction = LiSHTFunction,
230 typename InputDataType = arma::mat,
231 typename OutputDataType = arma::mat
232>
234 ActivationFunction, InputDataType, OutputDataType>;
235
239template <
240 class ActivationFunction = GELUFunction,
241 typename InputDataType = arma::mat,
242 typename OutputDataType = arma::mat
243>
245 ActivationFunction, InputDataType, OutputDataType>;
246
250template <
251 class ActivationFunction = ElliotFunction,
252 typename InputDataType = arma::mat,
253 typename OutputDataType = arma::mat
254>
256 ActivationFunction, InputDataType, OutputDataType>;
257
261template <
262 class ActivationFunction = ElishFunction,
263 typename InputDataType = arma::mat,
264 typename OutputDataType = arma::mat
265>
267 ActivationFunction, InputDataType, OutputDataType>;
268
272template <
273 class ActivationFunction = GaussianFunction,
274 typename InputDataType = arma::mat,
275 typename OutputDataType = arma::mat
276>
278 ActivationFunction, InputDataType, OutputDataType>;
279
280} // namespace ann
281} // namespace mlpack
282
283#endif
Implementation of the base layer.
Definition: base_layer.hpp:66
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: base_layer.hpp:84
OutputDataType const & Delta() const
Get the delta.
Definition: base_layer.hpp:114
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: base_layer.hpp:109
BaseLayer()
Create the BaseLayer object.
Definition: base_layer.hpp:71
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: base_layer.hpp:99
void serialize(Archive &, const unsigned int)
Serialize the layer.
Definition: base_layer.hpp:122
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: base_layer.hpp:111
OutputDataType & Delta()
Modify the delta.
Definition: base_layer.hpp:116
The ELiSH function, defined by.
The Elliot function, defined by.
The GELU function, defined by.
The gaussian function, defined by.
The hard sigmoid function, defined by.
The identity function, defined by.
The LiSHT function, defined by.
The Mish function, defined by.
The rectifier function, defined by.
The softplus function, defined by.
The swish function, defined by.
The tanh function, defined by.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.