mlpack 3.4.2
deterministic_set_visitor.hpp
Go to the documentation of this file.
1
14#ifndef MLPACK_METHODS_ANN_VISITOR_DETERMINISTIC_SET_VISITOR_HPP
15#define MLPACK_METHODS_ANN_VISITOR_DETERMINISTIC_SET_VISITOR_HPP
16
18
19#include <boost/variant.hpp>
20
21namespace mlpack {
22namespace ann {
23
28class DeterministicSetVisitor : public boost::static_visitor<void>
29{
30 public:
32 DeterministicSetVisitor(const bool deterministic = true);
33
35 template<typename LayerType>
36 void operator()(LayerType* layer) const;
37
38 void operator()(MoreTypes layer) const;
39
40 private:
42 const bool deterministic;
43
46 template<typename T>
47 typename std::enable_if<
48 HasDeterministicCheck<T, bool&(T::*)(void)>::value &&
49 HasModelCheck<T>::value, void>::type
50 LayerDeterministic(T* layer) const;
51
54 template<typename T>
55 typename std::enable_if<
56 !HasDeterministicCheck<T, bool&(T::*)(void)>::value &&
57 HasModelCheck<T>::value, void>::type
58 LayerDeterministic(T* layer) const;
59
62 template<typename T>
63 typename std::enable_if<
64 HasDeterministicCheck<T, bool&(T::*)(void)>::value &&
65 !HasModelCheck<T>::value, void>::type
66 LayerDeterministic(T* layer) const;
67
70 template<typename T>
71 typename std::enable_if<
72 !HasDeterministicCheck<T, bool&(T::*)(void)>::value &&
73 !HasModelCheck<T>::value, void>::type
74 LayerDeterministic(T* layer) const;
75};
76
77} // namespace ann
78} // namespace mlpack
79
80// Include implementation.
81#include "deterministic_set_visitor_impl.hpp"
82
83#endif
DeterministicSetVisitor set the deterministic parameter given the deterministic value.
void operator()(LayerType *layer) const
Set the deterministic parameter.
void operator()(MoreTypes layer) const
DeterministicSetVisitor(const bool deterministic=true)
Set the deterministic parameter given the current deterministic value.
boost::variant< Linear3D< arma::mat, arma::mat, NoRegularizer > *, Glimpse< arma::mat, arma::mat > *, Highway< arma::mat, arma::mat > *, MultiheadAttention< arma::mat, arma::mat, NoRegularizer > *, Recurrent< arma::mat, arma::mat > *, RecurrentAttention< arma::mat, arma::mat > *, ReinforceNormal< arma::mat, arma::mat > *, Reparametrization< arma::mat, arma::mat > *, Select< arma::mat, arma::mat > *, Sequential< arma::mat, arma::mat, false > *, Sequential< arma::mat, arma::mat, true > *, Subview< arma::mat, arma::mat > *, VRClassReward< arma::mat, arma::mat > *, VirtualBatchNorm< arma::mat, arma::mat > *, RBF< arma::mat, arma::mat, GaussianFunction > *, BaseLayer< GaussianFunction, arma::mat, arma::mat > *, PositionalEncoding< arma::mat, arma::mat > * > MoreTypes
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1