mlpack 3.4.2
svd_complete_incremental_learning.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_METHODS_AMF_SVD_COMPLETE_INCREMENTAL_LEARNING_HPP
13#define MLPACK_METHODS_AMF_SVD_COMPLETE_INCREMENTAL_LEARNING_HPP
14
15#include <mlpack/prereqs.hpp>
16
17namespace mlpack
18{
19namespace amf
20{
21
44template <class MatType>
46{
47 public:
57 double kw = 0,
58 double kh = 0)
59 : u(u), kw(kw), kh(kh)
60 {
61 // Nothing to do.
62 }
63
72 void Initialize(const MatType& /* dataset */, const size_t /* rank */)
73 {
74 // Initialize the current score counters.
75 currentUserIndex = 0;
76 currentItemIndex = 0;
77 }
78
87 inline void WUpdate(const MatType& V,
88 arma::mat& W,
89 const arma::mat& H)
90 {
91 arma::mat deltaW;
92 deltaW.zeros(1, W.n_cols);
93
94 // Loop until a non-zero entry is found.
95 while (true)
96 {
97 const double val = V(currentItemIndex, currentUserIndex);
98 // Update feature vector if current entry is non-zero and break the loop.
99 if (val != 0)
100 {
101 deltaW += (val - arma::dot(W.row(currentItemIndex),
102 H.col(currentUserIndex))) * H.col(currentUserIndex).t();
103
104 // Add regularization.
105 if (kw != 0)
106 deltaW -= kw * W.row(currentItemIndex);
107 break;
108 }
109 }
110
111 W.row(currentItemIndex) += u * deltaW;
112 }
113
123 inline void HUpdate(const MatType& V,
124 const arma::mat& W,
125 arma::mat& H)
126 {
127 arma::mat deltaH;
128 deltaH.zeros(H.n_rows, 1);
129
130 const double val = V(currentItemIndex, currentUserIndex);
131
132 // Update H matrix based on the non-zero entry found in WUpdate function.
133 deltaH += (val - arma::dot(W.row(currentItemIndex),
134 H.col(currentUserIndex))) * W.row(currentItemIndex).t();
135 // Add regularization.
136 if (kh != 0)
137 deltaH -= kh * H.col(currentUserIndex);
138
139 // Move on to the next entry.
140 currentUserIndex = currentUserIndex + 1;
141 if (currentUserIndex == V.n_rows)
142 {
143 currentUserIndex = 0;
144 currentItemIndex = (currentItemIndex + 1) % V.n_cols;
145 }
146
147 H.col(currentUserIndex++) += u * deltaH;
148 }
149
150 private:
152 double u;
154 double kw;
156 double kh;
157
159 size_t currentUserIndex;
161 size_t currentItemIndex;
162};
163
166
168template<>
170{
171 public:
173 double kw = 0,
174 double kh = 0)
175 : u(u), kw(kw), kh(kh), it(NULL)
176 {}
177
179 {
180 delete it;
181 }
182
183 void Initialize(const arma::sp_mat& dataset, const size_t rank)
184 {
185 (void)rank;
186 n = dataset.n_rows;
187 m = dataset.n_cols;
188
189 it = new arma::sp_mat::const_iterator(dataset.begin());
190 isStart = true;
191 }
192
202 inline void WUpdate(const arma::sp_mat& V,
203 arma::mat& W,
204 const arma::mat& H)
205 {
206 if (!isStart)
207 ++(*it);
208 else isStart = false;
209
210 if (*it == V.end())
211 {
212 delete it;
213 it = new arma::sp_mat::const_iterator(V.begin());
214 }
215
216 size_t currentUserIndex = it->col();
217 size_t currentItemIndex = it->row();
218
219 arma::mat deltaW(1, W.n_cols);
220 deltaW.zeros();
221
222 deltaW += (**it - arma::dot(W.row(currentItemIndex),
223 H.col(currentUserIndex))) * arma::trans(H.col(currentUserIndex));
224 if (kw != 0) deltaW -= kw * W.row(currentItemIndex);
225
226 W.row(currentItemIndex) += u*deltaW;
227 }
228
238 inline void HUpdate(const arma::sp_mat& /* V */,
239 const arma::mat& W,
240 arma::mat& H)
241 {
242 arma::mat deltaH(H.n_rows, 1);
243 deltaH.zeros();
244
245 size_t currentUserIndex = it->col();
246 size_t currentItemIndex = it->row();
247
248 deltaH += (**it - arma::dot(W.row(currentItemIndex),
249 H.col(currentUserIndex))) * arma::trans(W.row(currentItemIndex));
250 if (kh != 0) deltaH -= kh * H.col(currentUserIndex);
251
252 H.col(currentUserIndex) += u * deltaH;
253 }
254
255 private:
256 double u;
257 double kw;
258 double kh;
259
260 size_t n;
261 size_t m;
262
263 arma::sp_mat dummy;
264 arma::sp_mat::const_iterator* it;
265
266 bool isStart;
267}; // class SVDCompleteIncrementalLearning
268
269} // namespace amf
270} // namespace mlpack
271
272#endif
void WUpdate(const arma::sp_mat &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
void Initialize(const arma::sp_mat &dataset, const size_t rank)
void HUpdate(const arma::sp_mat &, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
This class computes SVD using complete incremental batch learning, as described in the following pape...
void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
SVDCompleteIncrementalLearning(double u=0.0001, double kw=0, double kh=0)
Initialize the SVDCompleteIncrementalLearning class with the given parameters.
void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
void Initialize(const MatType &, const size_t)
Initialize parameters before factorization.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.