mlpack 3.4.2
random.hpp
Go to the documentation of this file.
1
11#ifndef MLPACK_CORE_MATH_RANDOM_HPP
12#define MLPACK_CORE_MATH_RANDOM_HPP
13
14#include <mlpack/prereqs.hpp>
15#include <mlpack/mlpack_export.hpp>
16#include <random>
17
18namespace mlpack {
19namespace math {
20
26// Global random object.
27extern MLPACK_EXPORT std::mt19937 randGen;
28// Global uniform distribution.
29extern MLPACK_EXPORT std::uniform_real_distribution<> randUniformDist;
30// Global normal distribution.
31extern MLPACK_EXPORT std::normal_distribution<> randNormalDist;
32
40inline void RandomSeed(const size_t seed)
41{
42 #if (!defined(BINDING_TYPE) || BINDING_TYPE != BINDING_TYPE_TEST)
43 randGen.seed((uint32_t) seed);
44 #if (BINDING_TYPE == BINDING_TYPE_R)
45 // To suppress Found ‘srand’, possibly from ‘srand’ (C).
46 (void) seed;
47 #else
48 srand((unsigned int) seed);
49 #endif
50 arma::arma_rng::set_seed(seed);
51 #else
52 (void) seed;
53 #endif
54}
55
63#if (BINDING_TYPE == BINDING_TYPE_TEST)
64inline void FixedRandomSeed()
65{
66 const static size_t seed = rand();
67 randGen.seed((uint32_t) seed);
68 srand((unsigned int) seed);
69 arma::arma_rng::set_seed(seed);
70}
71
72inline void CustomRandomSeed(const size_t seed)
73{
74 randGen.seed((uint32_t) seed);
75 srand((unsigned int) seed);
76 arma::arma_rng::set_seed(seed);
77}
78#endif
79
83inline double Random()
84{
86}
87
91inline double Random(const double lo, const double hi)
92{
93 return lo + (hi - lo) * randUniformDist(randGen);
94}
95
99inline double RandBernoulli(const double input)
100{
101 if (Random() < input)
102 return 1;
103 else
104 return 0;
105}
106
110inline int RandInt(const int hiExclusive)
111{
112 return (int) std::floor((double) hiExclusive * randUniformDist(randGen));
113}
114
118inline int RandInt(const int lo, const int hiExclusive)
119{
120 return lo + (int) std::floor((double) (hiExclusive - lo)
122}
123
127inline double RandNormal()
128{
129 return randNormalDist(randGen);
130}
131
139inline double RandNormal(const double mean, const double variance)
140{
141 return variance * randNormalDist(randGen) + mean;
142}
143
153inline void ObtainDistinctSamples(const size_t loInclusive,
154 const size_t hiExclusive,
155 const size_t maxNumSamples,
156 arma::uvec& distinctSamples)
157{
158 const size_t samplesRangeSize = hiExclusive - loInclusive;
159
160 if (samplesRangeSize > maxNumSamples)
161 {
162 arma::Col<size_t> samples;
163
164 samples.zeros(samplesRangeSize);
165
166 for (size_t i = 0; i < maxNumSamples; ++i)
167 samples [ (size_t) math::RandInt(samplesRangeSize) ]++;
168
169 distinctSamples = arma::find(samples > 0);
170
171 if (loInclusive > 0)
172 distinctSamples += loInclusive;
173 }
174 else
175 {
176 distinctSamples.set_size(samplesRangeSize);
177 for (size_t i = 0; i < samplesRangeSize; ++i)
178 distinctSamples[i] = loInclusive + i;
179 }
180}
181
182} // namespace math
183} // namespace mlpack
184
185#endif // MLPACK_CORE_MATH_MATH_LIB_HPP
MLPACK_EXPORT std::normal_distribution randNormalDist
void FixedRandomSeed()
Set the random seed to a fixed number.
Definition: random.hpp:64
double Random()
Generates a uniform random number between 0 and 1.
Definition: random.hpp:83
double RandNormal()
Generates a normally distributed random number with mean 0 and variance 1.
Definition: random.hpp:127
MLPACK_EXPORT std::mt19937 randGen
MLPACK_EXPORT is required for global variables; it exports the symbols correctly on Windows.
double RandBernoulli(const double input)
Generates a 0/1 specified by the input.
Definition: random.hpp:99
MLPACK_EXPORT std::uniform_real_distribution randUniformDist
void RandomSeed(const size_t seed)
Set the random seed used by the random functions (Random() and RandInt()).
Definition: random.hpp:40
void ObtainDistinctSamples(const size_t loInclusive, const size_t hiExclusive, const size_t maxNumSamples, arma::uvec &distinctSamples)
Obtains no more than maxNumSamples distinct samples.
Definition: random.hpp:153
int RandInt(const int hiExclusive)
Generates a uniform random integer.
Definition: random.hpp:110
void CustomRandomSeed(const size_t seed)
Definition: random.hpp:72
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.