mlpack 3.4.2
rp_tree_mean_split.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_RP_TREE_MEAN_SPLIT_HPP
14#define MLPACK_CORE_TREE_BINARY_SPACE_TREE_RP_TREE_MEAN_SPLIT_HPP
15
16#include <mlpack/prereqs.hpp>
17#include "rp_tree_max_split.hpp"
20
21namespace mlpack {
22namespace tree {
23
32template<typename BoundType, typename MatType = arma::mat>
34{
35 public:
37 typedef typename MatType::elem_type ElemType;
39 struct SplitInfo
40 {
42 arma::Col<ElemType> direction;
44 arma::Col<ElemType> mean;
50 };
51
64 static bool SplitNode(const BoundType& /* bound */,
65 MatType& data,
66 const size_t begin,
67 const size_t count,
68 SplitInfo& splitInfo);
69
82 static size_t PerformSplit(MatType& data,
83 const size_t begin,
84 const size_t count,
85 const SplitInfo& splitInfo)
86 {
87 return split::PerformSplit<MatType, RPTreeMeanSplit>(data, begin, count,
88 splitInfo);
89 }
90
106 static size_t PerformSplit(MatType& data,
107 const size_t begin,
108 const size_t count,
109 const SplitInfo& splitInfo,
110 std::vector<size_t>& oldFromNew)
111 {
112 return split::PerformSplit<MatType, RPTreeMeanSplit>(data, begin, count,
113 splitInfo, oldFromNew);
114 }
115
122 template<typename VecType>
123 static bool AssignToLeftNode(const VecType& point, const SplitInfo& splitInfo)
124 {
125 if (splitInfo.meanSplit)
126 return arma::dot(point - splitInfo.mean, point - splitInfo.mean) <=
127 splitInfo.splitVal;
128
129 return (arma::dot(point, splitInfo.direction) <= splitInfo.splitVal);
130 }
131
132 private:
139 static ElemType GetAveragePointDistance(MatType& data,
140 const arma::uvec& samples);
141
151 static bool GetDotMedian(const MatType& data,
152 const arma::uvec& samples,
153 const arma::Col<ElemType>& direction,
154 ElemType& splitVal);
155
165 static bool GetMeanMedian(const MatType& data,
166 const arma::uvec& samples,
167 arma::Col<ElemType>& mean,
168 ElemType& splitVal);
169};
170
171} // namespace tree
172} // namespace mlpack
173
174// Include implementation.
175#include "rp_tree_mean_split_impl.hpp"
176
177#endif // MLPACK_CORE_TREE_BINARY_SPACE_TREE_RP_TREE_MEAN_SPLIT_HPP
This class splits a binary space tree.
MatType::elem_type ElemType
The element type held by the matrix type.
static bool SplitNode(const BoundType &, MatType &data, const size_t begin, const size_t count, SplitInfo &splitInfo)
Split the node according to the mean value in the dimension with maximum width.
static size_t PerformSplit(MatType &data, const size_t begin, const size_t count, const SplitInfo &splitInfo)
Perform the split process according to the information about the split.
static bool AssignToLeftNode(const VecType &point, const SplitInfo &splitInfo)
Indicates that a point should be assigned to the left subtree.
static size_t PerformSplit(MatType &data, const size_t begin, const size_t count, const SplitInfo &splitInfo, std::vector< size_t > &oldFromNew)
Perform the split process according to the information about the split and return the list of changed...
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.
An information about the partition.
ElemType splitVal
The value according to which the split will be performed.
arma::Col< ElemType > mean
The mean of some sampled points.
arma::Col< ElemType > direction
The normal to the hyperplane that will split the node.
bool meanSplit
Indicates that we should use the mean split algorithm instead of the median split.