mlpack 3.4.2
cosine_search.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_METHODS_CF_COSINE_SEARCH_HPP
13#define MLPACK_METHODS_CF_COSINE_SEARCH_HPP
14
15#include <mlpack/prereqs.hpp>
17
18namespace mlpack {
19namespace cf {
20
45{
46 public:
53 CosineSearch(const arma::mat& referenceSet)
54 {
55 // Normalize all vectors to unit length.
56 arma::mat normalizedSet = arma::normalise(referenceSet, 2, 0);
57
58 neighborSearch.Train(std::move(normalizedSet));
59 }
60
70 void Search(const arma::mat& query, const size_t k,
71 arma::Mat<size_t>& neighbors, arma::mat& similarities)
72 {
73 // Normalize query vectors to unit length.
74 arma::mat normalizedQuery = arma::normalise(query, 2, 0);
75
76 neighborSearch.Search(normalizedQuery, k, neighbors, similarities);
77
78 // Resulting similarities from Search() are Euclidean distance.
79 // For unit vectors a and b, cos(a, b) = 1 - dis(a, b) ^ 2 / 2,
80 // where dis(a, b) is Euclidean distance.
81 // Furthermore, we restrict the range of similarity to be [0, 1]:
82 // similarities = (cos(a,b) + 1) / 2.0. As a result we have the following
83 // formula.
84 similarities = 1 - arma::pow(similarities, 2) / 4.0;
85 }
86
87 private:
89 neighbor::KNN neighborSearch;
90};
91
92} // namespace cf
93} // namespace mlpack
94
95#endif
Nearest neighbor search with cosine distance.
CosineSearch(const arma::mat &referenceSet)
Constructor with reference set.
void Search(const arma::mat &query, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &similarities)
Given a set of query points, find the nearest k neighbors, and return similarities.
The NeighborSearch class is a template class for performing distance-based neighbor searches.
void Search(const MatType &querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
For each point in the query set, compute the nearest neighbors and store the output in the given matr...
void Train(MatType referenceSet)
Set the reference set to a new reference set, and build a tree if necessary.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.