mlpack 3.4.2
lookup.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_ANN_LAYER_LOOKUP_HPP
14#define MLPACK_METHODS_ANN_LAYER_LOOKUP_HPP
15
16#include <mlpack/prereqs.hpp>
18
19namespace mlpack {
20namespace ann /* Artificial Neural Network. */ {
21
37template <
38 typename InputDataType = arma::mat,
39 typename OutputDataType = arma::mat
40>
41class Lookup
42{
43 public:
50 Lookup(const size_t vocabSize = 0, const size_t embeddingSize = 0);
51
59 template<typename eT>
60 void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
61
71 template<typename eT>
72 void Backward(const arma::Mat<eT>& /* input */,
73 const arma::Mat<eT>& gy,
74 arma::Mat<eT>& g);
75
83 template<typename eT>
84 void Gradient(const arma::Mat<eT>& input,
85 const arma::Mat<eT>& error,
86 arma::Mat<eT>& gradient);
87
89 OutputDataType const& Parameters() const { return weights; }
91 OutputDataType& Parameters() { return weights; }
92
94 OutputDataType const& OutputParameter() const { return outputParameter; }
96 OutputDataType& OutputParameter() { return outputParameter; }
97
99 OutputDataType const& Delta() const { return delta; }
101 OutputDataType& Delta() { return delta; }
102
104 OutputDataType const& Gradient() const { return gradient; }
106 OutputDataType& Gradient() { return gradient; }
107
109 size_t VocabSize() const { return vocabSize; }
110
112 size_t EmbeddingSize() const { return embeddingSize; }
113
117 template<typename Archive>
118 void serialize(Archive& ar, const unsigned int /* version */);
119
120 private:
122 size_t vocabSize;
123
125 size_t embeddingSize;
126
128 OutputDataType weights;
129
131 OutputDataType delta;
132
134 OutputDataType gradient;
135
137 OutputDataType outputParameter;
138}; // class Lookup
139
140// Alias for using as embedding layer.
141template<typename MatType = arma::mat>
143
144} // namespace ann
145} // namespace mlpack
146
147// Include implementation.
148#include "lookup_impl.hpp"
149
150#endif
The Lookup class stores word embeddings and retrieves them using tokens.
Definition: lookup.hpp:42
OutputDataType const & Delta() const
Get the delta.
Definition: lookup.hpp:99
OutputDataType const & Parameters() const
Get the parameters.
Definition: lookup.hpp:89
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: lookup.hpp:94
void Gradient(const arma::Mat< eT > &input, const arma::Mat< eT > &error, arma::Mat< eT > &gradient)
Calculate the gradient using the output delta and the input activation.
OutputDataType const & Gradient() const
Get the gradient.
Definition: lookup.hpp:104
OutputDataType & Gradient()
Modify the gradient.
Definition: lookup.hpp:106
size_t VocabSize() const
Get the size of the vocabulary.
Definition: lookup.hpp:109
Lookup(const size_t vocabSize=0, const size_t embeddingSize=0)
Create the Lookup object using the specified vocabulary and embedding size.
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: lookup.hpp:96
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
OutputDataType & Parameters()
Modify the parameters.
Definition: lookup.hpp:91
OutputDataType & Delta()
Modify the delta.
Definition: lookup.hpp:101
size_t EmbeddingSize() const
Get the length of each embedding vector.
Definition: lookup.hpp:112
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.