mlpack 3.4.2
regression_interpolation.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_METHODS_CF_REGRESSION_INTERPOLATION_HPP
13#define MLPACK_METHODS_CF_REGRESSION_INTERPOLATION_HPP
14
15#include <mlpack/prereqs.hpp>
16
17namespace mlpack {
18namespace cf {
19
57{
58 public:
63
69 RegressionInterpolation(const arma::sp_mat& cleanedData)
70 {
71 const size_t userNum = cleanedData.n_cols;
72 a.set_size(userNum, userNum);
73 b.set_size(userNum, userNum);
74 }
75
93 template <typename VectorType,
94 typename DecompositionPolicy>
95 void GetWeights(VectorType&& weights,
96 const DecompositionPolicy& decomposition,
97 const size_t queryUser,
98 const arma::Col<size_t>& neighbors,
99 const arma::vec& /* similarities*/,
100 const arma::sp_mat& cleanedData)
101 {
102 if (weights.n_elem != neighbors.n_elem)
103 {
104 Log::Fatal << "The size of the first parameter (weights) should "
105 << "be set to the number of neighbors before calling GetWeights()."
106 << std::endl;
107 }
108
109 const arma::mat& w = decomposition.W();
110 const arma::mat& h = decomposition.H();
111 const size_t itemNum = cleanedData.n_rows;
112 const size_t neighborNum = neighbors.size();
113
114 // Coeffcients of the linear equations used to compute weights.
115 arma::mat coeff(neighborNum, neighborNum);
116 // Constant terms of the linear equations used to compute weights.
117 arma::vec constant(neighborNum);
118
119 arma::vec userRating(cleanedData.col(queryUser));
120 const size_t support = arma::accu(userRating != 0);
121
122 // If user has no rating at all, average interpolation is used.
123 if (support == 0)
124 {
125 weights.fill(1.0 / neighbors.n_elem);
126 return;
127 }
128
129 for (size_t i = 0; i < neighborNum; ++i)
130 {
131 // Calculate coefficient.
132 arma::vec iPrediction;
133 for (size_t j = i; j < neighborNum; ++j)
134 {
135 if (a(neighbors(i), neighbors(j)) != 0)
136 {
137 // The coefficient has already been cached.
138 coeff(i, j) = a(neighbors(i), neighbors(j));
139 coeff(j, i) = coeff(i, j);
140 }
141 else
142 {
143 // Calculate the coefficient.
144 if (iPrediction.size() == 0)
145 // Avoid recalculation of iPrediction.
146 iPrediction = w * h.col(neighbors(i));
147 arma::vec jPrediction = w * h.col(neighbors(j));
148 coeff(i, j) = arma::dot(iPrediction, jPrediction) / itemNum;
149 if (coeff(i, j) == 0)
150 coeff(i, j) = std::numeric_limits<double>::min();
151 coeff(j, i) = coeff(i, j);
152 // Cache calcualted coefficient.
153 a(neighbors(i), neighbors(j)) = coeff(i, j);
154 a(neighbors(j), neighbors(i)) = coeff(i, j);
155 }
156 }
157
158 // Calculate constant terms.
159 if (b(neighbors(i), queryUser) != 0)
160 // The constant term has already been cached.
161 constant(i) = b(neighbors(i), queryUser);
162 else
163 {
164 // Calcuate the constant term.
165 if (iPrediction.size() == 0)
166 // Avoid recalculation of iPrediction.
167 iPrediction = w * h.col(neighbors(i));
168 constant(i) = arma::dot(iPrediction, userRating) / support;
169 if (constant(i) == 0)
170 constant(i) = std::numeric_limits<double>::min();
171 // Cache calculated constant term.
172 b(neighbors(i), queryUser) = constant(i);
173 }
174 }
175 weights = arma::solve(coeff, constant);
176 }
177
178 private:
180 arma::sp_mat a;
182 arma::sp_mat b;
183};
184
185} // namespace cf
186} // namespace mlpack
187
188#endif
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Implementation of regression-based interpolation method.
RegressionInterpolation(const arma::sp_mat &cleanedData)
Use cleanedData to perform necessary preprocessing.
void GetWeights(VectorType &&weights, const DecompositionPolicy &decomposition, const size_t queryUser, const arma::Col< size_t > &neighbors, const arma::vec &, const arma::sp_mat &cleanedData)
The regression-based interpolation problem can be solved by a linear system of equations.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.