mlpack 3.4.2
standard_scaler.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_CORE_DATA_STANDARD_SCALE_HPP
13#define MLPACK_CORE_DATA_STANDARD_SCALE_HPP
14
15#include <mlpack/prereqs.hpp>
16
17namespace mlpack {
18namespace data {
19
48{
49 public:
55 template<typename MatType>
56 void Fit(const MatType& input)
57 {
58 itemMean = arma::mean(input, 1);
59 itemStdDev = arma::stddev(input, 1, 1);
60 // Handle zeros in scale vector.
61 itemStdDev.for_each([](arma::vec::elem_type& val) { val =
62 (val == 0) ? 1 : val; });
63 }
64
71 template<typename MatType>
72 void Transform(const MatType& input, MatType& output)
73 {
74 if (itemMean.is_empty() || itemStdDev.is_empty())
75 {
76 throw std::runtime_error("Call Fit() before Transform(), please"
77 " refer to the documentation.");
78 }
79 output.copy_size(input);
80 output = (input.each_col() - itemMean).each_col() / itemStdDev;
81 }
82
89 template<typename MatType>
90 void InverseTransform(const MatType& input, MatType& output)
91 {
92 output.copy_size(input);
93 output = (input.each_col() % itemStdDev).each_col() + itemMean;
94 }
95
97 const arma::vec& ItemMean() const { return itemMean; }
99 const arma::vec& ItemStdDev() const { return itemStdDev; }
100
101 template<typename Archive>
102 void serialize(Archive& ar, const unsigned int /* version */)
103 {
104 ar & BOOST_SERIALIZATION_NVP(itemMean);
105 ar & BOOST_SERIALIZATION_NVP(itemStdDev);
106 }
107
108 private:
109 // Vector which holds mean of each feature.
110 arma::vec itemMean;
111 // Vector which holds standard devation of each feature.
112 arma::vec itemStdDev;
113}; // class StandardScaler
114
115} // namespace data
116} // namespace mlpack
117
118#endif
A simple Standard Scaler class.
void Fit(const MatType &input)
Function to fit features, to find out the min max and scale.
const arma::vec & ItemMean() const
Get the mean row vector.
const arma::vec & ItemStdDev() const
Get the standard deviation row vector.
void Transform(const MatType &input, MatType &output)
Function to scale features.
void serialize(Archive &ar, const unsigned int)
void InverseTransform(const MatType &input, MatType &output)
Function to retrieve original dataset.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.