mlpack 3.4.2
nca.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_METHODS_NCA_NCA_HPP
13#define MLPACK_METHODS_NCA_NCA_HPP
14
15#include <mlpack/prereqs.hpp>
17#include <ensmallen.hpp>
18
20
21namespace mlpack {
22namespace nca {
23
47template<typename MetricType = metric::SquaredEuclideanDistance,
48 typename OptimizerType = ens::StandardSGD>
49class NCA
50{
51 public:
61 NCA(const arma::mat& dataset,
62 const arma::Row<size_t>& labels,
63 MetricType metric = MetricType());
64
77 template<typename... CallbackTypes>
78 void LearnDistance(arma::mat& outputMatrix, CallbackTypes&&... callbacks);
79
81 const arma::mat& Dataset() const { return dataset; }
83 const arma::Row<size_t>& Labels() const { return labels; }
84
86 const OptimizerType& Optimizer() const { return optimizer; }
87 OptimizerType& Optimizer() { return optimizer; }
88
89 private:
91 const arma::mat& dataset;
93 const arma::Row<size_t>& labels;
94
96 MetricType metric;
97
100
102 OptimizerType optimizer;
103};
104
105} // namespace nca
106} // namespace mlpack
107
108// Include the implementation.
109#include "nca_impl.hpp"
110
111#endif
An implementation of Neighborhood Components Analysis, both a linear dimensionality reduction techniq...
Definition: nca.hpp:50
const arma::mat & Dataset() const
Get the dataset reference.
Definition: nca.hpp:81
OptimizerType & Optimizer()
Definition: nca.hpp:87
void LearnDistance(arma::mat &outputMatrix, CallbackTypes &&... callbacks)
Perform Neighborhood Components Analysis.
NCA(const arma::mat &dataset, const arma::Row< size_t > &labels, MetricType metric=MetricType())
Construct the Neighborhood Components Analysis object.
const arma::Row< size_t > & Labels() const
Get the labels reference.
Definition: nca.hpp:83
const OptimizerType & Optimizer() const
Get the optimizer.
Definition: nca.hpp:86
The "softmax" stochastic neighbor assignment probability function.
LMetric< 2, false > SquaredEuclideanDistance
The squared Euclidean (L2) distance.
Definition: lmetric.hpp:107
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.