mlpack 3.4.2
k_fold_cv.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_CORE_CV_K_FOLD_CV_HPP
13#define MLPACK_CORE_CV_K_FOLD_CV_HPP
14
17
18namespace mlpack {
19namespace cv {
20
57template<typename MLAlgorithm,
58 typename Metric,
59 typename MatType = arma::mat,
60 typename PredictionsType =
62 typename WeightsType =
63 typename MetaInfoExtractor<MLAlgorithm, MatType,
64 PredictionsType>::WeightsType>
66{
67 public:
78 KFoldCV(const size_t k,
79 const MatType& xs,
80 const PredictionsType& ys,
81 const bool shuffle = true);
82
92 KFoldCV(const size_t k,
93 const MatType& xs,
94 const PredictionsType& ys,
95 const size_t numClasses,
96 const bool shuffle = true);
97
109 KFoldCV(const size_t k,
110 const MatType& xs,
111 const data::DatasetInfo& datasetInfo,
112 const PredictionsType& ys,
113 const size_t numClasses,
114 const bool shuffle = true);
115
127 KFoldCV(const size_t k,
128 const MatType& xs,
129 const PredictionsType& ys,
130 const WeightsType& weights,
131 const bool shuffle = true);
132
144 KFoldCV(const size_t k,
145 const MatType& xs,
146 const PredictionsType& ys,
147 const size_t numClasses,
148 const WeightsType& weights,
149 const bool shuffle = true);
150
163 KFoldCV(const size_t k,
164 const MatType& xs,
165 const data::DatasetInfo& datasetInfo,
166 const PredictionsType& ys,
167 const size_t numClasses,
168 const WeightsType& weights,
169 const bool shuffle = true);
170
177 template<typename... MLAlgorithmArgs>
178 double Evaluate(const MLAlgorithmArgs& ...args);
179
181 MLAlgorithm& Model();
182
183 private:
186
187 public:
192 template<bool Enabled = !Base::MIE::SupportsWeights,
193 typename = typename std::enable_if<Enabled>::type>
194 void Shuffle();
195
200 template<bool Enabled = Base::MIE::SupportsWeights,
201 typename = typename std::enable_if<Enabled>::type,
202 typename = void>
203 void Shuffle();
204
205 private:
207 Base base;
208
210 const size_t k;
211
213 MatType xs;
215 PredictionsType ys;
217 WeightsType weights;
218
220 size_t lastBinSize;
221
223 size_t binSize;
224
226 std::unique_ptr<MLAlgorithm> modelPtr;
227
232 KFoldCV(Base&& base,
233 const size_t k,
234 const MatType& xs,
235 const PredictionsType& ys,
236 const bool shuffle);
237
242 KFoldCV(Base&& base,
243 const size_t k,
244 const MatType& xs,
245 const PredictionsType& ys,
246 const WeightsType& weights,
247 const bool shuffle);
248
253 template<typename DataType>
254 void InitKFoldCVMat(const DataType& source, DataType& destination);
255
259 template<typename... MLAlgorithmArgs,
260 bool Enabled = !Base::MIE::SupportsWeights,
261 typename = typename std::enable_if<Enabled>::type>
262 double TrainAndEvaluate(const MLAlgorithmArgs& ...mlAlgorithmArgs);
263
267 template<typename... MLAlgorithmArgs,
268 bool Enabled = Base::MIE::SupportsWeights,
269 typename = typename std::enable_if<Enabled>::type,
270 typename = void>
271 double TrainAndEvaluate(const MLAlgorithmArgs& ...mlAlgorithmArgs);
272
279 inline size_t ValidationSubsetFirstCol(const size_t i);
280
284 template<typename ElementType>
285 inline arma::Mat<ElementType> GetTrainingSubset(arma::Mat<ElementType>& m,
286 const size_t i);
287
291 template<typename ElementType>
292 inline arma::Row<ElementType> GetTrainingSubset(arma::Row<ElementType>& r,
293 const size_t i);
294
298 template<typename ElementType>
299 inline arma::Mat<ElementType> GetValidationSubset(arma::Mat<ElementType>& m,
300 const size_t i);
301
305 template<typename ElementType>
306 inline arma::Row<ElementType> GetValidationSubset(arma::Row<ElementType>& r,
307 const size_t i);
308};
309
310} // namespace cv
311} // namespace mlpack
312
313// Include implementation
314#include "k_fold_cv_impl.hpp"
315
316#endif
An auxiliary class for cross-validation.
Definition: cv_base.hpp:40
The class KFoldCV implements k-fold cross-validation for regression and classification algorithms.
Definition: k_fold_cv.hpp:66
KFoldCV(const size_t k, const MatType &xs, const PredictionsType &ys, const size_t numClasses, const WeightsType &weights, const bool shuffle=true)
This constructor can be used for multiclass classification algorithms that support weighted learning.
KFoldCV(const size_t k, const MatType &xs, const PredictionsType &ys, const WeightsType &weights, const bool shuffle=true)
This constructor can be used for regression and binary classification algorithms that support weighte...
KFoldCV(const size_t k, const MatType &xs, const data::DatasetInfo &datasetInfo, const PredictionsType &ys, const size_t numClasses, const bool shuffle=true)
This constructor can be used for multiclass classification algorithms that can take a data::DatasetIn...
KFoldCV(const size_t k, const MatType &xs, const PredictionsType &ys, const bool shuffle=true)
This constructor can be used for regression algorithms and for binary classification algorithms.
double Evaluate(const MLAlgorithmArgs &...args)
Run k-fold cross-validation.
void Shuffle()
Shuffle the data.
MLAlgorithm & Model()
Access and modify a model from the last run of k-fold cross-validation.
void Shuffle()
Shuffle the data.
KFoldCV(const size_t k, const MatType &xs, const PredictionsType &ys, const size_t numClasses, const bool shuffle=true)
This constructor can be used for multiclass classification algorithms.
KFoldCV(const size_t k, const MatType &xs, const data::DatasetInfo &datasetInfo, const PredictionsType &ys, const size_t numClasses, const WeightsType &weights, const bool shuffle=true)
This constructor can be used for multiclass classification algorithms that can take a data::DatasetIn...
typename Select< TF1, TF2, TF3, TF4, TF5 >::Type::PredictionsType PredictionsType
The type of predictions used in MLAlgorithm.
static const bool SupportsWeights
An indication whether MLAlgorithm supports weighted learning.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1