13#ifndef MLPACK_CORE_DATA_SPLIT_DATA_HPP
14#define MLPACK_CORE_DATA_SPLIT_DATA_HPP
50template<
typename T,
typename U>
51void Split(
const arma::Mat<T>& input,
52 const arma::Row<U>& inputLabel,
53 arma::Mat<T>& trainData,
54 arma::Mat<T>& testData,
55 arma::Row<U>& trainLabel,
56 arma::Row<U>& testLabel,
57 const double testRatio,
58 const bool shuffleData =
true)
60 const size_t testSize =
static_cast<size_t>(input.n_cols * testRatio);
61 const size_t trainSize = input.n_cols - testSize;
62 trainData.set_size(input.n_rows, trainSize);
63 testData.set_size(input.n_rows, testSize);
64 trainLabel.set_size(trainSize);
65 testLabel.set_size(testSize);
69 arma::uvec order = arma::shuffle(arma::linspace<arma::uvec>(
70 0, input.n_cols - 1, input.n_cols));
73 trainData = input.cols(order.subvec(0, trainSize - 1));
74 trainLabel = inputLabel.cols(order.subvec(0, trainSize - 1));
76 if (trainSize < input.n_cols)
78 testData = input.cols(order.subvec(trainSize, input.n_cols - 1));
79 testLabel = inputLabel.cols(order.subvec(trainSize, input.n_cols - 1));
86 trainData = input.cols(0, trainSize - 1);
87 trainLabel = inputLabel.subvec(0, trainSize - 1);
89 if (trainSize < input.n_cols)
91 testData = input.cols(trainSize , input.n_cols - 1);
92 testLabel = inputLabel.subvec(trainSize , input.n_cols - 1);
121void Split(
const arma::Mat<T>& input,
122 arma::Mat<T>& trainData,
123 arma::Mat<T>& testData,
124 const double testRatio,
125 const bool shuffleData =
true)
127 const size_t testSize =
static_cast<size_t>(input.n_cols * testRatio);
128 const size_t trainSize = input.n_cols - testSize;
129 trainData.set_size(input.n_rows, trainSize);
130 testData.set_size(input.n_rows, testSize);
134 arma::uvec order = arma::shuffle(arma::linspace<arma::uvec>(
135 0, input.n_cols - 1, input.n_cols));
138 trainData = input.cols(order.subvec(0, trainSize - 1));
140 if (trainSize < input.n_cols)
141 testData = input.cols(order.subvec(trainSize, input.n_cols - 1));
146 trainData = input.cols(0, trainSize - 1);
148 if (trainSize < input.n_cols)
149 testData = input.cols(trainSize , input.n_cols - 1);
174template<
typename T,
typename U>
175std::tuple<arma::Mat<T>, arma::Mat<T>, arma::Row<U>, arma::Row<U>>
177 const arma::Row<U>& inputLabel,
178 const double testRatio,
179 const bool shuffleData =
true)
181 arma::Mat<T> trainData;
182 arma::Mat<T> testData;
183 arma::Row<U> trainLabel;
184 arma::Row<U> testLabel;
186 Split(input, inputLabel, trainData, testData, trainLabel, testLabel,
187 testRatio, shuffleData);
189 return std::make_tuple(std::move(trainData),
191 std::move(trainLabel),
192 std::move(testLabel));
214std::tuple<arma::Mat<T>, arma::Mat<T>>
216 const double testRatio,
217 const bool shuffleData =
true)
219 arma::Mat<T> trainData;
220 arma::Mat<T> testData;
221 Split(input, trainData, testData, testRatio, shuffleData);
223 return std::make_tuple(std::move(trainData),
224 std::move(testData));
void Split(const arma::Mat< T > &input, const arma::Row< U > &inputLabel, arma::Mat< T > &trainData, arma::Mat< T > &testData, arma::Row< U > &trainLabel, arma::Row< U > &testLabel, const double testRatio, const bool shuffleData=true)
Given an input dataset and labels, split into a training set and test set.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.