13#ifndef MLPACK_METHODS_LMNN_FUNCTION_HPP
14#define MLPACK_METHODS_LMNN_FUNCTION_HPP
45template<
typename MetricType = metric::SquaredEucl
ideanDistance>
60 const arma::Row<size_t>& labels,
62 double regularization,
64 MetricType metric = MetricType());
79 double Evaluate(
const arma::mat& transformation);
93 double Evaluate(
const arma::mat& transformation,
95 const size_t batchSize = 1);
106 template<
typename GradType>
107 void Gradient(
const arma::mat& transformation, GradType& gradient);
124 template<
typename GradType>
128 const size_t batchSize = 1);
140 template<
typename GradType>
159 template<
typename GradType>
163 const size_t batchSize = 1);
175 const arma::mat&
Dataset()
const {
return dataset; }
183 const size_t&
K()
const {
return k; }
185 size_t&
K() {
return k; }
188 const size_t&
Range()
const {
return range; }
196 arma::Row<size_t> labels;
198 arma::mat initialPoint;
200 arma::mat transformedDataset;
202 arma::Mat<size_t> targetNeighbors;
204 arma::Mat<size_t> impostors;
226 arma::mat maxImpNorm;
228 arma::mat transformationOld;
230 std::vector<arma::mat> oldTransformationMatrices;
232 std::vector<size_t> oldTransformationCounts;
234 arma::vec lastTransformationIndices;
244 inline void Precalculate();
246 inline void UpdateCache(
const arma::mat& transformation,
248 const size_t batchSize);
250 inline void TransDiff(std::map<size_t, double>& transformationDiffs,
251 const arma::mat& transformation,
253 const size_t batchSize);
259#include "lmnn_function_impl.hpp"
Interface for generating distance based constraints on a given dataset, provided corresponding true l...
The Large Margin Nearest Neighbors function.
double EvaluateWithGradient(const arma::mat &transformation, GradType &gradient)
Evaluate the LMNN objective function together with gradient for the given transformation matrix.
const arma::mat & Dataset() const
Return the dataset passed into the constructor.
size_t & Range()
Modify the value of k.
size_t NumFunctions() const
Get the number of functions the objective function can be decomposed into.
void Shuffle()
Shuffle the points in the dataset.
double Evaluate(const arma::mat &transformation, const size_t begin, const size_t batchSize=1)
Evaluate the LMNN objective function for the given transformation matrix on the given batch size from...
const size_t & Range() const
Access the value of range.
double EvaluateWithGradient(const arma::mat &transformation, const size_t begin, GradType &gradient, const size_t batchSize=1)
Evaluate the LMNN objective function together with gradient for the given transformation matrix on th...
void Gradient(const arma::mat &transformation, GradType &gradient)
Evaluate the gradient of the LMNN function for the given transformation matrix.
double & Regularization()
Modify the regularization value.
double Evaluate(const arma::mat &transformation)
Evaluate the LMNN function for the given transformation matrix.
const double & Regularization() const
Access the regularization value.
size_t & K()
Modify the value of k.
const arma::mat & GetInitialPoint() const
Return the initial point for the optimization.
void Gradient(const arma::mat &transformation, const size_t begin, GradType &gradient, const size_t batchSize=1)
Evaluate the gradient of the LMNN function for the given transformation matrix on the given batch siz...
const size_t & K() const
Access the value of k.
LMNNFunction(const arma::mat &dataset, const arma::Row< size_t > &labels, size_t k, double regularization, size_t range, MetricType metric=MetricType())
Constructor for LMNNFunction class.
f To use this form of regularization
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.