mlpack 3.4.2
ra_search_rules.hpp
Go to the documentation of this file.
1
14#ifndef MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
15#define MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
16
18
19#include <queue>
20
21namespace mlpack {
22namespace neighbor {
23
32template<typename SortPolicy, typename MetricType, typename TreeType>
34{
35 public:
57 RASearchRules(const arma::mat& referenceSet,
58 const arma::mat& querySet,
59 const size_t k,
60 MetricType& metric,
61 const double tau = 5,
62 const double alpha = 0.95,
63 const bool naive = false,
64 const bool sampleAtLeaves = false,
65 const bool firstLeafExact = false,
66 const size_t singleSampleLimit = 20,
67 const bool sameSet = false);
68
76 void GetResults(arma::Mat<size_t>& neighbors, arma::mat& distances);
77
85 double BaseCase(const size_t queryIndex, const size_t referenceIndex);
86
109 double Score(const size_t queryIndex, TreeType& referenceNode);
110
134 double Score(const size_t queryIndex,
135 TreeType& referenceNode,
136 const double baseCaseResult);
137
155 double Rescore(const size_t queryIndex,
156 TreeType& referenceNode,
157 const double oldScore);
158
177 double Score(TreeType& queryNode, TreeType& referenceNode);
178
199 double Score(TreeType& queryNode,
200 TreeType& referenceNode,
201 const double baseCaseResult);
202
225 double Rescore(TreeType& queryNode,
226 TreeType& referenceNode,
227 const double oldScore);
228
229
230 size_t NumDistComputations() { return numDistComputations; }
232 {
233 if (numSamplesMade.n_elem == 0)
234 return 0;
235 else
236 return arma::sum(numSamplesMade);
237 }
238
240
241 const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
242 TraversalInfoType& TraversalInfo() { return traversalInfo; }
243
247 size_t MinimumBaseCases() const { return k; }
248
249 private:
251 const arma::mat& referenceSet;
252
254 const arma::mat& querySet;
255
257 typedef std::pair<double, size_t> Candidate;
258
260 struct CandidateCmp {
261 bool operator()(const Candidate& c1, const Candidate& c2)
262 {
263 return !SortPolicy::IsBetter(c2.first, c1.first);
264 };
265 };
266
268 typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
269 CandidateList;
270
272 std::vector<CandidateList> candidates;
273
275 const size_t k;
276
278 MetricType& metric;
279
281 bool sampleAtLeaves;
282
284 bool firstLeafExact;
285
287 size_t singleSampleLimit;
288
290 size_t numSamplesReqd;
291
293 arma::Col<size_t> numSamplesMade;
294
296 double samplingRatio;
297
299 size_t numDistComputations;
300
302 bool sameSet;
303
304 TraversalInfoType traversalInfo;
305
313 void InsertNeighbor(const size_t queryIndex,
314 const size_t neighbor,
315 const double distance);
316
320 double Score(const size_t queryIndex,
321 TreeType& referenceNode,
322 const double distance,
323 const double bestDistance);
324
328 double Score(TreeType& queryNode,
329 TreeType& referenceNode,
330 const double distance,
331 const double bestDistance);
332
333 static_assert(tree::TreeTraits<TreeType>::UniqueNumDescendants, "TreeType "
334 "must provide a unique number of descendants points.");
335}; // class RASearchRules
336
337} // namespace neighbor
338} // namespace mlpack
339
340// Include implementation.
341#include "ra_search_rules_impl.hpp"
342
343#endif // MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
The RASearchRules class is a template helper class used by RASearch class when performing rank-approx...
double Score(TreeType &queryNode, TreeType &referenceNode)
Get the score for recursion order.
double Rescore(TreeType &queryNode, TreeType &referenceNode, const double oldScore)
Re-evaluate the score for recursion order.
size_t MinimumBaseCases() const
Get the minimum number of base cases that must be performed for each query point for an acceptable re...
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
const TraversalInfoType & TraversalInfo() const
double Score(const size_t queryIndex, TreeType &referenceNode, const double baseCaseResult)
Get the score for recursion order.
TraversalInfoType & TraversalInfo()
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore)
Re-evaluate the score for recursion order.
RASearchRules(const arma::mat &referenceSet, const arma::mat &querySet, const size_t k, MetricType &metric, const double tau=5, const double alpha=0.95, const bool naive=false, const bool sampleAtLeaves=false, const bool firstLeafExact=false, const size_t singleSampleLimit=20, const bool sameSet=false)
Construct the RASearchRules object.
tree::TraversalInfo< TreeType > TraversalInfoType
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
Get the distance from the query point to the reference point.
void GetResults(arma::Mat< size_t > &neighbors, arma::mat &distances)
Store the list of candidates for each query point in the given matrices.
double Score(TreeType &queryNode, TreeType &referenceNode, const double baseCaseResult)
Get the score for recursion order, passing the base case result (in the situation where it may be nee...
The TraversalInfo class holds traversal information which is used in dual-tree (and single-tree) trav...
static const bool UniqueNumDescendants
This is true if the NumDescendants() method doesn't include duplicated points.
see subsection cli_alt_reg_tut Alternate DET regularization The usual regularized error f$R_ alpha(t)\f$ of a node \f$t\f$ is given by
Definition: det.txt:344
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1