mlpack 3.4.2
shuffle_data.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_CORE_MATH_SHUFFLE_DATA_HPP
13#define MLPACK_CORE_MATH_SHUFFLE_DATA_HPP
14
15#include <mlpack/prereqs.hpp>
16
17namespace mlpack {
18namespace math {
19
27template<typename MatType, typename LabelsType>
28void ShuffleData(const MatType& inputPoints,
29 const LabelsType& inputLabels,
30 MatType& outputPoints,
31 LabelsType& outputLabels,
32 const std::enable_if_t<!arma::is_SpMat<MatType>::value>* = 0,
33 const std::enable_if_t<!arma::is_Cube<MatType>::value>* = 0)
34{
35 // Generate ordering.
36 arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
37 inputPoints.n_cols - 1, inputPoints.n_cols));
38
39 outputPoints = inputPoints.cols(ordering);
40 outputLabels = inputLabels.cols(ordering);
41}
42
50template<typename MatType, typename LabelsType>
51void ShuffleData(const MatType& inputPoints,
52 const LabelsType& inputLabels,
53 MatType& outputPoints,
54 LabelsType& outputLabels,
55 const std::enable_if_t<arma::is_SpMat<MatType>::value>* = 0,
56 const std::enable_if_t<!arma::is_Cube<MatType>::value>* = 0)
57{
58 // Generate ordering.
59 arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
60 inputPoints.n_cols - 1, inputPoints.n_cols));
61
62 // Extract coordinate list representation.
63 arma::umat locations(2, inputPoints.n_nonzero);
64 arma::Col<typename MatType::elem_type> values(inputPoints.n_nonzero);
65 typename MatType::const_iterator it = inputPoints.begin();
66 size_t index = 0;
67 while (it != inputPoints.end())
68 {
69 locations(0, index) = it.row();
70 locations(1, index) = ordering[it.col()];
71 values(index) = (*it);
72 ++it;
73 ++index;
74 }
75
76 if (&inputPoints == &outputPoints || &inputLabels == &outputLabels)
77 {
78 MatType newOutputPoints(locations, values, inputPoints.n_rows,
79 inputPoints.n_cols, true);
80 LabelsType newOutputLabels(inputLabels.n_elem);
81 newOutputLabels.cols(ordering) = inputLabels;
82
83 outputPoints = std::move(newOutputPoints);
84 outputLabels = std::move(newOutputLabels);
85 }
86 else
87 {
88 outputPoints = MatType(locations, values, inputPoints.n_rows,
89 inputPoints.n_cols, true);
90 outputLabels.set_size(inputLabels.n_elem);
91 outputLabels.cols(ordering) = inputLabels;
92 }
93}
94
102template<typename MatType, typename LabelsType>
103void ShuffleData(const MatType& inputPoints,
104 const LabelsType& inputLabels,
105 MatType& outputPoints,
106 LabelsType& outputLabels,
107 const std::enable_if_t<!arma::is_SpMat<MatType>::value>* = 0,
108 const std::enable_if_t<arma::is_Cube<MatType>::value>* = 0,
109 const std::enable_if_t<arma::is_Cube<LabelsType>::value>* = 0)
110{
111 // Generate ordering.
112 arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
113 inputPoints.n_cols - 1, inputPoints.n_cols));
114
115 // Properly handle the case where the input and output data are the same
116 // object.
117 MatType* outputPointsPtr = &outputPoints;
118 LabelsType* outputLabelsPtr = &outputLabels;
119 if (&inputPoints == &outputPoints)
120 outputPointsPtr = new MatType();
121 if (&inputLabels == &outputLabels)
122 outputLabelsPtr = new LabelsType();
123
124 outputPointsPtr->set_size(inputPoints.n_rows, inputPoints.n_cols,
125 inputPoints.n_slices);
126 outputLabelsPtr->set_size(inputLabels.n_rows, inputLabels.n_cols,
127 inputLabels.n_slices);
128 for (size_t i = 0; i < ordering.n_elem; ++i)
129 {
130 outputPointsPtr->tube(0, ordering[i], outputPointsPtr->n_rows - 1,
131 ordering[i]) = inputPoints.tube(0, i, inputPoints.n_rows - 1, i);
132 outputLabelsPtr->tube(0, ordering[i], outputLabelsPtr->n_rows - 1,
133 ordering[i]) = inputLabels.tube(0, i, inputLabels.n_rows - 1, i);
134 }
135
136 // Clean up memory if needed.
137 if (&inputPoints == &outputPoints)
138 {
139 outputPoints = std::move(*outputPointsPtr);
140 delete outputPointsPtr;
141 }
142
143 if (&inputLabels == &outputLabels)
144 {
145 outputLabels = std::move(*outputLabelsPtr);
146 delete outputLabelsPtr;
147 }
148}
149
159template<typename MatType, typename LabelsType, typename WeightsType>
160void ShuffleData(const MatType& inputPoints,
161 const LabelsType& inputLabels,
162 const WeightsType& inputWeights,
163 MatType& outputPoints,
164 LabelsType& outputLabels,
165 WeightsType& outputWeights,
166 const std::enable_if_t<!arma::is_SpMat<MatType>::value>* = 0,
167 const std::enable_if_t<!arma::is_Cube<MatType>::value>* = 0)
168{
169 // Generate ordering.
170 arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
171 inputPoints.n_cols - 1, inputPoints.n_cols));
172
173 outputPoints = inputPoints.cols(ordering);
174 outputLabels = inputLabels.cols(ordering);
175 outputWeights = inputWeights.cols(ordering);
176}
177
187template<typename MatType, typename LabelsType, typename WeightsType>
188void ShuffleData(const MatType& inputPoints,
189 const LabelsType& inputLabels,
190 const WeightsType& inputWeights,
191 MatType& outputPoints,
192 LabelsType& outputLabels,
193 WeightsType& outputWeights,
194 const std::enable_if_t<arma::is_SpMat<MatType>::value>* = 0,
195 const std::enable_if_t<!arma::is_Cube<MatType>::value>* = 0)
196{
197 // Generate ordering.
198 arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
199 inputPoints.n_cols - 1, inputPoints.n_cols));
200
201 // Extract coordinate list representation.
202 arma::umat locations(2, inputPoints.n_nonzero);
203 arma::Col<typename MatType::elem_type> values(inputPoints.n_nonzero);
204 typename MatType::const_iterator it = inputPoints.begin();
205 size_t index = 0;
206 while (it != inputPoints.end())
207 {
208 locations(0, index) = it.row();
209 locations(1, index) = ordering[it.col()];
210 values(index) = (*it);
211 ++it;
212 ++index;
213 }
214
215 if (&inputPoints == &outputPoints || &inputLabels == &outputLabels ||
216 &inputWeights == &outputWeights)
217 {
218 MatType newOutputPoints(locations, values, inputPoints.n_rows,
219 inputPoints.n_cols, true);
220 LabelsType newOutputLabels(inputLabels.n_elem);
221 WeightsType newOutputWeights(inputWeights.n_elem);
222 newOutputLabels.cols(ordering) = inputLabels;
223 newOutputWeights.cols(ordering) = inputWeights;
224
225 outputPoints = std::move(newOutputPoints);
226 outputLabels = std::move(newOutputLabels);
227 outputWeights = std::move(newOutputWeights);
228 }
229 else
230 {
231 outputPoints = MatType(locations, values, inputPoints.n_rows,
232 inputPoints.n_cols, true);
233 outputLabels.set_size(inputLabels.n_elem);
234 outputLabels.cols(ordering) = inputLabels;
235 outputWeights.set_size(inputWeights.n_elem);
236 outputWeights.cols(ordering) = inputWeights;
237 }
238}
239
240} // namespace math
241} // namespace mlpack
242
243#endif
void ShuffleData(const MatType &inputPoints, const LabelsType &inputLabels, MatType &outputPoints, LabelsType &outputLabels, const std::enable_if_t<!arma::is_SpMat< MatType >::value > *=0, const std::enable_if_t<!arma::is_Cube< MatType >::value > *=0)
Shuffle a dataset and associated labels (or responses).
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:70
The core includes that mlpack expects; standard C++ includes and Armadillo.