mlpack 3.4.2
lecun_normal_init.hpp
Go to the documentation of this file.
1
15#ifndef MLPACK_METHODS_ANN_INIT_RULES_LECUN_NORMAL_INIT_HPP
16#define MLPACK_METHODS_ANN_INIT_RULES_LECUN_NORMAL_INIT_HPP
17
18#include <mlpack/prereqs.hpp>
20
21namespace mlpack {
22namespace ann {
23
50{
51 public:
56 {
57 // Nothing to do here.
58 }
59
68 template <typename eT>
69 void Initialize(arma::Mat<eT>& W,
70 const size_t rows,
71 const size_t cols)
72 {
73 // He initialization rule says to initialize weights with random
74 // values taken from a gaussian distribution with mean = 0 and
75 // standard deviation = sqrt(1 / rows), i.e. variance = (1 / rows).
76 const double variance = 1.0 / ((double) rows);
77
78 if (W.is_empty())
79 W.set_size(rows, cols);
80
81 // Multipling a random variable X with variance V(X) by some factor c,
82 // then the variance V(cX) = (c ^ 2) * V(X).
83 W.imbue( [&]() { return sqrt(variance) * arma::randn(); } );
84 }
85
92 template <typename eT>
93 void Initialize(arma::Mat<eT>& W)
94 {
95 // He initialization rule says to initialize weights with random
96 // values taken from a gaussian distribution with mean = 0 and
97 // standard deviation = sqrt(1 / rows), i.e. variance = (1 / rows).
98 const double variance = 1.0 / (double) W.n_rows;
99
100 if (W.is_empty())
101 Log::Fatal << "Cannot initialize an empty matrix." << std::endl;
102
103 // Multipling a random variable X with variance V(X) by some factor c,
104 // then the variance V(cX) = (c ^ 2) * V(X).
105 W.imbue( [&]() { return sqrt(variance) * arma::randn(); } );
106 }
107
117 template <typename eT>
118 void Initialize(arma::Cube<eT> & W,
119 const size_t rows,
120 const size_t cols,
121 const size_t slices)
122 {
123 if (W.is_empty())
124 W.set_size(rows, cols, slices);
125
126 for (size_t i = 0; i < slices; ++i)
127 Initialize(W.slice(i), rows, cols);
128 }
129
136 template <typename eT>
137 void Initialize(arma::Cube<eT> & W)
138 {
139 if (W.is_empty())
140 Log::Fatal << "Cannot initialize an empty cube." << std::endl;
141
142 for (size_t i = 0; i < W.n_slices; ++i)
143 Initialize(W.slice(i));
144 }
145}; // class LecunNormalInitialization
146
147} // namespace ann
148} // namespace mlpack
149
150#endif
if(NOT BUILD_GO_SHLIB) macro(add_go_binding name) endmacro() return() endif() endmacro() macro(post_go_setup) if(BUILD_GO_BINDINGS) file(APPEND "$
Definition: CMakeLists.txt:3
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
This class is used to initialize weight matrix with the Lecun Normalization initialization rule.
LecunNormalInitialization()
Initialize the LecunNormalInitialization object.
void Initialize(arma::Cube< eT > &W, const size_t rows, const size_t cols, const size_t slices)
Initialize the elements of the specified weight 3rd order tensor with Lecun Normal initialization rul...
void Initialize(arma::Mat< eT > &W, const size_t rows, const size_t cols)
Initialize the elements of the weight matrix with the Lecun Normal initialization rule.
void Initialize(arma::Cube< eT > &W)
Initialize the elements of the specified weight 3rd order tensor with Lecun Normal initialization rul...
void Initialize(arma::Mat< eT > &W)
Initialize the elements of the weight matrix with the Lecun Normal initialization rule.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.
Miscellaneous math random-related routines.