mlpack 3.4.2
loss_visitor.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_ANN_VISITOR_LOSS_VISITOR_HPP
14#define MLPACK_METHODS_ANN_VISITOR_LOSS_VISITOR_HPP
15
17
18#include <boost/variant.hpp>
19
20namespace mlpack {
21namespace ann {
22
26class LossVisitor : public boost::static_visitor<double>
27{
28 public:
30 template<typename LayerType>
31 double operator()(LayerType* layer) const;
32
33 double operator()(MoreTypes layer) const;
34
35 private:
37 template<typename T>
38 typename std::enable_if<
39 !HasLoss<T, double(T::*)()>::value &&
40 !HasModelCheck<T>::value, double>::type
41 LayerLoss(T* layer) const;
42
44 template<typename T>
45 typename std::enable_if<
46 HasLoss<T, double(T::*)()>::value &&
47 !HasModelCheck<T>::value, double>::type
48 LayerLoss(T* layer) const;
49
51 template<typename T>
52 typename std::enable_if<
53 !HasLoss<T, double(T::*)()>::value &&
54 HasModelCheck<T>::value, double>::type
55 LayerLoss(T* layer) const;
56
58 template<typename T>
59 typename std::enable_if<
60 HasLoss<T, double(T::*)()>::value &&
61 HasModelCheck<T>::value, double>::type
62 LayerLoss(T* layer) const;
63};
64
65} // namespace ann
66} // namespace mlpack
67
68// Include implementation.
69#include "loss_visitor_impl.hpp"
70
71#endif
LossVisitor exposes the Loss() method of the given module.
double operator()(MoreTypes layer) const
double operator()(LayerType *layer) const
Return the Loss.
boost::variant< Linear3D< arma::mat, arma::mat, NoRegularizer > *, Glimpse< arma::mat, arma::mat > *, Highway< arma::mat, arma::mat > *, MultiheadAttention< arma::mat, arma::mat, NoRegularizer > *, Recurrent< arma::mat, arma::mat > *, RecurrentAttention< arma::mat, arma::mat > *, ReinforceNormal< arma::mat, arma::mat > *, Reparametrization< arma::mat, arma::mat > *, Select< arma::mat, arma::mat > *, Sequential< arma::mat, arma::mat, false > *, Sequential< arma::mat, arma::mat, true > *, Subview< arma::mat, arma::mat > *, VRClassReward< arma::mat, arma::mat > *, VirtualBatchNorm< arma::mat, arma::mat > *, RBF< arma::mat, arma::mat, GaussianFunction > *, BaseLayer< GaussianFunction, arma::mat, arma::mat > *, PositionalEncoding< arma::mat, arma::mat > * > MoreTypes
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1