mlpack 3.4.2
fastmks_rules.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_METHODS_FASTMKS_FASTMKS_RULES_HPP
13#define MLPACK_METHODS_FASTMKS_FASTMKS_RULES_HPP
14
15#include <mlpack/prereqs.hpp>
19#include <boost/heap/priority_queue.hpp>
20
21namespace mlpack {
22namespace fastmks {
23
33template<typename KernelType, typename TreeType>
35{
36 public:
46 FastMKSRules(const typename TreeType::Mat& referenceSet,
47 const typename TreeType::Mat& querySet,
48 const size_t k,
49 KernelType& kernel);
50
57 void GetResults(arma::Mat<size_t>& indices, arma::mat& products);
58
60 double BaseCase(const size_t queryIndex, const size_t referenceIndex);
61
70 double Score(const size_t queryIndex, TreeType& referenceNode);
71
80 double Score(TreeType& queryNode, TreeType& referenceNode);
81
93 double Rescore(const size_t queryIndex,
94 TreeType& referenceNode,
95 const double oldScore) const;
96
108 double Rescore(TreeType& queryNode,
109 TreeType& referenceNode,
110 const double oldScore) const;
111
113 size_t BaseCases() const { return baseCases; }
115 size_t& BaseCases() { return baseCases; }
116
118 size_t Scores() const { return scores; }
120 size_t& Scores() { return scores; }
121
123
124 const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
125 TraversalInfoType& TraversalInfo() { return traversalInfo; }
126
129 size_t MinimumBaseCases() const { return k; }
130
131 private:
133 const typename TreeType::Mat& referenceSet;
135 const typename TreeType::Mat& querySet;
136
138 typedef std::pair<double, size_t> Candidate;
139
141 struct CandidateCmp {
142 bool operator()(const Candidate& c1, const Candidate& c2) const
143 {
144 return c1.first > c2.first;
145 };
146 };
147
152 typedef boost::heap::priority_queue<Candidate,
153 boost::heap::compare<CandidateCmp>> CandidateList;
154
156 std::vector<CandidateList> candidates;
157
159 const size_t k;
160
162 arma::vec queryKernels;
164 arma::vec referenceKernels;
165
167 KernelType& kernel;
168
170 size_t lastQueryIndex;
172 size_t lastReferenceIndex;
174 double lastKernel;
175
177 double CalculateBound(TreeType& queryNode) const;
178
186 void InsertNeighbor(const size_t queryIndex,
187 const size_t index,
188 const double product);
189
191 size_t baseCases;
193 size_t scores;
194
195 TraversalInfoType traversalInfo;
196};
197
198} // namespace fastmks
199} // namespace mlpack
200
201// Include implementation.
202#include "fastmks_rules_impl.hpp"
203
204#endif
The FastMKSRules class is a template helper class used by FastMKS class when performing exact max-ker...
double Score(TreeType &queryNode, TreeType &referenceNode)
Get the score for recursion order.
size_t BaseCases() const
Get the number of times BaseCase() was called.
size_t MinimumBaseCases() const
Get the minimum number of base cases we need to perform to have acceptable results.
size_t & BaseCases()
Modify the number of times BaseCase() was called.
size_t Scores() const
Get the number of times Score() was called.
double Rescore(TreeType &queryNode, TreeType &referenceNode, const double oldScore) const
Re-evaluate the score for recursion order.
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
const TraversalInfoType & TraversalInfo() const
TraversalInfoType & TraversalInfo()
void GetResults(arma::Mat< size_t > &indices, arma::mat &products)
Store the list of candidates for each query point in the given matrices.
tree::TraversalInfo< TreeType > TraversalInfoType
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
Compute the base case (kernel value) between two points.
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore) const
Re-evaluate the score for recursion order.
FastMKSRules(const typename TreeType::Mat &referenceSet, const typename TreeType::Mat &querySet, const size_t k, KernelType &kernel)
Construct the FastMKSRules object.
size_t & Scores()
Modify the number of times Score() was called.
The TraversalInfo class holds traversal information which is used in dual-tree (and single-tree) trav...
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.