mlpack 3.4.2
validation_rmse_termination.hpp
Go to the documentation of this file.
1
12#ifndef _MLPACK_METHODS_AMF_VALIDATIONRMSETERMINATION_HPP_INCLUDED
13#define _MLPACK_METHODS_AMF_VALIDATIONRMSETERMINATION_HPP_INCLUDED
14
15#include <mlpack/prereqs.hpp>
16
17namespace mlpack
18{
19namespace amf
20{
21
36template <class MatType>
38{
39 public:
51 size_t num_test_points,
52 double tolerance = 1e-5,
53 size_t maxIterations = 10000,
54 size_t reverseStepTolerance = 3)
55 : tolerance(tolerance),
56 maxIterations(maxIterations),
57 num_test_points(num_test_points),
58 reverseStepTolerance(reverseStepTolerance)
59 {
60 size_t n = V.n_rows;
61 size_t m = V.n_cols;
62
63 // initialize validation set matrix
64 test_points.zeros(num_test_points, 3);
65
66 // fill validation set matrix with random chosen entries
67 for (size_t i = 0; i < num_test_points; ++i)
68 {
69 double t_val;
70 size_t t_row;
71 size_t t_col;
72
73 // pick a random non-zero entry
74 do
75 {
76 t_row = rand() % n;
77 t_col = rand() % m;
78 } while ((t_val = V(t_row, t_col)) == 0);
79
80 // add the entry to the validation set
81 test_points(i, 0) = t_row;
82 test_points(i, 1) = t_col;
83 test_points(i, 2) = t_val;
84
85 // nullify the added entry from data matrix (training set)
86 V(t_row, t_col) = 0;
87 }
88 }
89
95 void Initialize(const MatType& /* V */)
96 {
97 iteration = 1;
98
99 rmse = DBL_MAX;
100 rmseOld = DBL_MAX;
101
102 c_index = 0;
103 c_indexOld = 0;
104
105 reverseStepCount = 0;
106 isCopy = false;
107 }
108
115 bool IsConverged(arma::mat& W, arma::mat& H)
116 {
117 arma::mat WH;
118
119 WH = W * H;
120
121 // compute validation RMSE
122 if (iteration != 0)
123 {
124 rmseOld = rmse;
125 rmse = 0;
126 for (size_t i = 0; i < num_test_points; ++i)
127 {
128 size_t t_row = test_points(i, 0);
129 size_t t_col = test_points(i, 1);
130 double t_val = test_points(i, 2);
131 double temp = (t_val - WH(t_row, t_col));
132 temp *= temp;
133 rmse += temp;
134 }
135 rmse /= num_test_points;
136 rmse = sqrt(rmse);
137 }
138
139 // increment iteration count
140 iteration++;
141
142 // if RMSE tolerance is not satisfied
143 if ((rmseOld - rmse) / rmseOld < tolerance && iteration > 4)
144 {
145 // check if this is a first of successive drops
146 if (reverseStepCount == 0 && isCopy == false)
147 {
148 // store a copy of W and H matrix
149 isCopy = true;
150 this->W = W;
151 this->H = H;
152 // store residue values
153 c_indexOld = rmseOld;
154 c_index = rmse;
155 }
156 // increase successive drop count
157 reverseStepCount++;
158 }
159 // if tolerance is satisfied
160 else
161 {
162 // initialize successive drop count
163 reverseStepCount = 0;
164 // if residue is droped below minimum scrap stored values
165 if (rmse <= c_indexOld && isCopy == true)
166 {
167 isCopy = false;
168 }
169 }
170
171 // check if termination criterion is met
172 if (reverseStepCount == reverseStepTolerance || iteration > maxIterations)
173 {
174 // if stored values are present replace them with current value as they
175 // represent the minimum residue point
176 if (isCopy)
177 {
178 W = this->W;
179 H = this->H;
180 rmse = c_index;
181 }
182 return true;
183 }
184 else return false;
185 }
186
188 const double& Index() const { return rmse; }
189
191 const size_t& Iteration() const { return iteration; }
192
194 const size_t& NumTestPoints() const { return num_test_points; }
195
197 const size_t& MaxIterations() const { return maxIterations; }
198 size_t& MaxIterations() { return maxIterations; }
199
201 const double& Tolerance() const { return tolerance; }
202 double& Tolerance() { return tolerance; }
203
204 private:
206 double tolerance;
208 size_t maxIterations;
210 size_t num_test_points;
211
213 size_t iteration;
214
216 arma::mat test_points;
217
219 double rmseOld;
220 double rmse;
221
223 size_t reverseStepTolerance;
225 size_t reverseStepCount;
226
229 bool isCopy;
230
232 arma::mat W;
233 arma::mat H;
234 double c_indexOld;
235 double c_index;
236}; // class ValidationRMSETermination
237
238} // namespace amf
239} // namespace mlpack
240
241
242#endif // _MLPACK_METHODS_AMF_VALIDATIONRMSETERMINATION_HPP_INCLUDED
This class implements validation termination policy based on RMSE index.
ValidationRMSETermination(MatType &V, size_t num_test_points, double tolerance=1e-5, size_t maxIterations=10000, size_t reverseStepTolerance=3)
Create a validation set according to given parameters and nullifies this set in data matrix(training ...
bool IsConverged(arma::mat &W, arma::mat &H)
Check if termination criterio is met.
const double & Index() const
Get current value of residue.
const size_t & Iteration() const
Get current iteration count.
const size_t & NumTestPoints() const
Get number of validation points.
const double & Tolerance() const
Access tolerance value.
void Initialize(const MatType &)
Initializes the termination policy before stating the factorization.
const size_t & MaxIterations() const
Access upper limit of iteration count.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.