mlpack 3.4.2
linear_regression.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_LINEAR_REGRESSION_LINEAR_REGRESSION_HPP
14#define MLPACK_METHODS_LINEAR_REGRESSION_LINEAR_REGRESSION_HPP
15
16#include <mlpack/prereqs.hpp>
17
18namespace mlpack {
19namespace regression {
20
27{
28 public:
37 LinearRegression(const arma::mat& predictors,
38 const arma::rowvec& responses,
39 const double lambda = 0,
40 const bool intercept = true);
41
51 LinearRegression(const arma::mat& predictors,
52 const arma::rowvec& responses,
53 const arma::rowvec& weights,
54 const double lambda = 0,
55 const bool intercept = true);
56
62 LinearRegression() : lambda(0.0), intercept(true) { }
63
76 double Train(const arma::mat& predictors,
77 const arma::rowvec& responses,
78 const bool intercept = true);
79
93 double Train(const arma::mat& predictors,
94 const arma::rowvec& responses,
95 const arma::rowvec& weights,
96 const bool intercept = true);
97
104 void Predict(const arma::mat& points, arma::rowvec& predictions) const;
105
123 double ComputeError(const arma::mat& points,
124 const arma::rowvec& responses) const;
125
127 const arma::vec& Parameters() const { return parameters; }
129 arma::vec& Parameters() { return parameters; }
130
132 double Lambda() const { return lambda; }
134 double& Lambda() { return lambda; }
135
137 bool Intercept() const { return intercept; }
138
142 template<typename Archive>
143 void serialize(Archive& ar, const unsigned int /* version */)
144 {
145 ar & BOOST_SERIALIZATION_NVP(parameters);
146 ar & BOOST_SERIALIZATION_NVP(lambda);
147 ar & BOOST_SERIALIZATION_NVP(intercept);
148 }
149
150 private:
155 arma::vec parameters;
156
161 double lambda;
162
164 bool intercept;
165};
166
167} // namespace regression
168} // namespace mlpack
169
170#endif // MLPACK_METHODS_LINEAR_REGRESSION_HPP
A simple linear regression algorithm using ordinary least squares.
LinearRegression(const arma::mat &predictors, const arma::rowvec &responses, const arma::rowvec &weights, const double lambda=0, const bool intercept=true)
Creates the model with weighted learning.
arma::vec & Parameters()
Modify the parameters (the b vector).
const arma::vec & Parameters() const
Return the parameters (the b vector).
double & Lambda()
Modify the Tikhonov regularization parameter for ridge regression.
double Lambda() const
Return the Tikhonov regularization parameter for ridge regression.
LinearRegression(const arma::mat &predictors, const arma::rowvec &responses, const double lambda=0, const bool intercept=true)
Creates the model.
double Train(const arma::mat &predictors, const arma::rowvec &responses, const bool intercept=true)
Train the LinearRegression model on the given data.
void Predict(const arma::mat &points, arma::rowvec &predictions) const
Calculate y_i for each data point in points.
bool Intercept() const
Return whether or not an intercept term is used in the model.
double Train(const arma::mat &predictors, const arma::rowvec &responses, const arma::rowvec &weights, const bool intercept=true)
Train the LinearRegression model on the given data and weights.
double ComputeError(const arma::mat &points, const arma::rowvec &responses) const
Calculate the L2 squared error on the given predictors and responses using this linear regression mod...
void serialize(Archive &ar, const unsigned int)
Serialize the model.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.