mlpack 3.4.2
concat.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_ANN_LAYER_CONCAT_HPP
14#define MLPACK_METHODS_ANN_LAYER_CONCAT_HPP
15
16#include <mlpack/prereqs.hpp>
17
18#include "../visitor/delete_visitor.hpp"
19#include "../visitor/delta_visitor.hpp"
20#include "../visitor/output_parameter_visitor.hpp"
21
22#include <boost/ptr_container/ptr_vector.hpp>
23
24#include "layer_types.hpp"
25
26namespace mlpack {
27namespace ann {
28
40template <
41 typename InputDataType = arma::mat,
42 typename OutputDataType = arma::mat,
43 typename... CustomLayers
44>
45class Concat
46{
47 public:
54 Concat(const bool model = false,
55 const bool run = true);
56
65 Concat(arma::Row<size_t>& inputSize,
66 const size_t axis,
67 const bool model = false,
68 const bool run = true);
69
74
82 template<typename eT>
83 void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
84
94 template<typename eT>
95 void Backward(const arma::Mat<eT>& /* input */,
96 const arma::Mat<eT>& gy,
97 arma::Mat<eT>& g);
98
108 template<typename eT>
109 void Backward(const arma::Mat<eT>& /* input */,
110 const arma::Mat<eT>& gy,
111 arma::Mat<eT>& g,
112 const size_t index);
113
114 /*
115 * Calculate the gradient using the output delta and the input activation.
116 *
117 * @param input The input parameter used for calculating the gradient.
118 * @param error The calculated error.
119 * @param gradient The calculated gradient.
120 */
121 template<typename eT>
122 void Gradient(const arma::Mat<eT>& /* input */,
123 const arma::Mat<eT>& error,
124 arma::Mat<eT>& /* gradient */);
125
126 /*
127 * This is the overload of Gradient() that runs a specific layer with the
128 * given input.
129 *
130 * @param input The input parameter used for calculating the gradient.
131 * @param error The calculated error.
132 * @param gradient The calculated gradient.
133 * @param The index of the layer to run.
134 */
135 template<typename eT>
136 void Gradient(const arma::Mat<eT>& input,
137 const arma::Mat<eT>& error,
138 arma::Mat<eT>& gradient,
139 const size_t index);
140
141 /*
142 * Add a new module to the model.
143 *
144 * @param args The layer parameter.
145 */
146 template <class LayerType, class... Args>
147 void Add(Args... args) { network.push_back(new LayerType(args...)); }
148
149 /*
150 * Add a new module to the model.
151 *
152 * @param layer The Layer to be added to the model.
153 */
154 void Add(LayerTypes<CustomLayers...> layer) { network.push_back(layer); }
155
157 std::vector<LayerTypes<CustomLayers...> >& Model()
158 {
159 if (model)
160 {
161 return network;
162 }
163
164 return empty;
165 }
166
168 const arma::mat& Parameters() const { return parameters; }
170 arma::mat& Parameters() { return parameters; }
171
173 bool Run() const { return run; }
175 bool& Run() { return run; }
176
177 arma::mat const& InputParameter() const { return inputParameter; }
179 arma::mat& InputParameter() { return inputParameter; }
180
182 arma::mat const& OutputParameter() const { return outputParameter; }
184 arma::mat& OutputParameter() { return outputParameter; }
185
187 arma::mat const& Delta() const { return delta; }
189 arma::mat& Delta() { return delta; }
190
192 arma::mat const& Gradient() const { return gradient; }
194 arma::mat& Gradient() { return gradient; }
195
197 size_t const& ConcatAxis() const { return axis; }
198
202 template<typename Archive>
203 void serialize(Archive& /* ar */, const unsigned int /* version */);
204
205 private:
207 arma::Row<size_t> inputSize;
208
210 size_t axis;
211
213 bool useAxis;
214
216 bool model;
217
220 bool run;
221
223 size_t channels;
224
226 std::vector<LayerTypes<CustomLayers...> > network;
227
229 arma::mat parameters;
230
232 DeltaVisitor deltaVisitor;
233
235 OutputParameterVisitor outputParameterVisitor;
236
238 DeleteVisitor deleteVisitor;
239
241 std::vector<LayerTypes<CustomLayers...> > empty;
242
244 arma::mat delta;
245
247 arma::mat inputParameter;
248
250 arma::mat outputParameter;
251
253 arma::mat gradient;
254}; // class Concat
255
256} // namespace ann
257} // namespace mlpack
258
259// Include implementation.
260#include "concat_impl.hpp"
261
262#endif
Implementation of the Concat class.
Definition: concat.hpp:46
arma::mat const & Gradient() const
Get the gradient.
Definition: concat.hpp:192
arma::mat & Parameters()
Modify the initial point for the optimization.
Definition: concat.hpp:170
void Gradient(const arma::Mat< eT > &, const arma::Mat< eT > &error, arma::Mat< eT > &)
bool & Run()
Modify the value of run parameter.
Definition: concat.hpp:175
arma::mat & OutputParameter()
Modify the output parameter.
Definition: concat.hpp:184
const arma::mat & Parameters() const
Return the initial point for the optimization.
Definition: concat.hpp:168
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g, const size_t index)
This is the overload of Backward() that runs only a specific layer with the given input.
void Gradient(const arma::Mat< eT > &input, const arma::Mat< eT > &error, arma::Mat< eT > &gradient, const size_t index)
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
void Add(LayerTypes< CustomLayers... > layer)
Definition: concat.hpp:154
Concat(const bool model=false, const bool run=true)
Create the Concat object using the specified parameters.
arma::mat & Delta()
Modify the delta.
Definition: concat.hpp:189
size_t const & ConcatAxis() const
Get the axis of concatenation.
Definition: concat.hpp:197
std::vector< LayerTypes< CustomLayers... > > & Model()
Return the model modules.
Definition: concat.hpp:157
void Add(Args... args)
Definition: concat.hpp:147
arma::mat const & InputParameter() const
Definition: concat.hpp:177
arma::mat const & Delta() const
Get the delta.e.
Definition: concat.hpp:187
void serialize(Archive &, const unsigned int)
Serialize the layer.
arma::mat & InputParameter()
Modify the input parameter.
Definition: concat.hpp:179
~Concat()
Destroy the layers held by the model.
arma::mat const & OutputParameter() const
Get the output parameter.
Definition: concat.hpp:182
arma::mat & Gradient()
Modify the gradient.
Definition: concat.hpp:194
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, using 3rd-order tensors as input,...
bool Run() const
Get the value of run parameter.
Definition: concat.hpp:173
Concat(arma::Row< size_t > &inputSize, const size_t axis, const bool model=false, const bool run=true)
Create the Concat object using the specified parameters.
DeleteVisitor executes the destructor of the instantiated object.
DeltaVisitor exposes the delta parameter of the given module.
OutputParameterVisitor exposes the output parameter of the given module.
boost::variant< AdaptiveMaxPooling< arma::mat, arma::mat > *, AdaptiveMeanPooling< arma::mat, arma::mat > *, Add< arma::mat, arma::mat > *, AddMerge< arma::mat, arma::mat > *, AlphaDropout< arma::mat, arma::mat > *, AtrousConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, BaseLayer< LogisticFunction, arma::mat, arma::mat > *, BaseLayer< IdentityFunction, arma::mat, arma::mat > *, BaseLayer< TanhFunction, arma::mat, arma::mat > *, BaseLayer< SoftplusFunction, arma::mat, arma::mat > *, BaseLayer< RectifierFunction, arma::mat, arma::mat > *, BatchNorm< arma::mat, arma::mat > *, BilinearInterpolation< arma::mat, arma::mat > *, CELU< arma::mat, arma::mat > *, Concat< arma::mat, arma::mat > *, Concatenate< arma::mat, arma::mat > *, ConcatPerformance< NegativeLogLikelihood< arma::mat, arma::mat >, arma::mat, arma::mat > *, Constant< arma::mat, arma::mat > *, Convolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, CReLU< arma::mat, arma::mat > *, DropConnect< arma::mat, arma::mat > *, Dropout< arma::mat, arma::mat > *, ELU< arma::mat, arma::mat > *, FastLSTM< arma::mat, arma::mat > *, FlexibleReLU< arma::mat, arma::mat > *, GRU< arma::mat, arma::mat > *, HardTanH< arma::mat, arma::mat > *, Join< arma::mat, arma::mat > *, LayerNorm< arma::mat, arma::mat > *, LeakyReLU< arma::mat, arma::mat > *, Linear< arma::mat, arma::mat, NoRegularizer > *, LinearNoBias< arma::mat, arma::mat, NoRegularizer > *, LogSoftMax< arma::mat, arma::mat > *, Lookup< arma::mat, arma::mat > *, LSTM< arma::mat, arma::mat > *, MaxPooling< arma::mat, arma::mat > *, MeanPooling< arma::mat, arma::mat > *, MiniBatchDiscrimination< arma::mat, arma::mat > *, MultiplyConstant< arma::mat, arma::mat > *, MultiplyMerge< arma::mat, arma::mat > *, NegativeLogLikelihood< arma::mat, arma::mat > *, NoisyLinear< arma::mat, arma::mat > *, Padding< arma::mat, arma::mat > *, PReLU< arma::mat, arma::mat > *, Softmax< arma::mat, arma::mat > *, SpatialDropout< arma::mat, arma::mat > *, TransposedConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, WeightNorm< arma::mat, arma::mat > *, MoreTypes, CustomLayers *... > LayerTypes
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.