mlpack 3.4.2
simple_tolerance_termination.hpp
Go to the documentation of this file.
1
12#ifndef _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
13#define _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
14
15#include <mlpack/prereqs.hpp>
16
17namespace mlpack {
18namespace amf {
19
30template <class MatType>
32{
33 public:
35 SimpleToleranceTermination(const double tolerance = 1e-5,
36 const size_t maxIterations = 10000,
37 const size_t reverseStepTolerance = 3) :
38 tolerance(tolerance),
39 maxIterations(maxIterations),
40 V(nullptr),
41 iteration(1),
42 residueOld(DBL_MAX),
43 residue(DBL_MIN),
44 reverseStepTolerance(reverseStepTolerance),
45 reverseStepCount(0),
46 isCopy(false),
47 c_indexOld(0),
48 c_index(0)
49 { }
50
56 void Initialize(const MatType& V)
57 {
58 residueOld = DBL_MAX;
59 iteration = 1;
60 residue = DBL_MIN;
61 reverseStepCount = 0;
62 isCopy = false;
63
64 this->V = &V;
65
66 c_index = 0;
67 c_indexOld = 0;
68 }
69
76 bool IsConverged(arma::mat& W, arma::mat& H)
77 {
78 arma::mat WH;
79
80 WH = W * H;
81
82 // compute residue
83 residueOld = residue;
84 size_t n = V->n_rows;
85 size_t m = V->n_cols;
86 double sum = 0;
87 size_t count = 0;
88 for (size_t i = 0; i < n; ++i)
89 {
90 for (size_t j = 0; j < m; ++j)
91 {
92 double temp = 0;
93 if ((temp = (*V)(i, j)) != 0)
94 {
95 temp = (temp - WH(i, j));
96 temp = temp * temp;
97 sum += temp;
98 count++;
99 }
100 }
101 }
102 residue = sum / count;
103 residue = sqrt(residue);
104
105 // increment iteration count
106 iteration++;
107 Log::Info << "Iteration " << iteration << "; residue "
108 << ((residueOld - residue) / residueOld) << ".\n";
109
110 // if residue tolerance is not satisfied
111 if ((residueOld - residue) / residueOld < tolerance && iteration > 4)
112 {
113 // check if this is a first of successive drops
114 if (reverseStepCount == 0 && isCopy == false)
115 {
116 // store a copy of W and H matrix
117 isCopy = true;
118 this->W = W;
119 this->H = H;
120 // store residue values
121 c_index = residue;
122 c_indexOld = residueOld;
123 }
124 // increase successive drop count
125 reverseStepCount++;
126 }
127 // if tolerance is satisfied
128 else
129 {
130 // initialize successive drop count
131 reverseStepCount = 0;
132 // if residue is droped below minimum scrap stored values
133 if (residue <= c_indexOld && isCopy == true)
134 {
135 isCopy = false;
136 }
137 }
138
139 // check if termination criterion is met
140 if (reverseStepCount == reverseStepTolerance || iteration > maxIterations)
141 {
142 // if stored values are present replace them with current value as they
143 // represent the minimum residue point
144 if (isCopy)
145 {
146 W = this->W;
147 H = this->H;
148 residue = c_index;
149 }
150 return true;
151 }
152 else return false;
153 }
154
156 const double& Index() const { return residue; }
157
159 const size_t& Iteration() const { return iteration; }
160
162 const size_t& MaxIterations() const { return maxIterations; }
163 size_t& MaxIterations() { return maxIterations; }
164
166 const double& Tolerance() const { return tolerance; }
167 double& Tolerance() { return tolerance; }
168
169 private:
171 double tolerance;
173 size_t maxIterations;
174
176 const MatType* V;
177
179 size_t iteration;
180
182 double residueOld;
183 double residue;
184
186 size_t reverseStepTolerance;
188 size_t reverseStepCount;
189
192 bool isCopy;
193
195 arma::mat W;
196 arma::mat H;
197 double c_indexOld;
198 double c_index;
199}; // class SimpleToleranceTermination
200
201} // namespace amf
202} // namespace mlpack
203
204#endif // _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
205
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
This class implements residue tolerance termination policy.
bool IsConverged(arma::mat &W, arma::mat &H)
Check if termination criterio is met.
const double & Index() const
Get current value of residue.
void Initialize(const MatType &V)
Initializes the termination policy before stating the factorization.
SimpleToleranceTermination(const double tolerance=1e-5, const size_t maxIterations=10000, const size_t reverseStepTolerance=3)
empty constructor
const size_t & Iteration() const
Get current iteration count.
const double & Tolerance() const
Access tolerance value.
const size_t & MaxIterations() const
Access upper limit of iteration count.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.