mlpack 3.4.2
glorot_init.hpp
Go to the documentation of this file.
1
14#ifndef MLPACK_METHODS_ANN_INIT_RULES_GLOROT_INIT_HPP
15#define MLPACK_METHODS_ANN_INIT_RULES_GLOROT_INIT_HPP
16
17#include <mlpack/prereqs.hpp>
18#include "random_init.hpp"
19#include "gaussian_init.hpp"
20
21using namespace mlpack::math;
22
23namespace mlpack {
24namespace ann {
25
54template<bool Uniform = true>
56{
57 public:
62 {
63 // Nothing to do here.
64 }
65
73 template<typename eT>
74 void Initialize(arma::Mat<eT>& W,
75 const size_t rows,
76 const size_t cols);
77
83 template<typename eT>
84 void Initialize(arma::Mat<eT>& W);
85
95 template<typename eT>
96 void Initialize(arma::Cube<eT>& W,
97 const size_t rows,
98 const size_t cols,
99 const size_t slices);
100
107 template<typename eT>
108 void Initialize(arma::Cube<eT>& W);
109}; // class GlorotInitializationType
110
111template <>
112template<typename eT>
114 const size_t rows,
115 const size_t cols)
116{
117 if (W.is_empty())
118 W.set_size(rows, cols);
119
120 double var = 2.0 / double(rows + cols);
121 GaussianInitialization normalInit(0.0, var);
122 normalInit.Initialize(W, rows, cols);
123}
124
125template <>
126template<typename eT>
128{
129 if (W.is_empty())
130 Log::Fatal << "Cannot initialize and empty matrix." << std::endl;
131
132 double var = 2.0 / double(W.n_rows + W.n_cols);
133 GaussianInitialization normalInit(0.0, var);
134 normalInit.Initialize(W);
135}
136
137template <>
138template<typename eT>
140 const size_t rows,
141 const size_t cols)
142{
143 if (W.is_empty())
144 W.set_size(rows, cols);
145
146 // Limit of distribution.
147 double a = sqrt(6) / sqrt(rows + cols);
148 RandomInitialization randomInit(-a, a);
149 randomInit.Initialize(W, rows, cols);
150}
151
152template <>
153template<typename eT>
155{
156 if (W.is_empty())
157 Log::Fatal << "Cannot initialize an empty matrix." << std::endl;
158
159 // Limit of distribution.
160 double a = sqrt(6) / sqrt(W.n_rows + W.n_cols);
161 RandomInitialization randomInit(-a, a);
162 randomInit.Initialize(W);
163}
164
165template <bool Uniform>
166template<typename eT>
168 const size_t rows,
169 const size_t cols,
170 const size_t slices)
171{
172 if (W.is_empty())
173 W.set_size(rows, cols, slices);
174
175 for (size_t i = 0; i < slices; ++i)
176 Initialize(W.slice(i), rows, cols);
177}
178
179template <bool Uniform>
180template<typename eT>
182{
183 if (W.is_empty())
184 Log::Fatal << "Cannot initialize an empty matrix." << std::endl;
185
186 for (size_t i = 0; i < W.n_slices; ++i)
187 Initialize(W.slice(i));
188}
189
190// Convenience typedefs.
191
196
201// Uses normal distribution
202} // namespace ann
203} // namespace mlpack
204
205#endif
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
This class is used to initialize weigth matrix with a gaussian.
void Initialize(arma::Mat< eT > &W, const size_t rows, const size_t cols)
Initialize the elements weight matrix using a Gaussian Distribution.
This class is used to initialize the weight matrix with the Glorot Initialization method.
Definition: glorot_init.hpp:56
GlorotInitializationType()
Initialize the Glorot initialization object.
Definition: glorot_init.hpp:61
void Initialize(arma::Mat< eT > &W, const size_t rows, const size_t cols)
Initialize the elements weight matrix with glorot initialization method.
void Initialize(arma::Mat< eT > &W)
Initialize the elements weight matrix with glorot initialization method.
This class is used to initialize randomly the weight matrix.
Definition: random_init.hpp:25
void Initialize(arma::Mat< eT > &W, const size_t rows, const size_t cols)
Initialize randomly the elements of the specified weight matrix.
Definition: random_init.hpp:56
Miscellaneous math routines.
Definition: ccov.hpp:20
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.