mlpack 3.4.2
dtree.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_DET_DTREE_HPP
14#define MLPACK_METHODS_DET_DTREE_HPP
15
16#include <mlpack/prereqs.hpp>
17
18namespace mlpack {
19namespace det {
20
44template<typename MatType = arma::mat,
45 typename TagType = int>
46class DTree
47{
48 public:
50 typedef typename MatType::elem_type ElemType;
52 typedef typename MatType::vec_type VecType;
54 typedef typename arma::Col<ElemType> StatType;
55
60
66 DTree(const DTree& obj);
67
73 DTree& operator=(const DTree& obj);
74
80 DTree(DTree&& obj);
81
88
97 DTree(const StatType& maxVals,
98 const StatType& minVals,
99 const size_t totalPoints);
100
109 DTree(MatType& data);
110
123 DTree(const StatType& maxVals,
124 const StatType& minVals,
125 const size_t start,
126 const size_t end,
127 const double logNegError);
128
141 DTree(const StatType& maxVals,
142 const StatType& minVals,
143 const size_t totalPoints,
144 const size_t start,
145 const size_t end);
146
149
160 double Grow(MatType& data,
161 arma::Col<size_t>& oldFromNew,
162 const bool useVolReg = false,
163 const size_t maxLeafSize = 10,
164 const size_t minLeafSize = 5);
165
174 double PruneAndUpdate(const double oldAlpha,
175 const size_t points,
176 const bool useVolReg = false);
177
183 double ComputeValue(const VecType& query) const;
184
194 TagType TagTree(const TagType& tag = 0, bool everyNode = false);
195
196
203 TagType FindBucket(const VecType& query) const;
204
205
211 void ComputeVariableImportance(arma::vec& importances) const;
212
219 double LogNegativeError(const size_t totalPoints) const;
220
224 bool WithinRange(const VecType& query) const;
225
226 private:
227 // The indices in the complete set of points
228 // (after all forms of swapping in the original data
229 // matrix to align all the points in a node
230 // consecutively in the matrix. The 'old_from_new' array
231 // maps the points back to their original indices.
232
235 size_t start;
238 size_t end;
239
241 StatType maxVals;
243 StatType minVals;
244
246 size_t splitDim;
247
249 ElemType splitValue;
250
252 double logNegError;
253
255 double subtreeLeavesLogNegError;
256
258 size_t subtreeLeaves;
259
261 bool root;
262
264 double ratio;
265
267 double logVolume;
268
270 TagType bucketTag;
271
273 double alphaUpper;
274
276 DTree* left;
278 DTree* right;
279
280 public:
282 size_t Start() const { return start; }
284 size_t End() const { return end; }
286 size_t SplitDim() const { return splitDim; }
288 ElemType SplitValue() const { return splitValue; }
290 double LogNegError() const { return logNegError; }
292 double SubtreeLeavesLogNegError() const { return subtreeLeavesLogNegError; }
294 size_t SubtreeLeaves() const { return subtreeLeaves; }
297 double Ratio() const { return ratio; }
299 double LogVolume() const { return logVolume; }
301 DTree* Left() const { return left; }
303 DTree* Right() const { return right; }
305 bool Root() const { return root; }
307 double AlphaUpper() const { return alphaUpper; }
309 TagType BucketTag() const { return bucketTag; }
311 size_t NumChildren() const { return !left ? 0 : 2; }
312
319 DTree& Child(const size_t child) const { return !child ? *left : *right; }
320
321 DTree*& ChildPtr(const size_t child) { return (!child) ? left : right; }
322
324 const StatType& MaxVals() const { return maxVals; }
325
327 const StatType& MinVals() const { return minVals; }
328
332 template<typename Archive>
333 void serialize(Archive& ar, const unsigned int /* version */);
334
335 private:
336 // Utility methods.
337
341 bool FindSplit(const MatType& data,
342 size_t& splitDim,
343 ElemType& splitValue,
344 double& leftError,
345 double& rightError,
346 const size_t minLeafSize = 5) const;
347
351 size_t SplitData(MatType& data,
352 const size_t splitDim,
353 const ElemType splitValue,
354 arma::Col<size_t>& oldFromNew) const;
355
356 void FillMinMax(const StatType& mins,
357 const StatType& maxs);
358};
359
360} // namespace det
361} // namespace mlpack
362
363#include "dtree_impl.hpp"
364
365#endif // MLPACK_METHODS_DET_DTREE_HPP
A density estimation tree is similar to both a decision tree and a space partitioning tree (like a kd...
Definition: dtree.hpp:47
double Ratio() const
Return the ratio of points in this node to the points in the whole dataset.
Definition: dtree.hpp:297
DTree & operator=(const DTree &obj)
Copy the given tree.
const StatType & MinVals() const
Return the minimum values.
Definition: dtree.hpp:327
double LogNegativeError(const size_t totalPoints) const
Compute the log-negative-error for this point, given the total number of points in the dataset.
DTree(const StatType &maxVals, const StatType &minVals, const size_t start, const size_t end, const double logNegError)
Create a child node of a density estimation tree given the bounding box specified by maxVals and minV...
DTree *& ChildPtr(const size_t child)
Definition: dtree.hpp:321
void ComputeVariableImportance(arma::vec &importances) const
Compute the variable importance of each dimension in the learned tree.
size_t NumChildren() const
Return the number of children in this node.
Definition: dtree.hpp:311
DTree(const DTree &obj)
Create a tree that is the copy of the given tree.
MatType::elem_type ElemType
The actual, underlying type we're working with.
Definition: dtree.hpp:50
size_t SubtreeLeaves() const
Return the number of leaves which are descendants of this node.
Definition: dtree.hpp:294
TagType BucketTag() const
Return the current bucket's ID, if leaf, or -1 otherwise.
Definition: dtree.hpp:309
double Grow(MatType &data, arma::Col< size_t > &oldFromNew, const bool useVolReg=false, const size_t maxLeafSize=10, const size_t minLeafSize=5)
Greedily expand the tree.
DTree(DTree &&obj)
Create a tree by taking ownership of another tree (move constructor).
ElemType SplitValue() const
Return the split value of this node.
Definition: dtree.hpp:288
DTree * Left() const
Return the left child.
Definition: dtree.hpp:301
double LogNegError() const
Return the log negative error of this node.
Definition: dtree.hpp:290
DTree * Right() const
Return the right child.
Definition: dtree.hpp:303
double ComputeValue(const VecType &query) const
Compute the logarithm of the density estimate of a given query point.
size_t SplitDim() const
Return the split dimension of this node.
Definition: dtree.hpp:286
DTree(const StatType &maxVals, const StatType &minVals, const size_t totalPoints, const size_t start, const size_t end)
Create a child node of a density estimation tree given the bounding box specified by maxVals and minV...
MatType::vec_type VecType
The type of vector we are using.
Definition: dtree.hpp:52
arma::Col< ElemType > StatType
The statistic type we are holding.
Definition: dtree.hpp:54
DTree & operator=(DTree &&obj)
Take ownership of the given tree (move operator).
bool WithinRange(const VecType &query) const
Return whether a query point is within the range of this node.
double LogVolume() const
Return the inverse of the volume of this node.
Definition: dtree.hpp:299
DTree & Child(const size_t child) const
Return the specified child (0 will be left, 1 will be right).
Definition: dtree.hpp:319
size_t End() const
Return the first index of a point not contained in this node.
Definition: dtree.hpp:284
const StatType & MaxVals() const
Return the maximum values.
Definition: dtree.hpp:324
double AlphaUpper() const
Return the upper part of the alpha sum.
Definition: dtree.hpp:307
double PruneAndUpdate(const double oldAlpha, const size_t points, const bool useVolReg=false)
Perform alpha pruning on a tree.
TagType FindBucket(const VecType &query) const
Return the tag of the leaf containing the query.
double SubtreeLeavesLogNegError() const
Return the log negative error of all descendants of this node.
Definition: dtree.hpp:292
DTree(const StatType &maxVals, const StatType &minVals, const size_t totalPoints)
Create a density estimation tree with the given bounds and the given number of total points.
bool Root() const
Return whether or not this is the root of the tree.
Definition: dtree.hpp:305
DTree()
Create an empty density estimation tree.
TagType TagTree(const TagType &tag=0, bool everyNode=false)
Index the buckets for possible usage later; this results in every leaf in the tree having a specific ...
size_t Start() const
Return the starting index of points contained in this node.
Definition: dtree.hpp:282
void serialize(Archive &ar, const unsigned int)
Serialize the density estimation tree.
DTree(MatType &data)
Create a density estimation tree on the given data.
~DTree()
Clean up memory allocated by the tree.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.