mlpack 3.4.2
cover_tree.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
13#define MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
14
15#include <mlpack/prereqs.hpp>
17
18#include "../statistic.hpp"
20
21namespace mlpack {
22namespace tree {
23
95template<typename MetricType = metric::LMetric<2, true>,
96 typename StatisticType = EmptyStatistic,
97 typename MatType = arma::mat,
98 typename RootPointPolicy = FirstPointIsRoot>
100{
101 public:
103 typedef MatType Mat;
105 typedef typename MatType::elem_type ElemType;
106
118 CoverTree(const MatType& dataset,
119 const ElemType base = 2.0,
120 MetricType* metric = NULL);
121
131 CoverTree(const MatType& dataset,
132 MetricType& metric,
133 const ElemType base = 2.0);
134
142 CoverTree(MatType&& dataset,
143 const ElemType base = 2.0);
144
153 CoverTree(MatType&& dataset,
154 MetricType& metric,
155 const ElemType base = 2.0);
156
189 CoverTree(const MatType& dataset,
190 const ElemType base,
191 const size_t pointIndex,
192 const int scale,
193 CoverTree* parent,
194 const ElemType parentDistance,
195 arma::Col<size_t>& indices,
196 arma::vec& distances,
197 size_t nearSetSize,
198 size_t& farSetSize,
199 size_t& usedSetSize,
200 MetricType& metric = NULL);
201
218 CoverTree(const MatType& dataset,
219 const ElemType base,
220 const size_t pointIndex,
221 const int scale,
222 CoverTree* parent,
223 const ElemType parentDistance,
224 const ElemType furthestDescendantDistance,
225 MetricType* metric = NULL);
226
233 CoverTree(const CoverTree& other);
234
242
249
256
260 template<typename Archive>
262 Archive& ar,
264
269
272 template<typename RuleType>
274
276 template<typename RuleType>
277 class DualTreeTraverser;
278
279 template<typename RuleType>
281
283 const MatType& Dataset() const { return *dataset; }
284
286 size_t Point() const { return point; }
288 size_t Point(const size_t) const { return point; }
289
290 bool IsLeaf() const { return (children.size() == 0); }
291 size_t NumPoints() const { return 1; }
292
294 const CoverTree& Child(const size_t index) const { return *children[index]; }
296 CoverTree& Child(const size_t index) { return *children[index]; }
297
298 CoverTree*& ChildPtr(const size_t index) { return children[index]; }
299
301 size_t NumChildren() const { return children.size(); }
302
304 const std::vector<CoverTree*>& Children() const { return children; }
306 std::vector<CoverTree*>& Children() { return children; }
307
309 size_t NumDescendants() const;
310
312 size_t Descendant(const size_t index) const;
313
315 int Scale() const { return scale; }
317 int& Scale() { return scale; }
318
320 ElemType Base() const { return base; }
322 ElemType& Base() { return base; }
323
325 const StatisticType& Stat() const { return stat; }
327 StatisticType& Stat() { return stat; }
328
333 template<typename VecType>
335 const VecType& point,
337
342 template<typename VecType>
344 const VecType& point,
346
351 size_t GetNearestChild(const CoverTree& queryNode);
352
357 size_t GetFurthestChild(const CoverTree& queryNode);
358
360 ElemType MinDistance(const CoverTree& other) const;
361
364 ElemType MinDistance(const CoverTree& other, const ElemType distance) const;
365
367 ElemType MinDistance(const arma::vec& other) const;
368
371 ElemType MinDistance(const arma::vec& other, const ElemType distance) const;
372
374 ElemType MaxDistance(const CoverTree& other) const;
375
378 ElemType MaxDistance(const CoverTree& other, const ElemType distance) const;
379
381 ElemType MaxDistance(const arma::vec& other) const;
382
385 ElemType MaxDistance(const arma::vec& other, const ElemType distance) const;
386
389
393 const ElemType distance) const;
394
396 math::RangeType<ElemType> RangeDistance(const arma::vec& other) const;
397
401 const ElemType distance) const;
402
404 CoverTree* Parent() const { return parent; }
406 CoverTree*& Parent() { return parent; }
407
409 ElemType ParentDistance() const { return parentDistance; }
411 ElemType& ParentDistance() { return parentDistance; }
412
414 ElemType FurthestPointDistance() const { return 0.0; }
415
418 { return furthestDescendantDistance; }
421 ElemType& FurthestDescendantDistance() { return furthestDescendantDistance; }
422
425 ElemType MinimumBoundDistance() const { return furthestDescendantDistance; }
426
428 void Center(arma::vec& center) const
429 {
430 center = arma::vec(dataset->col(point));
431 }
432
434 MetricType& Metric() const { return *metric; }
435
436 private:
438 const MatType* dataset;
440 size_t point;
442 std::vector<CoverTree*> children;
444 int scale;
446 ElemType base;
448 StatisticType stat;
450 size_t numDescendants;
452 CoverTree* parent;
454 ElemType parentDistance;
456 ElemType furthestDescendantDistance;
458 bool localMetric;
460 bool localDataset;
462 MetricType* metric;
463
467 void CreateChildren(arma::Col<size_t>& indices,
468 arma::vec& distances,
469 size_t nearSetSize,
470 size_t& farSetSize,
471 size_t& usedSetSize);
472
484 void ComputeDistances(const size_t pointIndex,
485 const arma::Col<size_t>& indices,
486 arma::vec& distances,
487 const size_t pointSetSize);
502 size_t SplitNearFar(arma::Col<size_t>& indices,
503 arma::vec& distances,
504 const ElemType bound,
505 const size_t pointSetSize);
506
526 size_t SortPointSet(arma::Col<size_t>& indices,
527 arma::vec& distances,
528 const size_t childFarSetSize,
529 const size_t childUsedSetSize,
530 const size_t farSetSize);
531
532 void MoveToUsedSet(arma::Col<size_t>& indices,
533 arma::vec& distances,
534 size_t& nearSetSize,
535 size_t& farSetSize,
536 size_t& usedSetSize,
537 arma::Col<size_t>& childIndices,
538 const size_t childFarSetSize,
539 const size_t childUsedSetSize);
540 size_t PruneFarSet(arma::Col<size_t>& indices,
541 arma::vec& distances,
542 const ElemType bound,
543 const size_t nearSetSize,
544 const size_t pointSetSize);
545
550 void RemoveNewImplicitNodes();
551
552 protected:
560
562 friend class boost::serialization::access;
563
564 public:
568 template<typename Archive>
569 void serialize(Archive& ar, const unsigned int /* version */);
570
571 size_t DistanceComps() const { return distanceComps; }
572 size_t& DistanceComps() { return distanceComps; }
573
574 private:
575 size_t distanceComps;
576};
577
578} // namespace tree
579} // namespace mlpack
580
581// Include implementation.
582#include "cover_tree_impl.hpp"
583
584// Include the rest of the pieces, if necessary.
585#include "../cover_tree.hpp"
586
587#endif
Simple real-valued range.
Definition: range.hpp:35
A dual-tree cover tree traverser; see dual_tree_traverser.hpp.
A single-tree cover tree traverser; see single_tree_traverser.hpp for implementation.
A cover tree is a tree specifically designed to speed up nearest-neighbor computation in high-dimensi...
Definition: cover_tree.hpp:100
size_t GetFurthestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the furthest child node to the given query point.
ElemType MaxDistance(const CoverTree &other) const
Return the maximum distance to another node.
ElemType MaxDistance(const arma::vec &other, const ElemType distance) const
Return the maximum distance to another point given that the distance from the center to the point has...
CoverTree(const MatType &dataset, MetricType &metric, const ElemType base=2.0)
Create the cover tree with the given dataset and the given instantiated metric.
ElemType MinDistance(const arma::vec &other, const ElemType distance) const
Return the minimum distance to another point given that the distance from the center to the point has...
math::RangeType< ElemType > RangeDistance(const CoverTree &other, const ElemType distance) const
Return the minimum and maximum distance to another node given that the point-to-point distance has al...
size_t GetNearestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the nearest child node to the given query point.
size_t NumChildren() const
Get the number of children.
Definition: cover_tree.hpp:301
ElemType MinimumBoundDistance() const
Get the minimum distance from the center to any bound edge (this is the same as furthestDescendantDis...
Definition: cover_tree.hpp:425
MatType::elem_type ElemType
The type held by the matrix type.
Definition: cover_tree.hpp:105
CoverTree(const MatType &dataset, const ElemType base, const size_t pointIndex, const int scale, CoverTree *parent, const ElemType parentDistance, arma::Col< size_t > &indices, arma::vec &distances, size_t nearSetSize, size_t &farSetSize, size_t &usedSetSize, MetricType &metric=NULL)
Construct a child cover tree node.
ElemType & ParentDistance()
Modify the distance to the parent.
Definition: cover_tree.hpp:411
size_t NumDescendants() const
Get the number of descendant points.
CoverTree *& Parent()
Modify the parent node.
Definition: cover_tree.hpp:406
size_t NumPoints() const
Definition: cover_tree.hpp:291
ElemType & Base()
Modify the base; don't do this, you'll break everything.
Definition: cover_tree.hpp:322
CoverTree(const CoverTree &other)
Create a cover tree from another tree.
StatisticType & Stat()
Modify the statistic for this node.
Definition: cover_tree.hpp:327
CoverTree(const MatType &dataset, const ElemType base=2.0, MetricType *metric=NULL)
Create the cover tree with the given dataset and given base.
CoverTree(Archive &ar, const typename std::enable_if_t< Archive::is_loading::value > *=0)
Create a cover tree from a boost::serialization archive.
size_t Point() const
Get the index of the point which this node represents.
Definition: cover_tree.hpp:286
CoverTree & operator=(const CoverTree &other)
Copy the given Cover Tree.
ElemType Base() const
Get the base.
Definition: cover_tree.hpp:320
ElemType MinDistance(const CoverTree &other, const ElemType distance) const
Return the minimum distance to another node given that the point-to-point distance has already been c...
MatType Mat
So that other classes can access the matrix type.
Definition: cover_tree.hpp:103
math::RangeType< ElemType > RangeDistance(const CoverTree &other) const
Return the minimum and maximum distance to another node.
CoverTree(const MatType &dataset, const ElemType base, const size_t pointIndex, const int scale, CoverTree *parent, const ElemType parentDistance, const ElemType furthestDescendantDistance, MetricType *metric=NULL)
Manually construct a cover tree node; no tree assembly is done in this constructor,...
CoverTree()
A default constructor.
math::RangeType< ElemType > RangeDistance(const arma::vec &other) const
Return the minimum and maximum distance to another point.
CoverTree & operator=(CoverTree &&other)
Take ownership of the given Cover Tree.
int & Scale()
Modify the scale of this node. Be careful...
Definition: cover_tree.hpp:317
size_t DistanceComps() const
Definition: cover_tree.hpp:571
const std::vector< CoverTree * > & Children() const
Get the children.
Definition: cover_tree.hpp:304
ElemType MaxDistance(const arma::vec &other) const
Return the maximum distance to another point.
int Scale() const
Get the scale of this node.
Definition: cover_tree.hpp:315
math::RangeType< ElemType > RangeDistance(const arma::vec &other, const ElemType distance) const
Return the minimum and maximum distance to another point given that the point-to-point distance has a...
ElemType MaxDistance(const CoverTree &other, const ElemType distance) const
Return the maximum distance to another node given that the point-to-point distance has already been c...
const CoverTree & Child(const size_t index) const
Get a particular child node.
Definition: cover_tree.hpp:294
ElemType MinDistance(const arma::vec &other) const
Return the minimum distance to another point.
ElemType & FurthestDescendantDistance()
Modify the distance from the center of the node to the furthest descendant.
Definition: cover_tree.hpp:421
ElemType MinDistance(const CoverTree &other) const
Return the minimum distance to another node.
size_t Descendant(const size_t index) const
Get the index of a particular descendant point.
void Center(arma::vec &center) const
Get the center of the node and store it in the given vector.
Definition: cover_tree.hpp:428
CoverTree(MatType &&dataset, MetricType &metric, const ElemType base=2.0)
Create the cover tree with the given dataset and the given instantiated metric, taking ownership of t...
const StatisticType & Stat() const
Get the statistic for this node.
Definition: cover_tree.hpp:325
CoverTree(MatType &&dataset, const ElemType base=2.0)
Create the cover tree with the given dataset, taking ownership of the dataset.
size_t GetNearestChild(const CoverTree &queryNode)
Return the index of the nearest child node to the given query node.
CoverTree & Child(const size_t index)
Modify a particular child node.
Definition: cover_tree.hpp:296
CoverTree(CoverTree &&other)
Move constructor for a Cover Tree, possess all the members of the given tree.
size_t GetFurthestChild(const CoverTree &queryNode)
Return the index of the furthest child node to the given query node.
MetricType & Metric() const
Get the instantiated metric.
Definition: cover_tree.hpp:434
ElemType FurthestPointDistance() const
Get the distance to the furthest point. This is always 0 for cover trees.
Definition: cover_tree.hpp:414
CoverTree *& ChildPtr(const size_t index)
Definition: cover_tree.hpp:298
const MatType & Dataset() const
Get a reference to the dataset.
Definition: cover_tree.hpp:283
ElemType FurthestDescendantDistance() const
Get the distance from the center of the node to the furthest descendant.
Definition: cover_tree.hpp:417
~CoverTree()
Delete this cover tree node and its children.
CoverTree * Parent() const
Get the parent node.
Definition: cover_tree.hpp:404
size_t Point(const size_t) const
For compatibility with other trees; the argument is ignored.
Definition: cover_tree.hpp:288
void serialize(Archive &ar, const unsigned int)
Serialize the tree.
ElemType ParentDistance() const
Get the distance to the parent.
Definition: cover_tree.hpp:409
std::vector< CoverTree * > & Children()
Modify the children manually (maybe not a great idea).
Definition: cover_tree.hpp:306
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:70
The core includes that mlpack expects; standard C++ includes and Armadillo.
Definition of the Range class, which represents a simple range with a lower and upper bound.
If value == true, then VecType is some sort of Armadillo vector or subview.
Definition: arma_traits.hpp:36