mlpack 3.4.2
hardshrink.hpp
Go to the documentation of this file.
1
15#ifndef MLPACK_METHODS_ANN_LAYER_HARDSHRINK_HPP
16#define MLPACK_METHODS_ANN_LAYER_HARDSHRINK_HPP
17
18#include <mlpack/prereqs.hpp>
19
20namespace mlpack {
21namespace ann {
22
41template <
42 typename InputDataType = arma::mat,
43 typename OutputDataType = arma::mat
44>
46{
47 public:
56 HardShrink(const double lambda = 0.5);
57
65 template<typename InputType, typename OutputType>
66 void Forward(const InputType& input, OutputType& output);
67
77 template<typename DataType>
78 void Backward(const DataType& input,
79 DataType& gy,
80 DataType& g);
81
83 OutputDataType const& OutputParameter() const { return outputParameter; }
85 OutputDataType& OutputParameter() { return outputParameter; }
86
88 OutputDataType const& Delta() const { return delta; }
90 OutputDataType& Delta() { return delta; }
91
93 double const& Lambda() const { return lambda; }
95 double& Lambda() { return lambda; }
96
100 template<typename Archive>
101 void serialize(Archive& ar, const unsigned int /* version */);
102
103 private:
105 OutputDataType delta;
106
108 OutputDataType outputParameter;
109
111 double lambda;
112}; // class HardShrink
113
114} // namespace ann
115} // namespace mlpack
116
117// Include implementation.
118#include "hardshrink_impl.hpp"
119
120#endif
Hard Shrink operator is defined as,.
Definition: hardshrink.hpp:46
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
OutputDataType const & Delta() const
Get the delta.
Definition: hardshrink.hpp:88
void Backward(const DataType &input, DataType &gy, DataType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
HardShrink(const double lambda=0.5)
Create HardShrink object using specified hyperparameter lambda.
double & Lambda()
Modify the hyperparameter lambda.
Definition: hardshrink.hpp:95
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: hardshrink.hpp:83
double const & Lambda() const
Get the hyperparameter lambda.
Definition: hardshrink.hpp:93
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: hardshrink.hpp:85
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
OutputDataType & Delta()
Modify the delta.
Definition: hardshrink.hpp:90
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.