mlpack 3.4.2
poisson_nll_loss.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_ANN_LOSS_FUNCTIONS_POISSON_NLL_LOSS_HPP
14#define MLPACK_METHODS_ANN_LOSS_FUNCTIONS_POISSON_NLL_LOSS_HPP
15
16#include <mlpack/prereqs.hpp>
17
18namespace mlpack {
19namespace ann {
20
32template <
33 typename InputDataType = arma::mat,
34 typename OutputDataType = arma::mat
35>
37{
38 public:
50 PoissonNLLLoss(const bool logInput = true,
51 const bool full = false,
52 const typename InputDataType::elem_type eps = 1e-08,
53 const bool mean = true);
54
62 template<typename InputType, typename TargetType>
63 typename InputDataType::elem_type Forward(const InputType& input,
64 const TargetType& target);
65
77 template<typename InputType, typename TargetType, typename OutputType>
78 void Backward(const InputType& input,
79 const TargetType& target,
80 OutputType& output);
81
83 InputDataType& InputParameter() const { return inputParameter; }
85 InputDataType& InputParameter() { return inputParameter; }
86
88 OutputDataType& OutputParameter() const { return outputParameter; }
90 OutputDataType& OutputParameter() { return outputParameter; }
91
94 bool LogInput() const { return logInput; }
97 bool& LogInput() { return logInput; }
98
101 bool Full() const { return full; }
104 bool& Full() { return full; }
105
108 typename InputDataType::elem_type Eps() const { return eps; }
111 typename InputDataType::elem_type& Eps() { return eps; }
112
115 bool Mean() const { return mean; }
118 bool& Mean() { return mean; }
119
123 template<typename Archive>
124 void serialize(Archive& ar, const unsigned int /* version */);
125
126 private:
128 template<typename eT>
129 void CheckProbs(const arma::Mat<eT>& probs)
130 {
131 for (size_t i = 0; i < probs.size(); ++i)
132 {
133 if (probs[i] > 1.0 || probs[i] < 0.0)
134 Log::Fatal << "Probabilities cannot be greater than 1 "
135 << "or smaller than 0." << std::endl;
136 }
137 }
138
140 InputDataType inputParameter;
141
143 OutputDataType outputParameter;
144
146 bool logInput;
147
149 // approximation term.
150 bool full;
151
153 typename InputDataType::elem_type eps;
154
156 bool mean;
157}; // class PoissonNLLLoss
158
159} // namespace ann
160} // namespace mlpack
161
162// Include implementation.
163#include "poisson_nll_loss_impl.hpp"
164
165#endif
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Implementation of the Poisson negative log likelihood loss.
InputDataType::elem_type Forward(const InputType &input, const TargetType &target)
Computes the Poisson negative log likelihood Loss.
InputDataType & InputParameter() const
Get the input parameter.
bool & Mean()
Modify the value of mean.
InputDataType & InputParameter()
Modify the input parameter.
OutputDataType & OutputParameter() const
Get the output parameter.
bool & Full()
Modify the value of full.
PoissonNLLLoss(const bool logInput=true, const bool full=false, const typename InputDataType::elem_type eps=1e-08, const bool mean=true)
Create the PoissonNLLLoss object.
void Backward(const InputType &input, const TargetType &target, OutputType &output)
Ordinary feed backward pass of a neural network.
InputDataType::elem_type & Eps()
Modify the value of eps.
bool LogInput() const
Get the value of logInput.
InputDataType::elem_type Eps() const
Get the value of eps.
bool Mean() const
Get the value of mean.
bool Full() const
Get the value of full.
OutputDataType & OutputParameter()
Modify the output parameter.
bool & LogInput()
Modify the value of logInput.
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.