mlpack 3.4.2
lsh_search.hpp
Go to the documentation of this file.
1
47#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
48#define MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
49
50#include <mlpack/prereqs.hpp>
51
54
55#include <queue>
56
57namespace mlpack {
58namespace neighbor {
59
68template<
69 typename SortPolicy = NearestNeighborSort,
70 typename MatType = arma::mat
71>
73{
74 public:
97 LSHSearch(MatType referenceSet,
98 const arma::cube& projections,
99 const double hashWidth = 0.0,
100 const size_t secondHashSize = 99901,
101 const size_t bucketSize = 500);
102
125 LSHSearch(MatType referenceSet,
126 const size_t numProj,
127 const size_t numTables,
128 const double hashWidth = 0.0,
129 const size_t secondHashSize = 99901,
130 const size_t bucketSize = 500);
131
137
143 LSHSearch(const LSHSearch& other);
144
151
158
165
191 void Train(MatType referenceSet,
192 const size_t numProj,
193 const size_t numTables,
194 const double hashWidth = 0.0,
195 const size_t secondHashSize = 99901,
196 const size_t bucketSize = 500,
197 const arma::cube& projection = arma::cube());
198
220 void Search(const MatType& querySet,
221 const size_t k,
222 arma::Mat<size_t>& resultingNeighbors,
223 arma::mat& distances,
224 const size_t numTablesToSearch = 0,
225 const size_t T = 0);
226
246 void Search(const size_t k,
247 arma::Mat<size_t>& resultingNeighbors,
248 arma::mat& distances,
249 const size_t numTablesToSearch = 0,
250 size_t T = 0);
251
261 static double ComputeRecall(const arma::Mat<size_t>& foundNeighbors,
262 const arma::Mat<size_t>& realNeighbors);
263
270 template<typename Archive>
271 void serialize(Archive& ar, const unsigned int version);
272
274 size_t DistanceEvaluations() const { return distanceEvaluations; }
276 size_t& DistanceEvaluations() { return distanceEvaluations; }
277
279 const MatType& ReferenceSet() const { return referenceSet; }
280
282 size_t NumProjections() const { return projections.n_slices; }
283
285 const arma::mat& Offsets() const { return offsets; }
286
288 const arma::vec& SecondHashWeights() const { return secondHashWeights; }
289
291 size_t BucketSize() const { return bucketSize; }
292
294 const std::vector<arma::Col<size_t>>& SecondHashTable() const
295 { return secondHashTable; }
296
298 const arma::cube& Projections() { return projections; }
299
301 void Projections(const arma::cube& projTables)
302 {
303 // Simply call Train() with the given projection tables.
304 Train(referenceSet, numProj, numTables, hashWidth, secondHashSize,
305 bucketSize, projTables);
306 }
307
308 private:
324 template<typename VecType>
325 void ReturnIndicesFromTable(const VecType& queryPoint,
326 arma::uvec& referenceIndices,
327 size_t numTablesToSearch,
328 const size_t T) const;
329
343 void BaseCase(const size_t queryIndex,
344 const arma::uvec& referenceIndices,
345 const size_t k,
346 arma::Mat<size_t>& neighbors,
347 arma::mat& distances) const;
348
363 void BaseCase(const size_t queryIndex,
364 const arma::uvec& referenceIndices,
365 const size_t k,
366 const MatType& querySet,
367 arma::Mat<size_t>& neighbors,
368 arma::mat& distances) const;
369
384 void GetAdditionalProbingBins(const arma::vec& queryCode,
385 const arma::vec& queryCodeNotFloored,
386 const size_t T,
387 arma::mat& additionalProbingBins) const;
388
396 double PerturbationScore(const std::vector<bool>& A,
397 const arma::vec& scores) const;
398
406 bool PerturbationShift(std::vector<bool>& A) const;
407
416 bool PerturbationExpand(std::vector<bool>& A) const;
417
425 bool PerturbationValid(const std::vector<bool>& A) const;
426
428 MatType referenceSet;
429
431 size_t numProj;
433 size_t numTables;
434
436 arma::cube projections; // should be [numProj x dims] x numTables slices
437
439 arma::mat offsets; // should be numProj x numTables
440
442 double hashWidth;
443
445 size_t secondHashSize;
446
448 arma::vec secondHashWeights;
449
451 size_t bucketSize;
452
455 std::vector<arma::Col<size_t>> secondHashTable;
456
459 arma::Col<size_t> bucketContentSize;
460
463 arma::Col<size_t> bucketRowInHashTable;
464
466 size_t distanceEvaluations;
467
469 typedef std::pair<double, size_t> Candidate;
470
472 struct CandidateCmp {
473 bool operator()(const Candidate& c1, const Candidate& c2)
474 {
475 return !SortPolicy::IsBetter(c2.first, c1.first);
476 };
477 };
478
480 typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
481 CandidateList;
482}; // class LSHSearch
483
484} // namespace neighbor
485} // namespace mlpack
486
488BOOST_TEMPLATE_CLASS_VERSION(template<typename SortPolicy>,
490
491// Include implementation.
492#include "lsh_search_impl.hpp"
493
494#endif
The LSHSearch class; this class builds a hash on the reference set and uses this hash to compute the ...
Definition: lsh_search.hpp:73
size_t BucketSize() const
Get the bucket size of the second hash.
Definition: lsh_search.hpp:291
const arma::mat & Offsets() const
Get the offsets 'b' for each of the projections. (One 'b' per column.)
Definition: lsh_search.hpp:285
size_t NumProjections() const
Get the number of projections.
Definition: lsh_search.hpp:282
const MatType & ReferenceSet() const
Return the reference dataset.
Definition: lsh_search.hpp:279
LSHSearch(LSHSearch &&other)
Take ownership of the given LSH model.
LSHSearch()
Create an untrained LSH model.
const std::vector< arma::Col< size_t > > & SecondHashTable() const
Get the second hash table.
Definition: lsh_search.hpp:294
LSHSearch(MatType referenceSet, const arma::cube &projections, const double hashWidth=0.0, const size_t secondHashSize=99901, const size_t bucketSize=500)
This function initializes the LSH class.
LSHSearch & operator=(LSHSearch &&other)
Take ownership of the given LSH model.
void Search(const MatType &querySet, const size_t k, arma::Mat< size_t > &resultingNeighbors, arma::mat &distances, const size_t numTablesToSearch=0, const size_t T=0)
Compute the nearest neighbors of the points in the given query set and store the output in the given ...
static double ComputeRecall(const arma::Mat< size_t > &foundNeighbors, const arma::Mat< size_t > &realNeighbors)
Compute the recall (% of neighbors found) given the neighbors returned by LSHSearch::Search and a "gr...
const arma::vec & SecondHashWeights() const
Get the weights of the second hash.
Definition: lsh_search.hpp:288
void serialize(Archive &ar, const unsigned int version)
Serialize the LSH model.
const arma::cube & Projections()
Get the projection tables.
Definition: lsh_search.hpp:298
LSHSearch(const LSHSearch &other)
Copy the given LSH model.
LSHSearch(MatType referenceSet, const size_t numProj, const size_t numTables, const double hashWidth=0.0, const size_t secondHashSize=99901, const size_t bucketSize=500)
This function initializes the LSH class.
size_t & DistanceEvaluations()
Modify the number of distance evaluations performed.
Definition: lsh_search.hpp:276
void Train(MatType referenceSet, const size_t numProj, const size_t numTables, const double hashWidth=0.0, const size_t secondHashSize=99901, const size_t bucketSize=500, const arma::cube &projection=arma::cube())
Train the LSH model on the given dataset.
void Projections(const arma::cube &projTables)
Change the projection tables (this retrains the LSH model).
Definition: lsh_search.hpp:301
LSHSearch & operator=(const LSHSearch &other)
Copy the given LSH model.
void Search(const size_t k, arma::Mat< size_t > &resultingNeighbors, arma::mat &distances, const size_t numTablesToSearch=0, size_t T=0)
Compute the nearest neighbors and store the output in the given matrices.
size_t DistanceEvaluations() const
Return the number of distance evaluations performed.
Definition: lsh_search.hpp:274
BOOST_TEMPLATE_CLASS_VERSION(template< typename SortPolicy >, mlpack::neighbor::LSHSearch< SortPolicy >, 1)
Set the serialization version of the LSHSearch class.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.