mlpack 3.4.2
mean_shift.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_HPP
14#define MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_HPP
15
16#include <mlpack/prereqs.hpp>
20#include <boost/utility.hpp>
21
22namespace mlpack {
23namespace meanshift {
24
48template<bool UseKernel = false,
49 typename KernelType = kernel::GaussianKernel,
50 typename MatType = arma::mat>
52{
53 public:
65 MeanShift(const double radius = 0,
66 const size_t maxIterations = 1000,
67 const KernelType kernel = KernelType());
68
75 double EstimateRadius(const MatType& data, const double ratio = 0.2);
76
89 void Cluster(const MatType& data,
90 arma::Row<size_t>& assignments,
91 arma::mat& centroids,
92 bool forceConvergence = true,
93 bool useSeeds = true);
94
96 size_t MaxIterations() const { return maxIterations; }
98 size_t& MaxIterations() { return maxIterations; }
99
101 double Radius() const { return radius; }
103 void Radius(double radius);
104
106 const KernelType& Kernel() const { return kernel; }
108 KernelType& Kernel() { return kernel; }
109
110 private:
124 void GenSeeds(const MatType& data,
125 const double binSize,
126 const int minFreq,
127 MatType& seeds);
128
137 template<bool ApplyKernel = UseKernel>
138 typename std::enable_if<ApplyKernel, bool>::type
139 CalculateCentroid(const MatType& data,
140 const std::vector<size_t>& neighbors,
141 const std::vector<double>& distances,
142 arma::colvec& centroid);
143
152 template<bool ApplyKernel = UseKernel>
153 typename std::enable_if<!ApplyKernel, bool>::type
154 CalculateCentroid(const MatType& data,
155 const std::vector<size_t>& neighbors,
156 const std::vector<double>&, /*unused*/
157 arma::colvec& centroid);
158
164 double radius;
165
167 size_t maxIterations;
168
170 KernelType kernel;
171};
172
173} // namespace meanshift
174} // namespace mlpack
175
176// Include implementation.
177#include "mean_shift_impl.hpp"
178
179#endif // MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_HPP
The standard Gaussian kernel.
This class implements mean shift clustering.
Definition: mean_shift.hpp:52
double Radius() const
Get the radius.
Definition: mean_shift.hpp:101
double EstimateRadius(const MatType &data, const double ratio=0.2)
Give an estimation of radius based on given dataset.
KernelType & Kernel()
Modify the kernel.
Definition: mean_shift.hpp:108
MeanShift(const double radius=0, const size_t maxIterations=1000, const KernelType kernel=KernelType())
Create a mean shift object and set the parameters which mean shift will be run with.
size_t MaxIterations() const
Get the maximum number of iterations.
Definition: mean_shift.hpp:96
size_t & MaxIterations()
Set the maximum number of iterations.
Definition: mean_shift.hpp:98
const KernelType & Kernel() const
Get the kernel.
Definition: mean_shift.hpp:106
void Radius(double radius)
Set the radius.
void Cluster(const MatType &data, arma::Row< size_t > &assignments, arma::mat &centroids, bool forceConvergence=true, bool useSeeds=true)
Perform mean shift clustering on the data, returning a list of cluster assignments and centroids.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.