mlpack 3.4.2
serialization.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_TESTS_SERIALIZATION_HPP
13#define MLPACK_TESTS_SERIALIZATION_HPP
14
15#include <boost/serialization/serialization.hpp>
16#include <boost/archive/xml_iarchive.hpp>
17#include <boost/archive/xml_oarchive.hpp>
18#include <boost/archive/text_iarchive.hpp>
19#include <boost/archive/text_oarchive.hpp>
20#include <boost/archive/binary_iarchive.hpp>
21#include <boost/archive/binary_oarchive.hpp>
22#include <mlpack/core.hpp>
23
24#include <boost/test/unit_test.hpp>
25#include "test_tools.hpp"
26
27namespace mlpack {
28
29// Test function for loading and saving Armadillo objects.
30template<typename CubeType,
31 typename IArchiveType,
32 typename OArchiveType>
33void TestArmadilloSerialization(arma::Cube<CubeType>& x)
34{
35 // First save it.
36 // Use type_info name to get unique file name for serialization test files.
37 std::string fileName = FilterFileName(typeid(IArchiveType).name());
38 std::ofstream ofs(fileName, std::ios::binary);
39 bool success = true;
40
41 {
42 OArchiveType o(ofs);
43
44 try
45 {
46 o << BOOST_SERIALIZATION_NVP(x);
47 }
48 catch (boost::archive::archive_exception& e)
49 {
50 success = false;
51 }
52 }
53
54 BOOST_REQUIRE_EQUAL(success, true);
55 ofs.close();
56
57 // Now load it.
58 arma::Cube<CubeType> orig(x);
59 success = true;
60 std::ifstream ifs(fileName, std::ios::binary);
61
62 {
63 IArchiveType i(ifs);
64
65 try
66 {
67 i >> BOOST_SERIALIZATION_NVP(x);
68 }
69 catch (boost::archive::archive_exception& e)
70 {
71 success = false;
72 }
73 }
74 ifs.close();
75
76 remove(fileName.c_str());
77
78 BOOST_REQUIRE_EQUAL(success, true);
79
80 BOOST_REQUIRE_EQUAL(x.n_rows, orig.n_rows);
81 BOOST_REQUIRE_EQUAL(x.n_cols, orig.n_cols);
82 BOOST_REQUIRE_EQUAL(x.n_elem_slice, orig.n_elem_slice);
83 BOOST_REQUIRE_EQUAL(x.n_slices, orig.n_slices);
84 BOOST_REQUIRE_EQUAL(x.n_elem, orig.n_elem);
85
86 for (size_t slice = 0; slice != x.n_slices; ++slice)
87 {
88 const auto& origSlice = orig.slice(slice);
89 const auto& xSlice = x.slice(slice);
90 for (size_t i = 0; i < x.n_cols; ++i)
91 {
92 for (size_t j = 0; j < x.n_rows; ++j)
93 {
94 if (double(origSlice(j, i)) == 0.0)
95 BOOST_REQUIRE_SMALL(double(xSlice(j, i)), 1e-8);
96 else
97 BOOST_REQUIRE_CLOSE(double(origSlice(j, i)), double(xSlice(j, i)),
98 1e-8);
99 }
100 }
101 }
102}
103
104// Test all serialization strategies.
105template<typename CubeType>
106void TestAllArmadilloSerialization(arma::Cube<CubeType>& x)
107{
108 TestArmadilloSerialization<CubeType, boost::archive::xml_iarchive,
109 boost::archive::xml_oarchive>(x);
110 TestArmadilloSerialization<CubeType, boost::archive::text_iarchive,
111 boost::archive::text_oarchive>(x);
112 TestArmadilloSerialization<CubeType, boost::archive::binary_iarchive,
113 boost::archive::binary_oarchive>(x);
114}
115
116// Test function for loading and saving Armadillo objects.
117template<typename MatType,
118 typename IArchiveType,
119 typename OArchiveType>
121{
122 // First save it.
123 std::string fileName = FilterFileName(typeid(IArchiveType).name());
124 std::ofstream ofs(fileName, std::ios::binary);
125 bool success = true;
126
127 {
128 OArchiveType o(ofs);
129
130 try
131 {
132 o << BOOST_SERIALIZATION_NVP(x);
133 }
134 catch (boost::archive::archive_exception& e)
135 {
136 success = false;
137 }
138 }
139
140 BOOST_REQUIRE_EQUAL(success, true);
141 ofs.close();
142
143 // Now load it.
144 MatType orig(x);
145 success = true;
146 std::ifstream ifs(fileName, std::ios::binary);
147
148 {
149 IArchiveType i(ifs);
150
151 try
152 {
153 i >> BOOST_SERIALIZATION_NVP(x);
154 }
155 catch (boost::archive::archive_exception& e)
156 {
157 success = false;
158 }
159 }
160 ifs.close();
161
162 remove(fileName.c_str());
163
164 BOOST_REQUIRE_EQUAL(success, true);
165
166 BOOST_REQUIRE_EQUAL(x.n_rows, orig.n_rows);
167 BOOST_REQUIRE_EQUAL(x.n_cols, orig.n_cols);
168 BOOST_REQUIRE_EQUAL(x.n_elem, orig.n_elem);
169
170 for (size_t i = 0; i < x.n_cols; ++i)
171 for (size_t j = 0; j < x.n_rows; ++j)
172 if (double(orig(j, i)) == 0.0)
173 BOOST_REQUIRE_SMALL(double(x(j, i)), 1e-8);
174 else
175 BOOST_REQUIRE_CLOSE(double(orig(j, i)), double(x(j, i)), 1e-8);
176}
177
178// Test all serialization strategies.
179template<typename MatType>
181{
182 TestArmadilloSerialization<MatType, boost::archive::xml_iarchive,
183 boost::archive::xml_oarchive>(x);
184 TestArmadilloSerialization<MatType, boost::archive::text_iarchive,
185 boost::archive::text_oarchive>(x);
186 TestArmadilloSerialization<MatType, boost::archive::binary_iarchive,
187 boost::archive::binary_oarchive>(x);
188}
189
190// Save and load an mlpack object.
191// The re-loaded copy is placed in 'newT'.
192template<typename T, typename IArchiveType, typename OArchiveType>
193void SerializeObject(T& t, T& newT)
194{
195 std::string fileName = FilterFileName(typeid(T).name());
196 std::ofstream ofs(fileName, std::ios::binary);
197 bool success = true;
198
199 {
200 OArchiveType o(ofs);
201
202 try
203 {
204 o << BOOST_SERIALIZATION_NVP(t);
205 }
206 catch (boost::archive::archive_exception& e)
207 {
208 MLPACK_CERR_STREAM << e.what() << std::endl;
209 success = false;
210 }
211 }
212 ofs.close();
213
214 BOOST_REQUIRE_EQUAL(success, true);
215
216 std::ifstream ifs(fileName, std::ios::binary);
217
218 {
219 IArchiveType i(ifs);
220
221 try
222 {
223 i >> BOOST_SERIALIZATION_NVP(newT);
224 }
225 catch (boost::archive::archive_exception& e)
226 {
227 MLPACK_COUT_STREAM << e.what() << "\n";
228 success = false;
229 }
230 }
231 ifs.close();
232
233 remove(fileName.c_str());
234
235 BOOST_REQUIRE_EQUAL(success, true);
236}
237
238// Test mlpack serialization with all three archive types.
239template<typename T>
240void SerializeObjectAll(T& t, T& xmlT, T& textT, T& binaryT)
241{
242 SerializeObject<T, boost::archive::xml_iarchive,
243 boost::archive::xml_oarchive>(t, xmlT);
244 SerializeObject<T, boost::archive::text_iarchive,
245 boost::archive::text_oarchive>(t, textT);
246 SerializeObject<T, boost::archive::binary_iarchive,
247 boost::archive::binary_oarchive>(t, binaryT);
248}
249
250// Save and load a non-default-constructible mlpack object.
251template<typename T, typename IArchiveType, typename OArchiveType>
252void SerializePointerObject(T* t, T*& newT)
253{
254 std::string fileName = FilterFileName(typeid(T).name());
255 std::ofstream ofs(fileName, std::ios::binary);
256 bool success = true;
257
258 {
259 OArchiveType o(ofs);
260 try
261 {
262 o << BOOST_SERIALIZATION_NVP(t);
263 }
264 catch (boost::archive::archive_exception& e)
265 {
266 MLPACK_COUT_STREAM << e.what() << "\n";
267 success = false;
268 }
269 }
270 ofs.close();
271
272 BOOST_REQUIRE_EQUAL(success, true);
273
274 std::ifstream ifs(fileName, std::ios::binary);
275
276 {
277 IArchiveType i(ifs);
278
279 try
280 {
281 i >> BOOST_SERIALIZATION_NVP(newT);
282 }
283 catch (std::exception& e)
284 {
285 MLPACK_COUT_STREAM << e.what() << "\n";
286 success = false;
287 }
288 }
289 ifs.close();
290
291 remove(fileName.c_str());
292
293 BOOST_REQUIRE_EQUAL(success, true);
294}
295
296template<typename T>
297void SerializePointerObjectAll(T* t, T*& xmlT, T*& textT, T*& binaryT)
298{
299 SerializePointerObject<T, boost::archive::text_iarchive,
300 boost::archive::text_oarchive>(t, textT);
301 SerializePointerObject<T, boost::archive::binary_iarchive,
302 boost::archive::binary_oarchive>(t, binaryT);
303 SerializePointerObject<T, boost::archive::xml_iarchive,
304 boost::archive::xml_oarchive>(t, xmlT);
305}
306
307// Utility function to check the equality of two Armadillo matrices.
308void CheckMatrices(const arma::mat& x,
309 const arma::mat& xmlX,
310 const arma::mat& textX,
311 const arma::mat& binaryX);
312
313void CheckMatrices(const arma::Mat<size_t>& x,
314 const arma::Mat<size_t>& xmlX,
315 const arma::Mat<size_t>& textX,
316 const arma::Mat<size_t>& binaryX);
317
318void CheckMatrices(const arma::cube& x,
319 const arma::cube& xmlX,
320 const arma::cube& textX,
321 const arma::cube& binaryX);
322
323} // namespace mlpack
324
325#endif
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void SerializeObject(T &t, T &newT)
void SerializePointerObjectAll(T *t, T *&xmlT, T *&textT, T *&binaryT)
void TestAllArmadilloSerialization(arma::Cube< CubeType > &x)
void CheckMatrices(const arma::mat &x, const arma::mat &xmlX, const arma::mat &textX, const arma::mat &binaryX)
void SerializePointerObject(T *t, T *&newT)
void TestArmadilloSerialization(arma::Cube< CubeType > &x)
void SerializeObjectAll(T &t, T &xmlT, T &textT, T &binaryT)
#define MLPACK_COUT_STREAM
Definition: prereqs.hpp:45
#define MLPACK_CERR_STREAM
Definition: prereqs.hpp:51
std::string FilterFileName(const std::string &inputString)