mlpack 3.4.2
regression_distribution.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_CORE_DISTRIBUTIONS_REGRESSION_DISTRIBUTION_HPP
14#define MLPACK_CORE_DISTRIBUTIONS_REGRESSION_DISTRIBUTION_HPP
15
16#include <mlpack/prereqs.hpp>
19
20namespace mlpack {
21namespace distribution {
22
32{
33 private:
38
39 public:
43 RegressionDistribution() { /* nothing to do */ }
44
52 mlpack_deprecated RegressionDistribution(const arma::mat& predictors,
53 const arma::vec& responses) :
54 RegressionDistribution(predictors, arma::rowvec(responses.t()))
55 {}
56
64 RegressionDistribution(const arma::mat& predictors,
65 const arma::rowvec& responses)
66 {
67 rf.Train(predictors, responses);
68 err = GaussianDistribution(1);
69 arma::mat cov(1, 1);
70 cov(0, 0) = rf.ComputeError(predictors, responses);
71 err.Covariance(std::move(cov));
72 }
73
77 template<typename Archive>
78 void serialize(Archive& ar, const unsigned int /* version */)
79 {
80 ar & BOOST_SERIALIZATION_NVP(rf);
81 ar & BOOST_SERIALIZATION_NVP(err);
82 }
83
85 const regression::LinearRegression& Rf() const { return rf; }
88
90 const GaussianDistribution& Err() const { return err; }
92 GaussianDistribution& Err() { return err; }
93
99 void Train(const arma::mat& observations);
100
107 mlpack_deprecated void Train(const arma::mat& observations,
108 const arma::vec& weights);
109
116 void Train(const arma::mat& observations, const arma::rowvec& weights);
117
123 double Probability(const arma::vec& observation) const;
124
130 double LogProbability(const arma::vec& observation) const
131 {
132 return log(Probability(observation));
133 }
134
141 mlpack_deprecated void Predict(const arma::mat& points,
142 arma::vec& predictions) const;
143
150 void Predict(const arma::mat& points, arma::rowvec& predictions) const;
151
153 const arma::vec& Parameters() const { return rf.Parameters(); }
154
156 size_t Dimensionality() const { return rf.Parameters().n_elem; }
157};
158
159
160} // namespace distribution
161} // namespace mlpack
162
163#endif
A single multivariate Gaussian distribution.
const arma::mat & Covariance() const
Return the covariance matrix.
A class that represents a univariate conditionally Gaussian distribution.
void Train(const arma::mat &observations)
Estimate the Gaussian distribution directly from the given observations.
RegressionDistribution()
Default constructor, which creates a Gaussian with zero dimension.
const arma::vec & Parameters() const
Return the parameters (the b vector).
mlpack_deprecated void Train(const arma::mat &observations, const arma::vec &weights)
Estimate parameters using provided observation weights.
double LogProbability(const arma::vec &observation) const
Evaluate log probability density function of given observation.
void Predict(const arma::mat &points, arma::rowvec &predictions) const
Calculate y_i for each data point in points.
size_t Dimensionality() const
Return the dimensionality.
mlpack_deprecated void Predict(const arma::mat &points, arma::vec &predictions) const
Calculate y_i for each data point in points.
RegressionDistribution(const arma::mat &predictors, const arma::rowvec &responses)
Create a Conditional Gaussian distribution with conditional mean function obtained by running Regress...
double Probability(const arma::vec &observation) const
Evaluate probability density function of given observation.
regression::LinearRegression & Rf()
Modify regression function.
GaussianDistribution & Err()
Modify error distribution.
const GaussianDistribution & Err() const
Return error distribution.
void Train(const arma::mat &observations, const arma::rowvec &weights)
Estimate parameters using provided observation weights.
const regression::LinearRegression & Rf() const
Return regression function.
mlpack_deprecated RegressionDistribution(const arma::mat &predictors, const arma::vec &responses)
Create a Conditional Gaussian distribution with conditional mean function obtained by running Regress...
void serialize(Archive &ar, const unsigned int)
Serialize the distribution.
A simple linear regression algorithm using ordinary least squares.
const arma::vec & Parameters() const
Return the parameters (the b vector).
double Train(const arma::mat &predictors, const arma::rowvec &responses, const bool intercept=true)
Train the LinearRegression model on the given data.
double ComputeError(const arma::mat &points, const arma::rowvec &responses) const
Calculate the L2 squared error on the given predictors and responses using this linear regression mod...
#define mlpack_deprecated
Definition: deprecated.hpp:22
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.