mlpack 3.4.2
neighbor_search_rules.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
14#define MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
15
17
18#include <queue>
19
20namespace mlpack {
21namespace neighbor {
22
34template<typename SortPolicy, typename MetricType, typename TreeType>
36{
37 public:
50 NeighborSearchRules(const typename TreeType::Mat& referenceSet,
51 const typename TreeType::Mat& querySet,
52 const size_t k,
53 MetricType& metric,
54 const double epsilon = 0,
55 const bool sameSet = false);
56
64 void GetResults(arma::Mat<size_t>& neighbors, arma::mat& distances);
65
74 double BaseCase(const size_t queryIndex, const size_t referenceIndex);
75
84 double Score(const size_t queryIndex, TreeType& referenceNode);
85
92 size_t GetBestChild(const size_t queryIndex, TreeType& referenceNode);
93
100 size_t GetBestChild(const TreeType& queryNode, TreeType& referenceNode);
101
113 double Rescore(const size_t queryIndex,
114 TreeType& referenceNode,
115 const double oldScore) const;
116
125 double Score(TreeType& queryNode, TreeType& referenceNode);
126
138 double Rescore(TreeType& queryNode,
139 TreeType& referenceNode,
140 const double oldScore) const;
141
143 size_t BaseCases() const { return baseCases; }
145 size_t& BaseCases() { return baseCases; }
146
148 size_t Scores() const { return scores; }
150 size_t& Scores() { return scores; }
151
154
159
162 size_t MinimumBaseCases() const { return k; }
163
164 protected:
166 const typename TreeType::Mat& referenceSet;
167
169 const typename TreeType::Mat& querySet;
170
172 typedef std::pair<double, size_t> Candidate;
173
176 bool operator()(const Candidate& c1, const Candidate& c2)
177 {
178 return !SortPolicy::IsBetter(c2.first, c1.first);
179 };
180 };
181
183 typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
185
187 std::vector<CandidateList> candidates;
188
190 const size_t k;
191
193 MetricType& metric;
194
197
199 const double epsilon;
200
207
209 size_t baseCases;
211 size_t scores;
212
216
220 double CalculateBound(TreeType& queryNode) const;
221
229 void InsertNeighbor(const size_t queryIndex,
230 const size_t neighbor,
231 const double distance);
232};
233
234} // namespace neighbor
235} // namespace mlpack
236
237// Include implementation.
238#include "neighbor_search_rules_impl.hpp"
239
240#endif // MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
The NeighborSearchRules class is a template helper class used by NeighborSearch class when performing...
double Score(TreeType &queryNode, TreeType &referenceNode)
Get the score for recursion order.
size_t baseCases
The number of base cases that have been performed.
std::pair< double, size_t > Candidate
Candidate represents a possible candidate neighbor (distance, index).
size_t GetBestChild(const size_t queryIndex, TreeType &referenceNode)
Get the child node with the best score.
std::vector< CandidateList > candidates
Set of candidate neighbors for each point.
const TreeType::Mat & querySet
The query set.
size_t BaseCases() const
Get the number of base cases that have been performed.
size_t MinimumBaseCases() const
Get the minimum number of base cases we need to perform to have acceptable results.
size_t & BaseCases()
Modify the number of base cases that have been performed.
size_t Scores() const
Get the number of scores that have been performed.
bool sameSet
Denotes whether or not the reference and query sets are the same.
double lastBaseCase
The last base case result.
double Rescore(TreeType &queryNode, TreeType &referenceNode, const double oldScore) const
Re-evaluate the score for recursion order.
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
const TraversalInfoType & TraversalInfo() const
Get the traversal info.
TraversalInfoType & TraversalInfo()
Modify the traversal info.
void InsertNeighbor(const size_t queryIndex, const size_t neighbor, const double distance)
Helper function to insert a point into the list of candidate points.
NeighborSearchRules(const typename TreeType::Mat &referenceSet, const typename TreeType::Mat &querySet, const size_t k, MetricType &metric, const double epsilon=0, const bool sameSet=false)
Construct the NeighborSearchRules object.
size_t lastQueryIndex
The last query point BaseCase() was called with.
std::priority_queue< Candidate, std::vector< Candidate >, CandidateCmp > CandidateList
Use a priority queue to represent the list of candidate neighbors.
size_t GetBestChild(const TreeType &queryNode, TreeType &referenceNode)
Get the child node with the best score.
double CalculateBound(TreeType &queryNode) const
Recalculate the bound for a given query node.
TraversalInfoType traversalInfo
Traversal info for the parent combination; this is updated by the traversal before each call to Score...
tree::TraversalInfo< TreeType > TraversalInfoType
Convenience typedef.
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
Get the distance from the query point to the reference point.
MetricType & metric
The instantiated metric.
const double epsilon
Relative error to be considered in approximate search.
size_t lastReferenceIndex
The last reference point BaseCase() was called with.
void GetResults(arma::Mat< size_t > &neighbors, arma::mat &distances)
Store the list of candidates for each query point in the given matrices.
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore) const
Re-evaluate the score for recursion order.
const TreeType::Mat & referenceSet
The reference set.
size_t scores
The number of scores that have been performed.
size_t & Scores()
Modify the number of scores that have been performed.
const size_t k
Number of neighbors to search for.
The TraversalInfo class holds traversal information which is used in dual-tree (and single-tree) trav...
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Compare two candidates based on the distance.
bool operator()(const Candidate &c1, const Candidate &c2)