mlpack 3.4.2
mean_normalization.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_CORE_DATA_MEAN_NORMALIZATION_HPP
13#define MLPACK_CORE_DATA_MEAN_NORMALIZATION_HPP
14
15#include <mlpack/prereqs.hpp>
16
17namespace mlpack {
18namespace data {
19
47{
48 public:
54 template<typename MatType>
55 void Fit(const MatType& input)
56 {
57 itemMean = arma::mean(input, 1);
58 itemMin = arma::min(input, 1);
59 itemMax = arma::max(input, 1);
60 scale = itemMax - itemMin;
61 // Handling zeros in scale vector.
62 scale.for_each([](arma::vec::elem_type& val) { val =
63 (val == 0) ? 1 : val; });
64 }
65
72 template<typename MatType>
73 void Transform(const MatType& input, MatType& output)
74 {
75 if (itemMean.is_empty() || scale.is_empty())
76 {
77 throw std::runtime_error("Call Fit() before Transform(), please"
78 " refer to the documentation.");
79 }
80 output.copy_size(input);
81 output = (input.each_col() - itemMean).each_col() / scale;
82 }
83
90 template<typename MatType>
91 void InverseTransform(const MatType& input, MatType& output)
92 {
93 output.copy_size(input);
94 output = (input.each_col() % scale).each_col() + itemMean;
95 }
96
98 const arma::vec& ItemMean() const { return itemMean; }
100 const arma::vec& ItemMin() const { return itemMin; }
102 const arma::vec& ItemMax() const { return itemMax; }
104 const arma::vec& Scale() const { return scale; }
105
106 template<typename Archive>
107 void serialize(Archive& ar, const unsigned int /* version */)
108 {
109 ar & BOOST_SERIALIZATION_NVP(itemMin);
110 ar & BOOST_SERIALIZATION_NVP(itemMax);
111 ar & BOOST_SERIALIZATION_NVP(scale);
112 ar & BOOST_SERIALIZATION_NVP(itemMean);
113 }
114
115 private:
116 // Vector which holds mean of each feature.
117 arma::vec itemMean;
118 // Vector which holds minimum of each feature.
119 arma::vec itemMin;
120 // Vector which holds maximum of each feature.
121 arma::vec itemMax;
122 // Vector which is used to scale up each feature.
123 arma::vec scale;
124}; // class MeanNormalization
125
126} // namespace data
127} // namespace mlpack
128
129#endif
A simple Mean Normalization class.
void Fit(const MatType &input)
Function to fit features, to find out the min max and scale.
const arma::vec & ItemMean() const
Get the Mean row vector.
void Transform(const MatType &input, MatType &output)
Function to scale features.
const arma::vec & Scale() const
Get the Scale row vector.
const arma::vec & ItemMax() const
Get the Max row vector.
const arma::vec & ItemMin() const
Get the Min row vector.
void serialize(Archive &ar, const unsigned int)
void InverseTransform(const MatType &input, MatType &output)
Function to retrieve original dataset.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.