mlpack 3.4.2
orthogonal_init.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_METHODS_ANN_INIT_RULES_ORTHOGONAL_INIT_HPP
13#define MLPACK_METHODS_ANN_INIT_RULES_ORTHOGONAL_INIT_HPP
14
15#include <mlpack/prereqs.hpp>
16
17namespace mlpack {
18namespace ann {
19
25{
26 public:
32 OrthogonalInitialization(const double gain = 1.0) : gain(gain) { }
33
42 template<typename eT>
43 void Initialize(arma::Mat<eT>& W, const size_t rows, const size_t cols)
44 {
45 arma::Mat<eT> V;
46 arma::Col<eT> s;
47
48 arma::svd_econ(W, s, V, arma::randu<arma::Mat<eT> >(rows, cols));
49 W *= gain;
50 }
51
58 template<typename eT>
59 void Initialize(arma::Mat<eT>& W)
60 {
61 arma::Mat<eT> V;
62 arma::Col<eT> s;
63
64 arma::svd_econ(W, s, V, arma::randu<arma::Mat<eT> >(W.n_rows, W.n_cols));
65 W *= gain;
66 }
67
77 template<typename eT>
78 void Initialize(arma::Cube<eT>& W,
79 const size_t rows,
80 const size_t cols,
81 const size_t slices)
82 {
83 if (W.is_empty())
84 W.set_size(rows, cols, slices);
85
86 for (size_t i = 0; i < slices; ++i)
87 Initialize(W.slice(i), rows, cols);
88 }
89
96 template<typename eT>
97 void Initialize(arma::Cube<eT>& W)
98 {
99 if (W.is_empty())
100 Log::Fatal << "Cannot initialize an empty cube." << std::endl;
101
102 for (size_t i = 0; i < W.n_slices; ++i)
103 Initialize(W.slice(i));
104 }
105
106 private:
108 double gain;
109}; // class OrthogonalInitialization
110
111
112} // namespace ann
113} // namespace mlpack
114
115#endif
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
This class is used to initialize the weight matrix with the orthogonal matrix initialization.
void Initialize(arma::Cube< eT > &W, const size_t rows, const size_t cols, const size_t slices)
Initialize the elements of the specified weight 3rd order tensor with the orthogonal matrix initializ...
void Initialize(arma::Mat< eT > &W, const size_t rows, const size_t cols)
Initialize the elements of the specified weight matrix with the orthogonal matrix initialization meth...
OrthogonalInitialization(const double gain=1.0)
Initialize the orthogonal matrix initialization rule with the given gain.
void Initialize(arma::Cube< eT > &W)
Initialize the elements of the specified weight 3rd order tensor with the orthogonal matrix initializ...
void Initialize(arma::Mat< eT > &W)
Initialize the elements of the specified weight matrix with the orthogonal matrix initialization meth...
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.