mlpack 3.4.2
hpt.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_CORE_HPT_HPT_HPP
13#define MLPACK_CORE_HPT_HPT_HPP
14
17#include <ensmallen.hpp>
18
19namespace mlpack {
20namespace hpt {
21
85template<typename MLAlgorithm,
86 typename Metric,
87 template<typename, typename, typename, typename, typename> class CV,
88 typename OptimizerType = ens::GridSearch,
89 typename MatType = arma::mat,
90 typename PredictionsType =
91 typename cv::MetaInfoExtractor<MLAlgorithm,
92 MatType>::PredictionsType,
93 typename WeightsType =
94 typename cv::MetaInfoExtractor<MLAlgorithm, MatType,
95 PredictionsType>::WeightsType>
97{
98 public:
106 template<typename... CVArgs>
107 HyperParameterTuner(const CVArgs& ...args);
108
110 OptimizerType& Optimizer() { return optimizer; }
111
121 double RelativeDelta() const { return relativeDelta; }
122
132 double& RelativeDelta() { return relativeDelta; }
133
142 double MinDelta() const { return minDelta; }
143
152 double& MinDelta() { return minDelta; }
153
177 template<typename... Args>
178 TupleOfHyperParameters<Args...> Optimize(const Args&... args);
179
181 double BestObjective() const { return bestObjective; }
182
184 const MLAlgorithm& BestModel() const { return bestModel; }
185
187 MLAlgorithm& BestModel() { return bestModel; }
188
189 private:
193 template<typename OriginalMetric>
194 struct Negated
195 {
196 static double Evaluate(MLAlgorithm& model,
197 const MatType& xs,
198 const PredictionsType& ys)
199 { return -OriginalMetric::Evaluate(model, xs, ys); }
200 };
201
203 using CVType = typename std::conditional<Metric::NeedsMinimization,
204 CV<MLAlgorithm, Metric, MatType, PredictionsType, WeightsType>,
205 CV<MLAlgorithm, Negated<Metric>, MatType, PredictionsType,
206 WeightsType>>::type;
207
208
210 CVType cv;
211
213 OptimizerType optimizer;
214
216 double bestObjective;
217
219 MLAlgorithm bestModel;
220
225 double relativeDelta;
226
231 double minDelta;
232
237 template<typename Tuple, size_t I>
238 using IsPreFixed = IsPreFixedArg<typename std::tuple_element<I, Tuple>::type>;
239
244 template<typename Tuple, size_t I>
245 using IsArithmetic = std::is_arithmetic<typename std::remove_reference<
246 typename std::tuple_element<I, Tuple>::type>::type>;
247
255 template<size_t I /* Index of the next argument to handle. */,
256 typename ArgsTuple,
257 typename... FixedArgs,
259 inline void InitAndOptimize(
260 const ArgsTuple& args,
261 arma::mat& bestParams,
262 data::DatasetMapper<data::IncrementPolicy, double>& datasetInfo,
263 FixedArgs... fixedArgs);
264
273 template<size_t I /* Index of the next argument to handle. */,
274 typename ArgsTuple,
275 typename... FixedArgs,
276 typename = std::enable_if_t<(I < std::tuple_size<ArgsTuple>::value)>,
278 inline void InitAndOptimize(
279 const ArgsTuple& args,
280 arma::mat& bestParams,
281 data::DatasetMapper<data::IncrementPolicy, double>& datasetInfo,
282 FixedArgs... fixedArgs);
283
292 template<size_t I /* Index of the next argument to handle. */,
293 typename ArgsTuple,
294 typename... FixedArgs,
295 typename = std::enable_if_t<(I < std::tuple_size<ArgsTuple>::value)>,
297 IsArithmetic<ArgsTuple, I>::value>,
298 typename = void>
299 inline void InitAndOptimize(
300 const ArgsTuple& args,
301 arma::mat& bestParams,
302 data::DatasetMapper<data::IncrementPolicy, double>& datasetInfo,
303 FixedArgs... fixedArgs);
304
313 template<size_t I /* Index of the next argument to handle. */,
314 typename ArgsTuple,
315 typename... FixedArgs,
316 typename = std::enable_if_t<(I < std::tuple_size<ArgsTuple>::value)>,
318 !IsArithmetic<ArgsTuple, I>::value>,
319 typename = void,
320 typename = void>
321 inline void InitAndOptimize(
322 const ArgsTuple& args,
323 arma::mat& bestParams,
324 data::DatasetMapper<data::IncrementPolicy, double>& datasetInfo,
325 FixedArgs... fixedArgs);
326
331 template<typename TupleType,
332 size_t I /* Index of the element in vector to handle. */,
333 typename... Args,
334 typename = typename
335 std::enable_if_t<(I < std::tuple_size<TupleType>::value)>>
336 inline TupleType VectorToTuple(const arma::vec& vector, const Args&... args);
337
341 template<typename TupleType,
342 size_t I /* Index of the element in vector to handle. */,
343 typename... Args,
344 typename = typename
346 typename = void>
347 inline TupleType VectorToTuple(const arma::vec& vector, const Args&... args);
348};
349
350} // namespace hpt
351} // namespace mlpack
352
353// Include implementation
354#include "hpt_impl.hpp"
355
356#endif
The class HyperParameterTuner for the given MLAlgorithm utilizes the provided Optimizer to find the v...
Definition: hpt.hpp:97
double & RelativeDelta()
Modify relative increase of arguments for calculation of partial derivatives (by the definition) in g...
Definition: hpt.hpp:132
OptimizerType & Optimizer()
Access and modify the optimizer.
Definition: hpt.hpp:110
HyperParameterTuner(const CVArgs &...args)
Create a HyperParameterTuner object by passing constructor arguments for the given cross-validation s...
double & MinDelta()
Modify minimum increase of arguments for calculation of partial derivatives (by the definition) in gr...
Definition: hpt.hpp:152
MLAlgorithm & BestModel()
Modify the best model from the last run.
Definition: hpt.hpp:187
const MLAlgorithm & BestModel() const
Get the best model from the last run.
Definition: hpt.hpp:184
TupleOfHyperParameters< Args... > Optimize(const Args &... args)
Find the best hyper-parameters by using the given Optimizer.
double RelativeDelta() const
Get relative increase of arguments for calculation of partial derivatives (by the definition) in grad...
Definition: hpt.hpp:121
double BestObjective() const
Get the performance measurement of the best model from the last run.
Definition: hpt.hpp:181
double MinDelta() const
Get minimum increase of arguments for calculation of partial derivatives (by the definition) in gradi...
Definition: hpt.hpp:142
static const bool value
Definition: fixed.hpp:105
typename DeduceHyperParameterTypes< Args... >::TupleType TupleOfHyperParameters
A short alias for deducing types of hyper-parameters from types of arguments in the Optimize method i...
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:70