mlpack 3.4.2
test_function_tools.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_TESTS_TEST_FUNCTION_TOOLS_HPP
13#define MLPACK_TESTS_TEST_FUNCTION_TOOLS_HPP
14
15#include <mlpack/core.hpp>
16
18
19using namespace mlpack;
20using namespace mlpack::distribution;
21using namespace mlpack::regression;
22
33inline void LogisticRegressionTestData(arma::mat& data,
34 arma::mat& testData,
35 arma::mat& shuffledData,
36 arma::Row<size_t>& responses,
37 arma::Row<size_t>& testResponses,
38 arma::Row<size_t>& shuffledResponses)
39{
40 // Generate a two-Gaussian dataset.
41 GaussianDistribution g1(arma::vec("1.0 1.0 1.0"), arma::eye<arma::mat>(3, 3));
42 GaussianDistribution g2(arma::vec("9.0 9.0 9.0"), arma::eye<arma::mat>(3, 3));
43
44 data = arma::mat(3, 1000);
45 responses = arma::Row<size_t>(1000);
46 for (size_t i = 0; i < 500; ++i)
47 {
48 data.col(i) = g1.Random();
49 responses[i] = 0;
50 }
51 for (size_t i = 500; i < 1000; ++i)
52 {
53 data.col(i) = g2.Random();
54 responses[i] = 1;
55 }
56
57 // Shuffle the dataset.
58 arma::uvec indices = arma::shuffle(arma::linspace<arma::uvec>(0,
59 data.n_cols - 1, data.n_cols));
60 shuffledData = arma::mat(3, 1000);
61 shuffledResponses = arma::Row<size_t>(1000);
62 for (size_t i = 0; i < data.n_cols; ++i)
63 {
64 shuffledData.col(i) = data.col(indices[i]);
65 shuffledResponses[i] = responses[indices[i]];
66 }
67
68 // Create a test set.
69 testData = arma::mat(3, 1000);
70 testResponses = arma::Row<size_t>(1000);
71 for (size_t i = 0; i < 500; ++i)
72 {
73 testData.col(i) = g1.Random();
74 testResponses[i] = 0;
75 }
76 for (size_t i = 500; i < 1000; ++i)
77 {
78 testData.col(i) = g2.Random();
79 testResponses[i] = 1;
80 }
81}
82
83#endif
A single multivariate Gaussian distribution.
arma::vec Random() const
Return a randomly generated observation according to the probability distribution defined by this obj...
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
Probability distributions.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void LogisticRegressionTestData(arma::mat &data, arma::mat &testData, arma::mat &shuffledData, arma::Row< size_t > &responses, arma::Row< size_t > &testResponses, arma::Row< size_t > &shuffledResponses)
Create the data for the a logistic regression test.