mlpack 3.4.2
simple_cv.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_CORE_CV_SIMPLE_CV_HPP
13#define MLPACK_CORE_CV_SIMPLE_CV_HPP
14
17
18namespace mlpack {
19namespace cv {
20
60template<typename MLAlgorithm,
61 typename Metric,
62 typename MatType = arma::mat,
63 typename PredictionsType =
65 typename WeightsType =
66 typename MetaInfoExtractor<MLAlgorithm, MatType,
67 PredictionsType>::WeightsType>
69{
70 public:
84 template<typename MatInType, typename PredictionsInType>
85 SimpleCV(const double validationSize,
86 MatInType&& xs,
87 PredictionsInType&& ys);
88
101 template<typename MatInType, typename PredictionsInType>
102 SimpleCV(const double validationSize,
103 MatInType&& xs,
104 PredictionsInType&& ys,
105 const size_t numClasses);
106
121 template<typename MatInType, typename PredictionsInType>
122 SimpleCV(const double validationSize,
123 MatInType&& xs,
124 const data::DatasetInfo& datasetInfo,
125 PredictionsInType&& ys,
126 const size_t numClasses);
127
143 template<typename MatInType,
144 typename PredictionsInType,
145 typename WeightsInType>
146 SimpleCV(const double validationSize,
147 MatInType&& xs,
148 PredictionsInType&& ys,
149 WeightsInType&& weights);
150
166 template<typename MatInType,
167 typename PredictionsInType,
168 typename WeightsInType>
169 SimpleCV(const double validationSize,
170 MatInType&& xs,
171 PredictionsInType&& ys,
172 const size_t numClasses,
173 WeightsInType&& weights);
174
191 template<typename MatInType,
192 typename PredictionsInType,
193 typename WeightsInType>
194 SimpleCV(const double validationSize,
195 MatInType&& xs,
196 const data::DatasetInfo& datasetInfo,
197 PredictionsInType&& ys,
198 const size_t numClasses,
199 WeightsInType&& weights);
200
208 template<typename... MLAlgorithmArgs>
209 double Evaluate(const MLAlgorithmArgs&... args);
210
212 MLAlgorithm& Model();
213
214 private:
217
219 Base base;
220
222 MatType xs;
224 PredictionsType ys;
226 WeightsType weights;
227
229 MatType trainingXs;
231 PredictionsType trainingYs;
233 WeightsType trainingWeights;
234
236 MatType validationXs;
238 PredictionsType validationYs;
239
241 std::unique_ptr<MLAlgorithm> modelPtr;
242
247 template<typename MatInType,
248 typename PredictionsInType>
249 SimpleCV(Base&& base,
250 const double validationSize,
251 MatInType&& xs,
252 PredictionsInType&& ys);
253
258 template<typename MatInType,
259 typename PredictionsInType,
260 typename WeightsInType>
261 SimpleCV(Base&& base,
262 const double validationSize,
263 MatInType&& xs,
264 PredictionsInType&& ys,
265 WeightsInType&& weights);
266
270 size_t CalculateAndAssertNumberOfTrainingPoints(const double validationSize);
271
275 template<typename ElementType>
276 arma::Mat<ElementType> GetSubset(arma::Mat<ElementType>& m,
277 const size_t firstCol,
278 const size_t lastCol);
279
283 template<typename ElementType>
284 arma::Row<ElementType> GetSubset(arma::Row<ElementType>& r,
285 const size_t firstCol,
286 const size_t lastCol);
287
291 template<typename... MLAlgorithmArgs,
292 bool Enabled = !Base::MIE::SupportsWeights,
293 typename = typename std::enable_if<Enabled>::type>
294 double TrainAndEvaluate(const MLAlgorithmArgs&... args);
295
299 template<typename... MLAlgorithmArgs,
300 bool Enabled = Base::MIE::SupportsWeights,
301 typename = typename std::enable_if<Enabled>::type,
302 typename = void>
303 double TrainAndEvaluate(const MLAlgorithmArgs&... args);
304};
305
306} // namespace cv
307} // namespace mlpack
308
309// Include implementation
310#include "simple_cv_impl.hpp"
311
312#endif
An auxiliary class for cross-validation.
Definition: cv_base.hpp:40
typename Select< TF1, TF2, TF3, TF4, TF5 >::Type::PredictionsType PredictionsType
The type of predictions used in MLAlgorithm.
static const bool SupportsWeights
An indication whether MLAlgorithm supports weighted learning.
SimpleCV splits data into two sets - training and validation sets - and then runs training on the tra...
Definition: simple_cv.hpp:69
double Evaluate(const MLAlgorithmArgs &... args)
Train on the training set and assess performance on the validation set by using the class Metric.
SimpleCV(const double validationSize, MatInType &&xs, const data::DatasetInfo &datasetInfo, PredictionsInType &&ys, const size_t numClasses)
This constructor can be used for multiclass classification algorithms that can take a data::DatasetIn...
SimpleCV(const double validationSize, MatInType &&xs, PredictionsInType &&ys)
This constructor can be used for regression algorithms and for binary classification algorithms.
SimpleCV(const double validationSize, MatInType &&xs, PredictionsInType &&ys, const size_t numClasses)
This constructor can be used for multiclass classification algorithms.
SimpleCV(const double validationSize, MatInType &&xs, const data::DatasetInfo &datasetInfo, PredictionsInType &&ys, const size_t numClasses, WeightsInType &&weights)
This constructor can be used for multiclass classification algorithms that can take a data::DatasetIn...
MLAlgorithm & Model()
Access and modify the last trained model.
SimpleCV(const double validationSize, MatInType &&xs, PredictionsInType &&ys, const size_t numClasses, WeightsInType &&weights)
This constructor can be used for multiclass classification algorithms that support weighted learning.
SimpleCV(const double validationSize, MatInType &&xs, PredictionsInType &&ys, WeightsInType &&weights)
This constructor can be used for regression and binary classification algorithms that support weighte...
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1