mlpack 3.4.2
svd_incomplete_incremental_learning.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_METHODS_AMF_SVD_INCOMPLETE_INCREMENTAL_LEARNING_HPP
13#define MLPACK_METHODS_AMF_SVD_INCOMPLETE_INCREMENTAL_LEARNING_HPP
14
15namespace mlpack
16{
17namespace amf
18{
19
44{
45 public:
54 double kw = 0,
55 double kh = 0)
56 : u(u), kw(kw), kh(kh)
57 {
58 // Nothing to do.
59 }
60
69 template<typename MatType>
70 void Initialize(const MatType& /* dataset */, const size_t /* rank */)
71 {
72 // Set the current user to 0.
73 currentUserIndex = 0;
74 }
75
85 template<typename MatType>
86 inline void WUpdate(const MatType& V,
87 arma::mat& W,
88 const arma::mat& H)
89 {
90 arma::mat deltaW;
91 deltaW.zeros(V.n_rows, W.n_cols);
92
93 // Iterate through all the rating by this user to update corresponding item
94 // feature feature vector.
95 for (size_t i = 0; i < V.n_rows; ++i)
96 {
97 const double val = V(i, currentUserIndex);
98 // Update only if the rating is non-zero.
99 if (val != 0)
100 {
101 deltaW.row(i) += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
102 H.col(currentUserIndex).t();
103 }
104 // Add regularization.
105 if (kw != 0)
106 deltaW.row(i) -= kw * W.row(i);
107 }
108
109 W += u * deltaW;
110 }
111
120 template<typename MatType>
121 inline void HUpdate(const MatType& V,
122 const arma::mat& W,
123 arma::mat& H)
124 {
125 arma::vec deltaH;
126 deltaH.zeros(H.n_rows);
127
128 // Iterate through all the rating by this user to update corresponding item
129 // feature feature vector.
130 for (size_t i = 0; i < V.n_rows; ++i)
131 {
132 const double val = V(i, currentUserIndex);
133 // Update only if the rating is non-zero.
134 if (val != 0)
135 {
136 deltaH += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
137 W.row(i).t();
138 }
139 }
140 // Add regularization.
141 if (kh != 0)
142 deltaH -= kh * H.col(currentUserIndex);
143
144 // Update H matrix and move on to the next user.
145 H.col(currentUserIndex++) += u * deltaH;
146 currentUserIndex = currentUserIndex % V.n_cols;
147 }
148
149 private:
151 double u;
153 double kw;
155 double kh;
156
158 size_t currentUserIndex;
159};
160
163
165template<>
166inline void SVDIncompleteIncrementalLearning::WUpdate<arma::sp_mat>(
167 const arma::sp_mat& V, arma::mat& W, const arma::mat& H)
168{
169 arma::mat deltaW(V.n_rows, W.n_cols);
170 deltaW.zeros();
171 for (arma::sp_mat::const_iterator it = V.begin_col(currentUserIndex);
172 it != V.end_col(currentUserIndex); ++it)
173 {
174 double val = *it;
175 size_t i = it.row();
176 deltaW.row(i) += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
177 arma::trans(H.col(currentUserIndex));
178 if (kw != 0) deltaW.row(i) -= kw * W.row(i);
179 }
180
181 W += u*deltaW;
182}
183
184template<>
185inline void SVDIncompleteIncrementalLearning::HUpdate<arma::sp_mat>(
186 const arma::sp_mat& V, const arma::mat& W, arma::mat& H)
187{
188 arma::mat deltaH(H.n_rows, 1);
189 deltaH.zeros();
190
191 for (arma::sp_mat::const_iterator it = V.begin_col(currentUserIndex);
192 it != V.end_col(currentUserIndex); ++it)
193 {
194 double val = *it;
195 size_t i = it.row();
196 if ((val = V(i, currentUserIndex)) != 0)
197 {
198 deltaH += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
199 arma::trans(W.row(i));
200 }
201 }
202 if (kh != 0) deltaH -= kh * H.col(currentUserIndex);
203
204 H.col(currentUserIndex++) += u * deltaH;
205 currentUserIndex = currentUserIndex % V.n_cols;
206}
207
208} // namespace amf
209} // namespace mlpack
210
211#endif
This class computes SVD using incomplete incremental batch learning, as described in the following pa...
void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
void Initialize(const MatType &, const size_t)
Initialize parameters before factorization.
SVDIncompleteIncrementalLearning(double u=0.001, double kw=0, double kh=0)
Initialize the parameters of SVDIncompleteIncrementalLearning.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1