mlpack 3.4.2
svd_batch_learning.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCH_LEARNING_HPP
13#define MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCH_LEARNING_HPP
14
15#include <mlpack/prereqs.hpp>
16
17namespace mlpack {
18namespace amf {
19
42{
43 public:
52 SVDBatchLearning(double u = 0.0002,
53 double kw = 0,
54 double kh = 0,
55 double momentum = 0.9)
56 : u(u), kw(kw), kh(kh), momentum(momentum)
57 {
58 // empty constructor
59 }
60
68 template<typename MatType>
69 void Initialize(const MatType& dataset, const size_t rank)
70 {
71 const size_t n = dataset.n_rows;
72 const size_t m = dataset.n_cols;
73
74 mW.zeros(n, rank);
75 mH.zeros(rank, m);
76 }
77
87 template<typename MatType>
88 inline void WUpdate(const MatType& V,
89 arma::mat& W,
90 const arma::mat& H)
91 {
92 size_t n = V.n_rows;
93 size_t m = V.n_cols;
94
95 size_t r = W.n_cols;
96
97 // initialize the momentum of this iteration.
98 mW = momentum * mW;
99
100 // Compute the step.
101 arma::mat deltaW;
102 deltaW.zeros(n, r);
103 for (size_t i = 0; i < n; ++i)
104 {
105 for (size_t j = 0; j < m; ++j)
106 {
107 const double val = V(i, j);
108 if (val != 0)
109 deltaW.row(i) += (val - arma::dot(W.row(i), H.col(j))) *
110 arma::trans(H.col(j));
111 }
112 // Add regularization.
113 if (kw != 0)
114 deltaW.row(i) -= kw * W.row(i);
115 }
116
117 // Add the step to the momentum.
118 mW += u * deltaW;
119 // Add the momentum to the W matrix.
120 W += mW;
121 }
122
132 template<typename MatType>
133 inline void HUpdate(const MatType& V,
134 const arma::mat& W,
135 arma::mat& H)
136 {
137 size_t n = V.n_rows;
138 size_t m = V.n_cols;
139
140 size_t r = W.n_cols;
141
142 // Initialize the momentum of this iteration.
143 mH = momentum * mH;
144
145 // Compute the step.
146 arma::mat deltaH;
147 deltaH.zeros(r, m);
148 for (size_t j = 0; j < m; ++j)
149 {
150 for (size_t i = 0; i < n; ++i)
151 {
152 const double val = V(i, j);
153 if (val != 0)
154 deltaH.col(j) += (val - arma::dot(W.row(i), H.col(j))) * W.row(i).t();
155 }
156 // Add regularization.
157 if (kh != 0)
158 deltaH.col(j) -= kh * H.col(j);
159 }
160
161 // Add this step to the momentum.
162 mH += u * deltaH;
163 // Add the momentum to H.
164 H += mH;
165 }
166
168 template<typename Archive>
169 void serialize(Archive& ar, const unsigned int /* version */)
170 {
171 ar & BOOST_SERIALIZATION_NVP(u);
172 ar & BOOST_SERIALIZATION_NVP(kw);
173 ar & BOOST_SERIALIZATION_NVP(kh);
174 ar & BOOST_SERIALIZATION_NVP(momentum);
175 ar & BOOST_SERIALIZATION_NVP(mW);
176 ar & BOOST_SERIALIZATION_NVP(mH);
177 }
178
179 private:
181 double u;
183 double kw;
185 double kh;
187 double momentum;
188
190 arma::mat mW;
192 arma::mat mH;
193}; // class SVDBatchLearning
194
197
201template<>
202inline void SVDBatchLearning::WUpdate<arma::sp_mat>(const arma::sp_mat& V,
203 arma::mat& W,
204 const arma::mat& H)
205{
206 const size_t n = V.n_rows;
207 const size_t r = W.n_cols;
208
209 mW = momentum * mW;
210
211 arma::mat deltaW;
212 deltaW.zeros(n, r);
213
214 for (arma::sp_mat::const_iterator it = V.begin(); it != V.end(); ++it)
215 {
216 const size_t row = it.row();
217 const size_t col = it.col();
218 deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) *
219 arma::trans(H.col(col));
220 }
221
222 if (kw != 0)
223 deltaW -= kw * W;
224
225 mW += u * deltaW;
226 W += mW;
227}
228
229template<>
230inline void SVDBatchLearning::HUpdate<arma::sp_mat>(const arma::sp_mat& V,
231 const arma::mat& W,
232 arma::mat& H)
233{
234 const size_t m = V.n_cols;
235 const size_t r = W.n_cols;
236
237 mH = momentum * mH;
238
239 arma::mat deltaH;
240 deltaH.zeros(r, m);
241
242 for (arma::sp_mat::const_iterator it = V.begin(); it != V.end(); ++it)
243 {
244 const size_t row = it.row();
245 const size_t col = it.col();
246 deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) *
247 W.row(row).t();
248 }
249
250 if (kh != 0)
251 deltaH -= kh * H;
252
253 mH += u * deltaH;
254 H += mH;
255}
256
257} // namespace amf
258} // namespace mlpack
259
260#endif // MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCH_LEARNING_HPP
This class implements SVD batch learning with momentum.
void Initialize(const MatType &dataset, const size_t rank)
Initialize parameters before factorization.
void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
SVDBatchLearning(double u=0.0002, double kw=0, double kh=0, double momentum=0.9)
SVD Batch learning constructor.
void serialize(Archive &ar, const unsigned int)
Serialize the SVDBatch object.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.