mlpack 3.4.2
oivs_init.hpp
Go to the documentation of this file.
1
27#ifndef MLPACK_METHODS_ANN_INIT_RULES_OIVS_INIT_HPP
28#define MLPACK_METHODS_ANN_INIT_RULES_OIVS_INIT_HPP
29
30#include <mlpack/prereqs.hpp>
32
33#include "random_init.hpp"
34
35namespace mlpack {
36namespace ann {
37
56template<
57 class ActivationFunction = LogisticFunction
58>
60{
61 public:
69 OivsInitialization(const double epsilon = 0.1,
70 const int k = 5,
71 const double gamma = 0.9) :
72 k(k), gamma(gamma),
73 b(std::abs(ActivationFunction::Inv(1 - epsilon) -
74 ActivationFunction::Inv(epsilon)))
75 {
76 }
77
85 template<typename eT>
86 void Initialize(arma::Mat<eT>& W, const size_t rows, const size_t cols)
87 {
88 RandomInitialization randomInit(-gamma, gamma);
89 randomInit.Initialize(W, rows, cols);
90
91 W = (b / (k * rows)) * arma::sqrt(W + 1);
92 }
93
99 template<typename eT>
100 void Initialize(arma::Mat<eT>& W)
101 {
102 RandomInitialization randomInit(-gamma, gamma);
103 randomInit.Initialize(W);
104
105 W = (b / (k * W.n_rows)) * arma::sqrt(W + 1);
106 }
107
117 template<typename eT>
118 void Initialize(arma::Cube<eT>& W,
119 const size_t rows,
120 const size_t cols,
121 const size_t slices)
122 {
123 if (W.is_empty())
124 W.set_size(rows, cols, slices);
125
126 for (size_t i = 0; i < slices; ++i)
127 Initialize(W.slice(i), rows, cols);
128 }
129
136 template<typename eT>
137 void Initialize(arma::Cube<eT>& W)
138 {
139 if (W.is_empty())
140 Log::Fatal << "Cannot initialize an empty cube." << std::endl;
141
142 for (size_t i = 0; i < W.n_slices; ++i)
143 Initialize(W.slice(i));
144 }
145
146 private:
148 int k;
149
151 double gamma;
152
154 double b;
155}; // class OivsInitialization
156
157
158} // namespace ann
159} // namespace mlpack
160
161#endif
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
This class is used to initialize the weight matrix with the oivs method.
Definition: oivs_init.hpp:60
void Initialize(arma::Cube< eT > &W, const size_t rows, const size_t cols, const size_t slices)
Initialize the elements of the specified weight 3rd order tensor with the oivs method.
Definition: oivs_init.hpp:118
OivsInitialization(const double epsilon=0.1, const int k=5, const double gamma=0.9)
Initialize the random initialization rule with the given values.
Definition: oivs_init.hpp:69
void Initialize(arma::Mat< eT > &W, const size_t rows, const size_t cols)
Initialize the elements of the specified weight matrix with the oivs method.
Definition: oivs_init.hpp:86
void Initialize(arma::Cube< eT > &W)
Initialize the elements of the specified weight 3rd order tensor with the oivs method.
Definition: oivs_init.hpp:137
void Initialize(arma::Mat< eT > &W)
Initialize the elements of the specified weight matrix with the oivs method.
Definition: oivs_init.hpp:100
This class is used to initialize randomly the weight matrix.
Definition: random_init.hpp:25
void Initialize(arma::Mat< eT > &W, const size_t rows, const size_t cols)
Initialize randomly the elements of the specified weight matrix.
Definition: random_init.hpp:56
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: prereqs.hpp:67
The core includes that mlpack expects; standard C++ includes and Armadillo.