mlpack 3.4.2
nmf_mult_div.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_METHODS_LMF_UPDATE_RULES_NMF_MULT_DIV_HPP
13#define MLPACK_METHODS_LMF_UPDATE_RULES_NMF_MULT_DIV_HPP
14
15#include <mlpack/prereqs.hpp>
16
17namespace mlpack {
18namespace amf {
19
49{
50 public:
51 // Empty constructor required for the WUpdateRule template.
53
58 template<typename MatType>
59 void Initialize(const MatType& /* dataset */, const size_t /* rank */)
60 {
61 // Nothing to do.
62 }
63
79 template<typename MatType>
80 inline static void WUpdate(const MatType& V,
81 arma::mat& W,
82 const arma::mat& H)
83 {
84 // Simple implementation left in the header file.
85 arma::mat t1;
86 arma::rowvec t2;
87
88 t1 = W * H;
89 for (size_t i = 0; i < W.n_rows; ++i)
90 {
91 for (size_t j = 0; j < W.n_cols; ++j)
92 {
93 // Writing this as a single expression does not work as of Armadillo
94 // 3.920. This should be fixed in a future release, and then the code
95 // below can be fixed.
96 // t2 = H.row(j) % V.row(i) / t1.row(i);
97 t2.set_size(H.n_cols);
98 for (size_t k = 0; k < t2.n_elem; ++k)
99 {
100 t2(k) = H(j, k) * V(i, k) / t1(i, k);
101 }
102
103 W(i, j) = W(i, j) * sum(t2) / sum(H.row(j));
104 }
105 }
106 }
107
123 template<typename MatType>
124 inline static void HUpdate(const MatType& V,
125 const arma::mat& W,
126 arma::mat& H)
127 {
128 // Simple implementation left in the header file.
129 arma::mat t1;
130 arma::colvec t2;
131
132 t1 = W * H;
133 for (size_t i = 0; i < H.n_rows; ++i)
134 {
135 for (size_t j = 0; j < H.n_cols; ++j)
136 {
137 // Writing this as a single expression does not work as of Armadillo
138 // 3.920. This should be fixed in a future release, and then the code
139 // below can be fixed.
140 // t2 = W.col(i) % V.col(j) / t1.col(j);
141 t2.set_size(W.n_rows);
142 for (size_t k = 0; k < t2.n_elem; ++k)
143 {
144 t2(k) = W(k, i) * V(k, j) / t1(k, j);
145 }
146
147 H(i, j) = H(i, j) * sum(t2) / sum(W.col(i));
148 }
149 }
150 }
151
153 template<typename Archive>
154 void serialize(Archive& /* ar */, const unsigned int /* version */) { }
155};
156
157} // namespace amf
158} // namespace mlpack
159
160#endif
This follows a method described in the paper 'Algorithms for Non-negative.
static void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
void serialize(Archive &, const unsigned int)
Serialize the object (in this case, there is nothing to serialize).
void Initialize(const MatType &, const size_t)
Initialize the factorization.
static void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.