mlpack 3.4.2
network_init.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_ANN_INIT_RULES_NETWORK_INIT_HPP
14#define MLPACK_METHODS_ANN_INIT_RULES_NETWORK_INIT_HPP
15
16#include <mlpack/prereqs.hpp>
17
18#include "../visitor/reset_visitor.hpp"
19#include "../visitor/weight_size_visitor.hpp"
20#include "../visitor/weight_set_visitor.hpp"
21#include "init_rules_traits.hpp"
22
24
25namespace mlpack {
26namespace ann {
27
32template<typename InitializationRuleType, typename... CustomLayers>
34{
35 public:
42 const InitializationRuleType& initializeRule = InitializationRuleType()) :
43 initializeRule(initializeRule)
44 {
45 // Nothing to do here.
46 }
47
56 template <typename eT>
57 void Initialize(const std::vector<LayerTypes<CustomLayers...> >& network,
58 arma::Mat<eT>& parameter, size_t parameterOffset = 0)
59 {
60 // Determine the number of parameter/weights of the given network.
61 if (parameter.is_empty())
62 {
63 size_t weights = 0;
64 for (size_t i = 0; i < network.size(); ++i)
65 weights += boost::apply_visitor(weightSizeVisitor, network[i]);
66 parameter.set_size(weights, 1);
67 }
68
69 // Initialize the network layer by layer or the complete network.
71 {
72 for (size_t i = 0, offset = parameterOffset; i < network.size(); ++i)
73 {
74 // Initialize the layer with the specified parameter/weight
75 // initialization rule.
76 const size_t weight = boost::apply_visitor(weightSizeVisitor,
77 network[i]);
78 arma::Mat<eT> tmp = arma::mat(parameter.memptr() + offset,
79 weight, 1, false, false);
80 initializeRule.Initialize(tmp, tmp.n_elem, 1);
81
82 // Increase the parameter/weight offset for the next layer.
83 offset += weight;
84 }
85 }
86 else
87 {
88 initializeRule.Initialize(parameter, parameter.n_elem, 1);
89 }
90
91 // Note: We can't merge the for loop into the for loop above because
92 // WeightSetVisitor also sets the parameter/weights of the inner modules.
93 // Inner Modules are held by the parent module e.g. the concat module can
94 // hold various other modules.
95 for (size_t i = 0, offset = parameterOffset; i < network.size(); ++i)
96 {
97 offset += boost::apply_visitor(WeightSetVisitor(parameter, offset),
98 network[i]);
99
100 boost::apply_visitor(resetVisitor, network[i]);
101 }
102 }
103
104 private:
107 InitializationRuleType initializeRule;
108
110 ResetVisitor resetVisitor;
111
113 WeightSizeVisitor weightSizeVisitor;
114}; // class NetworkInitialization
115
116} // namespace ann
117} // namespace mlpack
118
119#endif
This is a template class that can provide information about various initialization methods.
This class is used to initialize the network with the given initialization rule.
NetworkInitialization(const InitializationRuleType &initializeRule=InitializationRuleType())
Use the given initialization rule to initialize the specified network.
void Initialize(const std::vector< LayerTypes< CustomLayers... > > &network, arma::Mat< eT > &parameter, size_t parameterOffset=0)
Initialize the specified network and store the results in the given parameter.
ResetVisitor executes the Reset() function.
WeightSetVisitor update the module parameters given the parameters set.
WeightSizeVisitor returns the number of weights of the given module.
boost::variant< AdaptiveMaxPooling< arma::mat, arma::mat > *, AdaptiveMeanPooling< arma::mat, arma::mat > *, Add< arma::mat, arma::mat > *, AddMerge< arma::mat, arma::mat > *, AlphaDropout< arma::mat, arma::mat > *, AtrousConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, BaseLayer< LogisticFunction, arma::mat, arma::mat > *, BaseLayer< IdentityFunction, arma::mat, arma::mat > *, BaseLayer< TanhFunction, arma::mat, arma::mat > *, BaseLayer< SoftplusFunction, arma::mat, arma::mat > *, BaseLayer< RectifierFunction, arma::mat, arma::mat > *, BatchNorm< arma::mat, arma::mat > *, BilinearInterpolation< arma::mat, arma::mat > *, CELU< arma::mat, arma::mat > *, Concat< arma::mat, arma::mat > *, Concatenate< arma::mat, arma::mat > *, ConcatPerformance< NegativeLogLikelihood< arma::mat, arma::mat >, arma::mat, arma::mat > *, Constant< arma::mat, arma::mat > *, Convolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, CReLU< arma::mat, arma::mat > *, DropConnect< arma::mat, arma::mat > *, Dropout< arma::mat, arma::mat > *, ELU< arma::mat, arma::mat > *, FastLSTM< arma::mat, arma::mat > *, FlexibleReLU< arma::mat, arma::mat > *, GRU< arma::mat, arma::mat > *, HardTanH< arma::mat, arma::mat > *, Join< arma::mat, arma::mat > *, LayerNorm< arma::mat, arma::mat > *, LeakyReLU< arma::mat, arma::mat > *, Linear< arma::mat, arma::mat, NoRegularizer > *, LinearNoBias< arma::mat, arma::mat, NoRegularizer > *, LogSoftMax< arma::mat, arma::mat > *, Lookup< arma::mat, arma::mat > *, LSTM< arma::mat, arma::mat > *, MaxPooling< arma::mat, arma::mat > *, MeanPooling< arma::mat, arma::mat > *, MiniBatchDiscrimination< arma::mat, arma::mat > *, MultiplyConstant< arma::mat, arma::mat > *, MultiplyMerge< arma::mat, arma::mat > *, NegativeLogLikelihood< arma::mat, arma::mat > *, NoisyLinear< arma::mat, arma::mat > *, Padding< arma::mat, arma::mat > *, PReLU< arma::mat, arma::mat > *, Softmax< arma::mat, arma::mat > *, SpatialDropout< arma::mat, arma::mat > *, TransposedConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, WeightNorm< arma::mat, arma::mat > *, MoreTypes, CustomLayers *... > LayerTypes
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.