mlpack 3.4.2
he_init.hpp
Go to the documentation of this file.
1
16#ifndef MLPACK_METHODS_ANN_INIT_RULES_HE_INIT_HPP
17#define MLPACK_METHODS_ANN_INIT_RULES_HE_INIT_HPP
18
19#include <mlpack/prereqs.hpp>
21
22namespace mlpack {
23namespace ann {
24
46{
47 public:
52 {
53 // Nothing to do here.
54 }
55
64 template <typename eT>
65 void Initialize(arma::Mat<eT>& W, const size_t rows, const size_t cols)
66 {
67 // He initialization rule says to initialize weights with random
68 // values taken from a gaussian distribution with mean = 0 and
69 // standard deviation = sqrt(2/rows), i.e. variance = (2/rows).
70 const double variance = 2.0 / (double) rows;
71
72 if (W.is_empty())
73 W.set_size(rows, cols);
74
75 // Multipling a random variable X with variance V(X) by some factor c,
76 // then the variance V(cX) = (c^2) * V(X).
77 W.imbue( [&]() { return sqrt(variance) * arma::randn(); } );
78 }
79
86 template <typename eT>
87 void Initialize(arma::Mat<eT>& W)
88 {
89 // He initialization rule says to initialize weights with random
90 // values taken from a gaussian distribution with mean = 0 and
91 // standard deviation = sqrt(2 / rows), i.e. variance = (2 / rows).
92 const double variance = 2.0 / (double) W.n_rows;
93
94 if (W.is_empty())
95 Log::Fatal << "Cannot initialize an empty matrix." << std::endl;
96
97 // Multipling a random variable X with variance V(X) by some factor c,
98 // then the variance V(cX) = (c^2) * V(X).
99 W.imbue( [&]() { return sqrt(variance) * arma::randn(); } );
100 }
101
111 template <typename eT>
112 void Initialize(arma::Cube<eT> & W,
113 const size_t rows,
114 const size_t cols,
115 const size_t slices)
116 {
117 if (W.is_empty())
118 W.set_size(rows, cols, slices);
119
120 for (size_t i = 0; i < slices; ++i)
121 Initialize(W.slice(i), rows, cols);
122 }
123
130 template <typename eT>
131 void Initialize(arma::Cube<eT> & W)
132 {
133 if (W.is_empty())
134 Log::Fatal << "Cannot initialize an empty matrix" << std::endl;
135
136 for (size_t i = 0; i < W.n_slices; ++i)
137 Initialize(W.slice(i));
138 }
139}; // class HeInitialization
140
141} // namespace ann
142} // namespace mlpack
143
144#endif
if(NOT BUILD_GO_SHLIB) macro(add_go_binding name) endmacro() return() endif() endmacro() macro(post_go_setup) if(BUILD_GO_BINDINGS) file(APPEND "$
Definition: CMakeLists.txt:3
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
This class is used to initialize weight matrix with the He initialization rule given by He et.
Definition: he_init.hpp:46
HeInitialization()
Initialize the HeInitialization object.
Definition: he_init.hpp:51
void Initialize(arma::Cube< eT > &W, const size_t rows, const size_t cols, const size_t slices)
Initialize the elements of the specified weight 3rd order tensor with He initialization rule.
Definition: he_init.hpp:112
void Initialize(arma::Mat< eT > &W, const size_t rows, const size_t cols)
Initialize the elements of the weight matrix with the He initialization rule.
Definition: he_init.hpp:65
void Initialize(arma::Cube< eT > &W)
Initialize the elements of the specified weight 3rd order tensor with He initialization rule.
Definition: he_init.hpp:131
void Initialize(arma::Mat< eT > &W)
Initialize the elements of the weight matrix with the He initialization rule.
Definition: he_init.hpp:87
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.
Miscellaneous math random-related routines.