mlpack 3.4.2
print_input_processing.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_BINDINGS_PYTHON_PRINT_INPUT_PROCESSING_HPP
14#define MLPACK_BINDINGS_PYTHON_PRINT_INPUT_PROCESSING_HPP
15
16#include <mlpack/prereqs.hpp>
17#include "get_arma_type.hpp"
18#include "get_numpy_type.hpp"
20#include "get_cython_type.hpp"
21#include "strip_type.hpp"
22
23namespace mlpack {
24namespace bindings {
25namespace python {
26
30template<typename T>
33 const size_t indent,
34 const typename boost::disable_if<util::IsStdVector<T>>::type* = 0,
35 const typename boost::disable_if<arma::is_arma_type<T>>::type* = 0,
36 const typename boost::disable_if<data::HasSerialize<T>>::type* = 0,
37 const typename boost::disable_if<std::is_same<T,
38 std::tuple<data::DatasetInfo, arma::mat>>>::type* = 0)
39{
40 // The copy_all_inputs parameter must be handled first, and therefore is
41 // outside the scope of this code.
42 if (d.name == "copy_all_inputs")
43 return;
44
45 const std::string prefix(indent, ' ');
46
47 std::string def = "None";
48 if (std::is_same<T, bool>::value)
49 def = "False";
50
51 // Make sure that we don't use names that are Python keywords.
52 std::string name = (d.name == "lambda") ? "lambda_" : d.name;
53
65 std::cout << prefix << "# Detect if the parameter was passed; set if so."
66 << std::endl;
67 if (!d.required)
68 {
69 if (GetPrintableType<T>(d) == "bool")
70 {
71 std::cout << prefix << "if isinstance(" << name << ", "
72 << GetPrintableType<T>(d) << "):" << std::endl;
73 std::cout << prefix << " if " << name << " is not " << def << ":"
74 << std::endl;
75 }
76 else
77 {
78 std::cout << prefix << "if " << name << " is not " << def << ":"
79 << std::endl;
80 std::cout << prefix << " if isinstance(" << name << ", "
81 << GetPrintableType<T>(d) << "):" << std::endl;
82 }
83
84 std::cout << prefix << " SetParam[" << GetCythonType<T>(d)
85 << "](<const string> '" << d.name << "', ";
86 if (GetCythonType<T>(d) == "string")
87 std::cout << name << ".encode(\"UTF-8\")";
88 else
89 std::cout << name;
90 std::cout << ")" << std::endl;
91 std::cout << prefix << " IO.SetPassed(<const string> '" << d.name
92 << "')" << std::endl;
93
94 // If this parameter is "verbose", then enable verbose output.
95 if (d.name == "verbose")
96 std::cout << prefix << " EnableVerbose()" << std::endl;
97
98 if (GetPrintableType<T>(d) == "bool")
99 {
100 std::cout << " else:" << std::endl;
101 std::cout << " raise TypeError(" <<"\"'"<< name
102 << "' must have type \'" << GetPrintableType<T>(d)
103 << "'!\")" << std::endl;
104 }
105 else
106 {
107 std::cout << " else:" << std::endl;
108 std::cout << " raise TypeError(" <<"\"'"<< name
109 << "' must have type \'" << GetPrintableType<T>(d)
110 << "'!\")" << std::endl;
111 }
112 }
113 else
114 {
115 if (GetPrintableType<T>(d) == "bool")
116 {
117 std::cout << prefix << "if isinstance(" << name << ", "
118 << GetPrintableType<T>(d) << "):" << std::endl;
119 std::cout << prefix << " if " << name << " is not " << def << ":"
120 << std::endl;
121 }
122 else
123 {
124 std::cout << prefix << "if " << name << " is not " << def << ":"
125 << std::endl;
126 std::cout << prefix << " if isinstance(" << name << ", "
127 << GetPrintableType<T>(d) << "):" << std::endl;
128 }
129
130 std::cout << prefix << " SetParam[" << GetCythonType<T>(d) << "](<const "
131 << "string> '" << d.name << "', ";
132 if (GetCythonType<T>(d) == "string")
133 std::cout << name << ".encode(\"UTF-8\")";
134 else if (GetCythonType<T>(d) == "vector[string]")
135 std::cout << "[i.encode(\"UTF-8\") for i in " << name << "]";
136 else
137 std::cout << name;
138 std::cout << ")" << std::endl;
139 std::cout << prefix << " IO.SetPassed(<const string> '"
140 << d.name << "')" << std::endl;
141
142 if (GetPrintableType<T>(d) == "bool")
143 {
144 std::cout << " else:" << std::endl;
145 std::cout << " raise TypeError(" <<"\"'"<< name
146 << "' must have type \'" << GetPrintableType<T>(d)
147 << "'!\")" << std::endl;
148 }
149 else
150 {
151 std::cout << " else:" << std::endl;
152 std::cout << " raise TypeError(" <<"\"'"<< name
153 << "' must have type \'" << GetPrintableType<T>(d)
154 << "'!\")" << std::endl;
155 }
156 }
157 std::cout << std::endl; // Extra line is to clear up the code a bit.
158}
159
163template<typename T>
166 const size_t indent,
167 const typename boost::disable_if<arma::is_arma_type<T>>::type* = 0,
168 const typename boost::disable_if<data::HasSerialize<T>>::type* = 0,
169 const typename boost::disable_if<std::is_same<T,
170 std::tuple<data::DatasetInfo, arma::mat>>>::type* = 0,
171 const typename boost::enable_if<util::IsStdVector<T>>::type* = 0)
172{
173 const std::string prefix(indent, ' ');
174
189 std::cout << prefix << "# Detect if the parameter was passed; set if so."
190 << std::endl;
191 if (!d.required)
192 {
193 std::cout << prefix << "if " << d.name << " is not None:"
194 << std::endl;
195 std::cout << prefix << " if isinstance(" << d.name << ", list):"
196 << std::endl;
197 std::cout << prefix << " if len(" << d.name << ") > 0:"
198 << std::endl;
199 std::cout << prefix << " if isinstance(" << d.name << "[0], "
200 << GetPrintableType<typename T::value_type>(d) << "):" << std::endl;
201 std::cout << prefix << " SetParam[" << GetCythonType<T>(d)
202 << "](<const string> '" << d.name << "', ";
203 // Strings need special handling.
204 if (GetCythonType<T>(d) == "vector[string]")
205 std::cout << "[i.encode(\"UTF-8\") for i in " << d.name << "]";
206 else
207 std::cout << d.name;
208 std::cout << ")" << std::endl;
209 std::cout << prefix << " IO.SetPassed(<const string> '" << d.name
210 << "')" << std::endl;
211 std::cout << prefix << " else:" << std::endl;
212 std::cout << prefix << " raise TypeError(" <<"\"'"<< d.name
213 << "' must have type \'" << GetPrintableType<T>(d)
214 << "'!\")" << std::endl;
215 std::cout << prefix << " else:" << std::endl;
216 std::cout << prefix << " raise TypeError(" <<"\"'"<< d.name
217 << "' must have type \'list'!\")" << std::endl;
218 }
219 else
220 {
221 std::cout << prefix << "if isinstance(" << d.name << ", list):"
222 << std::endl;
223 std::cout << prefix << " if len(" << d.name << ") > 0:"
224 << std::endl;
225 std::cout << prefix << " if isinstance(" << d.name << "[0], "
226 << GetPrintableType<typename T::value_type>(d) << "):" << std::endl;
227 std::cout << prefix << " SetParam[" << GetCythonType<T>(d)
228 << "](<const string> '" << d.name << "', ";
229 // Strings need special handling.
230 if (GetCythonType<T>(d) == "vector[string]")
231 std::cout << "[i.encode(\"UTF-8\") for i in " << d.name << "]";
232 else
233 std::cout << d.name;
234 std::cout << ")" << std::endl;
235 std::cout << prefix << " IO.SetPassed(<const string> '" << d.name
236 << "')" << std::endl;
237 std::cout << prefix << " else:" << std::endl;
238 std::cout << prefix << " raise TypeError(" <<"\"'"<< d.name
239 << "' must have type \'" << GetPrintableType<T>(d)
240 << "'!\")" << std::endl;
241 std::cout << prefix << "else:" << std::endl;
242 std::cout << prefix << " raise TypeError(" <<"\"'"<< d.name
243 << "' must have type \'list'!\")" << std::endl;
244 }
245}
246
250template<typename T>
253 const size_t indent,
254 const typename boost::disable_if<util::IsStdVector<T>>::type* = 0,
255 const typename boost::enable_if<arma::is_arma_type<T>>::type* = 0)
256{
257 const std::string prefix(indent, ' ');
258
274 std::cout << prefix << "# Detect if the parameter was passed; set if so."
275 << std::endl;
276 if (!d.required)
277 {
278 if (T::is_row || T::is_col)
279 {
280 std::cout << prefix << "if " << d.name << " is not None:" << std::endl;
281 std::cout << prefix << " " << d.name << "_tuple = to_matrix("
282 << d.name << ", dtype=" << GetNumpyType<typename T::elem_type>()
283 << ", copy=IO.HasParam('copy_all_inputs'))" << std::endl;
284 std::cout << prefix << " if len(" << d.name << "_tuple[0].shape) > 1:"
285 << std::endl;
286 std::cout << prefix << " if " << d.name << "_tuple[0]"
287 << ".shape[0] == 1 or " << d.name << "_tuple[0].shape[1] == 1:"
288 << std::endl;
289 std::cout << prefix << " " << d.name << "_tuple[0].shape = ("
290 << d.name << "_tuple[0].size,)" << std::endl;
291 std::cout << prefix << " " << d.name << "_mat = arma_numpy.numpy_to_"
292 << GetArmaType<T>() << "_" << GetNumpyTypeChar<T>() << "(" << d.name
293 << "_tuple[0], " << d.name << "_tuple[1])" << std::endl;
294 std::cout << prefix << " SetParam[" << GetCythonType<T>(d)
295 << "](<const string> '" << d.name << "', dereference("
296 << d.name << "_mat))"<< std::endl;
297 std::cout << prefix << " IO.SetPassed(<const string> '" << d.name
298 << "')" << std::endl;
299 std::cout << prefix << " del " << d.name << "_mat" << std::endl;
300 }
301 else
302 {
303 std::cout << prefix << "if " << d.name << " is not None:" << std::endl;
304 std::cout << prefix << " " << d.name << "_tuple = to_matrix("
305 << d.name << ", dtype=" << GetNumpyType<typename T::elem_type>()
306 << ", copy=IO.HasParam('copy_all_inputs'))" << std::endl;
307 std::cout << prefix << " if len(" << d.name << "_tuple[0].shape"
308 << ") < 2:" << std::endl;
309 std::cout << prefix << " " << d.name << "_tuple[0].shape = (" << d.name
310 << "_tuple[0].shape[0], 1)" << std::endl;
311 std::cout << prefix << " " << d.name << "_mat = arma_numpy.numpy_to_"
312 << GetArmaType<T>() << "_" << GetNumpyTypeChar<T>() << "(" << d.name
313 << "_tuple[0], " << d.name << "_tuple[1])" << std::endl;
314 std::cout << prefix << " SetParam[" << GetCythonType<T>(d)
315 << "](<const string> '" << d.name << "', dereference("
316 << d.name << "_mat))"<< std::endl;
317 std::cout << prefix << " IO.SetPassed(<const string> '" << d.name
318 << "')" << std::endl;
319 std::cout << prefix << " del " << d.name << "_mat" << std::endl;
320 }
321 }
322 else
323 {
324 if (T::is_row || T::is_col)
325 {
326 std::cout << prefix << d.name << "_tuple = to_matrix(" << d.name
327 << ", dtype=" << GetNumpyType<typename T::elem_type>()
328 << ", copy=IO.HasParam('copy_all_inputs'))" << std::endl;
329 std::cout << prefix << "if len(" << d.name << "_tuple[0].shape) > 1:"
330 << std::endl;
331 std::cout << prefix << " if " << d.name << "_tuple[0].shape[0] == 1 or "
332 << d.name << "_tuple[0].shape[1] == 1:" << std::endl;
333 std::cout << prefix << " " << d.name << "_tuple[0].shape = ("
334 << d.name << "_tuple[0].size,)" << std::endl;
335 std::cout << prefix << d.name << "_mat = arma_numpy.numpy_to_"
336 << GetArmaType<T>() << "_" << GetNumpyTypeChar<T>() << "(" << d.name
337 << "_tuple[0], " << d.name << "_tuple[1])" << std::endl;
338 std::cout << prefix << "SetParam[" << GetCythonType<T>(d)
339 << "](<const string> '" << d.name << "', dereference("
340 << d.name << "_mat))"<< std::endl;
341 std::cout << prefix << "IO.SetPassed(<const string> '" << d.name << "')"
342 << std::endl;
343 std::cout << prefix << "del " << d.name << "_mat" << std::endl;
344 }
345 else
346 {
347 std::cout << prefix << d.name << "_tuple = to_matrix(" << d.name
348 << ", dtype=" << GetNumpyType<typename T::elem_type>()
349 << ", copy=IO.HasParam('copy_all_inputs'))" << std::endl;
350 std::cout << prefix << "if len(" << d.name << "_tuple[0].shape) > 2:"
351 << std::endl;
352 std::cout << prefix << " " << d.name << "_tuple[0].shape = (" << d.name
353 << "_tuple[0].shape[0], 1)" << std::endl;
354 std::cout << prefix << d.name << "_mat = arma_numpy.numpy_to_"
355 << GetArmaType<T>() << "_" << GetNumpyTypeChar<T>() << "(" << d.name
356 << "_tuple[0], " << d.name << "_tuple[1])" << std::endl;
357 std::cout << prefix << "SetParam[" << GetCythonType<T>(d)
358 << "](<const string> '" << d.name << "', dereference(" << d.name
359 << "_mat))" << std::endl;
360 std::cout << prefix << "IO.SetPassed(<const string> '" << d.name << "')"
361 << std::endl;
362 std::cout << prefix << "del " << d.name << "_mat" << std::endl;
363 }
364 }
365 std::cout << std::endl;
366}
367
371template<typename T>
374 const size_t indent,
375 const typename boost::disable_if<util::IsStdVector<T>>::type* = 0,
376 const typename boost::disable_if<arma::is_arma_type<T>>::type* = 0,
377 const typename boost::enable_if<data::HasSerialize<T>>::type* = 0)
378{
379 // First, get the correct class name if needed.
380 std::string strippedType, printedType, defaultsType;
381 StripType(d.cppType, strippedType, printedType, defaultsType);
382
383 const std::string prefix(indent, ' ');
384
401 std::cout << prefix << "# Detect if the parameter was passed; set if so."
402 << std::endl;
403 if (!d.required)
404 {
405 std::cout << prefix << "if " << d.name << " is not None:" << std::endl;
406 std::cout << prefix << " try:" << std::endl;
407 std::cout << prefix << " SetParamPtr[" << strippedType << "]('" << d.name
408 << "', (<" << strippedType << "Type?> " << d.name << ").modelptr, "
409 << "IO.HasParam('copy_all_inputs'))" << std::endl;
410 std::cout << prefix << " except TypeError as e:" << std::endl;
411 std::cout << prefix << " if type(" << d.name << ").__name__ == '"
412 << strippedType << "Type':" << std::endl;
413 std::cout << prefix << " SetParamPtr[" << strippedType << "]('"
414 << d.name << "', (<" << strippedType << "Type> " << d.name
415 << ").modelptr, IO.HasParam('copy_all_inputs'))" << std::endl;
416 std::cout << prefix << " else:" << std::endl;
417 std::cout << prefix << " raise e" << std::endl;
418 std::cout << prefix << " IO.SetPassed(<const string> '" << d.name << "')"
419 << std::endl;
420 }
421 else
422 {
423 std::cout << prefix << "try:" << std::endl;
424 std::cout << prefix << " SetParamPtr[" << strippedType << "]('" << d.name
425 << "', (<" << strippedType << "Type?> " << d.name << ").modelptr, "
426 << "IO.HasParam('copy_all_inputs'))" << std::endl;
427 std::cout << prefix << "except TypeError as e:" << std::endl;
428 std::cout << prefix << " if type(" << d.name << ").__name__ == '"
429 << strippedType << "Type':" << std::endl;
430 std::cout << prefix << " SetParamPtr[" << strippedType << "]('" << d.name
431 << "', (<" << strippedType << "Type> " << d.name << ").modelptr, "
432 << "IO.HasParam('copy_all_inputs'))" << std::endl;
433 std::cout << prefix << " else:" << std::endl;
434 std::cout << prefix << " raise e" << std::endl;
435 std::cout << prefix << "IO.SetPassed(<const string> '" << d.name << "')"
436 << std::endl;
437 }
438 std::cout << std::endl;
439}
440
444template<typename T>
447 const size_t indent,
448 const typename boost::disable_if<util::IsStdVector<T>>::type* = 0,
449 const typename boost::enable_if<std::is_same<T,
450 std::tuple<data::DatasetInfo, arma::mat>>>::type* = 0)
451{
452 // The user should pass in a matrix type of some sort.
453 const std::string prefix(indent, ' ');
454
466 std::cout << prefix << "cdef np.ndarray " << d.name << "_dims" << std::endl;
467 std::cout << prefix << "# Detect if the parameter was passed; set if so."
468 << std::endl;
469 if (!d.required)
470 {
471 std::cout << prefix << "if " << d.name << " is not None:" << std::endl;
472 std::cout << prefix << " " << d.name << "_tuple = to_matrix_with_info("
473 << d.name << ", dtype=np.double, copy=IO.HasParam('copy_all_inputs'))"
474 << std::endl;
475 std::cout << prefix << " if len(" << d.name << "_tuple[0].shape"
476 << ") < 2:" << std::endl;
477 std::cout << prefix << " " << d.name << "_tuple[0].shape = (" << d.name
478 << "_tuple[0].shape[0], 1)" << std::endl;
479 std::cout << prefix << " " << d.name << "_mat = arma_numpy.numpy_to_mat_d("
480 << d.name << "_tuple[0], " << d.name << "_tuple[1])" << std::endl;
481 std::cout << prefix << " " << d.name << "_dims = " << d.name
482 << "_tuple[2]" << std::endl;
483 std::cout << prefix << " SetParamWithInfo[arma.Mat[double]](<const "
484 << "string> '" << d.name << "', dereference(" << d.name << "_mat), "
485 << "<const cbool*> " << d.name << "_dims.data)" << std::endl;
486 std::cout << prefix << " IO.SetPassed(<const string> '" << d.name
487 << "')" << std::endl;
488 std::cout << prefix << " del " << d.name << "_mat" << std::endl;
489 }
490 else
491 {
492 std::cout << prefix << d.name << "_tuple = to_matrix_with_info(" << d.name
493 << ", dtype=np.double, copy=IO.HasParam('copy_all_inputs'))"
494 << std::endl;
495 std::cout << prefix << "if len(" << d.name << "_tuple[0].shape"
496 << ") < 2:" << std::endl;
497 std::cout << prefix << " " << d.name << "_tuple[0].shape = (" << d.name
498 << "_tuple[0].shape[0], 1)" << std::endl;
499 std::cout << prefix << d.name << "_mat = arma_numpy.numpy_to_mat_d("
500 << d.name << "_tuple[0], " << d.name << "_tuple[1])" << std::endl;
501 std::cout << prefix << d.name << "_dims = " << d.name << "_tuple[2]"
502 << std::endl;
503 std::cout << prefix << "SetParamWithInfo[arma.Mat[double]](<const "
504 << "string> '" << d.name << "', dereference(" << d.name << "_mat), "
505 << "<const cbool*> " << d.name << "_dims.data)" << std::endl;
506 std::cout << prefix << "IO.SetPassed(<const string> '" << d.name << "')"
507 << std::endl;
508 std::cout << prefix << "del " << d.name << "_mat" << std::endl;
509 }
510 std::cout << std::endl;
511}
512
524template<typename T>
526 const void* input,
527 void* /* output */)
528{
529 PrintInputProcessing<typename std::remove_pointer<T>::type>(d,
530 *((size_t*) input));
531}
532
533} // namespace python
534} // namespace bindings
535} // namespace mlpack
536
537#endif
if(NOT BUILD_GO_SHLIB) macro(add_go_binding name) endmacro() return() endif() endmacro() macro(post_go_setup) if(BUILD_GO_BINDINGS) file(APPEND "$
Definition: CMakeLists.txt:3
python
Definition: CMakeLists.txt:6
void PrintInputProcessing(util::ParamData &d, const size_t indent, const typename boost::disable_if< util::IsStdVector< T > >::type *=0, const typename boost::disable_if< arma::is_arma_type< T > >::type *=0, const typename boost::disable_if< data::HasSerialize< T > >::type *=0, const typename boost::disable_if< std::is_same< T, std::tuple< data::DatasetInfo, arma::mat > > >::type *=0)
Print input processing for a standard option type.
void StripType(const std::string &inputType, std::string &strippedType, std::string &printedType, std::string &defaultsType)
Given an input type like, e.g., "LogisticRegression<>", return three types that can be used in Python...
Definition: strip_type.hpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.
Metaprogramming structure for vector detection.
This structure holds all of the information about a single parameter, including its value (which is s...
Definition: param_data.hpp:53
bool required
True if this option is required.
Definition: param_data.hpp:71
std::string name
Name of this parameter.
Definition: param_data.hpp:56
std::string cppType
The true name of the type, as it would be written in C++.
Definition: param_data.hpp:84