mlpack 3.4.2
ns_model.hpp
Go to the documentation of this file.
1
15#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP
16#define MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP
17
23#include <boost/variant.hpp>
24#include "neighbor_search.hpp"
25
26namespace mlpack {
27namespace neighbor {
28
32template<typename SortPolicy,
33 template<typename TreeMetricType,
34 typename TreeStatType,
35 typename TreeMatType> class TreeType>
36using NSType = NeighborSearch<SortPolicy,
38 arma::mat,
39 TreeType,
42 arma::mat>::template DualTreeTraverser>;
43
48class MonoSearchVisitor : public boost::static_visitor<void>
49{
50 private:
52 const size_t k;
54 arma::Mat<size_t>& neighbors;
56 arma::mat& distances;
57
58 public:
60 template<typename NSType>
61 void operator()(NSType* ns) const;
62
64 MonoSearchVisitor(const size_t k,
65 arma::Mat<size_t>& neighbors,
66 arma::mat& distances) :
67 k(k),
68 neighbors(neighbors),
69 distances(distances)
70 {};
71};
72
79template<typename SortPolicy>
80class BiSearchVisitor : public boost::static_visitor<void>
81{
82 private:
84 const arma::mat& querySet;
86 const size_t k;
88 arma::Mat<size_t>& neighbors;
90 arma::mat& distances;
92 const size_t leafSize;
94 const double tau;
96 const double rho;
97
99 template<typename NSType>
100 void SearchLeaf(NSType* ns) const;
101
102 public:
104 template<template<typename TreeMetricType,
105 typename TreeStatType,
106 typename TreeMatType> class TreeType>
108
110 template<template<typename TreeMetricType,
111 typename TreeStatType,
112 typename TreeMatType> class TreeType>
113 void operator()(NSTypeT<TreeType>* ns) const;
114
117
120
122 void operator()(SpillKNN* ns) const;
123
126
128 BiSearchVisitor(const arma::mat& querySet,
129 const size_t k,
130 arma::Mat<size_t>& neighbors,
131 arma::mat& distances,
132 const size_t leafSize,
133 const double tau,
134 const double rho);
135};
136
143template<typename SortPolicy>
144class TrainVisitor : public boost::static_visitor<void>
145{
146 private:
148 arma::mat&& referenceSet;
150 size_t leafSize;
152 const double tau;
154 const double rho;
155
157 template<typename NSType>
158 void TrainLeaf(NSType* ns) const;
159
160 public:
162 template<template<typename TreeMetricType,
163 typename TreeStatType,
164 typename TreeMatType> class TreeType>
166
168 template<template<typename TreeMetricType,
169 typename TreeStatType,
170 typename TreeMatType> class TreeType>
171 void operator()(NSTypeT<TreeType>* ns) const;
172
175
178
180 void operator()(SpillKNN* ns) const;
181
184
187 TrainVisitor(arma::mat&& referenceSet,
188 const size_t leafSize,
189 const double tau,
190 const double rho);
191};
192
196class SearchModeVisitor : public boost::static_visitor<NeighborSearchMode&>
197{
198 public:
200 template<typename NSType>
202};
203
207class EpsilonVisitor : public boost::static_visitor<double&>
208{
209 public:
211 template<typename NSType>
212 double& operator()(NSType *ns) const;
213};
214
218class ReferenceSetVisitor : public boost::static_visitor<const arma::mat&>
219{
220 public:
222 template<typename NSType>
223 const arma::mat& operator()(NSType *ns) const;
224};
225
229class DeleteVisitor : public boost::static_visitor<void>
230{
231 public:
233 template<typename NSType>
234 void operator()(NSType *ns) const;
235};
236
247template<typename SortPolicy>
249{
250 public:
253 {
268 OCTREE
269 };
270
271 private:
273 TreeTypes treeType;
274
276 size_t leafSize;
277
279 double tau;
281 double rho;
282
284 bool randomBasis;
286 arma::mat q;
287
293 boost::variant<NSType<SortPolicy, tree::KDTree>*,
305 SpillKNN*,
308
309 public:
318 NSModel(TreeTypes treeType = TreeTypes::KD_TREE, bool randomBasis = false);
319
325 NSModel(const NSModel& other);
326
332 NSModel(NSModel&& other);
333
339 NSModel& operator=(const NSModel& other);
340
347
350
352 template<typename Archive>
353 void serialize(Archive& ar, const unsigned int /* version */);
354
356 const arma::mat& Dataset() const;
357
361
363 double Epsilon() const;
364 double& Epsilon();
365
367 size_t LeafSize() const { return leafSize; }
368 size_t& LeafSize() { return leafSize; }
369
371 double Tau() const { return tau; }
372 double& Tau() { return tau; }
373
375 double Rho() const { return rho; }
376 double& Rho() { return rho; }
377
379 TreeTypes TreeType() const { return treeType; }
380 TreeTypes& TreeType() { return treeType; }
381
383 bool RandomBasis() const { return randomBasis; }
384 bool& RandomBasis() { return randomBasis; }
385
387 void BuildModel(arma::mat&& referenceSet,
388 const size_t leafSize,
389 const NeighborSearchMode searchMode,
390 const double epsilon = 0);
391
393 void Search(arma::mat&& querySet,
394 const size_t k,
395 arma::Mat<size_t>& neighbors,
396 arma::mat& distances);
397
399 void Search(const size_t k,
400 arma::Mat<size_t>& neighbors,
401 arma::mat& distances);
402
404 std::string TreeName() const;
405};
406
407} // namespace neighbor
408} // namespace mlpack
409
411BOOST_TEMPLATE_CLASS_VERSION(template<typename SortPolicy>,
413
414// Include implementation.
415#include "ns_model_impl.hpp"
416
417#endif
BiSearchVisitor executes a bichromatic neighbor search on the given NSType.
Definition: ns_model.hpp:81
void operator()(SpillKNN *ns) const
Bichromatic neighbor search specialized for SPTrees.
void operator()(NSTypeT< tree::KDTree > *ns) const
Bichromatic neighbor search on the given NSType specialized for KDTrees.
void operator()(NSTypeT< tree::Octree > *ns) const
Bichromatic neighbor search specialized for octrees.
void operator()(NSTypeT< tree::BallTree > *ns) const
Bichromatic neighbor search on the given NSType specialized for BallTrees.
BiSearchVisitor(const arma::mat &querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances, const size_t leafSize, const double tau, const double rho)
Construct the BiSearchVisitor.
DeleteVisitor deletes the given NSType instance.
Definition: ns_model.hpp:230
void operator()(NSType *ns) const
Delete the NSType object.
EpsilonVisitor exposes the Epsilon method of the given NSType.
Definition: ns_model.hpp:208
double & operator()(NSType *ns) const
Return epsilon, the approximation parameter.
MonoSearchVisitor executes a monochromatic neighbor search on the given NSType.
Definition: ns_model.hpp:49
void operator()(NSType *ns) const
Perform monochromatic nearest neighbor search.
MonoSearchVisitor(const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
Construct the MonoSearchVisitor object with the given parameters.
Definition: ns_model.hpp:64
The NSModel class provides an easy way to serialize a model, abstracts away the different types of tr...
Definition: ns_model.hpp:249
const arma::mat & Dataset() const
Expose the dataset.
NeighborSearchMode SearchMode() const
Expose SearchMode.
~NSModel()
Clean memory, if necessary.
double Rho() const
Expose rho.
Definition: ns_model.hpp:375
double Tau() const
Expose tau.
Definition: ns_model.hpp:371
NSModel & operator=(const NSModel &other)
Copy the given NSModel.
void Search(const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
Perform monochromatic neighbor search.
TreeTypes & TreeType()
Definition: ns_model.hpp:380
NSModel(const NSModel &other)
Copy the given NSModel.
TreeTypes
Enum type to identify each accepted tree type.
Definition: ns_model.hpp:253
NSModel(TreeTypes treeType=TreeTypes::KD_TREE, bool randomBasis=false)
Initialize the NSModel with the given type and whether or not a random basis should be used.
NSModel & operator=(NSModel &&other)
Take ownership of the given NSModel.
NSModel(NSModel &&other)
Take ownership of the given NSModel.
std::string TreeName() const
Return a string representation of the current tree type.
void BuildModel(arma::mat &&referenceSet, const size_t leafSize, const NeighborSearchMode searchMode, const double epsilon=0)
Build the reference tree.
size_t LeafSize() const
Expose leafSize.
Definition: ns_model.hpp:367
NeighborSearchMode & SearchMode()
TreeTypes TreeType() const
Expose treeType.
Definition: ns_model.hpp:379
bool RandomBasis() const
Expose randomBasis.
Definition: ns_model.hpp:383
void serialize(Archive &ar, const unsigned int)
Serialize the neighbor search model.
void Search(arma::mat &&querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
Perform neighbor search. The query set will be reordered.
double Epsilon() const
Expose Epsilon.
Extra data for each node in the tree.
The NeighborSearch class is a template class for performing distance-based neighbor searches.
ReferenceSetVisitor exposes the referenceSet of the given NSType.
Definition: ns_model.hpp:219
const arma::mat & operator()(NSType *ns) const
Return the reference set.
SearchModeVisitor exposes the SearchMode() method of the given NSType.
Definition: ns_model.hpp:197
NeighborSearchMode & operator()(NSType *ns) const
Return the search mode.
TrainVisitor sets the reference set to a new reference set on the given NSType.
Definition: ra_model.hpp:128
void operator()(SpillKNN *ns) const
Train specialized for SPTrees.
void operator()(NSTypeT< tree::KDTree > *ns) const
Train on the given NSType specialized for KDTrees.
void operator()(NSTypeT< tree::Octree > *ns) const
Train specialized for octrees.
void operator()(NSTypeT< tree::BallTree > *ns) const
Train on the given NSType specialized for BallTrees.
TrainVisitor(arma::mat &&referenceSet, const size_t leafSize, const double tau, const double rho)
Construct the TrainVisitor object with the given reference set, leafSize for BinarySpaceTrees,...
LMetric< 2, true > EuclideanDistance
The Euclidean (L2) distance.
Definition: lmetric.hpp:112
NeighborSearchMode
NeighborSearchMode represents the different neighbor search modes available.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
BOOST_TEMPLATE_CLASS_VERSION(template< typename SortPolicy >, mlpack::neighbor::NSModel< SortPolicy >, 1)
Set the serialization version of the NSModel class.