mlpack 3.4.2
neighbor_search.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP
14#define MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP
15
16#include <mlpack/prereqs.hpp>
17#include <vector>
18#include <string>
19
23
27
28namespace mlpack {
29// Neighbor-search routines. These include all-nearest-neighbors and
30// all-furthest-neighbors searches.
31namespace neighbor {
32
33// Forward declaration.
34template<typename SortPolicy>
35class TrainVisitor;
36
39{
44};
45
69template<typename SortPolicy = NearestNeighborSort,
70 typename MetricType = mlpack::metric::EuclideanDistance,
71 typename MatType = arma::mat,
72 template<typename TreeMetricType,
73 typename TreeStatType,
74 typename TreeMatType> class TreeType = tree::KDTree,
75 template<typename RuleType> class DualTreeTraversalType =
76 TreeType<MetricType,
77 NeighborSearchStat<SortPolicy>,
78 MatType>::template DualTreeTraverser,
79 template<typename RuleType> class SingleTreeTraversalType =
80 TreeType<MetricType,
81 NeighborSearchStat<SortPolicy>,
82 MatType>::template SingleTreeTraverser>
84{
85 public:
87 typedef TreeType<MetricType, NeighborSearchStat<SortPolicy>, MatType> Tree;
88
105 NeighborSearch(MatType referenceSet,
107 const double epsilon = 0,
108 const MetricType metric = MetricType());
109
133 NeighborSearch(Tree referenceTree,
135 const double epsilon = 0,
136 const MetricType metric = MetricType());
137
148 const double epsilon = 0,
149 const MetricType metric = MetricType());
150
158
166
173
180
186
196 void Train(MatType referenceSet);
197
207 void Train(Tree referenceTree);
208
226 void Search(const MatType& querySet,
227 const size_t k,
228 arma::Mat<size_t>& neighbors,
229 arma::mat& distances);
230
251 void Search(Tree& queryTree,
252 const size_t k,
253 arma::Mat<size_t>& neighbors,
254 arma::mat& distances,
255 bool sameSet = false);
256
271 void Search(const size_t k,
272 arma::Mat<size_t>& neighbors,
273 arma::mat& distances);
274
290 static double EffectiveError(arma::mat& foundDistances,
291 arma::mat& realDistances);
292
304 static double Recall(arma::Mat<size_t>& foundNeighbors,
305 arma::Mat<size_t>& realNeighbors);
306
309 size_t BaseCases() const { return baseCases; }
310
312 size_t Scores() const { return scores; }
313
315 NeighborSearchMode SearchMode() const { return searchMode; }
317 NeighborSearchMode& SearchMode() { return searchMode; }
318
320 double Epsilon() const { return epsilon; }
322 double& Epsilon() { return epsilon; }
323
325 const MatType& ReferenceSet() const { return *referenceSet; }
326
328 const Tree& ReferenceTree() const { return *referenceTree; }
330 Tree& ReferenceTree() { return *referenceTree; }
331
333 template<typename Archive>
334 void serialize(Archive& ar, const unsigned int /* version */);
335
336 private:
338 std::vector<size_t> oldFromNewReferences;
340 Tree* referenceTree;
342 const MatType* referenceSet;
343
345 NeighborSearchMode searchMode;
347 double epsilon;
348
350 MetricType metric;
351
353 size_t baseCases;
355 size_t scores;
356
359 bool treeNeedsReset;
360
362 template<typename SortPol>
363 friend class TrainVisitor;
364}; // class NeighborSearch
365
366} // namespace neighbor
367} // namespace mlpack
368
369// Include implementation.
370#include "neighbor_search_impl.hpp"
371
372// Include convenience typedefs.
373#include "typedef.hpp"
374
375#endif
Definition of generalized binary space partitioning tree (BinarySpaceTree).
The NeighborSearch class is a template class for performing distance-based neighbor searches.
void Search(const MatType &querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
For each point in the query set, compute the nearest neighbors and store the output in the given matr...
NeighborSearchMode SearchMode() const
Access the search mode.
~NeighborSearch()
Delete the NeighborSearch object.
void Train(MatType referenceSet)
Set the reference set to a new reference set, and build a tree if necessary.
const MatType & ReferenceSet() const
Access the reference dataset.
NeighborSearch(NeighborSearch &&other)
Construct the NeighborSearch object by taking ownership of the given NeighborSearch object.
size_t BaseCases() const
Return the total number of base case evaluations performed during the last search.
NeighborSearch(Tree referenceTree, const NeighborSearchMode mode=DUAL_TREE_MODE, const double epsilon=0, const MetricType metric=MetricType())
Initialize the NeighborSearch object with a copy of the given pre-constructed reference tree (this is...
size_t Scores() const
Return the number of node combination scores during the last search.
NeighborSearch & operator=(NeighborSearch &&other)
Take ownership of the given NeighborSearch object.
void Train(Tree referenceTree)
Set the reference tree to a new reference tree.
double & Epsilon()
Modify the relative error to be considered in approximate search.
void Search(const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
Search for the nearest neighbors of every point in the reference set.
NeighborSearch(const NeighborSearch &other)
Construct the NeighborSearch object by copying the given NeighborSearch object.
NeighborSearch & operator=(const NeighborSearch &other)
Copy the given NeighborSearch object.
NeighborSearchMode & SearchMode()
Modify the search mode.
Tree & ReferenceTree()
Modify the reference tree.
TreeType< MetricType, NeighborSearchStat< SortPolicy >, MatType > Tree
Convenience typedef.
static double Recall(arma::Mat< size_t > &foundNeighbors, arma::Mat< size_t > &realNeighbors)
Calculate the recall (% of neighbors found) given the list of found neighbors and the true set of nei...
const Tree & ReferenceTree() const
Access the reference tree.
NeighborSearch(const NeighborSearchMode mode=DUAL_TREE_MODE, const double epsilon=0, const MetricType metric=MetricType())
Create a NeighborSearch object without any reference data.
void Search(Tree &queryTree, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances, bool sameSet=false)
Given a pre-built query tree, search for the nearest neighbors of each point in the query tree,...
void serialize(Archive &ar, const unsigned int)
Serialize the NeighborSearch model.
double Epsilon() const
Access the relative error to be considered in approximate search.
static double EffectiveError(arma::mat &foundDistances, arma::mat &realDistances)
Calculate the average relative error (effective error) between the distances calculated and the true ...
NeighborSearch(MatType referenceSet, const NeighborSearchMode mode=DUAL_TREE_MODE, const double epsilon=0, const MetricType metric=MetricType())
Initialize the NeighborSearch object, passing a reference dataset (this is the dataset which is searc...
TrainVisitor sets the reference set to a new reference set on the given NSType.
Definition: ra_model.hpp:128
LMetric< 2, true > EuclideanDistance
The Euclidean (L2) distance.
Definition: lmetric.hpp:112
NeighborSearchMode
NeighborSearchMode represents the different neighbor search modes available.
BinarySpaceTree< MetricType, StatisticType, MatType, bound::HRectBound, MidpointSplit > KDTree
The standard midpoint-split kd-tree.
Definition: typedef.hpp:63
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.