mlpack 3.4.2
random_init.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_ANN_INIT_RULES_RANDOM_INIT_HPP
14#define MLPACK_METHODS_ANN_INIT_RULES_RANDOM_INIT_HPP
15
16#include <mlpack/prereqs.hpp>
17
18namespace mlpack {
19namespace ann {
20
25{
26 public:
34 RandomInitialization(const double lowerBound = -1,
35 const double upperBound = 1) :
36 lowerBound(lowerBound), upperBound(upperBound) { }
37
45 RandomInitialization(const double bound) :
46 lowerBound(-std::abs(bound)), upperBound(std::abs(bound)) { }
47
55 template<typename eT>
56 void Initialize(arma::Mat<eT>& W, const size_t rows, const size_t cols)
57 {
58 if (W.is_empty())
59 W.set_size(rows, cols);
60
61 W.randu();
62 W *= (upperBound - lowerBound);
63 W += lowerBound;
64 }
65
71 template<typename eT>
72 void Initialize(arma::Mat<eT>& W)
73 {
74 if (W.is_empty())
75 Log::Fatal << "Cannot initialize an empty matrix." << std::endl;
76
77 W.randu();
78 W *= (upperBound - lowerBound);
79 W += lowerBound;
80 }
81
90 template<typename eT>
91 void Initialize(arma::Cube<eT>& W,
92 const size_t rows,
93 const size_t cols,
94 const size_t slices)
95 {
96 if (W.is_empty())
97 W.set_size(rows, cols, slices);
98
99 for (size_t i = 0; i < slices; ++i)
100 Initialize(W.slice(i), rows, cols);
101 }
102
108 template<typename eT>
109 void Initialize(arma::Cube<eT>& W)
110 {
111 if (W.is_empty())
112 Log::Fatal << "Cannot initialize an empty cube." << std::endl;
113
114 for (size_t i = 0; i < W.n_slices; ++i)
115 Initialize(W.slice(i));
116 }
117
118 private:
120 double lowerBound;
121
123 double upperBound;
124}; // class RandomInitialization
125
126} // namespace ann
127} // namespace mlpack
128
129#endif
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
This class is used to initialize randomly the weight matrix.
Definition: random_init.hpp:25
void Initialize(arma::Cube< eT > &W, const size_t rows, const size_t cols, const size_t slices)
Initialize randomly the elements of the specified weight 3rd order tensor.
Definition: random_init.hpp:91
void Initialize(arma::Mat< eT > &W, const size_t rows, const size_t cols)
Initialize randomly the elements of the specified weight matrix.
Definition: random_init.hpp:56
RandomInitialization(const double lowerBound=-1, const double upperBound=1)
Initialize the random initialization rule with the given lower bound and upper bound.
Definition: random_init.hpp:34
RandomInitialization(const double bound)
Initialize the random initialization rule with the given bound.
Definition: random_init.hpp:45
void Initialize(arma::Cube< eT > &W)
Initialize randomly the elements of the specified weight 3rd order tensor.
void Initialize(arma::Mat< eT > &W)
Initialize randomly the elements of the specified weight matrix.
Definition: random_init.hpp:72
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: prereqs.hpp:67
The core includes that mlpack expects; standard C++ includes and Armadillo.