mlpack 3.4.2
cf_model.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_CF_CF_MODEL_HPP
14#define MLPACK_METHODS_CF_CF_MODEL_HPP
15
16#include <mlpack/core.hpp>
17#include <boost/variant.hpp>
18#include "cf.hpp"
19
27
33
34namespace mlpack {
35namespace cf {
36
41class DeleteVisitor : public boost::static_visitor<void>
42{
43 public:
45 template <typename DecompositionPolicy,
46 typename NormalizationType = NoNormalization>
48};
49
53class GetValueVisitor : public boost::static_visitor<void*>
54{
55 public:
57 template <typename DecompositionPolicy,
58 typename NormalizationType = NoNormalization>
60};
61
66template <typename NeighborSearchPolicy,
67 typename InterpolationPolicy>
68class PredictVisitor : public boost::static_visitor<void>
69{
70 private:
72 const arma::Mat<size_t>& combinations;
74 arma::vec& predictions;
75
76 public:
78 template <typename DecompositionPolicy,
79 typename NormalizationType = NoNormalization>
81
83 PredictVisitor(const arma::Mat<size_t>& combinations,
84 arma::vec& predictions);
85};
86
91template <typename NeighborSearchPolicy,
92 typename InterpolationPolicy>
93class RecommendationVisitor : public boost::static_visitor<void>
94{
95 private:
97 const size_t numRecs;
99 arma::Mat<size_t>& recommendations;
101 const arma::Col<size_t>& users;
103 const bool usersGiven;
104
105 public:
107 RecommendationVisitor(const size_t numRecs,
108 arma::Mat<size_t>& recommendations,
109 const arma::Col<size_t>& users,
110 const bool usersGiven);
111
113 template <typename DecompositionPolicy,
114 typename NormalizationType = NoNormalization>
116};
117
122{
123 private:
130 boost::variant<CFType<NMFPolicy, NoNormalization>*,
138
147
156
165
174
175 public:
178
181
183 template <typename DecompositionPolicy,
184 typename NormalizationType = NoNormalization>
186
188 template<typename DecompositionPolicy,
189 typename MatType>
190 void Train(const MatType& data,
191 const size_t numUsersForSimilarity,
192 const size_t rank,
193 const size_t maxIterations,
194 const double minResidue,
195 const bool mit,
196 const std::string& normalizationType = "none");
197
199 template <typename NeighborSearchPolicy,
200 typename InterpolationPolicy>
201 void Predict(const arma::Mat<size_t>& combinations,
202 arma::vec& predictions);
203
205 template<typename NeighborSearchPolicy,
206 typename InterpolationPolicy>
207 void GetRecommendations(const size_t numRecs,
208 arma::Mat<size_t>& recommendations,
209 const arma::Col<size_t>& users);
210
212 template<typename NeighborSearchPolicy,
213 typename InterpolationPolicy>
214 void GetRecommendations(const size_t numRecs,
215 arma::Mat<size_t>& recommendations);
216
218 template<typename Archive>
219 void serialize(Archive& ar, const unsigned int /* version */);
220};
221
222} // namespace cf
223} // namespace mlpack
224
225// Include implementation.
226#include "cf_model_impl.hpp"
227
228#endif
The model to save to disk.
Definition: cf_model.hpp:122
void GetRecommendations(const size_t numRecs, arma::Mat< size_t > &recommendations, const arma::Col< size_t > &users)
Compute recommendations for query users.
CFModel()
Create an empty CF model.
Definition: cf_model.hpp:177
void Predict(const arma::Mat< size_t > &combinations, arma::vec &predictions)
Make predictions.
~CFModel()
Clean up memory.
const CFType< DecompositionPolicy, NormalizationType > * CFPtr() const
Get the pointer to CFType<> object.
void Train(const MatType &data, const size_t numUsersForSimilarity, const size_t rank, const size_t maxIterations, const double minResidue, const bool mit, const std::string &normalizationType="none")
Train the model.
void GetRecommendations(const size_t numRecs, arma::Mat< size_t > &recommendations)
Compute recommendations for all users.
void serialize(Archive &ar, const unsigned int)
Serialize the model.
This class implements Collaborative Filtering (CF).
Definition: cf.hpp:71
DeleteVisitor deletes the CFType<> object which is pointed to by the variable cf in class CFModel.
Definition: cf_model.hpp:42
void operator()(CFType< DecompositionPolicy, NormalizationType > *c) const
Delete CFType object.
GetValueVisitor returns the pointer which points to the CFType object.
Definition: cf_model.hpp:54
void * operator()(CFType< DecompositionPolicy, NormalizationType > *c) const
Return stored pointer as void* type.
This normalization class doesn't perform any normalization.
PredictVisitor uses the CFType object to make predictions on the given combinations of users and item...
Definition: cf_model.hpp:69
PredictVisitor(const arma::Mat< size_t > &combinations, arma::vec &predictions)
Visitor constructor.
void operator()(CFType< DecompositionPolicy, NormalizationType > *c) const
Predict ratings for each user-item combination.
RecommendationVisitor uses the CFType object to get recommendations for the given users.
Definition: cf_model.hpp:94
RecommendationVisitor(const size_t numRecs, arma::Mat< size_t > &recommendations, const arma::Col< size_t > &users, const bool usersGiven)
Visitor constructor.
void operator()(CFType< DecompositionPolicy, NormalizationType > *c) const
Generates the given number of recommendations.
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1