mlpack 3.4.2
split_data.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_CORE_DATA_SPLIT_DATA_HPP
14#define MLPACK_CORE_DATA_SPLIT_DATA_HPP
15
16#include <mlpack/prereqs.hpp>
17
18namespace mlpack {
19namespace data {
50template<typename T, typename U>
51void Split(const arma::Mat<T>& input,
52 const arma::Row<U>& inputLabel,
53 arma::Mat<T>& trainData,
54 arma::Mat<T>& testData,
55 arma::Row<U>& trainLabel,
56 arma::Row<U>& testLabel,
57 const double testRatio,
58 const bool shuffleData = true)
59{
60 const size_t testSize = static_cast<size_t>(input.n_cols * testRatio);
61 const size_t trainSize = input.n_cols - testSize;
62 trainData.set_size(input.n_rows, trainSize);
63 testData.set_size(input.n_rows, testSize);
64 trainLabel.set_size(trainSize);
65 testLabel.set_size(testSize);
66
67 if (shuffleData)
68 {
69 arma::uvec order = arma::shuffle(arma::linspace<arma::uvec>(
70 0, input.n_cols - 1, input.n_cols));
71 if (trainSize > 0)
72 {
73 trainData = input.cols(order.subvec(0, trainSize - 1));
74 trainLabel = inputLabel.cols(order.subvec(0, trainSize - 1));
75 }
76 if (trainSize < input.n_cols)
77 {
78 testData = input.cols(order.subvec(trainSize, input.n_cols - 1));
79 testLabel = inputLabel.cols(order.subvec(trainSize, input.n_cols - 1));
80 }
81 }
82 else
83 {
84 if (trainSize > 0)
85 {
86 trainData = input.cols(0, trainSize - 1);
87 trainLabel = inputLabel.subvec(0, trainSize - 1);
88 }
89 if (trainSize < input.n_cols)
90 {
91 testData = input.cols(trainSize , input.n_cols - 1);
92 testLabel = inputLabel.subvec(trainSize , input.n_cols - 1);
93 }
94 }
95}
96
120template<typename T>
121void Split(const arma::Mat<T>& input,
122 arma::Mat<T>& trainData,
123 arma::Mat<T>& testData,
124 const double testRatio,
125 const bool shuffleData = true)
126{
127 const size_t testSize = static_cast<size_t>(input.n_cols * testRatio);
128 const size_t trainSize = input.n_cols - testSize;
129 trainData.set_size(input.n_rows, trainSize);
130 testData.set_size(input.n_rows, testSize);
131
132 if (shuffleData)
133 {
134 arma::uvec order = arma::shuffle(arma::linspace<arma::uvec>(
135 0, input.n_cols - 1, input.n_cols));
136
137 if (trainSize > 0)
138 trainData = input.cols(order.subvec(0, trainSize - 1));
139
140 if (trainSize < input.n_cols)
141 testData = input.cols(order.subvec(trainSize, input.n_cols - 1));
142 }
143 else
144 {
145 if (trainSize > 0)
146 trainData = input.cols(0, trainSize - 1);
147
148 if (trainSize < input.n_cols)
149 testData = input.cols(trainSize , input.n_cols - 1);
150 }
151}
152
174template<typename T, typename U>
175std::tuple<arma::Mat<T>, arma::Mat<T>, arma::Row<U>, arma::Row<U>>
176Split(const arma::Mat<T>& input,
177 const arma::Row<U>& inputLabel,
178 const double testRatio,
179 const bool shuffleData = true)
180{
181 arma::Mat<T> trainData;
182 arma::Mat<T> testData;
183 arma::Row<U> trainLabel;
184 arma::Row<U> testLabel;
185
186 Split(input, inputLabel, trainData, testData, trainLabel, testLabel,
187 testRatio, shuffleData);
188
189 return std::make_tuple(std::move(trainData),
190 std::move(testData),
191 std::move(trainLabel),
192 std::move(testLabel));
193}
194
213template<typename T>
214std::tuple<arma::Mat<T>, arma::Mat<T>>
215Split(const arma::Mat<T>& input,
216 const double testRatio,
217 const bool shuffleData = true)
218{
219 arma::Mat<T> trainData;
220 arma::Mat<T> testData;
221 Split(input, trainData, testData, testRatio, shuffleData);
222
223 return std::make_tuple(std::move(trainData),
224 std::move(testData));
225}
226
227} // namespace data
228} // namespace mlpack
229
230#endif
void Split(const arma::Mat< T > &input, const arma::Row< U > &inputLabel, arma::Mat< T > &trainData, arma::Mat< T > &testData, arma::Row< U > &trainLabel, arma::Row< U > &testLabel, const double testRatio, const bool shuffleData=true)
Given an input dataset and labels, split into a training set and test set.
Definition: split_data.hpp:51
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.