mlpack 3.4.2
swish_function.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_SWISH_FUNCTION_HPP
14#define MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_SWISH_FUNCTION_HPP
15
16#include <mlpack/prereqs.hpp>
17
18namespace mlpack {
19namespace ann {
20
31{
32 public:
39 static double Fn(const double x)
40 {
41 return x / (1.0 + std::exp(-x));
42 }
43
50 template<typename eT>
51 static void Fn(const arma::Mat<eT>& x, arma::Mat<eT>& y)
52 {
53 y = x / (1.0 + arma::exp(-x));
54 }
55
62 template<typename InputVecType, typename OutputVecType>
63 static void Fn(const InputVecType& x, OutputVecType& y)
64 {
65 y.set_size(arma::size(x));
66
67 for (size_t i = 0; i < x.n_elem; ++i)
68 y(i) = Fn(x(i));
69 }
70
77 static double Deriv(const double y)
78 {
79 return y / (1 + std::exp(-y)) + (1 - y / (1 + std::exp(-y))) /
80 (1 + std::exp(-y));
81 }
82
89 template<typename InputVecType, typename OutputVecType>
90 static void Deriv(const InputVecType& y, OutputVecType& x)
91 {
92 x = y / (1 + arma::exp(-y)) + (1 - y / (1 + arma::exp(-y))) /
93 (1 + arma::exp(-y));
94 }
95}; // class SwishFunction
96
97} // namespace ann
98} // namespace mlpack
99
100#endif
The swish function, defined by.
static double Fn(const double x)
Computes the swish function.
static double Deriv(const double y)
Computes the first derivative of the swish function.
static void Fn(const arma::Mat< eT > &x, arma::Mat< eT > &y)
Computes the swish function using a matrix as input.
static void Deriv(const InputVecType &y, OutputVecType &x)
Computes the first derivatives of the swish function.
static void Fn(const InputVecType &x, OutputVecType &y)
Computes the swish function.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.