mlpack 3.4.2
serialization_catch.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_TESTS_SERIALIZATION_CATCH_HPP
13#define MLPACK_TESTS_SERIALIZATION_CATCH_HPP
14
15#include <boost/serialization/serialization.hpp>
16#include <boost/archive/xml_iarchive.hpp>
17#include <boost/archive/xml_oarchive.hpp>
18#include <boost/archive/text_iarchive.hpp>
19#include <boost/archive/text_oarchive.hpp>
20#include <boost/archive/binary_iarchive.hpp>
21#include <boost/archive/binary_oarchive.hpp>
22#include <mlpack/core.hpp>
23
24
25#include "test_catch_tools.hpp"
26#include "catch.hpp"
27
28namespace mlpack {
29
30// Test function for loading and saving Armadillo objects.
31template<typename CubeType,
32 typename IArchiveType,
33 typename OArchiveType>
34void TestArmadilloSerialization(arma::Cube<CubeType>& x)
35{
36 // First save it.
37 // Use type_info name to get unique file name for serialization test files.
38 std::string fileName = FilterFileName(typeid(IArchiveType).name());
39 std::ofstream ofs(fileName, std::ios::binary);
40 bool success = true;
41
42 {
43 OArchiveType o(ofs);
44
45 try
46 {
47 o << BOOST_SERIALIZATION_NVP(x);
48 }
49 catch (boost::archive::archive_exception& e)
50 {
51 success = false;
52 }
53 }
54
55 REQUIRE(success == true);
56 ofs.close();
57
58 // Now load it.
59 arma::Cube<CubeType> orig(x);
60 success = true;
61 std::ifstream ifs(fileName, std::ios::binary);
62
63 {
64 IArchiveType i(ifs);
65
66 try
67 {
68 i >> BOOST_SERIALIZATION_NVP(x);
69 }
70 catch (boost::archive::archive_exception& e)
71 {
72 success = false;
73 }
74 }
75 ifs.close();
76
77 remove(fileName.c_str());
78
79 REQUIRE(success == true);
80
81 REQUIRE(x.n_rows == orig.n_rows);
82 REQUIRE(x.n_cols == orig.n_cols);
83 REQUIRE(x.n_elem_slice == orig.n_elem_slice);
84 REQUIRE(x.n_slices == orig.n_slices);
85 REQUIRE(x.n_elem == orig.n_elem);
86
87 for (size_t slice = 0; slice != x.n_slices; ++slice)
88 {
89 const auto& origSlice = orig.slice(slice);
90 const auto& xSlice = x.slice(slice);
91 for (size_t i = 0; i < x.n_cols; ++i)
92 {
93 for (size_t j = 0; j < x.n_rows; ++j)
94 {
95 if (double(origSlice(j, i)) == 0.0)
96 REQUIRE(double(xSlice(j, i)) == Approx(0.0).margin(1e-8 / 100));
97 else
98 REQUIRE(double(origSlice(j, i)) ==
99 Approx(double(xSlice(j, i))).epsilon(1e-8 / 100));
100 }
101 }
102 }
103}
104
105// Test all serialization strategies.
106template<typename CubeType>
107void TestAllArmadilloSerialization(arma::Cube<CubeType>& x)
108{
109 TestArmadilloSerialization<CubeType, boost::archive::xml_iarchive,
110 boost::archive::xml_oarchive>(x);
111 TestArmadilloSerialization<CubeType, boost::archive::text_iarchive,
112 boost::archive::text_oarchive>(x);
113 TestArmadilloSerialization<CubeType, boost::archive::binary_iarchive,
114 boost::archive::binary_oarchive>(x);
115}
116
117// Test function for loading and saving Armadillo objects.
118template<typename MatType,
119 typename IArchiveType,
120 typename OArchiveType>
121void TestArmadilloSerialization(MatType& x)
122{
123 // First save it.
124 std::string fileName = FilterFileName(typeid(IArchiveType).name());
125 std::ofstream ofs(fileName, std::ios::binary);
126 bool success = true;
127
128 {
129 OArchiveType o(ofs);
130
131 try
132 {
133 o << BOOST_SERIALIZATION_NVP(x);
134 }
135 catch (boost::archive::archive_exception& e)
136 {
137 success = false;
138 }
139 }
140
141 REQUIRE(success == true);
142 ofs.close();
143
144 // Now load it.
145 MatType orig(x);
146 success = true;
147 std::ifstream ifs(fileName, std::ios::binary);
148
149 {
150 IArchiveType i(ifs);
151
152 try
153 {
154 i >> BOOST_SERIALIZATION_NVP(x);
155 }
156 catch (boost::archive::archive_exception& e)
157 {
158 success = false;
159 }
160 }
161 ifs.close();
162
163 remove(fileName.c_str());
164
165 REQUIRE(success == true);
166
167 REQUIRE(x.n_rows == orig.n_rows);
168 REQUIRE(x.n_cols == orig.n_cols);
169 REQUIRE(x.n_elem == orig.n_elem);
170
171 for (size_t i = 0; i < x.n_cols; ++i)
172 for (size_t j = 0; j < x.n_rows; ++j)
173 if (double(orig(j, i)) == 0.0)
174 REQUIRE(double(x(j, i)) == Approx(0.0).margin(1e-8 / 100));
175 else
176 REQUIRE(double(orig(j, i)) ==
177 Approx(double(x(j, i))).epsilon(1e-8 / 100));
178}
179
180// Test all serialization strategies.
181template<typename MatType>
182void TestAllArmadilloSerialization(MatType& x)
183{
184 TestArmadilloSerialization<MatType, boost::archive::xml_iarchive,
185 boost::archive::xml_oarchive>(x);
186 TestArmadilloSerialization<MatType, boost::archive::text_iarchive,
187 boost::archive::text_oarchive>(x);
188 TestArmadilloSerialization<MatType, boost::archive::binary_iarchive,
189 boost::archive::binary_oarchive>(x);
190}
191
192// Save and load an mlpack object.
193// The re-loaded copy is placed in 'newT'.
194template<typename T, typename IArchiveType, typename OArchiveType>
195void SerializeObject(T& t, T& newT)
196{
197 std::string fileName = FilterFileName(typeid(T).name());
198 std::ofstream ofs(fileName, std::ios::binary);
199 bool success = true;
200
201 {
202 OArchiveType o(ofs);
203
204 try
205 {
206 o << BOOST_SERIALIZATION_NVP(t);
207 }
208 catch (boost::archive::archive_exception& e)
209 {
210 std::cerr << e.what() << std::endl;
211 success = false;
212 }
213 }
214 ofs.close();
215
216 REQUIRE(success == true);
217
218 std::ifstream ifs(fileName, std::ios::binary);
219
220 {
221 IArchiveType i(ifs);
222
223 try
224 {
225 i >> BOOST_SERIALIZATION_NVP(newT);
226 }
227 catch (boost::archive::archive_exception& e)
228 {
229 std::cout << e.what() << "\n";
230 success = false;
231 }
232 }
233 ifs.close();
234
235 remove(fileName.c_str());
236
237 REQUIRE(success == true);
238}
239
240// Test mlpack serialization with all three archive types.
241template<typename T>
242void SerializeObjectAll(T& t, T& xmlT, T& textT, T& binaryT)
243{
244 SerializeObject<T, boost::archive::xml_iarchive,
245 boost::archive::xml_oarchive>(t, xmlT);
246 SerializeObject<T, boost::archive::text_iarchive,
247 boost::archive::text_oarchive>(t, textT);
248 SerializeObject<T, boost::archive::binary_iarchive,
249 boost::archive::binary_oarchive>(t, binaryT);
250}
251
252// Save and load a non-default-constructible mlpack object.
253template<typename T, typename IArchiveType, typename OArchiveType>
254void SerializePointerObject(T* t, T*& newT)
255{
256 std::string fileName = FilterFileName(typeid(T).name());
257 std::ofstream ofs(fileName, std::ios::binary);
258 bool success = true;
259
260 {
261 OArchiveType o(ofs);
262 try
263 {
264 o << BOOST_SERIALIZATION_NVP(t);
265 }
266 catch (boost::archive::archive_exception& e)
267 {
268 std::cout << e.what() << "\n";
269 success = false;
270 }
271 }
272 ofs.close();
273
274 REQUIRE(success == true);
275
276 std::ifstream ifs(fileName, std::ios::binary);
277
278 {
279 IArchiveType i(ifs);
280
281 try
282 {
283 i >> BOOST_SERIALIZATION_NVP(newT);
284 }
285 catch (std::exception& e)
286 {
287 std::cout << e.what() << "\n";
288 success = false;
289 }
290 }
291 ifs.close();
292
293 remove(fileName.c_str());
294
295 REQUIRE(success == true);
296}
297
298template<typename T>
299void SerializePointerObjectAll(T* t, T*& xmlT, T*& textT, T*& binaryT)
300{
301 SerializePointerObject<T, boost::archive::text_iarchive,
302 boost::archive::text_oarchive>(t, textT);
303 SerializePointerObject<T, boost::archive::binary_iarchive,
304 boost::archive::binary_oarchive>(t, binaryT);
305 SerializePointerObject<T, boost::archive::xml_iarchive,
306 boost::archive::xml_oarchive>(t, xmlT);
307}
308
309// Utility function to check the equality of two Armadillo matrices.
310void CheckMatrices(const arma::mat& x,
311 const arma::mat& xmlX,
312 const arma::mat& textX,
313 const arma::mat& binaryX);
314
315void CheckMatrices(const arma::Mat<size_t>& x,
316 const arma::Mat<size_t>& xmlX,
317 const arma::Mat<size_t>& textX,
318 const arma::Mat<size_t>& binaryX);
319
320void CheckMatrices(const arma::cube& x,
321 const arma::cube& xmlX,
322 const arma::cube& textX,
323 const arma::cube& binaryX);
324
325} // namespace mlpack
326
327#endif
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void SerializeObject(T &t, T &newT)
void SerializePointerObjectAll(T *t, T *&xmlT, T *&textT, T *&binaryT)
void TestAllArmadilloSerialization(arma::Cube< CubeType > &x)
void CheckMatrices(const arma::mat &x, const arma::mat &xmlX, const arma::mat &textX, const arma::mat &binaryX)
void SerializePointerObject(T *t, T *&newT)
void TestArmadilloSerialization(arma::Cube< CubeType > &x)
void SerializeObjectAll(T &t, T &xmlT, T &textT, T &binaryT)
std::string FilterFileName(const std::string &inputString)