mlpack 3.4.2
lmnn.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_METHODS_LMNN_LMNN_HPP
13#define MLPACK_METHODS_LMNN_LMNN_HPP
14
15#include <mlpack/prereqs.hpp>
17#include <ensmallen.hpp>
18
19#include "lmnn_function.hpp"
20
21namespace mlpack {
22namespace lmnn {
23
53template<typename MetricType = metric::SquaredEuclideanDistance,
54 typename OptimizerType = ens::AMSGrad>
55class LMNN
56{
57 public:
68 LMNN(const arma::mat& dataset,
69 const arma::Row<size_t>& labels,
70 const size_t k,
71 const MetricType metric = MetricType());
72
73
85 template<typename... CallbackTypes>
86 void LearnDistance(arma::mat& outputMatrix, CallbackTypes&&... callbacks);
87
88
90 const arma::mat& Dataset() const { return dataset; }
91
93 const arma::Row<size_t>& Labels() const { return labels; }
94
96 const double& Regularization() const { return regularization; }
98 double& Regularization() { return regularization; }
99
101 const size_t& Range() const { return range; }
103 size_t& Range() { return range; }
104
106 const size_t& K() const { return k; }
108 size_t K() { return k; }
109
111 const OptimizerType& Optimizer() const { return optimizer; }
112 OptimizerType& Optimizer() { return optimizer; }
113
114 private:
116 const arma::mat& dataset;
117
119 const arma::Row<size_t>& labels;
120
122 size_t k;
123
125 double regularization;
126
128 size_t range;
129
131 MetricType metric;
132
134 OptimizerType optimizer;
135}; // class LMNN
136
137} // namespace lmnn
138} // namespace mlpack
139
140// Include the implementation.
141#include "lmnn_impl.hpp"
142
143#endif
An implementation of Large Margin nearest neighbor metric learning technique.
Definition: lmnn.hpp:56
const arma::mat & Dataset() const
Get the dataset reference.
Definition: lmnn.hpp:90
size_t & Range()
Modify the range value.
Definition: lmnn.hpp:103
LMNN(const arma::mat &dataset, const arma::Row< size_t > &labels, const size_t k, const MetricType metric=MetricType())
Initialize the LMNN object, passing a dataset (distance metric is learned using this dataset) and lab...
OptimizerType & Optimizer()
Definition: lmnn.hpp:112
const size_t & Range() const
Access the range value.
Definition: lmnn.hpp:101
void LearnDistance(arma::mat &outputMatrix, CallbackTypes &&... callbacks)
Perform Large Margin Nearest Neighbors metric learning.
double & Regularization()
Modify the regularization value.
Definition: lmnn.hpp:98
const arma::Row< size_t > & Labels() const
Get the labels reference.
Definition: lmnn.hpp:93
const double & Regularization() const
Access the regularization value.
Definition: lmnn.hpp:96
const size_t & K() const
Access the value of k.
Definition: lmnn.hpp:106
const OptimizerType & Optimizer() const
Get the optimizer.
Definition: lmnn.hpp:111
size_t K()
Modify the value of k.
Definition: lmnn.hpp:108
LMetric< 2, false > SquaredEuclideanDistance
The squared Euclidean (L2) distance.
Definition: lmetric.hpp:107
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.