mlpack 3.4.2
cv_base.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_CORE_CV_CV_BASE_HPP
13#define MLPACK_CORE_CV_CV_BASE_HPP
14
16
17namespace mlpack {
18namespace cv {
19
35template<typename MLAlgorithm,
36 typename MatType,
37 typename PredictionsType,
38 typename WeightsType>
39class CVBase
40{
41 public:
43 using MIE =
45
51
57 CVBase(const size_t numClasses);
58
66 CVBase(const data::DatasetInfo& datasetInfo,
67 const size_t numClasses);
68
72 static void AssertDataConsistency(const MatType& xs,
73 const PredictionsType& ys);
74
79 static void AssertWeightsConsistency(const MatType& xs,
80 const WeightsType& weights);
81
86 template<typename... MLAlgorithmArgs>
87 MLAlgorithm Train(const MatType& xs,
88 const PredictionsType& ys,
89 const MLAlgorithmArgs&... args);
90
95 template<typename... MLAlgorithmArgs>
96 MLAlgorithm Train(const MatType& xs,
97 const PredictionsType& ys,
98 const WeightsType& weights,
99 const MLAlgorithmArgs&... args);
100
101 private:
102 static_assert(MIE::IsSupported,
103 "The given MLAlgorithm is not supported by MetaInfoExtractor");
104
106 const data::DatasetInfo datasetInfo;
108 const bool isDatasetInfoPassed;
110 size_t numClasses;
111
115 static void AssertSizeEquality(const MatType& xs,
116 const PredictionsType& ys);
117
121 static void AssertWeightsSize(const MatType& xs,
122 const WeightsType& weights);
123
128 template<typename... MLAlgorithmArgs,
129 bool Enabled = !MIE::TakesNumClasses,
130 typename = typename std::enable_if<Enabled>::type>
131 MLAlgorithm TrainModel(const MatType& xs,
132 const PredictionsType& ys,
133 const MLAlgorithmArgs&... args);
134
139 template<typename... MLAlgorithmArgs,
141 typename = typename std::enable_if<Enabled>::type,
142 typename = void>
143 MLAlgorithm TrainModel(const MatType& xs,
144 const PredictionsType& ys,
145 const MLAlgorithmArgs&... args);
146
151 template<typename... MLAlgorithmArgs,
153 typename = typename std::enable_if<Enabled>::type,
154 typename = void,
155 typename = void>
156 MLAlgorithm TrainModel(const MatType& xs,
157 const PredictionsType& ys,
158 const MLAlgorithmArgs&... args);
159
164 template<typename... MLAlgorithmArgs,
165 bool Enabled = !MIE::TakesNumClasses,
166 typename = typename std::enable_if<Enabled>::type>
167 MLAlgorithm TrainModel(const MatType& xs,
168 const PredictionsType& ys,
169 const WeightsType& weights,
170 const MLAlgorithmArgs&... args);
171
176 template<typename... MLAlgorithmArgs,
178 typename = typename std::enable_if<Enabled>::type,
179 typename = void>
180 MLAlgorithm TrainModel(const MatType& xs,
181 const PredictionsType& ys,
182 const WeightsType& weights,
183 const MLAlgorithmArgs&... args);
184
189 template<typename... MLAlgorithmArgs,
191 typename = typename std::enable_if<Enabled>::type,
192 typename = void,
193 typename = void>
194 MLAlgorithm TrainModel(const MatType& xs,
195 const PredictionsType& ys,
196 const WeightsType& weights,
197 const MLAlgorithmArgs&... args);
198
208 template<bool ConstructableWithoutDatasetInfo,
209 typename... MLAlgorithmArgs,
210 typename =
211 typename std::enable_if<ConstructableWithoutDatasetInfo>::type>
212 MLAlgorithm TrainModel(const MatType& xs,
213 const PredictionsType& ys,
214 const MLAlgorithmArgs&... args);
215
220 template<bool ConstructableWithoutDatasetInfo,
221 typename... MLAlgorithmArgs,
222 typename =
223 typename std::enable_if<!ConstructableWithoutDatasetInfo>::type,
224 typename = void>
225 MLAlgorithm TrainModel(const MatType& xs,
226 const PredictionsType& ys,
227 const MLAlgorithmArgs&... args);
228};
229
230} // namespace cv
231} // namespace mlpack
232
233// Include implementation
234#include "cv_base_impl.hpp"
235
236#endif
An auxiliary class for cross-validation.
Definition: cv_base.hpp:40
CVBase(const size_t numClasses)
Assert that MLAlgorithm takes the numClasses parameter and store it.
CVBase()
Assert that MLAlgorithm doesn't take any additional basic parameters like numClasses.
MLAlgorithm Train(const MatType &xs, const PredictionsType &ys, const MLAlgorithmArgs &... args)
Train MLAlgorithm with given data points, predictions, and hyperparameters depending on what CVBase c...
static void AssertWeightsConsistency(const MatType &xs, const WeightsType &weights)
Assert weighted learning is supported and there is the equal number of data points and weights.
CVBase(const data::DatasetInfo &datasetInfo, const size_t numClasses)
Assert that MLAlgorithm takes the numClasses parameter and a data::DatasetInfo parameter and store th...
static void AssertDataConsistency(const MatType &xs, const PredictionsType &ys)
Assert there is the equal number of data points and predictions.
MLAlgorithm Train(const MatType &xs, const PredictionsType &ys, const WeightsType &weights, const MLAlgorithmArgs &... args)
Train MLAlgorithm with given data points, predictions, weights, and hyperparameters depending on what...
MetaInfoExtractor is a tool for extracting meta information about a given machine learning algorithm.
static const bool IsSupported
An indication whether PredictionsType has been identified (i.e.
static const bool TakesNumClasses
An indication whether MLAlgorithm takes the numClasses (size_t) parameter.
static const bool TakesDatasetInfo
An indication whether MLAlgorithm takes a data::DatasetInfo parameter.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1